import os import torch import torch.nn as nn from torch import Tensor from torch.utils.data import DataLoader, Dataset import torchmetrics from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm # from datasets import load_dataset from load_dataset import load_local_dataset from transformer import get_model from config import load_config, get_weights_file_path from validate import run_validation from tokenizer import get_or_build_local_tokenizer from pathlib import Path from dataset import BilingualDataset from bleu import calculate_bleu_score from decode_method import greedy_decode def get_local_dataset_tokenizer(config): train_ds_raw = load_local_dataset( dataset_filename='datasets/'+config['dataset']['train_dataset'], src_lang=config['dataset']['src_lang'], tgt_lang=config['dataset']['tgt_lang'] ) val_ds_raw = load_local_dataset( dataset_filename='datasets/'+config['dataset']['validate_dataset'], src_lang=config['dataset']['src_lang'], tgt_lang=config['dataset']['tgt_lang'] ) src_tokenizer = get_or_build_local_tokenizer( config=config, ds=train_ds_raw + val_ds_raw, lang=config['dataset']['src_lang'], tokenizer_type=config['dataset']['src_tokenizer'] ) tgt_tokenizer = get_or_build_local_tokenizer( config=config, ds=train_ds_raw + val_ds_raw, lang=config['dataset']['tgt_lang'], tokenizer_type=config['dataset']['tgt_tokenizer'] ) train_ds = BilingualDataset( ds=train_ds_raw, src_tokenizer=src_tokenizer, tgt_tokenizer=tgt_tokenizer, src_lang=config['dataset']['src_lang'], tgt_lang=config['dataset']['tgt_lang'], src_max_seq_len=config['dataset']['src_max_seq_len'], tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], ) val_ds = BilingualDataset( ds=val_ds_raw, src_tokenizer=src_tokenizer, tgt_tokenizer=tgt_tokenizer, src_lang=config['dataset']['src_lang'], tgt_lang=config['dataset']['tgt_lang'], src_max_seq_len=config['dataset']['src_max_seq_len'], tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], ) src_max_seq_len = 0 tgt_max_seq_len = 0 for item in (train_ds_raw + val_ds_raw): src_ids = src_tokenizer.encode(item['translation'][config['dataset']['src_lang']]).ids tgt_ids = tgt_tokenizer.encode(item['translation'][config['dataset']['tgt_lang']]).ids src_max_seq_len = max(src_max_seq_len, len(src_ids)) tgt_max_seq_len = max(tgt_max_seq_len, len(tgt_ids)) print(f'Max length of source sequence: {src_max_seq_len}') print(f'Max length of target sequence: {tgt_max_seq_len}') train_dataloader = DataLoader(train_ds, batch_size=config['train']['batch_size'], shuffle=True) val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) return train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer def train_model(config): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device {device}') Path(config['model']['model_folder']).mkdir(parents=True, exist_ok=True) train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config) model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device) print(f'{src_tokenizer.get_vocab_size()}, {tgt_tokenizer.get_vocab_size()}') #Tensorboard writer = SummaryWriter(config['experiment_name']) optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'], eps=1e-9) from transformers import get_linear_schedule_with_warmup scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=config['train']['warm_up_steps'], num_training_steps=len(train_dataloader) * config['train']['num_epochs']+1 ) initial_epoch = 0 global_step = 0 if config['model']['preload']: model_filename = get_weights_file_path(config, config['model']['preload']) print(f'Preloading model from {model_filename}') state = torch.load(model_filename, map_location=device) initial_epoch = state['epoch']+1 model.load_state_dict(state['model_state_dict']) optimizer.load_state_dict(state['optimizer_state_dict']) scheduler.load_state_dict(state['scheduler_state_dict']) global_step = state['global_step'] loss_fn = nn.CrossEntropyLoss( ignore_index=src_tokenizer.token_to_id(''), label_smoothing=config['train']['label_smoothing'], ).to(device) print(f"Training model with {model.count_parameters()} params.") patience = config['train']['patience'] best_state = { 'model_state_dict': model.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': 9999999.99 } for epoch in range(initial_epoch, config['train']['num_epochs']): batch_iterator = tqdm(train_dataloader, desc=f'Proceesing epoch {epoch:02d}') for batch in batch_iterator: model.train() encoder_input = batch['encoder_input'].to(device) # (batch, seq_len) decoder_input = batch['decoder_input'].to(device) # (batch. seq_len) encoder_mask = batch['encoder_mask'].to(device) # (batch, 1, 1, seq_len) decoder_mask = batch['decoder_mask'].to(device) # (batch, 1, seq_len, seq_len) encoder_output = model.encode(encoder_input, encoder_mask) # (batch, seq_len, d_model) decoder_output, attn = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (batch, seq_len, d_model) proj_output = model.project(decoder_output) # (batch, seq_len, tgt_vocab_size) label = batch['label'].to(device) # (batch, seq_len) loss = loss_fn(proj_output.view(-1, tgt_tokenizer.get_vocab_size()), label.view(-1)) batch_iterator.set_postfix({f"loss":f"{loss.item():6.3f}"}) writer.add_scalar('train_loss', loss.item(), global_step) writer.flush() global_step += 1 if global_step % patience == 0: if loss > best_state['loss']: model.load_state_dict(best_state['model_state_dict']) optimizer.load_state_dict(best_state['optimizer_state_dict']) scheduler.load_state_dict(best_state['scheduler_state_dict']) continue else: best_state = { 'model_state_dict': model.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': 9999999.99 } loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() run_validation(model, val_dataloader, src_tokenizer, tgt_tokenizer, device, lambda msg: batch_iterator.write(msg), global_step, writer) model_filename = get_weights_file_path(config, f'{epoch:02d}') torch.save({ 'epoch': epoch, 'model_state_dict': best_state['model_state_dict'], 'scheduler_state_dict': best_state['scheduler_state_dict'], 'optimizer_state_dict': best_state['optimizer_state_dict'], 'global_step': global_step, }, model_filename) # print(f"Bleu score: {calculate_bleu_score(model, val_dataloader, src_tokenizer, tgt_tokenizer, device)}") if config['train']['on_colab']: # if (epoch % 5) == 0: # model_zip_filename = f'model_epoch_{epoch}.zip' # os.system(f'zip -r {model_zip_filename} /content/silver-spoon/weights') runs_zip_filename = f'runs_epoch_{epoch}.zip' os.system(f"zip -r {runs_zip_filename} /content/silver-spoon/{config['experiment_name']}") if __name__ == '__main__': config = load_config(file_name='config.yaml') train_model(config)