|
import torch |
|
from torch.utils.data import DataLoader |
|
from torchtext.data.metrics import bleu_score |
|
|
|
from tqdm import tqdm |
|
|
|
from decode_method import beam_search_decode |
|
from transformer import Transformer |
|
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
def calculate_bleu_score( |
|
model: Transformer, |
|
bleu_dataloader: DataLoader, |
|
src_tokenizer: Tokenizer, |
|
tgt_tokenizer: Tokenizer, |
|
device = torch.device('cpu'), |
|
num_samples: int = 9999999, |
|
): |
|
"""""" |
|
model.eval() |
|
|
|
|
|
count = 0 |
|
expected = [] |
|
predicted = [] |
|
|
|
with torch.no_grad(): |
|
batch_iterator = tqdm(bleu_dataloader) |
|
for batch in batch_iterator: |
|
count += 1 |
|
encoder_input = batch['encoder_input'].to(device) |
|
encoder_mask = batch['encoder_mask'].to(device) |
|
|
|
assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation" |
|
|
|
model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device) |
|
|
|
target_text = batch['tgt_text'][0] |
|
model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy()) |
|
|
|
expected.append([target_text.split()]) |
|
predicted.append(model_out_text.split()) |
|
|
|
if count == num_samples: |
|
break |
|
|
|
return bleu_score(predicted, expected) * 100.0 |
|
|