import torch from torch import nn from einops import rearrange import tqdm from pytorch3d.ops.knn import knn_gather, knn_points from pytorch3d.transforms import so3_exponential_map from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix, matrix_to_quaternion from simple_knn._C import distCUDA2 from GHA.lib.network.MLP import MLP from GHA.lib.network.PositionalEmbedding import get_embedder from GHA.lib.utils.general_utils import inverse_sigmoid class GaussianHeadModule(nn.Module): def __init__(self, cfg, xyz, feature, landmarks_3d_neutral, add_mouth_points=False): super(GaussianHeadModule, self).__init__() if add_mouth_points and cfg.num_add_mouth_points > 0: mouth_keypoints = landmarks_3d_neutral[48:66] mouth_center = torch.mean(mouth_keypoints, dim=0, keepdim=True) mouth_center[:, 2] = mouth_keypoints[:, 2].min() max_dist = (mouth_keypoints - mouth_center).abs().max(0)[0] points_add = (torch.rand([cfg.num_add_mouth_points, 3]) - 0.5) * 1.6 * max_dist + mouth_center xyz = torch.cat([xyz, points_add]) feature = torch.cat([feature, torch.zeros([cfg.num_add_mouth_points, feature.shape[1]])]) self.xyz = nn.Parameter(xyz) self.feature = nn.Parameter(feature) self.register_buffer('landmarks_3d_neutral', landmarks_3d_neutral) dist2 = torch.clamp_min(distCUDA2(self.xyz.cuda()), 0.0000001).cpu() scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) self.scales = nn.Parameter(scales) rots = torch.zeros((xyz.shape[0], 4), device=xyz.device) rots[:, 0] = 1 self.rotation = nn.Parameter(rots) self.opacity = nn.Parameter(inverse_sigmoid(0.3 * torch.ones((xyz.shape[0], 1)))) self.exp_color_mlp = MLP(cfg.exp_color_mlp, last_op=None) self.pose_color_mlp = MLP(cfg.pose_color_mlp, last_op=None) self.exp_attributes_mlp = MLP(cfg.exp_attributes_mlp, last_op=None) self.pose_attributes_mlp = MLP(cfg.pose_attributes_mlp, last_op=None) self.exp_deform_mlp = MLP(cfg.exp_deform_mlp, last_op=nn.Tanh()) self.pose_deform_mlp = MLP(cfg.pose_deform_mlp, last_op=nn.Tanh()) self.pos_embedding, _ = get_embedder(cfg.pos_freq) self.exp_coeffs_dim = cfg.exp_coeffs_dim self.dist_threshold_near = cfg.dist_threshold_near self.dist_threshold_far = cfg.dist_threshold_far self.deform_scale = cfg.deform_scale self.attributes_scale = cfg.attributes_scale def generate(self, data): B = data['exp_coeff'].shape[0] xyz = self.xyz.unsqueeze(0).repeat(B, 1, 1) feature = torch.tanh(self.feature).unsqueeze(0).repeat(B, 1, 1) dists, _, _ = knn_points(xyz, self.landmarks_3d_neutral.unsqueeze(0).repeat(B, 1, 1)) exp_weights = torch.clamp((self.dist_threshold_far - dists) / (self.dist_threshold_far - self.dist_threshold_near), 0.0, 1.0) pose_weights = 1 - exp_weights exp_controlled = (dists < self.dist_threshold_far).squeeze(-1) pose_controlled = (dists > self.dist_threshold_near).squeeze(-1) color = torch.zeros([B, xyz.shape[1], self.exp_color_mlp.dims[-1]], device=xyz.device) delta_xyz = torch.zeros_like(xyz, device=xyz.device) delta_attributes = torch.zeros([B, xyz.shape[1], self.scales.shape[1] + self.rotation.shape[1] + self.opacity.shape[1]], device=xyz.device) for b in range(B): # print(B) feature_exp_controlled = feature[b, exp_controlled[b], :] exp_color_input = torch.cat([feature_exp_controlled.t(), data['exp_coeff'][b].unsqueeze(-1).repeat(1, feature_exp_controlled.shape[0])], 0)[None] exp_color = self.exp_color_mlp(exp_color_input)[0].t() color[b, exp_controlled[b], :] += exp_color * exp_weights[b, exp_controlled[b], :] feature_pose_controlled = feature[b, pose_controlled[b], :] pose_color_input = torch.cat([feature_pose_controlled.t(), self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, feature_pose_controlled.shape[0])], 0)[None] pose_color = self.pose_color_mlp(pose_color_input)[0].t() color[b, pose_controlled[b], :] += pose_color * pose_weights[b, pose_controlled[b], :] exp_attributes_input = exp_color_input exp_delta_attributes = self.exp_attributes_mlp(exp_attributes_input)[0].t() delta_attributes[b, exp_controlled[b], :] += exp_delta_attributes * exp_weights[b, exp_controlled[b], :] pose_attributes_input = pose_color_input pose_attributes = self.pose_attributes_mlp(pose_attributes_input)[0].t() delta_attributes[b, pose_controlled[b], :] += pose_attributes * pose_weights[b, pose_controlled[b], :] xyz_exp_controlled = xyz[b, exp_controlled[b], :] exp_deform_input = torch.cat([self.pos_embedding(xyz_exp_controlled).t(), data['exp_coeff'][b].unsqueeze(-1).repeat(1, xyz_exp_controlled.shape[0])], 0)[None] exp_deform = self.exp_deform_mlp(exp_deform_input)[0].t() delta_xyz[b, exp_controlled[b], :] += exp_deform * exp_weights[b, exp_controlled[b], :] xyz_pose_controlled = xyz[b, pose_controlled[b], :] pose_deform_input = torch.cat([self.pos_embedding(xyz_pose_controlled).t(), self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, xyz_pose_controlled.shape[0])], 0)[None] pose_deform = self.pose_deform_mlp(pose_deform_input)[0].t() delta_xyz[b, pose_controlled[b], :] += pose_deform * pose_weights[b, pose_controlled[b], :] xyz = xyz + delta_xyz * self.deform_scale delta_scales = delta_attributes[:, :, 0:3] scales = self.scales.unsqueeze(0).repeat(B, 1, 1) + delta_scales * self.attributes_scale scales = torch.exp(scales) delta_rotation = delta_attributes[:, :, 3:7] rotation = self.rotation.unsqueeze(0).repeat(B, 1, 1) + delta_rotation * self.attributes_scale rotation = torch.nn.functional.normalize(rotation, dim=2) delta_opacity = delta_attributes[:, :, 7:8] opacity = self.opacity.unsqueeze(0).repeat(B, 1, 1) + delta_opacity * self.attributes_scale opacity = torch.sigmoid(opacity) if 'pose' in data: R = so3_exponential_map(data['pose'][:, :3]) T = data['pose'][:, None, 3:] S = data['scale'][:, :, None] xyz = torch.bmm(xyz * S, R.permute(0, 2, 1)) + T rotation_matrix = quaternion_to_matrix(rotation) rotation_matrix = rearrange(rotation_matrix, 'b n x y -> (b n) x y') R = rearrange(R.unsqueeze(1).repeat(1, rotation.shape[1], 1, 1), 'b n x y -> (b n) x y') rotation_matrix = rearrange(torch.bmm(R, rotation_matrix), '(b n) x y -> b n x y', b=B) rotation = matrix_to_quaternion(rotation_matrix) scales = scales * S data['exp_deform'] = exp_deform data['xyz'] = xyz data['color'] = color data['scales'] = scales data['rotation'] = rotation data['opacity'] = opacity return data