torchaudio_tacotron2_kss / modeling_tacotron2.py
Bingsu's picture
version: 0.1.0
589d655
raw
history blame
13.1 kB
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torchaudio.models import Tacotron2
from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils import ModelOutput
__version__ = "0.1.0"
@dataclass
class Tacotron2Output(ModelOutput):
"""
mel_outputs_postnet
The predicted mel spectrogram with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths
The length of the predicted mel spectrogram with shape `(n_batch, )`.
alignments
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
mel_outputs_postnet: Tensor = None
mel_specgram_lengths: Tensor = None
alignments: Tensor = None
@dataclass
class Tacotron2ForPreTrainingOutput(ModelOutput):
"""
mel_specgram
Mel spectrogram before Postnet with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_postnet
Mel spectrogram after Postnet with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
gate_outputs
The output for stop token at each time step with shape
`(n_batch, max of mel_specgram_lengths)`.
alignments
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
"""
mel_specgram: Tensor = None
mel_specgram_postnet: Tensor = None
gate_outputs: Tensor = None
alignments: Tensor = None
loss: Optional[Tensor] = None
mel_loss: Optional[Tensor] = None
mel_postnet_loss: Optional[Tensor] = None
gate_loss: Optional[Tensor] = None
class Tacotron2Config(PretrainedConfig):
def __init__(
self,
mask_padding: bool = False,
n_mels: int = 80,
n_symbol: int = 392,
n_frames_per_step: int = 1,
symbol_embedding_dim: int = 512,
encoder_embedding_dim: int = 512,
encoder_n_convolution: int = 3,
encoder_kernel_size: int = 5,
decoder_rnn_dim: int = 1024,
decoder_max_step: int = 2000,
decoder_dropout: float = 0.1,
decoder_early_stopping: bool = True,
attention_rnn_dim: int = 1024,
attention_hidden_dim: int = 128,
attention_location_n_filter: int = 32,
attention_location_kernel_size: int = 31,
attention_dropout: float = 0.1,
prenet_dim: int = 256,
postnet_n_convolution: int = 5,
postnet_kernel_size: int = 5,
postnet_embedding_dim: int = 512,
gate_threshold: float = 0.5,
**kwargs,
):
# https://pytorch.org/audio/stable/generated/torchaudio.models.Tacotron2.html#torchaudio.models.Tacotron2 # noqa
if n_frames_per_step != 1:
raise ValueError(
f"n_frames_per_step: only 1 is supported, got {n_frames_per_step}"
)
self.mask_padding = mask_padding
self.n_mels = n_mels
self.n_symbol = n_symbol
self.n_frames_per_step = n_frames_per_step
self.symbol_embedding_dim = symbol_embedding_dim
self.encoder_embedding_dim = encoder_embedding_dim
self.encoder_n_convolution = encoder_n_convolution
self.encoder_kernel_size = encoder_kernel_size
self.decoder_rnn_dim = decoder_rnn_dim
self.decoder_max_step = decoder_max_step
self.decoder_dropout = decoder_dropout
self.decoder_early_stopping = decoder_early_stopping
self.attention_rnn_dim = attention_rnn_dim
self.attention_hidden_dim = attention_hidden_dim
self.attention_location_n_filter = attention_location_n_filter
self.attention_location_kernel_size = attention_location_kernel_size
self.attention_dropout = attention_dropout
self.prenet_dim = prenet_dim
self.postnet_n_convolution = postnet_n_convolution
self.postnet_kernel_size = postnet_kernel_size
self.postnet_embedding_dim = postnet_embedding_dim
self.gate_threshold = gate_threshold
super().__init__(**kwargs)
class Tacotron2PreTrainedModel(PreTrainedModel):
config_class = Tacotron2Config
base_model_prefix = "tacotron2"
main_input_name = "input_ids"
class Tacotron2Model(Tacotron2PreTrainedModel):
def __init__(self, config: Tacotron2Config):
super().__init__(config)
self.tacotron2 = Tacotron2(
mask_padding=config.mask_padding,
n_mels=config.n_mels,
n_symbol=config.n_symbol,
n_frames_per_step=config.n_frames_per_step,
symbol_embedding_dim=config.symbol_embedding_dim,
encoder_embedding_dim=config.encoder_embedding_dim,
encoder_n_convolution=config.encoder_n_convolution,
encoder_kernel_size=config.encoder_kernel_size,
decoder_rnn_dim=config.decoder_rnn_dim,
decoder_max_step=config.decoder_max_step,
decoder_dropout=config.decoder_dropout,
decoder_early_stopping=config.decoder_early_stopping,
attention_rnn_dim=config.attention_rnn_dim,
attention_hidden_dim=config.attention_hidden_dim,
attention_location_n_filter=config.attention_location_n_filter,
attention_location_kernel_size=config.attention_location_kernel_size,
attention_dropout=config.attention_dropout,
prenet_dim=config.prenet_dim,
postnet_n_convolution=config.postnet_n_convolution,
postnet_kernel_size=config.postnet_kernel_size,
postnet_embedding_dim=config.postnet_embedding_dim,
gate_threshold=config.gate_threshold,
)
def forward(
self,
input_ids: Tensor,
length: Optional[Tensor] = None,
return_dict: Optional[bool] = None,
):
r"""
Using Tacotron2 for inference. The input is a batch of encoded
sentences (``tokens``) and its corresponding lengths (``lengths``). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `tokens` should be padded with zeros to length max of ``lengths``.
Args:
tokens (Tensor):
The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
lengths (Tensor or None, optional):
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid.
Default: ``None``
Returns:
(Tensor, Tensor, Tensor):
Tensor
The predicted mel spectrogram with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
Tensor
The length of the predicted mel spectrogram with shape
`(n_batch, )`.
Tensor
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.tacotron2.infer(tokens=input_ids, lengths=length)
if not return_dict:
return outputs
return Tacotron2Output(
mel_outputs_postnet=outputs[0],
mel_specgram_lengths=outputs[1],
alignments=outputs[2],
)
class Tacotron2Loss(nn.Module):
"""Tacotron2 loss function modified from:
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py # noqa
"""
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss(reduction="mean")
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
def forward(
self,
model_outputs: Tuple[Tensor, Tensor, Tensor],
targets: Tuple[Tensor, Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Pass the input through the Tacotron2 loss.
The original implementation was introduced in
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
[:footcite:`shen2018natural`].
Args:
model_outputs (tuple of three Tensors): The outputs of the
Tacotron2. These outputs should include three items:
(1) the predicted mel spectrogram before the postnet (``mel_specgram``)
with shape (batch, mel, time).
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) # noqa
with shape (batch, mel, time), and
(3) the stop token prediction (``gate_out``) with shape (batch, ).
targets (tuple of two Tensors):
The ground truth mel spectrogram (batch, mel, time) and
stop token with shape (batch, ).
Returns:
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram # noqa
with shape ``torch.Size([])``.
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
ground truth mel spectrogram with shape ``torch.Size([])``.
gate_loss (Tensor): The mean binary cross entropy loss of
the prediction on the stop token with shape ``torch.Size([])``.
"""
mel_target, gate_target = targets[0], targets[1]
gate_target = gate_target.view(-1, 1)
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
gate_out = gate_out.view(-1, 1)
mel_loss = self.mse_loss(mel_specgram, mel_target)
mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target)
gate_loss = self.bce_loss(gate_out, gate_target)
return mel_loss, mel_postnet_loss, gate_loss
class Tacotron2ForPreTraining(Tacotron2PreTrainedModel):
def __init__(self, config: Tacotron2Config):
super().__init__(config)
self.tacotron2 = Tacotron2(
mask_padding=config.mask_padding,
n_mels=config.n_mels,
n_symbol=config.n_symbol,
n_frames_per_step=config.n_frames_per_step,
symbol_embedding_dim=config.symbol_embedding_dim,
encoder_embedding_dim=config.encoder_embedding_dim,
encoder_n_convolution=config.encoder_n_convolution,
encoder_kernel_size=config.encoder_kernel_size,
decoder_rnn_dim=config.decoder_rnn_dim,
decoder_max_step=config.decoder_max_step,
decoder_dropout=config.decoder_dropout,
decoder_early_stopping=config.decoder_early_stopping,
attention_rnn_dim=config.attention_rnn_dim,
attention_hidden_dim=config.attention_hidden_dim,
attention_location_n_filter=config.attention_location_n_filter,
attention_location_kernel_size=config.attention_location_kernel_size,
attention_dropout=config.attention_dropout,
prenet_dim=config.prenet_dim,
postnet_n_convolution=config.postnet_n_convolution,
postnet_kernel_size=config.postnet_kernel_size,
postnet_embedding_dim=config.postnet_embedding_dim,
gate_threshold=config.gate_threshold,
)
self.loss_fct = Tacotron2Loss()
def sync_batchnorm(self):
self.tacotron2 = nn.SyncBatchNorm.convert_sync_batchnorm(self.tacotron2)
def forward(
self,
input_ids: Tensor,
length: Tensor,
mel_specgram: Tensor,
mel_specgram_length: Tensor,
gate_padded: Optional[Tensor] = None,
return_dict: Optional[bool] = None,
):
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.tacotron2(
tokens=input_ids,
token_lengths=length,
mel_specgram=mel_specgram,
mel_specgram_lengths=mel_specgram_length,
)
loss = mel_loss = mel_postnet_loss = gate_loss = None
if gate_padded is not None:
targets = (mel_specgram, gate_padded)
targets[0].requires_grad = False
targets[1].requires_grad = False
mel_loss, mel_postnet_loss, gate_loss = self.loss_fct(outputs[:3], targets)
loss = mel_loss + mel_postnet_loss + gate_loss
if not return_dict:
if loss is not None:
return outputs + (loss, mel_loss, mel_postnet_loss, gate_loss)
return outputs
return Tacotron2ForPreTrainingOutput(
mel_specgram=outputs[0],
mel_specgram_postnet=outputs[1],
gate_outputs=outputs[2],
alignments=outputs[3],
loss=loss,
mel_loss=mel_loss,
mel_postnet_loss=mel_postnet_loss,
gate_loss=gate_loss,
)