pengc02's picture
all
ec9a6bc
raw
history blame
1.46 kB
import torch
import torch.nn as nn
from pytorch3d.transforms import axis_angle_to_quaternion
from network.mlp import MLPLinear
from utils.embedder import get_embedder
class HandAvatar(nn.Module):
def __init__(self,
multires = 4,
view_multires = -1,
pose_dim = 15*4):
super(HandAvatar, self).__init__()
self.pos_embedder, self.pos_dim = get_embedder(multires, 3)
if view_multires == -1:
self.view_embedder, self.view_dim = None, 0
else:
self.view_embedder, self.view_dim = get_embedder(view_multires, 3)
self.pose_dim = pose_dim
self.tex_mlp = MLPLinear(
in_channels = self.pos_dim + 1 + self.view_dim + pose_dim,
inter_channels = [64, 64, 64, 64, 64],
out_channels = 3,
last_op = nn.Sigmoid()
)
def forward(self, cano_xyz, sdf, view_dir, hand_pose):
batch_size, n_pts = cano_xyz.shape[:2]
in_feat = torch.cat([self.pos_embedder(cano_xyz), sdf], -1)
hand_pose = axis_angle_to_quaternion(hand_pose.reshape(batch_size, -1, 3)).reshape(batch_size, -1)
if self.view_embedder is not None:
in_feat = torch.cat([in_feat, self.view_embedder(view_dir)], -1)
in_feat = torch.cat([in_feat, hand_pose[:, None].expand(-1, n_pts, -1)], -1)
color = self.tex_mlp(in_feat)
return color