|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from torchaudio.models import Tacotron2 |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from transformers.utils import ModelOutput |
|
|
|
|
|
__version__ = "0.1.0" |
|
|
|
@dataclass |
|
class Tacotron2Output(ModelOutput): |
|
""" |
|
mel_outputs_postnet |
|
The predicted mel spectrogram with shape |
|
`(n_batch, n_mels, max of mel_specgram_lengths)`. |
|
mel_specgram_lengths |
|
The length of the predicted mel spectrogram with shape `(n_batch, )`. |
|
alignments |
|
Sequence of attention weights from the decoder with shape |
|
`(n_batch, max of mel_specgram_lengths, max of lengths)`. |
|
""" |
|
|
|
mel_outputs_postnet: Tensor = None |
|
mel_specgram_lengths: Tensor = None |
|
alignments: Tensor = None |
|
|
|
|
|
@dataclass |
|
class Tacotron2ForPreTrainingOutput(ModelOutput): |
|
""" |
|
mel_specgram |
|
Mel spectrogram before Postnet with shape |
|
`(n_batch, n_mels, max of mel_specgram_lengths)`. |
|
mel_specgram_postnet |
|
Mel spectrogram after Postnet with shape |
|
`(n_batch, n_mels, max of mel_specgram_lengths)`. |
|
gate_outputs |
|
The output for stop token at each time step with shape |
|
`(n_batch, max of mel_specgram_lengths)`. |
|
alignments |
|
Sequence of attention weights from the decoder with shape |
|
`(n_batch, max of mel_specgram_lengths, max of token_lengths)`. |
|
""" |
|
|
|
mel_specgram: Tensor = None |
|
mel_specgram_postnet: Tensor = None |
|
gate_outputs: Tensor = None |
|
alignments: Tensor = None |
|
loss: Optional[Tensor] = None |
|
mel_loss: Optional[Tensor] = None |
|
mel_postnet_loss: Optional[Tensor] = None |
|
gate_loss: Optional[Tensor] = None |
|
|
|
|
|
class Tacotron2Config(PretrainedConfig): |
|
def __init__( |
|
self, |
|
mask_padding: bool = False, |
|
n_mels: int = 80, |
|
n_symbol: int = 392, |
|
n_frames_per_step: int = 1, |
|
symbol_embedding_dim: int = 512, |
|
encoder_embedding_dim: int = 512, |
|
encoder_n_convolution: int = 3, |
|
encoder_kernel_size: int = 5, |
|
decoder_rnn_dim: int = 1024, |
|
decoder_max_step: int = 2000, |
|
decoder_dropout: float = 0.1, |
|
decoder_early_stopping: bool = True, |
|
attention_rnn_dim: int = 1024, |
|
attention_hidden_dim: int = 128, |
|
attention_location_n_filter: int = 32, |
|
attention_location_kernel_size: int = 31, |
|
attention_dropout: float = 0.1, |
|
prenet_dim: int = 256, |
|
postnet_n_convolution: int = 5, |
|
postnet_kernel_size: int = 5, |
|
postnet_embedding_dim: int = 512, |
|
gate_threshold: float = 0.5, |
|
**kwargs, |
|
): |
|
|
|
if n_frames_per_step != 1: |
|
raise ValueError( |
|
f"n_frames_per_step: only 1 is supported, got {n_frames_per_step}" |
|
) |
|
|
|
self.mask_padding = mask_padding |
|
self.n_mels = n_mels |
|
self.n_symbol = n_symbol |
|
self.n_frames_per_step = n_frames_per_step |
|
self.symbol_embedding_dim = symbol_embedding_dim |
|
self.encoder_embedding_dim = encoder_embedding_dim |
|
self.encoder_n_convolution = encoder_n_convolution |
|
self.encoder_kernel_size = encoder_kernel_size |
|
self.decoder_rnn_dim = decoder_rnn_dim |
|
self.decoder_max_step = decoder_max_step |
|
self.decoder_dropout = decoder_dropout |
|
self.decoder_early_stopping = decoder_early_stopping |
|
self.attention_rnn_dim = attention_rnn_dim |
|
self.attention_hidden_dim = attention_hidden_dim |
|
self.attention_location_n_filter = attention_location_n_filter |
|
self.attention_location_kernel_size = attention_location_kernel_size |
|
self.attention_dropout = attention_dropout |
|
self.prenet_dim = prenet_dim |
|
self.postnet_n_convolution = postnet_n_convolution |
|
self.postnet_kernel_size = postnet_kernel_size |
|
self.postnet_embedding_dim = postnet_embedding_dim |
|
self.gate_threshold = gate_threshold |
|
super().__init__(**kwargs) |
|
|
|
|
|
class Tacotron2PreTrainedModel(PreTrainedModel): |
|
config_class = Tacotron2Config |
|
base_model_prefix = "tacotron2" |
|
main_input_name = "input_ids" |
|
|
|
|
|
class Tacotron2Model(Tacotron2PreTrainedModel): |
|
def __init__(self, config: Tacotron2Config): |
|
super().__init__(config) |
|
self.tacotron2 = Tacotron2( |
|
mask_padding=config.mask_padding, |
|
n_mels=config.n_mels, |
|
n_symbol=config.n_symbol, |
|
n_frames_per_step=config.n_frames_per_step, |
|
symbol_embedding_dim=config.symbol_embedding_dim, |
|
encoder_embedding_dim=config.encoder_embedding_dim, |
|
encoder_n_convolution=config.encoder_n_convolution, |
|
encoder_kernel_size=config.encoder_kernel_size, |
|
decoder_rnn_dim=config.decoder_rnn_dim, |
|
decoder_max_step=config.decoder_max_step, |
|
decoder_dropout=config.decoder_dropout, |
|
decoder_early_stopping=config.decoder_early_stopping, |
|
attention_rnn_dim=config.attention_rnn_dim, |
|
attention_hidden_dim=config.attention_hidden_dim, |
|
attention_location_n_filter=config.attention_location_n_filter, |
|
attention_location_kernel_size=config.attention_location_kernel_size, |
|
attention_dropout=config.attention_dropout, |
|
prenet_dim=config.prenet_dim, |
|
postnet_n_convolution=config.postnet_n_convolution, |
|
postnet_kernel_size=config.postnet_kernel_size, |
|
postnet_embedding_dim=config.postnet_embedding_dim, |
|
gate_threshold=config.gate_threshold, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Tensor, |
|
length: Optional[Tensor] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
r""" |
|
Using Tacotron2 for inference. The input is a batch of encoded |
|
sentences (``tokens``) and its corresponding lengths (``lengths``). The |
|
output is the generated mel spectrograms, its corresponding lengths, and |
|
the attention weights from the decoder. |
|
|
|
The input `tokens` should be padded with zeros to length max of ``lengths``. |
|
|
|
Args: |
|
tokens (Tensor): |
|
The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`. |
|
lengths (Tensor or None, optional): |
|
The valid length of each sample in ``tokens`` with shape `(n_batch, )`. |
|
If ``None``, it is assumed that the all the tokens are valid. |
|
Default: ``None`` |
|
|
|
Returns: |
|
(Tensor, Tensor, Tensor): |
|
Tensor |
|
The predicted mel spectrogram with shape |
|
`(n_batch, n_mels, max of mel_specgram_lengths)`. |
|
Tensor |
|
The length of the predicted mel spectrogram with shape |
|
`(n_batch, )`. |
|
Tensor |
|
Sequence of attention weights from the decoder with shape |
|
`(n_batch, max of mel_specgram_lengths, max of lengths)`. |
|
""" |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
outputs = self.tacotron2.infer(tokens=input_ids, lengths=length) |
|
|
|
if not return_dict: |
|
return outputs |
|
|
|
return Tacotron2Output( |
|
mel_outputs_postnet=outputs[0], |
|
mel_specgram_lengths=outputs[1], |
|
alignments=outputs[2], |
|
) |
|
|
|
|
|
class Tacotron2Loss(nn.Module): |
|
"""Tacotron2 loss function modified from: |
|
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py # noqa |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.mse_loss = nn.MSELoss(reduction="mean") |
|
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean") |
|
|
|
def forward( |
|
self, |
|
model_outputs: Tuple[Tensor, Tensor, Tensor], |
|
targets: Tuple[Tensor, Tensor], |
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
r"""Pass the input through the Tacotron2 loss. |
|
The original implementation was introduced in |
|
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* |
|
[:footcite:`shen2018natural`]. |
|
Args: |
|
model_outputs (tuple of three Tensors): The outputs of the |
|
Tacotron2. These outputs should include three items: |
|
(1) the predicted mel spectrogram before the postnet (``mel_specgram``) |
|
with shape (batch, mel, time). |
|
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) # noqa |
|
with shape (batch, mel, time), and |
|
(3) the stop token prediction (``gate_out``) with shape (batch, ). |
|
targets (tuple of two Tensors): |
|
The ground truth mel spectrogram (batch, mel, time) and |
|
stop token with shape (batch, ). |
|
|
|
Returns: |
|
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram # noqa |
|
with shape ``torch.Size([])``. |
|
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and |
|
ground truth mel spectrogram with shape ``torch.Size([])``. |
|
gate_loss (Tensor): The mean binary cross entropy loss of |
|
the prediction on the stop token with shape ``torch.Size([])``. |
|
""" |
|
mel_target, gate_target = targets[0], targets[1] |
|
gate_target = gate_target.view(-1, 1) |
|
|
|
mel_specgram, mel_specgram_postnet, gate_out = model_outputs |
|
gate_out = gate_out.view(-1, 1) |
|
mel_loss = self.mse_loss(mel_specgram, mel_target) |
|
mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target) |
|
gate_loss = self.bce_loss(gate_out, gate_target) |
|
return mel_loss, mel_postnet_loss, gate_loss |
|
|
|
|
|
class Tacotron2ForPreTraining(Tacotron2PreTrainedModel): |
|
def __init__(self, config: Tacotron2Config): |
|
super().__init__(config) |
|
self.tacotron2 = Tacotron2( |
|
mask_padding=config.mask_padding, |
|
n_mels=config.n_mels, |
|
n_symbol=config.n_symbol, |
|
n_frames_per_step=config.n_frames_per_step, |
|
symbol_embedding_dim=config.symbol_embedding_dim, |
|
encoder_embedding_dim=config.encoder_embedding_dim, |
|
encoder_n_convolution=config.encoder_n_convolution, |
|
encoder_kernel_size=config.encoder_kernel_size, |
|
decoder_rnn_dim=config.decoder_rnn_dim, |
|
decoder_max_step=config.decoder_max_step, |
|
decoder_dropout=config.decoder_dropout, |
|
decoder_early_stopping=config.decoder_early_stopping, |
|
attention_rnn_dim=config.attention_rnn_dim, |
|
attention_hidden_dim=config.attention_hidden_dim, |
|
attention_location_n_filter=config.attention_location_n_filter, |
|
attention_location_kernel_size=config.attention_location_kernel_size, |
|
attention_dropout=config.attention_dropout, |
|
prenet_dim=config.prenet_dim, |
|
postnet_n_convolution=config.postnet_n_convolution, |
|
postnet_kernel_size=config.postnet_kernel_size, |
|
postnet_embedding_dim=config.postnet_embedding_dim, |
|
gate_threshold=config.gate_threshold, |
|
) |
|
|
|
self.loss_fct = Tacotron2Loss() |
|
|
|
def sync_batchnorm(self): |
|
self.tacotron2 = nn.SyncBatchNorm.convert_sync_batchnorm(self.tacotron2) |
|
|
|
def forward( |
|
self, |
|
input_ids: Tensor, |
|
length: Tensor, |
|
mel_specgram: Tensor, |
|
mel_specgram_length: Tensor, |
|
gate_padded: Optional[Tensor] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
outputs = self.tacotron2( |
|
tokens=input_ids, |
|
token_lengths=length, |
|
mel_specgram=mel_specgram, |
|
mel_specgram_lengths=mel_specgram_length, |
|
) |
|
|
|
loss = mel_loss = mel_postnet_loss = gate_loss = None |
|
if gate_padded is not None: |
|
targets = (mel_specgram, gate_padded) |
|
targets[0].requires_grad = False |
|
targets[1].requires_grad = False |
|
mel_loss, mel_postnet_loss, gate_loss = self.loss_fct(outputs[:3], targets) |
|
loss = mel_loss + mel_postnet_loss + gate_loss |
|
|
|
if not return_dict: |
|
if loss is not None: |
|
return outputs + (loss, mel_loss, mel_postnet_loss, gate_loss) |
|
return outputs |
|
|
|
return Tacotron2ForPreTrainingOutput( |
|
mel_specgram=outputs[0], |
|
mel_specgram_postnet=outputs[1], |
|
gate_outputs=outputs[2], |
|
alignments=outputs[3], |
|
loss=loss, |
|
mel_loss=mel_loss, |
|
mel_postnet_loss=mel_postnet_loss, |
|
gate_loss=gate_loss, |
|
) |
|
|