File size: 3,647 Bytes
46f657a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from transformers import MarianMTModel, AutoTokenizer
import ctranslate2
from colorize import align_words
import logging

# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)  # Set to debug to capture all levels of logs
file_handler = logging.FileHandler('app.log', mode='a')  # 'a' mode appends to the file
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

model_to_ar = MarianMTModel.from_pretrained("./he_ar/", output_attentions=True)
model_from_ar = MarianMTModel.from_pretrained("./ar_he/", output_attentions=True)
model_to_ar_ct2 = ctranslate2.Translator("./he_ar_ct2/")
model_from_ar_ct2 = ctranslate2.Translator("./ar_he_ct2/")

tokenizer_to_ar = AutoTokenizer.from_pretrained("./he_ar/")
tokenizer_from_ar = AutoTokenizer.from_pretrained("./ar_he/")
print("Done loading models")

dialect_map = {
    "Palestinian": "P",
    "Syrian": "S",
    "Lebanese": "L",
    "Egyptian": "E",
    "פלסטיני": "P",
    "סורי": "S",
    "לבנוני": "L",
    "מצרי": "E"
}


def translate(text, ct_model, hf_model, tokenizer, to_arabic=True,
              threshold=None, layer=2, head=6):

    logger.info(f"Translating: {text}")
    inp_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
    out_tokens = ct_model.translate_batch([inp_tokens])[0].hypotheses[0]
    out_string = tokenizer.convert_tokens_to_string(out_tokens)

    encoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(inp_tokens)).unsqueeze(0)
    decoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(["<pad>"] + out_tokens + 
                                                                     ['</s>'])).unsqueeze(0)
    
    colorization_output = hf_model(input_ids=encoder_input_ids,
                                   decoder_input_ids=decoder_input_ids)
    
    if not threshold:
        if len(inp_tokens) < 10:
            threshold = 0.05
        elif len(inp_tokens) < 20:
            threshold = 0.10
        else:
            threshold = 0.05

    srchtml, tgthtml = align_words(colorization_output,
                                   tokenizer,
                                   encoder_input_ids,
                                   decoder_input_ids,
                                   threshold,
                                   skip_first_src=to_arabic,
                                   skip_second_src=False,
                                   layer=layer,
                                   head=head)

    html = f"<div style='direction: rtl'>{srchtml}<br><br>{tgthtml}</div>"
    
    arabic = out_string if is_arabic(out_string) else text
    return html, arabic


#%%


def is_arabic(text):
    # return True if text has more than 50% arabic characters, False otherwise
    text = text.replace(" ", "")
    arabic_chars = 0
    for c in text:
        if "\u0600" <= c <= "\u06FF":
            arabic_chars += 1
            
    return arabic_chars / len(text) > 0.5

def run_translate(text, dialect=None):
    if not text:
        return ""
    if is_arabic(text):
        return translate(text, model_from_ar_ct2, model_from_ar, tokenizer_from_ar,
                         to_arabic=False, threshold=None, layer=2, head=1)
    else:
        if dialect in dialect_map:
            dialect = dialect_map[dialect]

        text = f"{dialect} {text}" if dialect else text
        return translate(text, model_to_ar_ct2, model_to_ar, tokenizer_to_ar,
                          to_arabic=True, threshold=None, layer=2, head=6)