from calendar import c import os # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' import yaml import shutil import collections import torch import torch.utils.data import torch.nn.functional as F import numpy as np import cv2 as cv import glob import datetime import trimesh from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import importlib # import config from omegaconf import OmegaConf import json # AnimatableGaussians part from AnimatableGaussians.network.lpips import LPIPS from AnimatableGaussians.dataset.dataset_pose import PoseDataset import AnimatableGaussians.utils.net_util as net_util import AnimatableGaussians.utils.visualize_util as visualize_util from AnimatableGaussians.utils.renderer import Renderer from AnimatableGaussians.utils.net_util import to_cuda from AnimatableGaussians.utils.obj_io import save_mesh_as_ply from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply import AnimatableGaussians.config as ag_config # Gaussian-Head-Avatar part from GHA.config.config import config_reenactment from GHA.lib.dataset.Dataset import ReenactmentDataset from GHA.lib.dataset.DataLoaderX import DataLoaderX from GHA.lib.module.GaussianHeadModule import GaussianHeadModule from GHA.lib.module.SuperResolutionModule import SuperResolutionModule from GHA.lib.module.CameraModule import CameraModule from GHA.lib.recorder.Recorder import ReenactmentRecorder from GHA.lib.apps.Reenactment import Reenactment # cat utils from calc_offline_rendering_param import calc_offline_rendering_param import ipdb class Avatar: def __init__(self, config): self.config = config self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # animateble gaussians part init self.body = config.animatablegaussians self.body.mode = 'test' ag_config.set_opt(self.body) avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar') print('Import AvatarNet from %s' % avatar_module) AvatarNet = importlib.import_module(avatar_module).AvatarNet self.avatar_net = AvatarNet(self.body.model).to(self.device) self.random_bg_color = self.body['train'].get('random_bg_color', True) self.bg_color = (1., 1., 1.) self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device) self.loss_weight = self.body['train']['loss_weight'] self.finetune_color = self.body['train']['finetune_color'] print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) # gaussian head avatar part init self.head = config.gha self.head_config = config_reenactment() self.head_config.load(self.head.config_path) self.head_config = self.head_config.get_cfg() # cat utils part init self.cat = config.cat @torch.no_grad() def test_body(self): # run the animatable gaussian test self.avatar_net.eval() dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX') MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module) training_dataset = MvRgbDataset(**self.body['train']['data'], training = False) if self.body['test'].get('n_pca', -1) >= 1: training_dataset.compute_pca(n_components = self.body['test']['n_pca']) if 'pose_data' in self.body.test: testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = training_dataset.smpl_data['betas'][0]) dataset_name = testing_dataset.dataset_name seq_name = testing_dataset.seq_name else: # throw an error raise ValueError('No pose data in test config') self.dataset = testing_dataset # iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1] output_dir = self.body['test'].get('output_dir', None) if output_dir is None: raise ValueError('No output_dir in test config') use_pca = self.body['test'].get('n_pca', -1) >= 1 if use_pca: output_dir += '/pca_%d_sigma_%.2f' % (self.body['test'].get('n_pca', -1), float(self.body['test'].get('sigma_pca', 1.))) else: output_dir += '/vanilla' print('# Output dir: \033[1;31m%s\033[0m' % output_dir) os.makedirs(output_dir + '/live_skeleton', exist_ok = True) os.makedirs(output_dir + '/rgb_map', exist_ok = True) os.makedirs(output_dir + '/rgb_map_wo_hand', exist_ok = True) os.makedirs(output_dir + '/torso_map', exist_ok = True) os.makedirs(output_dir + '/mask_map', exist_ok = True) os.makedirs(output_dir + '/posed_gaussians', exist_ok = True) os.makedirs(output_dir + '/posed_params', exist_ok = True) os.makedirs(output_dir + '/full_body_mask', exist_ok = True) os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) geo_renderer = None item_0 = self.dataset.getitem(0, training = False) object_center = item_0['live_bounds'].mean(0) global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] # set x and z to 0 global_orient[0] = 0 global_orient[2] = 0 global_orient = cv.Rodrigues(global_orient)[0] time_start = torch.cuda.Event(enable_timing = True) time_start_all = torch.cuda.Event(enable_timing = True) time_end = torch.cuda.Event(enable_timing = True) data_num = len(self.dataset) if self.body['test'].get('fix_hand', False): self.avatar_net.generate_mean_hands() log_time = False extr_list = [] intr_list = [] img_h_list = [] img_w_list = [] for idx in tqdm(range(data_num), desc = 'Rendering avatars...'): if log_time: time_start.record() time_start_all.record() img_scale = self.body['test'].get('img_scale', 1.0) view_setting = self.body['test'].get('view_setting', 'free') if view_setting == 'camera': # training view setting cam_id = self.body['test']['render_view_idx'] intr = self.dataset.intr_mats[cam_id].copy() intr[:2] *= img_scale extr = self.dataset.extr_mats[cam_id].copy() img_h, img_w = int(self.dataset.img_heights[cam_id] * img_scale), int(self.dataset.img_widths[cam_id] * img_scale) elif view_setting.startswith('free'): # free view setting # frame_num_per_circle = 360 # print(self.opt['test'].get('global_orient', False)) frame_num_per_circle = 360 rot_Y = (idx % frame_num_per_circle) / float(frame_num_per_circle) * 2 * np.pi extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = rot_Y, rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.body['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) elif view_setting.startswith('degree120'): print('we render 120 degree') # +- 60 degree frame_per_cycle = 480 max_degree = 60 frame_half_cycle = frame_per_cycle // 2 if idx%frame_per_cycle < frame_per_cycle/2: rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi else: rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # to radian rot_Y = rot_Y * np.pi / 180 if rot_Y<0: rot_Y = rot_Y + 2 * np.pi # print('rot_Y: ', rot_Y) extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = rot_Y, rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.body['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) elif view_setting.startswith('degree90'): print('we render 90 degree') # +- 60 degree frame_per_cycle = 360 max_degree = 45 frame_half_cycle = frame_per_cycle // 2 if idx%frame_per_cycle < frame_per_cycle/2: rot_Y = -max_degree + (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # rot_Y = (idx % frame_per_60) / float(frame_per_60) * 2 * np.pi else: rot_Y = max_degree - (2 * max_degree / frame_half_cycle) * (idx%frame_half_cycle) # to radian rot_Y = rot_Y * np.pi / 180 if rot_Y<0: rot_Y = rot_Y + 2 * np.pi # print('rot_Y: ', rot_Y) extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = rot_Y, rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.body['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) elif view_setting.startswith('front'): # front view setting extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = 0., rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.body['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) extr_list.append(extr) intr_list.append(intr) img_h_list.append(img_h) img_w_list.append(img_w) # print('extr: ', extr) # print('intr: ', intr) # print('img_h: ', img_h) # print('img_w: ', img_w) # exit() elif view_setting.startswith('back'): # back view setting extr = visualize_util.calc_free_mv(object_center, tar_pos = np.array([0, 0, 2.5]), rot_Y = np.pi, rot_X = 0.5 * np.pi / 4. if view_setting.endswith('bird') else 0., global_orient = global_orient if self.body['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) elif view_setting.startswith('moving'): # moving camera setting extr = visualize_util.calc_free_mv(object_center, # tar_pos = np.array([0, 0, 3.0]), # rot_Y = -0.3, tar_pos = np.array([0, 0, 2.5]), rot_Y = 0., rot_X = 0.3 if view_setting.endswith('bird') else 0., global_orient = global_orient if self.body['test'].get('global_orient', False) else None) intr = np.array([[1100, 0, 512], [0, 1100, 512], [0, 0, 1]], np.float32) intr[:2] *= img_scale img_h = int(1024 * img_scale) img_w = int(1024 * img_scale) elif view_setting.startswith('cano'): cano_center = self.dataset.cano_bounds.mean(0) extr = np.identity(4, np.float32) extr[:3, 3] = -cano_center rot_x = np.identity(4, np.float32) rot_x[:3, :3] = cv.Rodrigues(np.array([np.pi, 0, 0], np.float32))[0] extr = rot_x @ extr f_len = 5000 extr[2, 3] += f_len / 512 intr = np.array([[f_len, 0, 512], [0, f_len, 512], [0, 0, 1]], np.float32) # item = self.dataset.getitem(idx, # training = False, # extr = extr, # intr = intr, # img_w = 1024, # img_h = 1024) img_w, img_h = 1024, 1024 # item['live_smpl_v'] = item['cano_smpl_v'] # item['cano2live_jnt_mats'] = torch.eye(4, dtype = torch.float32)[None].expand(item['cano2live_jnt_mats'].shape[0], -1, -1) # item['live_bounds'] = item['cano_bounds'] else: raise ValueError('Invalid view setting for animation!') self.dump_renderer_info(output_dir, extr_list, intr_list, img_h_list, img_w_list) # also save the extr and intr and img_h and img_w to json camera_info = [] for i in range(len(extr_list)): camera = {} camera['extr'] = extr_list[i].tolist() camera['intr'] = intr_list[i].tolist() camera['img_h'] = img_h_list[i] camera['img_w'] = img_w_list[i] camera_info.append(camera) with open(os.path.join(output_dir, 'camera_info.json'), 'w') as fp: json.dump(camera_info, fp) getitem_func = self.dataset.getitem_fast if hasattr(self.dataset, 'getitem_fast') else self.dataset.getitem item = getitem_func( idx, training = False, extr = extr, intr = intr, img_w = img_w, img_h = img_h ) items = to_cuda(item, add_batch = False) if view_setting.startswith('moving') or view_setting == 'free_moving': current_center = items['live_bounds'].cpu().numpy().mean(0) delta = current_center - object_center object_center[0] += delta[0] # object_center[1] += delta[1] # object_center[2] += delta[2] if log_time: time_end.record() torch.cuda.synchronize() print('Loading data costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() if self.body['test'].get('render_skeleton', False): from AnimatableGaussians.utils.visualize_skeletons import construct_skeletons skel_vertices, skel_faces = construct_skeletons(item['joints'].cpu().numpy(), item['kin_parent'].cpu().numpy()) skel_mesh = trimesh.Trimesh(skel_vertices, skel_faces, process = False) if geo_renderer is None: geo_renderer = Renderer(item['img_w'], item['img_h'], shader_name = 'phong_geometry', bg_color = (1, 1, 1)) extr, intr = item['extr'], item['intr'] geo_renderer.set_camera(extr, intr) geo_renderer.set_model(skel_vertices[skel_faces.reshape(-1)], skel_mesh.vertex_normals.astype(np.float32)[skel_faces.reshape(-1)]) skel_img = geo_renderer.render()[:, :, :3] skel_img = (skel_img * 255).astype(np.uint8) cv.imwrite(output_dir + '/live_skeleton/%08d.jpg' % item['data_idx'], skel_img) if log_time: time_end.record() torch.cuda.synchronize() print('Rendering skeletons costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() if 'smpl_pos_map' not in items: self.avatar_net.get_pose_map(items) # pca if use_pca: mask = training_dataset.pos_map_mask live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) pose_conds = front_live_pos_map[mask] new_pose_conds = training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.))) front_live_pos_map[mask] = new_pose_conds live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) items.update({ 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1) }) if log_time: time_end.record() torch.cuda.synchronize() print('Rendering pose conditions costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) if log_time: time_end.record() torch.cuda.synchronize() print('Rendering avatar costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) time_start.record() if 'rgb_map' in output_wo_hand: rgb_map_wo_hand = output_wo_hand['rgb_map'] if 'full_body_rgb_map' in mask_output: os.makedirs(output_dir + '/full_body_mask', exist_ok = True) full_body_mask = mask_output['full_body_rgb_map'] full_body_mask.clip_(0., 1.) full_body_mask = (full_body_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/full_body_mask/%08d.png' % item['data_idx'], full_body_mask.cpu().numpy()) if 'hand_only_rgb_map' in mask_output: os.makedirs(output_dir + '/hand_only_mask', exist_ok = True) hand_only_mask = mask_output['hand_only_rgb_map'] hand_only_mask.clip_(0., 1.) hand_only_mask = (hand_only_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/hand_only_mask/%08d.png' % item['data_idx'], hand_only_mask.cpu().numpy()) if 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: # mask only covers hand body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.01 if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 if_mask_r_hand = if_mask_r_hand.cpu().numpy() body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 if_mask_l_hand = if_mask_l_hand.cpu().numpy() # 保存左右手被遮挡部分的mask red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) all_mask = red_mask | blue_mask # now save 3 mask to 3 folders os.makedirs(output_dir + '/hand_mask', exist_ok = True) os.makedirs(output_dir + '/r_hand_mask', exist_ok = True) os.makedirs(output_dir + '/l_hand_mask', exist_ok = True) os.makedirs(output_dir + '/hand_visual', exist_ok = True) all_mask = (all_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/hand_mask/%08d.png' % item['data_idx'], all_mask.cpu().numpy()) r_hand_mask = (body_red_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/r_hand_mask/%08d.png' % item['data_idx'], r_hand_mask.cpu().numpy()) l_hand_mask = (body_blue_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/l_hand_mask/%08d.png' % item['data_idx'], l_hand_mask.cpu().numpy()) hand_visual = [if_mask_r_hand, if_mask_l_hand] # save to npy with open(output_dir + '/hand_visual/%08d.npy' % item['data_idx'], 'wb') as f: np.save(f, hand_visual) # now build sleeve_mask if 'left_hand_rgb_map' in mask_output and 'right_hand_rgb_map' in mask_output: os.makedirs(output_dir + '/left_sleeve_mask', exist_ok = True) os.makedirs(output_dir + '/right_sleeve_mask', exist_ok = True) mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) mask = mask.cpu().numpy().astype(np.uint8) # 定义一个结构元素,可以调整其大小以改变膨胀的程度 kernel = np.ones((5, 5), np.uint8) # 应用膨胀操作 mask = cv.dilate(mask, kernel, iterations=3) mask = torch.tensor(mask).to(self.device) left_hand_mask = mask_output['left_hand_rgb_map'] left_hand_mask.clip_(0., 1.) # non white part is mask left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 # dele two hand mask left_hand_mask = left_hand_mask & ~mask right_hand_mask = mask_output['right_hand_rgb_map'] right_hand_mask.clip_(0., 1.) right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 right_hand_mask = right_hand_mask & ~mask # save left_hand_mask = (left_hand_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/left_sleeve_mask/%08d.png' % item['data_idx'], left_hand_mask.cpu().numpy()) right_hand_mask = (right_hand_mask * 255).to(torch.uint8) cv.imwrite(output_dir + '/right_sleeve_mask/%08d.png' % item['data_idx'], right_hand_mask.cpu().numpy()) rgb_map = output['rgb_map'] rgb_map.clip_(0., 1.) rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() cv.imwrite(output_dir + '/rgb_map/%08d.jpg' % item['data_idx'], rgb_map) # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map if 'rgb_map' in output_wo_hand and 'full_body_rgb_map' in mask_output and 'hand_only_rgb_map' in mask_output: rgb_map_wo_hand = output_wo_hand['rgb_map'] rgb_map_wo_hand.clip_(0., 1.) rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() r_mask = (r_hand_mask>128).cpu().numpy() l_mask = (l_hand_mask>128).cpu().numpy() mask = r_mask | l_mask mask = mask.astype(np.uint8) # 定义一个结构元素,可以调整其大小以改变膨胀的程度 kernel = np.ones((5, 5), np.uint8) # 应用膨胀操作 mask = cv.dilate(mask, kernel, iterations=3) mask = mask.astype(np.bool_) mask = np.expand_dims(mask, axis=2) # print('mask shape: ', mask.shape) import ipdb # ipdb.set_trace() mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask cv.imwrite(output_dir + '/rgb_map_wo_hand/%08d.png' % item['data_idx'], mix) if 'torso_map' in output: os.makedirs(output_dir + '/torso_map', exist_ok = True) torso_map = output['torso_map'][:, :, 0] torso_map.clip_(0., 1.) torso_map = (torso_map * 255).to(torch.uint8) cv.imwrite(output_dir + '/torso_map/%08d.png' % item['data_idx'], torso_map.cpu().numpy()) if 'mask_map' in output: os.makedirs(output_dir + '/mask_map', exist_ok = True) mask_map = output['mask_map'][:, :, 0] mask_map.clip_(0., 1.) mask_map = (mask_map * 255).to(torch.uint8) cv.imwrite(output_dir + '/mask_map/%08d.png' % item['data_idx'], mask_map.cpu().numpy()) if self.body['test'].get('save_tex_map', False): os.makedirs(output_dir + '/cano_tex_map', exist_ok = True) cano_tex_map = output['cano_tex_map'] cano_tex_map.clip_(0., 1.) cano_tex_map = (cano_tex_map * 255).to(torch.uint8) cv.imwrite(output_dir + '/cano_tex_map/%08d.png' % item['data_idx'], cano_tex_map.cpu().numpy()) if self.body['test'].get('save_ply', False): if item['data_idx'] == 0: save_gaussians_as_ply(output_dir + '/posed_gaussians/%08d.ply' % item['data_idx'], output['posed_gaussians']) for k in output['posed_gaussians'].keys(): if isinstance(output['posed_gaussians'][k], torch.Tensor): output['posed_gaussians'][k] = output['posed_gaussians'][k].detach().cpu().numpy() np.savez(output_dir + '/posed_gaussians/%08d.npz' % item['data_idx'], **output['posed_gaussians']) np.savez(output_dir + ('/posed_params/%08d.npz' % item['data_idx']), betas=training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), global_orient=item['global_orient'].reshape([-1]).detach().cpu().numpy(), transl=item['transl'].reshape([-1]).detach().cpu().numpy(), body_pose=item['body_pose'].reshape([-1]).detach().cpu().numpy()) if log_time: time_end.record() torch.cuda.synchronize() print('Saving images costs %.4f secs' % (time_start.elapsed_time(time_end) / 1000.)) print('Animating one frame costs %.4f secs' % (time_start_all.elapsed_time(time_end) / 1000.)) torch.cuda.empty_cache() def dump_renderer_info(self, dump_dir, extrs, intrs, img_heights, img_widths): with open(os.path.join(dump_dir, 'cfg_args'), 'w') as fp: outstr = "Namespace(sh_degree=%d, source_path='%s', model_path='%s', images='images', resolution=-1, " \ "white_background=False, data_device='cuda', eval=False)" % ( 3, self.body['train']['data']['data_dir'], dump_dir) fp.write(outstr) with open(os.path.join(dump_dir, 'cameras.json'), 'w') as fp: cam_jsons = [] for ci in range(len(extrs)): extr, intr = extrs[ci], intrs[ci] img_h, img_w = img_heights[ci], img_widths[ci] w2c = extr c2w = np.linalg.inv(w2c) pos = c2w[:3, 3] rot = c2w[:3, :3] serializable_array_2d = [x.tolist() for x in rot] camera_entry = { 'id': ci, 'img_name': '%08d' % ci, 'width': int(img_w), 'height': int(img_h), 'position': pos.tolist(), 'rotation': serializable_array_2d, 'fy': float(intr[1, 1]), 'fx': float(intr[0, 0]), } cam_jsons.append(camera_entry) json.dump(cam_jsons, fp) return def test_head(self): dataset = ReenactmentDataset(self.head_config.dataset) dataloader = DataLoaderX(dataset, batch_size=1, shuffle=False, pin_memory=True) device = torch.device('cuda:%d' % self.head_config.gpu_id) gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule, xyz=gaussianhead_state_dict['xyz'], feature=gaussianhead_state_dict['feature'], landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(device) gaussianhead.load_state_dict(gaussianhead_state_dict) supres = SuperResolutionModule(self.head_config.supresmodule).to(device) supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage)) camera = CameraModule() recorder = ReenactmentRecorder(self.head_config.recorder) app = Reenactment(dataloader, gaussianhead, supres, camera, recorder, self.head_config.gpu_id, dataset.freeview) if self.head.offline_rendering_param_fpath is None: app.run(stop_fid=800) else: app.run_for_offline_stitching(self.head.offline_rendering_param_fpath) def cal_cat_param(self): calc_offline_rendering_param( self.cat.body_gaussian_root_dir, self.cat.ref_head_gaussian_path, self.cat.ref_head_param_path, self.cat.render_cam_fpath, self.cat.body_head_blending_param_path ) if __name__ == '__main__': conf = OmegaConf.load('configs/example.yaml') avatar = Avatar(conf) avatar.test_body() # avatar.test_head()