Spaces:
Running
Running
File size: 7,118 Bytes
ec9a6bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import math
from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Reenactment():
def __init__(self, dataloader, gaussianhead, supres, camera, recorder, gpu_id, freeview):
self.dataloader = dataloader
self.gaussianhead = gaussianhead
self.supres = supres
self.camera = camera
self.recorder = recorder
self.device = torch.device('cuda:%d' % gpu_id)
self.freeview = freeview
def run(self, stop_fid=None):
for idx, data in tqdm(enumerate(self.dataloader)):
to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
'pose', 'scale', 'exp_coeff', 'pose_code']
for data_item in to_cuda:
data[data_item] = data[data_item].to(device=self.device)
if not self.freeview:
if idx > 0:
data['pose'] = pose_last * 0.5 + data['pose'] * 0.5
data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5
pose_last = data['pose']
exp_last = data['exp_coeff']
else:
data['pose'] *= 0
if idx > 0:
data['exp_coeff'] = exp_last * 0.5 + data['exp_coeff'] * 0.5
exp_last = data['exp_coeff']
with torch.no_grad():
data = self.gaussianhead.generate(data)
data = self.camera.render_gaussian(data, 512)
render_images = data['render_images']
supres_images = self.supres(render_images)
data['supres_images'] = supres_images
log = {
'data': data,
'iter': idx
}
self.recorder.log(log)
if stop_fid is not None and idx == stop_fid:
print('# Reaching stop frame index (%d)' % stop_fid)
break
def run_for_offline_stitching(self, offline_rendering_param_fpath):
head_offline_rendering_param = np.load(offline_rendering_param_fpath)
cam_extr = head_offline_rendering_param['cam_extr']
cam_intr = head_offline_rendering_param['cam_intr']
cam_intr_zoom = head_offline_rendering_param['cam_intr_zoom']
zoom_image_size = head_offline_rendering_param['zoom_image_size']
head_pose = head_offline_rendering_param['head_pose']
head_scale = head_offline_rendering_param['head_scale']
head_color_bw = head_offline_rendering_param['head_color_bw']
zoom_scale = head_offline_rendering_param['zoom_scale']
head_pose = torch.from_numpy(head_pose.astype(np.float32)).to(self.device)
head_color_bw = torch.from_numpy(head_color_bw.astype(np.float32)).to(self.device)
render_size = 512
for idx, data in enumerate(tqdm(self.dataloader)):
if idx >= len(cam_extr):
print('# Reaching the end of offline stiitching parameters! Rendering stopped. ')
break
new_gs_camera_param_dict = self.prepare_camera_data_for_gs_rendering(cam_extr[idx], cam_intr_zoom[idx], render_size, render_size)
for k in new_gs_camera_param_dict.keys():
if isinstance(new_gs_camera_param_dict[k], torch.Tensor):
new_gs_camera_param_dict[k] = new_gs_camera_param_dict[k].unsqueeze(0).to(self.device)
new_gs_camera_param_dict['pose'] = head_pose.unsqueeze(0).to(self.device)
to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
'pose', 'scale', 'exp_coeff', 'pose_code']
for data_item in to_cuda:
data[data_item] = data[data_item].to(device=self.device)
data.update(new_gs_camera_param_dict)
with torch.no_grad():
data = self.gaussianhead.generate(data)
data = self.camera.render_gaussian(data, 512)
render_images = data['render_images']
supres_images = self.supres(render_images)
data['supres_images'] = supres_images
data['bg_color'] = torch.zeros([1, 32], device=self.device, dtype=torch.float32)
data['color_bk'] = data.pop('color')
data['color'] = torch.ones_like(data['color_bk']) * head_color_bw.reshape([1, -1, 1]) * 2.0
data['color'][:, :, 1] = 1
data['color'] = torch.clamp(data['color'], 0., 1.)
data = self.camera.render_gaussian(data, render_size)
render_bw = data['render_images'][:, :3, :, :]
data['color'] = data.pop('color_bk')
data['render_bw'] = render_bw
log = {
'data': data,
'iter': idx
}
self.recorder.log(log)
def prepare_camera_data_for_gs_rendering(self, extrinsic, intrinsic, original_resolution, new_resolution):
extrinsic = np.copy(extrinsic)
intrinsic = np.copy(intrinsic)
new_intrinsic = np.copy(intrinsic)
new_intrinsic[:2] *= new_resolution / original_resolution
intrinsic[0, 0] = intrinsic[0, 0] * 2 / original_resolution
intrinsic[0, 2] = intrinsic[1, 2] * 2 / original_resolution - 1
intrinsic[1, 1] = intrinsic[1, 1] * 2 / original_resolution
intrinsic[1, 2] = intrinsic[1, 2] * 2 / original_resolution - 1
fovx = 2 * math.atan(1 / intrinsic[0, 0])
fovy = 2 * math.atan(1 / intrinsic[1, 1])
world_view_transform = torch.tensor(getWorld2View2(extrinsic[:3, :3].transpose(), extrinsic[:3, 3])).transpose(0, 1)
projection_matrix = getProjectionMatrix(
znear=0.01, zfar=100, fovX=None, fovY=None,
K=new_intrinsic, img_h=new_resolution, img_w=new_resolution).transpose(0,1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
c2w = np.linalg.inv(extrinsic)
viewdir = np.matmul(c2w[:3, :3], np.array([0, 0, -1], np.float32).reshape([3, 1])).reshape([-1])
viewdir = torch.from_numpy(viewdir.astype(np.float32))
return {
'extrinsics': torch.from_numpy(extrinsic.astype(np.float32)),
'intrinsics': torch.from_numpy(intrinsic.astype(np.float32)),
'viewdir': viewdir,
'fovx': torch.Tensor([fovx]),
'fovy': torch.Tensor([fovy]),
'world_view_transform': world_view_transform,
'projection_matrix': projection_matrix,
'full_proj_transform': full_proj_transform,
'camera_center': camera_center
} |