stupidog04 commited on
Commit
83a4675
1 Parent(s): 658f973

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -31
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  from pathlib import Path
9
  import cv2
10
  import pandas as pd
 
11
 
12
 
13
  #torch.hub.download_url_to_file('https://github.com/AaronCWacker/Yggdrasil/blob/main/images/BeautyIsTruthTruthisBeauty.JPG', 'BeautyIsTruthTruthisBeauty.JPG')
@@ -20,6 +21,7 @@ torch.hub.download_url_to_file('https://github.com/JaidedAI/EasyOCR/raw/master/e
20
  torch.hub.download_url_to_file('https://github.com/JaidedAI/EasyOCR/raw/master/examples/japanese.jpg', 'japanese.jpg')
21
  torch.hub.download_url_to_file('https://i.imgur.com/mwQFd7G.jpeg', 'Hindi.jpeg')
22
 
 
23
  def draw_boxes(image, bounds, color='yellow', width=2):
24
  draw = ImageDraw.Draw(image)
25
  for bound in bounds:
@@ -39,8 +41,29 @@ def box_size(box):
39
  def box_position(box):
40
  return (box[0][0][0] + box[0][2][0]) / 2, (box[0][0][1] + box[0][2][1]) / 2
41
 
42
-
43
- def inference(video, lang, time_step, full_scan=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  output = 'results.mp4'
45
  reader = easyocr.Reader(lang)
46
  bounds = []
@@ -50,22 +73,25 @@ def inference(video, lang, time_step, full_scan=False):
50
  frame_rate = vidcap.get(cv2.CAP_PROP_FPS)
51
  output_frames = []
52
  temporal_profiles = []
53
- compress_mp4 = True
54
 
55
  # Get the positions of the largest boxes in the first frame
56
  bounds = reader.readtext(frame)
 
 
 
 
 
57
  im = PIL.Image.fromarray(frame)
58
  im_with_boxes = draw_boxes(im, bounds)
59
  largest_boxes = sorted(bounds, key=lambda x: box_size(x), reverse=True)
60
  positions = [box_position(b) for b in largest_boxes]
61
  temporal_profiles = [[] for _ in range(len(largest_boxes))]
62
-
63
- # Match bboxes to position and store the text read by OCR
64
  # Match bboxes to position and store the text read by OCR
65
- if full_scan:
66
- # Match bboxes to position and store the text read by OCR
67
- while success:
68
- if count % (int(frame_rate * time_step)) == 0:
69
  bounds = reader.readtext(frame)
70
  for box in bounds:
71
  bbox_pos = box_position(box)
@@ -74,15 +100,7 @@ def inference(video, lang, time_step, full_scan=False):
74
  if distance < 50:
75
  temporal_profiles[i].append((count / frame_rate, box[1]))
76
  break
77
- im = PIL.Image.fromarray(frame)
78
- im_with_boxes = draw_boxes(im, bounds)
79
- output_frames.append(np.array(im_with_boxes))
80
- success, frame = vidcap.read()
81
- count += 1
82
- else:
83
- # Match bboxes to position and store the text read by OCR
84
- while success:
85
- if count % (int(frame_rate * time_step)) == 0:
86
  for i, box in enumerate(largest_boxes):
87
  x1, y1 = box[0][0]
88
  x2, y2 = box[0][2]
@@ -94,15 +112,27 @@ def inference(video, lang, time_step, full_scan=False):
94
  y1 = max(0, int(y1 - ratio * box_height))
95
  y2 = min(frame.shape[0], int(y2 + ratio * box_height))
96
  cropped_frame = frame[y1:y2, x1:x2]
97
- text = reader.readtext(cropped_frame)
98
- if text:
99
- temporal_profiles[i].append((count / frame_rate, text[0][1]))
100
- im = PIL.Image.fromarray(frame)
101
- im_with_boxes = draw_boxes(im, bounds)
102
- output_frames.append(np.array(im_with_boxes))
103
- success, frame = vidcap.read()
104
- count += 1
105
-
 
 
 
 
 
 
 
 
 
 
 
 
106
  # Default resolutions of the frame are obtained. The default resolutions are system dependent.
107
  # We convert the resolutions from float to integer.
108
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -150,11 +180,11 @@ def inference(video, lang, time_step, full_scan=False):
150
 
151
 
152
  title = '🖼️Video to Multilingual OCR👁️Gradio'
153
- description = 'Multilingual OCR which works conveniently on all devices in multiple languages. Adjust time-step for inference and the scan mode according to your requirement. For `Full Scan`, model scan the whole image if flag is ture, while scan only the box detected at the first video frame; this save computation cost; noting that the box is fixed in this case.'
154
  article = "<p style='text-align: center'></p>"
155
 
156
  examples = [
157
- ['test.mp4',['en'],10,False]
158
  ]
159
 
160
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
@@ -176,12 +206,15 @@ gr.Interface(
176
  gr.inputs.Video(label='Input Video'),
177
  gr.inputs.CheckboxGroup(choices, type="value", default=['en'], label='Language'),
178
  gr.inputs.Number(label='Time Step (in seconds)', default=1.0),
179
- gr.inputs.Dropdown(['True', 'False'], label='Full Scan', default='False')
 
 
 
180
  ],
181
  [
182
  gr.outputs.Video(label='Output Video'),
183
  gr.outputs.Image(label='Output Preview', type='numpy'),
184
- gr.outputs.Dataframe(headers=['Box', 'Time (s)', 'Text'], type='pandas')
185
  ],
186
  title=title,
187
  description=description,
 
8
  from pathlib import Path
9
  import cv2
10
  import pandas as pd
11
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
12
 
13
 
14
  #torch.hub.download_url_to_file('https://github.com/AaronCWacker/Yggdrasil/blob/main/images/BeautyIsTruthTruthisBeauty.JPG', 'BeautyIsTruthTruthisBeauty.JPG')
 
21
  torch.hub.download_url_to_file('https://github.com/JaidedAI/EasyOCR/raw/master/examples/japanese.jpg', 'japanese.jpg')
22
  torch.hub.download_url_to_file('https://i.imgur.com/mwQFd7G.jpeg', 'Hindi.jpeg')
23
 
24
+
25
  def draw_boxes(image, bounds, color='yellow', width=2):
26
  draw = ImageDraw.Draw(image)
27
  for bound in bounds:
 
41
  def box_position(box):
42
  return (box[0][0][0] + box[0][2][0]) / 2, (box[0][0][1] + box[0][2][1]) / 2
43
 
44
+ def filter_temporal_profiles(temporal_profiles, period_index):
45
+ filtered_profiles = []
46
+ for profile in temporal_profiles:
47
+ filtered_profile = []
48
+ for t, text in profile:
49
+ # Remove all non-digit characters from text
50
+ filtered_text = ''.join(filter(str.isdigit, text))
51
+ # Insert period at the specified index
52
+ filtered_text = filtered_text[:period_index] + "." + filtered_text[period_index:]
53
+ try:
54
+ filtered_value = float(filtered_text)
55
+ except ValueError:
56
+ continue
57
+ filtered_profile.append((t, filtered_value))
58
+ filtered_profiles.append(filtered_profile)
59
+ return filtered_profiles
60
+
61
+
62
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
63
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-printed')
64
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-printed').to(device)
65
+
66
+ def inference(video, lang, time_step, full_scan, number_filter, use_trocr, period_index):
67
  output = 'results.mp4'
68
  reader = easyocr.Reader(lang)
69
  bounds = []
 
73
  frame_rate = vidcap.get(cv2.CAP_PROP_FPS)
74
  output_frames = []
75
  temporal_profiles = []
76
+ compress_mp4 = False
77
 
78
  # Get the positions of the largest boxes in the first frame
79
  bounds = reader.readtext(frame)
80
+ for i in reversed(range(len(bounds))):
81
+ box = bounds[i]
82
+ # Remove box if it doesn't contain a number
83
+ if not any(char.isdigit() for char in box[1]):
84
+ bounds.pop(i)
85
  im = PIL.Image.fromarray(frame)
86
  im_with_boxes = draw_boxes(im, bounds)
87
  largest_boxes = sorted(bounds, key=lambda x: box_size(x), reverse=True)
88
  positions = [box_position(b) for b in largest_boxes]
89
  temporal_profiles = [[] for _ in range(len(largest_boxes))]
90
+
 
91
  # Match bboxes to position and store the text read by OCR
92
+ while success:
93
+ if count % (int(frame_rate * time_step)) == 0:
94
+ if full_scan:
 
95
  bounds = reader.readtext(frame)
96
  for box in bounds:
97
  bbox_pos = box_position(box)
 
100
  if distance < 50:
101
  temporal_profiles[i].append((count / frame_rate, box[1]))
102
  break
103
+ else:
 
 
 
 
 
 
 
 
104
  for i, box in enumerate(largest_boxes):
105
  x1, y1 = box[0][0]
106
  x2, y2 = box[0][2]
 
112
  y1 = max(0, int(y1 - ratio * box_height))
113
  y2 = min(frame.shape[0], int(y2 + ratio * box_height))
114
  cropped_frame = frame[y1:y2, x1:x2]
115
+ if use_trocr:
116
+ pixel_values = processor(images=cropped_frame, return_tensors="pt").pixel_values
117
+ generated_ids = model.generate(pixel_values.to(device))
118
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
+ temporal_profiles[i].append((count / frame_rate, generated_text))
120
+ else:
121
+ text = reader.readtext(cropped_frame)
122
+ if text:
123
+ temporal_profiles[i].append((count / frame_rate, text[0][1]))
124
+
125
+ im = PIL.Image.fromarray(frame)
126
+ im_with_boxes = draw_boxes(im, bounds)
127
+ output_frames.append(np.array(im_with_boxes))
128
+
129
+ success, frame = vidcap.read()
130
+ count += 1
131
+
132
+ if number_filter:
133
+ # Filter the temporal profiles by removing non-matching characters and converting to floats
134
+ temporal_profiles = filter_temporal_profiles(temporal_profiles, int(period_index))
135
+
136
  # Default resolutions of the frame are obtained. The default resolutions are system dependent.
137
  # We convert the resolutions from float to integer.
138
  width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
180
 
181
 
182
  title = '🖼️Video to Multilingual OCR👁️Gradio'
183
+ description = 'Multilingual OCR which works conveniently on all devices in multiple languages. Adjust time-step for inference and the scan mode according to your requirement. For `Full Screen Scan`, model scan the whole image if flag is ture, while scan only the box detected at the first video frame; this accelerate the inference while detecting the fixed box.'
184
  article = "<p style='text-align: center'></p>"
185
 
186
  examples = [
187
+ ['test.mp4',['en'],10,]
188
  ]
189
 
190
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
 
206
  gr.inputs.Video(label='Input Video'),
207
  gr.inputs.CheckboxGroup(choices, type="value", default=['en'], label='Language'),
208
  gr.inputs.Number(label='Time Step (in seconds)', default=1.0),
209
+ gr.inputs.Checkbox(label='Full Screen Scan'),
210
+ gr.inputs.Checkbox(label='Use TrOCR large (this is only available when Full Screen Scan is disable)'),
211
+ gr.inputs.Checkbox(label='Number Filter (remove non-digit char and insert period)'),
212
+ gr.inputs.Textbox(label="period position",default=1)
213
  ],
214
  [
215
  gr.outputs.Video(label='Output Video'),
216
  gr.outputs.Image(label='Output Preview', type='numpy'),
217
+ gr.outputs.Dataframe(headers=['Box', 'Time (s)', 'Text'], type='pandas'),
218
  ],
219
  title=title,
220
  description=description,