import torch import torch.nn.functional as F from tqdm import tqdm import kaolin import lpips from einops import rearrange from pytorch3d.transforms import so3_exponential_map from kaolin.ops.mesh import index_vertices_by_faces from kaolin.metrics.trianglemesh import point_to_mesh_distance def laplace_regularizer_const(mesh_verts, mesh_faces): term = torch.zeros_like(mesh_verts) norm = torch.zeros_like(mesh_verts[..., 0:1]) v0 = mesh_verts[mesh_faces[:, 0], :] v1 = mesh_verts[mesh_faces[:, 1], :] v2 = mesh_verts[mesh_faces[:, 2], :] term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0)) term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1)) term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2)) two = torch.ones_like(v0) * 2.0 norm.scatter_add_(0, mesh_faces[:, 0:1], two) norm.scatter_add_(0, mesh_faces[:, 1:2], two) norm.scatter_add_(0, mesh_faces[:, 2:3], two) term = term / torch.clamp(norm, min=1.0) return torch.mean(term**2) class MeshHeadTrainer(): def __init__(self, dataloader, meshhead, camera, optimizer, recorder, gpu_id): self.dataloader = dataloader self.meshhead = meshhead self.camera = camera self.optimizer = optimizer self.recorder = recorder self.device = torch.device('cuda:%d' % gpu_id) def train(self, start_epoch=0, epochs=1): for epoch in range(start_epoch, epochs): for idx, data in tqdm(enumerate(self.dataloader)): # prepare data to_cuda = ['images', 'masks', 'visibles', 'intrinsics', 'extrinsics', 'pose', 'scale', 'exp_coeff', 'landmarks_3d', 'exp_id'] for data_item in to_cuda: data[data_item] = data[data_item].to(device=self.device) images = data['images'].permute(0, 1, 3, 4, 2) masks = data['masks'].permute(0, 1, 3, 4, 2) visibles = data['visibles'].permute(0, 1, 3, 4, 2) resolution = images.shape[2] R = so3_exponential_map(data['pose'][:, :3]) T = data['pose'][:, 3:, None] S = data['scale'][:, :, None] landmarks_3d_can = (torch.bmm(R.permute(0,2,1), (data['landmarks_3d'].permute(0, 2, 1) - T)) / S).permute(0, 2, 1) landmarks_3d_neutral = self.meshhead.get_landmarks()[None].repeat(data['landmarks_3d'].shape[0], 1, 1) data['landmarks_3d_neutral'] = landmarks_3d_neutral deform_data = { 'exp_coeff': data['exp_coeff'], 'query_pts': landmarks_3d_neutral } deform_data = self.meshhead.deform(deform_data) pred_landmarks_3d_can = deform_data['deformed_pts'] loss_def = F.mse_loss(pred_landmarks_3d_can, landmarks_3d_can) deform_data = self.meshhead.query_sdf(deform_data) sdf_landmarks_3d = deform_data['sdf'] loss_lmk = torch.abs(sdf_landmarks_3d[:, :, 0]).mean() data = self.meshhead.reconstruct(data) data = self.camera.render(data, resolution) render_images = data['render_images'] render_soft_masks = data['render_soft_masks'] exp_deform = data['exp_deform'] pose_deform = data['pose_deform'] verts_list = data['verts_list'] faces_list = data['faces_list'] loss_rgb = F.l1_loss(render_images[:, :, :, :, 0:3] * visibles, images * visibles) loss_sil = kaolin.metrics.render.mask_iou((render_soft_masks * visibles[:, :, :, :, 0]).view(-1, resolution, resolution), (masks * visibles).squeeze().view(-1, resolution, resolution)) loss_offset = (exp_deform ** 2).sum(-1).mean() + (pose_deform ** 2).sum(-1).mean() loss_lap = 0.0 for b in range(len(verts_list)): loss_lap += laplace_regularizer_const(verts_list[b], faces_list[b]) loss = loss_rgb * 1e-1 + loss_sil * 1e-1 + loss_def * 1e0 + loss_offset * 1e-2 + loss_lmk * 1e-1 + loss_lap * 1e2 self.optimizer.zero_grad() loss.backward() self.optimizer.step() log = { 'data': data, 'meshhead' : self.meshhead, 'loss_rgb' : loss_rgb, 'loss_sil' : loss_sil, 'loss_def' : loss_def, 'loss_offset' : loss_offset, 'loss_lmk' : loss_lmk, 'loss_lap' : loss_lap, 'epoch' : epoch, 'iter' : idx + epoch * len(self.dataloader) } self.recorder.log(log) if idx + epoch * len(self.dataloader) >= 10000: return