conex / espnet2 /tts /prosody_encoder.py
tobiasc's picture
Initial commit
ad16788
from typing import Sequence
import math
import torch
from torch import nn
from torch.nn import functional as F
from typeguard import check_argument_types
class VectorQuantizer(nn.Module):
"""
Reference:
[1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
"""
def __init__(self,
num_embeddings: int,
hidden_dim: int,
beta: float = 0.25):
super().__init__()
self.K = num_embeddings
self.D = hidden_dim
self.beta = 0.05 # beta override
self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.normal_(0.8, 0.1) # override
def forward(self, latents: torch.Tensor) -> torch.Tensor:
# latents = latents.permute(0, 2, 1).contiguous() # (B, D, L) -> (B, L, D)
latents_shape = latents.shape
flat_latents = latents.view(-1, self.D) # (BL, D)
# Compute L2 distance between latents and embedding weights
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - \
2 * torch.matmul(flat_latents, self.embedding.weight.t()) # (BL, K)
# Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim=1) # (BL)
output_inds = encoding_inds.view(latents_shape[0], latents_shape[1]) # (B, L)
encoding_inds = encoding_inds.unsqueeze(1) # (BL, 1)
# Convert to one-hot encodings
device = latents.device
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
encoding_one_hot.scatter_(1, encoding_inds, 1) # (BL, K)
# Quantize the latents
# (BL, D)
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)
quantized_latents = quantized_latents.view(latents_shape) # (B, L, D)
# Compute the VQ Losses
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
embedding_loss = F.mse_loss(quantized_latents, latents.detach())
vq_loss = commitment_loss * self.beta + embedding_loss
# Add the residue back to the latents
quantized_latents = latents + (quantized_latents - latents).detach()
# print(output_inds)
# print(quantized_latents)
# The perplexity a useful value to track during training.
# It indicates how many codes are 'active' on average.
avg_probs = torch.mean(encoding_one_hot, dim=0)
# Exponential entropy
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return quantized_latents, vq_loss, output_inds, self.embedding, perplexity
class ProsodyEncoder(nn.Module):
"""VQ-VAE prosody encoder module.
Args:
odim (int): Number of input channels (mel spectrogram channels).
ref_enc_conv_layers (int, optional):
The number of conv layers in the reference encoder.
ref_enc_conv_chans_list: (Sequence[int], optional):
List of the number of channels of conv layers in the referece encoder.
ref_enc_conv_kernel_size (int, optional):
Kernal size of conv layers in the reference encoder.
ref_enc_conv_stride (int, optional):
Stride size of conv layers in the reference encoder.
ref_enc_gru_layers (int, optional):
The number of GRU layers in the reference encoder.
ref_enc_gru_units (int, optional):
The number of GRU units in the reference encoder.
ref_emb_integration_type: How to integrate reference embedding.
adim (int, optional): This value is not that important.
This will not change the capacity in the information-bottleneck.
num_embeddings (int, optional): The higher this value, the higher the
capacity in the information bottleneck.
FG (int, optional): Number of hidden channels.
"""
def __init__(
self,
odim: int,
adim: int = 64,
num_embeddings: int = 10,
hidden_dim: int = 3,
beta: float = 0.25,
ref_enc_conv_layers: int = 2,
ref_enc_conv_chans_list: Sequence[int] = (32, 32),
ref_enc_conv_kernel_size: int = 3,
ref_enc_conv_stride: int = 1,
global_enc_gru_layers: int = 1,
global_enc_gru_units: int = 32,
global_emb_integration_type: str = "add",
) -> None:
assert check_argument_types()
super().__init__()
# store hyperparameters
self.global_emb_integration_type = global_emb_integration_type
padding = (ref_enc_conv_kernel_size - 1) // 2
self.ref_encoder = RefEncoder(
ref_enc_conv_layers=ref_enc_conv_layers,
ref_enc_conv_chans_list=ref_enc_conv_chans_list,
ref_enc_conv_kernel_size=ref_enc_conv_kernel_size,
ref_enc_conv_stride=ref_enc_conv_stride,
ref_enc_conv_padding=padding,
)
# get the number of ref enc output units
ref_enc_output_units = odim
for i in range(ref_enc_conv_layers):
ref_enc_output_units = (
ref_enc_output_units - ref_enc_conv_kernel_size + 2 * padding
) // ref_enc_conv_stride + 1
ref_enc_output_units *= ref_enc_conv_chans_list[-1]
self.fg_encoder = FGEncoder(
ref_enc_output_units + global_enc_gru_units,
hidden_dim=hidden_dim,
)
self.global_encoder = GlobalEncoder(
ref_enc_output_units,
global_enc_gru_layers=global_enc_gru_layers,
global_enc_gru_units=global_enc_gru_units,
)
# define a projection for the global embeddings
if self.global_emb_integration_type == "add":
self.global_projection = nn.Linear(global_enc_gru_units, adim)
else:
self.global_projection = nn.Linear(
adim + global_enc_gru_units, adim
)
self.ar_prior = ARPrior(
adim,
num_embeddings=num_embeddings,
hidden_dim=hidden_dim,
)
self.vq_layer = VectorQuantizer(num_embeddings, hidden_dim, beta)
# define a projection for the quantized fine-grained embeddings
self.qfg_projection = nn.Linear(hidden_dim, adim)
def forward(
self,
ys: torch.Tensor,
ds: torch.Tensor,
hs: torch.Tensor,
global_embs: torch.Tensor = None,
train_ar_prior: bool = False,
ar_prior_inference: bool = False,
fg_inds: torch.Tensor = None,
) -> Sequence[torch.Tensor]:
"""Calculate forward propagation.
Args:
ys (Tensor): Batch of padded target features (B, Lmax, odim).
ds (LongTensor): Batch of padded durations (B, Tmax).
hs (Tensor): Batch of phoneme embeddings (B, Tmax, D).
global_embs (Tensor, optional): Global embeddings (B, D)
Returns:
Tensor: Fine-grained quantized prosody embeddings (B, Tmax, adim).
Tensor: VQ loss.
Tensor: Global prosody embeddings (B, ref_enc_gru_units)
"""
if ys is not None:
print('generating global_embs')
ref_embs = self.ref_encoder(ys) # (B, L', ref_enc_output_units)
global_embs = self.global_encoder(ref_embs) # (B, ref_enc_gru_units)
if ar_prior_inference:
print('Using ar prior')
hs_integrated = self._integrate_with_global_embs(hs, global_embs)
qs, top_inds = self.ar_prior.inference(
hs_integrated, fg_inds, self.vq_layer.embedding
)
qs = self.qfg_projection(qs) # (B, Tmax, adim)
assert hs.size(2) == qs.size(2)
p_embs = self._integrate_with_global_embs(qs, global_embs)
assert hs.shape == p_embs.shape
return p_embs, 0, 0, 0, top_inds # (B, Tmax, adim)
# concat global embs to ref embs
global_embs_expanded = global_embs.unsqueeze(1).expand(-1, ref_embs.size(1), -1)
# (B, Tmax, D)
ref_embs_integrated = torch.cat([ref_embs, global_embs_expanded], dim=-1)
# (B, Tmax, hidden_dim)
fg_embs = self.fg_encoder(ref_embs_integrated, ds, ys.size(1))
# (B, Tmax, hidden_dim)
qs, vq_loss, inds, codebook, perplexity = self.vq_layer(fg_embs)
# Vector quantization should maintain length
assert hs.size(1) == qs.size(1)
qs = self.qfg_projection(qs) # (B, Tmax, adim)
assert hs.size(2) == qs.size(2)
p_embs = self._integrate_with_global_embs(qs, global_embs)
assert hs.shape == p_embs.shape
ar_prior_loss = 0
if train_ar_prior:
# (B, Tmax, adim)
hs_integrated = self._integrate_with_global_embs(hs, global_embs)
qs, ar_prior_loss = self.ar_prior(hs_integrated, inds, codebook)
qs = self.qfg_projection(qs) # (B, Tmax, adim)
assert hs.size(2) == qs.size(2)
p_embs = self._integrate_with_global_embs(qs, global_embs)
assert hs.shape == p_embs.shape
return p_embs, vq_loss, ar_prior_loss, perplexity, global_embs
def _integrate_with_global_embs(
self,
qs: torch.Tensor,
global_embs: torch.Tensor
) -> torch.Tensor:
"""Integrate ref embedding with spectrogram hidden states.
Args:
qs (Tensor): Batch of quantized FG embeddings (B, Tmax, adim).
global_embs (Tensor): Batch of global embeddings (B, global_enc_gru_units).
Returns:
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim).
"""
if self.global_emb_integration_type == "add":
# apply projection to hidden states
global_embs = self.global_projection(global_embs)
res = qs + global_embs.unsqueeze(1)
elif self.global_emb_integration_type == "concat":
# concat hidden states with prosody embeds and then apply projection
# (B, Tmax, ref_emb_dim)
global_embs = global_embs.unsqueeze(1).expand(-1, qs.size(1), -1)
# (B, Tmax, D)
res = self.prosody_projection(torch.cat([qs, global_embs], dim=-1))
else:
raise NotImplementedError("support only add or concat.")
return res
class RefEncoder(nn.Module):
def __init__(
self,
ref_enc_conv_layers: int = 2,
ref_enc_conv_chans_list: Sequence[int] = (32, 32),
ref_enc_conv_kernel_size: int = 3,
ref_enc_conv_stride: int = 1,
ref_enc_conv_padding: int = 1,
):
"""Initilize reference encoder module."""
assert check_argument_types()
super().__init__()
# check hyperparameters are valid
assert ref_enc_conv_kernel_size % 2 == 1, "kernel size must be odd."
assert (
len(ref_enc_conv_chans_list) == ref_enc_conv_layers
), "the number of conv layers and length of channels list must be the same."
convs = []
for i in range(ref_enc_conv_layers):
conv_in_chans = 1 if i == 0 else ref_enc_conv_chans_list[i - 1]
conv_out_chans = ref_enc_conv_chans_list[i]
convs += [
nn.Conv2d(
conv_in_chans,
conv_out_chans,
kernel_size=ref_enc_conv_kernel_size,
stride=ref_enc_conv_stride,
padding=ref_enc_conv_padding,
),
nn.ReLU(inplace=True),
]
self.convs = nn.Sequential(*convs)
def forward(self, ys: torch.Tensor) -> torch.Tensor:
"""Calculate forward propagation.
Args:
ys (Tensor): Batch of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of spectrogram hiddens (B, L', ref_enc_output_units)
"""
B = ys.size(0)
ys = ys.unsqueeze(1) # (B, 1, Lmax, odim)
hs = self.convs(ys) # (B, conv_out_chans, L', odim')
hs = hs.transpose(1, 2) # (B, L', conv_out_chans, odim')
L = hs.size(1)
# (B, L', ref_enc_output_units) -> "flatten"
hs = hs.contiguous().view(B, L, -1)
return hs
class GlobalEncoder(nn.Module):
"""Module that creates a global embedding from a hidden spectrogram sequence.
Args:
"""
def __init__(
self,
ref_enc_output_units: int,
global_enc_gru_layers: int = 1,
global_enc_gru_units: int = 32,
):
super().__init__()
self.gru = torch.nn.GRU(ref_enc_output_units, global_enc_gru_units,
global_enc_gru_layers, batch_first=True)
def forward(
self,
hs: torch.Tensor,
):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of spectrogram hiddens (B, L', ref_enc_output_units).
Returns:
Tensor: Reference embedding (B, ref_enc_gru_units).
"""
self.gru.flatten_parameters()
_, global_embs = self.gru(hs) # (gru_layers, B, ref_enc_gru_units)
global_embs = global_embs[-1] # (B, ref_enc_gru_units)
return global_embs
class FGEncoder(nn.Module):
"""Spectrogram to phoneme alignment module.
Args:
"""
def __init__(
self,
input_units: int,
hidden_dim: int = 3,
):
assert check_argument_types()
super().__init__()
self.projection = nn.Sequential(
nn.Sequential(
nn.Linear(input_units, input_units // 2),
nn.ReLU(),
nn.Dropout(p=0.2),
),
nn.Sequential(
nn.Linear(input_units // 2, hidden_dim),
nn.ReLU(),
nn.Dropout(p=0.2),
)
)
def forward(
self,
hs: torch.Tensor,
ds: torch.Tensor,
Lmax: int
):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of spectrogram hiddens
(B, L', ref_enc_output_units + global_enc_gru_units).
ds (LongTensor): Batch of padded durations (B, Tmax).
Returns:
Tensor: aligned spectrogram hiddens (B, Tmax, hidden_dim).
"""
# (B, Tmax, ref_enc_output_units + global_enc_gru_units)
hs = self._align_durations(hs, ds, Lmax)
hs = self.projection(hs) # (B, Tmax, hidden_dim)
return hs
def _align_durations(self, hs, ds, Lmax):
"""Transform the spectrogram hiddens according to the ground-truth durations
so that there's only one hidden per phoneme hidden.
Args:
# (B, L', ref_enc_output_units + global_enc_gru_units)
hs (Tensor): Batch of spectrogram hidden state sequences .
ds (LongTensor): Batch of padded durations (B, Tmax)
Returns:
# (B, Tmax, ref_enc_output_units + global_enc_gru_units)
Tensor: Batch of averaged spectrogram hidden state sequences.
"""
B = hs.size(0)
L = hs.size(1)
D = hs.size(2)
Tmax = ds.size(1) # -1 if Tmax + 1
device = hs.device
hs_res = torch.zeros(
[B, Tmax, D],
device=device
) # (B, Tmax, D)
with torch.no_grad():
for b_i in range(B):
durations = ds[b_i]
multiplier = L / Lmax
i = 0
for d_i in range(Tmax):
# take into account downsampling because of conv layers
d = max(math.floor(durations[d_i].item() * multiplier), 1)
if durations[d_i].item() > 0:
hs_slice = hs[b_i, i:i + d, :] # (d, D)
hs_res[b_i, d_i, :] = torch.mean(hs_slice, 0)
i += d
hs_res.requires_grad_(hs.requires_grad)
return hs_res
class ARPrior(nn.Module):
# torch.topk(decoder_output, beam_width)
"""Autoregressive prior.
This module is inspired by the AR prior described in `Generating diverse and
natural text-to-speech samples using a quantized fine-grained VAE and
auto-regressive prosody prior`. This prior is fit in the continuous latent space.
"""
def __init__(
self,
adim: int,
num_embeddings: int = 10,
hidden_dim: int = 3,
):
assert check_argument_types()
super().__init__()
# store hyperparameters
self.adim = adim
self.hidden_dim = hidden_dim
self.num_embeddings = num_embeddings
self.qs_projection = nn.Linear(hidden_dim, adim)
self.lstm = nn.LSTMCell(
self.adim,
self.num_embeddings,
)
self.criterion = nn.NLLLoss()
def inds_to_embs(self, inds, codebook, device):
"""Returns the quantized embeddings from the codebook,
corresponding to the indices.
Args:
inds (Tensor): Batch of indices (B, Tmax, 1).
codebook (Embedding): (num_embeddings, D).
Returns:
Tensor: Quantized embeddings (B, Tmax, D).
"""
flat_inds = torch.flatten(inds).unsqueeze(1) # (BL, 1)
# Convert to one-hot encodings
encoding_one_hot = torch.zeros(
flat_inds.size(0),
self.num_embeddings,
device=device
)
encoding_one_hot.scatter_(1, flat_inds, 1) # (BL, K)
# Quantize the latents
# (BL, D)
quantized_embs = torch.matmul(encoding_one_hot, codebook.weight)
# (B, L, D)
quantized_embs = quantized_embs.view(
inds.size(0), inds.size(1), self.hidden_dim
)
return quantized_embs
def top_embeddings(self, emb_scores: torch.Tensor, codebook):
"""Returns the top quantized embeddings from the codebook using the scores.
Args:
emb_scores (Tensor): Batch of embedding scores (B, Tmax, num_embeddings).
codebook (Embedding): (num_embeddings, D).
Returns:
Tensor: Top quantized embeddings (B, Tmax, D).
Tensor: Top 3 inds (B, Tmax, 3).
"""
_, top_inds = emb_scores.topk(1, dim=-1) # (B, L, 1)
quantized_embs = self.inds_to_embs(
top_inds,
codebook,
emb_scores.device,
)
_, top3_inds = emb_scores.topk(3, dim=-1) # (B, L, 1)
return quantized_embs, top3_inds
def _forward(self, hs_ref_embs, codebook, fg_inds=None):
inds = []
scores = []
embs = []
if fg_inds is not None:
init_embs = self.inds_to_embs(fg_inds, codebook, hs_ref_embs.device)
embs = [init_emb.unsqueeze(1) for init_emb in init_embs.transpose(1, 0)]
start = fg_inds.size(1) if fg_inds is not None else 0
hidden = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size)
cell = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size)
for i in range(start, hs_ref_embs.size(1)):
# (B, adim)
input = hs_ref_embs[:, i]
if i != 0:
# (B, 1, adim)
qs = self.qs_projection(embs[-1])
# (B, adim)
input = hs_ref_embs[:, i] + qs.squeeze()
hidden, cell = self.lstm(input, (hidden, cell)) # (B, K)
out = hidden.unsqueeze(1) # (B, 1, K)
# (B, 1, K)
emb_scores = F.log_softmax(out, dim=2)
quantized_embs, top_inds = self.top_embeddings(emb_scores, codebook)
# (B, 1, hidden_dim)
embs.append(quantized_embs)
scores.append(emb_scores)
inds.append(top_inds)
out_embs = torch.cat(embs, dim=1) # (B, L, hidden_dim)
assert(out_embs.size(0) == hs_ref_embs.size(0))
assert(out_embs.size(1) == hs_ref_embs.size(1))
out_emb_scores = torch.cat(scores, dim=1) if start < hs_ref_embs.size(1) else scores
out_inds = torch.cat(inds, dim=1) if start < hs_ref_embs.size(1) else fg_inds
return out_embs, out_emb_scores, out_inds
def forward(self, hs_ref_embs, inds, codebook):
"""Calculate forward propagation.
Args:
hs_p_embs (Tensor): Batch of phoneme embeddings
with integrated global prosody embeddings (B, Tmax, D).
inds (Tensor): Batch of ground-truth codebook indices
(B, Tmax).
Returns:
Tensor: Batch of predicted quantized latents (B, Tmax, D).
Tensor: Cross entropy loss value.
"""
quantized_embs, emb_scores, _ = self._forward(hs_ref_embs, codebook)
emb_scores = emb_scores.permute(0, 2, 1).contiguous() # (B, num_embeddings, L)
loss = self.criterion(emb_scores, inds)
return quantized_embs, loss
def inference(self, hs_ref_embs, fg_inds, codebook):
"""Inference duration.
Args:
hs_p_embs (Tensor): Batch of phoneme embeddings
with integrated global prosody embeddings (B, Tmax, D).
Returns:
Tensor: Batch of predicted quantized latents (B, Tmax, D).
"""
# Random sampling
# fg_inds = torch.rand(hs_ref_embs.size(0), hs_ref_embs.size(1))
# fg_inds *= codebook.weight.size(0) - 1
# fg_inds = torch.round(fg_inds)
# fg_inds = fg_inds.long()
quantized_embs, _, top_inds = self._forward(hs_ref_embs, codebook, fg_inds)
return quantized_embs, top_inds