Spaces:
Runtime error
Runtime error
Armen Gabrielyan
commited on
Commit
·
5e95a58
1
Parent(s):
deb4867
add initial app
Browse files- app.py +84 -0
- inference.py +29 -0
- requirements.txt +4 -0
- utils.py +44 -0
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta
|
2 |
+
import gradio as gr
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
import torchvision
|
5 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from inference import Inference
|
9 |
+
import utils
|
10 |
+
|
11 |
+
model_checkpoint = 'saved_model'
|
12 |
+
encoder_model_name = 'google/vit-large-patch32-224-in21k'
|
13 |
+
decoder_model_name = 'gpt2'
|
14 |
+
frame_step = 300
|
15 |
+
|
16 |
+
inference = Inference(
|
17 |
+
decoder_model_name=decoder_model_name,
|
18 |
+
model_checkpoint=model_checkpoint,
|
19 |
+
)
|
20 |
+
|
21 |
+
model = SentenceTransformer('all-mpnet-base-v2')
|
22 |
+
|
23 |
+
def search_in_video(video, query):
|
24 |
+
result = torchvision.io.read_video(video)
|
25 |
+
video = result[0]
|
26 |
+
video_fps = result[2]['video_fps']
|
27 |
+
|
28 |
+
video_segments = [
|
29 |
+
video[idx:idx + frame_step, :, :, :] for idx in range(0, video.shape[0], frame_step)
|
30 |
+
]
|
31 |
+
|
32 |
+
generated_texts = []
|
33 |
+
|
34 |
+
for video_seg in video_segments:
|
35 |
+
pixel_values = utils.video2image(video_seg, encoder_model_name)
|
36 |
+
|
37 |
+
generated_text = inference.generate_text(pixel_values, encoder_model_name)
|
38 |
+
generated_texts.append(generated_text)
|
39 |
+
|
40 |
+
sentences = [query] + generated_texts
|
41 |
+
|
42 |
+
sentence_embeddings = model.encode(sentences)
|
43 |
+
|
44 |
+
similarities = cosine_similarity(
|
45 |
+
[sentence_embeddings[0]],
|
46 |
+
sentence_embeddings[1:]
|
47 |
+
)
|
48 |
+
arg_sorted_similarities = np.argsort(similarities)
|
49 |
+
|
50 |
+
ordered_similarity_scores = similarities[0][arg_sorted_similarities]
|
51 |
+
|
52 |
+
best_video = video_segments[arg_sorted_similarities[0, -1]]
|
53 |
+
torchvision.io.write_video('best.mp4', best_video, video_fps)
|
54 |
+
|
55 |
+
total_frames = video.shape[0]
|
56 |
+
|
57 |
+
video_frame_segs = [
|
58 |
+
[idx, min(idx + frame_step, total_frames)] for idx in range(0, total_frames, frame_step)
|
59 |
+
]
|
60 |
+
ordered_start_ends = []
|
61 |
+
|
62 |
+
for [start, end] in video_frame_segs:
|
63 |
+
td = timedelta(seconds=(start / video_fps))
|
64 |
+
s = round(td.total_seconds(), 2)
|
65 |
+
|
66 |
+
td = timedelta(seconds=(end / video_fps))
|
67 |
+
e = round(td.total_seconds(), 2)
|
68 |
+
|
69 |
+
ordered_start_ends.append(f'{s}:{e}')
|
70 |
+
|
71 |
+
ordered_start_ends = np.array(ordered_start_ends)[arg_sorted_similarities]
|
72 |
+
|
73 |
+
labels_to_scores = dict(
|
74 |
+
zip(ordered_start_ends[0].tolist(), ordered_similarity_scores[0].tolist())
|
75 |
+
)
|
76 |
+
|
77 |
+
return 'best.mp4', labels_to_scores
|
78 |
+
|
79 |
+
app = gr.Interface(
|
80 |
+
fn=search_in_video,
|
81 |
+
inputs=['video', 'text'],
|
82 |
+
outputs=['video', gr.outputs.Label(num_top_classes=3, type='auto')],
|
83 |
+
)
|
84 |
+
app.launch(share=True)
|
inference.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, VisionEncoderDecoderModel
|
3 |
+
|
4 |
+
import utils
|
5 |
+
|
6 |
+
class Inference:
|
7 |
+
def __init__(self, decoder_model_name, model_checkpoint, max_length=32):
|
8 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
+
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
|
11 |
+
self.encoder_decoder_model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint)
|
12 |
+
self.encoder_decoder_model.to(self.device)
|
13 |
+
|
14 |
+
self.max_length = max_length
|
15 |
+
|
16 |
+
def generate_text(self, video, encoder_model_name):
|
17 |
+
if isinstance(video, str):
|
18 |
+
pixel_values = utils.video2image_from_path(video, encoder_model_name)
|
19 |
+
else:
|
20 |
+
pixel_values = video
|
21 |
+
|
22 |
+
if not self.tokenizer.pad_token:
|
23 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
24 |
+
self.encoder_decoder_model.decoder.resize_token_embeddings(len(self.tokenizer))
|
25 |
+
|
26 |
+
generated_ids = self.encoder_decoder_model.generate(pixel_values.unsqueeze(0).to(self.device), max_length=self.max_length)
|
27 |
+
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
28 |
+
|
29 |
+
return generated_text
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk==3.7
|
2 |
+
tqdm==4.64.0
|
3 |
+
scikit-learn==1.1.1
|
4 |
+
sentence-transformers==2.2.0
|
utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ViTFeatureExtractor
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms.functional as fn
|
4 |
+
import torch as th
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
|
9 |
+
def video2image_from_path(video_path, feature_extractor_name):
|
10 |
+
video = torchvision.io.read_video(video_path)
|
11 |
+
|
12 |
+
return video2image(video[0], feature_extractor_name)
|
13 |
+
|
14 |
+
|
15 |
+
def video2image(video, feature_extractor_name):
|
16 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(
|
17 |
+
feature_extractor_name
|
18 |
+
)
|
19 |
+
|
20 |
+
vid = th.permute(video, (3, 0, 1, 2))
|
21 |
+
samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long)
|
22 |
+
vid = vid[:, samp, :, :]
|
23 |
+
|
24 |
+
im_l = list()
|
25 |
+
for i in range(vid.shape[1]):
|
26 |
+
im_l.append(vid[:, i, :, :])
|
27 |
+
|
28 |
+
inputs = feature_extractor(im_l, return_tensors="pt")
|
29 |
+
|
30 |
+
inputs = inputs['pixel_values']
|
31 |
+
|
32 |
+
im_h = list()
|
33 |
+
for i in range(7):
|
34 |
+
im_v = th.cat((inputs[0+i*7, :, :, :],
|
35 |
+
inputs[1+i*7, :, :, :],
|
36 |
+
inputs[2+i*7, :, :, :],
|
37 |
+
inputs[3+i*7, :, :, :],
|
38 |
+
inputs[4+i*7, :, :, :],
|
39 |
+
inputs[5+i*7, :, :, :],
|
40 |
+
inputs[6+i*7, :, :, :]), 2)
|
41 |
+
im_h.append(im_v)
|
42 |
+
resize = fn.resize(th.cat(im_h, 1), size=[224])
|
43 |
+
|
44 |
+
return resize
|