import glob import os import pickle import numpy as np import cv2 as cv import torch import trimesh from torch.utils.data import Dataset import yaml import json import AnimatableGaussians.smplx as smplx import AnimatableGaussians.dataset.commons as commons import AnimatableGaussians.utils.nerf_util as nerf_util import AnimatableGaussians.utils.visualize_util as visualize_util import AnimatableGaussians.config as config class PoseDataset(Dataset): @torch.no_grad() def __init__( self, data_path, frame_range = None, frame_interval = 1, smpl_shape = None, gender = 'neutral', frame_win = 0, fix_head_pose = True, fix_hand_pose = True, denoise = False, hand_pose_type = 'ori', constrain_leg_pose = False, device = 'cuda:0' ): super(PoseDataset, self).__init__() self.data_path = data_path self.training = False self.gender = gender data_name, ext = os.path.splitext(os.path.basename(data_path)) print(data_name) if ext == '.pkl': smpl_data = pickle.load(open(data_path, 'rb')) smpl_data = dict(smpl_data) self.body_poses = torch.from_numpy(smpl_data['smpl_poses']).to(torch.float32) self.transl = torch.from_numpy(smpl_data['smpl_trans']).to(torch.float32) * 1e-3 self.dataset_name = 'aist++' self.seq_name = data_name elif ext == '.npz': potential_datasets = ['thuman4', 'actorshq', 'avatarrex', 'AMASS'] for i, potential_dataset in enumerate(potential_datasets): start_pos = data_path.find(potential_dataset) if start_pos == -1: if i < len(potential_datasets) - 1: continue else: raise ValueError('Invalid data_path!') self.dataset_name = potential_dataset self.seq_name = data_path[start_pos:].replace(self.dataset_name, '').replace('/', '_').replace('\\', '_').replace('.npz', '') break # print(self.dataset_name) # print(f'# Dataset name: {self.dataset_name}, sequence name: {self.seq_name}') if self.dataset_name == 'thuman4' or self.dataset_name == 'actorshq' or self.dataset_name == 'avatarrex': smpl_data = np.load(data_path) # if smpl_data.shape[1] == 156: # # build dict # smpl_data = { # 'betas': smpl_data[:, :10], # 'global_orient': smpl_data[:, 10:13], # 'transl': smpl_data[:, 13:16], # 'body_pose': smpl_data[:, 16:88], # 'left_hand_pose': smpl_data[:, 88:133], # 'right_hand_pose': smpl_data[:, 133:] # } smpl_data = dict(smpl_data) for k in smpl_data.keys(): print(k, smpl_data[k].shape) else: # AMASS dataset pose_file = np.load(data_path) smpl_data = { 'betas': np.zeros((1, 10), np.float32), 'global_orient': pose_file['poses'][:, :3], 'transl': pose_file['trans'], 'body_pose': pose_file['poses'][:, 3: 22 * 3], 'left_hand_pose': pose_file['poses'][:, 22 * 3: 37 * 3], 'right_hand_pose': pose_file['poses'][:, 37 * 3:] } # smpl_data['body_pose'][:, 13 * 3 + 2] -= 0.3 # smpl_data['body_pose'][:, 12 * 3 + 2] += 0.3 # # smpl_data['body_pose'][:, 16 * 3 + 2] -= 0.1 # # smpl_data['body_pose'][:, 15 * 3 + 2] += 0.1 # smpl_data['body_pose'][:, 19 * 3: 20 * 3] = 0. # smpl_data['body_pose'][:, 20 * 3: 21 * 3] = 0. # smpl_data['body_pose'][:, 14 * 3] = 0. # print(smpl_data['body_pose'].shape) if self.seq_name == '_actor01': smpl_data['body_pose'][:, 6*3: 7*3] = 0. smpl_data['body_pose'][:, 7*3: 8*3] = 0. smpl_data = {k: torch.from_numpy(v).to(torch.float32) for k, v in smpl_data.items()} frame_num = smpl_data['body_pose'].shape[0] self.body_poses = torch.zeros((frame_num, 72), dtype = torch.float32) self.body_poses[:, :3] = smpl_data['global_orient'] self.body_poses[:, 3:3+21*3] = smpl_data['body_pose'] self.transl = smpl_data['transl'] # print(self.body_poses) data_dir = os.path.dirname(data_path) calib_path = os.path.basename(data_path).replace('.npz', '.json').replace('pose', 'calibration') calib_path = data_dir + '/' + calib_path if os.path.exists(calib_path): cam_data = json.load(open(calib_path, 'r')) self.view_num = len(cam_data) self.extr_mats = [] self.cam_names = list(cam_data.keys()) for view_idx in range(self.view_num): extr_mat = np.identity(4, np.float32) extr_mat[:3, :3] = np.array(cam_data[self.cam_names[view_idx]]['R'], np.float32).reshape(3, 3) extr_mat[:3, 3] = np.array(cam_data[self.cam_names[view_idx]]['T'], np.float32) self.extr_mats.append(extr_mat) self.intr_mats = [np.array(cam_data[self.cam_names[view_idx]]['K'], np.float32).reshape(3, 3) for view_idx in range(self.view_num)] self.img_heights = [cam_data[self.cam_names[view_idx]]['imgSize'][1] for view_idx in range(self.view_num)] self.img_widths = [cam_data[self.cam_names[view_idx]]['imgSize'][0] for view_idx in range(self.view_num)] else: raise AssertionError('Invalid data_path!') if 'left_hand_pose' in smpl_data: self.left_hand_pose = smpl_data['left_hand_pose'] else: self.left_hand_pose = config.left_hand_pose[None].expand(self.body_poses.shape[0], -1) if 'right_hand_pose' in smpl_data: self.right_hand_pose = smpl_data['right_hand_pose'] else: self.right_hand_pose = config.right_hand_pose[None].expand(self.body_poses.shape[0], -1) self.body_poses = self.body_poses.to(device) self.transl = self.transl.to(device) self.fix_head_pose = fix_head_pose self.fix_hand_pose = fix_hand_pose self.smpl_model = smplx.SMPLX(model_path = config.PROJ_DIR + '/smpl_files/smplx', gender = self.gender, use_pca = False, num_pca_comps = 45, flat_hand_mean = True, batch_size = 1).to(device) pose_list = list(range(0, self.body_poses.shape[0], frame_interval)) if frame_range is not None: frame_range = list(frame_range) if isinstance(frame_range, list): if isinstance(frame_range[0], list): self.pose_list = [] for interval in frame_range: if len(interval) == 2 or len(interval) == 3: self.pose_list += list(range(*interval)) else: for i in range(interval[3]): self.pose_list += list(range(interval[0], interval[1], interval[2])) else: if len(frame_range) == 2: print(f'# Selected frame indices: range({frame_range[0]}, {frame_range[1]})') frame_range = range(frame_range[0], frame_range[1]) elif len(frame_range) == 3: print(f'# Selected frame indices: range({frame_range[0]}, {frame_range[1]}, {frame_range[2]})') frame_range = range(frame_range[0], frame_range[1], frame_range[2]) self.pose_list = list(frame_range) else: self.pose_list = pose_list print('# Pose list: ', self.pose_list) print('# Dataset contains %d items' % len(self)) # SMPL related self.smpl_shape = smpl_shape.to(torch.float32).to(device) if smpl_shape is not None else torch.zeros(10, dtype = torch.float32) ret = self.smpl_model.forward(betas = self.smpl_shape[None], global_orient = config.cano_smpl_global_orient[None].to(device), transl = config.cano_smpl_transl[None].to(device), body_pose = config.cano_smpl_body_pose[None].to(device), # left_hand_pose = config.left_hand_pose[None], # right_hand_pose = config.right_hand_pose[None] ) self.cano_smpl = {k: v[0] for k, v in ret.items() if isinstance(v, torch.Tensor)} self.inv_cano_jnt_mats = torch.linalg.inv(self.cano_smpl['A']) min_xyz = self.cano_smpl['vertices'].min(0)[0] max_xyz = self.cano_smpl['vertices'].max(0)[0] self.cano_smpl_center = 0.5 * (min_xyz + max_xyz) min_xyz[:2] -= 0.05 max_xyz[:2] += 0.05 min_xyz[2] -= 0.15 max_xyz[2] += 0.15 self.cano_bounds = torch.stack([min_xyz, max_xyz], 0).to(torch.float32).cpu().numpy() self.smpl_faces = self.smpl_model.faces.astype(np.int32) self.frame_win = int(frame_win) self.denoise = denoise if self.denoise: win_size = 1 body_poses_clone = self.body_poses.clone() transl_clone = self.transl.clone() frame_num = body_poses_clone.shape[0] self.body_poses[win_size: frame_num-win_size] = 0 self.transl[win_size: frame_num-win_size] = 0 for i in range(-win_size, win_size + 1): self.body_poses[win_size: frame_num-win_size] += body_poses_clone[win_size+i: frame_num-win_size+i] self.transl[win_size: frame_num-win_size] += transl_clone[win_size+i: frame_num-win_size+i] self.body_poses[win_size: frame_num-win_size] /= (2 * win_size + 1) self.transl[win_size: frame_num-win_size] /= (2 * win_size + 1) self.hand_pose_type = hand_pose_type self.device = device self.last_data_idx = 0 commons._initialize_hands(self) self.left_cano_mano_v, self.left_cano_mano_n, self.right_cano_mano_v, self.right_cano_mano_n \ = commons.generate_two_manos(self, self.cano_smpl['vertices']) if constrain_leg_pose: # a = 14. # # print(self.body_poses[284, 1*3:2*3]) # # print(self.body_poses[284, 2*3:3*3]) # self.body_poses[:, 1*3] = torch.clip(self.body_poses[:, 1 * 3], -np.pi / a, np.pi / a) # self.body_poses[:, 2*3] = torch.clip(self.body_poses[:, 2 * 3], -np.pi / a, np.pi / a) # self.body_poses[:, 1 * 3+2] = torch.clip(self.body_poses[:, 1 * 3+2], -np.pi / a, np.pi / a) # self.body_poses[:, 2 * 3+2] = torch.clip(self.body_poses[:, 2 * 3+2], -np.pi / a, np.pi / a) # exit(1) self.body_poses[:, 4*3] = torch.clip(self.body_poses[:, 4*3], -0.3, 0.3) self.body_poses[:, 5*3] = torch.clip(self.body_poses[:, 5*3], -0.3, 0.3) def __len__(self): return len(self.pose_list) def __getitem__(self, index): return self.getitem(index) @torch.no_grad() def getitem(self, index, **kwargs): pose_idx = self.pose_list[index] if pose_idx == 0 or pose_idx > self.pose_list[min(index - 1, 0)]: data_idx = pose_idx else: data_idx = self.last_data_idx + 1 # print('data index: %d, pose index: %d' % (data_idx, pose_idx)) if self.hand_pose_type == 'fist': left_hand_pose = config.left_hand_pose.to(self.device).clone() right_hand_pose = config.right_hand_pose.to(self.device).clone() left_hand_pose[:3] = 0. right_hand_pose[:3] = 0. elif self.hand_pose_type == 'normal': left_hand_pose = torch.tensor([0.10859203338623047, 0.10181399434804916, -0.2822268009185791, 0.10211331397294998, -0.09689036756753922, -0.4484838545322418, -0.11360692232847214, -0.023141659796237946, 0.10571160167455673, -0.08793719857931137, -0.026760095730423927, -0.41390693187713623, -0.0923849567770958, 0.10266668349504471, -0.36039748787879944, 0.02140655182301998, -0.07156527787446976, -0.04903153330087662, -0.22358819842338562, -0.3716682195663452, -0.2683027982711792, -0.1506909281015396, 0.07079305499792099, -0.34404537081718445, -0.168443500995636, -0.014021224342286587, 0.09489774703979492, -0.050323735922575, -0.18992969393730164, -0.43895423412323, -0.1806418001651764, 0.0198075994849205, -0.25444355607032776, -0.10171788930892944, -0.10680688172578812, -0.09953738003969193, 0.8094075918197632, 0.5156061053276062, -0.07900168001651764, -0.45094889402389526, 0.24947893619537354, 0.23369410634040833, 0.45277315378189087, -0.17375235259532928, -0.3077943027019501], dtype = torch.float32, device = self.device) right_hand_pose = torch.tensor([0.06415501981973648, -0.06942438334226608, 0.282951682806015, 0.09073827415704727, 0.0775153785943985, 0.2961004376411438, -0.07659692317247391, 0.004730052314698696, -0.12084470689296722, 0.007974660955369473, 0.05222926288843155, 0.32775357365608215, -0.10166633129119873, -0.06862349808216095, 0.174485981464386, -0.0023323255591094494, 0.04998664930462837, -0.03490559384226799, 0.12949667870998383, 0.26883721351623535, 0.06881044059991837, -0.18259745836257935, -0.08183271437883377, 0.17669665813446045, -0.08099694550037384, 0.04115655645728111, -0.17928685247898102, 0.07734024524688721, 0.13419172167778015, 0.2600148022174835, -0.151871919631958, -0.01772170141339302, 0.1267814189195633, -0.08800505846738815, 0.09480107575654984, 0.0016392067773267627, 0.6149336695671082, -0.32634419202804565, 0.02278662845492363, -0.39148610830307007, -0.22757330536842346, -0.07884717732667923, 0.38199105858802795, 0.13064607977867126, 0.20154500007629395], dtype = torch.float32, device = self.device) elif self.hand_pose_type == 'zero': left_hand_pose = torch.zeros(45, dtype = torch.float32, device = self.device) right_hand_pose = torch.zeros(45, dtype = torch.float32, device = self.device) elif self.hand_pose_type == 'ori': left_hand_pose = self.left_hand_pose[pose_idx].to(self.device) right_hand_pose = self.right_hand_pose[pose_idx].to(self.device) else: raise ValueError('Invalid hand_pose_type!') # SMPL live_smpl = self.smpl_model.forward(betas = self.smpl_shape[None], global_orient = self.body_poses[pose_idx, :3][None], transl = self.transl[pose_idx][None], body_pose = self.body_poses[pose_idx, 3: 66][None], left_hand_pose = left_hand_pose[None], right_hand_pose = right_hand_pose[None] ) # live_smpl_trimesh = trimesh.Trimesh(vertices = live_smpl.vertices[0].cpu().numpy(), faces = self.smpl_model.faces, process = False) # live_smpl_trimesh.export('./debug/smpl_amass.ply') # exit(1) live_smpl_woRoot = self.smpl_model.forward(betas = self.smpl_shape[None], # global_orient = self.body_poses[pose_idx, :3][None], # transl = self.transl[pose_idx][None], body_pose = self.body_poses[pose_idx, 3: 66][None], # left_hand_pose = config.left_hand_pose[None], # right_hand_pose = config.right_hand_pose[None] ) # cano_smpl = self.smpl_model.forward(betas=self.smpl_shape[None], # global_orient=config.cano_smpl_global_orient[None], # transl=config.cano_smpl_transl[None], # body_pose=config.cano_smpl_body_pose[None], # # left_hand_pose = left_hand_pose[None], # # right_hand_pose = right_hand_pose[None] # ) data_item = dict() data_item['item_idx'] = index data_item['data_idx'] = data_idx data_item['global_orient'] = self.body_poses[pose_idx, :3] data_item['transl'] = self.transl[pose_idx] data_item['joints'] = live_smpl.joints[0, :22] data_item['kin_parent'] = self.smpl_model.parents[:22].to(torch.long) data_item['pose_1st'] = self.body_poses[0, 3: 66] if self.frame_win > 0: total_frame_num = len(self.pose_list) selected_frames = self.pose_list[max(0, index - self.frame_win): min(total_frame_num, index + self.frame_win + 1)] data_item['pose'] = self.body_poses[selected_frames, 3: 66].clone() else: data_item['pose'] = self.body_poses[pose_idx, 3: 66].clone() if self.fix_head_pose: data_item['pose'][..., 3 * 11: 3 * 11 + 3] = 0. data_item['pose'][..., 3 * 14: 3 * 14 + 3] = 0. if self.fix_hand_pose: data_item['pose'][..., 3 * 19: 3 * 19 + 3] = 0. data_item['pose'][..., 3 * 20: 3 * 20 + 3] = 0. data_item['lhand_pose'] = torch.zeros_like(config.left_hand_pose) data_item['rhand_pose'] = torch.zeros_like(config.right_hand_pose) data_item['time_stamp'] = np.array(pose_idx, np.float32) data_item['live_smpl_v'] = live_smpl.vertices[0] data_item['live_smpl_v_woRoot'] = live_smpl_woRoot.vertices[0] data_item['cano_smpl_v'] = self.cano_smpl['vertices'] data_item['cano_jnts'] = self.cano_smpl['joints'] inv_cano_jnt_mats = torch.linalg.inv(self.cano_smpl['A']) data_item['cano2live_jnt_mats'] = torch.matmul(live_smpl.A[0], inv_cano_jnt_mats) data_item['cano2live_jnt_mats_woRoot'] = torch.matmul(live_smpl_woRoot.A[0], inv_cano_jnt_mats) data_item['cano_smpl_center'] = self.cano_smpl_center data_item['cano_bounds'] = self.cano_bounds data_item['smpl_faces'] = self.smpl_faces min_xyz = live_smpl.vertices[0].min(0)[0] - 0.15 max_xyz = live_smpl.vertices[0].max(0)[0] + 0.15 live_bounds = torch.stack([min_xyz, max_xyz], 0).to(torch.float32).cpu().numpy() data_item['live_bounds'] = live_bounds # # mano # data_item['left_cano_mano_v'], data_item['left_cano_mano_n'], data_item['right_cano_mano_v'], data_item['right_cano_mano_n']\ # = commons.generate_two_manos(self, self.cano_smpl['vertices']) # data_item['left_live_mano_v'], data_item['left_live_mano_n'], data_item['right_live_mano_v'], data_item['right_live_mano_n'] \ # = commons.generate_two_manos(self, live_smpl.vertices[0]) """ synthesis config """ img_h = 512 if 'img_h' not in kwargs else kwargs['img_h'] img_w = 512 if 'img_w' not in kwargs else kwargs['img_w'] intr = np.array([[550, 0, 256], [0, 550, 256], [0, 0, 1]], np.float32) if 'intr' not in kwargs else kwargs['intr'] if 'extr' not in kwargs: extr = visualize_util.calc_front_mv(live_bounds.mean(0), tar_pos = np.array([0, 0, 2.5])) else: extr = kwargs['extr'] """ training data config of view_idx """ # view_idx = 0 # img_h = self.img_heights[view_idx] # img_w = self.img_widths[view_idx] # intr = self.intr_mats[view_idx] # extr = self.extr_mats[view_idx] uv = self.gen_uv(img_w, img_h) uv = uv.reshape(-1, 2) ray_d, ray_o = nerf_util.get_rays(uv, extr, intr) near, far, mask_at_bound = nerf_util.get_near_far(live_bounds, ray_o, ray_d) uv = uv[mask_at_bound] ray_o = ray_o[mask_at_bound] ray_d = ray_d[mask_at_bound] data_item.update({ 'uv': uv, 'ray_o': ray_o, 'ray_d': ray_d, 'near': near, 'far': far, 'dist': np.zeros_like(near), 'img_h': img_h, 'img_w': img_w, 'extr': extr, 'intr': intr }) return data_item def getitem_fast(self, index, **kwargs): pose_idx = self.pose_list[index] if pose_idx == 0 or pose_idx > self.last_data_idx: data_idx = pose_idx else: data_idx = self.last_data_idx + 1 # print('data index: %d, pose index: %d' % (data_idx, pose_idx)) if self.hand_pose_type == 'fist': left_hand_pose = config.left_hand_pose.to(self.device) right_hand_pose = config.right_hand_pose.to(self.device) elif self.hand_pose_type == 'normal': left_hand_pose = torch.tensor( [0.10859203338623047, 0.10181399434804916, -0.2822268009185791, 0.10211331397294998, -0.09689036756753922, -0.4484838545322418, -0.11360692232847214, -0.023141659796237946, 0.10571160167455673, -0.08793719857931137, -0.026760095730423927, -0.41390693187713623, -0.0923849567770958, 0.10266668349504471, -0.36039748787879944, 0.02140655182301998, -0.07156527787446976, -0.04903153330087662, -0.22358819842338562, -0.3716682195663452, -0.2683027982711792, -0.1506909281015396, 0.07079305499792099, -0.34404537081718445, -0.168443500995636, -0.014021224342286587, 0.09489774703979492, -0.050323735922575, -0.18992969393730164, -0.43895423412323, -0.1806418001651764, 0.0198075994849205, -0.25444355607032776, -0.10171788930892944, -0.10680688172578812, -0.09953738003969193, 0.8094075918197632, 0.5156061053276062, -0.07900168001651764, -0.45094889402389526, 0.24947893619537354, 0.23369410634040833, 0.45277315378189087, -0.17375235259532928, -0.3077943027019501], dtype = torch.float32, device = self.device) right_hand_pose = torch.tensor( [0.06415501981973648, -0.06942438334226608, 0.282951682806015, 0.09073827415704727, 0.0775153785943985, 0.2961004376411438, -0.07659692317247391, 0.004730052314698696, -0.12084470689296722, 0.007974660955369473, 0.05222926288843155, 0.32775357365608215, -0.10166633129119873, -0.06862349808216095, 0.174485981464386, -0.0023323255591094494, 0.04998664930462837, -0.03490559384226799, 0.12949667870998383, 0.26883721351623535, 0.06881044059991837, -0.18259745836257935, -0.08183271437883377, 0.17669665813446045, -0.08099694550037384, 0.04115655645728111, -0.17928685247898102, 0.07734024524688721, 0.13419172167778015, 0.2600148022174835, -0.151871919631958, -0.01772170141339302, 0.1267814189195633, -0.08800505846738815, 0.09480107575654984, 0.0016392067773267627, 0.6149336695671082, -0.32634419202804565, 0.02278662845492363, -0.39148610830307007, -0.22757330536842346, -0.07884717732667923, 0.38199105858802795, 0.13064607977867126, 0.20154500007629395], dtype = torch.float32, device = self.device) elif self.hand_pose_type == 'zero': left_hand_pose = torch.zeros(45, dtype = torch.float32, device = self.device) right_hand_pose = torch.zeros(45, dtype = torch.float32, device = self.device) elif self.hand_pose_type == 'ori': left_hand_pose = self.left_hand_pose[pose_idx].to(self.device) right_hand_pose = self.right_hand_pose[pose_idx].to(self.device) else: raise ValueError('Invalid hand_pose_type!') # SMPL live_smpl = self.smpl_model.forward(betas = self.smpl_shape[None], global_orient = self.body_poses[pose_idx, :3][None], transl = self.transl[pose_idx][None], body_pose = self.body_poses[pose_idx, 3: 66][None], left_hand_pose = left_hand_pose[None], right_hand_pose = right_hand_pose[None] ) live_smpl_woRoot = self.smpl_model.forward(betas = self.smpl_shape[None], # global_orient = self.body_poses[pose_idx, :3][None], # transl = self.transl[pose_idx][None], body_pose = self.body_poses[pose_idx, 3: 66][None], # left_hand_pose = config.left_hand_pose[None], # right_hand_pose = config.right_hand_pose[None] ) # cano_smpl = self.smpl_model.forward(betas = self.smpl_shape[None], # global_orient = config.cano_smpl_global_orient[None], # transl = config.cano_smpl_transl[None], # body_pose = config.cano_smpl_body_pose[None], # # left_hand_pose = left_hand_pose[None], # # right_hand_pose = right_hand_pose[None] # ) data_item = dict() data_item['item_idx'] = index data_item['data_idx'] = data_idx data_item['global_orient'] = self.body_poses[pose_idx, :3] data_item['body_pose'] = self.body_poses[pose_idx, 3:66] data_item['transl'] = self.transl[pose_idx] data_item['joints'] = live_smpl.joints[0, :22] data_item['kin_parent'] = self.smpl_model.parents[:22].to(torch.long) data_item['live_smpl_v'] = live_smpl.vertices[0] data_item['live_smpl_v_woRoot'] = live_smpl_woRoot.vertices[0] data_item['cano_smpl_v'] = self.cano_smpl['vertices'] data_item['cano_jnts'] = self.cano_smpl['joints'] inv_cano_jnt_mats = torch.linalg.inv(self.cano_smpl['A']) data_item['cano2live_jnt_mats'] = torch.matmul(live_smpl.A[0], inv_cano_jnt_mats) data_item['cano2live_jnt_mats_woRoot'] = torch.matmul(live_smpl_woRoot.A[0], inv_cano_jnt_mats) data_item['cano_smpl_center'] = self.cano_smpl_center data_item['cano_bounds'] = self.cano_bounds data_item['smpl_faces'] = self.smpl_faces min_xyz = live_smpl.vertices[0].min(0)[0] - 0.15 max_xyz = live_smpl.vertices[0].max(0)[0] + 0.15 live_bounds = torch.stack([min_xyz, max_xyz], 0).to(torch.float32).cpu().numpy() data_item['live_bounds'] = live_bounds data_item['left_cano_mano_v'], data_item['left_cano_mano_n'], data_item['right_cano_mano_v'], data_item['right_cano_mano_n'] \ = self.left_cano_mano_v, self.left_cano_mano_n, self.right_cano_mano_v, self.right_cano_mano_n """ synthesis config """ img_h = 512 if 'img_h' not in kwargs else kwargs['img_h'] img_w = 512 if 'img_w' not in kwargs else kwargs['img_w'] intr = np.array([[550, 0, 256], [0, 550, 256], [0, 0, 1]], np.float32) if 'intr' not in kwargs else kwargs['intr'] if 'extr' not in kwargs: extr = visualize_util.calc_front_mv(live_bounds.mean(0), tar_pos = np.array([0, 0, 2.5])) else: extr = kwargs['extr'] data_item.update({ 'img_h': img_h, 'img_w': img_w, 'extr': extr, 'intr': intr }) self.last_data_idx = data_idx return data_item def getitem_a_pose(self, **kwargs): hand_pose_type = 'fist' if hand_pose_type == 'fist': left_hand_pose = config.left_hand_pose.to(self.device) right_hand_pose = config.right_hand_pose.to(self.device) elif hand_pose_type == 'normal': left_hand_pose = torch.tensor( [0.10859203338623047, 0.10181399434804916, -0.2822268009185791, 0.10211331397294998, -0.09689036756753922, -0.4484838545322418, -0.11360692232847214, -0.023141659796237946, 0.10571160167455673, -0.08793719857931137, -0.026760095730423927, -0.41390693187713623, -0.0923849567770958, 0.10266668349504471, -0.36039748787879944, 0.02140655182301998, -0.07156527787446976, -0.04903153330087662, -0.22358819842338562, -0.3716682195663452, -0.2683027982711792, -0.1506909281015396, 0.07079305499792099, -0.34404537081718445, -0.168443500995636, -0.014021224342286587, 0.09489774703979492, -0.050323735922575, -0.18992969393730164, -0.43895423412323, -0.1806418001651764, 0.0198075994849205, -0.25444355607032776, -0.10171788930892944, -0.10680688172578812, -0.09953738003969193, 0.8094075918197632, 0.5156061053276062, -0.07900168001651764, -0.45094889402389526, 0.24947893619537354, 0.23369410634040833, 0.45277315378189087, -0.17375235259532928, -0.3077943027019501], dtype = torch.float32, device = self.device) right_hand_pose = torch.tensor( [0.06415501981973648, -0.06942438334226608, 0.282951682806015, 0.09073827415704727, 0.0775153785943985, 0.2961004376411438, -0.07659692317247391, 0.004730052314698696, -0.12084470689296722, 0.007974660955369473, 0.05222926288843155, 0.32775357365608215, -0.10166633129119873, -0.06862349808216095, 0.174485981464386, -0.0023323255591094494, 0.04998664930462837, -0.03490559384226799, 0.12949667870998383, 0.26883721351623535, 0.06881044059991837, -0.18259745836257935, -0.08183271437883377, 0.17669665813446045, -0.08099694550037384, 0.04115655645728111, -0.17928685247898102, 0.07734024524688721, 0.13419172167778015, 0.2600148022174835, -0.151871919631958, -0.01772170141339302, 0.1267814189195633, -0.08800505846738815, 0.09480107575654984, 0.0016392067773267627, 0.6149336695671082, -0.32634419202804565, 0.02278662845492363, -0.39148610830307007, -0.22757330536842346, -0.07884717732667923, 0.38199105858802795, 0.13064607977867126, 0.20154500007629395], dtype = torch.float32, device = self.device) elif self.hand_pose_type == 'zero': left_hand_pose = torch.zeros(45, dtype = torch.float32, device = self.device) right_hand_pose = torch.zeros(45, dtype = torch.float32, device = self.device) else: raise ValueError('Invalid hand_pose_type!') body_pose = torch.zeros(21 * 3, dtype = torch.float32).to(self.device) body_pose[15 * 3 + 2] += -0.8 body_pose[16 * 3 + 2] += 0.8 # SMPL live_smpl = self.smpl_model.forward(betas = self.smpl_shape[None], global_orient = None, transl = None, body_pose = body_pose[None], left_hand_pose = left_hand_pose[None], right_hand_pose = right_hand_pose[None] ) live_smpl_woRoot = self.smpl_model.forward(betas = self.smpl_shape[None], # global_orient = self.body_poses[pose_idx, :3][None], # transl = self.transl[pose_idx][None], body_pose = body_pose[None], # left_hand_pose = config.left_hand_pose[None], # right_hand_pose = config.right_hand_pose[None] ) # cano_smpl = self.smpl_model.forward(betas = self.smpl_shape[None], # global_orient = config.cano_smpl_global_orient[None], # transl = config.cano_smpl_transl[None], # body_pose = config.cano_smpl_body_pose[None], # # left_hand_pose = left_hand_pose[None], # # right_hand_pose = right_hand_pose[None] # ) data_item = dict() data_item['item_idx'] = 0 data_item['data_idx'] = 0 data_item['global_orient'] = torch.zeros(3, dtype = torch.float32) data_item['joints'] = live_smpl.joints[0, :22] data_item['kin_parent'] = self.smpl_model.parents[:22].to(torch.long) data_item['live_smpl_v'] = live_smpl.vertices[0] data_item['live_smpl_v_woRoot'] = live_smpl_woRoot.vertices[0] data_item['cano_smpl_v'] = self.cano_smpl['vertices'] data_item['cano_jnts'] = self.cano_smpl['joints'] inv_cano_jnt_mats = torch.linalg.inv(self.cano_smpl['A']) data_item['cano2live_jnt_mats'] = torch.matmul(live_smpl.A[0], inv_cano_jnt_mats) data_item['cano2live_jnt_mats_woRoot'] = torch.matmul(live_smpl_woRoot.A[0], inv_cano_jnt_mats) data_item['cano_smpl_center'] = self.cano_smpl_center data_item['cano_bounds'] = self.cano_bounds data_item['smpl_faces'] = self.smpl_faces min_xyz = live_smpl.vertices[0].min(0)[0] - 0.15 max_xyz = live_smpl.vertices[0].max(0)[0] + 0.15 live_bounds = torch.stack([min_xyz, max_xyz], 0).to(torch.float32).cpu().numpy() data_item['live_bounds'] = live_bounds data_item['left_cano_mano_v'], data_item['left_cano_mano_n'], data_item['right_cano_mano_v'], data_item['right_cano_mano_n'] \ = self.left_cano_mano_v, self.left_cano_mano_n, self.right_cano_mano_v, self.right_cano_mano_n """ synthesis config """ img_h = 512 if 'img_h' not in kwargs else kwargs['img_h'] img_w = 300 if 'img_w' not in kwargs else kwargs['img_w'] intr = np.array([[550, 0, 150], [0, 550, 256], [0, 0, 1]], np.float32) if 'intr' not in kwargs else kwargs['intr'] if 'extr' not in kwargs: extr = visualize_util.calc_front_mv(live_bounds.mean(0), tar_pos = np.array([0, 0, 2.5])) else: extr = kwargs['extr'] data_item.update({ 'img_h': img_h, 'img_w': img_w, 'extr': extr, 'intr': intr }) return data_item @staticmethod def gen_uv(img_w, img_h): x, y = np.meshgrid(np.linspace(0, img_w - 1, img_w, dtype = np.int32), np.linspace(0, img_h - 1, img_h, dtype = np.int32)) uv = np.stack([x, y], axis = -1) return uv