import os import torch import torch.nn as nn from torch import Tensor from torch.utils.data import DataLoader, Dataset import torchmetrics from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm # from datasets import load_dataset from load_dataset import load_local_dataset from transformer import get_model, Transformer from config import load_config, get_weights_file_path from tokenizers import Tokenizer from tokenizers.models import WordLevel, BPE from tokenizers.trainers import WordLevelTrainer, BpeTrainer from tokenizers.pre_tokenizers import Whitespace from pathlib import Path from dataset import BilingualDataset from bleu import calculate_bleu_score from decode_method import greedy_decode def run_validation( model: Transformer, validation_ds: DataLoader, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, device, print_msg, global_state, writer, num_examples:int = 2 ): model.eval() # inferance count = 0 source_texts = [] expected = [] predicted = [] console_width = 50 with torch.no_grad(): for batch in validation_ds: count += 1 encoder_input = batch['encoder_input'].to(device) encoder_mask = batch['encoder_mask'].to(device) assert encoder_input.size(0) == 1, "batch_size = 1 for validation" model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device) source_text = batch['src_text'][0] target_text = batch['tgt_text'][0] model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy()) source_texts.append(source_text) expected.append(target_text) predicted.append(model_out_text) print_msg("-"*console_width) print_msg(f"SOURCE: {source_text}") print_msg(f"TARGET: {target_text}") print_msg(f"PREDICTED: {model_out_text}") if count == num_examples: break