pengc02's picture
all
ec9a6bc
raw
history blame
7.27 kB
import os
import numpy as np
import torch
import torch.nn.functional as F
import cv2 as cv
import trimesh
import yaml
import tqdm
import smplx
from network.volume import CanoBlendWeightVolume
from utils.renderer import Renderer
import config
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
def save_pos_map(pos_map, path):
mask = np.linalg.norm(pos_map, axis = -1) > 0.
positions = pos_map[mask]
print('Point nums %d' % positions.shape[0])
pc = trimesh.PointCloud(positions)
pc.export(path)
def interpolate_lbs(pts, vertices, faces, vertex_lbs):
from utils.posevocab_custom_ops.nearest_face import nearest_face_pytorch3d
from utils.geo_util import barycentric_interpolate
dists, indices, bc_coords = nearest_face_pytorch3d(
torch.from_numpy(pts).to(torch.float32).cuda()[None],
torch.from_numpy(vertices).to(torch.float32).cuda()[None],
torch.from_numpy(faces).to(torch.int64).cuda()
)
# print(dists.mean())
lbs = barycentric_interpolate(
vert_attris = vertex_lbs[None].to(torch.float32).cuda(),
faces = torch.from_numpy(faces).to(torch.int64).cuda()[None],
face_ids = indices,
bc_coords = bc_coords
)
return lbs[0].cpu().numpy()
map_size = 1024
if __name__ == '__main__':
from argparse import ArgumentParser
import importlib
arg_parser = ArgumentParser()
arg_parser.add_argument('-c', '--config_path', type = str, help = 'Configuration file path.')
args = arg_parser.parse_args()
opt = yaml.load(open(args.config_path, encoding = 'UTF-8'), Loader = yaml.FullLoader)
dataset_module = opt['train'].get('dataset', 'MvRgbDatasetAvatarReX')
MvRgbDataset = importlib.import_module('dataset.dataset_mv_rgb').__getattribute__(dataset_module)
dataset = MvRgbDataset(**opt['train']['data'])
data_dir, frame_list = dataset.data_dir, dataset.pose_list
os.makedirs(data_dir + '/smpl_pos_map', exist_ok = True)
cano_renderer = Renderer(map_size, map_size, shader_name = 'vertex_attribute')
smpl_model = smplx.SMPLX(config.PROJ_DIR + '/smpl_files/smplx', gender = 'neutral', use_pca = False, num_pca_comps = 45, flat_hand_mean = True, batch_size = 1)
smpl_data = np.load(data_dir + '/smpl_params.npz')
smpl_data = {k: torch.from_numpy(v.astype(np.float32)) for k, v in smpl_data.items()}
with torch.no_grad():
cano_smpl = smpl_model.forward(
betas = smpl_data['betas'],
global_orient = config.cano_smpl_global_orient[None],
transl = config.cano_smpl_transl[None],
body_pose = config.cano_smpl_body_pose[None]
)
cano_smpl_v = cano_smpl.vertices[0].cpu().numpy()
cano_center = 0.5 * (cano_smpl_v.min(0) + cano_smpl_v.max(0))
cano_smpl_v_min = cano_smpl_v.min()
smpl_faces = smpl_model.faces.astype(np.int64)
if os.path.exists(data_dir + '/template.ply'):
print('# Loading template from %s' % (data_dir + '/template.ply'))
template = trimesh.load(data_dir + '/template.ply', process = False)
using_template = True
else:
print(f'# Cannot find template.ply from {data_dir}, using SMPL-X as template')
template = trimesh.Trimesh(cano_smpl_v, smpl_faces, process = False)
using_template = False
cano_smpl_v = template.vertices.astype(np.float32)
smpl_faces = template.faces.astype(np.int64)
cano_smpl_v_dup = cano_smpl_v[smpl_faces.reshape(-1)]
cano_smpl_n_dup = template.vertex_normals.astype(np.float32)[smpl_faces.reshape(-1)]
# define front & back view matrices
front_mv = np.identity(4, np.float32)
front_mv[:3, 3] = -cano_center + np.array([0, 0, -10], np.float32)
front_mv[1:3] *= -1
back_mv = np.identity(4, np.float32)
rot_y = cv.Rodrigues(np.array([0, np.pi, 0], np.float32))[0]
back_mv[:3, :3] = rot_y
back_mv[:3, 3] = -rot_y @ cano_center + np.array([0, 0, -10], np.float32)
back_mv[1:3] *= -1
# render canonical smpl position maps
cano_renderer.set_model(cano_smpl_v_dup, cano_smpl_v_dup)
cano_renderer.set_camera(front_mv)
front_cano_pos_map = cano_renderer.render()[:, :, :3]
cano_renderer.set_camera(back_mv)
back_cano_pos_map = cano_renderer.render()[:, :, :3]
back_cano_pos_map = cv.flip(back_cano_pos_map, 1)
cano_pos_map = np.concatenate([front_cano_pos_map, back_cano_pos_map], 1)
cv.imwrite(data_dir + '/smpl_pos_map/cano_smpl_pos_map.exr', cano_pos_map)
# render canonical smpl normal maps
cano_renderer.set_model(cano_smpl_v_dup, cano_smpl_n_dup)
cano_renderer.set_camera(front_mv)
front_cano_nml_map = cano_renderer.render()[:, :, :3]
cano_renderer.set_camera(back_mv)
back_cano_nml_map = cano_renderer.render()[:, :, :3]
back_cano_nml_map = cv.flip(back_cano_nml_map, 1)
cano_nml_map = np.concatenate([front_cano_nml_map, back_cano_nml_map], 1)
cv.imwrite(data_dir + '/smpl_pos_map/cano_smpl_nml_map.exr', cano_nml_map)
body_mask = np.linalg.norm(cano_pos_map, axis = -1) > 0.
cano_pts = cano_pos_map[body_mask]
if using_template:
weight_volume = CanoBlendWeightVolume(data_dir + '/cano_weight_volume.npz')
pts_lbs = weight_volume.forward_weight(torch.from_numpy(cano_pts)[None].cuda())[0]
else:
pts_lbs = interpolate_lbs(cano_pts, cano_smpl_v, smpl_faces, smpl_model.lbs_weights)
pts_lbs = torch.from_numpy(pts_lbs).cuda()
np.save(data_dir + '/smpl_pos_map/init_pts_lbs.npy', pts_lbs.cpu().numpy())
inv_cano_smpl_A = torch.linalg.inv(cano_smpl.A).cuda()
body_mask = torch.from_numpy(body_mask).cuda()
cano_pts = torch.from_numpy(cano_pts).cuda()
pts_lbs = pts_lbs.cuda()
for pose_idx in tqdm.tqdm(frame_list, desc = 'Generating positional maps...'):
with torch.no_grad():
live_smpl_woRoot = smpl_model.forward(
betas = smpl_data['betas'],
# global_orient = smpl_data['global_orient'][pose_idx][None],
# transl = smpl_data['transl'][pose_idx][None],
body_pose = smpl_data['body_pose'][pose_idx][None],
jaw_pose = smpl_data['jaw_pose'][pose_idx][None],
expression = smpl_data['expression'][pose_idx][None],
# left_hand_pose = smpl_data['left_hand_pose'][pose_idx][None],
# right_hand_pose = smpl_data['right_hand_pose'][pose_idx][None]
)
cano2live_jnt_mats_woRoot = torch.matmul(live_smpl_woRoot.A.cuda(), inv_cano_smpl_A)[0]
pt_mats = torch.einsum('nj,jxy->nxy', pts_lbs, cano2live_jnt_mats_woRoot)
live_pts = torch.einsum('nxy,ny->nx', pt_mats[..., :3, :3], cano_pts) + pt_mats[..., :3, 3]
live_pos_map = torch.zeros((map_size, 2 * map_size, 3)).to(live_pts)
live_pos_map[body_mask] = live_pts
live_pos_map = F.interpolate(live_pos_map.permute(2, 0, 1)[None], None, [0.5, 0.5], mode = 'nearest')[0]
live_pos_map = live_pos_map.permute(1, 2, 0).cpu().numpy()
cv.imwrite(data_dir + '/smpl_pos_map/%08d.exr' % pose_idx, live_pos_map)