Spaces:
Running
Running
import platform | |
from turtle import left, right | |
from networkx import full_join | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import pytorch3d.ops | |
import pytorch3d.transforms | |
import cv2 as cv | |
import AnimatableGaussians.config as config | |
from AnimatableGaussians.network.styleunet.dual_styleunet import DualStyleUNet | |
from AnimatableGaussians.gaussians.gaussian_model import GaussianModel | |
from AnimatableGaussians.gaussians.gaussian_renderer import render3 | |
class AvatarNet(nn.Module): | |
def __init__(self, opt): | |
super(AvatarNet, self).__init__() | |
self.opt = opt | |
self.random_style = opt.get('random_style', False) | |
self.with_viewdirs = opt.get('with_viewdirs', True) | |
# init canonical gausssian model | |
self.max_sh_degree = 0 | |
self.cano_gaussian_model = GaussianModel(sh_degree = self.max_sh_degree) | |
cano_smpl_map = cv.imread(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/cano_smpl_pos_map.exr', cv.IMREAD_UNCHANGED) | |
self.cano_smpl_map = torch.from_numpy(cano_smpl_map).to(torch.float32).to(config.device) | |
self.cano_smpl_mask = torch.linalg.norm(self.cano_smpl_map, dim = -1) > 0. | |
self.init_points = self.cano_smpl_map[self.cano_smpl_mask] | |
self.lbs = torch.from_numpy(np.load(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/init_pts_lbs.npy')).to(torch.float32).to(config.device) | |
self.cano_gaussian_model.create_from_pcd(self.init_points, torch.rand_like(self.init_points), spatial_lr_scale = 2.5) | |
self.color_net = DualStyleUNet(inp_size = 512, inp_ch = 3, out_ch = 3, out_size = 1024, style_dim = 512, n_mlp = 2) | |
self.position_net = DualStyleUNet(inp_size = 512, inp_ch = 3, out_ch = 3, out_size = 1024, style_dim = 512, n_mlp = 2) | |
self.other_net = DualStyleUNet(inp_size = 512, inp_ch = 3, out_ch = 8, out_size = 1024, style_dim = 512, n_mlp = 2) | |
self.color_style = torch.ones([1, self.color_net.style_dim], dtype=torch.float32, device=config.device) / np.sqrt(self.color_net.style_dim) | |
self.position_style = torch.ones([1, self.position_net.style_dim], dtype=torch.float32, device=config.device) / np.sqrt(self.position_net.style_dim) | |
self.other_style = torch.ones([1, self.other_net.style_dim], dtype=torch.float32, device=config.device) / np.sqrt(self.other_net.style_dim) | |
if self.with_viewdirs: | |
cano_nml_map = cv.imread(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/cano_smpl_nml_map.exr', cv.IMREAD_UNCHANGED) | |
self.cano_nml_map = torch.from_numpy(cano_nml_map).to(torch.float32).to(config.device) | |
self.cano_nmls = self.cano_nml_map[self.cano_smpl_mask] | |
self.viewdir_net = nn.Sequential( | |
nn.Conv2d(1, 64, 4, 2, 1), | |
nn.LeakyReLU(0.2, inplace = True), | |
nn.Conv2d(64, 128, 4, 2, 1) | |
) | |
def generate_mean_hands(self): | |
# print('# Generating mean hands ...') | |
import glob | |
# get hand mask | |
lbs_argmax = self.lbs.argmax(1) | |
self.hand_mask = lbs_argmax == 20 | |
self.hand_mask = torch.logical_or(self.hand_mask, lbs_argmax == 21) | |
self.hand_mask = torch.logical_or(self.hand_mask, lbs_argmax >= 25) | |
pose_map_paths = sorted(glob.glob(config.opt['train']['data']['data_dir'] + '/smpl_pos_map/%08d.exr' % config.opt['test']['fix_hand_id'])) | |
smpl_pos_map = cv.imread(pose_map_paths[0], cv.IMREAD_UNCHANGED) | |
pos_map_size = smpl_pos_map.shape[1] // 2 | |
smpl_pos_map = np.concatenate([smpl_pos_map[:, :pos_map_size], smpl_pos_map[:, pos_map_size:]], 2) | |
smpl_pos_map = smpl_pos_map.transpose((2, 0, 1)) | |
pose_map = torch.from_numpy(smpl_pos_map).to(torch.float32).to(config.device) | |
pose_map = pose_map[:3] | |
cano_pts = self.get_positions(pose_map) | |
opacity, scales, rotations = self.get_others(pose_map) | |
colors, color_map = self.get_colors(pose_map) | |
self.hand_positions = cano_pts#[self.hand_mask] | |
self.hand_opacity = opacity#[self.hand_mask] | |
self.hand_scales = scales#[self.hand_mask] | |
self.hand_rotations = rotations#[self.hand_mask] | |
self.hand_colors = colors#[self.hand_mask] | |
# # debug | |
# hand_pts = trimesh.PointCloud(self.hand_positions.detach().cpu().numpy()) | |
# hand_pts.export('./debug/hand_template.obj') | |
# exit(1) | |
def transform_cano2live(self, gaussian_vals, items): | |
pt_mats = torch.einsum('nj,jxy->nxy', self.lbs, items['cano2live_jnt_mats']) | |
gaussian_vals['positions'] = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], gaussian_vals['positions']) + pt_mats[..., :3, 3] | |
rot_mats = pytorch3d.transforms.quaternion_to_matrix(gaussian_vals['rotations']) | |
rot_mats = torch.einsum('nxy,nyz->nxz', pt_mats[..., :3, :3], rot_mats) | |
gaussian_vals['rotations'] = pytorch3d.transforms.matrix_to_quaternion(rot_mats) | |
return gaussian_vals | |
def get_positions(self, pose_map, return_map = False): | |
position_map, _ = self.position_net([self.position_style], pose_map[None], randomize_noise = False) | |
front_position_map, back_position_map = torch.split(position_map, [3, 3], 1) | |
position_map = torch.cat([front_position_map, back_position_map], 3)[0].permute(1, 2, 0) | |
delta_position = 0.05 * position_map[self.cano_smpl_mask] | |
# delta_position = position_map[self.cano_smpl_mask] | |
positions = delta_position + self.cano_gaussian_model.get_xyz | |
if return_map: | |
return positions, position_map | |
else: | |
return positions | |
def get_others(self, pose_map): | |
other_map, _ = self.other_net([self.other_style], pose_map[None], randomize_noise = False) | |
front_map, back_map = torch.split(other_map, [8, 8], 1) | |
other_map = torch.cat([front_map, back_map], 3)[0].permute(1, 2, 0) | |
others = other_map[self.cano_smpl_mask] # (N, 8) | |
opacity, scales, rotations = torch.split(others, [1, 3, 4], 1) | |
opacity = self.cano_gaussian_model.opacity_activation(opacity + self.cano_gaussian_model.get_opacity_raw) | |
scales = self.cano_gaussian_model.scaling_activation(scales + self.cano_gaussian_model.get_scaling_raw) | |
rotations = self.cano_gaussian_model.rotation_activation(rotations + self.cano_gaussian_model.get_rotation_raw) | |
return opacity, scales, rotations | |
def get_colors(self, pose_map, front_viewdirs = None, back_viewdirs = None): | |
color_style = torch.rand_like(self.color_style) if self.random_style and self.training else self.color_style | |
color_map, _ = self.color_net([color_style], pose_map[None], randomize_noise = False, view_feature1 = front_viewdirs, view_feature2 = back_viewdirs) | |
front_color_map, back_color_map = torch.split(color_map, [3, 3], 1) | |
color_map = torch.cat([front_color_map, back_color_map], 3)[0].permute(1, 2, 0) | |
colors = color_map[self.cano_smpl_mask] | |
return colors, color_map | |
def get_viewdir_feat(self, items): | |
with torch.no_grad(): | |
pt_mats = torch.einsum('nj,jxy->nxy', self.lbs, items['cano2live_jnt_mats']) | |
live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], self.init_points) + pt_mats[..., :3, 3] | |
live_nmls = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], self.cano_nmls) | |
cam_pos = -torch.matmul(torch.linalg.inv(items['extr'][:3, :3]), items['extr'][:3, 3]) | |
viewdirs = F.normalize(cam_pos[None] - live_pts, dim = -1, eps = 1e-3) | |
if self.training: | |
viewdirs += torch.randn(*viewdirs.shape).to(viewdirs) * 0.1 | |
viewdirs = F.normalize(viewdirs, dim = -1, eps = 1e-3) | |
viewdirs = (live_nmls * viewdirs).sum(-1) | |
viewdirs_map = torch.zeros(*self.cano_nml_map.shape[:2]).to(viewdirs) | |
viewdirs_map[self.cano_smpl_mask] = viewdirs | |
viewdirs_map = viewdirs_map[None, None] | |
viewdirs_map = F.interpolate(viewdirs_map, None, 0.5, 'nearest') | |
front_viewdirs, back_viewdirs = torch.split(viewdirs_map, [512, 512], -1) | |
front_viewdirs = self.opt.get('weight_viewdirs', 1.) * self.viewdir_net(front_viewdirs) | |
back_viewdirs = self.opt.get('weight_viewdirs', 1.) * self.viewdir_net(back_viewdirs) | |
return front_viewdirs, back_viewdirs | |
def get_pose_map(self, items): | |
pt_mats = torch.einsum('nj,jxy->nxy', self.lbs, items['cano2live_jnt_mats_woRoot']) | |
live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], self.init_points) + pt_mats[..., :3, 3] | |
live_pos_map = torch.zeros_like(self.cano_smpl_map) | |
live_pos_map[self.cano_smpl_mask] = live_pts | |
live_pos_map = F.interpolate(live_pos_map.permute(2, 0, 1)[None], None, [0.5, 0.5], mode = 'nearest')[0] | |
live_pos_map = torch.cat(torch.split(live_pos_map, [512, 512], 2), 0) | |
items.update({ | |
'smpl_pos_map': live_pos_map | |
}) | |
return live_pos_map | |
def render(self, items, bg_color = (0., 0., 0.), use_pca = False, use_vae = False): | |
""" | |
Note that no batch index in items. | |
""" | |
bg_color = torch.from_numpy(np.asarray(bg_color)).to(torch.float32).to(config.device) | |
pose_map = items['smpl_pos_map'][:3] | |
assert not (use_pca and use_vae), "Cannot use both PCA and VAE!" | |
if use_pca: | |
pose_map = items['smpl_pos_map_pca'][:3] | |
if use_vae: | |
pose_map = items['smpl_pos_map_vae'][:3] | |
cano_pts, pos_map = self.get_positions(pose_map, return_map = True) | |
opacity, scales, rotations = self.get_others(pose_map) | |
# if not self.training: | |
# scales = torch.clip(scales, 0., 0.03) | |
if self.with_viewdirs: | |
front_viewdirs, back_viewdirs = self.get_viewdir_feat(items) | |
else: | |
front_viewdirs, back_viewdirs = None, None | |
colors, color_map = self.get_colors(pose_map, front_viewdirs, back_viewdirs) | |
if not self.training and config.opt['test'].get('fix_hand', False) and config.opt['mode'] == 'test': | |
# print('# fuse hands ...') | |
import ipdb | |
import AnimatableGaussians.utils.geo_util as geo_util | |
cano_xyz = self.init_points | |
wl = torch.sigmoid(2.5 * (geo_util.normalize_vert_bbox(items['left_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] + 2.0)) | |
wr = torch.sigmoid(-2.5 * (geo_util.normalize_vert_bbox(items['right_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] - 2.0)) | |
wl[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. | |
wr[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. | |
s = torch.maximum(wl + wr, torch.ones_like(wl)) | |
wl, wr = wl / s, wr / s | |
w = wl + wr | |
# ipdb.set_trace() | |
cano_pts = w * self.hand_positions + (1.0 - w) * cano_pts | |
# new_opacity = torch.zeros_like(opacity) | |
opacity = w * self.hand_opacity + (1.0 - w) * opacity | |
# opacity = w * self.hand_opacity * 0 + (1.0 - w) * opacity | |
# opacity = opacity * 0 | |
scales = w * self.hand_scales + (1.0 - w) * scales | |
rotations = w * self.hand_rotations + (1.0 - w) * rotations | |
# colors = w * self.hand_colors + (1.0 - w) * colors | |
# new_hand_colors = torch.ones_like(colors) * 0.5 | |
# colors = w * new_hand_colors + (1.0 - w) * colors | |
gaussian_vals = { | |
'positions': cano_pts, | |
'opacity': opacity, | |
'scales': scales, | |
'rotations': rotations, | |
'colors': colors, | |
'max_sh_degree': self.max_sh_degree | |
} | |
# ipdb.set_trace() | |
nonrigid_offset = gaussian_vals['positions'] - self.init_points | |
gaussian_vals = self.transform_cano2live(gaussian_vals, items) | |
render_ret = render3( | |
gaussian_vals, | |
bg_color, | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
rgb_map = render_ret['render'].permute(1, 2, 0) | |
mask_map = render_ret['mask'].permute(1, 2, 0) | |
torso_flag = 1 - (self.lbs[:, 12] + self.lbs[:, 15] + self.lbs[:, 22] + self.lbs[:, 23] + self.lbs[:, 24]) | |
torso_weight = torch.stack([torso_flag, torso_flag, torso_flag], dim=-1) | |
orig_color, gaussian_vals['colors'] = gaussian_vals['colors'], torso_weight | |
render_ret = render3( | |
gaussian_vals, | |
torch.zeros_like(bg_color), | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
torso_map = render_ret['render'].permute(1, 2, 0) | |
gaussian_vals['colors'] = orig_color | |
ret = { | |
'rgb_map': rgb_map, | |
'torso_map': torso_map, | |
'mask_map': mask_map, | |
'offset': nonrigid_offset, | |
'pos_map': pos_map | |
} | |
if not self.training: | |
ret.update({ | |
'cano_tex_map': color_map, | |
'posed_gaussians': gaussian_vals | |
}) | |
return ret | |
def render_wo_hand(self, items, bg_color = (0., 0., 0.), use_pca = False, use_vae = False): | |
""" | |
Note that no batch index in items. | |
""" | |
bg_color = torch.from_numpy(np.asarray(bg_color)).to(torch.float32).to(config.device) | |
pose_map = items['smpl_pos_map'][:3] | |
assert not (use_pca and use_vae), "Cannot use both PCA and VAE!" | |
if use_pca: | |
pose_map = items['smpl_pos_map_pca'][:3] | |
if use_vae: | |
pose_map = items['smpl_pos_map_vae'][:3] | |
cano_pts, pos_map = self.get_positions(pose_map, return_map = True) | |
opacity, scales, rotations = self.get_others(pose_map) | |
# if not self.training: | |
# scales = torch.clip(scales, 0., 0.03) | |
if self.with_viewdirs: | |
front_viewdirs, back_viewdirs = self.get_viewdir_feat(items) | |
else: | |
front_viewdirs, back_viewdirs = None, None | |
colors, color_map = self.get_colors(pose_map, front_viewdirs, back_viewdirs) | |
if not self.training and config.opt['test'].get('fix_hand', False) and config.opt['mode'] == 'test': | |
import AnimatableGaussians.utils.geo_util as geo_util | |
cano_xyz = self.init_points | |
wl = torch.sigmoid(2.5 * (geo_util.normalize_vert_bbox(items['left_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] + 2.0)) | |
wr = torch.sigmoid(-2.5 * (geo_util.normalize_vert_bbox(items['right_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] - 2.0)) | |
wl[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. | |
wr[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. | |
s = torch.maximum(wl + wr, torch.ones_like(wl)) | |
wl, wr = wl / s, wr / s | |
w = wl + wr | |
# ipdb.set_trace() | |
cano_pts = w * self.hand_positions + (1.0 - w) * cano_pts | |
# new_opacity = torch.zeros_like(opacity) | |
# opacity = w * self.hand_opacity + (1.0 - w) * opacity | |
opacity = w * self.hand_opacity * 0 + (1.0 - w) * opacity | |
# opacity = opacity * 0 | |
scales = w * self.hand_scales + (1.0 - w) * scales | |
rotations = w * self.hand_rotations + (1.0 - w) * rotations | |
# colors = w * self.hand_colors + (1.0 - w) * colors | |
# new_hand_colors = torch.ones_like(colors) * 0.5 | |
# colors = w * new_hand_colors + (1.0 - w) * colors | |
gaussian_vals = { | |
'positions': cano_pts, | |
'opacity': opacity, | |
'scales': scales, | |
'rotations': rotations, | |
'colors': colors, | |
'max_sh_degree': self.max_sh_degree | |
} | |
# ipdb.set_trace() | |
nonrigid_offset = gaussian_vals['positions'] - self.init_points | |
gaussian_vals = self.transform_cano2live(gaussian_vals, items) | |
render_ret = render3( | |
gaussian_vals, | |
bg_color, | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
rgb_map = render_ret['render'].permute(1, 2, 0) | |
ret = { | |
'rgb_map': rgb_map, | |
} | |
return ret | |
def render_mask(self, items, bg_color = (0., 0., 0.), use_pca = False, use_vae = False): | |
""" | |
Note that no batch index in items. | |
""" | |
bg_color = torch.from_numpy(np.asarray(bg_color)).to(torch.float32).to(config.device) | |
pose_map = items['smpl_pos_map'][:3] | |
assert not (use_pca and use_vae), "Cannot use both PCA and VAE!" | |
if use_pca: | |
pose_map = items['smpl_pos_map_pca'][:3] | |
if use_vae: | |
pose_map = items['smpl_pos_map_vae'][:3] | |
cano_pts, pos_map = self.get_positions(pose_map, return_map = True) | |
opacity, scales, rotations = self.get_others(pose_map) | |
# if not self.training: | |
# scales = torch.clip(scales, 0., 0.03) | |
if self.with_viewdirs: | |
front_viewdirs, back_viewdirs = self.get_viewdir_feat(items) | |
else: | |
front_viewdirs, back_viewdirs = None, None | |
colors, color_map = self.get_colors(pose_map, front_viewdirs, back_viewdirs) | |
if not self.training and config.opt['test'].get('fix_hand', False) and config.opt['mode'] == 'test': | |
# print('# fuse hands ...') | |
import ipdb | |
import AnimatableGaussians.utils.geo_util as geo_util | |
cano_xyz = self.init_points | |
wl = torch.sigmoid(2.5 * (geo_util.normalize_vert_bbox(items['left_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] + 2.0)) | |
wr = torch.sigmoid(-2.5 * (geo_util.normalize_vert_bbox(items['right_cano_mano_v'], attris = cano_xyz, dim = 0, per_axis = True)[..., 0:1] - 2.0)) | |
wl[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. | |
wr[cano_xyz[..., 1] < items['cano_smpl_center'][1]] = 0. | |
s = torch.maximum(wl + wr, torch.ones_like(wl)) | |
wl, wr = wl / s, wr / s | |
w = wl + wr | |
# ipdb.set_trace() | |
cano_pts = w * self.hand_positions + (1.0 - w) * cano_pts | |
# opacity = w * self.hand_opacity + (1.0 - w) * opacity | |
body_opacity = torch.zeros_like(opacity) | |
no_body_opacity = w * self.hand_opacity * 0 + (1.0 - w) * body_opacity | |
only_hand_opacity = w * self.hand_opacity + (1.0 - w) * body_opacity | |
left_hand_opacity = wl * self.hand_opacity + (1.0 - wl) * body_opacity | |
right_hand_opacity = wr * self.hand_opacity + (1.0 - wr) * body_opacity | |
opacity = w * self.hand_opacity + (1.0 - w) * opacity | |
# opacity = opacity * 0 | |
scales = w * self.hand_scales + (1.0 - w) * scales | |
rotations = w * self.hand_rotations + (1.0 - w) * rotations | |
# colors = w * self.hand_colors + (1.0 - w) * colors | |
r_hand_color = torch.ones_like(colors) * torch.tensor([1., 0., 0.]).to(config.device) | |
l_hand_color = torch.ones_like(colors) * torch.tensor([0., 0., 1.]).to(config.device) | |
body_color = torch.ones_like(colors) * torch.tensor([0, 1, 0]).to(config.device) | |
full_colors = wr * r_hand_color + wl * l_hand_color + (1.0 - w) * body_color | |
full_gaussian_vals = { | |
'positions': cano_pts, | |
'opacity': opacity, | |
'scales': scales, | |
'rotations': rotations, | |
'colors': full_colors, | |
'max_sh_degree': self.max_sh_degree | |
} | |
hand_only_gaussian_vals = { | |
'positions': cano_pts, | |
'opacity': only_hand_opacity, | |
'scales': scales, | |
'rotations': rotations, | |
'colors': full_colors, | |
'max_sh_degree': self.max_sh_degree | |
} | |
left_hand_gaussian_vals = { | |
'positions': cano_pts, | |
'opacity': left_hand_opacity, | |
'scales': scales, | |
'rotations': rotations, | |
'colors': l_hand_color, | |
'max_sh_degree': self.max_sh_degree | |
} | |
right_hand_gaussian_vals = { | |
'positions': cano_pts, | |
'opacity': right_hand_opacity, | |
'scales': scales, | |
'rotations': rotations, | |
'colors': r_hand_color, | |
'max_sh_degree': self.max_sh_degree | |
} | |
full_gaussian_vals = self.transform_cano2live(full_gaussian_vals, items) | |
hand_only_gaussian_vals = self.transform_cano2live(hand_only_gaussian_vals, items) | |
left_hand_gaussian_vals = self.transform_cano2live(left_hand_gaussian_vals, items) | |
right_hand_gaussian_vals = self.transform_cano2live(right_hand_gaussian_vals, items) | |
full_render_ret = render3( | |
full_gaussian_vals, | |
bg_color, | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
hand_only_render_ret = render3( | |
hand_only_gaussian_vals, | |
bg_color, | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
left_hand_render_ret = render3( | |
left_hand_gaussian_vals, | |
bg_color, | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
right_hand_render_ret = render3( | |
right_hand_gaussian_vals, | |
bg_color, | |
items['extr'], | |
items['intr'], | |
items['img_w'], | |
items['img_h'] | |
) | |
full_rgb_map = full_render_ret['render'].permute(1, 2, 0) | |
hand_only_rgb_map = hand_only_render_ret['render'].permute(1, 2, 0) | |
left_hand_rgb_map = left_hand_render_ret['render'].permute(1, 2, 0) | |
right_hand_rgb_map = right_hand_render_ret['render'].permute(1, 2, 0) | |
ret = { | |
'full_body_rgb_map': full_rgb_map, | |
'hand_only_rgb_map': hand_only_rgb_map, | |
'left_hand_rgb_map': left_hand_rgb_map, | |
'right_hand_rgb_map': right_hand_rgb_map, | |
} | |
return ret | |