Spaces:
Sleeping
Sleeping
# src/train.py | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import Dataset, DataLoader | |
import json | |
from model import TransformerModel | |
from utils import load_vocab, tokenize | |
from tqdm import tqdm | |
import os | |
import subprocess | |
class TextDataset(Dataset): | |
def __init__(self, data_path, vocab, seq_length=50): | |
with open(data_path, 'r', encoding='utf-8') as f: | |
self.data = json.load(f) | |
self.vocab = vocab | |
self.seq_length = seq_length | |
def __len__(self): | |
return len(self.data) | |
def numericalize(self, tokens): | |
return [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens] | |
def __getitem__(self, idx): | |
tokens = self.data[idx] | |
numericalized = self.numericalize(tokens) | |
if len(numericalized) < self.seq_length + 1: | |
numericalized += [self.vocab['<PAD>']] * (self.seq_length + 1 - len(numericalized)) | |
else: | |
numericalized = numericalized[:self.seq_length + 1] | |
input_seq = torch.tensor(numericalized[:-1], dtype=torch.long) | |
target_seq = torch.tensor(numericalized[1:], dtype=torch.long) | |
return input_seq, target_seq | |
def collate_fn(batch): | |
inputs, targets = zip(*batch) | |
inputs = torch.stack(inputs) | |
targets = torch.stack(targets) | |
return inputs, targets | |
def get_dataloader(data_path, vocab, batch_size=64, seq_length=50): | |
dataset = TextDataset(data_path, vocab, seq_length) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) | |
return dataloader | |
def train_model(config): | |
# Check if vocab.json exists | |
if not os.path.exists(config['vocab_path']): | |
print("vocab.json not found. Running data_processing.py...") | |
subprocess.run(['python', 'src/data_processing.py'], check=True) | |
# Load vocabulary | |
vocab = load_vocab(config['vocab_path']) | |
vocab_size = len(vocab) | |
# Initialize model | |
model = TransformerModel( | |
vocab_size=vocab_size, | |
embed_size=config['embed_size'], | |
num_heads=config['num_heads'], | |
hidden_dim=config['hidden_dim'], | |
num_layers=config['num_layers'], | |
dropout=config['dropout'] | |
) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
# Loss and optimizer | |
criterion = nn.CrossEntropyLoss(ignore_index=vocab['<PAD>']) | |
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) | |
# DataLoader | |
dataloader = get_dataloader( | |
data_path=config['data_path'], | |
vocab=vocab, | |
batch_size=config['batch_size'], | |
seq_length=config['seq_length'] | |
) | |
# Training loop | |
model.train() | |
for epoch in range(1, config['epochs'] + 1): | |
epoch_loss = 0 | |
progress = tqdm(dataloader, desc=f"Epoch {epoch}/{config['epochs']}") | |
for inputs, targets in progress: | |
inputs = inputs.to(device) | |
targets = targets.to(device) | |
optimizer.zero_grad() | |
src_mask = model.generate_square_subsequent_mask(inputs.size(1)).to(device) | |
outputs = model(inputs, src_mask) | |
loss = criterion(outputs.view(-1, vocab_size), targets.view(-1)) | |
loss.backward() | |
optimizer.step() | |
epoch_loss += loss.item() | |
progress.set_postfix(loss=loss.item()) | |
avg_loss = epoch_loss / len(dataloader) | |
print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}") | |
# Save model after each epoch | |
os.makedirs('models', exist_ok=True) | |
torch.save(model.state_dict(), f"models/3ed0k4_model_epoch{epoch}.pth") | |
print(f"Model saved at models/3ed0k4_model_epoch{epoch}.pth") | |
if __name__ == "__main__": | |
config = { | |
'vocab_path': 'vocab.json', | |
'data_path': 'data/processed/tokenized_data.json', | |
'embed_size': 256, | |
'num_heads': 8, | |
'hidden_dim': 512, | |
'num_layers': 4, | |
'dropout': 0.1, | |
'learning_rate': 0.001, | |
'batch_size': 64, | |
'seq_length': 50, | |
'epochs': 10 | |
} | |
train_model(config) | |