model_3ed0k4 / src /model.py
3ed0k4's picture
Upload 12 files
65224b2 verified
raw
history blame
1.66 kB
# src/model.py
import torch
import torch.nn as nn
class TransformerModel(nn.Module):
def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, dropout=0.1):
super(TransformerModel, self).__init__()
self.embed_size = embed_size
self.token_embedding = nn.Embedding(vocab_size, embed_size)
self.position_embedding = nn.Embedding(5000, embed_size) # Max sequence length
encoder_layers = nn.TransformerEncoderLayer(
d_model=embed_size,
nhead=num_heads,
dim_feedforward=hidden_dim,
dropout=dropout
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
self.fc_out = nn.Linear(embed_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask):
batch_size, seq_length = src.size()
positions = torch.arange(0, seq_length).unsqueeze(0).repeat(batch_size, 1).to(src.device)
x = self.token_embedding(src) + self.position_embedding(positions)
x = self.dropout(x)
x = x.permute(1, 0, 2) # Transformer expects [seq_length, batch_size, embed_size]
transformer_out = self.transformer_encoder(x, src_mask)
transformer_out = transformer_out.permute(1, 0, 2)
logits = self.fc_out(transformer_out)
return logits
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask