File size: 5,178 Bytes
f9d7028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import re
import sys
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

langs_supported = [
    "asm_Beng",
    "ben_Beng",
    "guj_Gujr",
    "eng_Latn",
    "hin_Deva",
    "kas_Deva",
    "kas_Arab",
    "kan_Knda",
    "mal_Mlym",
    "mai_Deva",
    "mar_Deva",
    "mni_Beng",
    "npi_Deva",
    "ory_Orya",
    "pan_Guru",
    "san_Deva",
    "snd_Arab",
    "sat_Olck",
    "tam_Taml",
    "tel_Telu",
    "urd_Arab",
]


def predict(batch, tokenizer, model, bos_token_id):
    encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
    generated_tokens = model.generate(
        **encoded_batch,
        num_beams=5,
        max_length=256,
        min_length=0,
        forced_bos_token_id=bos_token_id,
    )
    hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return hypothesis


def main(devtest_data_dir, batch_size):
    # load the pre-trained NLLB tokenizer and model
    model_name = "facebook/nllb-moe-54b"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    model.eval()

    # iterate over a list of language pairs from `devtest_data_dir`
    for pair in sorted(os.listdir(devtest_data_dir)):
        if "-" not in pair:
            continue

        src_lang, tgt_lang = pair.split("-")

        # check if the source and target languages are supported
        if (
            src_lang not in langs_supported.keys()
            or tgt_lang not in langs_supported.keys()
        ):
            print(f"Skipping {src_lang}-{tgt_lang} ...")
            continue

        # -------------------------------------------------------------------
        #                   source to target evaluation
        # -------------------------------------------------------------------
        print(f"Evaluating {src_lang}-{tgt_lang} ...")

        infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}")
        outfname = os.path.join(
            devtest_data_dir, pair, f"test.{tgt_lang}.pred.nllb_moe"
        )

        with open(infname, "r") as f:
            src_sents = f.read().split("\n")

        add_new_line = False
        if src_sents[-1] == "":
            add_new_line = True
            src_sents = src_sents[:-1]

        # set the source language for tokenization
        tokenizer.src_lang = src_lang

        # process sentences in batches and generate predictions
        hypothesis = []
        for i in tqdm(range(0, len(src_sents), batch_size)):
            start, end = i, int(min(len(src_sents), i + batch_size))
            batch = src_sents[start:end]
            if tgt_lang == "sat_Olck":
                bos_token_id = tokenizer.lang_code_to_id["sat_Beng"]
            else:
                bos_token_id = tokenizer.lang_code_to_id[tgt_lang]
            hypothesis += predict(batch, tokenizer, model, bos_token_id)

        assert len(hypothesis) == len(src_sents)

        hypothesis = [
            re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
            for x in hypothesis
        ]
        if add_new_line:
            hypothesis = hypothesis

        with open(outfname, "w") as f:
            f.write("\n".join(hypothesis))

        # -------------------------------------------------------------------
        #                   target to source evaluation
        # -------------------------------------------------------------------
        infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}")
        outfname = os.path.join(
            devtest_data_dir, pair, f"test.{src_lang}.pred.nllb_moe"
        )

        with open(infname, "r") as f:
            src_sents = f.read().split("\n")

        add_new_line = False
        if src_sents[-1] == "":
            add_new_line = True
            src_sents = src_sents[:-1]

        # set the source language for tokenization
        tokenizer.src_lang = "sat_Beng" if tgt_lang == "sat_Olck" else tgt_lang

        # process sentences in batches and generate predictions
        hypothesis = []
        for i in tqdm(range(0, len(src_sents), batch_size)):
            start, end = i, int(min(len(src_sents), i + batch_size))
            batch = src_sents[start:end]
            bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]]
            hypothesis += predict(batch, tokenizer, model, bos_token_id)

        assert len(hypothesis) == len(src_sents)

        hypothesis = [
            re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip()
            for x in hypothesis
        ]
        if add_new_line:
            hypothesis = hypothesis

        with open(outfname, "w") as f:
            f.write("\n".join(hypothesis))


if __name__ == "__main__":
    # expects En-X subdirectories pairs within the devtest data directory
    devtest_data_dir = sys.argv[1]
    batch_size = int(sys.argv[2])

    main(devtest_data_dir, batch_size)