Spaces:
Runtime error
Runtime error
from datetime import timedelta | |
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import torchvision | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
from inference import Inference | |
import utils | |
encoder_model_name = 'google/vit-large-patch32-224-in21k' | |
decoder_model_name = 'gpt2' | |
frame_step = 300 | |
inference = Inference( | |
decoder_model_name=decoder_model_name, | |
) | |
model = SentenceTransformer('all-mpnet-base-v2') | |
def search_in_video(video, query): | |
result = torchvision.io.read_video(video) | |
video = result[0] | |
video_fps = result[2]['video_fps'] | |
video_segments = [ | |
video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step) | |
] | |
generated_texts = [] | |
for video_seg in video_segments: | |
pixel_values = utils.video2image(video_seg, encoder_model_name) | |
generated_text = inference.generate_text(pixel_values, encoder_model_name) | |
generated_texts.append(generated_text) | |
sentences = [query] + generated_texts | |
sentence_embeddings = model.encode(sentences) | |
similarities = cosine_similarity( | |
[sentence_embeddings[0]], | |
sentence_embeddings[1:] | |
) | |
arg_sorted_similarities = np.argsort(similarities) | |
ordered_similarity_scores = similarities[0][arg_sorted_similarities] | |
best_video = video_segments[arg_sorted_similarities[0, -1]] | |
torchvision.io.write_video('best.mp4', best_video, video_fps) | |
total_frames = video.shape[0] | |
video_frame_segs = [ | |
[idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step) | |
] | |
ordered_start_ends = [] | |
for [start, end] in video_frame_segs: | |
td = timedelta(seconds=(start / video_fps)) | |
s = round(td.total_seconds(), 2) | |
td = timedelta(seconds=(end / video_fps)) | |
e = round(td.total_seconds(), 2) | |
ordered_start_ends.append(f'{s}:{e}') | |
ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities] | |
labels_to_scores = dict( | |
zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist()) | |
) | |
return 'best.mp4', labels_to_scores | |
app = gr.Interface( | |
fn=search_in_video, | |
inputs=['video', 'text'], | |
outputs=['video', gr.outputs.Label(num_top_classes=3, type='auto')], | |
) | |
app.launch(share=True) | |