File size: 1,366 Bytes
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torch.utils.data import DataLoader
from torchtext.data.metrics import bleu_score

from tqdm import tqdm

from decode_method import beam_search_decode
from transformer import Transformer

from tokenizers import Tokenizer


def calculate_bleu_score(
    model: Transformer, 
    bleu_dataloader: DataLoader, 
    src_tokenizer: Tokenizer, 
    tgt_tokenizer: Tokenizer, 
    device = torch.device('cpu'),
    num_samples: int = 9999999,
):
    """"""
    model.eval()

    # inferance
    count = 0
    expected = []
    predicted = []

    with torch.no_grad():
        batch_iterator = tqdm(bleu_dataloader)
        for batch in batch_iterator:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)

            assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation"

            model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device)

            target_text = batch['tgt_text'][0]
            model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())

            expected.append([target_text.split()])
            predicted.append(model_out_text.split())

            if count == num_samples:
                break

    return bleu_score(predicted, expected) * 100.0