Spaces:
Runtime error
Runtime error
Update
Browse files- app.py +25 -14
- bridgetower_custom.py +2 -2
app.py
CHANGED
@@ -87,7 +87,7 @@ def time_to_frame(time, fps):
|
|
87 |
'''
|
88 |
convert time in seconds into frame number
|
89 |
'''
|
90 |
-
return time * fps - 1
|
91 |
|
92 |
def str2time(strtime):
|
93 |
strtime = strtime.strip('"')
|
@@ -105,7 +105,7 @@ def collate_fn(batch_list):
|
|
105 |
batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
|
106 |
return batch
|
107 |
|
108 |
-
def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2):
|
109 |
if os.path.exists(os.path.join(output, 'embeddings.pkl')):
|
110 |
return
|
111 |
|
@@ -123,7 +123,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
123 |
# Get the total numer of frames in the video.
|
124 |
frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
|
125 |
|
126 |
-
print(fps, frame_count)
|
127 |
|
128 |
frame_number = 0
|
129 |
|
@@ -132,8 +132,9 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
132 |
|
133 |
embeddings = []
|
134 |
batch_list = []
|
|
|
135 |
|
136 |
-
for idx, caption in enumerate(
|
137 |
st_time = str2time(caption.start)
|
138 |
ed_time = str2time(caption.end)
|
139 |
|
@@ -144,9 +145,10 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
144 |
raise NotImplementedError
|
145 |
|
146 |
frame_no = time_to_frame(mid_time, fps)
|
147 |
-
|
|
|
|
|
148 |
print('Read a new frame: ', idx, mid_time, frame_no, text)
|
149 |
-
vidcap.set(1, frame_no) # added this line
|
150 |
success, frame = vidcap.read()
|
151 |
if success:
|
152 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
@@ -161,7 +163,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
161 |
'image_id': idx,
|
162 |
'img_fname': img_fname,
|
163 |
'caption': text,
|
164 |
-
'time':
|
165 |
'frame_no': frame_no
|
166 |
})
|
167 |
|
@@ -169,6 +171,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
169 |
encoding['text'] = text
|
170 |
encoding['image_filepath'] = img_fpath
|
171 |
encoding['start_time'] = caption.start
|
|
|
172 |
|
173 |
batch_list.append(encoding)
|
174 |
|
@@ -186,7 +189,7 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
186 |
'text': batch_list[i]['text'],
|
187 |
'image_filepath': batch_list[i]['image_filepath'],
|
188 |
'start_time': batch_list[i]['start_time'],
|
189 |
-
'
|
190 |
})
|
191 |
batch_list = []
|
192 |
|
@@ -201,9 +204,11 @@ def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=
|
|
201 |
'text': batch_list[i]['text'],
|
202 |
'image_filepath': batch_list[i]['image_filepath'],
|
203 |
'start_time': batch_list[i]['start_time'],
|
204 |
-
'
|
205 |
})
|
206 |
|
|
|
|
|
207 |
with open(os.path.join(output, 'annotations.json'), 'w') as fh:
|
208 |
json.dump(anno, fh)
|
209 |
|
@@ -240,10 +245,14 @@ def run_query(video_path, text_query, path='/tmp'):
|
|
240 |
clip_images = []
|
241 |
transcripts = []
|
242 |
for idx in I[0]:
|
243 |
-
frame_no = embeddings[idx]['frame_no']
|
244 |
-
vidcap.set(1, frame_no) # added this line
|
|
|
|
|
|
|
245 |
success, frame = vidcap.read()
|
246 |
if success:
|
|
|
247 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
248 |
frame = Image.fromarray(frame)
|
249 |
clip_images.append(frame)
|
@@ -277,7 +286,7 @@ def get_video_id_from_url(video_url):
|
|
277 |
return None
|
278 |
|
279 |
|
280 |
-
def process(video_url, text_query):
|
281 |
tmp_dir = os.environ.get('TMPDIR', '/tmp')
|
282 |
video_id = get_video_id_from_url(video_url)
|
283 |
output_dir = os.path.join(tmp_dir, video_id)
|
@@ -289,6 +298,7 @@ def process(video_url, text_query):
|
|
289 |
output=output_dir,
|
290 |
expanded=False,
|
291 |
batch_size=8,
|
|
|
292 |
)
|
293 |
frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
|
294 |
return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
|
@@ -311,8 +321,8 @@ with gr.Blocks() as demo:
|
|
311 |
gr.Examples(
|
312 |
examples=[
|
313 |
['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
|
314 |
-
['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake
|
315 |
-
['https://www.youtube.com/watch?v=rmPpNsx4yAk', '
|
316 |
['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
|
317 |
],
|
318 |
inputs=[video_url, text_query],
|
@@ -324,6 +334,7 @@ with gr.Blocks() as demo:
|
|
324 |
)
|
325 |
|
326 |
try:
|
|
|
327 |
demo.launch(share=True)
|
328 |
except:
|
329 |
demo.launch()
|
|
|
87 |
'''
|
88 |
convert time in seconds into frame number
|
89 |
'''
|
90 |
+
return int(time * fps - 1)
|
91 |
|
92 |
def str2time(strtime):
|
93 |
strtime = strtime.strip('"')
|
|
|
105 |
batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
|
106 |
return batch
|
107 |
|
108 |
+
def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2, progress=gr.Progress()):
|
109 |
if os.path.exists(os.path.join(output, 'embeddings.pkl')):
|
110 |
return
|
111 |
|
|
|
123 |
# Get the total numer of frames in the video.
|
124 |
frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
|
125 |
|
126 |
+
# print(fps, frame_count)
|
127 |
|
128 |
frame_number = 0
|
129 |
|
|
|
132 |
|
133 |
embeddings = []
|
134 |
batch_list = []
|
135 |
+
vtt = webvtt.read(subtitles)
|
136 |
|
137 |
+
for idx, caption in progress.tqdm(enumerate(vtt), total=vtt.total_length, desc="Generating embeddings"):
|
138 |
st_time = str2time(caption.start)
|
139 |
ed_time = str2time(caption.end)
|
140 |
|
|
|
145 |
raise NotImplementedError
|
146 |
|
147 |
frame_no = time_to_frame(mid_time, fps)
|
148 |
+
mid_time_ms = mid_time * 1000
|
149 |
+
# vidcap.set(1, frame_no) # added this line
|
150 |
+
vidcap.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
|
151 |
print('Read a new frame: ', idx, mid_time, frame_no, text)
|
|
|
152 |
success, frame = vidcap.read()
|
153 |
if success:
|
154 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
163 |
'image_id': idx,
|
164 |
'img_fname': img_fname,
|
165 |
'caption': text,
|
166 |
+
'time': mid_time_ms,
|
167 |
'frame_no': frame_no
|
168 |
})
|
169 |
|
|
|
171 |
encoding['text'] = text
|
172 |
encoding['image_filepath'] = img_fpath
|
173 |
encoding['start_time'] = caption.start
|
174 |
+
encoding['time'] = mid_time_ms
|
175 |
|
176 |
batch_list.append(encoding)
|
177 |
|
|
|
189 |
'text': batch_list[i]['text'],
|
190 |
'image_filepath': batch_list[i]['image_filepath'],
|
191 |
'start_time': batch_list[i]['start_time'],
|
192 |
+
'time': batch_list[i]['time'],
|
193 |
})
|
194 |
batch_list = []
|
195 |
|
|
|
204 |
'text': batch_list[i]['text'],
|
205 |
'image_filepath': batch_list[i]['image_filepath'],
|
206 |
'start_time': batch_list[i]['start_time'],
|
207 |
+
'time': batch_list[i]['time'],
|
208 |
})
|
209 |
|
210 |
+
batch_list = []
|
211 |
+
|
212 |
with open(os.path.join(output, 'annotations.json'), 'w') as fh:
|
213 |
json.dump(anno, fh)
|
214 |
|
|
|
245 |
clip_images = []
|
246 |
transcripts = []
|
247 |
for idx in I[0]:
|
248 |
+
# frame_no = embeddings[idx]['frame_no']
|
249 |
+
# vidcap.set(1, frame_no) # added this line
|
250 |
+
frame_timestamp = embeddings[idx]['time']
|
251 |
+
vidcap.set(cv2.CAP_PROP_POS_MSEC, frame_timestamp)
|
252 |
+
|
253 |
success, frame = vidcap.read()
|
254 |
if success:
|
255 |
+
frame = maintain_aspect_ratio_resize(frame, height=400)
|
256 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
257 |
frame = Image.fromarray(frame)
|
258 |
clip_images.append(frame)
|
|
|
286 |
return None
|
287 |
|
288 |
|
289 |
+
def process(video_url, text_query, progress=gr.Progress()):
|
290 |
tmp_dir = os.environ.get('TMPDIR', '/tmp')
|
291 |
video_id = get_video_id_from_url(video_url)
|
292 |
output_dir = os.path.join(tmp_dir, video_id)
|
|
|
298 |
output=output_dir,
|
299 |
expanded=False,
|
300 |
batch_size=8,
|
301 |
+
progress=gr.Progress(),
|
302 |
)
|
303 |
frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
|
304 |
return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
|
|
|
321 |
gr.Examples(
|
322 |
examples=[
|
323 |
['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
|
324 |
+
['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake'],
|
325 |
+
['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'bunny'],
|
326 |
['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
|
327 |
],
|
328 |
inputs=[video_url, text_query],
|
|
|
334 |
)
|
335 |
|
336 |
try:
|
337 |
+
demo.queue(concurrency_count=3)
|
338 |
demo.launch(share=True)
|
339 |
except:
|
340 |
demo.launch()
|
bridgetower_custom.py
CHANGED
@@ -96,8 +96,8 @@ class BridgeTowerTextFeatureExtractor(BridgeTowerPreTrainedModel):
|
|
96 |
labels: Optional[torch.LongTensor] = None,
|
97 |
):
|
98 |
|
99 |
-
outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask)
|
100 |
-
final_hidden_cls = outputs.
|
101 |
final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
|
102 |
|
103 |
return final_hidden_cls
|
|
|
96 |
labels: Optional[torch.LongTensor] = None,
|
97 |
):
|
98 |
|
99 |
+
outputs = self.bridgetower(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
100 |
+
final_hidden_cls = outputs.hidden_states[-1][:,0,:]
|
101 |
final_hidden_cls = F.normalize(self.itc_text_head(final_hidden_cls), dim=-1, p=2)
|
102 |
|
103 |
return final_hidden_cls
|