Spaces:
Runtime error
Runtime error
from datetime import timedelta | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import torchvision | |
import torch | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
from inference import Inference | |
import utils | |
encoder_model_name = 'google/vit-large-patch32-224-in21k' | |
decoder_model_name = 'gpt2-large' | |
frame_step = 300 | |
inference = Inference( | |
decoder_model_name=decoder_model_name, | |
) | |
model = SentenceTransformer('all-mpnet-base-v2') | |
def search_in_video(video, query): | |
result = torchvision.io.read_video(video) | |
video = result[0] | |
video_fps = result[2]['video_fps'] | |
video_segments = [ | |
video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step) | |
] | |
pixel_values = [utils.video2image(video_seg, encoder_model_name) for video_seg in video_segments] | |
pixel_values = torch.stack(pixel_values) | |
generated_texts = inference.generate_texts(pixel_values) | |
sentences = [query] + generated_texts | |
sentence_embeddings = model.encode(sentences) | |
similarities = cosine_similarity( | |
[sentence_embeddings[0]], | |
sentence_embeddings[1:] | |
) | |
arg_sorted_similarities = np.argsort(similarities) | |
ordered_similarity_scores = similarities[0][arg_sorted_similarities] | |
top1 = video_segments[arg_sorted_similarities[0, -1]] | |
top2 = video_segments[arg_sorted_similarities[0, -2]] | |
top3 = video_segments[arg_sorted_similarities[0, -3]] | |
torchvision.io.write_video('top1.mp4', top1, video_fps) | |
torchvision.io.write_video('top2.mp4', top2, video_fps) | |
torchvision.io.write_video('top3.mp4', top3, video_fps) | |
total_frames = video.shape[0] | |
video_frame_segs = [ | |
[idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step) | |
] | |
ordered_start_ends = [] | |
for [start, end] in video_frame_segs: | |
s = timedelta(seconds=(start / video_fps)) | |
e = timedelta(seconds=(end / video_fps)) | |
ordered_start_ends.append(f'{s}:{e}') | |
ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities] | |
labels_to_scores = dict( | |
zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist()) | |
) | |
return 'top1.mp4', 'top2.mp4', 'top3.mp4', labels_to_scores | |
app = gr.Interface( | |
fn=search_in_video, | |
inputs=['video', 'text'], | |
outputs=[ | |
gr.Video(format='mp4', label='Top1'), | |
gr.Video(format='mp4', label='Top2'), | |
gr.Video(format='mp4', label='Top3'), | |
gr.outputs.Label(num_top_classes=5, type='auto', label='Scores'), | |
], | |
) | |
app.launch() | |