search-in-video / app.py
Armen Gabrielyan
improve labels for seconds
b82b5a6
from datetime import timedelta
import gradio as gr
from sentence_transformers import SentenceTransformer
import torchvision
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from inference import Inference
import utils
encoder_model_name = 'google/vit-large-patch32-224-in21k'
decoder_model_name = 'gpt2-large'
frame_step = 300
inference = Inference(
decoder_model_name=decoder_model_name,
)
model = SentenceTransformer('all-mpnet-base-v2')
def search_in_video(video, query):
result = torchvision.io.read_video(video)
video = result[0]
video_fps = result[2]['video_fps']
video_segments = [
video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step)
]
pixel_values = [utils.video2image(video_seg, encoder_model_name) for video_seg in video_segments]
pixel_values = torch.stack(pixel_values)
generated_texts = inference.generate_texts(pixel_values)
sentences = [query] + generated_texts
sentence_embeddings = model.encode(sentences)
similarities = cosine_similarity(
[sentence_embeddings[0]],
sentence_embeddings[1:]
)
arg_sorted_similarities = np.argsort(similarities)
ordered_similarity_scores = similarities[0][arg_sorted_similarities]
top1 = video_segments[arg_sorted_similarities[0, -1]]
top2 = video_segments[arg_sorted_similarities[0, -2]]
top3 = video_segments[arg_sorted_similarities[0, -3]]
torchvision.io.write_video('top1.mp4', top1, video_fps)
torchvision.io.write_video('top2.mp4', top2, video_fps)
torchvision.io.write_video('top3.mp4', top3, video_fps)
total_frames = video.shape[0]
video_frame_segs = [
[idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step)
]
ordered_start_ends = []
for [start, end] in video_frame_segs:
s = timedelta(seconds=(start / video_fps))
e = timedelta(seconds=(end / video_fps))
ordered_start_ends.append(f'{s}:{e}')
ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities]
labels_to_scores = dict(
zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist())
)
return 'top1.mp4', 'top2.mp4', 'top3.mp4', labels_to_scores
app = gr.Interface(
fn=search_in_video,
inputs=['video', 'text'],
outputs=[
gr.Video(format='mp4', label='Top1'),
gr.Video(format='mp4', label='Top2'),
gr.Video(format='mp4', label='Top3'),
gr.outputs.Label(num_top_classes=5, type='auto', label='Scores'),
],
)
app.launch()