import torch import torch.nn.functional as F from tqdm import tqdm import numpy as np import math from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix class Reenactment(): def __init__(self, dataloader, gaussianhead, supres, camera, recorder, gpu_id, freeview): self.dataloader = dataloader self.gaussianhead = gaussianhead self.supres = supres self.camera = camera self.recorder = recorder self.device = torch.device('cuda:%d' % gpu_id) self.freeview = freeview def run(self, stop_fid=None): for idx, data in tqdm(enumerate(self.dataloader)): to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', 'pose', 'scale', 'exp_coeff', 'pose_code'] for data_item in to_cuda: data[data_item] = data[data_item].to(device=self.device) if not self.freeview: if idx > 0: data['pose'] = pose_last * 0.5 + data['pose'] * 0.5 data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5 pose_last = data['pose'] exp_last = data['exp_coeff'] else: data['pose'] *= 0 if idx > 0: data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5 exp_last = data['exp_coeff'] with torch.no_grad(): data = self.gaussianhead.generate(data) data = self.camera.render_gaussian(data, 512) render_images = data['render_images'] supres_images = self.supres(render_images) data['supres_images'] = supres_images log = { 'data': data, 'iter': idx } self.recorder.log(log) if stop_fid is not None and idx == stop_fid: print('# Reaching stop frame index (%d)' % stop_fid) break def run_for_offline_stitching(self, offline_rendering_param_fpath): head_offline_rendering_param = np.load(offline_rendering_param_fpath) cam_extr = head_offline_rendering_param['cam_extr'] cam_intr = head_offline_rendering_param['cam_intr'] cam_intr_zoom = head_offline_rendering_param['cam_intr_zoom'] zoom_image_size = head_offline_rendering_param['zoom_image_size'] head_pose = head_offline_rendering_param['head_pose'] head_scale = head_offline_rendering_param['head_scale'] head_color_bw = head_offline_rendering_param['head_color_bw'] zoom_scale = head_offline_rendering_param['zoom_scale'] head_pose = torch.from_numpy(head_pose.astype(np.float32)).to(self.device) head_color_bw = torch.from_numpy(head_color_bw.astype(np.float32)).to(self.device) render_size = 512 for idx, data in enumerate(tqdm(self.dataloader)): if idx >= len(cam_extr): print('# Reaching the end of offline stiitching parameters! Rendering stopped. ') break new_gs_camera_param_dict = self.prepare_camera_data_for_gs_rendering(cam_extr[idx], cam_intr_zoom[idx], render_size, render_size) for k in new_gs_camera_param_dict.keys(): if isinstance(new_gs_camera_param_dict[k], torch.Tensor): new_gs_camera_param_dict[k] = new_gs_camera_param_dict[k].unsqueeze(0).to(self.device) new_gs_camera_param_dict['pose'] = head_pose.unsqueeze(0).to(self.device) to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', 'pose', 'scale', 'exp_coeff', 'pose_code'] for data_item in to_cuda: data[data_item] = data[data_item].to(device=self.device) data.update(new_gs_camera_param_dict) with torch.no_grad(): data = self.gaussianhead.generate(data) data = self.camera.render_gaussian(data, 512) render_images = data['render_images'] supres_images = self.supres(render_images) data['supres_images'] = supres_images data['bg_color'] = torch.zeros([1, 32], device=self.device, dtype=torch.float32) data['color_bk'] = data.pop('color') data['color'] = torch.ones_like(data['color_bk']) * head_color_bw.reshape([1, -1, 1]) * 2.0 data['color'][:, :, 1] = 1 data['color'] = torch.clamp(data['color'], 0., 1.) data = self.camera.render_gaussian(data, render_size) render_bw = data['render_images'][:, :3, :, :] data['color'] = data.pop('color_bk') data['render_bw'] = render_bw log = { 'data': data, 'iter': idx } self.recorder.log(log) def prepare_camera_data_for_gs_rendering(self, extrinsic, intrinsic, original_resolution, new_resolution): extrinsic = np.copy(extrinsic) intrinsic = np.copy(intrinsic) new_intrinsic = np.copy(intrinsic) new_intrinsic[:2] *= new_resolution / original_resolution intrinsic[0, 0] = intrinsic[0, 0] * 2 / original_resolution intrinsic[0, 2] = intrinsic[1, 2] * 2 / original_resolution - 1 intrinsic[1, 1] = intrinsic[1, 1] * 2 / original_resolution intrinsic[1, 2] = intrinsic[1, 2] * 2 / original_resolution - 1 fovx = 2 * math.atan(1 / intrinsic[0, 0]) fovy = 2 * math.atan(1 / intrinsic[1, 1]) world_view_transform = torch.tensor(getWorld2View2(extrinsic[:3, :3].transpose(), extrinsic[:3, 3])).transpose(0, 1) projection_matrix = getProjectionMatrix( znear=0.01, zfar=100, fovX=None, fovY=None, K=new_intrinsic, img_h=new_resolution, img_w=new_resolution).transpose(0,1) full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] c2w = np.linalg.inv(extrinsic) viewdir = np.matmul(c2w[:3, :3], np.array([0, 0, -1], np.float32).reshape([3, 1])).reshape([-1]) viewdir = torch.from_numpy(viewdir.astype(np.float32)) return { 'extrinsics': torch.from_numpy(extrinsic.astype(np.float32)), 'intrinsics': torch.from_numpy(intrinsic.astype(np.float32)), 'viewdir': viewdir, 'fovx': torch.Tensor([fovx]), 'fovy': torch.Tensor([fovy]), 'world_view_transform': world_view_transform, 'projection_matrix': projection_matrix, 'full_proj_transform': full_proj_transform, 'camera_center': camera_center }