|
import os
|
|
import re
|
|
import sys
|
|
from tqdm import tqdm
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
|
langs_supported = {
|
|
"eng_Latn": "en_XX",
|
|
"guj_Gujr": "gu_IN",
|
|
"hin_Deva": "hi_IN",
|
|
"npi_Deva": "ne_NP",
|
|
"ben_Beng": "bn_IN",
|
|
"mal_Mlym": "ml_IN",
|
|
"mar_Deva": "mr_IN",
|
|
"tam_Taml": "ta_IN",
|
|
"tel_Telu": "te_IN",
|
|
"urd_Arab": "ur_PK",
|
|
}
|
|
|
|
|
|
def predict(batch, tokenizer, model, bos_token_id):
|
|
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
|
|
generated_tokens = model.generate(
|
|
**encoded_batch,
|
|
num_beams=5,
|
|
max_length=256,
|
|
min_length=0,
|
|
forced_bos_token_id=bos_token_id,
|
|
)
|
|
hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
return hypothesis
|
|
|
|
|
|
def main(devtest_data_dir, batch_size):
|
|
|
|
enxx_model_name = "facebook/mbart-large-50-one-to-many-mmt"
|
|
xxen_model_name = "facebook/mbart-large-50-many-to-one-mmt"
|
|
tokenizers = {
|
|
"enxx": AutoTokenizer.from_pretrained(enxx_model_name),
|
|
"xxen": AutoTokenizer.from_pretrained(xxen_model_name),
|
|
}
|
|
models = {
|
|
"enxx": AutoModelForSeq2SeqLM.from_pretrained(enxx_model_name).cuda(),
|
|
"xxen": AutoModelForSeq2SeqLM.from_pretrained(xxen_model_name).cuda(),
|
|
}
|
|
|
|
|
|
for model_name in models:
|
|
models[model_name].eval()
|
|
|
|
|
|
for pair in sorted(os.listdir(devtest_data_dir)):
|
|
if "-" not in pair:
|
|
continue
|
|
|
|
src_lang, tgt_lang = pair.split("-")
|
|
|
|
|
|
if (
|
|
src_lang not in langs_supported.keys()
|
|
or tgt_lang not in langs_supported.keys()
|
|
):
|
|
print(f"Skipping {src_lang}-{tgt_lang} ...")
|
|
continue
|
|
|
|
|
|
|
|
|
|
print(f"Evaluating {src_lang}-{tgt_lang} ...")
|
|
|
|
infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
|
|
outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.mbart50")
|
|
|
|
with open(infname, "r") as f:
|
|
src_sents = f.read().split("\n")
|
|
|
|
add_new_line = False
|
|
if src_sents[-1] == "":
|
|
add_new_line = True
|
|
src_sents = src_sents[:-1]
|
|
|
|
|
|
tokenizers["enxx"].src_lang = langs_supported[src_lang]
|
|
|
|
|
|
hypothesis = []
|
|
for i in tqdm(range(0, len(src_sents), batch_size)):
|
|
start, end = i, int(min(len(src_sents), i + batch_size))
|
|
batch = src_sents[start:end]
|
|
bos_token_id = tokenizers["enxx"].lang_code_to_id[langs_supported[tgt_lang]]
|
|
hypothesis += predict(
|
|
batch, tokenizers["enxx"], models["enxx"], bos_token_id
|
|
)
|
|
|
|
assert len(hypothesis) == len(src_sents)
|
|
|
|
hypothesis = [
|
|
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
|
for x in hypothesis
|
|
]
|
|
if add_new_line:
|
|
hypothesis = hypothesis
|
|
|
|
with open(outfname, "w") as f:
|
|
f.write("\n".join(hypothesis))
|
|
|
|
|
|
|
|
|
|
infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
|
|
outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.mbart50")
|
|
|
|
with open(infname, "r") as f:
|
|
src_sents = f.read().split("\n")
|
|
|
|
add_new_line = False
|
|
if src_sents[-1] == "":
|
|
add_new_line = True
|
|
src_sents = src_sents[:-1]
|
|
|
|
|
|
tokenizers["xxen"].src_lang = langs_supported[tgt_lang]
|
|
|
|
|
|
hypothesis = []
|
|
for i in tqdm(range(0, len(src_sents), batch_size)):
|
|
start, end = i, int(min(len(src_sents), i + batch_size))
|
|
batch = src_sents[start:end]
|
|
bos_token_id = tokenizers["xxen"].lang_code_to_id[langs_supported[src_lang]]
|
|
hypothesis += predict(
|
|
batch, tokenizers["xxen"], models["xxen"], bos_token_id
|
|
)
|
|
|
|
assert len(hypothesis) == len(src_sents)
|
|
|
|
hypothesis = [
|
|
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
|
|
for x in hypothesis
|
|
]
|
|
if add_new_line:
|
|
hypothesis = hypothesis
|
|
|
|
with open(outfname, "w") as f:
|
|
f.write("\n".join(hypothesis))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
devtest_data_dir = sys.argv[1]
|
|
batch_size = int(sys.argv[2])
|
|
|
|
if not torch.cuda.is_available():
|
|
print("No GPU available")
|
|
sys.exit(1)
|
|
|
|
main(devtest_data_dir, batch_size)
|
|
|