Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
from tqdm import tqdm | |
import pandas as pd | |
from decord import VideoReader, cpu | |
import torch | |
from torch.utils.data import Dataset | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
#import torchvision.transforms._transforms_video as transforms_video | |
def string_not_contains_any(substrings, target_string): | |
return not any(substring in target_string for substring in substrings) | |
word = ['digital', 'Digital', 'DIGITAL', 'concept', 'Concept', 'CONCEPT', 'abstract', 'Abstract', 'ABSTRACT', 'particle', 'Particle', 'PARTICLE', 'loop', 'Loop','LOOP'] | |
class WebVid(Dataset): | |
""" | |
WebVid Dataset. | |
Assumes webvid data is structured as follows. | |
Webvid/ | |
videos/ | |
000001_000050/ ($page_dir) | |
1.mp4 (videoid.mp4) | |
... | |
5000.mp4 | |
... | |
""" | |
def __init__(self, | |
meta_path, | |
data_dir, | |
subsample=None, | |
video_length=16, | |
resolution=[256, 512], | |
frame_stride=1, | |
frame_stride_min=1, | |
spatial_transform=None, | |
crop_resolution=None, | |
fps_max=None, | |
load_raw_resolution=False, | |
fixed_fps=None, | |
random_fs=False, | |
filter_CG=False, | |
human_dynamic=False, | |
sample_basedon_keyframe=False, | |
): | |
self.meta_path = meta_path | |
self.data_dir = data_dir | |
self.subsample = subsample | |
self.video_length = video_length | |
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution | |
self.fps_max = fps_max | |
self.frame_stride = frame_stride | |
self.frame_stride_min = frame_stride_min | |
self.fixed_fps = fixed_fps | |
self.load_raw_resolution = load_raw_resolution | |
self.random_fs = random_fs | |
self.filter_CG = filter_CG | |
self.human_dynamic = human_dynamic | |
self.sample_basedon_keyframe = sample_basedon_keyframe | |
self._load_metadata() | |
if spatial_transform is not None: | |
if spatial_transform == "random_crop": | |
self.spatial_transform = transforms.RandomCrop(crop_resolution) | |
elif spatial_transform == "center_crop": | |
self.spatial_transform = transforms.Compose([ | |
transforms.CenterCrop(resolution), | |
]) | |
elif spatial_transform == "resize_center_crop": | |
# assert(self.resolution[0] == self.resolution[1]) | |
self.spatial_transform = transforms.Compose([ | |
transforms.Resize(min(self.resolution)), | |
transforms.CenterCrop(self.resolution), | |
]) | |
elif spatial_transform == "resize": | |
self.spatial_transform = transforms.Compose([ | |
transforms.Resize((self.resolution)), | |
]) | |
else: | |
raise NotImplementedError | |
else: | |
self.spatial_transform = None | |
def _load_metadata(self): | |
metadata = pd.read_csv(self.meta_path) | |
print('Loaded: ', len(metadata)) | |
metadata['caption'] = metadata['name'] | |
del metadata['name'] | |
self.metadata = metadata | |
self.metadata.dropna(inplace=True) | |
def _get_video_path(self, sample): | |
full_video_fp = os.path.join(self.data_dir, sample['oripath'][1:] if sample['oripath'][0] == '/' else sample['oripath']) | |
cond_full_video_fp = os.path.join(self.data_dir, sample['videopath'][1:] if sample['videopath'][0] == '/' else sample['videopath']) | |
return full_video_fp, cond_full_video_fp | |
def __getitem__(self, index): | |
## | |
if self.random_fs: | |
frame_stride = random.randint(self.frame_stride_min, self.frame_stride) | |
else: | |
frame_stride = self.frame_stride | |
## get frames until success | |
while True: | |
index = index % len(self.metadata) | |
sample = self.metadata.iloc[index] | |
video_path, cond_video_path = self._get_video_path(sample) | |
#video_path = "/apdcephfs/share_1290939/0_public_datasets/WebVid/videos/002001_002050/1023214570.mp4" | |
caption = sample['caption'] | |
frameid = int(sample['frameid']) | |
try: | |
if self.load_raw_resolution: | |
video_reader = VideoReader(video_path, ctx=cpu(0)) | |
cond_video_reader = VideoReader(cond_video_path, ctx=cpu(0)) | |
else: | |
NotImplementedError("Must use load_raw_resolution=True") | |
if len(video_reader) < self.video_length or len(cond_video_reader) < self.video_length: | |
print(f"video length ({len(video_reader)}) or Cond video length ({len(cond_video_reader)}) is smaller than target length({self.video_length})") | |
index += 1 | |
continue | |
else: | |
pass | |
except: | |
index += 1 | |
print(f"Load video failed! path = {video_path}") | |
continue | |
frame_stride = 1 | |
start_idx = 0 | |
frame_indices = [start_idx + frame_stride*i for i in range(self.video_length)] | |
try: | |
frames = video_reader.get_batch(frame_indices) | |
frames_cond = cond_video_reader.get_batch(frame_indices) | |
break | |
except: | |
print(f"Get frames failed! path = {video_path}") | |
index += 1 | |
continue | |
## process data | |
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' | |
assert(frames_cond.shape[0] == self.video_length),f'{len(frames_cond)}, self.video_length={self.video_length}' | |
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] | |
frames_cond = torch.tensor(frames_cond.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] | |
if self.spatial_transform is not None: | |
frames = self.spatial_transform(frames) | |
frames_cond = self.spatial_transform(frames_cond) | |
if self.resolution is not None: | |
assert (frames.shape[2], frames.shape[3]) == (self.resolution[0], self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' | |
frames = (frames / 255 - 0.5) * 2 | |
frames_cond = (frames_cond / 255 - 0.5) * 2 | |
frames_cond[:,frameid,:,:] = frames[:,frameid,:,:] | |
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': 10, 'frame_stride': frame_stride, 'video_cond': frames_cond, 'frameid': frameid} | |
return data | |
def __len__(self): | |
return len(self.metadata) |