Spaces:
Runtime error
Runtime error
import gradio as gr | |
import yt_dlp | |
import os | |
import time | |
import torch | |
import transformers | |
import clip | |
import numpy as np | |
import cv2 | |
import random | |
from PIL import Image | |
from multilingual_clip import pt_multilingual_clip | |
class SearchVideo: | |
def __init__( | |
self, | |
clip_model: str, | |
text_model: str, | |
tokenizer, | |
compose, | |
) -> None: | |
""" | |
clip_model: CLIP model to use for image embeddings | |
text_model: text encoder model | |
""" | |
self.text_model = text_model | |
self.tokenizer = tokenizer | |
self.clip_model = clip_model | |
self.compose = compose | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def __call__(self, video: str, text: str) -> list: | |
torch.cuda.empty_cache() | |
img_list = [] | |
text_list = [] | |
frames = self.video2frames_ffmpeg(video) | |
img_embs = self.get_img_embs(frames) | |
txt_emb = self.get_txt_embs(text) | |
# txt_emb = [[t]*len(frames) for t in txt_emb] | |
txt_emb = txt_emb*len(frames) | |
logits_per_image = self.compare_embeddings(img_embs, txt_emb) | |
logits_per_image = [logit.numpy()[0] for logit in logits_per_image] | |
ind = np.argmax(logits_per_image) | |
seg_path = self.extract_seg(video, ind) | |
return ind, seg_path, frames[ind] | |
def extract_seg(self, video:str, start:int): | |
start = start if start > 5 else start-5 | |
start = time.strftime('%H:%M:%S', time.gmtime(start)) | |
cmd = f'ffmpeg -ss {start} -i "{video}" -t 00:00:02 -vcodec copy -acodec copy -y segment_{start}.mp4' | |
os.system(cmd) | |
return f'segment_{start}.mp4' | |
def video2frames_ffmpeg(self, video: str) -> list: | |
frames_dir = 'frames' | |
if not os.path.exists(frames_dir): | |
os.makedirs(frames_dir) | |
select = "select='if(eq(n\,0),1,floor(t)-floor(prev_selected_t))'" | |
os.system(f'ffmpeg -i {video} -r 1 {frames_dir}/output-%04d.jpg') | |
images = [Image.open(f'{frames_dir}/{f}') for f in sorted(os.listdir(frames_dir))] | |
os.system(f'rm -rf {frames_dir}') | |
return images | |
def video2frames(self, video: str) -> list: | |
cap = cv2.VideoCapture(video) | |
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
images = [] | |
frames_sec = [i for i in range(0, num_frames, 24*1)] | |
has_frames,image = cap.read() | |
frame_count = 0 | |
while has_frames: | |
has_frames,image = cap.read() | |
frame_count += 1 | |
if has_frames: | |
if frame_count in frames_sec: | |
image = Image.fromarray(image) | |
images.append(image) | |
return images | |
def get_img_embs(self, img_list: list) -> list: | |
""" | |
takes list of image and calculates clip embeddings with model specified by clip_model | |
""" | |
img_input = torch.stack([self.compose(img).to(self.device) | |
for img in img_list]) | |
with torch.no_grad(): | |
image_embs = self.clip_model.encode_image(img_input).float().cpu() | |
return image_embs | |
def get_txt_embs(self, text: str) -> torch.Tensor: | |
"calculates clip emebdding for the text " | |
with torch.no_grad(): | |
return self.text_model(text, self.tokenizer) | |
def compare_embeddings(self, img_embs, txt_embs): | |
# normalized features | |
image_features = img_embs / img_embs.norm(dim=-1, keepdim=True) | |
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True) | |
# cosine similarity as logits | |
logits_per_image = [] | |
for image_feature in image_features: | |
logits_per_image.append(image_feature @ text_features.t()) | |
return logits_per_image | |
def download_yt_video(url): | |
ydl_opts = { | |
'quiet': True, | |
"outtmpl": "%(id)s.%(ext)s", | |
'format': 'bv*[height<=360][ext=mp4]+ba/b[height<=360] / wv*+ba/w' | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
return url.split('/')[-1].replace('watch?v=', '')+'.mp4' | |
clip_model='ViT-B/32' | |
text_model='M-CLIP/XLM-Roberta-Large-Vit-B-32' | |
clip_model, compose = clip.load(clip_model) | |
tokenizer = transformers.AutoTokenizer.from_pretrained(text_model) | |
text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model) | |
def search_video(video_url, text, video=None): | |
search = SearchVideo( | |
clip_model=clip_model, | |
text_model=text_model, | |
tokenizer=tokenizer, | |
compose=compose | |
) | |
if video !=None: | |
video_url = None | |
if video_url: | |
video = download_yt_video(video_url) | |
ind, seg_path, img = search(video, text) | |
start = time.strftime('%H:%M:%S', time.gmtime(ind)) | |
return f'"{text}" found at {start}', seg_path | |
title = 'πποΈπ Search inside a video' | |
description = '''Just enter a search query, a video URL or upload your video and get a 2-sec fragment from the video which is visually closest to you query.''' | |
examples = [["https://www.youtube.com/watch?v=M93w3TjzVUE", "A dog"]] | |
iface = gr.Interface( | |
search_video, | |
inputs=[gr.Textbox(value="https://www.youtube.com/watch?v=M93w3TjzVUE", label='Video URL'), gr.Textbox(value="a dog", label='Text query'), gr.Video()], | |
outputs=[gr.Textbox(label="Output"), gr.Video(label="Video segment")], | |
allow_flagging="never", | |
title=title, | |
description=description, | |
examples=examples | |
) | |
if __name__ == "__main__": | |
iface.launch(show_error=True) |