import numpy as np import torch import torch.nn as nn import trimesh from sklearn.neighbors import KDTree import tqdm import cv2 as cv import os, glob from .lib.networks.faceverse_torch import FaceVerseModel from .lib.networks.smpl_torch import SmplTorch from .lib.utils.gaussian_np_utils import GaussianAttributes, load_gaussians_from_ply, save_gaussians_as_ply, \ apply_transformation_to_gaussians, combine_gaussians, select_gaussians, update_gaussian_attributes from .lib.utils.geometry import search_nearest_correspondence, estimate_rigid_transformation from .lib.utils.sh_utils import SH2RGB def process_smpl_head(): smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz') smpl_v_template = smpl.v_template.detach().cpu().numpy() smpl_faces = smpl.faces.detach().cpu().numpy() # head_skinning_weights = smpl.weights[:, 15] + smpl.weights[:, 22] + smpl.weights[:, 23] + smpl.weights[:, 24] # # blend_weight = np.clip(head_skinning_weights* 1.2 - 0.2, 0, 1) # head_ids = np.where(blend_weight > 0)[0] # # np.savez('./data/smplx_head_vidx_and_blendweight.npz', blend_weight=blend_weight, head_ids=head_ids) if not os.path.exists('./data/smpl_models/smplx_head_3.obj'): trimesh.Trimesh(vertices=smpl_v_template, faces=smpl_faces).export('./data/smpl_models/smplx_head_3.obj') print('Please cut out SMPL head!!!') import pdb; pdb.set_trace() smplx_head = trimesh.load('./data/smpl_models/smplx_head_3.obj') smplx_to_head_dist = np.zeros([smpl_v_template.shape[0]]) for vi, v in enumerate(smpl_v_template): nndist = np.min(np.linalg.norm(v.reshape([1, 3]) - smplx_head.vertices, axis=1)) smplx_to_head_dist[vi] = nndist head_ids = np.where(smplx_to_head_dist < 0.001)[0] blend_weight = np.exp(-smplx_to_head_dist*smplx_to_head_dist * 2000) np.savez('./data/smpl_models/smplx_head_vidx_and_blendweight.npz', blend_weight=blend_weight, head_ids=head_ids) return def load_body_params(path): param = dict(np.load(path)) global_orient = param['global_orient'] transl = param['transl'] body_pose = param['body_pose'] betas = param['betas'] return global_orient, transl, body_pose, betas def load_face_params(path): param = dict(np.load(path)) pose = param['pose'] scale = param['scale'] id_coeff = param['id_coeff'] exp_coeff = param['exp_coeff'] return pose, scale, id_coeff, exp_coeff def get_smpl_verts_and_head_transformation(smpl, global_orient, body_pose, transl, betas): pose = torch.cat([ torch.from_numpy(global_orient.astype(np.float32)), torch.from_numpy(body_pose.astype(np.float32)), torch.zeros([(3+15+15)*3], dtype=torch.float32)], dim=-1) beta = torch.from_numpy(betas.astype(np.float32)) verts, skinning_dict = smpl.forward(pose.reshape(1, -1), beta.reshape(1, -1)) verts = verts[0].detach().cpu().numpy() head_joint_transfmat = skinning_dict['G'][0, 15].detach().cpu().numpy() verts += transl.reshape([1, 3]) head_joint_transfmat[:3, 3] += transl.reshape([3]) return verts, head_joint_transfmat def crop_facial_area(faceverse_verts, points, dist_thres=0.025): min_x, min_y, min_z = np.min(faceverse_verts, axis=0) max_x, max_y, max_z = np.max(faceverse_verts, axis=0) pad = dist_thres*2 in_bbox_mask = (points[:, 0] > min_x - pad) * (points[:, 0] < max_x + pad) * \ (points[:, 1] > min_y - pad) * (points[:, 1] < max_y + pad) * \ (points[:, 2] > min_z - pad) * (points[:, 2] < max_z + pad) in_bbox_idx = np.where(in_bbox_mask)[0] facial_points = points[in_bbox_mask] nndist = np.ones([len(facial_points)]) * 1e10 for i in tqdm.trange(len(facial_points), desc='calculating facial area'): nndist[i] = np.min(np.linalg.norm(faceverse_verts - facial_points[i:(i+1)], axis=1, keepdims=False)) close_to_face_mask = nndist < dist_thres facial_points = facial_points[close_to_face_mask] facial_idx = in_bbox_idx[close_to_face_mask] return facial_points, facial_idx def crop_facial_area2(smpl_verts, smpl_head_vids, points): min_x, min_y, min_z = np.min(smpl_verts[smpl_head_vids], axis=0) max_x, max_y, max_z = np.max(smpl_verts[smpl_head_vids], axis=0) pad = 0.05 in_bbox_mask = (points[:, 0] > min_x - pad) * (points[:, 0] < max_x + pad) * \ (points[:, 1] > min_y - pad) * (points[:, 1] < max_y + pad) * \ (points[:, 2] > min_z - pad) * (points[:, 2] < max_z + pad) in_bbox_idx = np.where(in_bbox_mask)[0] facial_points = points[in_bbox_mask] smpl_head_mask = np.zeros([len(smpl_verts)], dtype=np.bool_) smpl_head_mask[smpl_head_vids] = True close_to_face_mask = np.zeros([len(facial_points)], dtype=np.bool_) for i in tqdm.trange(len(facial_points)): nnid = np.argmin(np.linalg.norm(smpl_verts - facial_points[i:(i+1)], axis=1, keepdims=False)) close_to_face_mask[i] = smpl_head_mask[nnid] facial_points = facial_points[close_to_face_mask] facial_idx = in_bbox_idx[close_to_face_mask] return facial_points, facial_idx def transform_faceverse_to_live_body_space(faceverse_verts, faceverse_to_smplx, head_joint_transfmat): faceverse_verts = np.matmul(faceverse_verts, faceverse_to_smplx[:3, :3].transpose()) + faceverse_to_smplx[:3, 3].reshape(1, 3) faceverse_verts = np.matmul(faceverse_verts, head_joint_transfmat[:3, :3].transpose()) + head_joint_transfmat[:3, 3].reshape(1, 3) return faceverse_verts def calc_livehead2livebody(head_pose, smplx_to_faceverse, head_joint_transfmat): head_cano2live = np.eye(4, dtype=np.float32) head_cano2live[:3, :3] = cv.Rodrigues(head_pose[:3])[0] head_cano2live[:3, 3] = head_pose[3:] head_live2cano = np.linalg.inv(head_cano2live) faceverse_to_smplx = np.linalg.inv(smplx_to_faceverse) total_transf = np.eye(4, dtype=np.float32) for t in [head_live2cano, np.diag([1, -1, -1, 1]), faceverse_to_smplx, head_joint_transfmat]: total_transf = np.matmul(t, total_transf) return total_transf def get_face_blend_weight(head_facial_points, smpl_verts, sigma=0.015): # dists = np.load('./data/faceverse/smplx_verts_to_faceverse_dist.npy').astype(np.float32) # face_nerf_blend_weight = np.exp(-dists**2/(2*sigma**2)) # face_nerf_blend_weight = np.clip(face_nerf_blend_weight*1.2 - 0.1, 0, 1) smpl_blend_weight = dict(np.load('./data/smpl_models/smplx_head_vidx_and_blendweight.npz'))['blend_weight'] corr_idx_, _ = search_nearest_correspondence(head_facial_points, smpl_verts) corr_bw = smpl_blend_weight[corr_idx_] for _ in tqdm.trange(10): corr_bw_ = np.zeros_like(corr_bw) tree = KDTree(head_facial_points, leaf_size=2) for i in range(len(head_facial_points)): _, idx = tree.query(head_facial_points[i:(i+1)], k=4) corr_bw_[i] = np.mean(corr_bw[idx]) corr_bw = np.copy(corr_bw_) # corr_bw = np.clip(corr_bw*1.2 - 0.15, 0, 1) # with open('./debug/debug_head_facial_bw.obj', 'w') as fp: # for p, w in zip(head_facial_points, corr_bw): # fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w)) # import pdb; pdb.set_trace() return corr_bw def get_face_blend_weight2(head_facial_points, body_points, body_facial_idx): body_facial_bbox_min = np.min(body_points[body_facial_idx], axis=0) body_facial_bbox_max = np.max(body_points[body_facial_idx], axis=0) body_facial_bbox_min = body_facial_bbox_min - 0.1 body_facial_bbox_max = body_facial_bbox_max + 0.1 inside_bbox_flag = \ np.int32(body_points[:, 0] > body_facial_bbox_min[0]) * \ np.int32(body_points[:, 0] < body_facial_bbox_max[0]) * \ np.int32(body_points[:, 1] > body_facial_bbox_min[1]) * \ np.int32(body_points[:, 1] < body_facial_bbox_max[1]) * \ np.int32(body_points[:, 2] > body_facial_bbox_min[2]) * \ np.int32(body_points[:, 2] < body_facial_bbox_max[2]) point_idx_inside_bbox = np.nonzero(inside_bbox_flag >0)[0] body_blend_weight = np.zeros([len(body_points)], dtype=np.float32) body_blend_weight[body_facial_idx] = 1 body_points_in_bbox = body_points[point_idx_inside_bbox] body_blend_weight_in_bbox = body_blend_weight[point_idx_inside_bbox] for _ in tqdm.trange(1, desc='Calculating body facial blend weight'): corr_bw_ = np.zeros_like(body_blend_weight_in_bbox) tree = KDTree(body_points_in_bbox, leaf_size=2) for i in tqdm.trange(len(body_points_in_bbox)): ind = tree.query_radius(body_points_in_bbox[i:(i+1)], r=0.035) corr_bw_[i] = np.mean(body_blend_weight_in_bbox[ind[0]]) body_blend_weight_in_bbox = np.copy(corr_bw_) body_blend_weight[point_idx_inside_bbox] = body_blend_weight_in_bbox with open('./debug/debug_body_facial_bw.obj', 'w') as fp: for p, w in zip(body_points, body_blend_weight): fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w)) tree = KDTree(body_points, leaf_size=2) corr_bw = np.zeros([len(head_facial_points)], dtype=np.float32) for i in range(len(head_facial_points)): _, idx = tree.query(head_facial_points[i:(i+1)], k=4) corr_bw[i] = np.mean(body_blend_weight[idx]) # corr_bw = np.clip(corr_bw*1.2 - 0.15, 0, 1) corr_bw_bmin, corr_bw_bmax = np.percentile(corr_bw, 5), np.percentile(corr_bw, 95) corr_bw = np.clip((corr_bw-corr_bw_bmin)/(corr_bw_bmax-corr_bw_bmin), 0, 1) with open('./debug/debug_head_facial_bw.obj', 'w') as fp: for p, w in zip(head_facial_points, corr_bw): fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], w, w, w)) return corr_bw def estimate_color_transfer(head_facial_points, body_facial_points, head_facial_color, body_facial_color, head_facial_opacity): head_facial_color = head_facial_color * 0.28209479177387814 + 0.5 body_facial_color = body_facial_color * 0.28209479177387814 + 0.5 corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points) corr_color = body_facial_color[corr_idx] opacity = 1/(1+np.exp(-head_facial_opacity)) weight = np.float32(opacity > 0.35) head_facial_color = head_facial_color.reshape(len(head_facial_color), 3) * weight.reshape([-1, 1]) corr_color = corr_color.reshape(len(corr_color), 3) * weight.reshape([-1, 1]) head_facial_color = np.concatenate([head_facial_color, np.zeros_like(head_facial_color[:, :1])], axis=1) corr_color = np.concatenate([corr_color, np.zeros_like(corr_color[:, :1])], axis=1) transfer = nn.Parameter(torch.eye(4, dtype=torch.float32)) head_facial_color_th = torch.from_numpy(head_facial_color).float() corr_color_th = torch.from_numpy(corr_color).float() weight_th = torch.from_numpy(weight).float() optim = torch.optim.Adam([transfer], lr=1e-2) for i in range(500): optim.zero_grad() loss = torch.mean(torch.abs(corr_color_th - torch.matmul(head_facial_color_th, transfer.permute(1, 0)))*weight_th) loss = loss + torch.sum(torch.square(transfer - torch.eye(4, dtype=torch.float32))) * 5e-2 if i % 25 == 0: print(loss.item()) loss.backward() optim.step() transfer = transfer.detach().cpu().numpy() print(transfer) # with open('./debug/debug_body_facial_color_updated.obj', 'w') as fp: # for p, c in zip(body_facial_points, body_facial_color): # # c = c * 0.28209479177387814 + 0.5 # c = np.clip(c, 0, 1) # fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2])) # with open('./debug/debug_head_facial_color_updated.obj', 'w') as fp: # head_facial_color = np.matmul(head_facial_color, transfer) # for p, c, w in zip(head_facial_points, head_facial_color, weight): # if w < 0.1: # continue # # c = c * 0.28209479177387814 + 0.5 # c = np.clip(c, 0, 1) # fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2])) # import pdb; pdb.set_trace() return transfer def blend_color(head_facial_color, body_facial_color, blend_weight): blend_weight = blend_weight.reshape([len(blend_weight)] + [1]*(len(head_facial_color.shape)-1)) result = head_facial_color * blend_weight + body_facial_color * (1-blend_weight) return result def save_body_face_stitching_data( result_path, smplx_to_faceverse, residual_transf, body_nonface_mask, head_nonface_mask, head_facial_idx, body_facial_idx, corr_idx, face_color_bw, color_transfer): # os.makedirs('./data/%s' % result_suffix, exist_ok=True) # np.savez('./data/%s/body_face_blending_param.npz' % result_suffix, # smplx_to_faceverse=smplx_to_faceverse.astype(np.float32), # residual_transf=residual_transf.astype(np.float32), # body_nonface_mask=body_nonface_mask.astype(np.int32), # head_facial_idx=head_facial_idx.astype(np.int32), # body_facial_idx=body_facial_idx.astype(np.int32), # head_body_facial_corr_idx=corr_idx.astype(np.int32), # face_color_bw=face_color_bw.astype(np.float32), # color_transfer=color_transfer.astype(np.float32)) head_color_bw = np.zeros([len(head_nonface_mask)]) head_color_bw[head_facial_idx] = face_color_bw head_corr_idx = np.zeros([len(head_nonface_mask)]) head_corr_idx[head_facial_idx] = body_facial_idx[corr_idx] np.savez(result_path, smplx_to_faceverse=smplx_to_faceverse.astype(np.float32), residual_transf=residual_transf.astype(np.float32), body_nonface_mask=body_nonface_mask.astype(np.int32), head_nonface_mask=head_nonface_mask.astype(np.int32), head_facial_idx=head_facial_idx.astype(np.int32), body_facial_idx=body_facial_idx.astype(np.int32), head_body_corr_idx=head_corr_idx.astype(np.int32), head_color_bw=head_color_bw.astype(np.float32), color_transfer=color_transfer.astype(np.float32)) return def manual_refine_facial_cropping(head_facial_points, head_facial_idx, head_facial_colors, body_facial_points, body_facial_idx, body_facial_colors): def _save_points_as_obj(fpath, points, points_color): points_color = np.clip(points_color, 0, 1) with open(fpath, 'w') as fp: for p, c in zip(points, points_color): fp.write('v %f %f %f %f %f %f\n' % (p[0], p[1], p[2], c[0], c[1], c[2])) return _save_points_as_obj('./debug/head_facial_points.obj', head_facial_points, head_facial_colors) _save_points_as_obj('./debug/body_facial_points.obj', body_facial_points, body_facial_colors) # trimesh.Trimesh(vertices=head_facial_points, vertex_colors=head_facial_colors).export('./debug/head_facial_points.obj') # trimesh.Trimesh(vertices=body_facial_points, vertex_colors=body_facial_colors).export('./debug/body_facial_points.obj') if True: print('Saving facial points cropped by algorithms. Please remove unnecessary points manually!') import pdb; pdb.set_trace() head_facial_points_ = trimesh.load('./debug/head_facial_points.obj').vertices body_facial_points_ = trimesh.load('./debug/body_facial_points.obj').vertices _, head_nndist = search_nearest_correspondence(head_facial_points, head_facial_points_) _, body_nndist = search_nearest_correspondence(body_facial_points, body_facial_points_) head_flag = head_nndist < 1e-4 body_flag = body_nndist < 1e-4 return head_facial_points[head_flag], head_facial_idx[head_flag], body_facial_points[body_flag], body_facial_idx[body_flag] def stitch_body_and_head(ref_body_gaussian_path, ref_head_gaussian_path, ref_body_param_path, ref_head_param_path, smplx2faceverse_path, result_folder): device = torch.device("cuda") body_gaussians = load_gaussians_from_ply(ref_body_gaussian_path) head_gaussians = load_gaussians_from_ply(ref_head_gaussian_path) global_orient, transl, body_pose, betas = load_body_params(ref_body_param_path) head_pose, head_scale, id_coeff, exp_coeff = load_face_params(ref_head_param_path) smplx_to_faceverse = np.load(smplx2faceverse_path) faceverse_to_smplx = np.linalg.inv(smplx_to_faceverse) smpl = SmplTorch(model_file='./AnimatableGaussians/smpl_files/smplx/SMPLX_NEUTRAL.npz') smpl_verts, head_joint_transfmat = get_smpl_verts_and_head_transformation( smpl, global_orient, body_pose, transl, betas) smpl_head_vids = dict(np.load('./data/smpl_models/smplx_head_vidx_and_blendweight.npz'))['head_ids'] smpl_head_verts = smpl_verts[smpl_head_vids] model_dict = np.load('./data/faceverse_models/faceverse_simple_v2.npy', allow_pickle=True).item() faceverse_model = FaceVerseModel(model_dict, batch_size=1) faceverse_model.init_coeff_tensors( id_coeff=torch.from_numpy(id_coeff).reshape([1, -1]).to(device), scale_coeff=torch.from_numpy(head_scale).reshape([1, 1]).to(device), ) faceverse_verts = faceverse_model.forward()['v'][0].detach().cpu().numpy() faceverse_verts = transform_faceverse_to_live_body_space(faceverse_verts, faceverse_to_smplx, head_joint_transfmat) livehead2livebody = calc_livehead2livebody( head_pose, smplx_to_faceverse, head_joint_transfmat) head_gaussians_xyz = np.matmul(head_gaussians.xyz, livehead2livebody[:3, :3].transpose()) \ + livehead2livebody[:3, 3].reshape(1, 3) # head_facial_points, head_facial_idx = crop_facial_area(smpl_head_verts, head_gaussians_xyz) # body_facial_points, body_facial_idx = crop_facial_area(smpl_head_verts, body_gaussians.xyz) head_facial_points, head_facial_idx = crop_facial_area2(smpl_verts, smpl_head_vids, head_gaussians_xyz) body_facial_points, body_facial_idx = crop_facial_area2(smpl_verts, smpl_head_vids, body_gaussians.xyz) residual_transf = np.eye(4) head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping( head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]), body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx])) while True: for _ in tqdm.trange(4, desc='Fitting residual transformation'): corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points) corr = body_facial_points[corr_idx] transf = estimate_rigid_transformation(head_facial_points, corr) residual_transf = np.matmul(transf, residual_transf) head_facial_points = np.matmul(head_facial_points, transf[:3, :3].transpose()) + transf[:3, 3].reshape(1, 3) if_crop_well = input('If the facial cropping is good enough? (y/n): ') if if_crop_well == 'y': break else: head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping( head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]), body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx])) # head_facial_points, head_facial_idx, body_facial_points, body_facial_idx = manual_refine_facial_cropping( # head_facial_points, head_facial_idx, SH2RGB(head_gaussians.features_dc[head_facial_idx]), # body_facial_points, body_facial_idx, SH2RGB(body_gaussians.features_dc[body_facial_idx])) # 更改一下逻辑,改成直到对齐为止。 print(np.matmul(residual_transf, livehead2livebody)) residual_transf = np.matmul(np.linalg.inv(livehead2livebody), np.matmul(residual_transf, livehead2livebody)) corr_idx, _ = search_nearest_correspondence(head_facial_points, body_facial_points) # head_gaussians_xyz = np.matmul(head_gaussians_xyz, residual_transf[:3, :3].transpose()) + residual_transf[:3, 3].reshape(1, 3) # faceverse_verts = np.matmul(faceverse_verts, residual_transf[:3, :3].transpose()) + residual_transf[:3, 3].reshape(1, 3) # total_transf = np.matmul(residual_transf, livehead2livebody) total_transf = np.matmul(livehead2livebody, residual_transf) print(total_transf) color_transfer = estimate_color_transfer( head_facial_points, body_facial_points, head_gaussians.features_dc[head_facial_idx], body_gaussians.features_dc[body_facial_idx], head_gaussians.opacities[head_facial_idx] ) # face_color_bw = get_face_blend_weight(head_facial_points, smpl_verts, sigma=0.015) face_color_bw = get_face_blend_weight2(head_facial_points, body_gaussians.xyz, body_facial_idx) body_nonface_mask = np.ones([len(body_gaussians.xyz)], dtype=np.bool_) body_nonface_mask[body_facial_idx] = 0 head_nonface_mask = np.ones([len(head_gaussians.xyz)], dtype=np.bool_) head_nonface_mask[head_facial_idx] = 0 save_body_face_stitching_data( os.path.join(result_folder, 'body_head_blending_param.npz'), smplx_to_faceverse, residual_transf, body_nonface_mask, head_nonface_mask, head_facial_idx, body_facial_idx, corr_idx, face_color_bw, color_transfer) body_gaussians = apply_transformation_to_gaussians(body_gaussians, np.eye(4)) head_gaussians = apply_transformation_to_gaussians(head_gaussians, total_transf, np.eye(3)) body_gaussians_wo_face = select_gaussians(body_gaussians, body_nonface_mask) head_gaussians_face_only = select_gaussians(head_gaussians, head_facial_idx) head_gaussians_face_only_new_color = blend_color( head_gaussians_face_only.features_dc, body_gaussians.features_dc[body_facial_idx][corr_idx], face_color_bw) head_gaussians_face_only_new_xyz = blend_color( head_gaussians_face_only.xyz, body_gaussians.xyz[body_facial_idx][corr_idx], face_color_bw) head_gaussians_face_only_new_opacities = blend_color( head_gaussians_face_only.opacities, body_gaussians.opacities[body_facial_idx][corr_idx], face_color_bw) head_gaussians_face_only_new_scales = blend_color( head_gaussians_face_only.scales, body_gaussians.scales[body_facial_idx][corr_idx], face_color_bw) head_gaussians_face_only = update_gaussian_attributes( head_gaussians_face_only, new_rgb=head_gaussians_face_only_new_color, new_xyz=head_gaussians_face_only_new_xyz, new_opacity=head_gaussians_face_only_new_opacities, new_scale=head_gaussians_face_only_new_scales) full_gaussians = combine_gaussians([body_gaussians_wo_face, head_gaussians_face_only]) save_gaussians_as_ply(os.path.join(result_folder, 'full_gaussians.ply'), full_gaussians)