Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,725 Bytes
fcb4edd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
from glob import glob
import random
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data.dataset import Dataset
class StableVideoDataset(Dataset):
def __init__(self,
video_data_dir,
max_num_videos=None,
frame_hight=576, frame_width=1024, num_frames=14,
is_reverse_video=True,
random_seed=42,
double_sampling_rate=False,
):
self.video_data_dir = video_data_dir
video_names = sorted([video for video in os.listdir(video_data_dir)
if os.path.isdir(os.path.join(video_data_dir, video))])
self.length = min(len(video_names), max_num_videos) if max_num_videos is not None else len(video_names)
self.video_names = video_names[:self.length]
if double_sampling_rate:
self.sample_frames = num_frames*2-1
self.sample_stride = 2
else:
self.sample_frames = num_frames
self.sample_stride = 1
self.frame_width = frame_width
self.frame_height = frame_hight
self.pixel_transforms = transforms.Compose([
transforms.Resize((self.frame_height, self.frame_width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
self.is_reverse_video=is_reverse_video
np.random.seed(random_seed)
def get_batch(self, idx):
video_name = self.video_names[idx]
video_frame_paths = sorted(glob(os.path.join(self.video_data_dir, video_name, '*.png')))
start_idx = np.random.randint(len(video_frame_paths)-self.sample_frames+1)
video_frame_paths = video_frame_paths[start_idx:start_idx+self.sample_frames:self.sample_stride]
video_frames = [np.asarray(Image.open(frame_path).convert('RGB')).astype(np.float32)/255.0 for frame_path in video_frame_paths]
video_frames = np.stack(video_frames, axis=0)
pixel_values = torch.from_numpy(video_frames.transpose(0, 3, 1, 2))
return pixel_values
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length-1)
pixel_values = self.pixel_transforms(pixel_values)
conditions = pixel_values[-1]
if self.is_reverse_video:
pixel_values = torch.flip(pixel_values, (0,))
sample = dict(pixel_values=pixel_values, conditions=conditions)
return sample |