homemade_lo_vi / bleu.py
moiduy04's picture
Upload 12 files
b8a6dde
raw
history blame
1.37 kB
import torch
from torch.utils.data import DataLoader
from torchtext.data.metrics import bleu_score
from tqdm import tqdm
from decode_method import beam_search_decode
from transformer import Transformer
from tokenizers import Tokenizer
def calculate_bleu_score(
model: Transformer,
bleu_dataloader: DataLoader,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
device = torch.device('cpu'),
num_samples: int = 9999999,
):
""""""
model.eval()
# inferance
count = 0
expected = []
predicted = []
with torch.no_grad():
batch_iterator = tqdm(bleu_dataloader)
for batch in batch_iterator:
count += 1
encoder_input = batch['encoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation"
model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
target_text = batch['tgt_text'][0]
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
expected.append([target_text.split()])
predicted.append(model_out_text.split())
if count == num_samples:
break
return bleu_score(predicted, expected) * 100.0