import platform from turtle import left, right from networkx import full_join import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import pytorch3d.ops import pytorch3d.transforms import cv2 as cv import AnimatableGaussians.config as config from AnimatableGaussians.network.styleunet.dual_styleunet import DualStyleUNet from AnimatableGaussians.gaussians.gaussian_model import GaussianModel from AnimatableGaussians.gaussians.gaussian_renderer import render3 class AvatarNet(nn.Module): def __init__(self, opt): super(AvatarNet, self).__init__() self.opt = opt self.random_style = opt.get('random_style', False) self.with_viewdirs = opt.get('with_viewdirs', True) # init canonical gausssian model self.max_sh_degree = 0 self.cano_gaussian_model = GaussianModel(sh_degree = self.max_sh_degree) cano_smpl_map = cv.imread(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/cano_smpl_pos_map.exr', cv.IMREAD_UNCHANGED) self.cano_smpl_map = torch.from_numpy(cano_smpl_map).to(torch.float32).to(config.device) self.cano_smpl_mask = torch.linalg.norm(self.cano_smpl_map, dim = -1) > 0. self.init_points = self.cano_smpl_map[self.cano_smpl_mask] self.lbs = torch.from_numpy(np.load(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/init_pts_lbs.npy')).to(torch.float32).to(config.device) self.cano_gaussian_model.create_from_pcd(self.init_points, torch.rand_like(self.init_points), spatial_lr_scale = 2.5) self.color_net = DualStyleUNet(inp_size = 512, inp_ch = 3, out_ch = 3, out_size = 1024, style_dim = 512, n_mlp = 2) self.position_net = DualStyleUNet(inp_size = 512, inp_ch = 3, out_ch = 3, out_size = 1024, style_dim = 512, n_mlp = 2) self.other_net = DualStyleUNet(inp_size = 512, inp_ch = 3, out_ch = 8, out_size = 1024, style_dim = 512, n_mlp = 2) self.color_style = torch.ones([1, self.color_net.style_dim], dtype=torch.float32, device=config.device) / np.sqrt(self.color_net.style_dim) self.position_style = torch.ones([1, self.position_net.style_dim], dtype=torch.float32, device=config.device) / np.sqrt(self.position_net.style_dim) self.other_style = torch.ones([1, self.other_net.style_dim], dtype=torch.float32, device=config.device) / np.sqrt(self.other_net.style_dim) if self.with_viewdirs: cano_nml_map = cv.imread(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/cano_smpl_nml_map.exr', cv.IMREAD_UNCHANGED) self.cano_nml_map = torch.from_numpy(cano_nml_map).to(torch.float32).to(config.device) self.cano_nmls = self.cano_nml_map[self.cano_smpl_mask] self.viewdir_net = nn.Sequential( nn.Conv2d(1, 64, 4, 2, 1), nn.LeakyReLU(0.2, inplace = True), nn.Conv2d(64, 128, 4, 2, 1) ) def generate_mean_hands(self): # print('# Generating mean hands ...') import glob # get hand mask lbs_argmax = self.lbs.argmax(1) self.hand_mask = lbs_argmax == 20 self.hand_mask = torch.logical_or(self.hand_mask, lbs_argmax == 21) self.hand_mask = torch.logical_or(self.hand_mask, lbs_argmax >= 25) pose_map_paths = sorted(glob.glob(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/%08d.exr' % config.opt['test']['fix_hand_id'])) smpl_pos_map = cv.imread(pose_map_paths[0], cv.IMREAD_UNCHANGED) pos_map_size = smpl_pos_map.shape[1] // 2 smpl_pos_map = np.concatenate([smpl_pos_map[:, :pos_map_size], smpl_pos_map[:, pos_map_size:]], 2) smpl_pos_map = smpl_pos_map.transpose((2, 0, 1)) pose_map = torch.from_numpy(smpl_pos_map).to(torch.float32).to(config.device) pose_map = pose_map[:3] cano_pts = self.get_positions(pose_map) opacity, scales, rotations = self.get_others(pose_map) colors, color_map = self.get_colors(pose_map) self.hand_positions = cano_pts#[self.hand_mask] self.hand_opacity = opacity#[self.hand_mask] self.hand_scales = scales#[self.hand_mask] self.hand_rotations = rotations#[self.hand_mask] self.hand_colors = colors#[self.hand_mask] # # debug # hand_pts = trimesh.PointCloud(self.hand_positions.detach().cpu().numpy()) # hand_pts.export('./debug/hand_template.obj') # exit(1) def transform_cano2live(self, gaussian_vals, items): pt_mats = torch.einsum('nj,jxy->nxy', self.lbs, items['cano2live_jnt_mats']) gaussian_vals['positions'] = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], gaussian_vals['positions']) + pt_mats[..., :3, 3] rot_mats = pytorch3d.transforms.quaternion_to_matrix(gaussian_vals['rotations']) rot_mats = torch.einsum('nxy,nyz->nxz', pt_mats[..., :3, :3], rot_mats) gaussian_vals['rotations'] = pytorch3d.transforms.matrix_to_quaternion(rot_mats) return gaussian_vals def get_positions(self, pose_map, return_map = False): position_map, _ = self.position_net([self.position_style], pose_map[None], randomize_noise = False) front_position_map, back_position_map = torch.split(position_map, [3, 3], 1) position_map = torch.cat([front_position_map, back_position_map], 3)[0].permute(1, 2, 0) delta_position = 0.05 * position_map[self.cano_smpl_mask] # delta_position = position_map[self.cano_smpl_mask] positions = delta_position + self.cano_gaussian_model.get_xyz if return_map: return positions, position_map else: return positions def get_others(self, pose_map): other_map, _ = self.other_net([self.other_style], pose_map[None], randomize_noise = False) front_map, back_map = torch.split(other_map, [8, 8], 1) other_map = torch.cat([front_map, back_map], 3)[0].permute(1, 2, 0) others = other_map[self.cano_smpl_mask] # (N, 8) opacity, scales, rotations = torch.split(others, [1, 3, 4], 1) opacity = self.cano_gaussian_model.opacity_activation(opacity + self.cano_gaussian_model.get_opacity_raw) scales = self.cano_gaussian_model.scaling_activation(scales + self.cano_gaussian_model.get_scaling_raw) rotations = self.cano_gaussian_model.rotation_activation(rotations + self.cano_gaussian_model.get_rotation_raw) return opacity, scales, rotations def get_colors(self, pose_map, front_viewdirs = None, back_viewdirs = None): color_style = torch.rand_like(self.color_style) if self.random_style and self.training else self.color_style color_map, _ = self.color_net([color_style], pose_map[None], randomize_noise = False, view_feature1 = front_viewdirs, view_feature2 = back_viewdirs) front_color_map, back_color_map = torch.split(color_map, [3, 3], 1) color_map = torch.cat([front_color_map, back_color_map], 3)[0].permute(1, 2, 0) colors = color_map[self.cano_smpl_mask] return colors, color_map def get_viewdir_feat(self, items): with torch.no_grad(): pt_mats = torch.einsum('nj,jxy->nxy', self.lbs, items['cano2live_jnt_mats']) live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], self.init_points) + pt_mats[..., :3, 3] live_nmls = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], self.cano_nmls) cam_pos = -torch.matmul(torch.linalg.inv(items['extr'][:3, :3]), items['extr'][:3, 3]) viewdirs = F.normalize(cam_pos[None] - live_pts, dim = -1, eps = 1e-3) if self.training: viewdirs += torch.randn(*viewdirs.shape).to(viewdirs) * 0.1 viewdirs = F.normalize(viewdirs, dim = -1, eps = 1e-3) viewdirs = (live_nmls * viewdirs).sum(-1) viewdirs_map = torch.zeros(*self.cano_nml_map.shape[:2]).to(viewdirs) viewdirs_map[self.cano_smpl_mask] = viewdirs viewdirs_map = viewdirs_map[None, None] viewdirs_map = F.interpolate(viewdirs_map, None, 0.5, 'nearest') front_viewdirs, back_viewdirs = torch.split(viewdirs_map, [512, 512], -1) front_viewdirs = self.opt.get('weight_viewdirs', 1.) * self.viewdir_net(front_viewdirs) back_viewdirs = self.opt.get('weight_viewdirs', 1.) * self.viewdir_net(back_viewdirs) return front_viewdirs, back_viewdirs def get_pose_map(self, items): pt_mats = torch.einsum('nj,jxy->nxy', self.lbs, items['cano2live_jnt_mats_woRoot']) live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], self.init_points) + pt_mats[..., :3, 3] live_pos_map = torch.zeros_like(self.cano_smpl_map) live_pos_map[self.cano_smpl_mask] = live_pts live_pos_map = F.interpolate(live_pos_map.permute(2, 0, 1)[None], None, [0.5, 0.5], mode = 'nearest')[0] live_pos_map = torch.cat(torch.split(live_pos_map, [512, 512], 2), 0) items.update({ 'smpl_pos_map': live_pos_map }) return live_pos_map def render(self, items, bg_color = (0., 0., 0.), use_pca = False, use_vae = False): """ Note that no batch index in items. """ bg_color = torch.from_numpy(np.asarray(bg_color)).to(torch.float32).to(config.device) pose_map = items['smpl_pos_map'][:3] assert not (use_pca and use_vae), "Cannot use both PCA and VAE!" if use_pca: pose_map = items['smpl_pos_map_pca'][:3] if use_vae: pose_map = items['smpl_pos_map_vae'][:3] cano_pts, pos_map = self.get_positions(pose_map, return_map = True) opacity, scales, rotations = self.get_others(pose_map) # if not self.training: # scales = torch.clip(scales, 0., 0.03) if self.with_viewdirs: front_viewdirs, back_viewdirs = self.get_viewdir_feat(items) else: front_viewdirs, back_viewdirs = None, None colors, color_map = self.get_colors(pose_map, front_viewdirs, back_viewdirs) if not self.training and config.opt['test'].get('fix_hand', False) and config.opt['mode'] == 'test': # print('# fuse hands ...') import ipdb import AnimatableGaussians.utils.geo_util as geo_util cano_xyz = self.init_points wl = torch.sigmoid(2.5 * (geo_util.normalize_vert_bbox(items['left_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] + 2.0)) wr = torch.sigmoid(-2.5 * (geo_util.normalize_vert_bbox(items['right_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] - 2.0)) wl[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. wr[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. s = torch.maximum(wl + wr, torch.ones_like(wl)) wl, wr = wl / s, wr / s w = wl + wr # ipdb.set_trace() cano_pts = w * self.hand_positions + (1.0 - w) * cano_pts # new_opacity = torch.zeros_like(opacity) opacity = w * self.hand_opacity + (1.0 - w) * opacity # opacity = w * self.hand_opacity * 0 + (1.0 - w) * opacity # opacity = opacity * 0 scales = w * self.hand_scales + (1.0 - w) * scales rotations = w * self.hand_rotations + (1.0 - w) * rotations # colors = w * self.hand_colors + (1.0 - w) * colors # new_hand_colors = torch.ones_like(colors) * 0.5 # colors = w * new_hand_colors + (1.0 - w) * colors gaussian_vals = { 'positions': cano_pts, 'opacity': opacity, 'scales': scales, 'rotations': rotations, 'colors': colors, 'max_sh_degree': self.max_sh_degree } # ipdb.set_trace() nonrigid_offset = gaussian_vals['positions'] - self.init_points gaussian_vals = self.transform_cano2live(gaussian_vals, items) render_ret = render3( gaussian_vals, bg_color, items['extr'], items['intr'], items['img_w'], items['img_h'] ) rgb_map = render_ret['render'].permute(1, 2, 0) mask_map = render_ret['mask'].permute(1, 2, 0) torso_flag = 1 - (self.lbs[:, 12] + self.lbs[:, 15] + self.lbs[:, 22] + self.lbs[:, 23] + self.lbs[:, 24]) torso_weight = torch.stack([torso_flag, torso_flag, torso_flag], dim=-1) orig_color, gaussian_vals['colors'] = gaussian_vals['colors'], torso_weight render_ret = render3( gaussian_vals, torch.zeros_like(bg_color), items['extr'], items['intr'], items['img_w'], items['img_h'] ) torso_map = render_ret['render'].permute(1, 2, 0) gaussian_vals['colors'] = orig_color ret = { 'rgb_map': rgb_map, 'torso_map': torso_map, 'mask_map': mask_map, 'offset': nonrigid_offset, 'pos_map': pos_map } if not self.training: ret.update({ 'cano_tex_map': color_map, 'posed_gaussians': gaussian_vals }) return ret def render_wo_hand(self, items, bg_color = (0., 0., 0.), use_pca = False, use_vae = False): """ Note that no batch index in items. """ bg_color = torch.from_numpy(np.asarray(bg_color)).to(torch.float32).to(config.device) pose_map = items['smpl_pos_map'][:3] assert not (use_pca and use_vae), "Cannot use both PCA and VAE!" if use_pca: pose_map = items['smpl_pos_map_pca'][:3] if use_vae: pose_map = items['smpl_pos_map_vae'][:3] cano_pts, pos_map = self.get_positions(pose_map, return_map = True) opacity, scales, rotations = self.get_others(pose_map) # if not self.training: # scales = torch.clip(scales, 0., 0.03) if self.with_viewdirs: front_viewdirs, back_viewdirs = self.get_viewdir_feat(items) else: front_viewdirs, back_viewdirs = None, None colors, color_map = self.get_colors(pose_map, front_viewdirs, back_viewdirs) if not self.training and config.opt['test'].get('fix_hand', False) and config.opt['mode'] == 'test': import AnimatableGaussians.utils.geo_util as geo_util cano_xyz = self.init_points wl = torch.sigmoid(2.5 * (geo_util.normalize_vert_bbox(items['left_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] + 2.0)) wr = torch.sigmoid(-2.5 * (geo_util.normalize_vert_bbox(items['right_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] - 2.0)) wl[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. wr[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. s = torch.maximum(wl + wr, torch.ones_like(wl)) wl, wr = wl / s, wr / s w = wl + wr # ipdb.set_trace() cano_pts = w * self.hand_positions + (1.0 - w) * cano_pts # new_opacity = torch.zeros_like(opacity) # opacity = w * self.hand_opacity + (1.0 - w) * opacity opacity = w * self.hand_opacity * 0 + (1.0 - w) * opacity # opacity = opacity * 0 scales = w * self.hand_scales + (1.0 - w) * scales rotations = w * self.hand_rotations + (1.0 - w) * rotations # colors = w * self.hand_colors + (1.0 - w) * colors # new_hand_colors = torch.ones_like(colors) * 0.5 # colors = w * new_hand_colors + (1.0 - w) * colors gaussian_vals = { 'positions': cano_pts, 'opacity': opacity, 'scales': scales, 'rotations': rotations, 'colors': colors, 'max_sh_degree': self.max_sh_degree } # ipdb.set_trace() nonrigid_offset = gaussian_vals['positions'] - self.init_points gaussian_vals = self.transform_cano2live(gaussian_vals, items) render_ret = render3( gaussian_vals, bg_color, items['extr'], items['intr'], items['img_w'], items['img_h'] ) rgb_map = render_ret['render'].permute(1, 2, 0) ret = { 'rgb_map': rgb_map, } return ret def render_mask(self, items, bg_color = (0., 0., 0.), use_pca = False, use_vae = False): """ Note that no batch index in items. """ bg_color = torch.from_numpy(np.asarray(bg_color)).to(torch.float32).to(config.device) pose_map = items['smpl_pos_map'][:3] assert not (use_pca and use_vae), "Cannot use both PCA and VAE!" if use_pca: pose_map = items['smpl_pos_map_pca'][:3] if use_vae: pose_map = items['smpl_pos_map_vae'][:3] cano_pts, pos_map = self.get_positions(pose_map, return_map = True) opacity, scales, rotations = self.get_others(pose_map) # if not self.training: # scales = torch.clip(scales, 0., 0.03) if self.with_viewdirs: front_viewdirs, back_viewdirs = self.get_viewdir_feat(items) else: front_viewdirs, back_viewdirs = None, None colors, color_map = self.get_colors(pose_map, front_viewdirs, back_viewdirs) if not self.training and config.opt['test'].get('fix_hand', False) and config.opt['mode'] == 'test': # print('# fuse hands ...') import ipdb import AnimatableGaussians.utils.geo_util as geo_util cano_xyz = self.init_points wl = torch.sigmoid(2.5 * (geo_util.normalize_vert_bbox(items['left_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] + 2.0)) wr = torch.sigmoid(-2.5 * (geo_util.normalize_vert_bbox(items['right_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] - 2.0)) wl[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. wr[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. s = torch.maximum(wl + wr, torch.ones_like(wl)) wl, wr = wl / s, wr / s w = wl + wr # ipdb.set_trace() cano_pts = w * self.hand_positions + (1.0 - w) * cano_pts # opacity = w * self.hand_opacity + (1.0 - w) * opacity body_opacity = torch.zeros_like(opacity) no_body_opacity = w * self.hand_opacity * 0 + (1.0 - w) * body_opacity only_hand_opacity = w * self.hand_opacity + (1.0 - w) * body_opacity left_hand_opacity = wl * self.hand_opacity + (1.0 - wl) * body_opacity right_hand_opacity = wr * self.hand_opacity + (1.0 - wr) * body_opacity opacity = w * self.hand_opacity + (1.0 - w) * opacity # opacity = opacity * 0 scales = w * self.hand_scales + (1.0 - w) * scales rotations = w * self.hand_rotations + (1.0 - w) * rotations # colors = w * self.hand_colors + (1.0 - w) * colors r_hand_color = torch.ones_like(colors) * torch.tensor([1., 0., 0.]).to(config.device) l_hand_color = torch.ones_like(colors) * torch.tensor([0., 0., 1.]).to(config.device) body_color = torch.ones_like(colors) * torch.tensor([0, 1, 0]).to(config.device) full_colors = wr * r_hand_color + wl * l_hand_color + (1.0 - w) * body_color full_gaussian_vals = { 'positions': cano_pts, 'opacity': opacity, 'scales': scales, 'rotations': rotations, 'colors': full_colors, 'max_sh_degree': self.max_sh_degree } hand_only_gaussian_vals = { 'positions': cano_pts, 'opacity': only_hand_opacity, 'scales': scales, 'rotations': rotations, 'colors': full_colors, 'max_sh_degree': self.max_sh_degree } left_hand_gaussian_vals = { 'positions': cano_pts, 'opacity': left_hand_opacity, 'scales': scales, 'rotations': rotations, 'colors': l_hand_color, 'max_sh_degree': self.max_sh_degree } right_hand_gaussian_vals = { 'positions': cano_pts, 'opacity': right_hand_opacity, 'scales': scales, 'rotations': rotations, 'colors': r_hand_color, 'max_sh_degree': self.max_sh_degree } full_gaussian_vals = self.transform_cano2live(full_gaussian_vals, items) hand_only_gaussian_vals = self.transform_cano2live(hand_only_gaussian_vals, items) left_hand_gaussian_vals = self.transform_cano2live(left_hand_gaussian_vals, items) right_hand_gaussian_vals = self.transform_cano2live(right_hand_gaussian_vals, items) full_render_ret = render3( full_gaussian_vals, bg_color, items['extr'], items['intr'], items['img_w'], items['img_h'] ) hand_only_render_ret = render3( hand_only_gaussian_vals, bg_color, items['extr'], items['intr'], items['img_w'], items['img_h'] ) left_hand_render_ret = render3( left_hand_gaussian_vals, bg_color, items['extr'], items['intr'], items['img_w'], items['img_h'] ) right_hand_render_ret = render3( right_hand_gaussian_vals, bg_color, items['extr'], items['intr'], items['img_w'], items['img_h'] ) full_rgb_map = full_render_ret['render'].permute(1, 2, 0) hand_only_rgb_map = hand_only_render_ret['render'].permute(1, 2, 0) left_hand_rgb_map = left_hand_render_ret['render'].permute(1, 2, 0) right_hand_rgb_map = right_hand_render_ret['render'].permute(1, 2, 0) ret = { 'full_body_rgb_map': full_rgb_map, 'hand_only_rgb_map': hand_only_rgb_map, 'left_hand_rgb_map': left_hand_rgb_map, 'right_hand_rgb_map': right_hand_rgb_map, } return ret