ViewCrafter / lvdm /data /webvid.py
Drexubery's picture
update
df13f4b
raw
history blame
8.13 kB
import os
import random
from tqdm import tqdm
import pandas as pd
from decord import VideoReader, cpu
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
class WebVid(Dataset):
"""
WebVid Dataset.
Assumes webvid data is structured as follows.
Webvid/
videos/
000001_000050/ ($page_dir)
1.mp4 (videoid.mp4)
...
5000.mp4
...
"""
def __init__(self,
meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256, 512],
frame_stride=1,
frame_stride_min=1,
spatial_transform=None,
crop_resolution=None,
fps_max=None,
load_raw_resolution=False,
fixed_fps=None,
random_fs=False,
):
self.meta_path = meta_path
self.data_dir = data_dir
self.subsample = subsample
self.video_length = video_length
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution
self.fps_max = fps_max
self.frame_stride = frame_stride
self.frame_stride_min = frame_stride_min
self.fixed_fps = fixed_fps
self.load_raw_resolution = load_raw_resolution
self.random_fs = random_fs
self._load_metadata()
if spatial_transform is not None:
if spatial_transform == "random_crop":
self.spatial_transform = transforms.RandomCrop(crop_resolution)
elif spatial_transform == "center_crop":
self.spatial_transform = transforms.Compose([
transforms.CenterCrop(resolution),
])
elif spatial_transform == "resize_center_crop":
# assert(self.resolution[0] == self.resolution[1])
self.spatial_transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
elif spatial_transform == "resize":
self.spatial_transform = transforms.Resize(self.resolution)
else:
raise NotImplementedError
else:
self.spatial_transform = None
def _load_metadata(self):
metadata = pd.read_csv(self.meta_path)
print(f'>>> {len(metadata)} data samples loaded.')
if self.subsample is not None:
metadata = metadata.sample(self.subsample, random_state=0)
metadata['caption'] = metadata['name']
del metadata['name']
self.metadata = metadata
self.metadata.dropna(inplace=True)
def _get_video_path(self, sample):
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
return full_video_fp
def __getitem__(self, index):
if self.random_fs:
frame_stride = random.randint(self.frame_stride_min, self.frame_stride)
else:
frame_stride = self.frame_stride
## get frames until success
while True:
index = index % len(self.metadata)
sample = self.metadata.iloc[index]
video_path = self._get_video_path(sample)
## video_path should be in the format of "....../WebVid/videos/$page_dir/$videoid.mp4"
caption = sample['caption']
try:
if self.load_raw_resolution:
video_reader = VideoReader(video_path, ctx=cpu(0))
else:
video_reader = VideoReader(video_path, ctx=cpu(0), width=530, height=300)
if len(video_reader) < self.video_length:
print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})")
index += 1
continue
else:
pass
except:
index += 1
print(f"Load video failed! path = {video_path}")
continue
fps_ori = video_reader.get_avg_fps()
if self.fixed_fps is not None:
frame_stride = int(frame_stride * (1.0 * fps_ori / self.fixed_fps))
## to avoid extreme cases when fixed_fps is used
frame_stride = max(frame_stride, 1)
## get valid range (adapting case by case)
required_frame_num = frame_stride * (self.video_length-1) + 1
frame_num = len(video_reader)
if frame_num < required_frame_num:
## drop extra samples if fixed fps is required
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
index += 1
continue
else:
frame_stride = frame_num // self.video_length
required_frame_num = frame_stride * (self.video_length-1) + 1
## select a random clip
random_range = frame_num - required_frame_num
start_idx = random.randint(0, random_range) if random_range > 0 else 0
## calculate frame indices
frame_indices = [start_idx + frame_stride*i for i in range(self.video_length)]
try:
frames = video_reader.get_batch(frame_indices)
break
except:
print(f"Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]")
index += 1
continue
## process data
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}'
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w]
if self.spatial_transform is not None:
frames = self.spatial_transform(frames)
if self.resolution is not None:
assert (frames.shape[2], frames.shape[3]) == (self.resolution[0], self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}'
## turn frames tensors to [-1,1]
frames = (frames / 255 - 0.5) * 2
fps_clip = fps_ori // frame_stride
if self.fps_max is not None and fps_clip > self.fps_max:
fps_clip = self.fps_max
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride}
return data
def __len__(self):
return len(self.metadata)
if __name__== "__main__":
meta_path = "" ## path to the meta file
data_dir = "" ## path to the data directory
save_dir = "" ## path to the save directory
dataset = WebVid(meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256,448],
frame_stride=4,
spatial_transform="resize_center_crop",
crop_resolution=None,
fps_max=None,
load_raw_resolution=True
)
dataloader = DataLoader(dataset,
batch_size=1,
num_workers=0,
shuffle=False)
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
from utils.save_video import tensor_to_mp4
for i, batch in tqdm(enumerate(dataloader), desc="Data Batch"):
video = batch['video']
name = batch['path'][0].split('videos/')[-1].replace('/','_')
tensor_to_mp4(video, save_dir+'/'+name, fps=8)