import cv2 import torch as th import os import numpy as np from decord import VideoReader, cpu class Normalize(object): def __init__(self, mean, std): self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) self.std = th.FloatTensor(std).view(1, 3, 1, 1) def __call__(self, tensor): tensor = (tensor - self.mean) / (self.std + 1e-8) return tensor class Preprocessing(object): def __init__(self): self.norm = Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], ) def __call__(self, tensor): tensor = tensor / 255.0 tensor = self.norm(tensor) return tensor class VideoLoader: """Pytorch video loader.""" def __init__( self, framerate=1, size=224, centercrop=True, ): self.centercrop = centercrop self.size = size self.framerate = framerate self.preprocess = Preprocessing() self.max_feats = 10 self.features_dim = 768 def _get_video_dim(self, video_path): vr = VideoReader(video_path, ctx=cpu(0)) height, width, _ = vr[0].shape frame_rate = vr.get_avg_fps() return height, width, frame_rate def _get_output_dim(self, h, w): if isinstance(self.size, tuple) and len(self.size) == 2: return self.size elif h >= w: return int(h * self.size / w), self.size else: return self.size, int(w * self.size / h) def _getvideo(self, video_path): if os.path.isfile(video_path): print("Decoding video: {}".format(video_path)) try: h, w, fr = self._get_video_dim(video_path) except: print("ffprobe failed at: {}".format(video_path)) return { "video": th.zeros(1), "input": video_path } if fr < 1: print("Corrupted Frame Rate: {}".format(video_path)) return { "video": th.zeros(1), "input": video_path } height, width = self._get_output_dim(h, w) # resize ## vr = VideoReader(video_path, ctx=cpu(0)) video = vr.get_batch(range(0, len(vr), int(fr))).asnumpy() video = np.array([cv2.resize(frame, (width, height)) for frame in video]) try: if self.centercrop: x = int((width - self.size) / 2.0) y = int((height - self.size) / 2.0) video = video[:, y:y+self.size, x:x+self.size, :] except: print("ffmpeg error at: {}".format(video_path)) return { "video": th.zeros(1), "input": video_path, } if self.centercrop and isinstance(self.size, int): height, width = self.size, self.size video = th.from_numpy(video.astype("float32")) video = video.permute(0, 3, 1, 2) # t,c,h,w else: video = th.zeros(1) return {"video": video, "input": video_path} def __call__(self, video_path): video = self._getvideo(video_path)['video'] if len(video) > self.max_feats: sampled = [] for j in range(self.max_feats): sampled.append(video[(j * len(video)) // self.max_feats]) video = th.stack(sampled) video_len = self.max_feats elif len(video) < self.max_feats: video_len = len(video) video = th.cat( [video, th.zeros(self.max_feats - video_len, self.features_dim)], 0 ) video = self.preprocess(video) return video, video_len