|
import cv2 |
|
import torch as th |
|
import os |
|
import numpy as np |
|
from decord import VideoReader, cpu |
|
|
|
|
|
class Normalize(object): |
|
def __init__(self, mean, std): |
|
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) |
|
self.std = th.FloatTensor(std).view(1, 3, 1, 1) |
|
|
|
def __call__(self, tensor): |
|
tensor = (tensor - self.mean) / (self.std + 1e-8) |
|
return tensor |
|
|
|
|
|
class Preprocessing(object): |
|
def __init__(self): |
|
self.norm = Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711], |
|
) |
|
|
|
def __call__(self, tensor): |
|
tensor = tensor / 255.0 |
|
tensor = self.norm(tensor) |
|
return tensor |
|
|
|
|
|
class VideoLoader: |
|
"""Pytorch video loader.""" |
|
|
|
def __init__( |
|
self, |
|
framerate=1, |
|
size=224, |
|
centercrop=True, |
|
): |
|
self.centercrop = centercrop |
|
self.size = size |
|
self.framerate = framerate |
|
self.preprocess = Preprocessing() |
|
self.max_feats = 10 |
|
self.features_dim = 768 |
|
|
|
def _get_video_dim(self, video_path): |
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
height, width, _ = vr[0].shape |
|
frame_rate = vr.get_avg_fps() |
|
return height, width, frame_rate |
|
|
|
def _get_output_dim(self, h, w): |
|
if isinstance(self.size, tuple) and len(self.size) == 2: |
|
return self.size |
|
elif h >= w: |
|
return int(h * self.size / w), self.size |
|
else: |
|
return self.size, int(w * self.size / h) |
|
|
|
def _getvideo(self, video_path): |
|
|
|
if os.path.isfile(video_path): |
|
print("Decoding video: {}".format(video_path)) |
|
try: |
|
h, w, fr = self._get_video_dim(video_path) |
|
except: |
|
print("ffprobe failed at: {}".format(video_path)) |
|
return { |
|
"video": th.zeros(1), |
|
"input": video_path |
|
} |
|
if fr < 1: |
|
print("Corrupted Frame Rate: {}".format(video_path)) |
|
return { |
|
"video": th.zeros(1), |
|
"input": video_path |
|
} |
|
height, width = self._get_output_dim(h, w) |
|
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
video = vr.get_batch(range(0, len(vr), int(fr))).asnumpy() |
|
video = np.array([cv2.resize(frame, (width, height)) for frame in video]) |
|
try: |
|
|
|
if self.centercrop: |
|
x = int((width - self.size) / 2.0) |
|
y = int((height - self.size) / 2.0) |
|
video = video[:, y:y+self.size, x:x+self.size, :] |
|
|
|
except: |
|
print("ffmpeg error at: {}".format(video_path)) |
|
return { |
|
"video": th.zeros(1), |
|
"input": video_path, |
|
} |
|
if self.centercrop and isinstance(self.size, int): |
|
height, width = self.size, self.size |
|
video = th.from_numpy(video.astype("float32")) |
|
video = video.permute(0, 3, 1, 2) |
|
else: |
|
video = th.zeros(1) |
|
|
|
return {"video": video, "input": video_path} |
|
|
|
def __call__(self, video_path): |
|
|
|
video = self._getvideo(video_path)['video'] |
|
|
|
if len(video) > self.max_feats: |
|
sampled = [] |
|
for j in range(self.max_feats): |
|
sampled.append(video[(j * len(video)) // self.max_feats]) |
|
video = th.stack(sampled) |
|
video_len = self.max_feats |
|
elif len(video) < self.max_feats: |
|
video_len = len(video) |
|
video = th.cat( |
|
[video, th.zeros(self.max_feats - video_len, self.features_dim)], 0 |
|
) |
|
video = self.preprocess(video) |
|
return video, video_len |
|
|