import os import random import json import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.transforms.functional as F import numpy as np from decord import VideoReader from torch.utils.data.dataset import Dataset from packaging import version as pver class RandomHorizontalFlipWithPose(nn.Module): def __init__(self, p=0.5): super(RandomHorizontalFlipWithPose, self).__init__() self.p = p def get_flip_flag(self, n_image): return torch.rand(n_image) < self.p def forward(self, image, flip_flag=None): n_image = image.shape[0] if flip_flag is not None: assert n_image == flip_flag.shape[0] else: flip_flag = self.get_flip_flag(n_image) ret_images = [] for fflag, img in zip(flip_flag, image): if fflag: ret_images.append(F.hflip(img)) else: ret_images.append(img) return torch.stack(ret_images, dim=0) class Camera(object): def __init__(self, entry): fx, fy, cx, cy = entry[1:5] self.fx = fx self.fy = fy self.cx = cx self.cy = cy w2c_mat = np.array(entry[7:]).reshape(3, 4) w2c_mat_4x4 = np.eye(4) w2c_mat_4x4[:3, :] = w2c_mat self.w2c_mat = w2c_mat_4x4 self.c2w_mat = np.linalg.inv(w2c_mat_4x4) def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse('1.10'): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing='ij') def ray_condition(K, c2w, H, W, device, flip_flag=None): # c2w: B, V, 4, 4 # K: B, V, 4 B, V = K.shape[:2] j, i = custom_meshgrid( torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), ) i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 if n_flip > 0: j_flip, i_flip = custom_meshgrid( torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) ) i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 i[:, flip_flag, ...] = i_flip j[:, flip_flag, ...] = j_flip fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 zs = torch.ones_like(i) # [B, V, HxW] xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs zs = zs.expand_as(ys) directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 rays_o = c2w[..., :3, 3] # B, V, 3 rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 # c2w @ dirctions rays_dxo = torch.linalg.cross(rays_o, rays_d) # B, V, HW, 3 plucker = torch.cat([rays_dxo, rays_d], dim=-1) plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 # plucker = plucker.permute(0, 1, 4, 2, 3) return plucker class RealEstate10K(Dataset): def __init__( self, root_path, annotation_json, sample_stride=4, sample_n_frames=16, sample_size=[256, 384], is_image=False, ): self.root_path = root_path self.sample_stride = sample_stride self.sample_n_frames = sample_n_frames self.is_image = is_image self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r')) self.length = len(self.dataset) sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) pixel_transforms = [transforms.Resize(sample_size), transforms.RandomHorizontalFlip(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] self.pixel_transforms = transforms.Compose(pixel_transforms) def load_video_reader(self, idx): video_dict = self.dataset[idx] video_path = os.path.join(self.root_path, video_dict['clip_path']) video_reader = VideoReader(video_path) return video_reader, video_dict['caption'] def get_batch(self, idx): video_reader, video_caption = self.load_video_reader(idx) total_frames = len(video_reader) if self.is_image: frame_indice = [random.randint(0, total_frames - 1)] else: if isinstance(self.sample_stride, int): current_sample_stride = self.sample_stride else: assert len(self.sample_stride) == 2 assert (self.sample_stride[0] >= 1) and (self.sample_stride[1] >= self.sample_stride[0]) current_sample_stride = random.randint(self.sample_stride[0], self.sample_stride[1]) cropped_length = self.sample_n_frames * current_sample_stride start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1)) end_frame_ind = min(start_frame_ind + cropped_length, total_frames) assert end_frame_ind - start_frame_ind >= self.sample_n_frames frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int) pixel_values = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. if self.is_image: pixel_values = pixel_values[0] return pixel_values, video_caption def __len__(self): return self.length def __getitem__(self, idx): while True: try: video, video_caption = self.get_batch(idx) break except Exception as e: idx = random.randint(0, self.length - 1) video = self.pixel_transforms(video) sample = dict(pixel_values=video, caption=video_caption) return sample class RealEstate10KPose(Dataset): def __init__( self, root_path, annotation_json, sample_stride=4, minimum_sample_stride=1, sample_n_frames=16, relative_pose=False, zero_t_first_frame=False, sample_size=[256, 384], rescale_fxy=False, shuffle_frames=False, use_flip=False, return_clip_name=False, ): self.root_path = root_path self.relative_pose = relative_pose self.zero_t_first_frame = zero_t_first_frame self.sample_stride = sample_stride self.minimum_sample_stride = minimum_sample_stride self.sample_n_frames = sample_n_frames self.return_clip_name = return_clip_name self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r')) self.length = len(self.dataset) sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) self.sample_size = sample_size if use_flip: pixel_transforms = [transforms.Resize(sample_size), RandomHorizontalFlipWithPose(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] else: pixel_transforms = [transforms.Resize(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] self.rescale_fxy = rescale_fxy self.sample_wh_ratio = sample_size[1] / sample_size[0] self.pixel_transforms = pixel_transforms self.shuffle_frames = shuffle_frames self.use_flip = use_flip def get_relative_pose(self, cam_params): abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] source_cam_c2w = abs_c2ws[0] if self.zero_t_first_frame: cam_to_origin = 0 else: cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) target_cam_c2w = np.array([ [1, 0, 0, 0], [0, 1, 0, -cam_to_origin], [0, 0, 1, 0], [0, 0, 0, 1] ]) abs2rel = target_cam_c2w @ abs_w2cs[0] ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] ret_poses = np.array(ret_poses, dtype=np.float32) return ret_poses def load_video_reader(self, idx): video_dict = self.dataset[idx] video_path = os.path.join(self.root_path, video_dict['clip_path']) video_reader = VideoReader(video_path) return video_dict['clip_name'], video_reader, video_dict['caption'] def load_cameras(self, idx): video_dict = self.dataset[idx] pose_file = os.path.join(self.root_path, video_dict['pose_file']) with open(pose_file, 'r') as f: poses = f.readlines() poses = [pose.strip().split(' ') for pose in poses[1:]] cam_params = [[float(x) for x in pose] for pose in poses] cam_params = [Camera(cam_param) for cam_param in cam_params] return cam_params def get_batch(self, idx): clip_name, video_reader, video_caption = self.load_video_reader(idx) cam_params = self.load_cameras(idx) assert len(cam_params) >= self.sample_n_frames total_frames = len(cam_params) current_sample_stride = self.sample_stride if total_frames < self.sample_n_frames * current_sample_stride: maximum_sample_stride = int(total_frames // self.sample_n_frames) current_sample_stride = random.randint(self.minimum_sample_stride, maximum_sample_stride) cropped_length = self.sample_n_frames * current_sample_stride start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1)) end_frame_ind = min(start_frame_ind + cropped_length, total_frames) assert end_frame_ind - start_frame_ind >= self.sample_n_frames frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int) condition_image_ind = random.sample(list(set(range(total_frames)) - set(frame_indices.tolist())), 1) condition_image = torch.from_numpy(video_reader.get_batch(condition_image_ind).asnumpy()).permute(0, 3, 1, 2).contiguous() condition_image = condition_image / 255. if self.shuffle_frames: perm = np.random.permutation(self.sample_n_frames) frame_indices = frame_indices[perm] pixel_values = torch.from_numpy(video_reader.get_batch(frame_indices).asnumpy()).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. cam_params = [cam_params[indice] for indice in frame_indices] if self.rescale_fxy: ori_h, ori_w = pixel_values.shape[-2:] ori_wh_ratio = ori_w / ori_h if ori_wh_ratio > self.sample_wh_ratio: # rescale fx resized_ori_w = self.sample_size[0] * ori_wh_ratio for cam_param in cam_params: cam_param.fx = resized_ori_w * cam_param.fx / self.sample_size[1] else: # rescale fy resized_ori_h = self.sample_size[1] / ori_wh_ratio for cam_param in cam_params: cam_param.fy = resized_ori_h * cam_param.fy / self.sample_size[0] intrinsics = np.asarray([[cam_param.fx * self.sample_size[1], cam_param.fy * self.sample_size[0], cam_param.cx * self.sample_size[1], cam_param.cy * self.sample_size[0]] for cam_param in cam_params], dtype=np.float32) intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4] if self.relative_pose: c2w_poses = self.get_relative_pose(cam_params) else: c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4] if self.use_flip: flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames) else: flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device) plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu', flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() return pixel_values, condition_image, plucker_embedding, video_caption, flip_flag, clip_name def __len__(self): return self.length def __getitem__(self, idx): while True: try: video, condition_image, plucker_embedding, video_caption, flip_flag, clip_name = self.get_batch(idx) break except Exception as e: idx = random.randint(0, self.length - 1) if self.use_flip: video = self.pixel_transforms[0](video) video = self.pixel_transforms[1](video, flip_flag) for transform in self.pixel_transforms[2:]: video = transform(video) else: for transform in self.pixel_transforms: video = transform(video) for transform in self.pixel_transforms: condition_image = transform(condition_image) if self.return_clip_name: sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption, clip_name=clip_name) else: sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption) return sample