import torch from transformers import MarianMTModel, AutoTokenizer import ctranslate2 from colorize import align_words import logging # Create a logger logger = logging.getLogger() logger.setLevel(logging.INFO) # Set to debug to capture all levels of logs file_handler = logging.FileHandler('app.log', mode='a') # 'a' mode appends to the file file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) logger.addHandler(file_handler) model_to_ar = MarianMTModel.from_pretrained("./en_ar/", output_attentions=True) model_from_ar = MarianMTModel.from_pretrained("./ar_en/", output_attentions=True) model_to_ar_ct2 = ctranslate2.Translator("./en_ar_ct2/") model_from_ar_ct2 = ctranslate2.Translator("./ar_en_ct2/") tokenizer_to_ar = AutoTokenizer.from_pretrained("./en_ar/") tokenizer_from_ar = AutoTokenizer.from_pretrained("./ar_en/") print("Done loading models") dialect_map = { "Palestinian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E", "פלסטיני": "P", "סורי": "S", "לבנוני": "L", "מצרי": "E" } def translate(text, ct_model, hf_model, tokenizer, to_arabic=True, threshold=None, layer=2, head=6): logger.info(f"Translating: {text}") inp_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text)) out_tokens = ct_model.translate_batch([inp_tokens])[0].hypotheses[0] out_string = tokenizer.convert_tokens_to_string(out_tokens) encoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(inp_tokens)).unsqueeze(0) decoder_input_ids = torch.tensor(tokenizer.convert_tokens_to_ids([""] + out_tokens + [''])).unsqueeze(0) colorization_output = hf_model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) if not threshold: if len(inp_tokens) < 10: threshold = 0.05 elif len(inp_tokens) < 20: threshold = 0.10 else: threshold = 0.05 srchtml, tgthtml = align_words(colorization_output, tokenizer, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=to_arabic, skip_second_src=False, layer=layer, head=head) html = f"
{srchtml}

{tgthtml}
" arabic = out_string if is_arabic(out_string) else text return html, arabic #%% def is_arabic(text): # return True if text has more than 50% arabic characters, False otherwise text = text.replace(" ", "") arabic_chars = 0 for c in text: if "\u0600" <= c <= "\u06FF": arabic_chars += 1 return arabic_chars / len(text) > 0.5 def run_translate(text, dialect=None): if not text: return "" if is_arabic(text): return translate(text, model_from_ar_ct2, model_from_ar, tokenizer_from_ar, to_arabic=False, threshold=None, layer=2, head=7) else: if dialect in dialect_map: dialect = dialect_map[dialect] text = f"{dialect} {text}" if dialect else text return translate(text, model_to_ar_ct2, model_to_ar, tokenizer_to_ar, to_arabic=True, threshold=None, layer=2, head=7)