import argparse import math from pathlib import Path import cv2 import numpy as np import PIL.Image as Image import selfcontact import selfcontact.losses import shapely.geometry import torch import torch.nn as nn import torch.optim as optim import torchgeometry import tqdm import trimesh from skimage import measure import fist_pose import hist_cub import losses import pose_estimation import spin PE_KSP_TO_SPIN = { "Head": "Head", "Neck": "Neck", "Right Shoulder": "Right ForeArm", "Right Arm": "Right Arm", "Right Hand": "Right Hand", "Left Shoulder": "Left ForeArm", "Left Arm": "Left Arm", "Left Hand": "Left Hand", "Spine": "Spine1", "Hips": "Hips", "Right Upper Leg": "Right Upper Leg", "Right Leg": "Right Leg", "Right Foot": "Right Foot", "Left Upper Leg": "Left Upper Leg", "Left Leg": "Left Leg", "Left Foot": "Left Foot", "Left Toe": "Left Toe", "Right Toe": "Right Toe", } MODELS_DIR = "models" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--pose-estimation-model-path", type=str, default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx", help="Pose Estimation model", ) parser.add_argument( "--contact-model-path", type=str, default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx", help="Contact model", ) parser.add_argument( "--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Torch device", ) parser.add_argument( "--spin-model-path", type=str, default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt", help="SPIN model path", ) parser.add_argument( "--smpl-type", type=str, default="smplx", choices=["smplx"], help="SMPL model type", ) parser.add_argument( "--smpl-model-dir", type=str, default=f"./{MODELS_DIR}/models/smplx", help="SMPL model dir", ) parser.add_argument( "--smpl-mean-params-path", type=str, default=f"./{MODELS_DIR}/data/smpl_mean_params.npz", help="SMPL mean params", ) parser.add_argument( "--essentials-dir", type=str, default=f"./{MODELS_DIR}/smplify-xmc-essentials", help="SMPL Essentials folder for contacts", ) parser.add_argument( "--parametrization-path", type=str, default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy", help="Parametrization path", ) parser.add_argument( "--bone-parametrization-path", type=str, default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy", help="Bone parametrization path", ) parser.add_argument( "--foot-inds-path", type=str, default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy", help="Foot indinces", ) parser.add_argument( "--save-path", type=str, required=True, help="Path to save the results", ) parser.add_argument( "--img-path", type=str, required=True, help="Path to img to test", ) parser.add_argument( "--use-contacts", action="store_true", help="Use contact model", ) parser.add_argument( "--use-msc", action="store_true", help="Use MSC loss", ) parser.add_argument( "--use-natural", action="store_true", help="Use regularity", ) parser.add_argument( "--use-cos", action="store_true", help="Use cos model", ) parser.add_argument( "--use-angle-transf", action="store_true", help="Use cube foreshortening transformation", ) parser.add_argument( "--c-mse", type=float, default=0, help="MSE weight", ) parser.add_argument( "--c-par", type=float, default=10, help="Parallel weight", ) parser.add_argument( "--c-f", type=float, default=1000, help="Cos coef", ) parser.add_argument( "--c-parallel", type=float, default=100, help="Parallel weight", ) parser.add_argument( "--c-reg", type=float, default=1000, help="Regularity weight", ) parser.add_argument( "--c-cont2d", type=float, default=1, help="Contact 2D weight", ) parser.add_argument( "--c-msc", type=float, default=17_500, help="MSC weight", ) parser.add_argument( "--fist", nargs="+", type=str, choices=list(fist_pose.INT_TO_FIST), ) args = parser.parse_args() return args def freeze_layers(model): for module in model.modules(): if type(module) is False: continue if isinstance(module, nn.modules.batchnorm._BatchNorm): module.eval() for m in module.parameters(): m.requires_grad = False if isinstance(module, nn.Dropout): module.eval() for m in module.parameters(): m.requires_grad = False def project_and_normalize_to_spin(vertices_3d, camera): vertices_2d = vertices_3d # [:, :2] scale, translate = camera[0], camera[1:] translate = scale.new_zeros(3) translate[:2] = camera[1:] vertices_2d = vertices_2d + translate vertices_2d = scale * vertices_2d + 1 vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d return vertices_2d def project_and_normalize_to_spin_legs(vertices_3d, A, camera): A, J = A A = A[0] J = J[0] L = vertices_3d.new_tensor( [ [0.98619063, 0.16560926, 0.00127302], [-0.16560601, 0.98603675, 0.01749799], [0.00164258, -0.01746717, 0.99984609], ] ) R = vertices_3d.new_tensor( [ [0.9910211, -0.13368178, -0.0025208], [0.13367888, 0.99027076, 0.03864949], [-0.00267045, -0.03863944, 0.99924965], ] ) scale = camera[0] R = A[2, :3, :3] @ R # 2 - right L = A[1, :3, :3] @ L # 1 - left r = J[5] - J[2] l = J[4] - J[1] rleg = scale * spin.constants.IMG_RES / 2 * R @ r lleg = scale * spin.constants.IMG_RES / 2 * L @ l rleg = rleg[:2] lleg = lleg[:2] return rleg, lleg def rotation_matrix_to_angle_axis(rotmat): bs, n_joints, *_ = rotmat.size() rotmat = torch.cat( [ rotmat.view(-1, 3, 3), rotmat.new_tensor([0, 0, 1], dtype=torch.float32) .view(bs, 3, 1) .expand(n_joints, -1, -1), ], dim=-1, ) aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat) aa = aa.reshape(bs, 3 * n_joints) return aa def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False): if smpl.name() == "SMPL": smpl_output = smpl( betas=betas if use_betas else None, body_pose=rotmat[:, 1:], global_orient=rotmat[:, 0].unsqueeze(1), pose2rot=False, ) elif smpl.name() == "SMPL-X": rotmat = rotation_matrix_to_angle_axis(rotmat) if zero_hands: for i in [20, 21]: rotmat[:, 3 * i : 3 * (i + 1)] = 0 for i in [12, 15]: # neck, head rotmat[:, 3 * i + 1] = 0 # y smpl_output = smpl( betas=betas if use_betas else None, body_pose=rotmat[:, 3:], global_orient=rotmat[:, :3], pose2rot=True, ) else: raise NotImplementedError return smpl_output, rotmat def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False): input_img = input_img.unsqueeze(0) rotmat, betas, camera = model_hmr(input_img) smpl_output, rotmat = get_smpl_output( smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands ) rotmat = rotmat.squeeze(0) betas = betas.squeeze(0) camera = camera.squeeze(0) z = smpl_output.joints z = z.squeeze(0) return rotmat, betas, camera, smpl_output, z def get_pred_and_data( model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False ): rotmat, betas, camera, smpl_output, zz = get_predictions( model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands ) joints = smpl_output.joints.squeeze(0) joints_2d = project_and_normalize_to_spin(joints, camera) rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera) joints_2d_orig = joints_2d joints_2d = joints_2d[selector] vertices = smpl_output.vertices.squeeze(0) vertices_2d = project_and_normalize_to_spin(vertices, camera) zz = zz[selector] return ( rotmat, betas, camera, joints_2d, zz, vertices_2d, smpl_output, (rleg, lleg), joints_2d_orig, ) def normalize_keypoints_to_spin(keypoints_2d, img_size): h, w = img_size if h > w: # vertically ax1 = 1 ax2 = 0 else: # horizontal ax1 = 0 ax2 = 1 shift = (img_size[ax1] - img_size[ax2]) / 2 scale = spin.constants.IMG_RES / img_size[ax2] keypoints_2d_normalized = np.copy(keypoints_2d) keypoints_2d_normalized[:, ax2] -= shift keypoints_2d_normalized *= scale return keypoints_2d_normalized, shift, scale, ax2 def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2): keypoints_2d_normalized = np.copy(keypoints_2d) keypoints_2d_normalized /= scale keypoints_2d_normalized[:, ax2] += shift return keypoints_2d_normalized def get_vertices_in_heatmap(contact_heatmap): contact_heatmap_size = contact_heatmap.shape[:2] label = measure.label(contact_heatmap) y_data_conts = [] for i in range(1, label.max() + 1): predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype( "float" ) predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin( predicted_kps_contact, contact_heatmap_size ) y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist() y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull y_data_conts.append(y_data_cont) return y_data_conts def get_contact_heatmap(model_contact, img_path, thresh=0.5): contact_heatmap = pose_estimation.infer_single_image( model_contact, img_path, input_img_size=(192, 256), return_kps=False, ) contact_heatmap = contact_heatmap.squeeze(0) contact_heatmap_orig = contact_heatmap.copy() mi = contact_heatmap.min() ma = contact_heatmap.max() contact_heatmap = (contact_heatmap - mi) / (ma - mi) contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8") contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1) contact_heatmap = (contact_heatmap * 255).astype("uint8") return contact_heatmap_, contact_heatmap, contact_heatmap_orig def discretize(parametrization, n_bins=100): bins = np.linspace(0, 1, n_bins + 1) inds = np.digitize(parametrization, bins) disc_parametrization = bins[inds - 1] return disc_parametrization def get_mapping_from_params_to_verts(verts, params): mapping = {} for v, t in zip(verts, params): mapping.setdefault(t, []).append(v) return mapping def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375): n_bins = int(math.ceil(1 / step)) - 1 # mean face's circumradius contact = [] contact_2d = [] for_mask = [] for y_data_cont in y_data_conts: contact_loc = [] contact_2d_loc = [] buffer = y_data_cont.buffer(thresh) mask_add = False for i, j in pose_estimation.SKELETON: verts, t3d = bone_to_params[(i, j)] if len(verts) == 0: continue t3d = discretize(t3d, n_bins=n_bins) t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d) t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0]) t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted]) line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]]) lint = buffer.intersection(line) if len(lint.boundary.geoms) < 2: continue t2d_start = line.project(lint.boundary.geoms[0], normalized=True) t2d_end = line.project(lint.boundary.geoms[1], normalized=True) assert t2d_start <= t2d_end t2ds = discretize( np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins ) to_add = False for t2d in t2ds: if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]: continue t2d_ind = np.searchsorted(t3d_sorted_np, t2d) c = t3d_to_verts_sorted[t2d_ind][1] contact_loc.extend(c) to_add = True mask_add = True if t2d_ind + 1 < len(t3d_to_verts_sorted): c = t3d_to_verts_sorted[t2d_ind + 1][1] contact_loc.extend(c) if t2d_ind > 0: c = t3d_to_verts_sorted[t2d_ind - 1][1] contact_loc.extend(c) if to_add: contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start))) if mask_add: for_mask.append(buffer.exterior.coords.xy) contact_loc = sorted(set(contact_loc)) contact_loc = np.array(contact_loc, dtype="int") contact.append(contact_loc) contact_2d.append(contact_2d_loc) for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask] return contact, contact_2d, for_mask def optimize( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse=None, loss_parallel=None, c_mse=0.0, c_new_mse=1.0, c_beta=1e-3, sc_crit=None, msc_crit=None, contact=None, n_steps=60, i_ini=0, ): mean_zfoot_val = {} with tqdm.trange(n_steps) as pbar: for i in pbar: global_step = i + i_ini optimizer.zero_grad() ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, z, vertices_2d_pred, smpl_output, (rleg, lleg), joints_2d_orig, ) = get_pred_and_data( model_hmr, smpl, selector, input_img, ) keypoints_2d_pred = keypoints_3d_pred[:, :2] loss = l2 = 0.0 if c_mse > 0 and loss_mse is not None: l2 = loss_mse(keypoints_2d_pred, keypoints_2d) loss = loss + c_mse * l2 vertices_pred = smpl_output.vertices lpar = z_loss = loss_sh = 0.0 if c_new_mse > 0 and loss_parallel is not None: Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel( keypoints_3d_pred, keypoints_2d, z, (rleg, lleg), global_step=global_step, ) lpar = ( Ltan + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar) + Lspine + args.c_reg * Lgr + args.c_reg * Lstraight3d + args.c_cont2d * Lcon2d ) loss = loss + 300 * lpar for side in ["left", "right"]: attr = f"{side}_foot_inds" if hasattr(loss_parallel, attr): foot_inds = getattr(loss_parallel, attr) zind = 1 if attr not in mean_zfoot_val: with torch.no_grad(): mean_zfoot_val[attr] = torch.median( vertices_pred[0, foot_inds, zind], dim=0 ).values loss_foot = ( (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr]) ** 2 ).sum() loss = loss + args.c_reg * loss_foot if hasattr(loss_parallel, "silhuette_vertices_inds"): inds = loss_parallel.silhuette_vertices_inds loss_sh = ( (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2 ).sum() loss = loss + args.c_reg * loss_sh lbeta = (betas_pred**2).mean() lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean() loss = loss + c_beta * lbeta + lcam lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0 if sc_crit is not None: gsc_contact_loss, faces_angle_loss = sc_crit( vertices_pred, ) lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss loss = loss + lgsc_a msc_loss = 0.0 if contact is not None and len(contact) > 0 and msc_crit is not None: if not isinstance(contact, list): contact = [contact] for cntct in contact: msc_loss = msc_crit( cntct, vertices_pred, ) loss = loss + args.c_msc * msc_loss loss.backward() optimizer.step() epoch_loss = loss.item() pbar.set_postfix( **{ "l": f"{epoch_loss:.3}", "l2": f"{l2:.3}", "par": f"{lpar:.3}", "beta": f"{lbeta:.3}", "cam": f"{lcam:.3}", "z": f"{z_loss:.3}", "gsc_contact": f"{float(gsc_contact_loss):.3}", "faces_angle": f"{float(faces_angle_loss):.3}", "msc": f"{float(msc_loss):.3}", } ) with torch.no_grad(): ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, z, vertices_2d_pred, smpl_output, (rleg, lleg), joints_2d_orig, ) = get_pred_and_data( model_hmr, smpl, selector, input_img, zero_hands=True, ) return ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, vertices_2d_pred, smpl_output, z, joints_2d_orig, ) def optimize_ft( theta, camera, smpl, selector, keypoints_2d, args, loss_mse=None, loss_parallel=None, c_mse=0.0, c_new_mse=1.0, sc_crit=None, msc_crit=None, contact=None, n_steps=60, i_ini=0, zero_hands=False, fist=None, ): mean_zfoot_val = {} theta = theta.detach().clone() camera = camera.detach().clone() rotmat_pred = nn.Parameter(theta) camera_pred = nn.Parameter(camera) optimizer = torch.optim.Adam( [ rotmat_pred, camera_pred, ], lr=1e-3, ) global_step = i_ini with tqdm.trange(n_steps) as pbar: for i in pbar: global_step = i + i_ini optimizer.zero_grad() global_orient = rotmat_pred[:3] body_pose = rotmat_pred[3:] smpl_output = smpl( global_orient=global_orient.unsqueeze(0), body_pose=body_pose.unsqueeze(0), pose2rot=True, ) z = smpl_output.joints z = z.squeeze(0) joints = smpl_output.joints.squeeze(0) joints_2d = project_and_normalize_to_spin(joints, camera_pred) rleg, lleg = project_and_normalize_to_spin_legs( joints, smpl_output.A, camera_pred ) joints_2d = joints_2d[selector] z = z[selector] keypoints_3d_pred = joints_2d keypoints_2d_pred = keypoints_3d_pred[:, :2] lprior = ((rotmat_pred - theta) ** 2).sum() + ( (camera_pred - camera) ** 2 ).sum() loss = lprior l2 = 0.0 if c_mse > 0 and loss_mse is not None: l2 = loss_mse(keypoints_2d_pred, keypoints_2d) loss = loss + c_mse * l2 vertices_pred = smpl_output.vertices lpar = z_loss = loss_sh = 0.0 if c_new_mse > 0 and loss_parallel is not None: Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel( keypoints_3d_pred, keypoints_2d, z, (rleg, lleg), global_step=global_step, ) lpar = ( Ltan + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar) + Lspine + args.c_reg * Lgr + args.c_reg * Lstraight3d + args.c_cont2d * Lcon2d ) loss = loss + 300 * lpar for side in ["left", "right"]: attr = f"{side}_foot_inds" if hasattr(loss_parallel, attr): foot_inds = getattr(loss_parallel, attr) zind = 1 if attr not in mean_zfoot_val: with torch.no_grad(): mean_zfoot_val[attr] = torch.median( vertices_pred[0, foot_inds, zind], dim=0 ).values loss_foot = ( (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr]) ** 2 ).sum() loss = loss + args.c_reg * loss_foot if hasattr(loss_parallel, "silhuette_vertices_inds"): inds = loss_parallel.silhuette_vertices_inds loss_sh = ( (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2 ).sum() loss = loss + args.c_reg * loss_sh lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0 if sc_crit is not None: gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred) lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss loss = loss + lgsc_a msc_loss = 0.0 if contact is not None and len(contact) > 0 and msc_crit is not None: if not isinstance(contact, list): contact = [contact] for cntct in contact: msc_loss = msc_crit( cntct, vertices_pred, ) loss = loss + args.c_msc * msc_loss loss.backward() optimizer.step() epoch_loss = loss.item() pbar.set_postfix( **{ "l": f"{epoch_loss:.3}", "l2": f"{l2:.3}", "par": f"{lpar:.3}", "z": f"{z_loss:.3}", "gsc_contact": f"{float(gsc_contact_loss):.3}", "faces_angle": f"{float(faces_angle_loss):.3}", "msc": f"{float(msc_loss):.3}", } ) rotmat_pred = rotmat_pred.detach() if zero_hands: for i in [20, 21]: rotmat_pred[3 * i : 3 * (i + 1)] = 0 for i in [12, 15]: # neck, head rotmat_pred[3 * i + 1] = 0 # y global_orient = rotmat_pred[:3] body_pose = rotmat_pred[3:] left_hand_pose = None right_hand_pose = None if fist is not None: left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0) right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0) for f in fist: pp = fist_pose.INT_TO_FIST[f] if pp is not None: pp = rotmat_pred.new_tensor(pp).unsqueeze(0) if f.startswith("lf"): left_hand_pose = pp elif f.startswith("rf"): right_hand_pose = pp elif f.startswith("l"): body_pose[19 * 3 : 19 * 3 + 3] = pp left_hand_pose = None elif f.startswith("r"): body_pose[20 * 3 : 20 * 3 + 3] = pp right_hand_pose = None else: raise RuntimeError(f"No such hand pose: {f}") with torch.no_grad(): smpl_output = smpl( global_orient=global_orient.unsqueeze(0), body_pose=body_pose.unsqueeze(0), left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, pose2rot=True, ) return rotmat_pred, smpl_output def create_bone(i, j, keypoints_2d): a = keypoints_2d[i] b = keypoints_2d[j] ab = b - a ab = torch.nn.functional.normalize(ab, dim=0) return ab def is_parallel_to_plane(bone, thresh=21): return abs(bone[0]) > math.cos(math.radians(thresh)) def is_close_to_plane(bone, plane, thresh): dist = abs(bone[0] - plane) return dist < thresh def get_selector(): selector = [] for kp in pose_estimation.KPS: tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp]) selector.append(tmp) return selector def calc_cos(joints_2d, joints_3d): cos = [] for i, j in pose_estimation.SKELETON: a = joints_2d[i] - joints_2d[j] a = nn.functional.normalize(a, dim=0) b = joints_3d[i] - joints_3d[j] b = nn.functional.normalize(b, dim=0)[:2] c = (a * b).sum() cos.append(c) cos = torch.stack(cos, dim=0) return cos def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl): height_2d = ( keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0] ).item() plane_2d = keypoints_2d.max(dim=0).values[0].item() ground_parallel = [] parallel_in_3d = [] parallel3d_bones = set() # parallel chains for i, j, k in [ ("Right Upper Leg", "Right Leg", "Right Foot"), ("Right Leg", "Right Foot", "Right Toe"), # to remove? ("Left Upper Leg", "Left Leg", "Left Foot"), ("Left Leg", "Left Foot", "Left Toe"), # to remove? ("Right Shoulder", "Right Arm", "Right Hand"), ("Left Shoulder", "Left Arm", "Left Hand"), # ("Hips", "Spine", "Neck"), # ("Spine", "Neck", "Head"), ]: i = pose_estimation.KPS.index(i) j = pose_estimation.KPS.index(j) k = pose_estimation.KPS.index(k) upleg_leg = create_bone(i, j, keypoints_2d) leg_foot = create_bone(j, k, keypoints_2d) if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot): if is_close_to_plane( upleg_leg, plane_2d, thresh=0.1 * height_2d ) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d): ground_parallel.append(((i, j), 1)) ground_parallel.append(((j, k), 1)) if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)): parallel_in_3d.append(((i, j), (j, k))) parallel3d_bones.add((i, j)) parallel3d_bones.add((j, k)) # parallel feets for i, j in [ ("Right Foot", "Right Toe"), ("Left Foot", "Left Toe"), ]: i = pose_estimation.KPS.index(i) j = pose_estimation.KPS.index(j) if (i, j) in parallel3d_bones: continue foot_toe = create_bone(i, j, keypoints_2d) if is_parallel_to_plane(foot_toe, thresh=25): if "Right" in pose_estimation.KPS[i]: loss_parallel.right_foot_inds = right_foot_inds else: loss_parallel.left_foot_inds = left_foot_inds loss_parallel.ground_parallel = ground_parallel loss_parallel.parallel_in_3d = parallel_in_3d vertices_np = vertices[0].cpu().numpy() if len(ground_parallel) > 0: # Silhuette veritices mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False) silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1 height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min() plane_3d = vertices_np[:, 1].max() silhuette_vertices_mask_2 = ( np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d ) silhuette_vertices_mask = np.logical_and( silhuette_vertices_mask_1, silhuette_vertices_mask_2 ) (silhuette_vertices_inds,) = np.where(silhuette_vertices_mask) if len(silhuette_vertices_inds) > 0: loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds loss_parallel.ground = plane_3d def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel): keypoints_2d_pred = keypoints_3d_pred[:, :2] with torch.no_grad(): cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred) alpha = torch.acos(cos_r) if use_angle_transf: leg_inds = [ 5, 6, # right leg 7, 8, # left leg ] foot_inds = [15, 16] nleg_inds = sorted( set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds) ) alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min() amli = alpha[leg_inds].min() leg_inds.extend(foot_inds) alpha[leg_inds] = alpha[leg_inds] - amli angles = alpha.detach().cpu().numpy() angles = hist_cub.cub( angles / (math.pi / 2), a=1.2121212121212122, b=-1.105527638190953, c=0.787878787878789, ) * (math.pi / 2) alpha = alpha.new_tensor(angles) loss_parallel.cos = torch.cos(alpha) return cos_r def get_contacts( args, sc_module, y_data_conts, keypoints_2d, vertices, bone_to_params, loss_parallel, ): use_contacts = args.use_contacts use_msc = args.use_msc c_mse = args.c_mse if use_contacts: assert c_mse == 0 contact, contact_2d, _ = find_contacts( y_data_conts, keypoints_2d, bone_to_params ) if len(contact_2d) > 0: loss_parallel.contact_2d = contact_2d if len(contact) == 0: _, contact = sc_module.verts_in_contact(vertices, return_idx=True) contact = contact.cpu().numpy().ravel() elif use_msc: _, contact = sc_module.verts_in_contact(vertices, return_idx=True) contact = contact.cpu().numpy().ravel() else: contact = np.array([]) return contact def save_mesh( smpl, smpl_output, save_path, fname, ): mesh = trimesh.Trimesh( vertices=smpl_output.vertices[0].cpu().numpy(), faces=smpl.faces, process=False, ) rot = trimesh.transformations.rotation_matrix(np.pi, [1, 0, 0]) mesh.apply_transform(rot) rot = trimesh.transformations.rotation_matrix(np.pi, [0, 0, 1]) mesh.apply_transform(rot) mesh.export(save_path / f"{fname}.glb") def eft_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_beta, sc_module, y_data_conts, bone_to_params, ): ( _, _, _, keypoints_3d_pred, _, smpl_output, _, _, ) = optimize( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse=loss_mse, loss_parallel=loss_parallel, c_mse=1, c_new_mse=0, c_beta=c_beta, sc_crit=None, msc_crit=None, contact=None, n_steps=60 + 90, ) # find contacts vertices = smpl_output.vertices.detach() contact = get_contacts( args, sc_module, y_data_conts, keypoints_2d, vertices, bone_to_params, loss_parallel, ) return vertices, keypoints_3d_pred, contact def dc_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_mse, c_new_mse, c_beta, sc_crit, msc_crit, contact, use_contacts, use_msc, ): rotmat_pred, *_ = optimize( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse=loss_mse, loss_parallel=loss_parallel, c_mse=c_mse, c_new_mse=c_new_mse, c_beta=c_beta, sc_crit=sc_crit, msc_crit=msc_crit if use_contacts or use_msc else None, contact=contact if use_contacts or use_msc else None, n_steps=60 if c_new_mse > 0 or use_contacts or use_msc else 0, # + 60,, i_ini=60 + 90, ) return rotmat_pred def us_step( model_hmr, smpl, selector, input_img, rotmat_pred, keypoints_2d, args, loss_mse, loss_parallel, c_mse, c_new_mse, sc_crit, msc_crit, contact, use_contacts, use_msc, save_path, ): (_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data( model_hmr, smpl, selector, input_img, use_betas=False, zero_hands=True, ) _, smpl_output_us = optimize_ft( rotmat_pred, camera_pred_us, smpl, selector, keypoints_2d, args, loss_mse=loss_mse, loss_parallel=loss_parallel, c_mse=c_mse, c_new_mse=c_new_mse, sc_crit=sc_crit, msc_crit=msc_crit if use_contacts or use_msc else None, contact=contact if use_contacts or use_msc else None, n_steps=60 if use_contacts or use_msc else 0, # + 60, i_ini=60 + 90 + 60, zero_hands=True, fist=args.fist, ) save_mesh( smpl, smpl_output_us, save_path, "us", ) def main(): args = parse_args() print(args) # models model_pose = cv2.dnn.readNetFromONNX( args.pose_estimation_model_path ) # "hrn_w48_384x288.onnx" model_contact = cv2.dnn.readNetFromONNX( args.contact_model_path ) # "contact_hrn_w32_256x192.onnx" device = ( torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu") ) model_hmr = spin.hmr(args.smpl_mean_params_path) # "smpl_mean_params.npz" model_hmr.to(device) checkpoint = torch.load( args.spin_model_path, # "spin_model_smplx_eft_18.pt" map_location="cpu" ) smpl = spin.SMPLX( args.smpl_model_dir, # "models/smplx" batch_size=1, create_transl=False, use_pca=False, flat_hand_mean=args.fist is not None, ) smpl.to(device) selector = get_selector() use_contacts = args.use_contacts use_msc = args.use_msc bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item() foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item() left_foot_inds = foot_inds["left_foot_inds"] right_foot_inds = foot_inds["right_foot_inds"] if use_contacts: model_type = args.smpl_type sc_module = selfcontact.SelfContact( essentials_folder=args.essentials_dir, # "smplify-xmc-essentials" geothres=0.3, euclthres=0.02, test_segments=True, compute_hd=True, model_type=model_type, device=device, ) sc_module.to(device) sc_crit = selfcontact.losses.SelfContactLoss( contact_module=sc_module, inside_loss_weight=0.5, outside_loss_weight=0.0, contact_loss_weight=0.5, align_faces=True, use_hd=True, test_segments=True, device=device, model_type=model_type, ) sc_crit.to(device) msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask) msc_crit.to(device) else: sc_module = None sc_crit = None msc_crit = None loss_mse = losses.MSE([1, 10, 13]) # Neck + Right Upper Leg + Left Upper Leg ignore = ( (1, 2), # Neck + Right Shoulder (1, 5), # Neck + Left Shoulder (9, 10), # Hips + Right Upper Leg (9, 13), # Hips + Left Upper Leg ) loss_parallel = losses.Parallel( skeleton=pose_estimation.SKELETON, ignore=ignore, ) c_mse = args.c_mse c_new_mse = args.c_par c_beta = 1e-3 if c_mse > 0: assert c_new_mse == 0 elif c_mse == 0: assert c_new_mse > 0 root_path = Path(args.save_path) root_path.mkdir(exist_ok=True, parents=True) path_to_imgs = Path(args.img_path) if path_to_imgs.is_dir(): path_to_imgs = path_to_imgs.iterdir() else: path_to_imgs = [path_to_imgs] for img_path in path_to_imgs: if not any( img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"] ): continue img_name = img_path.stem # use 2d keypoints detection ( img_original, predicted_keypoints_2d, _, _, ) = pose_estimation.infer_single_image( model_pose, img_path, input_img_size=pose_estimation.IMG_SIZE, return_kps=True, ) save_path = root_path / img_name save_path.mkdir(exist_ok=True, parents=True) img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB) img_size_original = img_original.shape[:2] keypoints_2d, *_ = normalize_keypoints_to_spin( predicted_keypoints_2d, img_size_original ) keypoints_2d = torch.from_numpy(keypoints_2d) keypoints_2d = keypoints_2d.to(device) ( predicted_contact_heatmap, predicted_contact_heatmap_raw, very_hm_raw, ) = get_contact_heatmap(model_contact, img_path) predicted_contact_heatmap_raw = Image.fromarray( predicted_contact_heatmap_raw ).resize(img_size_original[::-1]) predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1]) if c_new_mse == 0: predicted_contact_heatmap_raw = None y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap) model_hmr.load_state_dict(checkpoint["model"], strict=True) model_hmr.train() freeze_layers(model_hmr) _, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES) input_img = input_img.to(device) optimizer = optim.Adam( filter(lambda p: p.requires_grad, model_hmr.parameters()), lr=1e-6, ) vertices, keypoints_3d_pred, contact = eft_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_beta, sc_module, y_data_conts, bone_to_params, ) if args.use_natural: get_natural( keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl, ) if args.use_cos: get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel) rotmat_pred = dc_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_mse, c_new_mse, c_beta, sc_crit, msc_crit, contact, use_contacts, use_msc, ) us_step( model_hmr, smpl, selector, input_img, rotmat_pred, keypoints_2d, args, loss_mse, loss_parallel, c_mse, c_new_mse, sc_crit, msc_crit, contact, use_contacts, use_msc, save_path, ) if __name__ == "__main__": main()