File size: 1,366 Bytes
b8a6dde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from torch.utils.data import DataLoader
from torchtext.data.metrics import bleu_score
from tqdm import tqdm
from decode_method import beam_search_decode
from transformer import Transformer
from tokenizers import Tokenizer
def calculate_bleu_score(
model: Transformer,
bleu_dataloader: DataLoader,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
device = torch.device('cpu'),
num_samples: int = 9999999,
):
""""""
model.eval()
# inferance
count = 0
expected = []
predicted = []
with torch.no_grad():
batch_iterator = tqdm(bleu_dataloader)
for batch in batch_iterator:
count += 1
encoder_input = batch['encoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation"
model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)
target_text = batch['tgt_text'][0]
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())
expected.append([target_text.split()])
predicted.append(model_out_text.split())
if count == num_samples:
break
return bleu_score(predicted, expected) * 100.0
|