full_gaussian_avatar / GHA /lib /module /GaussianHeadModule.py
pengc02's picture
all
ec9a6bc
import torch
from torch import nn
from einops import rearrange
import tqdm
from pytorch3d.ops.knn import knn_gather, knn_points
from pytorch3d.transforms import so3_exponential_map
from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix, matrix_to_quaternion
from simple_knn._C import distCUDA2
from GHA.lib.network.MLP import MLP
from GHA.lib.network.PositionalEmbedding import get_embedder
from GHA.lib.utils.general_utils import inverse_sigmoid
class GaussianHeadModule(nn.Module):
def __init__(self, cfg, xyz, feature, landmarks_3d_neutral, add_mouth_points=False):
super(GaussianHeadModule, self).__init__()
if add_mouth_points and cfg.num_add_mouth_points > 0:
mouth_keypoints = landmarks_3d_neutral[48:66]
mouth_center = torch.mean(mouth_keypoints, dim=0, keepdim=True)
mouth_center[:, 2] = mouth_keypoints[:, 2].min()
max_dist = (mouth_keypoints - mouth_center).abs().max(0)[0]
points_add = (torch.rand([cfg.num_add_mouth_points, 3]) - 0.5) * 1.6 * max_dist + mouth_center
xyz = torch.cat([xyz, points_add])
feature = torch.cat([feature, torch.zeros([cfg.num_add_mouth_points, feature.shape[1]])])
self.xyz = nn.Parameter(xyz)
self.feature = nn.Parameter(feature)
self.register_buffer('landmarks_3d_neutral', landmarks_3d_neutral)
dist2 = torch.clamp_min(distCUDA2(self.xyz.cuda()), 0.0000001).cpu()
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
self.scales = nn.Parameter(scales)
rots = torch.zeros((xyz.shape[0], 4), device=xyz.device)
rots[:, 0] = 1
self.rotation = nn.Parameter(rots)
self.opacity = nn.Parameter(inverse_sigmoid(0.3 * torch.ones((xyz.shape[0], 1))))
self.exp_color_mlp = MLP(cfg.exp_color_mlp, last_op=None)
self.pose_color_mlp = MLP(cfg.pose_color_mlp, last_op=None)
self.exp_attributes_mlp = MLP(cfg.exp_attributes_mlp, last_op=None)
self.pose_attributes_mlp = MLP(cfg.pose_attributes_mlp, last_op=None)
self.exp_deform_mlp = MLP(cfg.exp_deform_mlp, last_op=nn.Tanh())
self.pose_deform_mlp = MLP(cfg.pose_deform_mlp, last_op=nn.Tanh())
self.pos_embedding, _ = get_embedder(cfg.pos_freq)
self.exp_coeffs_dim = cfg.exp_coeffs_dim
self.dist_threshold_near = cfg.dist_threshold_near
self.dist_threshold_far = cfg.dist_threshold_far
self.deform_scale = cfg.deform_scale
self.attributes_scale = cfg.attributes_scale
def generate(self, data):
B = data['exp_coeff'].shape[0]
xyz = self.xyz.unsqueeze(0).repeat(B, 1, 1)
feature = torch.tanh(self.feature).unsqueeze(0).repeat(B, 1, 1)
dists, _, _ = knn_points(xyz, self.landmarks_3d_neutral.unsqueeze(0).repeat(B, 1, 1))
exp_weights = torch.clamp((self.dist_threshold_far - dists) / (self.dist_threshold_far - self.dist_threshold_near), 0.0, 1.0)
pose_weights = 1 - exp_weights
exp_controlled = (dists < self.dist_threshold_far).squeeze(-1)
pose_controlled = (dists > self.dist_threshold_near).squeeze(-1)
color = torch.zeros([B, xyz.shape[1], self.exp_color_mlp.dims[-1]], device=xyz.device)
delta_xyz = torch.zeros_like(xyz, device=xyz.device)
delta_attributes = torch.zeros([B, xyz.shape[1], self.scales.shape[1] + self.rotation.shape[1] + self.opacity.shape[1]], device=xyz.device)
for b in range(B):
# print(B)
feature_exp_controlled = feature[b, exp_controlled[b], :]
exp_color_input = torch.cat([feature_exp_controlled.t(),
data['exp_coeff'][b].unsqueeze(-1).repeat(1, feature_exp_controlled.shape[0])], 0)[None]
exp_color = self.exp_color_mlp(exp_color_input)[0].t()
color[b, exp_controlled[b], :] += exp_color * exp_weights[b, exp_controlled[b], :]
feature_pose_controlled = feature[b, pose_controlled[b], :]
pose_color_input = torch.cat([feature_pose_controlled.t(),
self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, feature_pose_controlled.shape[0])], 0)[None]
pose_color = self.pose_color_mlp(pose_color_input)[0].t()
color[b, pose_controlled[b], :] += pose_color * pose_weights[b, pose_controlled[b], :]
exp_attributes_input = exp_color_input
exp_delta_attributes = self.exp_attributes_mlp(exp_attributes_input)[0].t()
delta_attributes[b, exp_controlled[b], :] += exp_delta_attributes * exp_weights[b, exp_controlled[b], :]
pose_attributes_input = pose_color_input
pose_attributes = self.pose_attributes_mlp(pose_attributes_input)[0].t()
delta_attributes[b, pose_controlled[b], :] += pose_attributes * pose_weights[b, pose_controlled[b], :]
xyz_exp_controlled = xyz[b, exp_controlled[b], :]
exp_deform_input = torch.cat([self.pos_embedding(xyz_exp_controlled).t(),
data['exp_coeff'][b].unsqueeze(-1).repeat(1, xyz_exp_controlled.shape[0])], 0)[None]
exp_deform = self.exp_deform_mlp(exp_deform_input)[0].t()
delta_xyz[b, exp_controlled[b], :] += exp_deform * exp_weights[b, exp_controlled[b], :]
xyz_pose_controlled = xyz[b, pose_controlled[b], :]
pose_deform_input = torch.cat([self.pos_embedding(xyz_pose_controlled).t(),
self.pos_embedding(data['pose'][b]).unsqueeze(-1).repeat(1, xyz_pose_controlled.shape[0])], 0)[None]
pose_deform = self.pose_deform_mlp(pose_deform_input)[0].t()
delta_xyz[b, pose_controlled[b], :] += pose_deform * pose_weights[b, pose_controlled[b], :]
xyz = xyz + delta_xyz * self.deform_scale
delta_scales = delta_attributes[:, :, 0:3]
scales = self.scales.unsqueeze(0).repeat(B, 1, 1) + delta_scales * self.attributes_scale
scales = torch.exp(scales)
delta_rotation = delta_attributes[:, :, 3:7]
rotation = self.rotation.unsqueeze(0).repeat(B, 1, 1) + delta_rotation * self.attributes_scale
rotation = torch.nn.functional.normalize(rotation, dim=2)
delta_opacity = delta_attributes[:, :, 7:8]
opacity = self.opacity.unsqueeze(0).repeat(B, 1, 1) + delta_opacity * self.attributes_scale
opacity = torch.sigmoid(opacity)
if 'pose' in data:
R = so3_exponential_map(data['pose'][:, :3])
T = data['pose'][:, None, 3:]
S = data['scale'][:, :, None]
xyz = torch.bmm(xyz * S, R.permute(0, 2, 1)) + T
rotation_matrix = quaternion_to_matrix(rotation)
rotation_matrix = rearrange(rotation_matrix, 'b n x y -> (b n) x y')
R = rearrange(R.unsqueeze(1).repeat(1, rotation.shape[1], 1, 1), 'b n x y -> (b n) x y')
rotation_matrix = rearrange(torch.bmm(R, rotation_matrix), '(b n) x y -> b n x y', b=B)
rotation = matrix_to_quaternion(rotation_matrix)
scales = scales * S
data['exp_deform'] = exp_deform
data['xyz'] = xyz
data['color'] = color
data['scales'] = scales
data['rotation'] = rotation
data['opacity'] = opacity
return data