|
import pdb |
|
|
|
import numpy as np |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
from anim.tquat import * |
|
|
|
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
pose_input_size, |
|
pose_output_size, |
|
speech_encoding_size, |
|
style_encoding_size, |
|
hidden_size, |
|
num_rnn_layers, |
|
rnn_cond="normal", |
|
): |
|
super(Decoder, self).__init__() |
|
|
|
if rnn_cond == "normal": |
|
self.recurrent_decoder = RecurrentDecoderNormal( |
|
pose_input_size, |
|
speech_encoding_size, |
|
style_encoding_size, |
|
pose_output_size, |
|
hidden_size, |
|
num_rnn_layers, |
|
) |
|
elif rnn_cond == "film": |
|
self.recurrent_decoder = RecurrentDecoderFiLM( |
|
pose_input_size, |
|
speech_encoding_size, |
|
style_encoding_size, |
|
pose_output_size, |
|
hidden_size, |
|
num_rnn_layers, |
|
) |
|
|
|
self.cell_state_encoder = CellStateEncoder( |
|
pose_input_size + style_encoding_size, hidden_size, num_rnn_layers |
|
) |
|
|
|
def forward( |
|
self, |
|
Z_root_pos, |
|
Z_root_rot, |
|
Z_root_vel, |
|
Z_root_vrt, |
|
Z_lpos, |
|
Z_ltxy, |
|
Z_lvel, |
|
Z_lvrt, |
|
Z_gaze_pos, |
|
speech_encoding, |
|
style_encoding, |
|
parents, |
|
anim_input_mean, |
|
anim_input_std, |
|
anim_output_mean, |
|
anim_output_std, |
|
dt: float, |
|
): |
|
|
|
batchsize = speech_encoding.shape[0] |
|
nframes = speech_encoding.shape[1] |
|
|
|
|
|
O_root_pos = [Z_root_pos] |
|
O_root_rot = [Z_root_rot] |
|
O_root_vel = [Z_root_vel] |
|
O_root_vrt = [Z_root_vrt] |
|
O_lpos = [Z_lpos] |
|
O_ltxy = [Z_ltxy] |
|
O_lvel = [Z_lvel] |
|
O_lvrt = [Z_lvrt] |
|
|
|
|
|
decoder_state = self.cell_state_encoder(vectorize_input(Z_root_pos, Z_root_rot, Z_root_vel, Z_root_vrt, Z_lpos, Z_ltxy, Z_lvel, Z_lvrt, Z_gaze_pos[:, 0], parents, anim_input_mean, anim_input_std), |
|
style_encoding[:, 0], |
|
) |
|
|
|
for i in range(1, nframes): |
|
|
|
pose_encoding = vectorize_input( |
|
O_root_pos[-1], |
|
O_root_rot[-1], |
|
O_root_vel[-1], |
|
O_root_vrt[-1], |
|
O_lpos[-1], |
|
O_ltxy[-1], |
|
O_lvel[-1], |
|
O_lvrt[-1], |
|
Z_gaze_pos[:, i], |
|
parents, |
|
anim_input_mean, |
|
anim_input_std, |
|
) |
|
pdb.set_trace() |
|
|
|
predicted, decoder_state = self.recurrent_decoder( |
|
pose_encoding, speech_encoding[:, i], style_encoding[:, i], decoder_state |
|
) |
|
|
|
|
|
( |
|
P_root_pos, |
|
P_root_rot, |
|
P_root_vel, |
|
P_root_vrt, |
|
P_lpos, |
|
P_ltxy, |
|
P_lvel, |
|
P_lvrt, |
|
) = devectorize_output( |
|
predicted, |
|
O_root_pos[-1], |
|
O_root_rot[-1], |
|
Z_lpos.shape[0], |
|
Z_lpos.shape[1], |
|
dt, |
|
anim_output_mean, |
|
anim_output_std, |
|
) |
|
|
|
|
|
O_root_pos.append(P_root_pos) |
|
O_root_rot.append(P_root_rot) |
|
O_root_vel.append(P_root_vel) |
|
O_root_vrt.append(P_root_vrt) |
|
O_lpos.append(P_lpos) |
|
O_ltxy.append(P_ltxy) |
|
O_lvel.append(P_lvel) |
|
O_lvrt.append(P_lvrt) |
|
|
|
return ( |
|
torch.cat([O[:, None] for O in O_root_pos], dim=1), |
|
torch.cat([O[:, None] for O in O_root_rot], dim=1), |
|
torch.cat([O[:, None] for O in O_root_vel], dim=1), |
|
torch.cat([O[:, None] for O in O_root_vrt], dim=1), |
|
torch.cat([O[:, None] for O in O_lpos], dim=1), |
|
torch.cat([O[:, None] for O in O_ltxy], dim=1), |
|
torch.cat([O[:, None] for O in O_lvel], dim=1), |
|
torch.cat([O[:, None] for O in O_lvrt], dim=1), |
|
) |
|
|
|
|
|
class RecurrentDecoderNormal(nn.Module): |
|
def __init__( |
|
self, pose_input_size, speech_size, style_size, output_size, hidden_size, num_rnn_layers |
|
): |
|
super(RecurrentDecoderNormal, self).__init__() |
|
|
|
all_input_size = pose_input_size + speech_size + style_size |
|
self.layer0 = nn.Linear(all_input_size, hidden_size) |
|
self.layer1 = nn.GRU( |
|
all_input_size + hidden_size, hidden_size, num_rnn_layers, batch_first=True |
|
) |
|
|
|
self.layer2 = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, pose, speech, style, cell_state): |
|
hidden = F.elu(self.layer0(torch.cat([pose, speech, style], dim=-1))) |
|
cell_output, cell_state = self.layer1( |
|
torch.cat([hidden, pose, speech, style], dim=-1).unsqueeze(1), cell_state |
|
) |
|
output = self.layer2(cell_output.squeeze(1)) |
|
return output, cell_state |
|
|
|
|
|
class RecurrentDecoderFiLM(nn.Module): |
|
def __init__( |
|
self, pose_input_size, speech_size, style_size, output_size, hidden_size, num_rnn_layers |
|
): |
|
super(RecurrentDecoderFiLM, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.gammas_predictor = LinearNorm( |
|
style_size, hidden_size * 2, w_init_gain="linear" |
|
) |
|
self.betas_predictor = LinearNorm( |
|
style_size, hidden_size * 2, w_init_gain="linear" |
|
) |
|
|
|
all_input_size = pose_input_size + speech_size + style_size |
|
self.layer0 = nn.Linear(pose_input_size + speech_size, hidden_size) |
|
self.layer1 = nn.GRU( |
|
pose_input_size + speech_size + hidden_size, |
|
hidden_size, |
|
num_rnn_layers, |
|
batch_first=True, |
|
dropout=0.0, |
|
) |
|
self.layer2 = nn.Linear(hidden_size, hidden_size) |
|
self.layer3 = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, pose, speech, style, cell_state): |
|
gammas = self.gammas_predictor(style) |
|
gammas = gammas + 1 |
|
betas = self.betas_predictor(style) |
|
|
|
hidden = F.elu(self.layer0(torch.cat([pose, speech], dim=-1))) |
|
hidden = hidden * gammas[:, : self.hidden_size] + betas[:, : self.hidden_size] |
|
cell_output, cell_state = self.layer1( |
|
torch.cat([hidden, pose, speech], dim=-1).unsqueeze(1), cell_state |
|
) |
|
hidden = F.elu(self.layer2(cell_output.squeeze(1))) |
|
hidden = hidden * gammas[:, self.hidden_size:] + betas[:, self.hidden_size:] |
|
output = self.layer3(hidden) |
|
return output, cell_state |
|
|
|
|
|
class CellStateEncoder(nn.Module): |
|
def __init__(self, input_size, hidden_size, num_rnn_layers): |
|
super(CellStateEncoder, self).__init__() |
|
self.num_rnn_layers = num_rnn_layers |
|
self.layer0 = nn.Linear(input_size, hidden_size) |
|
self.layer1 = nn.Linear(hidden_size, hidden_size) |
|
self.layer2 = nn.Linear(hidden_size, hidden_size * num_rnn_layers) |
|
|
|
def forward(self, pose, style): |
|
hidden = F.elu(self.layer0(torch.cat([pose, style], dim=-1))) |
|
hidden = F.elu(self.layer1(hidden)) |
|
output = self.layer2(hidden) |
|
|
|
return output.reshape(output.shape[0], self.num_rnn_layers, -1).swapaxes(0, 1).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
class SpeechEncoder(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size): |
|
super(SpeechEncoder, self).__init__() |
|
|
|
self.layer0 = nn.Conv1d( |
|
input_size, hidden_size, kernel_size=1, padding="same", padding_mode="replicate" |
|
) |
|
self.drop0 = nn.Dropout(p=0.2) |
|
|
|
self.layer1 = nn.Conv1d( |
|
hidden_size, output_size, kernel_size=31, padding="same", padding_mode="replicate" |
|
) |
|
self.drop1 = nn.Dropout(p=0.2) |
|
|
|
self.layer2 = nn.Linear(output_size, output_size) |
|
|
|
def forward(self, x): |
|
x = torch.swapaxes(x, 1, 2) |
|
x = self.drop0(F.elu(self.layer0(x))) |
|
x = self.drop1(F.elu(self.layer1(x))) |
|
x = torch.swapaxes(x, 1, 2) |
|
x = F.elu(self.layer2(x)) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class StyleEncoder(nn.Module): |
|
def __init__(self, input_size, hidden_size, style_embedding_size, type="attn", use_vae=False): |
|
super(StyleEncoder, self).__init__() |
|
self.use_vae = use_vae |
|
self.style_embedding_size = style_embedding_size |
|
output_size = 2 * style_embedding_size if use_vae else style_embedding_size |
|
if type == "gru": |
|
self.encoder = StyleEncoderGRU(input_size, hidden_size, output_size) |
|
elif type == "attn": |
|
self.encoder = StyleEncoderAttn(input_size, hidden_size, output_size) |
|
|
|
def forward(self, input, temprature: float = 1.0): |
|
encoder_output = self.encoder(input) |
|
if self.use_vae: |
|
mu, logvar = ( |
|
encoder_output[:, : self.style_embedding_size], |
|
encoder_output[:, self.style_embedding_size:], |
|
) |
|
|
|
|
|
std = torch.exp(0.5 * logvar) / temprature |
|
eps = torch.randn_like(std) |
|
|
|
style_embedding = mu + eps * std |
|
return style_embedding, mu, logvar |
|
else: |
|
return encoder_output, None, None |
|
|
|
|
|
class StyleEncoderGRU(nn.Module): |
|
def __init__(self, input_size, hidden_size, style_embedding_size): |
|
super(StyleEncoderGRU, self).__init__() |
|
|
|
self.convs = nn.Sequential( |
|
ConvNorm1D( |
|
input_size, |
|
hidden_size, |
|
kernel_size=3, |
|
stride=1, |
|
padding=int((3 - 1) / 2), |
|
dilation=1, |
|
w_init_gain="relu", |
|
), |
|
nn.ReLU(), |
|
|
|
ConvNorm1D( |
|
hidden_size, |
|
hidden_size, |
|
kernel_size=3, |
|
stride=1, |
|
padding=int((3 - 1) / 2), |
|
dilation=1, |
|
w_init_gain="relu", |
|
), |
|
nn.ReLU(), |
|
) |
|
self.rnn_layer = nn.GRU(hidden_size, hidden_size, 1, batch_first=True, bidirectional=True) |
|
self.projection_layer = LinearNorm( |
|
hidden_size * 2, style_embedding_size, w_init_gain="linear" |
|
) |
|
|
|
def forward(self, input): |
|
input = self.convs(input) |
|
output, _ = self.rnn_layer(input) |
|
style_embedding = self.projection_layer(output[:, -1]) |
|
return style_embedding |
|
|
|
|
|
class StyleEncoderAttn(nn.Module): |
|
""" Style Encoder Module: |
|
- Positional Encoding |
|
- Nf x FFT Blocks |
|
- Linear Projection Layer |
|
""" |
|
|
|
def __init__(self, input_size, hidden_size, style_embedding_size): |
|
super(StyleEncoderAttn, self).__init__() |
|
|
|
|
|
self.pos_enc = PositionalEncoding(style_embedding_size) |
|
|
|
self.convs = nn.Sequential( |
|
ConvNorm1D( |
|
input_size, |
|
hidden_size, |
|
kernel_size=3, |
|
stride=1, |
|
padding=int((3 - 1) / 2), |
|
dilation=1, |
|
w_init_gain="relu", |
|
), |
|
nn.ReLU(), |
|
nn.LayerNorm(hidden_size), |
|
nn.Dropout(0.2), |
|
ConvNorm1D( |
|
hidden_size, |
|
style_embedding_size, |
|
kernel_size=3, |
|
stride=1, |
|
padding=int((3 - 1) / 2), |
|
dilation=1, |
|
w_init_gain="relu", |
|
), |
|
nn.ReLU(), |
|
nn.LayerNorm(style_embedding_size), |
|
nn.Dropout(0.2), |
|
) |
|
|
|
blocks = [] |
|
for _ in range(1): |
|
blocks.append(FFTBlock(style_embedding_size)) |
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
def forward(self, input): |
|
""" Forward function of Prosody Encoder: |
|
frames_energy = (B, T_max) |
|
frames_pitch = (B, T_max) |
|
mel_specs = (B, nb_mels, T_max) |
|
speaker_ids = (B, ) |
|
output_lengths = (B, ) |
|
""" |
|
output_lengths = torch.as_tensor( |
|
len(input) * [input.shape[1]], device=input.device, dtype=torch.int32 |
|
) |
|
|
|
pos = self.pos_enc(output_lengths.unsqueeze(1)).to(input.device) |
|
|
|
outputs = self.convs(input) |
|
|
|
|
|
mask = ~get_mask_from_lengths(output_lengths) |
|
|
|
outputs = outputs + pos |
|
outputs = outputs.masked_fill(mask.unsqueeze(2), 0) |
|
|
|
for _, block in enumerate(self.blocks): |
|
outputs = block(outputs, None, mask) |
|
|
|
style_embedding = torch.sum(outputs, dim=1) / output_lengths.unsqueeze( |
|
1 |
|
) |
|
|
|
return style_embedding |
|
|
|
|
|
|
|
|
|
|
|
class LinearNorm(nn.Module): |
|
""" Linear Norm Module: |
|
- Linear Layer |
|
""" |
|
|
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): |
|
super(LinearNorm, self).__init__() |
|
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias) |
|
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain)) |
|
|
|
def forward(self, x): |
|
""" Forward function of Linear Norm |
|
x = (*, in_dim) |
|
""" |
|
x = self.linear_layer(x) |
|
|
|
return x |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
""" Positional Encoding Module: |
|
- Sinusoidal Positional Embedding |
|
""" |
|
|
|
def __init__(self, embed_dim, max_len=20000, timestep=10000.0): |
|
super(PositionalEncoding, self).__init__() |
|
self.embed_dim = embed_dim |
|
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, self.embed_dim, 2).float() * (-np.log(timestep) / self.embed_dim) |
|
) |
|
self.pos_enc = torch.FloatTensor(max_len, self.embed_dim).zero_() |
|
self.pos_enc[:, 0::2] = torch.sin(pos * div_term) |
|
self.pos_enc[:, 1::2] = torch.cos(pos * div_term) |
|
|
|
def forward(self, x): |
|
""" Forward function of Positional Encoding: |
|
x = (B, N) -- Long or Int tensor |
|
""" |
|
|
|
nb_frames_max = torch.max(torch.cumsum(x, dim=1)) |
|
pos_emb = torch.FloatTensor( |
|
x.size(0), nb_frames_max, self.embed_dim |
|
).zero_() |
|
|
|
|
|
|
|
for line_idx in range(x.size(0)): |
|
pos_idx = [] |
|
for column_idx in range(x.size(1)): |
|
idx = x[line_idx, column_idx] |
|
pos_idx.extend([i for i in range(idx)]) |
|
emb = self.pos_enc[pos_idx] |
|
pos_emb[line_idx, : emb.size(0), :] = emb |
|
|
|
return pos_emb |
|
|
|
|
|
class FFTBlock(nn.Module): |
|
""" FFT Block Module: |
|
- Multi-Head Attention |
|
- Position Wise Convolutional Feed-Forward |
|
- FiLM conditioning (if film_params is not None) |
|
""" |
|
|
|
def __init__(self, hidden_size): |
|
super(FFTBlock, self).__init__() |
|
self.attention = MultiHeadAttention(hidden_size) |
|
self.feed_forward = PositionWiseConvFF(hidden_size) |
|
|
|
def forward(self, x, film_params, mask): |
|
""" Forward function of FFT Block: |
|
x = (B, L_max, hidden_embed_dim) |
|
film_params = (B, nb_film_params) |
|
mask = (B, L_max) |
|
""" |
|
|
|
attn_outputs, _ = self.attention( |
|
x, x, x, key_padding_mask=mask |
|
) |
|
attn_outputs = attn_outputs.masked_fill( |
|
mask.unsqueeze(2), 0 |
|
) |
|
|
|
outputs = self.feed_forward(attn_outputs, film_params) |
|
outputs = outputs.masked_fill(mask.unsqueeze(2), 0) |
|
|
|
return outputs |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
""" Multi-Head Attention Module: |
|
- Multi-Head Attention |
|
A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser and I. Polosukhin |
|
"Attention is all you need", |
|
in NeurIPS, 2017. |
|
- Dropout |
|
- Residual Connection |
|
- Layer Normalization |
|
""" |
|
|
|
def __init__(self, hidden_size): |
|
super(MultiHeadAttention, self).__init__() |
|
self.multi_head_attention = nn.MultiheadAttention(hidden_size, 4, 0.1) |
|
self.dropout = nn.Dropout(0.1) |
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
|
def forward(self, query, key, value, key_padding_mask=None, attn_mask=None): |
|
""" Forward function of Multi-Head Attention: |
|
query = (B, L_max, hidden_embed_dim) |
|
key = (B, T_max, hidden_embed_dim) |
|
value = (B, T_max, hidden_embed_dim) |
|
key_padding_mask = (B, T_max) if not None |
|
attn_mask = (L_max, T_max) if not None |
|
""" |
|
|
|
|
|
|
|
attn_outputs, attn_weights = self.multi_head_attention( |
|
query.transpose(0, 1), |
|
key.transpose(0, 1), |
|
value.transpose(0, 1), |
|
key_padding_mask=key_padding_mask, |
|
attn_mask=attn_mask, |
|
) |
|
attn_outputs = attn_outputs.transpose(0, 1) |
|
|
|
attn_outputs = self.dropout(attn_outputs) |
|
|
|
attn_outputs = self.layer_norm(attn_outputs + query) |
|
|
|
return attn_outputs, attn_weights |
|
|
|
|
|
class PositionWiseConvFF(nn.Module): |
|
""" Position Wise Convolutional Feed-Forward Module: |
|
- 2x Conv 1D with ReLU |
|
- Dropout |
|
- Residual Connection |
|
- Layer Normalization |
|
- FiLM conditioning (if film_params is not None) |
|
""" |
|
|
|
def __init__(self, hidden_size): |
|
super(PositionWiseConvFF, self).__init__() |
|
self.convs = nn.Sequential( |
|
ConvNorm1D( |
|
hidden_size, |
|
hidden_size, |
|
kernel_size=3, |
|
stride=1, |
|
padding=int((3 - 1) / 2), |
|
dilation=1, |
|
w_init_gain="relu", |
|
), |
|
nn.ReLU(), |
|
ConvNorm1D( |
|
hidden_size, |
|
hidden_size, |
|
kernel_size=3, |
|
stride=1, |
|
padding=int((3 - 1) / 2), |
|
dilation=1, |
|
w_init_gain="linear", |
|
), |
|
nn.Dropout(0.1), |
|
) |
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
|
def forward(self, x, film_params): |
|
""" Forward function of PositionWiseConvFF: |
|
x = (B, L_max, hidden_embed_dim) |
|
film_params = (B, nb_film_params) |
|
""" |
|
|
|
outputs = self.convs(x) |
|
|
|
outputs = self.layer_norm(outputs + x) |
|
|
|
if film_params is not None: |
|
nb_gammas = int(film_params.size(1) / 2) |
|
assert nb_gammas == outputs.size(2) |
|
gammas = film_params[:, :nb_gammas].unsqueeze(1) |
|
betas = film_params[:, nb_gammas:].unsqueeze(1) |
|
outputs = gammas * outputs + betas |
|
|
|
return outputs |
|
|
|
|
|
class ConvNorm1D(nn.Module): |
|
""" Conv Norm 1D Module: |
|
- Conv 1D |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=None, |
|
dilation=1, |
|
bias=True, |
|
w_init_gain="linear", |
|
): |
|
super(ConvNorm1D, self).__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) |
|
|
|
def forward(self, x): |
|
""" Forward function of Conv Norm 1D |
|
x = (B, L, in_channels) |
|
""" |
|
x = x.transpose(1, 2) |
|
x = self.conv(x) |
|
x = x.transpose(1, 2) |
|
|
|
return x |
|
|
|
|
|
class AvgPoolNorm1D(nn.Module): |
|
def __init__( |
|
self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True |
|
): |
|
super(AvgPoolNorm1D, self).__init__() |
|
self.avgpool1d = nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1, 2) |
|
x = self.avgpool1d(x) |
|
x = x.transpose(1, 2) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def normalize(x, eps: float = 1e-8): |
|
return x / (torch.norm(x, dim=-1, keepdim=True) + eps) |
|
|
|
|
|
@torch.jit.script |
|
def vectorize_input( |
|
Z_root_pos, |
|
Z_root_rot, |
|
Z_root_vel, |
|
Z_root_vrt, |
|
Z_lpos, |
|
Z_ltxy, |
|
Z_lvel, |
|
Z_lvrt, |
|
Z_gaze_pos, |
|
parents, |
|
anim_input_mean, |
|
anim_input_std, |
|
): |
|
batchsize = Z_lpos.shape[0] |
|
|
|
|
|
|
|
Z_gaze_dir = quat_inv_mul_vec(Z_root_rot, Z_gaze_pos - Z_root_pos) |
|
|
|
|
|
pose_encoding = torch.cat( |
|
[ |
|
Z_root_vel.reshape([batchsize, -1]), |
|
Z_root_vrt.reshape([batchsize, -1]), |
|
Z_lpos.reshape([batchsize, -1]), |
|
Z_ltxy.reshape([batchsize, -1]), |
|
Z_lvel.reshape([batchsize, -1]), |
|
Z_lvrt.reshape([batchsize, -1]), |
|
Z_gaze_dir.reshape([batchsize, -1]), |
|
], |
|
dim=1, |
|
) |
|
|
|
|
|
return (pose_encoding - anim_input_mean) / anim_input_std |
|
|
|
|
|
@torch.jit.script |
|
def devectorize_output( |
|
predicted, |
|
Z_root_pos, |
|
Z_root_rot, |
|
batchsize: int, |
|
njoints: int, |
|
dt: float, |
|
anim_output_mean, |
|
anim_output_std, |
|
): |
|
|
|
predicted = (predicted * anim_output_std) + anim_output_mean |
|
|
|
|
|
P_root_vel = predicted[:, 0:3] |
|
P_root_vrt = predicted[:, 3:6] |
|
P_lpos = predicted[:, 6 + njoints * 0: 6 + njoints * 3].reshape([batchsize, njoints, 3]) |
|
P_ltxy = predicted[:, 6 + njoints * 3: 6 + njoints * 9].reshape([batchsize, njoints, 2, 3]) |
|
P_lvel = predicted[:, 6 + njoints * 9: 6 + njoints * 12].reshape([batchsize, njoints, 3]) |
|
P_lvrt = predicted[:, 6 + njoints * 12: 6 + njoints * 15].reshape([batchsize, njoints, 3]) |
|
|
|
|
|
P_root_pos = quat_mul_vec(Z_root_rot, P_root_vel * dt) + Z_root_pos |
|
P_root_rot = quat_mul(quat_from_helical(quat_mul_vec(Z_root_rot, P_root_vrt * dt)), Z_root_rot) |
|
|
|
return (P_root_pos, P_root_rot, P_root_vel, P_root_vrt, P_lpos, P_ltxy, P_lvel, P_lvrt) |
|
|
|
|
|
def generalized_logistic_function(x, center=0.0, B=1.0, A=0.0, K=1.0, C=1.0, Q=1.0, nu=1.0): |
|
""" Equation of the generalised logistic function |
|
https://en.wikipedia.org/wiki/Generalised_logistic_function |
|
|
|
:param x: abscissa point where logistic function needs to be evaluated |
|
:param center: abscissa point corresponding to starting time |
|
:param B: growth rate |
|
:param A: lower asymptote |
|
:param K: upper asymptote when C=1. |
|
:param C: change upper asymptote value |
|
:param Q: related to value at starting time abscissa point |
|
:param nu: affects near which asymptote maximum growth occurs |
|
|
|
:return: value of logistic function at abscissa point |
|
""" |
|
value = A + (K - A) / (C + Q * np.exp(-B * (x - center))) ** (1 / nu) |
|
return value |
|
|
|
|
|
def compute_KL_div(mu, logvar, iteration): |
|
""" Compute KL divergence loss |
|
mu = (B, embed_dim) |
|
logvar = (B, embed_dim) |
|
""" |
|
|
|
|
|
|
|
|
|
kl_weight_center = 7500 |
|
kl_weight_growth_rate = 0.005 |
|
kl_threshold = 2e-1 |
|
|
|
|
|
kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) |
|
kl_div = torch.mean(kl_div) |
|
|
|
|
|
|
|
|
|
kl_div_weight = generalized_logistic_function( |
|
iteration, center=kl_weight_center, B=kl_weight_growth_rate, |
|
) |
|
|
|
kl_div_weight = min(kl_div_weight, kl_threshold) |
|
return kl_div, kl_div_weight |
|
|
|
|
|
def compute_kl_uni_gaus(q_params: Tuple, p_params: Tuple): |
|
mu_q, log_var_q = q_params |
|
mu_p, log_var_p = p_params |
|
|
|
kl = 0.5 * (log_var_p - log_var_q) + (log_var_q.exp() + (mu_q - mu_p) ** 2) / (2 * log_var_p.exp()) - 0.5 + 1e-8 |
|
kl = torch.sum(kl, dim=-1) |
|
kl = torch.mean(kl) |
|
return kl |
|
|
|
|
|
def get_mask_from_lengths(lengths): |
|
""" Create a masked tensor from given lengths |
|
|
|
:param lengths: torch.tensor of size (B, ) -- lengths of each example |
|
|
|
:return mask: torch.tensor of size (B, max_length) -- the masked tensor |
|
""" |
|
max_len = torch.max(lengths) |
|
|
|
ids = torch.arange(0, max_len).long().to(lengths.device) |
|
mask = (ids < lengths.unsqueeze(1)).bool().to(lengths.device) |
|
return mask |
|
|