# src/train.py import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import json from model import TransformerModel from utils import load_vocab, tokenize from tqdm import tqdm import os import subprocess class TextDataset(Dataset): def __init__(self, data_path, vocab, seq_length=50): with open(data_path, 'r', encoding='utf-8') as f: self.data = json.load(f) self.vocab = vocab self.seq_length = seq_length def __len__(self): return len(self.data) def numericalize(self, tokens): return [self.vocab.get(token, self.vocab['']) for token in tokens] def __getitem__(self, idx): tokens = self.data[idx] numericalized = self.numericalize(tokens) if len(numericalized) < self.seq_length + 1: numericalized += [self.vocab['']] * (self.seq_length + 1 - len(numericalized)) else: numericalized = numericalized[:self.seq_length + 1] input_seq = torch.tensor(numericalized[:-1], dtype=torch.long) target_seq = torch.tensor(numericalized[1:], dtype=torch.long) return input_seq, target_seq def collate_fn(batch): inputs, targets = zip(*batch) inputs = torch.stack(inputs) targets = torch.stack(targets) return inputs, targets def get_dataloader(data_path, vocab, batch_size=64, seq_length=50): dataset = TextDataset(data_path, vocab, seq_length) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) return dataloader def train_model(config): # Check if vocab.json exists if not os.path.exists(config['vocab_path']): print("vocab.json not found. Running data_processing.py...") subprocess.run(['python', 'src/data_processing.py'], check=True) # Load vocabulary vocab = load_vocab(config['vocab_path']) vocab_size = len(vocab) # Initialize model model = TransformerModel( vocab_size=vocab_size, embed_size=config['embed_size'], num_heads=config['num_heads'], hidden_dim=config['hidden_dim'], num_layers=config['num_layers'], dropout=config['dropout'] ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss(ignore_index=vocab['']) optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) # DataLoader dataloader = get_dataloader( data_path=config['data_path'], vocab=vocab, batch_size=config['batch_size'], seq_length=config['seq_length'] ) # Training loop model.train() for epoch in range(1, config['epochs'] + 1): epoch_loss = 0 progress = tqdm(dataloader, desc=f"Epoch {epoch}/{config['epochs']}") for inputs, targets in progress: inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() src_mask = model.generate_square_subsequent_mask(inputs.size(1)).to(device) outputs = model(inputs, src_mask) loss = criterion(outputs.view(-1, vocab_size), targets.view(-1)) loss.backward() optimizer.step() epoch_loss += loss.item() progress.set_postfix(loss=loss.item()) avg_loss = epoch_loss / len(dataloader) print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}") # Save model after each epoch os.makedirs('models', exist_ok=True) torch.save(model.state_dict(), f"models/3ed0k4_model_epoch{epoch}.pth") print(f"Model saved at models/3ed0k4_model_epoch{epoch}.pth") if __name__ == "__main__": config = { 'vocab_path': 'vocab.json', 'data_path': 'data/processed/tokenized_data.json', 'embed_size': 256, 'num_heads': 8, 'hidden_dim': 512, 'num_layers': 4, 'dropout': 0.1, 'learning_rate': 0.001, 'batch_size': 64, 'seq_length': 50, 'epochs': 10 } train_model(config)