search-in-video / utils.py
Armen Gabrielyan
add initial app
5e95a58
raw
history blame
1.27 kB
from transformers import ViTFeatureExtractor
import torchvision
import torchvision.transforms.functional as fn
import torch as th
import os
import pickle
def video2image_from_path(video_path, feature_extractor_name):
video = torchvision.io.read_video(video_path)
return video2image(video[0], feature_extractor_name)
def video2image(video, feature_extractor_name):
feature_extractor = ViTFeatureExtractor.from_pretrained(
feature_extractor_name
)
vid = th.permute(video, (3, 0, 1, 2))
samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long)
vid = vid[:, samp, :, :]
im_l = list()
for i in range(vid.shape[1]):
im_l.append(vid[:, i, :, :])
inputs = feature_extractor(im_l, return_tensors="pt")
inputs = inputs['pixel_values']
im_h = list()
for i in range(7):
im_v = th.cat((inputs[0+i*7, :, :, :],
inputs[1+i*7, :, :, :],
inputs[2+i*7, :, :, :],
inputs[3+i*7, :, :, :],
inputs[4+i*7, :, :, :],
inputs[5+i*7, :, :, :],
inputs[6+i*7, :, :, :]), 2)
im_h.append(im_v)
resize = fn.resize(th.cat(im_h, 1), size=[224])
return resize