Spaces:
Runtime error
Runtime error
File size: 5,619 Bytes
0edd243 9b1f236 0edd243 3406caf 0edd243 3406caf 0edd243 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
import yt_dlp
import os
import time
import torch
import transformers
import clip
import numpy as np
import cv2
import random
from PIL import Image
from multilingual_clip import pt_multilingual_clip
class SearchVideo:
def __init__(
self,
clip_model: str,
text_model: str,
tokenizer,
compose,
) -> None:
"""
clip_model: CLIP model to use for image embeddings
text_model: text encoder model
"""
self.text_model = text_model
self.tokenizer = tokenizer
self.clip_model = clip_model
self.compose = compose
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, video: str, text: str) -> list:
torch.cuda.empty_cache()
img_list = []
text_list = []
frames = self.video2frames_ffmpeg(video)
img_embs = self.get_img_embs(frames)
txt_emb = self.get_txt_embs(text)
# txt_emb = [[t]*len(frames) for t in txt_emb]
txt_emb = txt_emb*len(frames)
logits_per_image = self.compare_embeddings(img_embs, txt_emb)
logits_per_image = [logit.numpy()[0] for logit in logits_per_image]
ind = np.argmax(logits_per_image)
seg_path = self.extract_seg(video, ind)
return ind, seg_path, frames[ind]
def extract_seg(self, video:str, start:int):
start = start if start > 5 else start-5
start = time.strftime('%H:%M:%S', time.gmtime(start))
cmd = f'ffmpeg -ss {start} -i "{video}" -t 00:00:02 -vcodec copy -acodec copy -y segment_{start}.mp4'
os.system(cmd)
return f'segment_{start}.mp4'
def video2frames_ffmpeg(self, video: str) -> list:
frames_dir = 'frames'
if not os.path.exists(frames_dir):
os.makedirs(frames_dir)
select = "select='if(eq(n\,0),1,floor(t)-floor(prev_selected_t))'"
os.system(f'ffmpeg -i {video} -r 1 {frames_dir}/output-%04d.jpg')
images = [Image.open(f'{frames_dir}/{f}') for f in sorted(os.listdir(frames_dir))]
os.system(f'rm -rf {frames_dir}')
return images
def video2frames(self, video: str) -> list:
cap = cv2.VideoCapture(video)
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
images = []
frames_sec = [i for i in range(0, num_frames, 24*1)]
has_frames,image = cap.read()
frame_count = 0
while has_frames:
has_frames,image = cap.read()
frame_count += 1
if has_frames:
if frame_count in frames_sec:
image = Image.fromarray(image)
images.append(image)
return images
def get_img_embs(self, img_list: list) -> list:
"""
takes list of image and calculates clip embeddings with model specified by clip_model
"""
img_input = torch.stack([self.compose(img).to(self.device)
for img in img_list])
with torch.no_grad():
image_embs = self.clip_model.encode_image(img_input).float().cpu()
return image_embs
def get_txt_embs(self, text: str) -> torch.Tensor:
"calculates clip emebdding for the text "
with torch.no_grad():
return self.text_model(text, self.tokenizer)
def compare_embeddings(self, img_embs, txt_embs):
# normalized features
image_features = img_embs / img_embs.norm(dim=-1, keepdim=True)
text_features = txt_embs / txt_embs.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = []
for image_feature in image_features:
logits_per_image.append(image_feature @ text_features.t())
return logits_per_image
def download_yt_video(url):
ydl_opts = {
'quiet': True,
"outtmpl": "%(id)s.%(ext)s",
'format': 'bv*[height<=360][ext=mp4]+ba/b[height<=360] / wv*+ba/w'
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
return url.split('/')[-1].replace('watch?v=', '')+'.mp4'
clip_model='ViT-B/32'
text_model='M-CLIP/XLM-Roberta-Large-Vit-B-32'
clip_model, compose = clip.load(clip_model)
tokenizer = transformers.AutoTokenizer.from_pretrained(text_model)
text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(text_model)
def search_video(video_url, text, video=None):
search = SearchVideo(
clip_model=clip_model,
text_model=text_model,
tokenizer=tokenizer,
compose=compose
)
if video !=None:
video_url = None
if video_url:
video = download_yt_video(video_url)
ind, seg_path, img = search(video, text)
start = time.strftime('%H:%M:%S', time.gmtime(ind))
return f'"{text}" found at {start}', seg_path
title = 'πποΈπ Search inside a video'
description = '''Just enter a search query, a video URL or upload your video and get a 2-sec fragment from the video which is visually closest to you query.'''
examples = [["https://www.youtube.com/watch?v=M93w3TjzVUE", "A dog"]]
iface = gr.Interface(
search_video,
inputs=[gr.Textbox(value="https://www.youtube.com/watch?v=M93w3TjzVUE", label='Video URL'), gr.Textbox(value="a dog", label='Text query'), gr.Video()],
outputs=[gr.Textbox(label="Output"), gr.Video(label="Video segment")],
allow_flagging="never",
title=title,
description=description,
examples=examples
)
if __name__ == "__main__":
iface.launch(show_error=True) |