import torch from torch.utils.data import DataLoader from torchtext.data.metrics import bleu_score from tqdm import tqdm from decode_method import beam_search_decode from transformer import Transformer from tokenizers import Tokenizer def calculate_bleu_score( model: Transformer, bleu_dataloader: DataLoader, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, device = torch.device('cpu'), num_samples: int = 9999999, ): """""" model.eval() # inferance count = 0 expected = [] predicted = [] with torch.no_grad(): batch_iterator = tqdm(bleu_dataloader) for batch in batch_iterator: count += 1 encoder_input = batch['encoder_input'].to(device) encoder_mask = batch['encoder_mask'].to(device) assert encoder_input.size(0) == 1, "batch_size = 1 for bleu calculation" model_out = beam_search_decode(model, encoder_input, encoder_mask, src_tokenizer, tgt_tokenizer, 300, device) target_text = batch['tgt_text'][0] model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy()) expected.append([target_text.split()]) predicted.append(model_out_text.split()) if count == num_samples: break return bleu_score(predicted, expected) * 100.0