# src/model.py import torch import torch.nn as nn class TransformerModel(nn.Module): def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, dropout=0.1): super(TransformerModel, self).__init__() self.embed_size = embed_size self.token_embedding = nn.Embedding(vocab_size, embed_size) self.position_embedding = nn.Embedding(5000, embed_size) # Max sequence length encoder_layers = nn.TransformerEncoderLayer( d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout ) self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) self.fc_out = nn.Linear(embed_size, vocab_size) self.dropout = nn.Dropout(dropout) def forward(self, src, src_mask): batch_size, seq_length = src.size() positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(src.device) x = self.token_embedding(src) + self.position_embedding(positions) x = self.dropout(x) x = x.permute(1, 0, 2) # Transformer expects [seq_length, batch_size, embed_size] transformer_out = self.transformer_encoder(x, src_mask) transformer_out = transformer_out.permute(1, 0, 2) logits = self.fc_out(transformer_out) return logits def generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask