File size: 13,115 Bytes
d07276d 589d655 d07276d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torchaudio.models import Tacotron2
from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils import ModelOutput
__version__ = "0.1.0"
@dataclass
class Tacotron2Output(ModelOutput):
"""
mel_outputs_postnet
The predicted mel spectrogram with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths
The length of the predicted mel spectrogram with shape `(n_batch, )`.
alignments
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
mel_outputs_postnet: Tensor = None
mel_specgram_lengths: Tensor = None
alignments: Tensor = None
@dataclass
class Tacotron2ForPreTrainingOutput(ModelOutput):
"""
mel_specgram
Mel spectrogram before Postnet with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_postnet
Mel spectrogram after Postnet with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
gate_outputs
The output for stop token at each time step with shape
`(n_batch, max of mel_specgram_lengths)`.
alignments
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
"""
mel_specgram: Tensor = None
mel_specgram_postnet: Tensor = None
gate_outputs: Tensor = None
alignments: Tensor = None
loss: Optional[Tensor] = None
mel_loss: Optional[Tensor] = None
mel_postnet_loss: Optional[Tensor] = None
gate_loss: Optional[Tensor] = None
class Tacotron2Config(PretrainedConfig):
def __init__(
self,
mask_padding: bool = False,
n_mels: int = 80,
n_symbol: int = 392,
n_frames_per_step: int = 1,
symbol_embedding_dim: int = 512,
encoder_embedding_dim: int = 512,
encoder_n_convolution: int = 3,
encoder_kernel_size: int = 5,
decoder_rnn_dim: int = 1024,
decoder_max_step: int = 2000,
decoder_dropout: float = 0.1,
decoder_early_stopping: bool = True,
attention_rnn_dim: int = 1024,
attention_hidden_dim: int = 128,
attention_location_n_filter: int = 32,
attention_location_kernel_size: int = 31,
attention_dropout: float = 0.1,
prenet_dim: int = 256,
postnet_n_convolution: int = 5,
postnet_kernel_size: int = 5,
postnet_embedding_dim: int = 512,
gate_threshold: float = 0.5,
**kwargs,
):
# https://pytorch.org/audio/stable/generated/torchaudio.models.Tacotron2.html#torchaudio.models.Tacotron2 # noqa
if n_frames_per_step != 1:
raise ValueError(
f"n_frames_per_step: only 1 is supported, got {n_frames_per_step}"
)
self.mask_padding = mask_padding
self.n_mels = n_mels
self.n_symbol = n_symbol
self.n_frames_per_step = n_frames_per_step
self.symbol_embedding_dim = symbol_embedding_dim
self.encoder_embedding_dim = encoder_embedding_dim
self.encoder_n_convolution = encoder_n_convolution
self.encoder_kernel_size = encoder_kernel_size
self.decoder_rnn_dim = decoder_rnn_dim
self.decoder_max_step = decoder_max_step
self.decoder_dropout = decoder_dropout
self.decoder_early_stopping = decoder_early_stopping
self.attention_rnn_dim = attention_rnn_dim
self.attention_hidden_dim = attention_hidden_dim
self.attention_location_n_filter = attention_location_n_filter
self.attention_location_kernel_size = attention_location_kernel_size
self.attention_dropout = attention_dropout
self.prenet_dim = prenet_dim
self.postnet_n_convolution = postnet_n_convolution
self.postnet_kernel_size = postnet_kernel_size
self.postnet_embedding_dim = postnet_embedding_dim
self.gate_threshold = gate_threshold
super().__init__(**kwargs)
class Tacotron2PreTrainedModel(PreTrainedModel):
config_class = Tacotron2Config
base_model_prefix = "tacotron2"
main_input_name = "input_ids"
class Tacotron2Model(Tacotron2PreTrainedModel):
def __init__(self, config: Tacotron2Config):
super().__init__(config)
self.tacotron2 = Tacotron2(
mask_padding=config.mask_padding,
n_mels=config.n_mels,
n_symbol=config.n_symbol,
n_frames_per_step=config.n_frames_per_step,
symbol_embedding_dim=config.symbol_embedding_dim,
encoder_embedding_dim=config.encoder_embedding_dim,
encoder_n_convolution=config.encoder_n_convolution,
encoder_kernel_size=config.encoder_kernel_size,
decoder_rnn_dim=config.decoder_rnn_dim,
decoder_max_step=config.decoder_max_step,
decoder_dropout=config.decoder_dropout,
decoder_early_stopping=config.decoder_early_stopping,
attention_rnn_dim=config.attention_rnn_dim,
attention_hidden_dim=config.attention_hidden_dim,
attention_location_n_filter=config.attention_location_n_filter,
attention_location_kernel_size=config.attention_location_kernel_size,
attention_dropout=config.attention_dropout,
prenet_dim=config.prenet_dim,
postnet_n_convolution=config.postnet_n_convolution,
postnet_kernel_size=config.postnet_kernel_size,
postnet_embedding_dim=config.postnet_embedding_dim,
gate_threshold=config.gate_threshold,
)
def forward(
self,
input_ids: Tensor,
length: Optional[Tensor] = None,
return_dict: Optional[bool] = None,
):
r"""
Using Tacotron2 for inference. The input is a batch of encoded
sentences (``tokens``) and its corresponding lengths (``lengths``). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `tokens` should be padded with zeros to length max of ``lengths``.
Args:
tokens (Tensor):
The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
lengths (Tensor or None, optional):
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid.
Default: ``None``
Returns:
(Tensor, Tensor, Tensor):
Tensor
The predicted mel spectrogram with shape
`(n_batch, n_mels, max of mel_specgram_lengths)`.
Tensor
The length of the predicted mel spectrogram with shape
`(n_batch, )`.
Tensor
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.tacotron2.infer(tokens=input_ids, lengths=length)
if not return_dict:
return outputs
return Tacotron2Output(
mel_outputs_postnet=outputs[0],
mel_specgram_lengths=outputs[1],
alignments=outputs[2],
)
class Tacotron2Loss(nn.Module):
"""Tacotron2 loss function modified from:
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py # noqa
"""
def __init__(self):
super().__init__()
self.mse_loss = nn.MSELoss(reduction="mean")
self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
def forward(
self,
model_outputs: Tuple[Tensor, Tensor, Tensor],
targets: Tuple[Tensor, Tensor],
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Pass the input through the Tacotron2 loss.
The original implementation was introduced in
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
[:footcite:`shen2018natural`].
Args:
model_outputs (tuple of three Tensors): The outputs of the
Tacotron2. These outputs should include three items:
(1) the predicted mel spectrogram before the postnet (``mel_specgram``)
with shape (batch, mel, time).
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) # noqa
with shape (batch, mel, time), and
(3) the stop token prediction (``gate_out``) with shape (batch, ).
targets (tuple of two Tensors):
The ground truth mel spectrogram (batch, mel, time) and
stop token with shape (batch, ).
Returns:
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram # noqa
with shape ``torch.Size([])``.
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
ground truth mel spectrogram with shape ``torch.Size([])``.
gate_loss (Tensor): The mean binary cross entropy loss of
the prediction on the stop token with shape ``torch.Size([])``.
"""
mel_target, gate_target = targets[0], targets[1]
gate_target = gate_target.view(-1, 1)
mel_specgram, mel_specgram_postnet, gate_out = model_outputs
gate_out = gate_out.view(-1, 1)
mel_loss = self.mse_loss(mel_specgram, mel_target)
mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target)
gate_loss = self.bce_loss(gate_out, gate_target)
return mel_loss, mel_postnet_loss, gate_loss
class Tacotron2ForPreTraining(Tacotron2PreTrainedModel):
def __init__(self, config: Tacotron2Config):
super().__init__(config)
self.tacotron2 = Tacotron2(
mask_padding=config.mask_padding,
n_mels=config.n_mels,
n_symbol=config.n_symbol,
n_frames_per_step=config.n_frames_per_step,
symbol_embedding_dim=config.symbol_embedding_dim,
encoder_embedding_dim=config.encoder_embedding_dim,
encoder_n_convolution=config.encoder_n_convolution,
encoder_kernel_size=config.encoder_kernel_size,
decoder_rnn_dim=config.decoder_rnn_dim,
decoder_max_step=config.decoder_max_step,
decoder_dropout=config.decoder_dropout,
decoder_early_stopping=config.decoder_early_stopping,
attention_rnn_dim=config.attention_rnn_dim,
attention_hidden_dim=config.attention_hidden_dim,
attention_location_n_filter=config.attention_location_n_filter,
attention_location_kernel_size=config.attention_location_kernel_size,
attention_dropout=config.attention_dropout,
prenet_dim=config.prenet_dim,
postnet_n_convolution=config.postnet_n_convolution,
postnet_kernel_size=config.postnet_kernel_size,
postnet_embedding_dim=config.postnet_embedding_dim,
gate_threshold=config.gate_threshold,
)
self.loss_fct = Tacotron2Loss()
def sync_batchnorm(self):
self.tacotron2 = nn.SyncBatchNorm.convert_sync_batchnorm(self.tacotron2)
def forward(
self,
input_ids: Tensor,
length: Tensor,
mel_specgram: Tensor,
mel_specgram_length: Tensor,
gate_padded: Optional[Tensor] = None,
return_dict: Optional[bool] = None,
):
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.tacotron2(
tokens=input_ids,
token_lengths=length,
mel_specgram=mel_specgram,
mel_specgram_lengths=mel_specgram_length,
)
loss = mel_loss = mel_postnet_loss = gate_loss = None
if gate_padded is not None:
targets = (mel_specgram, gate_padded)
targets[0].requires_grad = False
targets[1].requires_grad = False
mel_loss, mel_postnet_loss, gate_loss = self.loss_fct(outputs[:3], targets)
loss = mel_loss + mel_postnet_loss + gate_loss
if not return_dict:
if loss is not None:
return outputs + (loss, mel_loss, mel_postnet_loss, gate_loss)
return outputs
return Tacotron2ForPreTrainingOutput(
mel_specgram=outputs[0],
mel_specgram_postnet=outputs[1],
gate_outputs=outputs[2],
alignments=outputs[3],
loss=loss,
mel_loss=mel_loss,
mel_postnet_loss=mel_postnet_loss,
gate_loss=gate_loss,
)
|