import torch import torch.nn.functional as F import torchvision as tv from torch.utils.data import Dataset import numpy as np import glob import math import os import random import cv2 from skimage import io from pytorch3d.renderer.cameras import look_at_view_transform from pytorch3d.transforms import so3_exponential_map from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix def CropImage(left_up, crop_size, image=None, K=None): crop_size = np.array(crop_size).astype(np.int32) left_up = np.array(left_up).astype(np.int32) if not K is None: K[0:2,2] = K[0:2,2] - np.array(left_up) if not image is None: if left_up[0] < 0: image_left = np.zeros([image.shape[0], -left_up[0], image.shape[2]], dtype=np.uint8) image = np.hstack([image_left, image]) left_up[0] = 0 if left_up[1] < 0: image_up = np.zeros([-left_up[1], image.shape[1], image.shape[2]], dtype=np.uint8) image = np.vstack([image_up, image]) left_up[1] = 0 if crop_size[0] + left_up[0] > image.shape[1]: image_right = np.zeros([image.shape[0], crop_size[0] + left_up[0] - image.shape[1], image.shape[2]], dtype=np.uint8) image = np.hstack([image, image_right]) if crop_size[1] + left_up[1] > image.shape[0]: image_down = np.zeros([crop_size[1] + left_up[1] - image.shape[0], image.shape[1], image.shape[2]], dtype=np.uint8) image = np.vstack([image, image_down]) image = image[left_up[1]:left_up[1]+crop_size[1], left_up[0]:left_up[0]+crop_size[0], :] return image, K def ResizeImage(target_size, source_size, image=None, K=None): if not K is None: K[0,:] = (target_size[0] / source_size[0]) * K[0,:] K[1,:] = (target_size[1] / source_size[1]) * K[1,:] if not image is None: image = cv2.resize(image, dsize=target_size) return image, K class MeshDataset(Dataset): def __init__(self, cfg): super(MeshDataset, self).__init__() self.dataroot = cfg.dataroot self.camera_ids = cfg.camera_ids self.original_resolution = cfg.original_resolution self.resolution = cfg.resolution self.num_sample_view = cfg.num_sample_view self.samples = [] image_folder = os.path.join(self.dataroot, 'images') param_folder = os.path.join(self.dataroot, 'params') camera_folder = os.path.join(self.dataroot, 'cameras') frames = os.listdir(image_folder) self.num_exp_id = 0 for frame in frames: image_paths = [os.path.join(image_folder, frame, 'image_%s.jpg' % camera_id) for camera_id in self.camera_ids] mask_paths = [os.path.join(image_folder, frame, 'mask_%s.jpg' % camera_id) for camera_id in self.camera_ids] visible_paths = [os.path.join(image_folder, frame, 'window_%s.jpg' % camera_id) for camera_id in self.camera_ids] camera_paths = [os.path.join(image_folder, frame, 'camera_%s.npz' % camera_id) for camera_id in self.camera_ids] param_path = os.path.join(param_folder, frame, 'params.npz') landmarks_3d_path = os.path.join(param_folder, frame, 'lmk_3d.npy') vertices_path = os.path.join(param_folder, frame, 'vertices.npy') sample = (image_paths, mask_paths, visible_paths, camera_paths, param_path, landmarks_3d_path, vertices_path, self.num_exp_id) self.samples.append(sample) self.num_exp_id += 1 init_landmarks_3d = torch.from_numpy(np.load(os.path.join(param_folder, frames[0], 'lmk_3d.npy'))).float() init_vertices = torch.from_numpy(np.load(os.path.join(param_folder, frames[0], 'vertices.npy'))).float() init_landmarks_3d = torch.cat([init_landmarks_3d, init_vertices[::100]], 0) param = np.load(os.path.join(param_folder, frames[0], 'params.npz')) pose = torch.from_numpy(param['pose'][0]).float() R = so3_exponential_map(pose[None, :3])[0] T = pose[None, 3:] S = torch.from_numpy(param['scale']).float() self.init_landmarks_3d_neutral = (torch.matmul(init_landmarks_3d- T, R)) / S def get_item(self, index): data = self.__getitem__(index) return data def __getitem__(self, index): sample = self.samples[index] images = [] masks = [] visibles = [] views = random.sample(range(len(self.camera_ids)), self.num_sample_view) for view in views: image_path = sample[0][view] image = cv2.resize(io.imread(image_path), (self.resolution, self.resolution)) image = torch.from_numpy(image / 255).permute(2, 0, 1).float() images.append(image) mask_path = sample[1][view] mask = cv2.resize(io.imread(mask_path), (self.resolution, self.resolution)) if len(mask.shape) == 3: mask = mask[:, :, 0:1] elif len(mask.shape) == 2: mask = mask[:, :, None] mask = torch.from_numpy(mask / 255).permute(2, 0, 1).float() masks.append(mask) visible_path = sample[2][view] if os.path.exists(visible_path): visible = cv2.resize(io.imread(visible_path), (self.resolution, self.resolution)) if len(mask.shape) == 3: visible = visible[:, :, 0:1] elif len(mask.shape) == 2: visible = visible[:, :, None] visible = torch.from_numpy(visible / 255).permute(2, 0, 1).float() else: visible = torch.ones_like(image) visibles.append(visible) images = torch.stack(images) masks = torch.stack(masks) images = images * masks visibles = torch.stack(visibles) cameras = [np.load(sample[3][view]) for view in views] intrinsics = torch.stack([torch.from_numpy(camera['intrinsic']).float() for camera in cameras]) extrinsics = torch.stack([torch.from_numpy(camera['extrinsic']).float() for camera in cameras]) intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * 2 / self.original_resolution intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * 2 / self.original_resolution - 1 intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * 2 / self.original_resolution intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * 2 / self.original_resolution - 1 param_path = sample[4] param = np.load(param_path) pose = torch.from_numpy(param['pose'][0]).float() scale = torch.from_numpy(param['scale']).float() exp_coeff = torch.from_numpy(param['exp_coeff'][0]).float() landmarks_3d_path = sample[5] landmarks_3d = torch.from_numpy(np.load(landmarks_3d_path)).float() vertices_path = sample[6] vertices = torch.from_numpy(np.load(vertices_path)).float() landmarks_3d = torch.cat([landmarks_3d, vertices[::100]], 0) exp_id = sample[7] return { 'images': images, 'masks': masks, 'visibles': visibles, 'pose': pose, 'scale': scale, 'exp_coeff': exp_coeff, 'landmarks_3d': landmarks_3d, 'intrinsics': intrinsics, 'extrinsics': extrinsics, 'exp_id': exp_id} def __len__(self): return len(self.samples) class GaussianDataset(Dataset): def __init__(self, cfg): super(GaussianDataset, self).__init__() self.dataroot = cfg.dataroot self.camera_ids = cfg.camera_ids self.original_resolution = cfg.original_resolution self.resolution = cfg.resolution self.samples = [] image_folder = os.path.join(self.dataroot, 'images') param_folder = os.path.join(self.dataroot, 'params') camera_folder = os.path.join(self.dataroot, 'cameras') frames = os.listdir(image_folder) self.num_exp_id = 0 for frame in frames: image_paths = [os.path.join(image_folder, frame, 'image_%s.jpg' % camera_id) for camera_id in self.camera_ids] mask_paths = [os.path.join(image_folder, frame, 'mask_%s.jpg' % camera_id) for camera_id in self.camera_ids] visible_paths = [os.path.join(image_folder, frame, 'window_%s.jpg' % camera_id) for camera_id in self.camera_ids] camera_paths = [os.path.join(image_folder, frame, 'camera_%s.npz' % camera_id) for camera_id in self.camera_ids] param_path = os.path.join(param_folder, frame, 'params.npz') landmarks_3d_path = os.path.join(param_folder, frame, 'lmk_3d.npy') vertices_path = os.path.join(param_folder, frame, 'vertices.npy') sample = (image_paths, mask_paths, visible_paths, camera_paths, param_path, landmarks_3d_path, vertices_path, self.num_exp_id) self.samples.append(sample) self.num_exp_id += 1 def get_item(self, index): data = self.__getitem__(index) return data def __getitem__(self, index): sample = self.samples[index] view = random.sample(range(len(self.camera_ids)), 1)[0] image_path = sample[0][view] image = cv2.resize(io.imread(image_path), (self.original_resolution, self.original_resolution)) / 255 mask_path = sample[1][view] mask = cv2.resize(io.imread(mask_path), (self.original_resolution, self.original_resolution)) / 255 if len(mask.shape) == 3: mask = mask[:, :, 0:1] elif len(mask.shape) == 2: mask = mask[:, :, None] image = image * mask + (1 - mask) visible_path = sample[2][view] if os.path.exists(visible_path): visible = cv2.resize(io.imread(visible_path), (self.original_resolution, self.original_resolution)) / 255 if len(visible.shape) == 3: visible = visible[:, :, 0:1] elif len(mask.shape) == 2: visible = visible[:, :, None] else: visible = np.ones_like(image) camera = np.load(sample[3][view]) extrinsic = torch.from_numpy(camera['extrinsic']).float() R = extrinsic[:3,:3].t() T = extrinsic[:3, 3] intrinsic = camera['intrinsic'] if np.abs(intrinsic[0, 2] - self.original_resolution / 2) > 1 or np.abs(intrinsic[1, 2] - self.original_resolution / 2) > 1: left_up = np.around(intrinsic[0:2, 2] - np.array([self.original_resolution / 2, self.original_resolution / 2])).astype(np.int32) _, intrinsic = CropImage(left_up, (self.original_resolution, self.original_resolution), K=intrinsic) image, _ = CropImage(left_up, (self.original_resolution, self.original_resolution), image=image) mask, _ = CropImage(left_up, (self.original_resolution, self.original_resolution), image=mask) visible, _ = CropImage(left_up, (self.original_resolution, self.original_resolution), image=visible) intrinsic[0, 0] = intrinsic[0, 0] * 2 / self.original_resolution intrinsic[0, 2] = intrinsic[0, 2] * 2 / self.original_resolution - 1 intrinsic[1, 1] = intrinsic[1, 1] * 2 / self.original_resolution intrinsic[1, 2] = intrinsic[1, 2] * 2 / self.original_resolution - 1 intrinsic = torch.from_numpy(intrinsic).float() image = torch.from_numpy(cv2.resize(image, (self.resolution, self.resolution))).permute(2, 0, 1).float() mask = torch.from_numpy(cv2.resize(mask, (self.resolution, self.resolution)))[None].float() visible = torch.from_numpy(cv2.resize(visible, (self.resolution, self.resolution)))[None].float() image_coarse = F.interpolate(image[None], scale_factor=0.25)[0] mask_coarse = F.interpolate(mask[None], scale_factor=0.25)[0] visible_coarse = F.interpolate(visible[None], scale_factor=0.25)[0] fovx = 2 * math.atan(1 / intrinsic[0, 0]) fovy = 2 * math.atan(1 / intrinsic[1, 1]) world_view_transform = torch.tensor(getWorld2View2(R.numpy(), T.numpy())).transpose(0, 1) projection_matrix = getProjectionMatrix(znear=0.01, zfar=100, fovX=fovx, fovY=fovy).transpose(0,1) full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] param_path = sample[4] param = np.load(param_path) pose = torch.from_numpy(param['pose'][0]).float() scale = torch.from_numpy(param['scale']).float() exp_coeff = torch.from_numpy(param['exp_coeff'][0]).float() landmarks_3d_path = sample[5] landmarks_3d = torch.from_numpy(np.load(landmarks_3d_path)).float() vertices_path = sample[6] vertices = torch.from_numpy(np.load(vertices_path)).float() landmarks_3d = torch.cat([landmarks_3d, vertices[::100]], 0) exp_id = torch.tensor(sample[7]).long() return { 'images': image, 'masks': mask, 'visibles': visible, 'images_coarse': image_coarse, 'masks_coarse': mask_coarse, 'visibles_coarse': visible_coarse, 'pose': pose, 'scale': scale, 'exp_coeff': exp_coeff, 'landmarks_3d': landmarks_3d, 'exp_id': exp_id, 'extrinsics': extrinsic, 'intrinsics': intrinsic, 'fovx': fovx, 'fovy': fovy, 'world_view_transform': world_view_transform, 'projection_matrix': projection_matrix, 'full_proj_transform': full_proj_transform, 'camera_center': camera_center} def __len__(self): return len(self.samples) class ReenactmentDataset(Dataset): def __init__(self, cfg): super(ReenactmentDataset, self).__init__() self.dataroot = cfg.dataroot self.original_resolution = cfg.original_resolution self.resolution = cfg.resolution self.freeview = cfg.freeview self.Rot_z = torch.eye(3) self.Rot_z[0,0] = -1.0 self.Rot_z[1,1] = -1.0 self.samples = [] image_paths = sorted(glob.glob(os.path.join(self.dataroot, cfg.image_files))) param_paths = sorted(glob.glob(os.path.join(self.dataroot, cfg.param_files))) # assert len(image_paths) == len(param_paths) self.samples = [] # for i, image_path in enumerate(image_paths): # param_path = param_paths[i] # if os.path.exists(image_path) and os.path.exists(param_path): # sample = (image_path, param_path) # self.samples.append(sample) exp_path = cfg.exp_path self.exp_coeff = np.load(exp_path) for i, param_path in enumerate(param_paths): image_path = image_paths[0] if os.path.exists(image_path) and os.path.exists(param_path): sample = (image_path, param_path) self.samples.append(sample) # add to length of exp_path while len(self.samples) < self.exp_coeff.shape[0]: pack = (image_paths[0], param_paths[0]) self.samples.append(pack) if os.path.exists(cfg.pose_code_path): self.pose_code = torch.from_numpy(np.load(cfg.pose_code_path)['pose'][0]).float() else: self.pose_code = None self.extrinsic = torch.tensor([[1.0000, 0.0000, 0.0000, 0.0000], [0.0000, -1.0000, 0.0000, 0.0000], [0.0000, 0.0000, -1.0000, 1.0000]]).float() self.intrinsic = torch.tensor([[self.original_resolution * 3.5, 0.0000e+00, self.original_resolution / 2], [0.0000e+00, self.original_resolution * 3.5, self.original_resolution / 2], [0.0000e+00, 0.0000e+00, 1.0000e+00]]).float() if os.path.exists(cfg.camera_path): camera = np.load(cfg.camera_path) self.extrinsic = torch.from_numpy(camera['extrinsic']).float() if not self.freeview: self.intrinsic = torch.from_numpy(camera['intrinsic']).float() self.R = self.extrinsic[:3,:3].t() self.T = self.extrinsic[:3, 3] self.intrinsic[0, 0] = self.intrinsic[0, 0] * 2 / self.original_resolution self.intrinsic[0, 2] = self.intrinsic[0, 2] * 2 / self.original_resolution - 1 self.intrinsic[1, 1] = self.intrinsic[1, 1] * 2 / self.original_resolution self.intrinsic[1, 2] = self.intrinsic[1, 2] * 2 / self.original_resolution - 1 self.fovx = 2 * math.atan(1 / self.intrinsic[0, 0]) self.fovy = 2 * math.atan(1 / self.intrinsic[1, 1]) self.world_view_transform = torch.tensor(getWorld2View2(self.R.numpy(), self.T.numpy())).transpose(0, 1) self.projection_matrix = getProjectionMatrix(znear=0.01, zfar=100, fovX=self.fovx, fovY=self.fovy).transpose(0,1) self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] def update_camera(self, index): elev = math.sin(index / 20) * 8 + 0 azim = math.cos(index / 20) * 45 - 0 R, T = look_at_view_transform(dist=1.2, elev=elev, azim=azim, at=((0.0, 0.0, 0.05),)) R = torch.matmul(self.Rot_z, R[0].t()) self.extrinsic = torch.cat([R, T.t()], -1) self.R = self.extrinsic[:3,:3].t() self.T = self.extrinsic[:3, 3] self.world_view_transform = torch.tensor(getWorld2View2(self.R.numpy(), self.T.numpy())).transpose(0, 1) self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] def get_item(self, index): data = self.__getitem__(index) return data def __getitem__(self, index): if self.freeview: self.update_camera(index) sample = self.samples[index] image_path = sample[0] image = torch.from_numpy(cv2.resize(io.imread(image_path), (self.resolution, self.resolution)) / 255).permute(2, 0, 1).float() param_path = sample[1] param = np.load(param_path) pose = torch.from_numpy(param['pose'][0]).float() scale = torch.from_numpy(param['scale']).float() exp_coeff = torch.from_numpy(param['exp_coeff'][0]).float() id_coeff = torch.from_numpy(param['id_coeff'])[0].float() # load new exp_coeff # exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/ag_gha/data/face_pose/thu.npy') # exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/1004_slow_exp/auto.npy') # exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/0930_sing/singc.npy') # exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/train_audio_and_exp/train_exp.npy') exp_coeff = self.exp_coeff exp_coeff = torch.from_numpy(exp_coeff[index]).float() # add mid exp_coeff mid_exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/ag_gha/data/face0313all/params/0005/params.npz') mid_exp_coeff = mid_exp_coeff['exp_coeff'][0] mid_exp_coeff = torch.from_numpy(mid_exp_coeff).float() # exp_coeff = exp_coeff+mid_exp_coeff if self.pose_code is not None: pose_code = self.pose_code else: pose_code = pose return { 'images': image, 'pose': pose, 'scale': scale, 'id_coeff': id_coeff, 'exp_coeff': exp_coeff, 'pose_code': pose_code, 'extrinsics': self.extrinsic, 'intrinsics': self.intrinsic, 'fovx': self.fovx, 'fovy': self.fovy, 'world_view_transform': self.world_view_transform, 'projection_matrix': self.projection_matrix, 'full_proj_transform': self.full_proj_transform, 'camera_center': self.camera_center} def __len__(self): return len(self.samples)