File size: 4,044 Bytes
8a8dad9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os, csv, random
import numpy as np
from decord import VideoReader
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
class ChronoMagic(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=512, sample_stride=4, sample_n_frames=16,
is_image=False,
is_uniform=True,
):
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
self.is_uniform = is_uniform
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0], interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def _get_frame_indices_adjusted(self, video_length, n_frames):
indices = list(range(video_length))
additional_frames_needed = n_frames - video_length
repeat_indices = []
for i in range(additional_frames_needed):
index_to_repeat = i % video_length
repeat_indices.append(indices[index_to_repeat])
all_indices = indices + repeat_indices
all_indices.sort()
return all_indices
def _generate_frame_indices(self, video_length, n_frames, sample_stride, is_transmit):
prob_execute_original = 1 if int(is_transmit) == 0 else 0
# Generate a random number to decide which block of code to execute
if random.random() < prob_execute_original:
if video_length <= n_frames:
return self._get_frame_indices_adjusted(video_length, n_frames)
else:
interval = (video_length - 1) / (n_frames - 1)
indices = [int(round(i * interval)) for i in range(n_frames)]
indices[-1] = video_length - 1
return indices
else:
if video_length <= n_frames:
return self._get_frame_indices_adjusted(video_length, n_frames)
else:
clip_length = min(video_length, (n_frames - 1) * sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
return np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, is_transmit = video_dict['videoid'], video_dict['name'], video_dict['is_transmit']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = VideoReader(video_dir, num_threads=0)
video_length = len(video_reader)
batch_index = self._generate_frame_indices(video_length, self.sample_n_frames, self.sample_stride, is_transmit) if not self.is_image else [random.randint(0, video_length - 1)]
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2) / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name, videoid
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values, name, videoid = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length-1)
pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name, id=videoid)
return sample |