|
import os |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.utils.data import DataLoader, Dataset |
|
import torchmetrics |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
from load_dataset import load_local_dataset |
|
from transformer import get_model |
|
from config import load_config, get_weights_file_path |
|
from validate import run_validation |
|
from tokenizer import get_or_build_local_tokenizer |
|
|
|
from pathlib import Path |
|
|
|
from dataset import BilingualDataset |
|
from bleu import calculate_bleu_score |
|
from decode_method import greedy_decode |
|
|
|
def get_local_dataset_tokenizer(config): |
|
train_ds_raw = load_local_dataset( |
|
dataset_filename='datasets/'+config['dataset']['train_dataset'], |
|
src_lang=config['dataset']['src_lang'], |
|
tgt_lang=config['dataset']['tgt_lang'] |
|
) |
|
val_ds_raw = load_local_dataset( |
|
dataset_filename='datasets/'+config['dataset']['validate_dataset'], |
|
src_lang=config['dataset']['src_lang'], |
|
tgt_lang=config['dataset']['tgt_lang'] |
|
) |
|
|
|
src_tokenizer = get_or_build_local_tokenizer( |
|
config=config, |
|
ds=train_ds_raw + val_ds_raw, |
|
lang=config['dataset']['src_lang'], |
|
tokenizer_type=config['dataset']['src_tokenizer'] |
|
) |
|
tgt_tokenizer = get_or_build_local_tokenizer( |
|
config=config, |
|
ds=train_ds_raw + val_ds_raw, |
|
lang=config['dataset']['tgt_lang'], |
|
tokenizer_type=config['dataset']['tgt_tokenizer'] |
|
) |
|
|
|
train_ds = BilingualDataset( |
|
ds=train_ds_raw, |
|
src_tokenizer=src_tokenizer, |
|
tgt_tokenizer=tgt_tokenizer, |
|
src_lang=config['dataset']['src_lang'], |
|
tgt_lang=config['dataset']['tgt_lang'], |
|
src_max_seq_len=config['dataset']['src_max_seq_len'], |
|
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], |
|
) |
|
val_ds = BilingualDataset( |
|
ds=val_ds_raw, |
|
src_tokenizer=src_tokenizer, |
|
tgt_tokenizer=tgt_tokenizer, |
|
src_lang=config['dataset']['src_lang'], |
|
tgt_lang=config['dataset']['tgt_lang'], |
|
src_max_seq_len=config['dataset']['src_max_seq_len'], |
|
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], |
|
) |
|
|
|
src_max_seq_len = 0 |
|
tgt_max_seq_len = 0 |
|
for item in (train_ds_raw + val_ds_raw): |
|
src_ids = src_tokenizer.encode(item['translation'][config['dataset']['src_lang']]).ids |
|
tgt_ids = tgt_tokenizer.encode(item['translation'][config['dataset']['tgt_lang']]).ids |
|
src_max_seq_len = max(src_max_seq_len, len(src_ids)) |
|
tgt_max_seq_len = max(tgt_max_seq_len, len(tgt_ids)) |
|
print(f'Max length of source sequence: {src_max_seq_len}') |
|
print(f'Max length of target sequence: {tgt_max_seq_len}') |
|
|
|
train_dataloader = DataLoader(train_ds, batch_size=config['train']['batch_size'], shuffle=True) |
|
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) |
|
|
|
return train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer |
|
|
|
def train_model(config): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
print(f'Using device {device}') |
|
|
|
Path(config['model']['model_folder']).mkdir(parents=True, exist_ok=True) |
|
|
|
train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_local_dataset_tokenizer(config) |
|
model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device) |
|
|
|
print(f'{src_tokenizer.get_vocab_size()}, {tgt_tokenizer.get_vocab_size()}') |
|
|
|
|
|
writer = SummaryWriter(config['experiment_name']) |
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'], eps=1e-9) |
|
|
|
from transformers import get_linear_schedule_with_warmup |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=config['train']['warm_up_steps'], |
|
num_training_steps=len(train_dataloader) * config['train']['num_epochs']+1 |
|
) |
|
|
|
initial_epoch = 0 |
|
global_step = 0 |
|
if config['model']['preload']: |
|
model_filename = get_weights_file_path(config, config['model']['preload']) |
|
print(f'Preloading model from {model_filename}') |
|
state = torch.load(model_filename, map_location=device) |
|
|
|
initial_epoch = state['epoch']+1 |
|
model.load_state_dict(state['model_state_dict']) |
|
optimizer.load_state_dict(state['optimizer_state_dict']) |
|
scheduler.load_state_dict(state['scheduler_state_dict']) |
|
global_step = state['global_step'] |
|
|
|
loss_fn = nn.CrossEntropyLoss( |
|
ignore_index=src_tokenizer.token_to_id('<pad>'), |
|
label_smoothing=config['train']['label_smoothing'], |
|
).to(device) |
|
|
|
print(f"Training model with {model.count_parameters()} params.") |
|
|
|
patience = config['train']['patience'] |
|
best_state = { |
|
'model_state_dict': model.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': 9999999.99 |
|
} |
|
|
|
for epoch in range(initial_epoch, config['train']['num_epochs']): |
|
batch_iterator = tqdm(train_dataloader, desc=f'Proceesing epoch {epoch:02d}') |
|
for batch in batch_iterator: |
|
model.train() |
|
|
|
encoder_input = batch['encoder_input'].to(device) |
|
decoder_input = batch['decoder_input'].to(device) |
|
encoder_mask = batch['encoder_mask'].to(device) |
|
decoder_mask = batch['decoder_mask'].to(device) |
|
|
|
encoder_output = model.encode(encoder_input, encoder_mask) |
|
decoder_output, attn = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) |
|
proj_output = model.project(decoder_output) |
|
|
|
label = batch['label'].to(device) |
|
|
|
loss = loss_fn(proj_output.view(-1, tgt_tokenizer.get_vocab_size()), label.view(-1)) |
|
batch_iterator.set_postfix({f"loss":f"{loss.item():6.3f}"}) |
|
|
|
writer.add_scalar('train_loss', loss.item(), global_step) |
|
writer.flush() |
|
|
|
global_step += 1 |
|
if global_step % patience == 0: |
|
if loss > best_state['loss']: |
|
model.load_state_dict(best_state['model_state_dict']) |
|
optimizer.load_state_dict(best_state['optimizer_state_dict']) |
|
scheduler.load_state_dict(best_state['scheduler_state_dict']) |
|
continue |
|
else: |
|
best_state = { |
|
'model_state_dict': model.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'loss': 9999999.99 |
|
} |
|
loss.backward() |
|
|
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
run_validation(model, val_dataloader, src_tokenizer, tgt_tokenizer, device, lambda msg: batch_iterator.write(msg), global_step, writer) |
|
|
|
model_filename = get_weights_file_path(config, f'{epoch:02d}') |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': best_state['model_state_dict'], |
|
'scheduler_state_dict': best_state['scheduler_state_dict'], |
|
'optimizer_state_dict': best_state['optimizer_state_dict'], |
|
'global_step': global_step, |
|
}, model_filename) |
|
|
|
|
|
|
|
if config['train']['on_colab']: |
|
|
|
|
|
|
|
runs_zip_filename = f'runs_epoch_{epoch}.zip' |
|
os.system(f"zip -r {runs_zip_filename} /content/silver-spoon/{config['experiment_name']}") |
|
|
|
|
|
if __name__ == '__main__': |
|
config = load_config(file_name='/config/config_final.yaml') |
|
train_model(config) |
|
|