Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import numpy as np | |
import math | |
from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix | |
class Reenactment(): | |
def __init__(self, dataloader, gaussianhead, supres, camera, recorder, gpu_id, freeview): | |
self.dataloader = dataloader | |
self.gaussianhead = gaussianhead | |
self.supres = supres | |
self.camera = camera | |
self.recorder = recorder | |
self.device = torch.device('cuda:%d' % gpu_id) | |
self.freeview = freeview | |
def run(self, stop_fid=None): | |
for idx, data in tqdm(enumerate(self.dataloader)): | |
to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', | |
'pose', 'scale', 'exp_coeff', 'pose_code'] | |
for data_item in to_cuda: | |
data[data_item] = data[data_item].to(device=self.device) | |
if not self.freeview: | |
if idx > 0: | |
data['pose'] = pose_last * 0.5 + data['pose'] * 0.5 | |
data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5 | |
pose_last = data['pose'] | |
exp_last = data['exp_coeff'] | |
else: | |
data['pose'] *= 0 | |
if idx > 0: | |
data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5 | |
exp_last = data['exp_coeff'] | |
with torch.no_grad(): | |
data = self.gaussianhead.generate(data) | |
data = self.camera.render_gaussian(data, 512) | |
render_images = data['render_images'] | |
supres_images = self.supres(render_images) | |
data['supres_images'] = supres_images | |
log = { | |
'data': data, | |
'iter': idx | |
} | |
self.recorder.log(log) | |
if stop_fid is not None and idx == stop_fid: | |
print('# Reaching stop frame index (%d)' % stop_fid) | |
break | |
def run_for_offline_stitching(self, offline_rendering_param_fpath): | |
head_offline_rendering_param = np.load(offline_rendering_param_fpath) | |
cam_extr = head_offline_rendering_param['cam_extr'] | |
cam_intr = head_offline_rendering_param['cam_intr'] | |
cam_intr_zoom = head_offline_rendering_param['cam_intr_zoom'] | |
zoom_image_size = head_offline_rendering_param['zoom_image_size'] | |
head_pose = head_offline_rendering_param['head_pose'] | |
head_scale = head_offline_rendering_param['head_scale'] | |
head_color_bw = head_offline_rendering_param['head_color_bw'] | |
zoom_scale = head_offline_rendering_param['zoom_scale'] | |
head_pose = torch.from_numpy(head_pose.astype(np.float32)).to(self.device) | |
head_color_bw = torch.from_numpy(head_color_bw.astype(np.float32)).to(self.device) | |
render_size = 512 | |
for idx, data in enumerate(tqdm(self.dataloader)): | |
if idx >= len(cam_extr): | |
print('# Reaching the end of offline stiitching parameters! Rendering stopped. ') | |
break | |
new_gs_camera_param_dict = self.prepare_camera_data_for_gs_rendering(cam_extr[idx], cam_intr_zoom[idx], render_size, render_size) | |
for k in new_gs_camera_param_dict.keys(): | |
if isinstance(new_gs_camera_param_dict[k], torch.Tensor): | |
new_gs_camera_param_dict[k] = new_gs_camera_param_dict[k].unsqueeze(0).to(self.device) | |
new_gs_camera_param_dict['pose'] = head_pose.unsqueeze(0).to(self.device) | |
to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', | |
'pose', 'scale', 'exp_coeff', 'pose_code'] | |
for data_item in to_cuda: | |
data[data_item] = data[data_item].to(device=self.device) | |
data.update(new_gs_camera_param_dict) | |
with torch.no_grad(): | |
data = self.gaussianhead.generate(data) | |
data = self.camera.render_gaussian(data, 512) | |
render_images = data['render_images'] | |
supres_images = self.supres(render_images) | |
data['supres_images'] = supres_images | |
data['bg_color'] = torch.zeros([1, 32], device=self.device, dtype=torch.float32) | |
data['color_bk'] = data.pop('color') | |
data['color'] = torch.ones_like(data['color_bk']) * head_color_bw.reshape([1, -1, 1]) * 2.0 | |
data['color'][:, :, 1] = 1 | |
data['color'] = torch.clamp(data['color'], 0., 1.) | |
data = self.camera.render_gaussian(data, render_size) | |
render_bw = data['render_images'][:, :3, :, :] | |
data['color'] = data.pop('color_bk') | |
data['render_bw'] = render_bw | |
log = { | |
'data': data, | |
'iter': idx | |
} | |
self.recorder.log(log) | |
def prepare_camera_data_for_gs_rendering(self, extrinsic, intrinsic, original_resolution, new_resolution): | |
extrinsic = np.copy(extrinsic) | |
intrinsic = np.copy(intrinsic) | |
new_intrinsic = np.copy(intrinsic) | |
new_intrinsic[:2] *= new_resolution / original_resolution | |
intrinsic[0, 0] = intrinsic[0, 0] * 2 / original_resolution | |
intrinsic[0, 2] = intrinsic[1, 2] * 2 / original_resolution - 1 | |
intrinsic[1, 1] = intrinsic[1, 1] * 2 / original_resolution | |
intrinsic[1, 2] = intrinsic[1, 2] * 2 / original_resolution - 1 | |
fovx = 2 * math.atan(1 / intrinsic[0, 0]) | |
fovy = 2 * math.atan(1 / intrinsic[1, 1]) | |
world_view_transform = torch.tensor(getWorld2View2(extrinsic[:3, :3].transpose(), extrinsic[:3, 3])).transpose(0, 1) | |
projection_matrix = getProjectionMatrix( | |
znear=0.01, zfar=100, fovX=None, fovY=None, | |
K=new_intrinsic, img_h=new_resolution, img_w=new_resolution).transpose(0,1) | |
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) | |
camera_center = world_view_transform.inverse()[3, :3] | |
c2w = np.linalg.inv(extrinsic) | |
viewdir = np.matmul(c2w[:3, :3], np.array([0, 0, -1], np.float32).reshape([3, 1])).reshape([-1]) | |
viewdir = torch.from_numpy(viewdir.astype(np.float32)) | |
return { | |
'extrinsics': torch.from_numpy(extrinsic.astype(np.float32)), | |
'intrinsics': torch.from_numpy(intrinsic.astype(np.float32)), | |
'viewdir': viewdir, | |
'fovx': torch.Tensor([fovx]), | |
'fovy': torch.Tensor([fovy]), | |
'world_view_transform': world_view_transform, | |
'projection_matrix': projection_matrix, | |
'full_proj_transform': full_proj_transform, | |
'camera_center': camera_center | |
} |