File size: 2,048 Bytes
b8a6dde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
# from datasets import load_dataset
from load_dataset import load_local_dataset
from transformer import get_model, Transformer
from config import load_config, get_weights_file_path
from tokenizers import Tokenizer
from tokenizers.models import WordLevel, BPE
from tokenizers.trainers import WordLevelTrainer, BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from dataset import BilingualDataset
from bleu import calculate_bleu_score
from decode_method import greedy_decode
def run_validation(
model: Transformer,
validation_ds: DataLoader,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
device,
print_msg,
global_state,
writer,
num_examples:int = 2
):
model.eval()
# inferance
count = 0
source_texts = []
expected = []
predicted = []
console_width = 50
with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch['encoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
assert encoder_input.size(0) == 1, "batch_size = 1 for validation"
model_out = greedy_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
source_text = batch['src_text'][0]
target_text = batch['tgt_text'][0]
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
source_texts.append(source_text)
expected.append(target_text)
predicted.append(model_out_text)
print_msg("-"*console_width)
print_msg(f"SOURCE: {source_text}")
print_msg(f"TARGET: {target_text}")
print_msg(f"PREDICTED: {model_out_text}")
if count == num_examples:
break |