pengc02's picture
all
ec9a6bc
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import math
from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Reenactment():
def __init__(self, dataloader, gaussianhead, supres, camera, recorder, gpu_id, freeview):
self.dataloader = dataloader
self.gaussianhead = gaussianhead
self.supres = supres
self.camera = camera
self.recorder = recorder
self.device = torch.device('cuda:%d' % gpu_id)
self.freeview = freeview
def run(self, stop_fid=None):
for idx, data in tqdm(enumerate(self.dataloader)):
to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
'pose', 'scale', 'exp_coeff', 'pose_code']
for data_item in to_cuda:
data[data_item] = data[data_item].to(device=self.device)
if not self.freeview:
if idx > 0:
data['pose'] = pose_last * 0.5 + data['pose'] * 0.5
data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5
pose_last = data['pose']
exp_last = data['exp_coeff']
else:
data['pose'] *= 0
if idx > 0:
data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5
exp_last = data['exp_coeff']
with torch.no_grad():
data = self.gaussianhead.generate(data)
data = self.camera.render_gaussian(data, 512)
render_images = data['render_images']
supres_images = self.supres(render_images)
data['supres_images'] = supres_images
log = {
'data': data,
'iter': idx
}
self.recorder.log(log)
if stop_fid is not None and idx == stop_fid:
print('# Reaching stop frame index (%d)' % stop_fid)
break
def run_for_offline_stitching(self, offline_rendering_param_fpath):
head_offline_rendering_param = np.load(offline_rendering_param_fpath)
cam_extr = head_offline_rendering_param['cam_extr']
cam_intr = head_offline_rendering_param['cam_intr']
cam_intr_zoom = head_offline_rendering_param['cam_intr_zoom']
zoom_image_size = head_offline_rendering_param['zoom_image_size']
head_pose = head_offline_rendering_param['head_pose']
head_scale = head_offline_rendering_param['head_scale']
head_color_bw = head_offline_rendering_param['head_color_bw']
zoom_scale = head_offline_rendering_param['zoom_scale']
head_pose = torch.from_numpy(head_pose.astype(np.float32)).to(self.device)
head_color_bw = torch.from_numpy(head_color_bw.astype(np.float32)).to(self.device)
render_size = 512
for idx, data in enumerate(tqdm(self.dataloader)):
if idx >= len(cam_extr):
print('# Reaching the end of offline stiitching parameters! Rendering stopped. ')
break
new_gs_camera_param_dict = self.prepare_camera_data_for_gs_rendering(cam_extr[idx], cam_intr_zoom[idx], render_size, render_size)
for k in new_gs_camera_param_dict.keys():
if isinstance(new_gs_camera_param_dict[k], torch.Tensor):
new_gs_camera_param_dict[k] = new_gs_camera_param_dict[k].unsqueeze(0).to(self.device)
new_gs_camera_param_dict['pose'] = head_pose.unsqueeze(0).to(self.device)
to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
'pose', 'scale', 'exp_coeff', 'pose_code']
for data_item in to_cuda:
data[data_item] = data[data_item].to(device=self.device)
data.update(new_gs_camera_param_dict)
with torch.no_grad():
data = self.gaussianhead.generate(data)
data = self.camera.render_gaussian(data, 512)
render_images = data['render_images']
supres_images = self.supres(render_images)
data['supres_images'] = supres_images
data['bg_color'] = torch.zeros([1, 32], device=self.device, dtype=torch.float32)
data['color_bk'] = data.pop('color')
data['color'] = torch.ones_like(data['color_bk']) * head_color_bw.reshape([1, -1, 1]) * 2.0
data['color'][:, :, 1] = 1
data['color'] = torch.clamp(data['color'], 0., 1.)
data = self.camera.render_gaussian(data, render_size)
render_bw = data['render_images'][:, :3, :, :]
data['color'] = data.pop('color_bk')
data['render_bw'] = render_bw
log = {
'data': data,
'iter': idx
}
self.recorder.log(log)
def prepare_camera_data_for_gs_rendering(self, extrinsic, intrinsic, original_resolution, new_resolution):
extrinsic = np.copy(extrinsic)
intrinsic = np.copy(intrinsic)
new_intrinsic = np.copy(intrinsic)
new_intrinsic[:2] *= new_resolution / original_resolution
intrinsic[0, 0] = intrinsic[0, 0] * 2 / original_resolution
intrinsic[0, 2] = intrinsic[1, 2] * 2 / original_resolution - 1
intrinsic[1, 1] = intrinsic[1, 1] * 2 / original_resolution
intrinsic[1, 2] = intrinsic[1, 2] * 2 / original_resolution - 1
fovx = 2 * math.atan(1 / intrinsic[0, 0])
fovy = 2 * math.atan(1 / intrinsic[1, 1])
world_view_transform = torch.tensor(getWorld2View2(extrinsic[:3, :3].transpose(), extrinsic[:3, 3])).transpose(0, 1)
projection_matrix = getProjectionMatrix(
znear=0.01, zfar=100, fovX=None, fovY=None,
K=new_intrinsic, img_h=new_resolution, img_w=new_resolution).transpose(0,1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
c2w = np.linalg.inv(extrinsic)
viewdir = np.matmul(c2w[:3, :3], np.array([0, 0, -1], np.float32).reshape([3, 1])).reshape([-1])
viewdir = torch.from_numpy(viewdir.astype(np.float32))
return {
'extrinsics': torch.from_numpy(extrinsic.astype(np.float32)),
'intrinsics': torch.from_numpy(intrinsic.astype(np.float32)),
'viewdir': viewdir,
'fovx': torch.Tensor([fovx]),
'fovy': torch.Tensor([fovy]),
'world_view_transform': world_view_transform,
'projection_matrix': projection_matrix,
'full_proj_transform': full_proj_transform,
'camera_center': camera_center
}