from pytorch3d.structures import Meshes, Pointclouds import torch from lib.common.render_utils import face_vertices from lib.dataset.Evaluator import point_mesh_distance from lib.dataset.mesh_util import SMPLX, barycentric_coordinates_of_projection class PointFeat: def __init__(self, verts, faces): # verts [B, N_vert, 3] # faces [B, N_face, 3] # triangles [B, N_face, 3, 3] self.Bsize = verts.shape[0] self.device = verts.device self.faces = faces # SMPL has watertight mesh, but SMPL-X has two eyeballs and open mouth # 1. remove eye_ball faces from SMPL-X: 9928-9383, 10474-9929 # 2. fill mouth holes with 30 more faces if verts.shape[1] == 10475: faces = faces[:, ~SMPLX().smplx_eyeball_fid_mask] mouth_faces = ( torch.as_tensor(SMPLX().smplx_mouth_fid).unsqueeze(0).repeat(self.Bsize, 1, 1).to(self.device) ) self.faces = torch.cat([faces, mouth_faces], dim=1).long() self.verts = verts.float() self.triangles = face_vertices(self.verts, self.faces) self.mesh = Meshes(self.verts, self.faces).to(self.device) def query(self, points): points = points.float() residues, pts_ind = point_mesh_distance(self.mesh, Pointclouds(points), weighted=False) closest_triangles = torch.gather( self.triangles, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) ).view(-1, 3, 3) bary_weights = barycentric_coordinates_of_projection(points.view(-1, 3), closest_triangles) feat_normals = face_vertices(self.mesh.verts_normals_padded(), self.faces) closest_normals = torch.gather( feat_normals, 1, pts_ind[None, :, None, None].expand(-1, -1, 3, 3) ).view(-1, 3, 3) shoot_verts = ((closest_triangles * bary_weights[:, :, None]).sum(1).unsqueeze(0)) pts2shoot_normals = points - shoot_verts pts2shoot_normals = pts2shoot_normals / torch.norm(pts2shoot_normals, dim=-1, keepdim=True) shoot_normals = ((closest_normals * bary_weights[:, :, None]).sum(1).unsqueeze(0)) shoot_normals = shoot_normals / torch.norm(shoot_normals, dim=-1, keepdim=True) angles = (pts2shoot_normals * shoot_normals).sum(dim=-1).abs() return (torch.sqrt(residues).unsqueeze(0), angles)