import torch from torch import nn import numpy as np import kaolin import tqdm from pytorch3d.ops.knn import knn_gather, knn_points from pytorch3d.transforms import so3_exponential_map from lib.network.MLP import MLP from lib.network.PositionalEmbedding import get_embedder from lib.utils.dmtet_utils import marching_tetrahedra class MeshHeadModule(nn.Module): def __init__(self, cfg, init_landmarks_3d_neutral): super(MeshHeadModule, self).__init__() self.geo_mlp = MLP(cfg.geo_mlp, last_op=nn.Tanh()) self.exp_color_mlp = MLP(cfg.exp_color_mlp, last_op=None) self.pose_color_mlp = MLP(cfg.pose_color_mlp, last_op=None) self.exp_deform_mlp = MLP(cfg.exp_deform_mlp, last_op=nn.Tanh()) self.pose_deform_mlp = MLP(cfg.pose_deform_mlp, last_op=nn.Tanh()) self.landmarks_3d_neutral = nn.Parameter(init_landmarks_3d_neutral) self.pos_embedding, _ = get_embedder(cfg.pos_freq) self.model_bbox = cfg.model_bbox self.dist_threshold_near = cfg.dist_threshold_near self.dist_threshold_far = cfg.dist_threshold_far self.deform_scale = cfg.deform_scale tets_data = np.load('assets/tets_data.npz') self.register_buffer('tet_verts', torch.from_numpy(tets_data['tet_verts'])) self.register_buffer('tets', torch.from_numpy(tets_data['tets'])) self.grid_res = 128 if cfg.subdivide: self.subdivide() def geometry(self, geo_input): pred = self.geo_mlp(geo_input) return pred def exp_color(self, color_input): verts_color = self.exp_color_mlp(color_input) return verts_color def pose_color(self, color_input): verts_color = self.pose_color_mlp(color_input) return verts_color def exp_deform(self, deform_input): deform = self.exp_deform_mlp(deform_input) return deform def pose_deform(self, deform_input): deform = self.pose_deform_mlp(deform_input) return deform def get_landmarks(self): return self.landmarks_3d_neutral def subdivide(self): new_tet_verts, new_tets = kaolin.ops.mesh.subdivide_tetmesh(self.tet_verts.unsqueeze(0), self.tets) self.tet_verts = new_tet_verts[0] self.tets = new_tets self.grid_res *= 2 def reconstruct(self, data): B = data['exp_coeff'].shape[0] query_pts = self.tet_verts.unsqueeze(0).repeat(B, 1, 1) geo_input = self.pos_embedding(query_pts).permute(0, 2, 1) pred = self.geometry(geo_input) sdf, deform, features = pred[:, :1, :], pred[:, 1:4, :], pred[:, 4:, :] sdf = sdf.permute(0, 2, 1) features = features.permute(0, 2, 1) verts_deformed = (query_pts + torch.tanh(deform.permute(0, 2, 1)) / self.grid_res) verts_list, features_list, faces_list = marching_tetrahedra(verts_deformed, features, self.tets, sdf) data['verts0_list'] = verts_list data['faces_list'] = faces_list verts_batch = [] verts_features_batch = [] num_pts_max = 0 for b in range(B): if verts_list[b].shape[0] > num_pts_max: num_pts_max = verts_list[b].shape[0] for b in range(B): verts_batch.append(torch.cat([verts_list[b], torch.zeros([num_pts_max - verts_list[b].shape[0], verts_list[b].shape[1]], device=verts_list[b].device)], 0)) verts_features_batch.append(torch.cat([features_list[b], torch.zeros([num_pts_max - features_list[b].shape[0], features_list[b].shape[1]], device=features_list[b].device)], 0)) verts_batch = torch.stack(verts_batch, 0) verts_features_batch = torch.stack(verts_features_batch, 0) dists, idx, _ = knn_points(verts_batch, data['landmarks_3d_neutral']) exp_weights = torch.clamp((self.dist_threshold_far - dists) / (self.dist_threshold_far - self.dist_threshold_near), 0.0, 1.0) pose_weights = 1 - exp_weights exp_color_input = torch.cat([verts_features_batch.permute(0, 2, 1), data['exp_coeff'].unsqueeze(-1).repeat(1, 1, num_pts_max)], 1) verts_color_batch = self.exp_color(exp_color_input).permute(0, 2, 1) * exp_weights pose_color_input = torch.cat([verts_features_batch.permute(0, 2, 1), self.pos_embedding(data['pose']).unsqueeze(-1).repeat(1, 1, num_pts_max)], 1) verts_color_batch = verts_color_batch + self.pose_color(pose_color_input).permute(0, 2, 1) * pose_weights exp_deform_input = torch.cat([self.pos_embedding(verts_batch).permute(0, 2, 1), data['exp_coeff'].unsqueeze(-1).repeat(1, 1, num_pts_max)], 1) exp_deform = self.exp_deform(exp_deform_input).permute(0, 2, 1) verts_batch = verts_batch + exp_deform * exp_weights * self.deform_scale pose_deform_input = torch.cat([self.pos_embedding(verts_batch).permute(0, 2, 1), self.pos_embedding(data['pose']).unsqueeze(-1).repeat(1, 1, num_pts_max)], 1) pose_deform = self.pose_deform(pose_deform_input).permute(0, 2, 1) verts_batch = verts_batch + pose_deform * pose_weights * self.deform_scale if 'pose' in data: R = so3_exponential_map(data['pose'][:, :3]) T = data['pose'][:, None, 3:] S = data['scale'][:, :, None] verts_batch = torch.bmm(verts_batch * S, R.permute(0, 2, 1)) + T data['exp_deform'] = exp_deform data['pose_deform'] = pose_deform data['verts_list'] = [verts_batch[b, :verts_list[b].shape[0], :] for b in range(B)] data['verts_color_list'] = [verts_color_batch[b, :verts_list[b].shape[0], :] for b in range(B)] return data def reconstruct_neutral(self): query_pts = self.tet_verts.unsqueeze(0) geo_input = self.pos_embedding(query_pts).permute(0, 2, 1) pred = self.geometry(geo_input) sdf, deform, features = pred[:, :1, :], pred[:, 1:4, :], pred[:, 4:, :] sdf = sdf.permute(0, 2, 1) features = features.permute(0, 2, 1) verts_deformed = (query_pts + torch.tanh(deform.permute(0, 2, 1)) / self.grid_res) verts_list, features_list, faces_list = marching_tetrahedra(verts_deformed, features, self.tets, sdf) data = {} data['verts'] = verts_list[0] data['faces'] = faces_list[0] data['verts_feature'] = features_list[0] return data def query_sdf(self, data): query_pts = data['query_pts'] geo_input = self.pos_embedding(query_pts).permute(0, 2, 1) pred = self.geometry(geo_input) sdf = pred[:, :1, :] sdf = sdf.permute(0, 2, 1) data['sdf'] = sdf return data def deform(self, data): exp_coeff = data['exp_coeff'] query_pts = data['query_pts'] geo_input = self.pos_embedding(query_pts).permute(0, 2, 1) pred = self.geometry(geo_input) sdf, deform = pred[:, :1, :], pred[:, 1:4, :] query_pts = (query_pts + torch.tanh(deform).permute(0, 2, 1) / self.grid_res) exp_deform_input = torch.cat([self.pos_embedding(query_pts).permute(0, 2, 1), exp_coeff.unsqueeze(-1).repeat(1, 1, query_pts.shape[1])], 1) exp_deform = self.exp_deform(exp_deform_input).permute(0, 2, 1) deformed_pts = query_pts + exp_deform * self.deform_scale data['deformed_pts'] = deformed_pts return data def in_bbox(self, verts, bbox): is_in_bbox = (verts[:, :, 0] > bbox[0][0]) & \ (verts[:, :, 1] > bbox[1][0]) & \ (verts[:, :, 2] > bbox[2][0]) & \ (verts[:, :, 0] < bbox[0][1]) & \ (verts[:, :, 1] < bbox[1][1]) & \ (verts[:, :, 2] < bbox[2][1]) return is_in_bbox def pre_train_sphere(self, iter, device): loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-3) for i in tqdm.tqdm(range(iter)): query_pts = torch.rand((8, 1024, 3), device=device) * 3 - 1.5 ref_value = torch.sqrt((query_pts**2).sum(-1)) - 1.0 data = { 'query_pts': query_pts } data = self.query_sdf(data) sdf = data['sdf'] loss = loss_fn(sdf[:, :, 0], ref_value) optimizer.zero_grad() loss.backward() optimizer.step() print("Pre-trained MLP", loss.item())