Spaces:
Sleeping
Sleeping
File size: 1,319 Bytes
6faf7e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from typing import List
from poetry_diacritizer.models.seq2seq import Seq2Seq, Decoder as Seq2SeqDecoder
from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet
from torch import nn
class Tacotron(Seq2Seq):
pass
class Encoder(nn.Module):
def __init__(
self,
inp_vocab_size: int,
embedding_dim: int = 512,
use_prenet: bool = True,
prenet_sizes: List[int] = [256, 128],
cbhg_gru_units: int = 128,
cbhg_filters: int = 16,
cbhg_projections: List[int] = [128, 128],
padding_idx: int = 0,
):
super().__init__()
self.use_prenet = use_prenet
self.embedding = nn.Embedding(
inp_vocab_size, embedding_dim, padding_idx=padding_idx
)
if use_prenet:
self.prenet = Prenet(embedding_dim, prenet_depth=prenet_sizes)
self.cbhg = CBHG(
prenet_sizes[-1] if use_prenet else embedding_dim,
cbhg_gru_units,
K=cbhg_filters,
projections=cbhg_projections,
)
def forward(self, inputs, input_lengths=None):
outputs = self.embedding(inputs)
if self.use_prenet:
outputs = self.prenet(outputs)
return self.cbhg(outputs, input_lengths)
class Decoder(Seq2SeqDecoder):
pass
|