Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
from tqdm import tqdm | |
import pandas as pd | |
from decord import VideoReader, cpu | |
import torch | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
class WebVid(Dataset): | |
""" | |
WebVid Dataset. | |
Assumes webvid data is structured as follows. | |
Webvid/ | |
videos/ | |
000001_000050/ ($page_dir) | |
1.mp4 (videoid.mp4) | |
... | |
5000.mp4 | |
... | |
""" | |
def __init__(self, | |
meta_path, | |
data_dir, | |
subsample=None, | |
video_length=16, | |
resolution=[256, 512], | |
frame_stride=1, | |
frame_stride_min=1, | |
spatial_transform=None, | |
crop_resolution=None, | |
fps_max=None, | |
load_raw_resolution=False, | |
fixed_fps=None, | |
random_fs=False, | |
): | |
self.meta_path = meta_path | |
self.data_dir = data_dir | |
self.subsample = subsample | |
self.video_length = video_length | |
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution | |
self.fps_max = fps_max | |
self.frame_stride = frame_stride | |
self.frame_stride_min = frame_stride_min | |
self.fixed_fps = fixed_fps | |
self.load_raw_resolution = load_raw_resolution | |
self.random_fs = random_fs | |
self._load_metadata() | |
if spatial_transform is not None: | |
if spatial_transform == "random_crop": | |
self.spatial_transform = transforms.RandomCrop(crop_resolution) | |
elif spatial_transform == "center_crop": | |
self.spatial_transform = transforms.Compose([ | |
transforms.CenterCrop(resolution), | |
]) | |
elif spatial_transform == "resize_center_crop": | |
# assert(self.resolution[0] == self.resolution[1]) | |
self.spatial_transform = transforms.Compose([ | |
transforms.Resize(min(self.resolution)), | |
transforms.CenterCrop(self.resolution), | |
]) | |
elif spatial_transform == "resize": | |
self.spatial_transform = transforms.Resize(self.resolution) | |
else: | |
raise NotImplementedError | |
else: | |
self.spatial_transform = None | |
def _load_metadata(self): | |
metadata = pd.read_csv(self.meta_path) | |
print(f'>>> {len(metadata)} data samples loaded.') | |
if self.subsample is not None: | |
metadata = metadata.sample(self.subsample, random_state=0) | |
metadata['caption'] = metadata['name'] | |
del metadata['name'] | |
self.metadata = metadata | |
self.metadata.dropna(inplace=True) | |
def _get_video_path(self, sample): | |
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') | |
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) | |
return full_video_fp | |
def __getitem__(self, index): | |
if self.random_fs: | |
frame_stride = random.randint(self.frame_stride_min, self.frame_stride) | |
else: | |
frame_stride = self.frame_stride | |
## get frames until success | |
while True: | |
index = index % len(self.metadata) | |
sample = self.metadata.iloc[index] | |
video_path = self._get_video_path(sample) | |
## video_path should be in the format of "....../WebVid/videos/$page_dir/$videoid.mp4" | |
caption = sample['caption'] | |
try: | |
if self.load_raw_resolution: | |
video_reader = VideoReader(video_path, ctx=cpu(0)) | |
else: | |
video_reader = VideoReader(video_path, ctx=cpu(0), width=530, height=300) | |
if len(video_reader) < self.video_length: | |
print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})") | |
index += 1 | |
continue | |
else: | |
pass | |
except: | |
index += 1 | |
print(f"Load video failed! path = {video_path}") | |
continue | |
fps_ori = video_reader.get_avg_fps() | |
if self.fixed_fps is not None: | |
frame_stride = int(frame_stride * (1.0 * fps_ori / self.fixed_fps)) | |
## to avoid extreme cases when fixed_fps is used | |
frame_stride = max(frame_stride, 1) | |
## get valid range (adapting case by case) | |
required_frame_num = frame_stride * (self.video_length-1) + 1 | |
frame_num = len(video_reader) | |
if frame_num < required_frame_num: | |
## drop extra samples if fixed fps is required | |
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5: | |
index += 1 | |
continue | |
else: | |
frame_stride = frame_num // self.video_length | |
required_frame_num = frame_stride * (self.video_length-1) + 1 | |
## select a random clip | |
random_range = frame_num - required_frame_num | |
start_idx = random.randint(0, random_range) if random_range > 0 else 0 | |
## calculate frame indices | |
frame_indices = [start_idx + frame_stride*i for i in range(self.video_length)] | |
try: | |
frames = video_reader.get_batch(frame_indices) | |
break | |
except: | |
print(f"Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]") | |
index += 1 | |
continue | |
## process data | |
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' | |
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] | |
if self.spatial_transform is not None: | |
frames = self.spatial_transform(frames) | |
if self.resolution is not None: | |
assert (frames.shape[2], frames.shape[3]) == (self.resolution[0], self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' | |
## turn frames tensors to [-1,1] | |
frames = (frames / 255 - 0.5) * 2 | |
fps_clip = fps_ori // frame_stride | |
if self.fps_max is not None and fps_clip > self.fps_max: | |
fps_clip = self.fps_max | |
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride} | |
return data | |
def __len__(self): | |
return len(self.metadata) | |
if __name__== "__main__": | |
meta_path = "" ## path to the meta file | |
data_dir = "" ## path to the data directory | |
save_dir = "" ## path to the save directory | |
dataset = WebVid(meta_path, | |
data_dir, | |
subsample=None, | |
video_length=16, | |
resolution=[256,448], | |
frame_stride=4, | |
spatial_transform="resize_center_crop", | |
crop_resolution=None, | |
fps_max=None, | |
load_raw_resolution=True | |
) | |
dataloader = DataLoader(dataset, | |
batch_size=1, | |
num_workers=0, | |
shuffle=False) | |
import sys | |
sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) | |
from utils.save_video import tensor_to_mp4 | |
for i, batch in tqdm(enumerate(dataloader), desc="Data Batch"): | |
video = batch['video'] | |
name = batch['path'][0].split('videos/')[-1].replace('/','_') | |
tensor_to_mp4(video, save_dir+'/'+name, fps=8) | |