homemade_lo_vi / validate.py
moiduy04's picture
Upload 12 files
b8a6dde
import os
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# from datasets import load_dataset
from load_dataset import load_local_dataset
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path
from tokenizers import Tokenizer
from tokenizers.models import WordLevel, BPE
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from dataset import BilingualDataset
from bleu import calculate_bleu_score
from decode_method import greedy_decode
def run_validation(
model: Transformer,
validation_ds: DataLoader,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
device,
print_msg,
global_state,
writer,
num_examples:int = 2
):
model.eval()
# inferance
count = 0
source_texts = []
expected = []
predicted = []
console_width = 50
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch['encoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
assert encoder_input.size(0) == 1, "batch_size = 1 for validation"
model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
source_text = batch['src_text'][0]
target_text = batch['tgt_text'][0]
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
source_texts.append(source_text)
expected.append(target_text)
predicted.append(model_out_text)
print_msg("-"*console_width)
print_msg(f"SOURCE: {source_text}")
print_msg(f"TARGET: {target_text}")
print_msg(f"PREDICTED: {model_out_text}")
if count == num_examples:
break