import os import cv2 import gradio as gr import numpy as np import json import pickle from PIL import Image import torch from torch.nn.utils.rnn import pad_sequence from transformers import BridgeTowerProcessor from tqdm import tqdm from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC import faiss import webvtt from pytube import YouTube from youtube_transcript_api import YouTubeTranscriptApi from youtube_transcript_api.formatters import WebVTTFormatter if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc' model = BridgeTowerForITC.from_pretrained(model_name).to(device) text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device) processor = BridgeTowerProcessor.from_pretrained(model_name) def download_video(video_url, path='/tmp/'): yt = YouTube(video_url) yt = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() if not os.path.exists(path): os.makedirs(path) filepath = os.path.join(path, yt.default_filename) if not os.path.exists(filepath): print('Downloading video from YouTube...') yt.download(path) return filepath # Get transcript in webvtt def get_transcript_vtt(video_id, path='/tmp'): filepath = os.path.join(path,'test_vm.vtt') if os.path.exists(filepath): return filepath transcript = YouTubeTranscriptApi.get_transcript(video_id) formatter = WebVTTFormatter() webvtt_formatted = formatter.format_transcript(transcript) with open(filepath, 'w', encoding='utf-8') as webvtt_file: webvtt_file.write(webvtt_formatted) webvtt_file.close() return filepath # https://stackoverflow.com/a/57781047 # Resizes a image and maintains aspect ratio def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA): # Grab the image size and initialize dimensions dim = None (h, w) = image.shape[:2] # Return original image if no need to resize if width is None and height is None: return image # We are resizing height if width is none if width is None: # Calculate the ratio of the height and construct the dimensions r = height / float(h) dim = (int(w * r), height) # We are resizing width if height is none else: # Calculate the ratio of the width and construct the dimensions r = width / float(w) dim = (width, int(h * r)) # Return the resized image return cv2.resize(image, dim, interpolation=inter) def time_to_frame(time, fps): ''' convert time in seconds into frame number ''' return int(time * fps - 1) def str2time(strtime): strtime = strtime.strip('"') hrs, mins, seconds = [float(c) for c in strtime.split(':')] total_seconds = hrs * 60**2 + mins * 60 + seconds return total_seconds def collate_fn(batch_list): batch = {} batch['input_ids'] = pad_sequence([encoding['input_ids'].squeeze(0) for encoding in batch_list], batch_first=True) batch['attention_mask'] = pad_sequence([encoding['attention_mask'].squeeze(0) for encoding in batch_list], batch_first=True) batch['pixel_values'] = torch.cat([encoding['pixel_values'] for encoding in batch_list], dim=0) batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0) return batch def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2, progress=gr.Progress()): if os.path.exists(os.path.join(output, 'embeddings.pkl')): return os.makedirs(output, exist_ok=True) os.makedirs(os.path.join(output, 'frames'), exist_ok=True) os.makedirs(os.path.join(output, 'frames_thumb'), exist_ok=True) count = 0 vidcap = cv2.VideoCapture(video_path) # Get the frames per second fps = vidcap.get(cv2.CAP_PROP_FPS) # Get the total numer of frames in the video. frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT) # print(fps, frame_count) frame_number = 0 count = 0 anno = [] embeddings = [] batch_list = [] vtt = webvtt.read(subtitles) for idx, caption in enumerate(tqdm(vtt, total=vtt.total_length, desc="Generating embeddings")): st_time = str2time(caption.start) ed_time = str2time(caption.end) mid_time = (ed_time + st_time) / 2 text = caption.text.replace('\n', ' ') if expanded : raise NotImplementedError frame_no = time_to_frame(mid_time, fps) mid_time_ms = mid_time * 1000 # vidcap.set(1, frame_no) # added this line vidcap.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms) print('Read a new frame: ', idx, mid_time, frame_no, text) success, frame = vidcap.read() if success: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) img_fname = f'{video_id}_{idx:06d}' img_fpath = os.path.join(output, 'frames', img_fname + '.jpg') # image = maintain_aspect_ratio_resize(image, height=350) # save frame as JPEG file # cv2.imwrite( img_fpath, image) # save frame as JPEG file count += 1 anno.append({ 'image_id': idx, 'img_fname': img_fname, 'caption': text, 'time': mid_time_ms, 'frame_no': frame_no }) encoding = processor(frame, text, return_tensors="pt").to(device) encoding['text'] = text encoding['image_filepath'] = img_fpath encoding['start_time'] = caption.start encoding['time'] = mid_time_ms batch_list.append(encoding) else: break if len(batch_list) == batch_size: batch = collate_fn(batch_list) with torch.no_grad(): outputs = model(**batch, output_hidden_states=True) for i in range(batch_size): embeddings.append({ 'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(), 'text': batch_list[i]['text'], 'image_filepath': batch_list[i]['image_filepath'], 'start_time': batch_list[i]['start_time'], 'time': batch_list[i]['time'], }) batch_list = [] if batch_list: batch = collate_fn(batch_list) with torch.no_grad(): outputs = model(**batch, output_hidden_states=True) for i in range(len(batch_list)): embeddings.append({ 'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(), 'text': batch_list[i]['text'], 'image_filepath': batch_list[i]['image_filepath'], 'start_time': batch_list[i]['start_time'], 'time': batch_list[i]['time'], }) batch_list = [] with open(os.path.join(output, 'annotations.json'), 'w') as fh: json.dump(anno, fh) with open(os.path.join(output, 'embeddings.pkl'), 'wb') as fh: pickle.dump(embeddings, fh) def run_query(video_path, text_query, path='/tmp'): vidcap = cv2.VideoCapture(video_path) embeddings_filepath = os.path.join(path, 'embeddings.pkl') faiss_filepath = os.path.join(path, 'faiss_index.pkl') embeddings = pickle.load(open(embeddings_filepath, 'rb')) if os.path.exists(faiss_filepath): faiss_index = pickle.load(open(faiss_filepath, 'rb')) else : embs = [emb['embeddings'] for emb in embeddings] vectors = np.stack(embs, axis=0) num_vectors, vector_dim = vectors.shape faiss_index = faiss.IndexFlatIP(vector_dim) faiss_index.add(vectors) pickle.dump(faiss_index, open(faiss_filepath, 'wb')) print('Processing query') encoding = processor.tokenizer(text_query, return_tensors="pt").to(device) with torch.no_grad(): outputs = text_model(**encoding) emb_query = outputs.cpu().numpy() print('Running FAISS search') _, I = faiss_index.search(emb_query, 6) clip_images = [] transcripts = [] for idx in I[0]: # frame_no = embeddings[idx]['frame_no'] # vidcap.set(1, frame_no) # added this line frame_timestamp = embeddings[idx]['time'] vidcap.set(cv2.CAP_PROP_POS_MSEC, frame_timestamp) success, frame = vidcap.read() if success: frame = maintain_aspect_ratio_resize(frame, height=400) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) clip_images.append(frame) transcripts.append(f"({embeddings[idx]['start_time']}) {embeddings[idx]['text']}") return clip_images, transcripts #https://stackoverflow.com/a/7936523 def get_video_id_from_url(video_url): """ Examples: - http://youtu.be/SA2iWivDJiE - http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu - http://www.youtube.com/embed/SA2iWivDJiE - http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US """ import urllib.parse url = urllib.parse.urlparse(video_url) if url.hostname == 'youtu.be': return url.path[1:] if url.hostname in ('www.youtube.com', 'youtube.com'): if url.path == '/watch': p = urllib.parse.parse_qs(url.query) return p['v'][0] if url.path[:7] == '/embed/': return url.path.split('/')[2] if url.path[:3] == '/v/': return url.path.split('/')[2] return None def process(video_url, text_query, progress=gr.Progress(track_tqdm=True)): tmp_dir = os.environ.get('TMPDIR', '/tmp') video_id = get_video_id_from_url(video_url) output_dir = os.path.join(tmp_dir, video_id) video_file = download_video(video_url, path=output_dir) subtitles = get_transcript_vtt(video_id, path=output_dir) extract_images_and_embeds(video_id=video_id, video_path=video_file, subtitles=subtitles, output=output_dir, expanded=False, batch_size=8, progress=progress, ) frame_paths, transcripts = run_query(video_file, text_query, path=output_dir) return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)] description = "This Space lets you run semantic search on a video." with gr.Blocks() as demo: gr.Markdown(description) with gr.Row(): with gr.Column(): video_url = gr.Text(label="Youtube url") text_query = gr.Text(label="Text query") btn = gr.Button("Run query") video_player = gr.Video(label="Video") with gr.Row(): gallery = gr.Gallery(label="Images") gr.Examples( examples=[ ['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'], ['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake'], ['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'bunny'], ['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'], ], inputs=[video_url, text_query], ) btn.click(fn=process, inputs=[video_url, text_query], outputs=[video_player, gallery], ) try: demo.queue(concurrency_count=3) demo.launch(share=True) except: demo.launch()