|
from torch.utils.data import DataLoader |
|
from bleu import calculate_bleu_score |
|
from load_dataset import load_local_bleu_dataset |
|
from dataset import BilingualDataset |
|
from config import load_config |
|
from load_and_save_model import load_model_tokenizer |
|
|
|
|
|
def get_bleu_of_model(config) -> float: |
|
model, src_tokenizer, tgt_tokenizer = load_model_tokenizer(config) |
|
bleu_ds_raw = load_local_bleu_dataset( |
|
src_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['src_lang'], |
|
tgt_dataset_filename='datasets/'+config['dataset']['bleu_dataset']+'.'+config['dataset']['tgt_lang'], |
|
src_lang=config['dataset']['src_lang'], |
|
tgt_lang=config['dataset']['tgt_lang'], |
|
) |
|
bleu_ds = BilingualDataset( |
|
ds=bleu_ds_raw, |
|
src_tokenizer=src_tokenizer, |
|
tgt_tokenizer=tgt_tokenizer, |
|
src_lang=config['dataset']['src_lang'], |
|
tgt_lang=config['dataset']['tgt_lang'], |
|
src_max_seq_len=config['dataset']['src_max_seq_len'], |
|
tgt_max_seq_len=config['dataset']['tgt_max_seq_len'], |
|
) |
|
bleu_dataloader = DataLoader(bleu_ds, batch_size=1, shuffle=True) |
|
return calculate_bleu_score( |
|
model, bleu_dataloader, src_tokenizer, tgt_tokenizer, |
|
) |
|
|
|
if __name__ == '__main__': |
|
for file_name in {'config_final.yaml', 'config_huge.yaml', 'config_big.yaml', 'config_small.yaml'}: |
|
config = load_config(file_name) |
|
print(get_bleu_of_model(config), f" is the BLEU of {file_name}", sep='') |