File size: 3,814 Bytes
f9d7028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import sys
import glob
from tqdm import tqdm
from google.cloud import translate

# Expects a json file containing the API credentials.
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.join(
    os.path.dirname(__file__), r"api_key.json"
)

flores_to_iso = {
    "asm_Beng": "as",
    "ben_Beng": "bn",
    "doi_Deva": "doi",
    "eng_Latn": "en",
    "gom_Deva": "gom",
    "guj_Gujr": "gu",
    "hin_Deva": "hi",
    "kan_Knda": "kn",
    "mai_Deva": "mai",
    "mal_Mlym": "ml",
    "mar_Deva": "mr",
    "mni_Mtei": "mni_Mtei",
    "npi_Deva": "ne",
    "ory_Orya": "or",
    "pan_Guru": "pa",
    "san_Deva": "sa",
    "sat_Olck": "sat",
    "snd_Arab": "sd",
    "tam_Taml": "ta",
    "tel_Telu": "te",
    "urd_Arab": "ur",
}


# Copy the project id from the json file containing API credentials
def translate_text(text, src_lang, tgt_lang, project_id="project_id"):

    src_lang = flores_to_iso[src_lang]
    tgt_lang = flores_to_iso[tgt_lang]

    if src_lang == "mni_Mtei":
        src_lang = "mni-Mtei"

    if tgt_lang == "mni_Mtei":
        tgt_lang = "mni-Mtei"

    client = translate.TranslationServiceClient()

    location = "global"

    parent = f"projects/{project_id}/locations/{location}"

    response = client.translate_text(
        request={
            "parent": parent,
            "contents": [text],
            "mime_type": "text/plain",  # mime types: text/plain, text/html
            "source_language_code": src_lang,
            "target_language_code": tgt_lang,
        }
    )

    translated_text = ""
    for translation in response.translations:
        translated_text += translation.translated_text

    return translated_text


if __name__ == "__main__":
    root_dir = sys.argv[1]

    pairs = sorted(glob.glob(os.path.join(root_dir, "*")))

    for pair in pairs:

        print(pair)

        basename = os.path.basename(pair)

        src_lang, tgt_lang = basename.split("-")
        if src_lang not in flores_to_iso.keys() or tgt_lang not in flores_to_iso.keys():
            continue

        if src_lang == "eng_Latn":
            lang = tgt_lang
        else:
            lang = src_lang

        lang = flores_to_iso[lang]

        if lang not in "as bn doi gom gu hi kn mai ml mni_Mtei mr ne or pa sa sd ta te ur":
            continue

        print(f"{src_lang} - {tgt_lang}")

        # source to target translations

        src_infname = os.path.join(pair, f"test.{src_lang}")
        tgt_outfname = os.path.join(pair, f"test.{tgt_lang}.pred.google")
        if os.path.exists(src_infname) and not os.path.exists(tgt_outfname):
            src_sents = [
                sent.replace("\n", "").strip()
                for sent in open(src_infname, "r").read().split("\n")
                if sent
            ]
            translations = [
                translate_text(text, src_lang, tgt_lang).strip() for text in tqdm(src_sents)
            ]
            with open(tgt_outfname, "w") as f:
                f.write("\n".join(translations))

        # # target to source translations
        tgt_infname = os.path.join(pair, f"test.{tgt_lang}")
        src_outfname = os.path.join(pair, f"test.{src_lang}.pred.google")
        if os.path.exists(tgt_infname) and not os.path.exists(src_outfname):
            tgt_sents = [
                sent.replace("\n", "").strip()
                for sent in open(tgt_infname, "r").read().split("\n")
                if sent
            ]
            translations = [
                translate_text(text, tgt_lang, src_lang).strip() for text in tqdm(tgt_sents)
            ]

            with open(src_outfname, "w") as f:
                f.write("\n".join(translations))