|
import cv2 |
|
import torch as th |
|
import os |
|
import numpy as np |
|
from decord import VideoReader, cpu |
|
import ffmpeg |
|
|
|
|
|
class Normalize(object): |
|
def __init__(self, mean, std): |
|
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) |
|
self.std = th.FloatTensor(std).view(1, 3, 1, 1) |
|
|
|
def __call__(self, tensor): |
|
tensor = (tensor - self.mean) / (self.std + 1e-8) |
|
return tensor |
|
|
|
|
|
class Preprocessing(object): |
|
def __init__(self): |
|
self.norm = Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711], |
|
) |
|
|
|
def __call__(self, tensor): |
|
tensor = tensor / 255.0 |
|
tensor = self.norm(tensor) |
|
return tensor |
|
|
|
|
|
class VideoLoader: |
|
"""Pytorch video loader.""" |
|
|
|
def __init__( |
|
self, |
|
framerate=1, |
|
size=224, |
|
centercrop=True, |
|
): |
|
self.centercrop = centercrop |
|
self.size = size |
|
self.framerate = framerate |
|
self.preprocess = Preprocessing() |
|
self.max_feats = 10 |
|
self.features_dim = 768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_output_dim(self, h, w): |
|
if isinstance(self.size, tuple) and len(self.size) == 2: |
|
return self.size |
|
elif h >= w: |
|
return int(h * self.size / w), self.size |
|
else: |
|
return self.size, int(w * self.size / h) |
|
|
|
def _get_video_dim(self, video_path): |
|
probe = ffmpeg.probe(video_path) |
|
video_stream = next( |
|
(stream for stream in probe["streams"] if stream["codec_type"] == "video"), |
|
None, |
|
) |
|
width = int(video_stream["width"]) |
|
height = int(video_stream["height"]) |
|
num, denum = video_stream["avg_frame_rate"].split("/") |
|
frame_rate = int(num) / int(denum) |
|
return height, width, frame_rate |
|
|
|
|
|
def _getvideo(self, video_path): |
|
|
|
if os.path.isfile(video_path): |
|
print("Decoding video: {}".format(video_path)) |
|
try: |
|
h, w, fr = self._get_video_dim(video_path) |
|
except: |
|
print("ffprobe failed at: {}".format(video_path)) |
|
return { |
|
"video": th.zeros(1), |
|
"input": video_path |
|
} |
|
if fr < 1: |
|
print("Corrupted Frame Rate: {}".format(video_path)) |
|
return { |
|
"video": th.zeros(1), |
|
"input": video_path |
|
} |
|
height, width = self._get_output_dim(h, w) |
|
|
|
try: |
|
cmd = ( |
|
ffmpeg.input(video_path) |
|
.filter("fps", fps=self.framerate) |
|
.filter("scale", width, height) |
|
) |
|
if self.centercrop: |
|
x = int((width - self.size) / 2.0) |
|
y = int((height - self.size) / 2.0) |
|
cmd = cmd.crop(x, y, self.size, self.size) |
|
out, _ = cmd.output("pipe:", format="rawvideo", pix_fmt="rgb24").run( |
|
capture_stdout=True, quiet=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except: |
|
print("ffmpeg error at: {}".format(video_path)) |
|
return { |
|
"video": th.zeros(1), |
|
"input": video_path, |
|
} |
|
if self.centercrop and isinstance(self.size, int): |
|
height, width = self.size, self.size |
|
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) |
|
video = th.from_numpy(video.astype("float32")) |
|
video = video.permute(0, 3, 1, 2) |
|
else: |
|
video = th.zeros(1) |
|
|
|
return {"video": video, "input": video_path} |
|
|
|
def __call__(self, video_path): |
|
|
|
video = self._getvideo(video_path)['video'] |
|
|
|
if len(video) > self.max_feats: |
|
sampled = [] |
|
for j in range(self.max_feats): |
|
sampled.append(video[(j * len(video)) // self.max_feats]) |
|
video = th.stack(sampled) |
|
video_len = self.max_feats |
|
elif len(video) < self.max_feats: |
|
video_len = len(video) |
|
video = th.cat( |
|
[video, th.zeros(self.max_feats - video_len, self.features_dim)], 0 |
|
) |
|
video = self.preprocess(video) |
|
return video, video_len |
|
|