Spaces:
Sleeping
Sleeping
from typing import List | |
from poetry_diacritizer.models.seq2seq import Seq2Seq, Decoder as Seq2SeqDecoder | |
from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet | |
from torch import nn | |
class Tacotron(Seq2Seq): | |
pass | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
inp_vocab_size: int, | |
embedding_dim: int = 512, | |
use_prenet: bool = True, | |
prenet_sizes: List[int] = [256, 128], | |
cbhg_gru_units: int = 128, | |
cbhg_filters: int = 16, | |
cbhg_projections: List[int] = [128, 128], | |
padding_idx: int = 0, | |
): | |
super().__init__() | |
self.use_prenet = use_prenet | |
self.embedding = nn.Embedding( | |
inp_vocab_size, embedding_dim, padding_idx=padding_idx | |
) | |
if use_prenet: | |
self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes) | |
self.cbhg = CBHG( | |
prenet_sizes[-1] if use_prenet else embedding_dim, | |
cbhg_gru_units, | |
K=cbhg_filters, | |
projections=cbhg_projections, | |
) | |
def forward(self, inputs, input_lengths=None): | |
outputs = self.embedding(inputs) | |
if self.use_prenet: | |
outputs = self.prenet(outputs) | |
return self.cbhg(outputs, input_lengths) | |
class Decoder(Seq2SeqDecoder): | |
pass | |