pengc02's picture
all
ec9a6bc
import torch
import torch.nn.functional as F
import torchvision as tv
from torch.utils.data import Dataset
import numpy as np
import glob
import math
import os
import random
import cv2
from skimage import io
from pytorch3d.renderer.cameras import look_at_view_transform
from pytorch3d.transforms import so3_exponential_map
from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix
def CropImage(left_up, crop_size, image=None, K=None):
crop_size = np.array(crop_size).astype(np.int32)
left_up = np.array(left_up).astype(np.int32)
if not K is None:
K[0:2,2] = K[0:2,2] - np.array(left_up)
if not image is None:
if left_up[0] < 0:
image_left = np.zeros([image.shape[0], -left_up[0], image.shape[2]], dtype=np.uint8)
image = np.hstack([image_left, image])
left_up[0] = 0
if left_up[1] < 0:
image_up = np.zeros([-left_up[1], image.shape[1], image.shape[2]], dtype=np.uint8)
image = np.vstack([image_up, image])
left_up[1] = 0
if crop_size[0] + left_up[0] > image.shape[1]:
image_right = np.zeros([image.shape[0], crop_size[0] + left_up[0] - image.shape[1], image.shape[2]], dtype=np.uint8)
image = np.hstack([image, image_right])
if crop_size[1] + left_up[1] > image.shape[0]:
image_down = np.zeros([crop_size[1] + left_up[1] - image.shape[0], image.shape[1], image.shape[2]], dtype=np.uint8)
image = np.vstack([image, image_down])
image = image[left_up[1]:left_up[1]+crop_size[1], left_up[0]:left_up[0]+crop_size[0], :]
return image, K
def ResizeImage(target_size, source_size, image=None, K=None):
if not K is None:
K[0,:] = (target_size[0] / source_size[0]) * K[0,:]
K[1,:] = (target_size[1] / source_size[1]) * K[1,:]
if not image is None:
image = cv2.resize(image, dsize=target_size)
return image, K
class MeshDataset(Dataset):
def __init__(self, cfg):
super(MeshDataset, self).__init__()
self.dataroot = cfg.dataroot
self.camera_ids = cfg.camera_ids
self.original_resolution = cfg.original_resolution
self.resolution = cfg.resolution
self.num_sample_view = cfg.num_sample_view
self.samples = []
image_folder = os.path.join(self.dataroot, 'images')
param_folder = os.path.join(self.dataroot, 'params')
camera_folder = os.path.join(self.dataroot, 'cameras')
frames = os.listdir(image_folder)
self.num_exp_id = 0
for frame in frames:
image_paths = [os.path.join(image_folder, frame, 'image_%s.jpg' % camera_id) for camera_id in self.camera_ids]
mask_paths = [os.path.join(image_folder, frame, 'mask_%s.jpg' % camera_id) for camera_id in self.camera_ids]
visible_paths = [os.path.join(image_folder, frame, 'window_%s.jpg' % camera_id) for camera_id in self.camera_ids]
camera_paths = [os.path.join(image_folder, frame, 'camera_%s.npz' % camera_id) for camera_id in self.camera_ids]
param_path = os.path.join(param_folder, frame, 'params.npz')
landmarks_3d_path = os.path.join(param_folder, frame, 'lmk_3d.npy')
vertices_path = os.path.join(param_folder, frame, 'vertices.npy')
sample = (image_paths, mask_paths, visible_paths, camera_paths, param_path, landmarks_3d_path, vertices_path, self.num_exp_id)
self.samples.append(sample)
self.num_exp_id += 1
init_landmarks_3d = torch.from_numpy(np.load(os.path.join(param_folder, frames[0], 'lmk_3d.npy'))).float()
init_vertices = torch.from_numpy(np.load(os.path.join(param_folder, frames[0], 'vertices.npy'))).float()
init_landmarks_3d = torch.cat([init_landmarks_3d, init_vertices[::100]], 0)
param = np.load(os.path.join(param_folder, frames[0], 'params.npz'))
pose = torch.from_numpy(param['pose'][0]).float()
R = so3_exponential_map(pose[None, :3])[0]
T = pose[None, 3:]
S = torch.from_numpy(param['scale']).float()
self.init_landmarks_3d_neutral = (torch.matmul(init_landmarks_3d- T, R)) / S
def get_item(self, index):
data = self.__getitem__(index)
return data
def __getitem__(self, index):
sample = self.samples[index]
images = []
masks = []
visibles = []
views = random.sample(range(len(self.camera_ids)), self.num_sample_view)
for view in views:
image_path = sample[0][view]
image = cv2.resize(io.imread(image_path), (self.resolution, self.resolution))
image = torch.from_numpy(image / 255).permute(2, 0, 1).float()
images.append(image)
mask_path = sample[1][view]
mask = cv2.resize(io.imread(mask_path), (self.resolution, self.resolution))
if len(mask.shape) == 3:
mask = mask[:, :, 0:1]
elif len(mask.shape) == 2:
mask = mask[:, :, None]
mask = torch.from_numpy(mask / 255).permute(2, 0, 1).float()
masks.append(mask)
visible_path = sample[2][view]
if os.path.exists(visible_path):
visible = cv2.resize(io.imread(visible_path), (self.resolution, self.resolution))
if len(mask.shape) == 3:
visible = visible[:, :, 0:1]
elif len(mask.shape) == 2:
visible = visible[:, :, None]
visible = torch.from_numpy(visible / 255).permute(2, 0, 1).float()
else:
visible = torch.ones_like(image)
visibles.append(visible)
images = torch.stack(images)
masks = torch.stack(masks)
images = images * masks
visibles = torch.stack(visibles)
cameras = [np.load(sample[3][view]) for view in views]
intrinsics = torch.stack([torch.from_numpy(camera['intrinsic']).float() for camera in cameras])
extrinsics = torch.stack([torch.from_numpy(camera['extrinsic']).float() for camera in cameras])
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * 2 / self.original_resolution
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * 2 / self.original_resolution - 1
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * 2 / self.original_resolution
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * 2 / self.original_resolution - 1
param_path = sample[4]
param = np.load(param_path)
pose = torch.from_numpy(param['pose'][0]).float()
scale = torch.from_numpy(param['scale']).float()
exp_coeff = torch.from_numpy(param['exp_coeff'][0]).float()
landmarks_3d_path = sample[5]
landmarks_3d = torch.from_numpy(np.load(landmarks_3d_path)).float()
vertices_path = sample[6]
vertices = torch.from_numpy(np.load(vertices_path)).float()
landmarks_3d = torch.cat([landmarks_3d, vertices[::100]], 0)
exp_id = sample[7]
return {
'images': images,
'masks': masks,
'visibles': visibles,
'pose': pose,
'scale': scale,
'exp_coeff': exp_coeff,
'landmarks_3d': landmarks_3d,
'intrinsics': intrinsics,
'extrinsics': extrinsics,
'exp_id': exp_id}
def __len__(self):
return len(self.samples)
class GaussianDataset(Dataset):
def __init__(self, cfg):
super(GaussianDataset, self).__init__()
self.dataroot = cfg.dataroot
self.camera_ids = cfg.camera_ids
self.original_resolution = cfg.original_resolution
self.resolution = cfg.resolution
self.samples = []
image_folder = os.path.join(self.dataroot, 'images')
param_folder = os.path.join(self.dataroot, 'params')
camera_folder = os.path.join(self.dataroot, 'cameras')
frames = os.listdir(image_folder)
self.num_exp_id = 0
for frame in frames:
image_paths = [os.path.join(image_folder, frame, 'image_%s.jpg' % camera_id) for camera_id in self.camera_ids]
mask_paths = [os.path.join(image_folder, frame, 'mask_%s.jpg' % camera_id) for camera_id in self.camera_ids]
visible_paths = [os.path.join(image_folder, frame, 'window_%s.jpg' % camera_id) for camera_id in self.camera_ids]
camera_paths = [os.path.join(image_folder, frame, 'camera_%s.npz' % camera_id) for camera_id in self.camera_ids]
param_path = os.path.join(param_folder, frame, 'params.npz')
landmarks_3d_path = os.path.join(param_folder, frame, 'lmk_3d.npy')
vertices_path = os.path.join(param_folder, frame, 'vertices.npy')
sample = (image_paths, mask_paths, visible_paths, camera_paths, param_path, landmarks_3d_path, vertices_path, self.num_exp_id)
self.samples.append(sample)
self.num_exp_id += 1
def get_item(self, index):
data = self.__getitem__(index)
return data
def __getitem__(self, index):
sample = self.samples[index]
view = random.sample(range(len(self.camera_ids)), 1)[0]
image_path = sample[0][view]
image = cv2.resize(io.imread(image_path), (self.original_resolution, self.original_resolution)) / 255
mask_path = sample[1][view]
mask = cv2.resize(io.imread(mask_path), (self.original_resolution, self.original_resolution)) / 255
if len(mask.shape) == 3:
mask = mask[:, :, 0:1]
elif len(mask.shape) == 2:
mask = mask[:, :, None]
image = image * mask + (1 - mask)
visible_path = sample[2][view]
if os.path.exists(visible_path):
visible = cv2.resize(io.imread(visible_path), (self.original_resolution, self.original_resolution)) / 255
if len(visible.shape) == 3:
visible = visible[:, :, 0:1]
elif len(mask.shape) == 2:
visible = visible[:, :, None]
else:
visible = np.ones_like(image)
camera = np.load(sample[3][view])
extrinsic = torch.from_numpy(camera['extrinsic']).float()
R = extrinsic[:3,:3].t()
T = extrinsic[:3, 3]
intrinsic = camera['intrinsic']
if np.abs(intrinsic[0, 2] - self.original_resolution / 2) > 1 or np.abs(intrinsic[1, 2] - self.original_resolution / 2) > 1:
left_up = np.around(intrinsic[0:2, 2] - np.array([self.original_resolution / 2, self.original_resolution / 2])).astype(np.int32)
_, intrinsic = CropImage(left_up, (self.original_resolution, self.original_resolution), K=intrinsic)
image, _ = CropImage(left_up, (self.original_resolution, self.original_resolution), image=image)
mask, _ = CropImage(left_up, (self.original_resolution, self.original_resolution), image=mask)
visible, _ = CropImage(left_up, (self.original_resolution, self.original_resolution), image=visible)
intrinsic[0, 0] = intrinsic[0, 0] * 2 / self.original_resolution
intrinsic[0, 2] = intrinsic[0, 2] * 2 / self.original_resolution - 1
intrinsic[1, 1] = intrinsic[1, 1] * 2 / self.original_resolution
intrinsic[1, 2] = intrinsic[1, 2] * 2 / self.original_resolution - 1
intrinsic = torch.from_numpy(intrinsic).float()
image = torch.from_numpy(cv2.resize(image, (self.resolution, self.resolution))).permute(2, 0, 1).float()
mask = torch.from_numpy(cv2.resize(mask, (self.resolution, self.resolution)))[None].float()
visible = torch.from_numpy(cv2.resize(visible, (self.resolution, self.resolution)))[None].float()
image_coarse = F.interpolate(image[None], scale_factor=0.25)[0]
mask_coarse = F.interpolate(mask[None], scale_factor=0.25)[0]
visible_coarse = F.interpolate(visible[None], scale_factor=0.25)[0]
fovx = 2 * math.atan(1 / intrinsic[0, 0])
fovy = 2 * math.atan(1 / intrinsic[1, 1])
world_view_transform = torch.tensor(getWorld2View2(R.numpy(), T.numpy())).transpose(0, 1)
projection_matrix = getProjectionMatrix(znear=0.01, zfar=100, fovX=fovx, fovY=fovy).transpose(0,1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
param_path = sample[4]
param = np.load(param_path)
pose = torch.from_numpy(param['pose'][0]).float()
scale = torch.from_numpy(param['scale']).float()
exp_coeff = torch.from_numpy(param['exp_coeff'][0]).float()
landmarks_3d_path = sample[5]
landmarks_3d = torch.from_numpy(np.load(landmarks_3d_path)).float()
vertices_path = sample[6]
vertices = torch.from_numpy(np.load(vertices_path)).float()
landmarks_3d = torch.cat([landmarks_3d, vertices[::100]], 0)
exp_id = torch.tensor(sample[7]).long()
return {
'images': image,
'masks': mask,
'visibles': visible,
'images_coarse': image_coarse,
'masks_coarse': mask_coarse,
'visibles_coarse': visible_coarse,
'pose': pose,
'scale': scale,
'exp_coeff': exp_coeff,
'landmarks_3d': landmarks_3d,
'exp_id': exp_id,
'extrinsics': extrinsic,
'intrinsics': intrinsic,
'fovx': fovx,
'fovy': fovy,
'world_view_transform': world_view_transform,
'projection_matrix': projection_matrix,
'full_proj_transform': full_proj_transform,
'camera_center': camera_center}
def __len__(self):
return len(self.samples)
class ReenactmentDataset(Dataset):
def __init__(self, cfg):
super(ReenactmentDataset, self).__init__()
self.dataroot = cfg.dataroot
self.original_resolution = cfg.original_resolution
self.resolution = cfg.resolution
self.freeview = cfg.freeview
self.Rot_z = torch.eye(3)
self.Rot_z[0,0] = -1.0
self.Rot_z[1,1] = -1.0
self.samples = []
image_paths = sorted(glob.glob(os.path.join(self.dataroot, cfg.image_files)))
param_paths = sorted(glob.glob(os.path.join(self.dataroot, cfg.param_files)))
# assert len(image_paths) == len(param_paths)
self.samples = []
# for i, image_path in enumerate(image_paths):
# param_path = param_paths[i]
# if os.path.exists(image_path) and os.path.exists(param_path):
# sample = (image_path, param_path)
# self.samples.append(sample)
exp_path = cfg.exp_path
self.exp_coeff = np.load(exp_path)
for i, param_path in enumerate(param_paths):
image_path = image_paths[0]
if os.path.exists(image_path) and os.path.exists(param_path):
sample = (image_path, param_path)
self.samples.append(sample)
# add to length of exp_path
while len(self.samples) < self.exp_coeff.shape[0]:
pack = (image_paths[0], param_paths[0])
self.samples.append(pack)
if os.path.exists(cfg.pose_code_path):
self.pose_code = torch.from_numpy(np.load(cfg.pose_code_path)['pose'][0]).float()
else:
self.pose_code = None
self.extrinsic = torch.tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, -1.0000, 0.0000, 0.0000],
[0.0000, 0.0000, -1.0000, 1.0000]]).float()
self.intrinsic = torch.tensor([[self.original_resolution * 3.5, 0.0000e+00, self.original_resolution / 2],
[0.0000e+00, self.original_resolution * 3.5, self.original_resolution / 2],
[0.0000e+00, 0.0000e+00, 1.0000e+00]]).float()
if os.path.exists(cfg.camera_path):
camera = np.load(cfg.camera_path)
self.extrinsic = torch.from_numpy(camera['extrinsic']).float()
if not self.freeview:
self.intrinsic = torch.from_numpy(camera['intrinsic']).float()
self.R = self.extrinsic[:3,:3].t()
self.T = self.extrinsic[:3, 3]
self.intrinsic[0, 0] = self.intrinsic[0, 0] * 2 / self.original_resolution
self.intrinsic[0, 2] = self.intrinsic[0, 2] * 2 / self.original_resolution - 1
self.intrinsic[1, 1] = self.intrinsic[1, 1] * 2 / self.original_resolution
self.intrinsic[1, 2] = self.intrinsic[1, 2] * 2 / self.original_resolution - 1
self.fovx = 2 * math.atan(1 / self.intrinsic[0, 0])
self.fovy = 2 * math.atan(1 / self.intrinsic[1, 1])
self.world_view_transform = torch.tensor(getWorld2View2(self.R.numpy(), self.T.numpy())).transpose(0, 1)
self.projection_matrix = getProjectionMatrix(znear=0.01, zfar=100, fovX=self.fovx, fovY=self.fovy).transpose(0,1)
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
def update_camera(self, index):
elev = math.sin(index / 20) * 8 + 0
azim = math.cos(index / 20) * 45 - 0
R, T = look_at_view_transform(dist=1.2, elev=elev, azim=azim, at=((0.0, 0.0, 0.05),))
R = torch.matmul(self.Rot_z, R[0].t())
self.extrinsic = torch.cat([R, T.t()], -1)
self.R = self.extrinsic[:3,:3].t()
self.T = self.extrinsic[:3, 3]
self.world_view_transform = torch.tensor(getWorld2View2(self.R.numpy(), self.T.numpy())).transpose(0, 1)
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
def get_item(self, index):
data = self.__getitem__(index)
return data
def __getitem__(self, index):
if self.freeview:
self.update_camera(index)
sample = self.samples[index]
image_path = sample[0]
image = torch.from_numpy(cv2.resize(io.imread(image_path), (self.resolution, self.resolution)) / 255).permute(2, 0, 1).float()
param_path = sample[1]
param = np.load(param_path)
pose = torch.from_numpy(param['pose'][0]).float()
scale = torch.from_numpy(param['scale']).float()
exp_coeff = torch.from_numpy(param['exp_coeff'][0]).float()
id_coeff = torch.from_numpy(param['id_coeff'])[0].float()
# load new exp_coeff
# exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/ag_gha/data/face_pose/thu.npy')
# exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/1004_slow_exp/auto.npy')
# exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/0930_sing/singc.npy')
# exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/avatar_final/data/train_audio_and_exp/train_exp.npy')
exp_coeff = self.exp_coeff
exp_coeff = torch.from_numpy(exp_coeff[index]).float()
# add mid exp_coeff
mid_exp_coeff = np.load('/home/pengc02/pengcheng/projects/gaussian_avatar/ag_gha/data/face0313all/params/0005/params.npz')
mid_exp_coeff = mid_exp_coeff['exp_coeff'][0]
mid_exp_coeff = torch.from_numpy(mid_exp_coeff).float()
# exp_coeff = exp_coeff+mid_exp_coeff
if self.pose_code is not None:
pose_code = self.pose_code
else:
pose_code = pose
return {
'images': image,
'pose': pose,
'scale': scale,
'id_coeff': id_coeff,
'exp_coeff': exp_coeff,
'pose_code': pose_code,
'extrinsics': self.extrinsic,
'intrinsics': self.intrinsic,
'fovx': self.fovx,
'fovy': self.fovy,
'world_view_transform': self.world_view_transform,
'projection_matrix': self.projection_matrix,
'full_proj_transform': self.full_proj_transform,
'camera_center': self.camera_center}
def __len__(self):
return len(self.samples)