|
from typing import Tuple |
|
|
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from modules.transformer_embedding import TransformerEmbedding |
|
from modules.positional_encoding import PositionalEncoding |
|
|
|
from model.encoder import Encoder |
|
from model.decoder import Decoder |
|
from layers.projection_layer import ProjectionLayer |
|
|
|
class Transformer(nn.Module): |
|
""" |
|
Transformer. |
|
|
|
Args: |
|
- src_vocab_size (int): source vocabulary size |
|
- tgt_vocab_size (int): target vocabulary size |
|
- src_max_seq_len (int): source max sequence length |
|
- tgt_max_seq_len (int): target max sequence length |
|
- d_model (int): dimension of model |
|
- num_heads (int): number of heads |
|
- d_ff (int): dimension of hidden feed forward layer |
|
- dropout_p (float): probability of dropout |
|
- num_encoder_layers (int): number of encoder layers |
|
- num_decoder_layers (int): number of decoder layers |
|
""" |
|
def __init__( |
|
self, |
|
src_vocab_size: int, |
|
tgt_vocab_size: int, |
|
src_max_seq_len: int, |
|
tgt_max_seq_len: int, |
|
d_model: int = 512, |
|
num_heads: int = 8, |
|
d_ff: int = 2048, |
|
dropout_p: float = 0.1, |
|
num_encoder_layers: int = 6, |
|
num_decoder_layers: int = 6, |
|
) -> None: |
|
super(Transformer, self).__init__() |
|
|
|
|
|
self.src_embedding = TransformerEmbedding( |
|
d_model=d_model, |
|
num_embeddings=src_vocab_size |
|
) |
|
self.tgt_embedding = TransformerEmbedding( |
|
d_model=d_model, |
|
num_embeddings=tgt_vocab_size |
|
) |
|
|
|
|
|
self.src_positional_encoding = PositionalEncoding( |
|
d_model=d_model, |
|
dropout_p=dropout_p, |
|
max_length=src_max_seq_len |
|
) |
|
self.tgt_positional_encoding = PositionalEncoding( |
|
d_model=d_model, |
|
dropout_p=dropout_p, |
|
max_length=tgt_max_seq_len |
|
) |
|
|
|
|
|
self.encoder = Encoder( |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
d_ff=d_ff, |
|
dropout_p=dropout_p, |
|
num_layers=num_encoder_layers |
|
) |
|
|
|
self.decoder = Decoder( |
|
d_model=d_model, |
|
num_heads=num_heads, |
|
d_ff=d_ff, |
|
dropout_p=dropout_p, |
|
num_layers=num_decoder_layers |
|
) |
|
|
|
self.projection_layer = ProjectionLayer( |
|
d_model=d_model, |
|
vocab_size=tgt_vocab_size |
|
) |
|
|
|
def encode( |
|
self, |
|
src: Tensor, |
|
src_mask: Tensor |
|
) -> Tensor: |
|
""" |
|
Get encoder outputs. |
|
""" |
|
src = self.src_embedding(src) |
|
src = self.src_positional_encoding(src) |
|
return self.encoder(src, src_mask) |
|
|
|
def decode( |
|
self, |
|
encoder_output: Tensor, |
|
src_mask: Tensor, |
|
tgt: Tensor, |
|
tgt_mask: Tensor |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Get decoder outputs for a set of target inputs. |
|
""" |
|
tgt = self.tgt_embedding(tgt) |
|
tgt = self.tgt_positional_encoding(tgt) |
|
return self.decoder( |
|
x=tgt, |
|
encoder_output=encoder_output, |
|
src_mask=src_mask, |
|
tgt_mask=tgt_mask |
|
) |
|
|
|
def project(self, decoder_output: Tensor) -> Tensor: |
|
""" |
|
Project decoder outputs to target vocabulary. |
|
""" |
|
return self.projection_layer(decoder_output) |
|
|
|
def forward( |
|
self, |
|
src: Tensor, |
|
src_mask: Tensor, |
|
tgt: Tensor, |
|
tgt_mask: Tensor |
|
) -> Tuple[Tensor, Tensor]: |
|
|
|
|
|
|
|
encoder_output = self.encode(src, src_mask) |
|
decoder_output, attn = self.decode( |
|
encoder_output, src_mask, tgt, tgt_mask |
|
) |
|
output = self.project(decoder_output) |
|
return output, attn |
|
|
|
def count_parameters(self): |
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
def get_model(config, src_vocab_size: int, tgt_vocab_size: int) -> Transformer: |
|
""" |
|
returns a `Transformer` model for a given config. |
|
""" |
|
return Transformer( |
|
src_vocab_size=src_vocab_size, |
|
tgt_vocab_size=tgt_vocab_size, |
|
src_max_seq_len=config['dataset']['src_max_seq_len'], |
|
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], |
|
d_model=config['model']['d_model'], |
|
num_heads=config['model']['num_heads'], |
|
d_ff=config['model']['d_ff'], |
|
dropout_p=config['model']['dropout_p'], |
|
num_encoder_layers=config['model']['num_encoder_layers'], |
|
num_decoder_layers=config['model']['num_decoder_layers'], |
|
) |