Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
import contextlib | |
import torch | |
import torch.nn as nn | |
from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet | |
class SpeechDecoderPostnet(nn.Module): | |
""" | |
Args: | |
in_channels (int): the number of input channels | |
mid_channels (int): the number of intermediate channels | |
out_channels (int): the number of output channels | |
kernel_sizes (List[int]): the kernel size for each convolutional layer | |
""" | |
def __init__( | |
self, | |
odim, | |
args, | |
): | |
super(SpeechDecoderPostnet, self).__init__() | |
# define decoder postnet | |
# define final projection | |
self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor) | |
self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor) | |
# define postnet | |
self.postnet = ( | |
None | |
if args.postnet_layers == 0 | |
else Postnet( | |
idim=0, | |
odim=odim, | |
n_layers=args.postnet_layers, | |
n_chans=args.postnet_chans, | |
n_filts=args.postnet_filts, | |
use_batch_norm=args.use_batch_norm, | |
dropout_rate=args.postnet_dropout_rate, | |
) | |
) | |
self.odim = odim | |
self.num_updates = 0 | |
self.freeze_decoder_updates = args.freeze_decoder_updates | |
def forward(self, zs): | |
ft = self.freeze_decoder_updates <= self.num_updates | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) | |
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) | |
# (B, Lmax//r, r) -> (B, Lmax//r * r) | |
logits = self.prob_out(zs).view(zs.size(0), -1) | |
# postnet -> (B, Lmax//r * r, odim) | |
if self.postnet is None: | |
after_outs = before_outs | |
else: | |
after_outs = before_outs + self.postnet( | |
before_outs.transpose(1, 2) | |
).transpose(1, 2) | |
return before_outs, after_outs, logits | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
self.num_updates = num_updates | |