import os import json import jsonlines import torch import math import random import cv2 from tqdm import tqdm from collections import OrderedDict from PIL import Image from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import numpy as np import subprocess from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from torchvision.transforms import functional as F class ImageTextDataset(Dataset): """ Usage: The dataset class for image-text pairs, used for image generation training It supports multi-aspect ratio training params: anno_file: The annotation file list add_normalize: whether to normalize the input image pixel to [-1, 1], default: True ratios: The aspect ratios during training, format: width / height sizes: The resoultion of training images, format: (width, height) """ def __init__( self, anno_file, add_normalize=True, ratios=[1/1, 3/5, 5/3], sizes=[(1024, 1024), (768, 1280), (1280, 768)], crop_mode='random', p_random_ratio=0.0, ): # Ratios and Sizes : (w h) super().__init__() self.image_annos = [] if not isinstance(anno_file, list): anno_file = [anno_file] for anno_file_ in anno_file: print(f"Load image annotation files from {anno_file_}") with jsonlines.open(anno_file_, 'r') as reader: for item in reader: self.image_annos.append(item) print(f"Totally Remained {len(self.image_annos)} images") transform_list = [ transforms.ToTensor(), ] if add_normalize: transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) self.transform = transforms.Compose(transform_list) print(f"Transform List is {transform_list}") assert crop_mode in ['center', 'random'] self.crop_mode = crop_mode self.ratios = ratios self.sizes = sizes self.p_random_ratio = p_random_ratio def get_closest_size(self, x): if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: best_size_idx = np.random.randint(len(self.ratios)) else: w, h = x.width, x.height best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) return self.sizes[best_size_idx] def get_resize_size(self, orig_size, tgt_size): if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) resize_size = max(alt_min, min(tgt_size)) else: alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) resize_size = max(alt_max, max(tgt_size)) return resize_size def __len__(self): return len(self.image_annos) def __getitem__(self, index): image_anno = self.image_annos[index] try: img = Image.open(image_anno['image']).convert("RGB") text = image_anno['text'] assert isinstance(text, str), "Text should be str" size = self.get_closest_size(img) resize_size = self.get_resize_size((img.width, img.height), size) img = transforms.functional.resize(img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True) if self.crop_mode == 'center': img = transforms.functional.center_crop(img, (size[1], size[0])) elif self.crop_mode == 'random': img = transforms.RandomCrop((size[1], size[0]))(img) else: img = transforms.functional.center_crop(img, (size[1], size[0])) image_tensor = self.transform(img) return { "video": image_tensor, # using keyname `video`, to be compatible with video "text" : text, "identifier": 'image', } except Exception as e: print(f'Load Image Error with {e}') return self.__getitem__(random.randint(0, self.__len__() - 1)) class LengthGroupedVideoTextDataset(Dataset): """ Usage: The dataset class for video-text pairs, used for video generation training It groups the video with the same frames together Now only supporting fixed resolution during training params: anno_file: The annotation file list max_frames: The maximum temporal lengths (This is the vae latent temporal length) 16 => (16 - 1) * 8 + 1 = 121 frames load_vae_latent: Loading the pre-extracted vae latents during training, we recommend to extract the latents in advance to reduce the time cost per batch load_text_fea: Loading the pre-extracted text features during training, we recommend to extract the prompt textual features in advance, since the T5 encoder will cost many GPU memories """ def __init__(self, anno_file, max_frames=16, resolution='384p', load_vae_latent=True, load_text_fea=True): super().__init__() self.video_annos = [] self.max_frames = max_frames self.load_vae_latent = load_vae_latent self.load_text_fea = load_text_fea self.resolution = resolution assert load_vae_latent, "Now only support loading vae latents, we will support to directly load video frames in the future" if not isinstance(anno_file, list): anno_file = [anno_file] for anno_file_ in anno_file: with jsonlines.open(anno_file_, 'r') as reader: for item in tqdm(reader): self.video_annos.append(item) print(f"Totally Remained {len(self.video_annos)} videos") def __len__(self): return len(self.video_annos) def __getitem__(self, index): try: video_anno = self.video_annos[index] text = video_anno['text'] latent_path = video_anno['latent'] latent = torch.load(latent_path, map_location='cpu') # loading the pre-extracted video latents # TODO: remove the hard code latent shape checking if self.resolution == '384p': assert latent.shape[-1] == 640 // 8 assert latent.shape[-2] == 384 // 8 else: assert self.resolution == '768p' assert latent.shape[-1] == 1280 // 8 assert latent.shape[-2] == 768 // 8 cur_temp = latent.shape[2] cur_temp = min(cur_temp, self.max_frames) video_latent = latent[:,:,:cur_temp].float() assert video_latent.shape[1] == 16 if self.load_text_fea: text_fea_path = video_anno['text_fea'] text_fea = torch.load(text_fea_path, map_location='cpu') return { 'video': video_latent, 'prompt_embed': text_fea['prompt_embed'], 'prompt_attention_mask': text_fea['prompt_attention_mask'], 'pooled_prompt_embed': text_fea['pooled_prompt_embed'], "identifier": 'video', } else: return { 'video': video_latent, 'text': text, "identifier": 'video', } except Exception as e: print(f'Load Video Error with {e}') return self.__getitem__(random.randint(0, self.__len__() - 1)) class VideoFrameProcessor: # load a video and transform def __init__(self, resolution=256, num_frames=24, add_normalize=True, sample_fps=24): image_size = resolution transform_list = [ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True), transforms.CenterCrop(image_size), ] if add_normalize: transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) print(f"Transform List is {transform_list}") self.num_frames = num_frames self.transform = transforms.Compose(transform_list) self.sample_fps = sample_fps def __call__(self, video_path): try: video_capture = cv2.VideoCapture(video_path) fps = video_capture.get(cv2.CAP_PROP_FPS) frames = [] while True: flag, frame = video_capture.read() if not flag: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = torch.from_numpy(frame) frame = frame.permute(2, 0, 1) frames.append(frame) video_capture.release() sample_fps = self.sample_fps interval = max(int(fps / sample_fps), 1) frames = frames[::interval] if len(frames) < self.num_frames: num_frame_to_pack = self.num_frames - len(frames) recurrent_num = num_frame_to_pack // len(frames) frames = frames + recurrent_num * frames + frames[:(num_frame_to_pack % len(frames))] assert len(frames) >= self.num_frames, f'{len(frames)}' start_indexs = list(range(0, max(0, len(frames) - self.num_frames + 1))) start_index = random.choice(start_indexs) filtered_frames = frames[start_index : start_index+self.num_frames] assert len(filtered_frames) == self.num_frames, f"The sampled frames should equals to {self.num_frames}" filtered_frames = torch.stack(filtered_frames).float() / 255 filtered_frames = self.transform(filtered_frames) filtered_frames = filtered_frames.permute(1, 0, 2, 3) return filtered_frames, None except Exception as e: print(f"Load video: {video_path} Error, Exception {e}") return None, None class VideoDataset(Dataset): def __init__(self, anno_file, resolution=256, max_frames=6, add_normalize=True): super().__init__() self.video_annos = [] self.max_frames = max_frames if not isinstance(anno_file, list): anno_file = [anno_file] print(f"The training video clip frame number is {max_frames} ") for anno_file_ in anno_file: print(f"Load annotation file from {anno_file_}") with jsonlines.open(anno_file_, 'r') as reader: for item in tqdm(reader): self.video_annos.append(item) print(f"Totally Remained {len(self.video_annos)} videos") self.video_processor = VideoFrameProcessor(resolution, max_frames, add_normalize) def __len__(self): return len(self.video_annos) def __getitem__(self, index): video_anno = self.video_annos[index] video_path = video_anno['video'] try: video_tensors, video_frames = self.video_processor(video_path) assert video_tensors.shape[1] == self.max_frames return { "video": video_tensors, "identifier": 'video', } except Exception as e: print('Loading Video Error with {e}') return self.__getitem__(random.randint(0, self.__len__() - 1)) class ImageDataset(Dataset): def __init__(self, anno_file, resolution=256, max_frames=8, add_normalize=True): super().__init__() self.image_annos = [] self.max_frames = max_frames image_paths = [] if not isinstance(anno_file, list): anno_file = [anno_file] for anno_file_ in anno_file: print(f"Load annotation file from {anno_file_}") with jsonlines.open(anno_file_, 'r') as reader: for item in tqdm(reader): image_paths.append(item['image']) print(f"Totally Remained {len(image_paths)} images") # pack multiple frames for idx in range(0, len(image_paths), self.max_frames): image_path_shard = image_paths[idx : idx + self.max_frames] if len(image_path_shard) < self.max_frames: image_path_shard = image_path_shard + image_paths[:self.max_frames - len(image_path_shard)] assert len(image_path_shard) == self.max_frames self.image_annos.append(image_path_shard) image_size = resolution transform_list = [ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True), transforms.CenterCrop(image_size), transforms.ToTensor(), ] if add_normalize: transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) print(f"Transform List is {transform_list}") self.transform = transforms.Compose(transform_list) def __len__(self): return len(self.image_annos) def __getitem__(self, index): image_paths = self.image_annos[index] try: packed_pil_frames = [Image.open(image_path).convert("RGB") for image_path in image_paths] filtered_frames = [self.transform(frame) for frame in packed_pil_frames] filtered_frames = torch.stack(filtered_frames) # [t, c, h, w] filtered_frames = filtered_frames.permute(1, 0, 2, 3) # [c, t, h, w] return { "video": filtered_frames, "identifier": 'image', } except Exception as e: print(f'Load Images Error with {e}') return self.__getitem__(random.randint(0, self.__len__() - 1))