from calendar import c import os # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' import yaml import shutil import collections import torch import torch.utils.data import torch.nn.functional as F import numpy as np import cv2 as cv import glob import datetime import trimesh from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import importlib # import config from omegaconf import OmegaConf import json import math import cv2 # AnimatableGaussians part from AnimatableGaussians.network.lpips import LPIPS from AnimatableGaussians.dataset.dataset_pose import PoseDataset import AnimatableGaussians.utils.net_util as net_util # import AnimatableGaussians.utils.visualize_util as visualize_util from AnimatableGaussians.utils.camera_dir import get_camera_dir from AnimatableGaussians.utils.renderer import Renderer from AnimatableGaussians.utils.net_util import to_cuda from AnimatableGaussians.utils.obj_io import save_mesh_as_ply from AnimatableGaussians.gaussians.obj_io import save_gaussians_as_ply import AnimatableGaussians.config as ag_config # Gaussian-Head-Avatar part from GHA.config.config import config_reenactment from GHA.lib.dataset.Dataset import ReenactmentDataset from GHA.lib.dataset.DataLoaderX import DataLoaderX from GHA.lib.module.GaussianHeadModule import GaussianHeadModule from GHA.lib.module.SuperResolutionModule import SuperResolutionModule from GHA.lib.module.CameraModule import CameraModule from GHA.lib.recorder.Recorder import ReenactmentRecorder from GHA.lib.apps.Reenactment import Reenactment from GHA.lib.utils.graphics_utils import getWorld2View2, getProjectionMatrix # cat utils from calc_offline_rendering_param import calc_offline_rendering_param from calc_offline_rendering_param import load_camera_data from render_utils.lib.networks.smpl_torch import SmplTorch from render_utils.lib.utils.gaussian_np_utils import load_gaussians_from_ply from render_utils.stitch_body_and_head import load_body_params, load_face_params, get_smpl_verts_and_head_transformation, calc_livehead2livebody from render_utils.stitch_funcs import soften_blending_mask,paste_back_with_linear_interp import ipdb class Avatar: def __init__(self, config): self.config = config self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # animateble gaussians part init self.body = config.animatablegaussians self.body.mode = 'test' ag_config.set_opt(self.body) avatar_module = self.body['model'].get('module', 'AnimatableGaussians.network.avatar') print('Import AvatarNet from %s' % avatar_module) AvatarNet = importlib.import_module(avatar_module).AvatarNet self.avatar_net = AvatarNet(self.body.model).to(self.device) self.random_bg_color = self.body['train'].get('random_bg_color', True) self.bg_color = (1., 1., 1.) self.bg_color_cuda = torch.from_numpy(np.asarray(self.bg_color)).to(torch.float32).to(self.device) self.loss_weight = self.body['train']['loss_weight'] self.finetune_color = self.body['train']['finetune_color'] print('# Parameter number of AvatarNet is %d' % (sum([p.numel() for p in self.avatar_net.parameters()]))) # gaussian head avatar part init self.head = config.gha # cat utils part init self.cat = config.cat def build_dataset(self, body_pose_path=None, face_exp_path=None): # build body_dataset if body_pose_path is not None: self.body['test']['pose_data']['data_path'] = body_pose_path body_pose = np.load(body_pose_path, allow_pickle = True) # print('body_pose keys:', body_pose.keys()) # print('body_pose shape:', body_pose['poses'].shape) self.body['test']['pose_data']['frame_range'] = [0,body_pose['poses'].shape[0]] dataset_module = self.body.get('dataset', 'MvRgbDatasetAvatarReX') MvRgbDataset = importlib.import_module('AnimatableGaussians.dataset.dataset_mv_rgb').__getattribute__(dataset_module) self.body_training_dataset = MvRgbDataset(**self.body['train']['data'], training = False) if self.body['test'].get('n_pca', -1) >= 1: self.body_training_dataset.compute_pca(n_components = self.body['test']['n_pca']) if 'pose_data' in self.body.test: testing_dataset = PoseDataset(**self.body['test']['pose_data'], smpl_shape = self.body_training_dataset.smpl_data['betas'][0]) dataset_name = testing_dataset.dataset_name seq_name = testing_dataset.seq_name else: # throw an error raise ValueError('No pose data in test config') self.body_dataset = testing_dataset iter_idx = self.load_ckpt(self.body['test']['prev_ckpt'], False)[1] self.head_config = config_reenactment() self.head_config.load(self.head.config_path) if face_exp_path is not None: self.head_config.cfg.dataset.exp_path = face_exp_path self.head_config.freeze() self.head_config = self.head_config.get_cfg() # build face dataset self.head_dataset = ReenactmentDataset(self.head_config.dataset) self.head_dataloader = DataLoaderX(self.head_dataset, batch_size=1, shuffle=False, pin_memory=True) # device = torch.device('cuda:%d' % cfg.gpu_id) gaussianhead_state_dict = torch.load(self.head_config.load_gaussianhead_checkpoint, map_location=lambda storage, loc: storage) self.gaussianhead = GaussianHeadModule(self.head_config.gaussianheadmodule, xyz=gaussianhead_state_dict['xyz'], feature=gaussianhead_state_dict['feature'], landmarks_3d_neutral=gaussianhead_state_dict['landmarks_3d_neutral']).to(self.device) self.gaussianhead.load_state_dict(gaussianhead_state_dict) self.supres = SuperResolutionModule(self.head_config.supresmodule).to(self.device) self.supres.load_state_dict(torch.load(self.head_config.load_supres_checkpoint, map_location=lambda storage, loc: storage)) self.head_camera = CameraModule() self.head_recorder = ReenactmentRecorder(self.head_config.recorder) def render_all(self): # len = short one lenth = min(len(self.body_dataset), len(self.head_dataloader)) # build a tqdm bar for idx in tqdm(range(lenth)): self.reder_frame(idx) # for idx in range(lenth): # self.reder_frame(idx) def reder_frame(self, idx): # 渲染身体和各种mask body_output = self.build_body(idx) # 计算头的渲染参数 head_param = self.build_param(idx,body_output) # 渲染头 head_output = self.build_head(idx, head_param) # 把头和身体拼接起来 body_rendering= body_output['rgb_map_wo_hand'].astype(np.float32) / 255.0 # save body_rendering # cv.imwrite('./output' + '/body_rgb_%08d.jpg' % idx, (body_output['rgb_map']).astype(np.uint8)) # cv.imwrite('./output' + '/body_rgb_wo_hand%08d.jpg' % idx, (body_output['rgb_map_wo_hand']).astype(np.uint8)) body_mask = body_output['mask_map'].astype(np.float32) / 255.0 body_torso_mask = body_output['torso_map'].astype(np.float32) / 255.0 head_rendering = head_output['render_images'].astype(np.float32) / 255.0 head_blending_mask = head_output['render_bw'].astype(np.float32) / 255.0 body_head_blending_params = np.load(self.cat.body_head_blending_param_path) head_offline_rendering_param = head_param stitch_output = self.stich_head_body(body_rendering, body_mask, body_torso_mask, head_rendering, head_blending_mask, body_head_blending_params, head_offline_rendering_param) cv.imwrite('./output' + '/%08d.jpg' % idx, stitch_output) # 渲染手和手的mask # 把手拼上去 return stitch_output pass def load_ckpt(self, path, load_optm = True): print('Loading networks from ', path + '/net.pt') net_dict = torch.load(path + '/net.pt') if 'avatar_net' in net_dict: self.avatar_net.load_state_dict(net_dict['avatar_net']) else: print('[WARNING] Cannot find "avatar_net" from the network checkpoint!') epoch_idx = net_dict['epoch_idx'] iter_idx = net_dict['iter_idx'] # if load_optm and os.path.exists(path + '/optm.pt'): # print('Loading optimizers from ', path + '/optm.pt') # optm_dict = torch.load(path + '/optm.pt') # if 'avatar_net' in optm_dict: # self.optm.load_state_dict(optm_dict['avatar_net']) # else: # print('[WARNING] Cannot find "avatar_net" from the optimizer checkpoint!') return epoch_idx, iter_idx @torch.no_grad() def build_body(self,idx): self.avatar_net.eval() geo_renderer = None item_0 = self.body_dataset.getitem(0, training = False) object_center = item_0['live_bounds'].mean(0) global_orient = item_0['global_orient'].cpu().numpy() if isinstance(item_0['global_orient'], torch.Tensor) else item_0['global_orient'] use_pca = self.body['test'].get('n_pca', -1) >= 1 # set x and z to 0 global_orient[0] = 0 global_orient[2] = 0 global_orient = cv.Rodrigues(global_orient)[0] time_start = torch.cuda.Event(enable_timing = True) time_start_all = torch.cuda.Event(enable_timing = True) time_end = torch.cuda.Event(enable_timing = True) if self.body['test'].get('fix_hand', False): self.avatar_net.generate_mean_hands() img_scale = self.body['test'].get('img_scale', 1.0) view_setting = self.body['test'].get('view_setting', 'free') extr, intr, img_h, img_w = get_camera_dir(idx, object_center, global_orient, img_scale, view_setting) w2c = extr c2w = np.linalg.inv(w2c) pos = c2w[:3, 3] rot = c2w[:3, :3] serializable_array_2d = [x.tolist() for x in rot] camera_entry = { 'width': int(img_w), 'height': int(img_h), 'position': pos.tolist(), 'rotation': serializable_array_2d, 'fy': float(intr[1, 1]), 'fx': float(intr[0, 0]), } getitem_func = self.body_dataset.getitem_fast if hasattr(self.body_dataset, 'getitem_fast') else self.body_dataset.getitem item = getitem_func( idx, training = False, extr = extr, intr = intr, img_w = img_w, img_h = img_h ) items = to_cuda(item, add_batch = False) if 'smpl_pos_map' not in items: self.avatar_net.get_pose_map(items) # pca if use_pca: mask = self.body_training_dataset.pos_map_mask live_pos_map = items['smpl_pos_map'].permute(1, 2, 0).cpu().numpy() front_live_pos_map, back_live_pos_map = np.split(live_pos_map, [3], 2) pose_conds = front_live_pos_map[mask] new_pose_conds = self.body_training_dataset.transform_pca(pose_conds, sigma_pca = float(self.body['test'].get('sigma_pca', 2.))) front_live_pos_map[mask] = new_pose_conds live_pos_map = np.concatenate([front_live_pos_map, back_live_pos_map], 2) items.update({ 'smpl_pos_map_pca': torch.from_numpy(live_pos_map).to(self.device).permute(2, 0, 1) }) # print items # print(items.keys()) # print(items.values()) # exit() # get render result output = self.avatar_net.render(items, bg_color = self.bg_color, use_pca = use_pca) output_wo_hand = self.avatar_net.render_wo_hand(items, bg_color = self.bg_color, use_pca = use_pca) mask_output = self.avatar_net.render_mask(items, bg_color = self.bg_color, use_pca = use_pca) # do some postprocess rgb_map_wo_hand = output_wo_hand['rgb_map'] full_body_mask = mask_output['full_body_rgb_map'] full_body_mask.clip_(0., 1.) full_body_mask = (full_body_mask * 255).to(torch.uint8) hand_only_mask = mask_output['hand_only_rgb_map'] hand_only_mask.clip_(0., 1.) hand_only_mask = (hand_only_mask * 255).to(torch.uint8) # build the covered hand mask and the hand visualbility flag body_red_mask = (mask_output['full_body_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['full_body_rgb_map'].device)) body_red_mask = (body_red_mask*body_red_mask).sum(dim=2) < 0.01 # need save hand_red_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([1., 0., 0.], device = mask_output['hand_only_rgb_map'].device)) hand_red_mask = (hand_red_mask*hand_red_mask).sum(dim=2) < 0.0 if_mask_r_hand = abs(body_red_mask.sum() - hand_red_mask.sum()) / hand_red_mask.sum() > 0.95 if_mask_r_hand = if_mask_r_hand.cpu().numpy() body_blue_mask = (mask_output['full_body_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['full_body_rgb_map'].device)) body_blue_mask = (body_blue_mask*body_blue_mask).sum(dim=2) < 0.01 # need save hand_blue_mask = (mask_output['hand_only_rgb_map'] - torch.tensor([0., 0., 1.], device = mask_output['hand_only_rgb_map'].device)) hand_blue_mask = (hand_blue_mask*hand_blue_mask).sum(dim=2) < 0.01 if_mask_l_hand = abs(body_blue_mask.sum() - hand_blue_mask.sum()) / hand_blue_mask.sum() > 0.95 if_mask_l_hand = if_mask_l_hand.cpu().numpy() # 保存左右手被遮挡部分的mask red_mask = hand_red_mask ^ (hand_red_mask & body_red_mask) blue_mask = hand_blue_mask ^ (hand_blue_mask & body_blue_mask) all_mask = red_mask | blue_mask all_mask = (all_mask * 255).to(torch.uint8) r_hand_mask = (body_red_mask * 255).to(torch.uint8) l_hand_mask = (body_blue_mask * 255).to(torch.uint8) hand_visual = [if_mask_r_hand, if_mask_l_hand] # build sleeve mask mask = (r_hand_mask>128) | (l_hand_mask>128)| (all_mask>128) mask = mask.cpu().numpy().astype(np.uint8) # 定义一个结构元素,可以调整其大小以改变膨胀的程度 kernel = np.ones((5, 5), np.uint8) # 应用膨胀操作 mask = cv.dilate(mask, kernel, iterations=3) mask = torch.tensor(mask).to(self.device) left_hand_mask = mask_output['left_hand_rgb_map'] left_hand_mask.clip_(0., 1.) # non white part is mask left_hand_mask = (torch.tensor([1., 1., 1.], device = left_hand_mask.device) - left_hand_mask) left_hand_mask = (left_hand_mask*left_hand_mask).sum(dim=2) > 0.01 # dele two hand mask left_hand_mask = left_hand_mask & ~mask right_hand_mask = mask_output['right_hand_rgb_map'] right_hand_mask.clip_(0., 1.) right_hand_mask = (torch.tensor([1., 1., 1.], device = right_hand_mask.device) - right_hand_mask) right_hand_mask = (right_hand_mask*right_hand_mask).sum(dim=2) > 0.01 right_hand_mask = right_hand_mask & ~mask left_sleeve_mask = (left_hand_mask * 255).to(torch.uint8) right_sleeve_mask = (right_hand_mask * 255).to(torch.uint8) # 利用 r_hand_mask 和 l_hand_mask,将wo_hand图像中的mask部分覆盖rgb_map rgb_map = output['rgb_map'] rgb_map.clip_(0., 1.) rgb_map = (rgb_map * 255).to(torch.uint8).cpu().numpy() rgb_map_wo_hand = output_wo_hand['rgb_map'] rgb_map_wo_hand.clip_(0., 1.) rgb_map_wo_hand = (rgb_map_wo_hand * 255).to(torch.uint8).cpu().numpy() r_mask = (r_hand_mask>128).cpu().numpy() l_mask = (l_hand_mask>128).cpu().numpy() mask = r_mask | l_mask mask = mask.astype(np.uint8) # 定义一个结构元素,可以调整其大小以改变膨胀的程度 kernel = np.ones((5, 5), np.uint8) # 应用膨胀操作 mask = cv.dilate(mask, kernel, iterations=3) mask = mask.astype(np.bool_) mask = np.expand_dims(mask, axis=2) # get the final rgb_map without hand mix = rgb_map_wo_hand.copy() * mask + rgb_map * ~mask torso_map = output['torso_map'][:, :, 0] torso_map.clip_(0., 1.) torso_map = (torso_map * 255).to(torch.uint8).cpu().numpy() mask_map = output['mask_map'][:, :, 0] mask_map.clip_(0., 1.) mask_map = (mask_map * 255).to(torch.uint8).cpu().numpy() output={ # smpl 'betas':self.body_training_dataset.smpl_data['betas'].reshape([-1]).detach().cpu().numpy(), 'global_orient':item['global_orient'].reshape([-1]).detach().cpu().numpy(), 'transl':item['transl'].reshape([-1]).detach().cpu().numpy(), 'body_pose':item['body_pose'].reshape([-1]).detach().cpu().numpy(), # camera 'extr':extr, 'intr':intr, 'img_h':img_h, 'img_w':img_w, 'camera_entry':camera_entry, # rgb and masks 'rgb_map':rgb_map, 'rgb_map_wo_hand':mix, 'torso_map':torso_map, 'mask_map':mask_map, 'all_mask':all_mask, 'left_sleeve_mask':left_sleeve_mask, 'right_sleeve_mask':right_sleeve_mask, 'hand_visual':hand_visual } return output def build_param(self,idx,body_output): head_gaussians = load_gaussians_from_ply(self.cat.ref_head_gaussian_path) head_pose, head_scale, id_coeff, exp_coeff = load_face_params(self.cat.ref_head_param_path) body_head_blending_params = np.load(self.cat.body_head_blending_param_path) smplx_to_faceverse = body_head_blending_params['smplx_to_faceverse'] residual_transf = body_head_blending_params['residual_transf'] head_color_bw = body_head_blending_params['head_color_bw'] smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz') global_orient, transl, body_pose, betas = body_output['global_orient'], body_output['transl'], body_output['body_pose'], body_output['betas'] smpl_verts, head_joint_transfmat = get_smpl_verts_and_head_transformation( smpl, global_orient, body_pose, transl, betas) livehead2livebody = calc_livehead2livebody(head_pose, smplx_to_faceverse, head_joint_transfmat) total_transf = np.matmul(livehead2livebody, residual_transf) cam, image_size = load_camera_data(body_output['camera_entry']) cam_extr = np.matmul(cam[0], total_transf) cam_intr = np.copy(cam[1]) pts = np.copy(head_gaussians.xyz) pts_proj = np.matmul(pts, cam_extr[:3, :3].transpose()) + cam_extr[:3, 3] pts_proj = np.matmul(pts_proj, cam_intr.transpose()) pts_proj = pts_proj / pts_proj[:, 2:] pts_min, pts_max = np.min(pts_proj, axis=0), np.max(pts_proj, axis=0) pts_center = (pts_min + pts_max) // 2 pts_size = np.max(pts_max - pts_min) tgt_pts_size = 350 tgt_image_size = 512 zoom_scale = tgt_pts_size / pts_size cam_intr_zoom = np.copy(cam_intr) cam_intr_zoom[:2] *= zoom_scale cam_intr_zoom[0, 2] = cam_intr_zoom[0, 2] - (pts_center[0]*zoom_scale - tgt_image_size/2) cam_intr_zoom[1, 2] = cam_intr_zoom[1, 2] - (pts_center[1]*zoom_scale - tgt_image_size/2) output = { 'cam_extr':cam_extr, 'cam_intr':cam_intr, 'image_size':image_size, 'cam_intr_zoom':cam_intr_zoom, 'zoom_image_size':[tgt_image_size, tgt_image_size], 'zoom_center':pts_center, 'zoom_scale':zoom_scale, 'head_pose':head_pose, 'head_scale':head_scale, 'head_color_bw':head_color_bw, } return output def build_head(self, idx, head_offline_rendering_param): # head_offline_rendering_param = np.load(offline_rendering_param_fpath) cam_extr = head_offline_rendering_param['cam_extr'] cam_intr = head_offline_rendering_param['cam_intr'] cam_intr_zoom = head_offline_rendering_param['cam_intr_zoom'] zoom_image_size = head_offline_rendering_param['zoom_image_size'] head_pose = head_offline_rendering_param['head_pose'] head_scale = head_offline_rendering_param['head_scale'] head_color_bw = head_offline_rendering_param['head_color_bw'] zoom_scale = head_offline_rendering_param['zoom_scale'] head_pose = torch.from_numpy(head_pose.astype(np.float32)).to(self.device) head_color_bw = torch.from_numpy(head_color_bw.astype(np.float32)).to(self.device) render_size = 512 # data = self.head_dataloader[idx] data = self.head_dataset[idx] # add batch dim data = {k: v.unsqueeze(0) for k, v in data.items() if isinstance(v, torch.Tensor)} # print(data.keys()) new_gs_camera_param_dict = self.prepare_camera_data_for_gs_rendering(cam_extr, cam_intr_zoom, render_size, render_size) for k in new_gs_camera_param_dict.keys(): if isinstance(new_gs_camera_param_dict[k], torch.Tensor): new_gs_camera_param_dict[k] = new_gs_camera_param_dict[k].unsqueeze(0).to(self.device) new_gs_camera_param_dict['pose'] = head_pose.unsqueeze(0).to(self.device) to_cuda = ['images', 'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', 'pose', 'scale', 'exp_coeff', 'pose_code'] for data_item in to_cuda: data[data_item] = data[data_item].to(device=self.device) data.update(new_gs_camera_param_dict) with torch.no_grad(): data = self.gaussianhead.generate(data) data = self.head_camera.render_gaussian(data, 512) render_images = data['render_images'] supres_images = self.supres(render_images) data['supres_images'] = supres_images data['bg_color'] = torch.zeros([1, 32], device=self.device, dtype=torch.float32) data['color_bk'] = data.pop('color') data['color'] = torch.ones_like(data['color_bk']) * head_color_bw.reshape([1, -1, 1]) * 2.0 data['color'][:, :, 1] = 1 data['color'] = torch.clamp(data['color'], 0., 1.) data = self.head_camera.render_gaussian(data, render_size) render_bw = data['render_images'][:, :3, :, :] data['color'] = data.pop('color_bk') data['render_bw'] = render_bw supres_image = data['supres_images'][0].permute(1, 2, 0).detach().cpu().numpy() supres_image = (supres_image * 255).astype(np.uint8)[:,:,::-1] render_bw = data['render_bw'][0].permute(1, 2, 0).detach().cpu().numpy() render_bw = np.clip(render_bw * 255, 0, 255).astype(np.uint8)[:,:,::-1] render_bw = cv2.resize(render_bw, (supres_image.shape[0], supres_image.shape[1])) output = { 'render_images':supres_image, 'render_bw':render_bw, } return output def prepare_camera_data_for_gs_rendering(self, extrinsic, intrinsic, original_resolution, new_resolution): extrinsic = np.copy(extrinsic) intrinsic = np.copy(intrinsic) new_intrinsic = np.copy(intrinsic) new_intrinsic[:2] *= new_resolution / original_resolution intrinsic[0, 0] = intrinsic[0, 0] * 2 / original_resolution intrinsic[0, 2] = intrinsic[1, 2] * 2 / original_resolution - 1 intrinsic[1, 1] = intrinsic[1, 1] * 2 / original_resolution intrinsic[1, 2] = intrinsic[1, 2] * 2 / original_resolution - 1 fovx = 2 * math.atan(1 / intrinsic[0, 0]) fovy = 2 * math.atan(1 / intrinsic[1, 1]) world_view_transform = torch.tensor(getWorld2View2(extrinsic[:3, :3].transpose(), extrinsic[:3, 3])).transpose(0, 1) projection_matrix = getProjectionMatrix( znear=0.01, zfar=100, fovX=None, fovY=None, K=new_intrinsic, img_h=new_resolution, img_w=new_resolution).transpose(0,1) full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0) camera_center = world_view_transform.inverse()[3, :3] c2w = np.linalg.inv(extrinsic) viewdir = np.matmul(c2w[:3, :3], np.array([0, 0, -1], np.float32).reshape([3, 1])).reshape([-1]) viewdir = torch.from_numpy(viewdir.astype(np.float32)) return { 'extrinsics': torch.from_numpy(extrinsic.astype(np.float32)), 'intrinsics': torch.from_numpy(intrinsic.astype(np.float32)), 'viewdir': viewdir, 'fovx': torch.Tensor([fovx]), 'fovy': torch.Tensor([fovy]), 'world_view_transform': world_view_transform, 'projection_matrix': projection_matrix, 'full_proj_transform': full_proj_transform, 'camera_center': camera_center } def stich_head_body(self,body_rendering,body_mask,body_torso_mask,head_rendering,head_blending_mask,body_head_blending_params,head_offline_rendering_param): color_transfer = body_head_blending_params['color_transfer'] zoom_image_size = head_offline_rendering_param['zoom_image_size'] zoom_center = head_offline_rendering_param['zoom_center'] zoom_scale = head_offline_rendering_param['zoom_scale'] if len(body_mask.shape) == 3: body_mask = body_mask[:, :, 0] if len(body_torso_mask.shape) == 3: body_torso_mask = body_torso_mask[:, :, 0] head_rendering = cv2.resize(head_rendering, (int(zoom_image_size[0]), int(zoom_image_size[1]))) head_blending_mask = cv2.resize(head_blending_mask, (int(zoom_image_size[0]), int(zoom_image_size[1]))) head_mask = head_blending_mask[:, :, 1] head_blending_mask = head_blending_mask[:, :, 0] head_blending_mask = soften_blending_mask(head_blending_mask, head_mask) pasteback_center = zoom_center pasteback_scale = zoom_scale head_rendering_back = paste_back_with_linear_interp(pasteback_scale, pasteback_center, head_rendering, [body_rendering.shape[1], body_rendering.shape[0]]) head_blending_mask_back = paste_back_with_linear_interp(pasteback_scale, pasteback_center, head_blending_mask, [body_rendering.shape[1], body_rendering.shape[0]]) head_mask_back = paste_back_with_linear_interp(pasteback_scale, pasteback_center, head_mask, [body_rendering.shape[1], body_rendering.shape[0]]) # head_blending_mask_back *= body_mask # head_mask_back *= body_mask head_blending_mask_back = head_blending_mask_back * (1 - body_torso_mask) head_rendering_back_shape = head_rendering_back.shape head_rendering_back = np.matmul(head_rendering_back.reshape(-1, 3), color_transfer[:3, :3].transpose()) + color_transfer[:3, 3][None] head_rendering_back = head_rendering_back.reshape(head_rendering_back_shape) head_rendering_back = head_rendering_back * head_mask_back[:, :, None] + (1 - head_mask_back[:, :, None]) body_rendering = body_rendering * (1 - head_blending_mask_back[:, :, None]) + head_rendering_back * head_blending_mask_back[:, :, None] return np.uint8(np.clip(body_rendering, 0, 1)*255) # def build_hand(betas,poses,camera): # # build hand here # output = { # 'hand_render':render, # 'hand_mask':mask, # } # return output if __name__ == '__main__': conf = OmegaConf.load('configs/example.yaml') avatar = Avatar(conf) avatar.build_dataset() # avatar.test_body() avatar.render_all()