search-in-video / utils.py
Armen Gabrielyan
change from auto model to vision encoder decoder model
16b0970
raw
history blame
1.24 kB
from transformers import ViTFeatureExtractor
import torchvision
import torchvision.transforms.functional as fn
import torch as th
def video2image_from_path(video_path, feature_extractor_name):
video = torchvision.io.read_video(video_path)
return video2image(video[0], feature_extractor_name)
def video2image(video, feature_extractor_name):
feature_extractor = ViTFeatureExtractor.from_pretrained(
feature_extractor_name
)
vid = th.permute(video, (3, 0, 1, 2))
samp = th.linspace(0, vid.shape[1]-1, 49, dtype=th.long)
vid = vid[:, samp, :, :]
im_l = list()
for i in range(vid.shape[1]):
im_l.append(vid[:, i, :, :])
inputs = feature_extractor(im_l, return_tensors="pt")
inputs = inputs['pixel_values']
im_h = list()
for i in range(7):
im_v = th.cat((inputs[0+i*7, :, :, :],
inputs[1+i*7, :, :, :],
inputs[2+i*7, :, :, :],
inputs[3+i*7, :, :, :],
inputs[4+i*7, :, :, :],
inputs[5+i*7, :, :, :],
inputs[6+i*7, :, :, :]), 2)
im_h.append(im_v)
resize = fn.resize(th.cat(im_h, 1), size=[224])
return resize