full_gaussian_avatar / GHA /lib /trainer /MeshHeadTrainer.py
pengc02's picture
all
ec9a6bc
import torch
import torch.nn.functional as F
from tqdm import tqdm
import kaolin
import lpips
from einops import rearrange
from pytorch3d.transforms import so3_exponential_map
from kaolin.ops.mesh import index_vertices_by_faces
from kaolin.metrics.trianglemesh import point_to_mesh_distance
def laplace_regularizer_const(mesh_verts, mesh_faces):
term = torch.zeros_like(mesh_verts)
norm = torch.zeros_like(mesh_verts[..., 0:1])
v0 = mesh_verts[mesh_faces[:, 0], :]
v1 = mesh_verts[mesh_faces[:, 1], :]
v2 = mesh_verts[mesh_faces[:, 2], :]
term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))
two = torch.ones_like(v0) * 2.0
norm.scatter_add_(0, mesh_faces[:, 0:1], two)
norm.scatter_add_(0, mesh_faces[:, 1:2], two)
norm.scatter_add_(0, mesh_faces[:, 2:3], two)
term = term / torch.clamp(norm, min=1.0)
return torch.mean(term**2)
class MeshHeadTrainer():
def __init__(self, dataloader, meshhead, camera, optimizer, recorder, gpu_id):
self.dataloader = dataloader
self.meshhead = meshhead
self.camera = camera
self.optimizer = optimizer
self.recorder = recorder
self.device = torch.device('cuda:%d' % gpu_id)
def train(self, start_epoch=0, epochs=1):
for epoch in range(start_epoch, epochs):
for idx, data in tqdm(enumerate(self.dataloader)):
# prepare data
to_cuda = ['images', 'masks', 'visibles', 'intrinsics', 'extrinsics', 'pose', 'scale', 'exp_coeff', 'landmarks_3d', 'exp_id']
for data_item in to_cuda:
data[data_item] = data[data_item].to(device=self.device)
images = data['images'].permute(0, 1, 3, 4, 2)
masks = data['masks'].permute(0, 1, 3, 4, 2)
visibles = data['visibles'].permute(0, 1, 3, 4, 2)
resolution = images.shape[2]
R = so3_exponential_map(data['pose'][:, :3])
T = data['pose'][:, 3:, None]
S = data['scale'][:, :, None]
landmarks_3d_can = (torch.bmm(R.permute(0,2,1), (data['landmarks_3d'].permute(0, 2, 1) - T)) / S).permute(0, 2, 1)
landmarks_3d_neutral = self.meshhead.get_landmarks()[None].repeat(data['landmarks_3d'].shape[0], 1, 1)
data['landmarks_3d_neutral'] = landmarks_3d_neutral
deform_data = {
'exp_coeff': data['exp_coeff'],
'query_pts': landmarks_3d_neutral
}
deform_data = self.meshhead.deform(deform_data)
pred_landmarks_3d_can = deform_data['deformed_pts']
loss_def = F.mse_loss(pred_landmarks_3d_can, landmarks_3d_can)
deform_data = self.meshhead.query_sdf(deform_data)
sdf_landmarks_3d = deform_data['sdf']
loss_lmk = torch.abs(sdf_landmarks_3d[:, :, 0]).mean()
data = self.meshhead.reconstruct(data)
data = self.camera.render(data, resolution)
render_images = data['render_images']
render_soft_masks = data['render_soft_masks']
exp_deform = data['exp_deform']
pose_deform = data['pose_deform']
verts_list = data['verts_list']
faces_list = data['faces_list']
loss_rgb = F.l1_loss(render_images[:, :, :, :, 0:3] * visibles, images * visibles)
loss_sil = kaolin.metrics.render.mask_iou((render_soft_masks * visibles[:, :, :, :, 0]).view(-1, resolution, resolution), (masks * visibles).squeeze().view(-1, resolution, resolution))
loss_offset = (exp_deform ** 2).sum(-1).mean() + (pose_deform ** 2).sum(-1).mean()
loss_lap = 0.0
for b in range(len(verts_list)):
loss_lap += laplace_regularizer_const(verts_list[b], faces_list[b])
loss = loss_rgb * 1e-1 + loss_sil * 1e-1 + loss_def * 1e0 + loss_offset * 1e-2 + loss_lmk * 1e-1 + loss_lap * 1e2
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
log = {
'data': data,
'meshhead' : self.meshhead,
'loss_rgb' : loss_rgb,
'loss_sil' : loss_sil,
'loss_def' : loss_def,
'loss_offset' : loss_offset,
'loss_lmk' : loss_lmk,
'loss_lap' : loss_lap,
'epoch' : epoch,
'iter' : idx + epoch * len(self.dataloader)
}
self.recorder.log(log)
if idx + epoch * len(self.dataloader) >= 10000:
return