File size: 2,725 Bytes
fcb4edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
from glob import glob
import random
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data.dataset import Dataset

class StableVideoDataset(Dataset):
    def __init__(self, 
        video_data_dir, 
        max_num_videos=None,
        frame_hight=576, frame_width=1024, num_frames=14,
        is_reverse_video=True,
        random_seed=42,
        double_sampling_rate=False,
    ):  
        self.video_data_dir = video_data_dir
        video_names = sorted([video for video in os.listdir(video_data_dir) 
                    if os.path.isdir(os.path.join(video_data_dir, video))])
        
        self.length = min(len(video_names), max_num_videos) if max_num_videos is not None else len(video_names)
        
        self.video_names = video_names[:self.length]
        if double_sampling_rate:
            self.sample_frames = num_frames*2-1
            self.sample_stride = 2
        else:
            self.sample_frames = num_frames
            self.sample_stride = 1

        self.frame_width = frame_width
        self.frame_height = frame_hight
        self.pixel_transforms = transforms.Compose([
            transforms.Resize((self.frame_height, self.frame_width), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
        self.is_reverse_video=is_reverse_video
        np.random.seed(random_seed)
        
    def get_batch(self, idx):
        video_name = self.video_names[idx]
        video_frame_paths = sorted(glob(os.path.join(self.video_data_dir, video_name, '*.png')))
        start_idx = np.random.randint(len(video_frame_paths)-self.sample_frames+1)
        video_frame_paths = video_frame_paths[start_idx:start_idx+self.sample_frames:self.sample_stride]
        video_frames = [np.asarray(Image.open(frame_path).convert('RGB')).astype(np.float32)/255.0 for frame_path in video_frame_paths]
        video_frames = np.stack(video_frames, axis=0)
        pixel_values = torch.from_numpy(video_frames.transpose(0, 3, 1, 2))
        return pixel_values

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values = self.get_batch(idx)
                break

            except Exception as e:
                idx = random.randint(0, self.length-1)

        pixel_values = self.pixel_transforms(pixel_values)
        conditions = pixel_values[-1]
        if self.is_reverse_video:
            pixel_values = torch.flip(pixel_values, (0,))
            
        sample = dict(pixel_values=pixel_values, conditions=conditions)
        return sample