Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,8 @@ dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egypt
|
|
10 |
# translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
|
11 |
translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
|
12 |
tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
|
13 |
-
translator_ar2en =
|
|
|
14 |
transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
|
15 |
|
16 |
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
|
@@ -28,7 +29,7 @@ def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT
|
|
28 |
return colors_hex
|
29 |
|
30 |
|
31 |
-
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4):
|
32 |
alignment = []
|
33 |
for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
|
34 |
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
|
@@ -93,7 +94,7 @@ def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, thresh
|
|
93 |
|
94 |
srchtml = []
|
95 |
for i, token in enumerate(encoder_input_ids[0]):
|
96 |
-
if i == 0:
|
97 |
continue
|
98 |
if f"trg_{i}" in colordict:
|
99 |
label = f"trg_{i}"
|
@@ -158,13 +159,42 @@ def translate_english(input_text, include):
|
|
158 |
|
159 |
return palhtml, pal_out, sy_out, lb_out, eg_out
|
160 |
|
161 |
-
def translate_arabic(input_text):
|
162 |
if not input_text:
|
163 |
return ""
|
164 |
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
def get_audio(input_text):
|
170 |
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
|
@@ -244,6 +274,7 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default")
|
|
244 |
input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
|
245 |
pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
|
246 |
include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
|
|
|
247 |
with gr.Tab('Ar > En'):
|
248 |
with gr.Row():
|
249 |
with gr.Column():
|
@@ -252,8 +283,12 @@ with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default")
|
|
252 |
btn = gr.Button("Translate", label="Translate")
|
253 |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
|
254 |
with gr.Column():
|
255 |
-
|
|
|
|
|
|
|
256 |
btn.click(translate_arabic,inputs=input_text, outputs=[eng])
|
|
|
257 |
with gr.Tab("Transliterate"):
|
258 |
with gr.Row():
|
259 |
with gr.Column():
|
|
|
10 |
# translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect")
|
11 |
translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True)
|
12 |
tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect")
|
13 |
+
translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True)
|
14 |
+
tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English")
|
15 |
transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator")
|
16 |
|
17 |
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION'))
|
|
|
29 |
return colors_hex
|
30 |
|
31 |
|
32 |
+
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True):
|
33 |
alignment = []
|
34 |
for i, tok in enumerate(outputs.cross_attentions[2][0][7]):
|
35 |
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()])
|
|
|
94 |
|
95 |
srchtml = []
|
96 |
for i, token in enumerate(encoder_input_ids[0]):
|
97 |
+
if skip_first_src and i == 0:
|
98 |
continue
|
99 |
if f"trg_{i}" in colordict:
|
100 |
label = f"trg_{i}"
|
|
|
159 |
|
160 |
return palhtml, pal_out, sy_out, lb_out, eg_out
|
161 |
|
162 |
+
def translate_arabic(input_text, include=["Colorize"]):
|
163 |
if not input_text:
|
164 |
return ""
|
165 |
|
166 |
+
input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids
|
167 |
+
# print(input_tokens)
|
168 |
+
outputs = translator_ar2en.generate(input_tokens)
|
169 |
+
# print(outputs)
|
170 |
+
|
171 |
+
encoder_input_ids = input_tokens[0].unsqueeze(0)
|
172 |
+
decoder_input_ids = outputs[0].unsqueeze(0)
|
173 |
|
174 |
+
decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True)
|
175 |
+
# print(decoded)
|
176 |
+
|
177 |
+
print(include)
|
178 |
+
if "Colorize" in include:
|
179 |
+
html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)
|
180 |
+
|
181 |
+
# set dynamic threshold
|
182 |
+
# print(input_tokens, input_tokens.shape)
|
183 |
+
if input_tokens.shape[1] < 20:
|
184 |
+
threshold = 0.1
|
185 |
+
elif input_tokens.shape[1] < 30:
|
186 |
+
threshold = 0.01
|
187 |
+
else:
|
188 |
+
threshold = 0.05
|
189 |
+
|
190 |
+
print("threshold", threshold)
|
191 |
+
|
192 |
+
srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False)
|
193 |
+
enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>"
|
194 |
+
else:
|
195 |
+
enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>"
|
196 |
+
|
197 |
+
return enhtml
|
198 |
|
199 |
def get_audio(input_text):
|
200 |
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav")
|
|
|
274 |
input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True)
|
275 |
pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]);
|
276 |
include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg])
|
277 |
+
|
278 |
with gr.Tab('Ar > En'):
|
279 |
with gr.Row():
|
280 |
with gr.Column():
|
|
|
283 |
btn = gr.Button("Translate", label="Translate")
|
284 |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).")
|
285 |
with gr.Column():
|
286 |
+
with gr.Box(label = "English"):
|
287 |
+
gr.Markdown("English")
|
288 |
+
with gr.Box():
|
289 |
+
eng = gr.HTML("<br>", label="English", elem_id="main")
|
290 |
btn.click(translate_arabic,inputs=input_text, outputs=[eng])
|
291 |
+
|
292 |
with gr.Tab("Transliterate"):
|
293 |
with gr.Row():
|
294 |
with gr.Column():
|