Spaces:
Sleeping
Sleeping
# src/model.py | |
import torch | |
import torch.nn as nn | |
class TransformerModel(nn.Module): | |
def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, dropout=0.1): | |
super(TransformerModel, self).__init__() | |
self.embed_size = embed_size | |
self.token_embedding = nn.Embedding(vocab_size, embed_size) | |
self.position_embedding = nn.Embedding(5000, embed_size) # Max sequence length | |
encoder_layers = nn.TransformerEncoderLayer( | |
d_model=embed_size, | |
nhead=num_heads, | |
dim_feedforward=hidden_dim, | |
dropout=dropout | |
) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) | |
self.fc_out = nn.Linear(embed_size, vocab_size) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src, src_mask): | |
batch_size, seq_length = src.size() | |
positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(src.device) | |
x = self.token_embedding(src) + self.position_embedding(positions) | |
x = self.dropout(x) | |
x = x.permute(1, 0, 2) # Transformer expects [seq_length, batch_size, embed_size] | |
transformer_out = self.transformer_encoder(x, src_mask) | |
transformer_out = transformer_out.permute(1, 0, 2) | |
logits = self.fc_out(transformer_out) | |
return logits | |
def generate_square_subsequent_mask(self, sz): | |
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask | |