Spaces:
Running
Running
import torch | |
from transformers import MarianMTModel, AutoTokenizer | |
import ctranslate2 | |
from colorize import align_words | |
import logging | |
# Create a logger | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) # Set to debug to capture all levels of logs | |
file_handler = logging.FileHandler('app.log', mode='a') # 'a' mode appends to the file | |
file_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
model_to_ar = MarianMTModel.from_pretrained("./he_ar/", output_attentions=True) | |
model_from_ar = MarianMTModel.from_pretrained("./ar_he/", output_attentions=True) | |
model_to_ar_ct2 = ctranslate2.Translator("./he_ar_ct2/") | |
model_from_ar_ct2 = ctranslate2.Translator("./ar_he_ct2/") | |
tokenizer_to_ar = AutoTokenizer.from_pretrained("./he_ar/") | |
tokenizer_from_ar = AutoTokenizer.from_pretrained("./ar_he/") | |
print("Done loading models") | |
dialect_map = { | |
"Palestinian": "P", | |
"Syrian": "S", | |
"Lebanese": "L", | |
"Egyptian": "E", | |
"פלסטיני": "P", | |
"סורי": "S", | |
"לבנוני": "L", | |
"מצרי": "E" | |
} | |
def translate(text, ct_model, hf_model, tokenizer, to_arabic=True, | |
threshold=None, layer=2, head=6): | |
logger.info(f"Translating: {text}") | |
inp_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text)) | |
out_tokens = ct_model.translate_batch([inp_tokens])[0].hypotheses[0] | |
out_string = tokenizer.convert_tokens_to_string(out_tokens) | |
encoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(inp_tokens)).unsqueeze(0) | |
decoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(["<pad>"] + out_tokens + | |
['</s>'])).unsqueeze(0) | |
colorization_output = hf_model(input_ids=encoder_input_ids, | |
decoder_input_ids=decoder_input_ids) | |
if not threshold: | |
if len(inp_tokens) < 10: | |
threshold = 0.05 | |
elif len(inp_tokens) < 20: | |
threshold = 0.10 | |
else: | |
threshold = 0.05 | |
srchtml, tgthtml = align_words(colorization_output, | |
tokenizer, | |
encoder_input_ids, | |
decoder_input_ids, | |
threshold, | |
skip_first_src=to_arabic, | |
skip_second_src=False, | |
layer=layer, | |
head=head) | |
html = f"<div style='direction: rtl'>{srchtml}<br><br>{tgthtml}</div>" | |
arabic = out_string if is_arabic(out_string) else text | |
return html, arabic | |
#%% | |
def is_arabic(text): | |
# return True if text has more than 50% arabic characters, False otherwise | |
text = text.replace(" ", "") | |
arabic_chars = 0 | |
for c in text: | |
if "\u0600" <= c <= "\u06FF": | |
arabic_chars += 1 | |
return arabic_chars / len(text) > 0.5 | |
def run_translate(text, dialect=None): | |
if not text: | |
return "" | |
if is_arabic(text): | |
return translate(text, model_from_ar_ct2, model_from_ar, tokenizer_from_ar, | |
to_arabic=False, threshold=None, layer=2, head=1) | |
else: | |
if dialect in dialect_map: | |
dialect = dialect_map[dialect] | |
text = f"{dialect} {text}" if dialect else text | |
return translate(text, model_to_ar_ct2, model_to_ar, tokenizer_to_ar, | |
to_arabic=True, threshold=None, layer=2, head=6) |