artst-tts-demo / artst /models /modules /speech_decoder_postnet.py
herwoww's picture
first upload
1547a56
raw
history blame
2.7 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import contextlib
import torch
import torch.nn as nn
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet
class SpeechDecoderPostnet(nn.Module):
"""
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
odim,
args,
):
super(SpeechDecoderPostnet, self).__init__()
# define decoder postnet
# define final projection
self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor)
self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor)
# define postnet
self.postnet = (
None
if args.postnet_layers == 0
else Postnet(
idim=0,
odim=odim,
n_layers=args.postnet_layers,
n_chans=args.postnet_chans,
n_filts=args.postnet_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.postnet_dropout_rate,
)
)
self.odim = odim
self.num_updates = 0
self.freeze_decoder_updates = args.freeze_decoder_updates
def forward(self, zs):
ft = self.freeze_decoder_updates <= self.num_updates
with torch.no_grad() if not ft else contextlib.ExitStack():
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
# (B, Lmax//r, r) -> (B, Lmax//r * r)
logits = self.prob_out(zs).view(zs.size(0), -1)
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose(1, 2)
).transpose(1, 2)
return before_outs, after_outs, logits
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self.num_updates = num_updates