from typing import List from typing import List, Optional import torch from torch import nn from torch.autograd import Variable from poetry_diacritizer.modules.attention import AttentionWrapper from poetry_diacritizer.modules.layers import ConvNorm from poetry_diacritizer.modules.tacotron_modules import CBHG, Prenet from poetry_diacritizer.options import AttentionType from poetry_diacritizer.util.utils import get_mask_from_lengths class Seq2Seq(nn.Module): def __init__(self, encoder: nn.Module, decoder: nn.Module): super().__init__() # Trying smaller std self.encoder = encoder self.decoder = decoder def forward( self, src: torch.Tensor, lengths: torch.Tensor, target: Optional[torch.Tensor] = None, ): encoder_outputs = self.encoder(src, lengths) mask = get_mask_from_lengths(encoder_outputs, lengths) outputs, alignments = self.decoder(encoder_outputs, target, mask) output = {"diacritics": outputs, "attention": alignments} return output class Encoder(nn.Module): def __init__( self, inp_vocab_size: int, embedding_dim: int = 512, layers_units: List[int] = [256, 256, 256], use_batch_norm: bool = False, ): super().__init__() self.embedding = nn.Embedding(inp_vocab_size, embedding_dim) layers_units = [embedding_dim // 2] + layers_units layers = [] for i in range(1, len(layers_units)): layers.append( nn.LSTM( layers_units[i - 1] * 2, layers_units[i], bidirectional=True, batch_first=True, ) ) if use_batch_norm: layers.append(nn.BatchNorm1d(layers_units[i] * 2)) self.layers = nn.ModuleList(layers) self.layers_units = layers_units self.use_batch_norm = use_batch_norm def forward(self, inputs: torch.Tensor, inputs_lengths: torch.Tensor): outputs = self.embedding(inputs) # embedded_inputs = [batch_size, src_len, embedding_dim] for i, layer in enumerate(self.layers): if isinstance(layer, nn.BatchNorm1d): outputs = layer(outputs.permute(0, 2, 1)) outputs = outputs.permute(0, 2, 1) continue if i > 0: outputs, (hn, cn) = layer(outputs, (hn, cn)) else: outputs, (hn, cn) = layer(outputs) return outputs class Decoder(nn.Module): """A seq2seq decoder that decode a diacritic at a time , Args: encoder_dim (int): the encoder output dim decoder_units (int): the number of neurons for each decoder layer decoder_layers (int): number of decoder layers """ def __init__( self, trg_vocab_size: int, start_symbol_id: int, encoder_dim: int = 256, embedding_dim: int = 256, decoder_units: int = 256, decoder_layers: int = 2, attention_units: int = 256, attention_type: AttentionType = AttentionType.LocationSensitive, is_attention_accumulative: bool = False, prenet_depth: List[int] = [256, 128], use_prenet: bool = True, teacher_forcing_probability: float = 0.0, ): super().__init__() self.output_dim: int = trg_vocab_size self.start_symbol_id = start_symbol_id self.attention_units = attention_units self.decoder_units = decoder_units self.encoder_dim = encoder_dim self.use_prenet = use_prenet self.teacher_forcing_probability = teacher_forcing_probability self.is_attention_accumulative = is_attention_accumulative self.embbeding = nn.Embedding(trg_vocab_size, embedding_dim, padding_idx=0) attention_in = embedding_dim if use_prenet: self.prenet = Prenet(embedding_dim, prenet_depth) attention_in = prenet_depth[-1] self.attention_layer = nn.GRUCell(encoder_dim + attention_in, attention_units) self.attention_wrapper = AttentionWrapper(attention_type, attention_units) self.keys_layer = nn.Linear(encoder_dim, attention_units, bias=False) self.project_to_decoder_in = nn.Linear( attention_units + encoder_dim, decoder_units, ) self.decoder_rnns = nn.ModuleList( [nn.GRUCell(decoder_units, decoder_units) for _ in range(decoder_layers)] ) self.diacritics_layer = nn.Linear(decoder_units, trg_vocab_size) self.device = "cuda" def decode( self, diacritic: torch.Tensor, ): """ Decode one time-step Args: diacritic (Tensor): (batch_size, 1) Returns: """ diacritic = self.embbeding(diacritic) if self.use_prenet: prenet_out = self.prenet(diacritic) else: prenet_out = diacritic cell_input = torch.cat((prenet_out, self.prev_attention), -1) self.attention_hidden = self.attention_layer(cell_input, self.attention_hidden) output = self.attention_hidden # The queries are the hidden state of the RNN layer attention, alignment = self.attention_wrapper( query=self.attention_hidden, values=self.encoder_outputs, keys=self.keys, mask=self.mask, prev_alignment=self.prev_alignment, ) decoder_input = torch.cat((output, attention), -1) decoder_input = self.project_to_decoder_in(decoder_input) for idx in range(len(self.decoder_rnns)): self.decoder_hiddens[idx] = self.decoder_rnns[idx]( decoder_input, self.decoder_hiddens[idx] ) decoder_input = self.decoder_hiddens[idx] + decoder_input output = decoder_input output = self.diacritics_layer(output) if self.is_attention_accumulative: self.prev_alignment = self.prev_alignment + alignment else: self.prev_alignment = alignment self.prev_attention = attention return output, alignment def inference(self): """Generate diacritics one at a time""" batch_size = self.encoder_outputs.size(0) trg_len = self.encoder_outputs.size(1) diacritic = ( torch.full((batch_size,), self.start_symbol_id).to(self.device).long() ) outputs, alignments = [], [] self.initialize() for _ in range(trg_len): output, alignment = self.decode(diacritic=diacritic) outputs.append(output) alignments.append(alignment) diacritic = torch.max(output, 1).indices alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() return outputs, alignments def forward( self, encoder_outputs: torch.Tensor, diacritics: Optional[torch.Tensor] = None, input_mask: Optional[torch.Tensor] = None, ): """calculate forward propagation Args: encoder_outputs (Tensor): the output of the encoder (batch_size, Tx, encoder_units * 2) diacritics(Tensor): target sequence input_mask (Tensor): the inputs mask (batch_size, Tx) """ self.mask = input_mask self.encoder_outputs = encoder_outputs self.keys = self.keys_layer(encoder_outputs) if diacritics is None: return self.inference() batch_size = diacritics.size(0) trg_len = diacritics.size(1) # Init decoder states outputs = [] alignments = [] self.initialize() diacritic = ( torch.full((batch_size,), self.start_symbol_id).to(self.device).long() ) for time in range(trg_len): output, alignment = self.decode(diacritic=diacritic) outputs += [output] alignments += [alignment] #if random.random() > self.teacher_forcing_probability: diacritic = diacritics[:, time] # use training input #else: #diacritic = torch.max(output, 1).indices # use last output alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() return outputs, alignments def initialize(self): """Initialize the first step variables""" batch_size = self.encoder_outputs.size(0) src_len = self.encoder_outputs.size(1) self.attention_hidden = Variable( torch.zeros(batch_size, self.attention_units) ).to(self.device) self.decoder_hiddens = [ Variable(torch.zeros(batch_size, self.decoder_units)).to(self.device) for _ in range(len(self.decoder_rnns)) ] self.prev_attention = Variable(torch.zeros(batch_size, self.encoder_dim)).to( self.device ) self.prev_alignment = Variable(torch.zeros(batch_size, src_len)).to(self.device)