import os import re import sys from tqdm import tqdm import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM langs_supported = [ "asm_Beng", "ben_Beng", "guj_Gujr", "eng_Latn", "hin_Deva", "kas_Deva", "kas_Arab", "kan_Knda", "mal_Mlym", "mai_Deva", "mar_Deva", "mni_Beng", "npi_Deva", "ory_Orya", "pan_Guru", "san_Deva", "snd_Arab", "sat_Olck", "tam_Taml", "tel_Telu", "urd_Arab", ] def predict(batch, tokenizer, model, bos_token_id): encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device) generated_tokens = model.generate( **encoded_batch, num_beams=5, max_length=256, min_length=0, forced_bos_token_id=bos_token_id, ) hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return hypothesis def main(devtest_data_dir, batch_size): # load the pre-trained NLLB tokenizer and model model_name = "facebook/nllb-moe-54b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) model.eval() # iterate over a list of language pairs from `devtest_data_dir` for pair in sorted(os.listdir(devtest_data_dir)): if "-" not in pair: continue src_lang, tgt_lang = pair.split("-") # check if the source and target languages are supported if ( src_lang not in langs_supported.keys() or tgt_lang not in langs_supported.keys() ): print(f"Skipping {src_lang}-{tgt_lang} ...") continue # ------------------------------------------------------------------- # source to target evaluation # ------------------------------------------------------------------- print(f"Evaluating {src_lang}-{tgt_lang} ...") infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}") outfname = os.path.join( devtest_data_dir, pair, f"test.{tgt_lang}.pred.nllb_moe" ) with open(infname, "r") as f: src_sents = f.read().split("\n") add_new_line = False if src_sents[-1] == "": add_new_line = True src_sents = src_sents[:-1] # set the source language for tokenization tokenizer.src_lang = src_lang # process sentences in batches and generate predictions hypothesis = [] for i in tqdm(range(0, len(src_sents), batch_size)): start, end = i, int(min(len(src_sents), i + batch_size)) batch = src_sents[start:end] if tgt_lang == "sat_Olck": bos_token_id = tokenizer.lang_code_to_id["sat_Beng"] else: bos_token_id = tokenizer.lang_code_to_id[tgt_lang] hypothesis += predict(batch, tokenizer, model, bos_token_id) assert len(hypothesis) == len(src_sents) hypothesis = [ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() for x in hypothesis ] if add_new_line: hypothesis = hypothesis with open(outfname, "w") as f: f.write("\n".join(hypothesis)) # ------------------------------------------------------------------- # target to source evaluation # ------------------------------------------------------------------- infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}") outfname = os.path.join( devtest_data_dir, pair, f"test.{src_lang}.pred.nllb_moe" ) with open(infname, "r") as f: src_sents = f.read().split("\n") add_new_line = False if src_sents[-1] == "": add_new_line = True src_sents = src_sents[:-1] # set the source language for tokenization tokenizer.src_lang = "sat_Beng" if tgt_lang == "sat_Olck" else tgt_lang # process sentences in batches and generate predictions hypothesis = [] for i in tqdm(range(0, len(src_sents), batch_size)): start, end = i, int(min(len(src_sents), i + batch_size)) batch = src_sents[start:end] bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]] hypothesis += predict(batch, tokenizer, model, bos_token_id) assert len(hypothesis) == len(src_sents) hypothesis = [ re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() for x in hypothesis ] if add_new_line: hypothesis = hypothesis with open(outfname, "w") as f: f.write("\n".join(hypothesis)) if __name__ == "__main__": # expects En-X subdirectories pairs within the devtest data directory devtest_data_dir = sys.argv[1] batch_size = int(sys.argv[2]) main(devtest_data_dir, batch_size)