from typing import Tuple import torch.nn as nn from torch import Tensor from modules.transformer_embedding import TransformerEmbedding from modules.positional_encoding import PositionalEncoding from model.encoder import Encoder from model.decoder import Decoder from layers.projection_layer import ProjectionLayer class Transformer(nn.Module): """ Transformer. Args: - src_vocab_size (int): source vocabulary size - tgt_vocab_size (int): target vocabulary size - src_max_seq_len (int): source max sequence length - tgt_max_seq_len (int): target max sequence length - d_model (int): dimension of model - num_heads (int): number of heads - d_ff (int): dimension of hidden feed forward layer - dropout_p (float): probability of dropout - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers """ def __init__( self, src_vocab_size: int, tgt_vocab_size: int, src_max_seq_len: int, tgt_max_seq_len: int, d_model: int = 512, num_heads: int = 8, d_ff: int = 2048, dropout_p: float = 0.1, num_encoder_layers: int = 6, num_decoder_layers: int = 6, ) -> None: super(Transformer, self).__init__() # Embedding layers self.src_embedding = TransformerEmbedding( d_model=d_model, num_embeddings=src_vocab_size ) self.tgt_embedding = TransformerEmbedding( d_model=d_model, num_embeddings=tgt_vocab_size ) # Positional Encoding layers self.src_positional_encoding = PositionalEncoding( d_model=d_model, dropout_p=dropout_p, max_length=src_max_seq_len ) self.tgt_positional_encoding = PositionalEncoding( d_model=d_model, dropout_p=dropout_p, max_length=tgt_max_seq_len ) # Encoder self.encoder = Encoder( d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout_p=dropout_p, num_layers=num_encoder_layers ) # Decoder self.decoder = Decoder( d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout_p=dropout_p, num_layers=num_decoder_layers ) # projecting decoder's output to the target language. self.projection_layer = ProjectionLayer( d_model=d_model, vocab_size=tgt_vocab_size ) def encode( self, src: Tensor, src_mask: Tensor ) -> Tensor: """ Get encoder outputs. """ src = self.src_embedding(src) src = self.src_positional_encoding(src) return self.encoder(src, src_mask) def decode( self, encoder_output: Tensor, src_mask: Tensor, tgt: Tensor, tgt_mask: Tensor ) -> Tuple[Tensor, Tensor]: """ Get decoder outputs for a set of target inputs. """ tgt = self.tgt_embedding(tgt) tgt = self.tgt_positional_encoding(tgt) return self.decoder( x=tgt, encoder_output=encoder_output, src_mask=src_mask, tgt_mask=tgt_mask ) def project(self, decoder_output: Tensor) -> Tensor: """ Project decoder outputs to target vocabulary. """ return self.projection_layer(decoder_output) def forward( self, src: Tensor, src_mask: Tensor, tgt: Tensor, tgt_mask: Tensor ) -> Tuple[Tensor, Tensor]: # src_mask = self.make_src_mask(src) # tgt_mask = self.make_tgt_mask(tgt) encoder_output = self.encode(src, src_mask) decoder_output, attn = self.decode( encoder_output, src_mask, tgt, tgt_mask ) output = self.project(decoder_output) return output, attn def count_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) def get_model(config, src_vocab_size: int, tgt_vocab_size: int) -> Transformer: """ returns a `Transformer` model for a given config. """ return Transformer( src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, src_max_seq_len=config['dataset']['src_max_seq_len'], tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], d_model=config['model']['d_model'], num_heads=config['model']['num_heads'], d_ff=config['model']['d_ff'], dropout_p=config['model']['dropout_p'], num_encoder_layers=config['model']['num_encoder_layers'], num_decoder_layers=config['model']['num_decoder_layers'], )