File size: 2,624 Bytes
5e95a58
 
 
 
cde7ed6
5e95a58
 
 
 
 
 
 
4820fa1
5e95a58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde7ed6
 
5e95a58
cde7ed6
5e95a58
 
 
 
 
 
 
 
 
 
 
 
 
0f78704
 
 
 
 
 
5e95a58
 
 
 
 
 
 
 
 
b82b5a6
 
 
5e95a58
 
 
 
 
 
 
 
0f78704
5e95a58
 
 
 
0f78704
 
 
 
 
 
5e95a58
d80771b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from datetime import timedelta
import gradio as gr
from sentence_transformers import SentenceTransformer
import torchvision
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

from inference import Inference
import utils

encoder_model_name = 'google/vit-large-patch32-224-in21k'
decoder_model_name = 'gpt2-large'
frame_step = 300

inference = Inference(
    decoder_model_name=decoder_model_name,
)

model = SentenceTransformer('all-mpnet-base-v2')

def search_in_video(video, query):
    result = torchvision.io.read_video(video)
    video = result[0]
    video_fps = result[2]['video_fps']

    video_segments = [
        video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step)
    ]

    pixel_values = [utils.video2image(video_seg, encoder_model_name) for video_seg in video_segments]
    pixel_values = torch.stack(pixel_values)

    generated_texts = inference.generate_texts(pixel_values)

    sentences = [query] + generated_texts

    sentence_embeddings = model.encode(sentences)

    similarities = cosine_similarity(
        [sentence_embeddings[0]],
        sentence_embeddings[1:]
    )
    arg_sorted_similarities = np.argsort(similarities)

    ordered_similarity_scores = similarities[0][arg_sorted_similarities]

    top1 = video_segments[arg_sorted_similarities[0, -1]]
    top2 = video_segments[arg_sorted_similarities[0, -2]]
    top3 = video_segments[arg_sorted_similarities[0, -3]]
    torchvision.io.write_video('top1.mp4', top1, video_fps)
    torchvision.io.write_video('top2.mp4', top2, video_fps)
    torchvision.io.write_video('top3.mp4', top3, video_fps)

    total_frames = video.shape[0]

    video_frame_segs = [
        [idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step)
    ]
    ordered_start_ends = []

    for [start, end] in video_frame_segs:
        s = timedelta(seconds=(start / video_fps))
        e = timedelta(seconds=(end / video_fps))

        ordered_start_ends.append(f'{s}:{e}')
    
    ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities]

    labels_to_scores = dict(
        zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist())
    )

    return 'top1.mp4', 'top2.mp4', 'top3.mp4', labels_to_scores

app = gr.Interface(
    fn=search_in_video,
    inputs=['video', 'text'],
    outputs=[
        gr.Video(format='mp4', label='Top1'),
        gr.Video(format='mp4', label='Top2'),
        gr.Video(format='mp4', label='Top3'),
        gr.outputs.Label(num_top_classes=5, type='auto', label='Scores'),
    ],
)
app.launch()