conex / espnet2 /tts /fastespeech.py
tobiasc's picture
Initial commit
ad16788
""" FastESpeech """
from typing import Dict
from typing import Sequence
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from espnet.nets.pytorch_backend.e2e_tts_fastspeech import (
FeedForwardTransformerLoss as FastSpeechLoss, # NOQA
)
from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor
from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding
from espnet.nets.pytorch_backend.transformer.encoder import (
Encoder as TransformerEncoder, # noqa: H301
)
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.torch_utils.initialize import initialize
from espnet2.tts.abs_tts import AbsTTS
from espnet2.tts.prosody_encoder import ProsodyEncoder
class FastESpeech(AbsTTS):
"""FastESpeech module.
This module adds a VQ-VAE prosody encoder to the FastSpeech model, and
takes cues from FastSpeech 2 for training.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/abs/1905.09263
.. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
https://arxiv.org/abs/2006.04558
Args:
idim (int): Dimension of the input -> size of the phoneme vocabulary.
odim (int): Dimension of the output -> dimension of the mel-spectrograms.
adim (int, optional): Dimension of the phoneme embeddings, dimension of the
prosody embedding, the hidden size of the self-attention, 1D convolution
in the FFT block.
aheads (int, optional): Number of attention heads.
elayers (int, optional): Number of encoder layers/blocks.
eunits (int, optional): Number of encoder hidden units
-> The number of units of position-wise feed forward layer.
dlayers (int, optional): Number of decoder layers/blocks.
dunits (int, optional): Number of decoder hidden units
-> The number of units of position-wise feed forward layer.
positionwise_layer_type (str, optional): Type of position-wise feed forward
layer - linear or conv1d.
positionwise_conv_kernel_size (int, optional): kernel size of positionwise
conv1d layer.
use_scaled_pos_enc (bool, optional):
Whether to use trainable scaled positional encoding.
encoder_normalize_before (bool, optional):
Whether to perform layer normalization before encoder block.
decoder_normalize_before (bool, optional):
Whether to perform layer normalization before decoder block.
encoder_concat_after (bool, optional): Whether to concatenate attention
layer's input and output in encoder.
decoder_concat_after (bool, optional): Whether to concatenate attention
layer's input and output in decoder.
duration_predictor_layers (int, optional): Number of duration predictor layers.
duration_predictor_chans (int, optional): Number of duration predictor channels.
duration_predictor_kernel_size (int, optional):
Kernel size of duration predictor.
reduction_factor (int, optional): Factor to multiply with output dimension.
encoder_type (str, optional): Encoder architecture type.
decoder_type (str, optional): Decoder architecture type.
# spk_embed_dim (int, optional): Number of speaker embedding dimensions.
# spk_embed_integration_type: How to integrate speaker embedding.
ref_enc_conv_layers (int, optional):
The number of conv layers in the reference encoder.
ref_enc_conv_chans_list: (Sequence[int], optional):
List of the number of channels of conv layers in the referece encoder.
ref_enc_conv_kernel_size (int, optional):
Kernal size of conv layers in the reference encoder.
ref_enc_conv_stride (int, optional):
Stride size of conv layers in the reference encoder.
ref_enc_gru_layers (int, optional):
The number of GRU layers in the reference encoder.
ref_enc_gru_units (int, optional):
The number of GRU units in the reference encoder.
ref_emb_integration_type: How to integrate reference embedding.
# reduction_factor (int, optional): Reduction factor.
prosody_num_embs (int, optional): The higher this value, the higher the
capacity in the information bottleneck.
prosody_hidden_dim (int, optional): Number of hidden channels.
prosody_emb_integration_type: How to integrate prosody embedding.
transformer_enc_dropout_rate (float, optional):
Dropout rate in encoder except attention & positional encoding.
transformer_enc_positional_dropout_rate (float, optional):
Dropout rate after encoder positional encoding.
transformer_enc_attn_dropout_rate (float, optional):
Dropout rate in encoder self-attention module.
transformer_dec_dropout_rate (float, optional):
Dropout rate in decoder except attention & positional encoding.
transformer_dec_positional_dropout_rate (float, optional):
Dropout rate after decoder positional encoding.
transformer_dec_attn_dropout_rate (float, optional):
Dropout rate in decoder self-attention module.
duration_predictor_dropout_rate (float, optional):
Dropout rate in duration predictor.
init_type (str, optional):
How to initialize transformer parameters.
init_enc_alpha (float, optional):
Initial value of alpha in scaled pos encoding of the encoder.
init_dec_alpha (float, optional):
Initial value of alpha in scaled pos encoding of the decoder.
use_masking (bool, optional):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool, optional):
Whether to apply weighted masking in loss calculation.
"""
def __init__(
self,
# network structure related
idim: int,
odim: int,
adim: int = 384,
aheads: int = 4,
elayers: int = 6,
eunits: int = 1536,
dlayers: int = 6,
dunits: int = 1536,
postnet_layers: int = 0, # 5
postnet_chans: int = 512,
postnet_filts: int = 5,
positionwise_layer_type: str = "conv1d",
positionwise_conv_kernel_size: int = 1,
use_scaled_pos_enc: bool = True,
use_batch_norm: bool = True,
encoder_normalize_before: bool = True,
decoder_normalize_before: bool = True,
encoder_concat_after: bool = False,
decoder_concat_after: bool = False,
duration_predictor_layers: int = 2,
duration_predictor_chans: int = 384,
duration_predictor_kernel_size: int = 3,
reduction_factor: int = 1,
encoder_type: str = "transformer",
decoder_type: str = "transformer",
# # only for conformer
# conformer_pos_enc_layer_type: str = "rel_pos",
# conformer_self_attn_layer_type: str = "rel_selfattn",
# conformer_activation_type: str = "swish",
# use_macaron_style_in_conformer: bool = True,
# use_cnn_in_conformer: bool = True,
# conformer_enc_kernel_size: int = 7,
# conformer_dec_kernel_size: int = 31,
# # pretrained spk emb
# spk_embed_dim: int = None,
# spk_embed_integration_type: str = "add",
# reference encoder
ref_enc_conv_layers: int = 2,
ref_enc_conv_chans_list: Sequence[int] = (32, 32),
ref_enc_conv_kernel_size: int = 3,
ref_enc_conv_stride: int = 1,
ref_enc_gru_layers: int = 1,
ref_enc_gru_units: int = 32,
ref_emb_integration_type: str = "add",
# prosody encoder
prosody_num_embs: int = 256,
prosody_hidden_dim: int = 128,
prosody_emb_integration_type: str = "add",
# training related
transformer_enc_dropout_rate: float = 0.1,
transformer_enc_positional_dropout_rate: float = 0.1,
transformer_enc_attn_dropout_rate: float = 0.1,
transformer_dec_dropout_rate: float = 0.1,
transformer_dec_positional_dropout_rate: float = 0.1,
transformer_dec_attn_dropout_rate: float = 0.1,
duration_predictor_dropout_rate: float = 0.1,
postnet_dropout_rate: float = 0.5,
init_type: str = "xavier_uniform",
init_enc_alpha: float = 1.0,
init_dec_alpha: float = 1.0,
use_masking: bool = False,
use_weighted_masking: bool = False,
):
"""Initialize FastESpeech module."""
assert check_argument_types()
super().__init__()
# store hyperparameters
self.idim = idim
self.odim = odim
self.eos = idim - 1
self.reduction_factor = reduction_factor
self.encoder_type = encoder_type
self.decoder_type = decoder_type
self.use_scaled_pos_enc = use_scaled_pos_enc
self.prosody_emb_integration_type = prosody_emb_integration_type
# self.spk_embed_dim = spk_embed_dim
# if self.spk_embed_dim is not None:
# self.spk_embed_integration_type = spk_embed_integration_type
# use idx 0 as padding idx, see:
# https://stackoverflow.com/questions/61172400/what-does-padding-idx-do-in-nn-embeddings
self.padding_idx = 0
# get positional encoding class
pos_enc_class = (
ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
)
# define encoder
encoder_input_layer = torch.nn.Embedding(
num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx
)
if encoder_type == "transformer":
self.encoder = TransformerEncoder(
idim=idim,
attention_dim=adim,
attention_heads=aheads,
linear_units=eunits,
num_blocks=elayers,
input_layer=encoder_input_layer,
dropout_rate=transformer_enc_dropout_rate,
positional_dropout_rate=transformer_enc_positional_dropout_rate,
attention_dropout_rate=transformer_enc_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=encoder_normalize_before,
concat_after=encoder_concat_after,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
)
# elif encoder_type == "conformer":
# self.encoder = ConformerEncoder(
# idim=idim,
# attention_dim=adim,
# attention_heads=aheads,
# linear_units=eunits,
# num_blocks=elayers,
# input_layer=encoder_input_layer,
# dropout_rate=transformer_enc_dropout_rate,
# positional_dropout_rate=transformer_enc_positional_dropout_rate,
# attention_dropout_rate=transformer_enc_attn_dropout_rate,
# normalize_before=encoder_normalize_before,
# concat_after=encoder_concat_after,
# positionwise_layer_type=positionwise_layer_type,
# positionwise_conv_kernel_size=positionwise_conv_kernel_size,
# macaron_style=use_macaron_style_in_conformer,
# pos_enc_layer_type=conformer_pos_enc_layer_type,
# selfattention_layer_type=conformer_self_attn_layer_type,
# activation_type=conformer_activation_type,
# use_cnn_module=use_cnn_in_conformer,
# cnn_module_kernel=conformer_enc_kernel_size,
# )
else:
raise ValueError(f"{encoder_type} is not supported.")
# define additional projection for prosody embedding
if self.prosody_emb_integration_type == "concat":
self.prosody_projection = torch.nn.Linear(
adim * 2, adim
)
# define prosody encoder
self.prosody_encoder = ProsodyEncoder(
odim,
adim=adim,
num_embeddings=prosody_num_embs,
hidden_dim=prosody_hidden_dim,
ref_enc_conv_layers=ref_enc_conv_layers,
ref_enc_conv_chans_list=ref_enc_conv_chans_list,
ref_enc_conv_kernel_size=ref_enc_conv_kernel_size,
ref_enc_conv_stride=ref_enc_conv_stride,
global_enc_gru_layers=ref_enc_gru_layers,
global_enc_gru_units=ref_enc_gru_units,
global_emb_integration_type=ref_emb_integration_type,
)
# # define additional projection for speaker embedding
# if self.spk_embed_dim is not None:
# if self.spk_embed_integration_type == "add":
# self.projection = torch.nn.Linear(self.spk_embed_dim, adim)
# else:
# self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim)
# define duration predictor
self.duration_predictor = DurationPredictor(
idim=adim,
n_layers=duration_predictor_layers,
n_chans=duration_predictor_chans,
kernel_size=duration_predictor_kernel_size,
dropout_rate=duration_predictor_dropout_rate,
)
# define length regulator
self.length_regulator = LengthRegulator()
# define decoder
# NOTE: we use encoder as decoder
# because fastspeech's decoder is the same as encoder
if decoder_type == "transformer":
self.decoder = TransformerEncoder(
idim=0,
attention_dim=adim,
attention_heads=aheads,
linear_units=dunits,
num_blocks=dlayers,
input_layer=None,
dropout_rate=transformer_dec_dropout_rate,
positional_dropout_rate=transformer_dec_positional_dropout_rate,
attention_dropout_rate=transformer_dec_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=decoder_normalize_before,
concat_after=decoder_concat_after,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
)
# elif decoder_type == "conformer":
# self.decoder = ConformerEncoder(
# idim=0,
# attention_dim=adim,
# attention_heads=aheads,
# linear_units=dunits,
# num_blocks=dlayers,
# input_layer=None,
# dropout_rate=transformer_dec_dropout_rate,
# positional_dropout_rate=transformer_dec_positional_dropout_rate,
# attention_dropout_rate=transformer_dec_attn_dropout_rate,
# normalize_before=decoder_normalize_before,
# concat_after=decoder_concat_after,
# positionwise_layer_type=positionwise_layer_type,
# positionwise_conv_kernel_size=positionwise_conv_kernel_size,
# macaron_style=use_macaron_style_in_conformer,
# pos_enc_layer_type=conformer_pos_enc_layer_type,
# selfattention_layer_type=conformer_self_attn_layer_type,
# activation_type=conformer_activation_type,
# use_cnn_module=use_cnn_in_conformer,
# cnn_module_kernel=conformer_dec_kernel_size,
# )
else:
raise ValueError(f"{decoder_type} is not supported.")
# define final projection
self.feat_out = torch.nn.Linear(adim, odim * reduction_factor)
# define postnet
self.postnet = (
None
if postnet_layers == 0
else Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate,
)
)
# initialize parameters
self._reset_parameters(
init_type=init_type,
init_enc_alpha=init_enc_alpha,
init_dec_alpha=init_dec_alpha,
)
# define criterions
self.criterion = FastSpeechLoss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking
)
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
durations: torch.Tensor,
durations_lengths: torch.Tensor,
spembs: torch.Tensor = None,
train_ar_prior: bool = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Calculate forward propagation.
Args:
text (LongTensor): Batch of padded token ids (B, Tmax).
text_lengths (LongTensor): Batch of lengths of each input (B,).
speech (Tensor): Batch of padded target features (B, Lmax, odim).
speech_lengths (LongTensor): Batch of the lengths of each target (B,).
durations (LongTensor): Batch of padded durations (B, Tmax + 1).
durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1).
spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim).
Returns:
Tensor: Loss scalar value.
Dict: Statistics to be monitored.
Tensor: Weight value.
"""
# train_ar_prior = True # TC marker
text = text[:, : text_lengths.max()] # for data-parallel
speech = speech[:, : speech_lengths.max()] # for data-parallel
durations = durations[:, : durations_lengths.max()] # for data-parallel
batch_size = text.size(0)
# Add eos at the last of sequence
xs = F.pad(text, [0, 1], "constant", self.padding_idx)
for i, l in enumerate(text_lengths):
xs[i, l] = self.eos
ilens = text_lengths + 1
ys, ds = speech, durations
olens = speech_lengths
# forward propagation
before_outs, after_outs, d_outs, ref_embs, \
vq_loss, ar_prior_loss, perplexity = self._forward(
xs,
ilens,
ys,
olens,
ds,
spembs=spembs,
is_inference=False,
train_ar_prior=train_ar_prior
)
# modify mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_olen = max(olens)
ys = ys[:, :max_olen]
if self.postnet is None:
after_outs = None
# calculate loss TODO: refactor if freezing works
l1_loss, duration_loss = self.criterion(
after_outs, before_outs, d_outs, ys, ds, ilens, olens
)
if train_ar_prior:
loss = ar_prior_loss
stats = dict(
l1_loss=l1_loss.item(),
duration_loss=duration_loss.item(),
vq_loss=vq_loss.item(),
ar_prior_loss=ar_prior_loss.item(),
loss=loss.item(),
perplexity=perplexity.item(),
)
else :
loss = l1_loss + duration_loss + vq_loss
stats = dict(
l1_loss=l1_loss.item(),
duration_loss=duration_loss.item(),
vq_loss=vq_loss.item(),
loss=loss.item(),
perplexity=perplexity.item()
)
# report extra information
if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
stats.update(
encoder_alpha=self.encoder.embed[-1].alpha.data.item(),
)
if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
stats.update(
decoder_alpha=self.decoder.embed[-1].alpha.data.item(),
)
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def _forward(
self,
xs: torch.Tensor,
ilens: torch.Tensor,
ys: torch.Tensor = None,
olens: torch.Tensor = None,
ds: torch.Tensor = None,
spembs: torch.Tensor = None,
ref_embs: torch.Tensor = None,
is_inference: bool = False,
train_ar_prior: bool = False,
ar_prior_inference: bool = False,
alpha: float = 1.0,
fg_inds: torch.Tensor = None,
) -> Sequence[torch.Tensor]:
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)
# # integrate speaker embedding
# if self.spk_embed_dim is not None:
# hs = self._integrate_with_spk_embed(hs, spembs)
# integrate with prosody encoder
# (B, Tmax, adim)
p_embs, vq_loss, ar_prior_loss, perplexity, ref_embs = self.prosody_encoder(
ys,
ds,
hs,
global_embs=ref_embs,
train_ar_prior=train_ar_prior,
ar_prior_inference=ar_prior_inference,
fg_inds=fg_inds,
)
hs = self._integrate_with_prosody_embs(hs, p_embs)
# forward duration predictor
d_masks = make_pad_mask(ilens).to(xs.device)
if is_inference:
print('predicted durations')
d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax)
hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim)
else:
d_outs = self.duration_predictor(hs, d_masks)
# use groundtruth in training
hs = self.length_regulator(hs, ds) # (B, Lmax, adim)
# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
olens_in = olens
h_masks = self._source_mask(olens_in)
else:
h_masks = None
zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim)
before_outs = self.feat_out(zs).view(
zs.size(0), -1, self.odim
) # (B, Lmax, odim)
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose(1, 2)
).transpose(1, 2)
return before_outs, after_outs, d_outs, ref_embs, vq_loss, ar_prior_loss, \
perplexity
def inference(
self,
text: torch.Tensor,
speech: torch.Tensor = None,
spembs: torch.Tensor = None,
durations: torch.Tensor = None,
ref_embs: torch.Tensor = None,
alpha: float = 1.0,
use_teacher_forcing: bool = False,
ar_prior_inference: bool = False,
fg_inds: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate the sequence of features given the sequences of characters.
Args:
text (LongTensor): Input sequence of characters (T,).
speech (Tensor, optional): Feature sequence to extract style (B, idim).
spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,).
durations (LongTensor, optional): Groundtruth of duration (T + 1,).
ref_embs (Tensor, optional): Reference embedding vector (B, gru_units).
alpha (float, optional): Alpha to control the speed.
use_teacher_forcing (bool, optional): Whether to use teacher forcing.
If true, groundtruth of duration will be used.
Returns:
Tensor: Output sequence of features (L, odim).
None: Dummy for compatibility.
None: Dummy for compatibility.
"""
x, y = text, speech
spemb, d = spembs, durations
# add eos at the last of sequence
x = F.pad(x, [0, 1], "constant", self.eos)
# setup batch axis
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
xs, ys = x.unsqueeze(0), None
if y is not None:
ys = y.unsqueeze(0)
if spemb is not None:
spembs = spemb.unsqueeze(0)
if ref_embs is not None:
ref_embs = ref_embs.unsqueeze(0)
if use_teacher_forcing:
# use groundtruth of duration
ds = d.unsqueeze(0)
_, after_outs, _, ref_embs, _, ar_prior_loss, _ = self._forward(
xs,
ilens,
ys,
ds=ds,
spembs=spembs,
ref_embs=ref_embs,
ar_prior_inference=ar_prior_inference,
) # (1, L, odim)
else:
_, after_outs, _, ref_embs, _, ar_prior_loss, _ = self._forward(
xs,
ilens,
ys,
spembs=spembs,
ref_embs=ref_embs,
is_inference=True,
alpha=alpha,
ar_prior_inference=ar_prior_inference,
fg_inds=fg_inds,
) # (1, L, odim)
return after_outs[0], None, None, ref_embs, ar_prior_loss
# def _integrate_with_spk_embed(
# self, hs: torch.Tensor, spembs: torch.Tensor
# ) -> torch.Tensor:
# """Integrate speaker embedding with hidden states.
# Args:
# hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
# spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
# Returns:
# Tensor: Batch of integrated hidden state sequences (B, Tmax, adim).
# """
# if self.spk_embed_integration_type == "add":
# # apply projection and then add to hidden states
# spembs = self.projection(F.normalize(spembs))
# hs = hs + spembs.unsqueeze(1)
# elif self.spk_embed_integration_type == "concat":
# # concat hidden states with spk embeds and then apply projection
# spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
# hs = self.projection(torch.cat([hs, spembs], dim=-1))
# else:
# raise NotImplementedError("support only add or concat.")
# return hs
def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
"""Make masks for self-attention.
Args:
ilens (LongTensor): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for self-attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens = [5, 3]
>>> self._source_mask(ilens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2)
def _integrate_with_prosody_embs(
self, hs: torch.Tensor, p_embs: torch.Tensor
) -> torch.Tensor:
"""Integrate prosody embeddings with hidden states.
Args:
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
p_embs (Tensor): Batch of prosody embeddings (B, Tmax, adim).
Returns:
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim).
"""
if self.prosody_emb_integration_type == "add":
# apply projection and then add to hidden states
# (B, Tmax, adim)
hs = hs + p_embs
elif self.prosody_emb_integration_type == "concat":
# concat hidden states with prosody embeds and then apply projection
# (B, Tmax, adim)
hs = self.prosody_projection(torch.cat([hs, p_embs], dim=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs
def _reset_parameters(
self, init_type: str, init_enc_alpha: float, init_dec_alpha: float
):
# initialize parameters
if init_type != "pytorch":
initialize(self, init_type)
# initialize alpha in scaled positional encoding
if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)