youngseng's picture
Upload 187 files
da855ff
import pdb
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
from anim.tquat import *
# ===============================================
# Decoder
# ===============================================
class Decoder(nn.Module):
def __init__(
self,
pose_input_size,
pose_output_size,
speech_encoding_size,
style_encoding_size,
hidden_size,
num_rnn_layers,
rnn_cond="normal",
):
super(Decoder, self).__init__()
if rnn_cond == "normal":
self.recurrent_decoder = RecurrentDecoderNormal(
pose_input_size,
speech_encoding_size,
style_encoding_size,
pose_output_size,
hidden_size,
num_rnn_layers,
)
elif rnn_cond == "film":
self.recurrent_decoder = RecurrentDecoderFiLM(
pose_input_size,
speech_encoding_size,
style_encoding_size,
pose_output_size,
hidden_size,
num_rnn_layers,
)
self.cell_state_encoder = CellStateEncoder(
pose_input_size + style_encoding_size, hidden_size, num_rnn_layers
)
def forward(
self,
Z_root_pos,
Z_root_rot,
Z_root_vel,
Z_root_vrt,
Z_lpos,
Z_ltxy,
Z_lvel,
Z_lvrt,
Z_gaze_pos,
speech_encoding,
style_encoding,
parents,
anim_input_mean,
anim_input_std,
anim_output_mean,
anim_output_std,
dt: float,
):
batchsize = speech_encoding.shape[0]
nframes = speech_encoding.shape[1]
# Getting initial values from ground truth
O_root_pos = [Z_root_pos]
O_root_rot = [Z_root_rot]
O_root_vel = [Z_root_vel]
O_root_vrt = [Z_root_vrt]
O_lpos = [Z_lpos]
O_ltxy = [Z_ltxy]
O_lvel = [Z_lvel]
O_lvrt = [Z_lvrt]
# Initialize the hidden state of decoder
decoder_state = self.cell_state_encoder(vectorize_input(Z_root_pos, Z_root_rot, Z_root_vel, Z_root_vrt, Z_lpos, Z_ltxy, Z_lvel, Z_lvrt, Z_gaze_pos[:, 0], parents, anim_input_mean, anim_input_std),
style_encoding[:, 0],
)
for i in range(1, nframes):
# Prepare Input
pose_encoding = vectorize_input(
O_root_pos[-1],
O_root_rot[-1],
O_root_vel[-1],
O_root_vrt[-1],
O_lpos[-1],
O_ltxy[-1],
O_lvel[-1],
O_lvrt[-1],
Z_gaze_pos[:, i],
parents,
anim_input_mean,
anim_input_std,
)
pdb.set_trace()
# Predict
predicted, decoder_state = self.recurrent_decoder(
pose_encoding, speech_encoding[:, i], style_encoding[:, i], decoder_state
)
# Integrate Prediction
(
P_root_pos,
P_root_rot,
P_root_vel,
P_root_vrt,
P_lpos,
P_ltxy,
P_lvel,
P_lvrt,
) = devectorize_output(
predicted,
O_root_pos[-1],
O_root_rot[-1],
Z_lpos.shape[0],
Z_lpos.shape[1],
dt,
anim_output_mean,
anim_output_std,
)
# Append
O_root_pos.append(P_root_pos)
O_root_rot.append(P_root_rot)
O_root_vel.append(P_root_vel)
O_root_vrt.append(P_root_vrt)
O_lpos.append(P_lpos)
O_ltxy.append(P_ltxy)
O_lvel.append(P_lvel)
O_lvrt.append(P_lvrt)
return (
torch.cat([O[:, None] for O in O_root_pos], dim=1),
torch.cat([O[:, None] for O in O_root_rot], dim=1),
torch.cat([O[:, None] for O in O_root_vel], dim=1),
torch.cat([O[:, None] for O in O_root_vrt], dim=1),
torch.cat([O[:, None] for O in O_lpos], dim=1),
torch.cat([O[:, None] for O in O_ltxy], dim=1),
torch.cat([O[:, None] for O in O_lvel], dim=1),
torch.cat([O[:, None] for O in O_lvrt], dim=1),
)
class RecurrentDecoderNormal(nn.Module):
def __init__(
self, pose_input_size, speech_size, style_size, output_size, hidden_size, num_rnn_layers
):
super(RecurrentDecoderNormal, self).__init__()
all_input_size = pose_input_size + speech_size + style_size
self.layer0 = nn.Linear(all_input_size, hidden_size)
self.layer1 = nn.GRU(
all_input_size + hidden_size, hidden_size, num_rnn_layers, batch_first=True
)
self.layer2 = nn.Linear(hidden_size, output_size)
def forward(self, pose, speech, style, cell_state):
hidden = F.elu(self.layer0(torch.cat([pose, speech, style], dim=-1)))
cell_output, cell_state = self.layer1(
torch.cat([hidden, pose, speech, style], dim=-1).unsqueeze(1), cell_state
)
output = self.layer2(cell_output.squeeze(1))
return output, cell_state
class RecurrentDecoderFiLM(nn.Module):
def __init__(
self, pose_input_size, speech_size, style_size, output_size, hidden_size, num_rnn_layers
):
super(RecurrentDecoderFiLM, self).__init__()
self.hidden_size = hidden_size
self.gammas_predictor = LinearNorm(
style_size, hidden_size * 2, w_init_gain="linear"
)
self.betas_predictor = LinearNorm(
style_size, hidden_size * 2, w_init_gain="linear"
)
all_input_size = pose_input_size + speech_size + style_size
self.layer0 = nn.Linear(pose_input_size + speech_size, hidden_size)
self.layer1 = nn.GRU(
pose_input_size + speech_size + hidden_size,
hidden_size,
num_rnn_layers,
batch_first=True,
dropout=0.0,
)
self.layer2 = nn.Linear(hidden_size, hidden_size)
self.layer3 = nn.Linear(hidden_size, output_size)
def forward(self, pose, speech, style, cell_state):
gammas = self.gammas_predictor(style)
gammas = gammas + 1
betas = self.betas_predictor(style)
hidden = F.elu(self.layer0(torch.cat([pose, speech], dim=-1)))
hidden = hidden * gammas[:, : self.hidden_size] + betas[:, : self.hidden_size]
cell_output, cell_state = self.layer1(
torch.cat([hidden, pose, speech], dim=-1).unsqueeze(1), cell_state
)
hidden = F.elu(self.layer2(cell_output.squeeze(1)))
hidden = hidden * gammas[:, self.hidden_size:] + betas[:, self.hidden_size:]
output = self.layer3(hidden)
return output, cell_state
class CellStateEncoder(nn.Module):
def __init__(self, input_size, hidden_size, num_rnn_layers):
super(CellStateEncoder, self).__init__()
self.num_rnn_layers = num_rnn_layers
self.layer0 = nn.Linear(input_size, hidden_size)
self.layer1 = nn.Linear(hidden_size, hidden_size)
self.layer2 = nn.Linear(hidden_size, hidden_size * num_rnn_layers)
def forward(self, pose, style):
hidden = F.elu(self.layer0(torch.cat([pose, style], dim=-1)))
hidden = F.elu(self.layer1(hidden))
output = self.layer2(hidden)
return output.reshape(output.shape[0], self.num_rnn_layers, -1).swapaxes(0, 1).contiguous()
# ===============================================
# Speech Encoder
# ===============================================
class SpeechEncoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SpeechEncoder, self).__init__()
self.layer0 = nn.Conv1d(
input_size, hidden_size, kernel_size=1, padding="same", padding_mode="replicate"
)
self.drop0 = nn.Dropout(p=0.2)
self.layer1 = nn.Conv1d(
hidden_size, output_size, kernel_size=31, padding="same", padding_mode="replicate"
)
self.drop1 = nn.Dropout(p=0.2)
self.layer2 = nn.Linear(output_size, output_size)
def forward(self, x):
x = torch.swapaxes(x, 1, 2)
x = self.drop0(F.elu(self.layer0(x)))
x = self.drop1(F.elu(self.layer1(x)))
x = torch.swapaxes(x, 1, 2)
x = F.elu(self.layer2(x))
return x
# ===============================================
# Style Encoder
# ===============================================
class StyleEncoder(nn.Module):
def __init__(self, input_size, hidden_size, style_embedding_size, type="attn", use_vae=False):
super(StyleEncoder, self).__init__()
self.use_vae = use_vae
self.style_embedding_size = style_embedding_size
output_size = 2 * style_embedding_size if use_vae else style_embedding_size
if type == "gru":
self.encoder = StyleEncoderGRU(input_size, hidden_size, output_size)
elif type == "attn":
self.encoder = StyleEncoderAttn(input_size, hidden_size, output_size)
def forward(self, input, temprature: float = 1.0):
encoder_output = self.encoder(input)
if self.use_vae:
mu, logvar = (
encoder_output[:, : self.style_embedding_size],
encoder_output[:, self.style_embedding_size:],
)
# re-parameterization trick
std = torch.exp(0.5 * logvar) / temprature
eps = torch.randn_like(std)
style_embedding = mu + eps * std
return style_embedding, mu, logvar
else:
return encoder_output, None, None
class StyleEncoderGRU(nn.Module):
def __init__(self, input_size, hidden_size, style_embedding_size):
super(StyleEncoderGRU, self).__init__()
self.convs = nn.Sequential(
ConvNorm1D(
input_size,
hidden_size,
kernel_size=3,
stride=1,
padding=int((3 - 1) / 2),
dilation=1,
w_init_gain="relu",
),
nn.ReLU(),
# AvgPoolNorm1D(kernel_size=2),
ConvNorm1D(
hidden_size,
hidden_size,
kernel_size=3,
stride=1,
padding=int((3 - 1) / 2),
dilation=1,
w_init_gain="relu",
),
nn.ReLU(),
)
self.rnn_layer = nn.GRU(hidden_size, hidden_size, 1, batch_first=True, bidirectional=True)
self.projection_layer = LinearNorm(
hidden_size * 2, style_embedding_size, w_init_gain="linear"
)
def forward(self, input):
input = self.convs(input)
output, _ = self.rnn_layer(input)
style_embedding = self.projection_layer(output[:, -1])
return style_embedding
class StyleEncoderAttn(nn.Module):
""" Style Encoder Module:
- Positional Encoding
- Nf x FFT Blocks
- Linear Projection Layer
"""
def __init__(self, input_size, hidden_size, style_embedding_size):
super(StyleEncoderAttn, self).__init__()
# positional encoding
self.pos_enc = PositionalEncoding(style_embedding_size)
self.convs = nn.Sequential(
ConvNorm1D(
input_size,
hidden_size,
kernel_size=3,
stride=1,
padding=int((3 - 1) / 2),
dilation=1,
w_init_gain="relu",
),
nn.ReLU(),
nn.LayerNorm(hidden_size),
nn.Dropout(0.2),
ConvNorm1D(
hidden_size,
style_embedding_size,
kernel_size=3,
stride=1,
padding=int((3 - 1) / 2),
dilation=1,
w_init_gain="relu",
),
nn.ReLU(),
nn.LayerNorm(style_embedding_size),
nn.Dropout(0.2),
)
# FFT blocks
blocks = []
for _ in range(1):
blocks.append(FFTBlock(style_embedding_size))
self.blocks = nn.ModuleList(blocks)
def forward(self, input):
""" Forward function of Prosody Encoder:
frames_energy = (B, T_max)
frames_pitch = (B, T_max)
mel_specs = (B, nb_mels, T_max)
speaker_ids = (B, )
output_lengths = (B, )
"""
output_lengths = torch.as_tensor(
len(input) * [input.shape[1]], device=input.device, dtype=torch.int32
)
# compute positional encoding
pos = self.pos_enc(output_lengths.unsqueeze(1)).to(input.device) # (B, T_max, hidden_embed_dim)
# pass through convs
outputs = self.convs(input) # (B, T_max, hidden_embed_dim)
# create mask
mask = ~get_mask_from_lengths(output_lengths) # (B, T_max)
# add encodings and mask tensor
outputs = outputs + pos # (B, T_max, hidden_embed_dim)
outputs = outputs.masked_fill(mask.unsqueeze(2), 0) # (B, T_max, hidden_embed_dim)
# pass through FFT blocks
for _, block in enumerate(self.blocks):
outputs = block(outputs, None, mask) # (B, T_max, hidden_embed_dim)
# average pooling on the whole time sequence
style_embedding = torch.sum(outputs, dim=1) / output_lengths.unsqueeze(
1
) # (B, hidden_embed_dim)
return style_embedding
# ===============================================
# Sub-modules
# ===============================================
class LinearNorm(nn.Module):
""" Linear Norm Module:
- Linear Layer
"""
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
super(LinearNorm, self).__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
def forward(self, x):
""" Forward function of Linear Norm
x = (*, in_dim)
"""
x = self.linear_layer(x) # (*, out_dim)
return x
class PositionalEncoding(nn.Module):
""" Positional Encoding Module:
- Sinusoidal Positional Embedding
"""
def __init__(self, embed_dim, max_len=20000, timestep=10000.0):
super(PositionalEncoding, self).__init__()
self.embed_dim = embed_dim
pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
div_term = torch.exp(
torch.arange(0, self.embed_dim, 2).float() * (-np.log(timestep) / self.embed_dim)
) # (embed_dim // 2, )
self.pos_enc = torch.FloatTensor(max_len, self.embed_dim).zero_() # (max_len, embed_dim)
self.pos_enc[:, 0::2] = torch.sin(pos * div_term)
self.pos_enc[:, 1::2] = torch.cos(pos * div_term)
def forward(self, x):
""" Forward function of Positional Encoding:
x = (B, N) -- Long or Int tensor
"""
# initialize tensor
nb_frames_max = torch.max(torch.cumsum(x, dim=1))
pos_emb = torch.FloatTensor(
x.size(0), nb_frames_max, self.embed_dim
).zero_() # (B, nb_frames_max, embed_dim)
# pos_emb = pos_emb.cuda(x.device, non_blocking=True).float() # (B, nb_frames_max, embed_dim)
# TODO: Check if we can remove the for loops
for line_idx in range(x.size(0)):
pos_idx = []
for column_idx in range(x.size(1)):
idx = x[line_idx, column_idx]
pos_idx.extend([i for i in range(idx)])
emb = self.pos_enc[pos_idx] # (nb_frames, embed_dim)
pos_emb[line_idx, : emb.size(0), :] = emb
return pos_emb
class FFTBlock(nn.Module):
""" FFT Block Module:
- Multi-Head Attention
- Position Wise Convolutional Feed-Forward
- FiLM conditioning (if film_params is not None)
"""
def __init__(self, hidden_size):
super(FFTBlock, self).__init__()
self.attention = MultiHeadAttention(hidden_size)
self.feed_forward = PositionWiseConvFF(hidden_size)
def forward(self, x, film_params, mask):
""" Forward function of FFT Block:
x = (B, L_max, hidden_embed_dim)
film_params = (B, nb_film_params)
mask = (B, L_max)
"""
# attend
attn_outputs, _ = self.attention(
x, x, x, key_padding_mask=mask
) # (B, L_max, hidden_embed_dim)
attn_outputs = attn_outputs.masked_fill(
mask.unsqueeze(2), 0
) # (B, L_max, hidden_embed_dim)
# feed-forward pass
outputs = self.feed_forward(attn_outputs, film_params) # (B, L_max, hidden_embed_dim)
outputs = outputs.masked_fill(mask.unsqueeze(2), 0) # (B, L_max, hidden_embed_dim)
return outputs
class MultiHeadAttention(nn.Module):
""" Multi-Head Attention Module:
- Multi-Head Attention
A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser and I. Polosukhin
"Attention is all you need",
in NeurIPS, 2017.
- Dropout
- Residual Connection
- Layer Normalization
"""
def __init__(self, hidden_size):
super(MultiHeadAttention, self).__init__()
self.multi_head_attention = nn.MultiheadAttention(hidden_size, 4, 0.1)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, query, key, value, key_padding_mask=None, attn_mask=None):
""" Forward function of Multi-Head Attention:
query = (B, L_max, hidden_embed_dim)
key = (B, T_max, hidden_embed_dim)
value = (B, T_max, hidden_embed_dim)
key_padding_mask = (B, T_max) if not None
attn_mask = (L_max, T_max) if not None
"""
# compute multi-head attention
# attn_outputs = (L_max, B, hidden_embed_dim)
# attn_weights = (B, L_max, T_max)
attn_outputs, attn_weights = self.multi_head_attention(
query.transpose(0, 1),
key.transpose(0, 1),
value.transpose(0, 1),
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
attn_outputs = attn_outputs.transpose(0, 1) # (B, L_max, hidden_embed_dim)
# apply dropout
attn_outputs = self.dropout(attn_outputs) # (B, L_max, hidden_embed_dim)
# add residual connection and perform layer normalization
attn_outputs = self.layer_norm(attn_outputs + query) # (B, L_max, hidden_embed_dim)
return attn_outputs, attn_weights
class PositionWiseConvFF(nn.Module):
""" Position Wise Convolutional Feed-Forward Module:
- 2x Conv 1D with ReLU
- Dropout
- Residual Connection
- Layer Normalization
- FiLM conditioning (if film_params is not None)
"""
def __init__(self, hidden_size):
super(PositionWiseConvFF, self).__init__()
self.convs = nn.Sequential(
ConvNorm1D(
hidden_size,
hidden_size,
kernel_size=3,
stride=1,
padding=int((3 - 1) / 2),
dilation=1,
w_init_gain="relu",
),
nn.ReLU(),
ConvNorm1D(
hidden_size,
hidden_size,
kernel_size=3,
stride=1,
padding=int((3 - 1) / 2),
dilation=1,
w_init_gain="linear",
),
nn.Dropout(0.1),
)
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, x, film_params):
""" Forward function of PositionWiseConvFF:
x = (B, L_max, hidden_embed_dim)
film_params = (B, nb_film_params)
"""
# pass through convs
outputs = self.convs(x) # (B, L_max, hidden_embed_dim)
# add residual connection and perform layer normalization
outputs = self.layer_norm(outputs + x) # (B, L_max, hidden_embed_dim)
# add FiLM transformation
if film_params is not None:
nb_gammas = int(film_params.size(1) / 2)
assert nb_gammas == outputs.size(2)
gammas = film_params[:, :nb_gammas].unsqueeze(1) # (B, 1, hidden_embed_dim)
betas = film_params[:, nb_gammas:].unsqueeze(1) # (B, 1, hidden_embed_dim)
outputs = gammas * outputs + betas # (B, L_max, hidden_embed_dim)
return outputs
class ConvNorm1D(nn.Module):
""" Conv Norm 1D Module:
- Conv 1D
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
):
super(ConvNorm1D, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
def forward(self, x):
""" Forward function of Conv Norm 1D
x = (B, L, in_channels)
"""
x = x.transpose(1, 2) # (B, in_channels, L)
x = self.conv(x) # (B, out_channels, L)
x = x.transpose(1, 2) # (B, L, out_channels)
return x
class AvgPoolNorm1D(nn.Module):
def __init__(
self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True
):
super(AvgPoolNorm1D, self).__init__()
self.avgpool1d = nn.AvgPool1d(kernel_size, stride, padding, ceil_mode, count_include_pad)
def forward(self, x):
x = x.transpose(1, 2) # (B, in_channels, L)
x = self.avgpool1d(x) # (B, out_channels, L)
x = x.transpose(1, 2) # (B, L, out_channels)
return x
# ===============================================
# Funcs
# ===============================================
@torch.jit.script
def normalize(x, eps: float = 1e-8):
return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
@torch.jit.script
def vectorize_input(
Z_root_pos,
Z_root_rot,
Z_root_vel,
Z_root_vrt,
Z_lpos,
Z_ltxy,
Z_lvel,
Z_lvrt,
Z_gaze_pos,
parents,
anim_input_mean,
anim_input_std,
):
batchsize = Z_lpos.shape[0]
# Compute Local Gaze
# Z_gaze_dir = quat_inv_mul_vec(Z_root_rot, normalize(Z_gaze_pos - Z_root_pos))
Z_gaze_dir = quat_inv_mul_vec(Z_root_rot, Z_gaze_pos - Z_root_pos)
# Flatten the autoregressive input
pose_encoding = torch.cat(
[
Z_root_vel.reshape([batchsize, -1]),
Z_root_vrt.reshape([batchsize, -1]),
Z_lpos.reshape([batchsize, -1]),
Z_ltxy.reshape([batchsize, -1]),
Z_lvel.reshape([batchsize, -1]),
Z_lvrt.reshape([batchsize, -1]),
Z_gaze_dir.reshape([batchsize, -1]),
],
dim=1,
)
# Normalize
return (pose_encoding - anim_input_mean) / anim_input_std
@torch.jit.script
def devectorize_output(
predicted,
Z_root_pos,
Z_root_rot,
batchsize: int,
njoints: int,
dt: float,
anim_output_mean,
anim_output_std,
):
# Denormalize
predicted = (predicted * anim_output_std) + anim_output_mean
# Extract predictions
P_root_vel = predicted[:, 0:3]
P_root_vrt = predicted[:, 3:6]
P_lpos = predicted[:, 6 + njoints * 0: 6 + njoints * 3].reshape([batchsize, njoints, 3])
P_ltxy = predicted[:, 6 + njoints * 3: 6 + njoints * 9].reshape([batchsize, njoints, 2, 3])
P_lvel = predicted[:, 6 + njoints * 9: 6 + njoints * 12].reshape([batchsize, njoints, 3])
P_lvrt = predicted[:, 6 + njoints * 12: 6 + njoints * 15].reshape([batchsize, njoints, 3])
# Update pose state
P_root_pos = quat_mul_vec(Z_root_rot, P_root_vel * dt) + Z_root_pos
P_root_rot = quat_mul(quat_from_helical(quat_mul_vec(Z_root_rot, P_root_vrt * dt)), Z_root_rot)
return (P_root_pos, P_root_rot, P_root_vel, P_root_vrt, P_lpos, P_ltxy, P_lvel, P_lvrt)
def generalized_logistic_function(x, center=0.0, B=1.0, A=0.0, K=1.0, C=1.0, Q=1.0, nu=1.0):
""" Equation of the generalised logistic function
https://en.wikipedia.org/wiki/Generalised_logistic_function
:param x: abscissa point where logistic function needs to be evaluated
:param center: abscissa point corresponding to starting time
:param B: growth rate
:param A: lower asymptote
:param K: upper asymptote when C=1.
:param C: change upper asymptote value
:param Q: related to value at starting time abscissa point
:param nu: affects near which asymptote maximum growth occurs
:return: value of logistic function at abscissa point
"""
value = A + (K - A) / (C + Q * np.exp(-B * (x - center))) ** (1 / nu)
return value
def compute_KL_div(mu, logvar, iteration):
""" Compute KL divergence loss
mu = (B, embed_dim)
logvar = (B, embed_dim)
"""
# compute KL divergence
# see Appendix B from VAE paper:
# D.P. Kingma and M. Welling, "Auto-Encoding Variational Bayes", ICLR, 2014.
kl_weight_center = 7500 # iteration at which weight of KL divergence loss is 0.5
kl_weight_growth_rate = 0.005 # growth rate for weight of KL divergence loss
kl_threshold = 2e-1 # KL weight threshold
# kl_threshold = 1.0
kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) # (B, )
kl_div = torch.mean(kl_div)
# compute weight for KL cost annealing:
# S.R. Bowman, L. Vilnis, O. Vinyals, A.M. Dai, R. Jozefowicz, S. Bengio,
# "Generating Sentences from a Continuous Space", arXiv:1511.06349, 2016.
kl_div_weight = generalized_logistic_function(
iteration, center=kl_weight_center, B=kl_weight_growth_rate,
)
# apply weight threshold
kl_div_weight = min(kl_div_weight, kl_threshold)
return kl_div, kl_div_weight
def compute_kl_uni_gaus(q_params: Tuple, p_params: Tuple):
mu_q, log_var_q = q_params
mu_p, log_var_p = p_params
kl = 0.5 * (log_var_p - log_var_q) + (log_var_q.exp() + (mu_q - mu_p) ** 2) / (2 * log_var_p.exp()) - 0.5 + 1e-8
kl = torch.sum(kl, dim=-1)
kl = torch.mean(kl)
return kl
def get_mask_from_lengths(lengths):
""" Create a masked tensor from given lengths
:param lengths: torch.tensor of size (B, ) -- lengths of each example
:return mask: torch.tensor of size (B, max_length) -- the masked tensor
"""
max_len = torch.max(lengths)
# ids = torch.arange(0, max_len).cuda(lengths.device, non_blocking=True).long()
ids = torch.arange(0, max_len).long().to(lengths.device)
mask = (ids < lengths.unsqueeze(1)).bool().to(lengths.device)
return mask