Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, MarianMTModel, AutoTokenizer | |
import os | |
import azure.cognitiveservices.speech as speechsdk | |
import matplotlib.pyplot as plt | |
import numpy as np | |
dialects = {"Palestinian/Jordanian": "P", "Syrian": "S", "Lebanese": "L", "Egyptian": "E"} | |
# translator_en2ar = pipeline(task="translation", model="guymorlan/English2Dialect") | |
translator_en2ar = MarianMTModel.from_pretrained("guymorlan/English2Dialect", output_attentions=True) | |
tokenizer_en2ar = AutoTokenizer.from_pretrained("guymorlan/English2Dialect") | |
translator_ar2en = MarianMTModel.from_pretrained("guymorlan/Shami2English", output_attentions=True) | |
tokenizer_ar2en = AutoTokenizer.from_pretrained("guymorlan/Shami2English") | |
transliterator = pipeline(task="translation", model="guymorlan/DialectTransliterator") | |
speech_config = speechsdk.SpeechConfig(subscription=os.environ.get('SPEECH_KEY'), region=os.environ.get('SPEECH_REGION')) | |
def generate_diverging_colors(num_colors, palette='Set3'): # courtesy of ChatGPT | |
# Generate a colormap with a specified number of colors | |
cmap = plt.cm.get_cmap(palette, num_colors) | |
# Get the RGB values of the colors in the colormap | |
colors_rgb = cmap(np.arange(num_colors)) | |
# Convert the RGB values to hexadecimal color codes | |
colors_hex = [format(int(color[0]*255)<<16|int(color[1]*255)<<8|int(color[2]*255), '06x') for color in colors_rgb] | |
return colors_hex | |
def align_words(outputs, tokenizer, encoder_input_ids, decoder_input_ids, threshold=0.4, skip_first_src=True): | |
alignment = [] | |
for i, tok in enumerate(outputs.cross_attentions[2][0][7]): | |
alignment.append([[i], (tok > threshold).nonzero().squeeze(-1).tolist()]) | |
merged = [] | |
for i in alignment: | |
token = tokenizer.convert_ids_to_tokens([decoder_input_ids[0][i[0]]])[0] | |
if token not in tokenizer.convert_tokens_to_ids(["</s>", "<pad>", "<unk>"]): | |
if merged: | |
tomerge = False | |
# check overlap with previous entry | |
for x in i[1]: | |
if x in merged[-1][1]:# or tokenizer.convert_ids_to_tokens([encoder_input_ids[0][x]])[0][0] != "โ": | |
tomerge = True | |
break | |
# if first character is not a "โ" | |
if token[0] != "โ": | |
tomerge = True | |
if tomerge: | |
merged[-1][0] += i[0] | |
merged[-1][1] += i[1] | |
else: | |
merged.append(i) | |
else: | |
merged.append(i) | |
colordict = {} | |
ncolors = 0 | |
for i in merged: | |
src_tok = [f"src_{x}" for x in i[0]] | |
trg_tok = [f"trg_{x}" for x in i[1]] | |
all_tok = src_tok + trg_tok | |
# see if any tokens in entry already have associated color | |
newcolor = None | |
for t in all_tok: | |
if t in colordict: | |
newcolor = colordict[t] | |
break | |
if not newcolor: | |
newcolor = ncolors | |
ncolors += 1 | |
for t in all_tok: | |
if t not in colordict: | |
colordict[t] = newcolor | |
colors = generate_diverging_colors(ncolors, palette="Set2") | |
id_to_color = {i: c for i, c in enumerate(colors)} | |
for k, v in colordict.items(): | |
colordict[k] = id_to_color[v] | |
tgthtml = [] | |
for i, token in enumerate(decoder_input_ids[0]): | |
if f"src_{i}" in colordict: | |
label = f"src_{i}" | |
tgthtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
else: | |
tgthtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
tgthtml = "".join(tgthtml) | |
tgthtml = tgthtml.replace("โ", " ") | |
tgthtml = f"<span style='font-size: 30px'>{tgthtml}</span>" | |
srchtml = [] | |
for i, token in enumerate(encoder_input_ids[0]): | |
if skip_first_src and i == 0: | |
continue | |
if f"trg_{i}" in colordict: | |
label = f"trg_{i}" | |
srchtml.append(f"<span style='color: #{colordict[label]}'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
else: | |
srchtml.append(f"<span style='color: --color-text-body'>{tokenizer.convert_ids_to_tokens([token])[0]}</span>") | |
srchtml = "".join(srchtml) | |
srchtml = srchtml.replace("โ", " ") | |
srchtml = f"<span style='font-size: 30px'>{srchtml}</span>" | |
return srchtml, tgthtml | |
def translate_english(input_text, include): | |
if not input_text: | |
return "", "", "", "", "" | |
inputs = [f"{val} {input_text}" for val in dialects.values()] | |
sy, lb, eg = "SYR" in include, "LEB" in include, "EGY" in include | |
# remove 2nd element if sy is false | |
if not eg: | |
inputs.pop() | |
if not lb: | |
inputs.pop() | |
if not sy: | |
inputs.pop() | |
input_tokens = tokenizer_en2ar(inputs, return_tensors="pt").input_ids | |
# print(input_tokens) | |
outputs = translator_en2ar.generate(input_tokens) | |
# print(outputs) | |
encoder_input_ids = input_tokens[0].unsqueeze(0) | |
decoder_input_ids = outputs[0].unsqueeze(0) | |
decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True) | |
# print(decoded) | |
pal_out = decoded[0] | |
sy_out = decoded[1] if sy else "" | |
lb_out = decoded[1 + sy] if lb else "" | |
eg_out = decoded[1 + sy + lb] if eg else "" | |
if "Colorize" in include: | |
html_outputs = translator_en2ar(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) | |
# set dynamic threshold | |
# print(input_tokens, input_tokens.shape) | |
if input_tokens.shape[1] < 10: | |
threshold = 0.4 | |
elif input_tokens.shape[1] < 20: | |
threshold = 0.10 | |
else: | |
threshold = 0.05 | |
print("threshold", threshold) | |
srchtml, tgthtml = align_words(html_outputs, tokenizer_en2ar, encoder_input_ids, decoder_input_ids, threshold) | |
palhtml = f"{srchtml}<br><br><div style='direction: rtl'>{tgthtml}</div>" | |
else: | |
palhtml = f"<div style='font-size: 30px; direction: rtl'>{pal_out}</div>" | |
return palhtml, pal_out, sy_out, lb_out, eg_out | |
def translate_arabic(input_text, include=["Colorize"]): | |
if not input_text: | |
return "" | |
input_tokens = tokenizer_ar2en(input_text, return_tensors="pt").input_ids | |
# print(input_tokens) | |
outputs = translator_ar2en.generate(input_tokens) | |
# print(outputs) | |
encoder_input_ids = input_tokens[0].unsqueeze(0) | |
decoder_input_ids = outputs[0].unsqueeze(0) | |
decoded = tokenizer_en2ar.batch_decode(outputs, skip_special_tokens=True) | |
# print(decoded) | |
print(include) | |
if "Colorize" in include: | |
html_outputs = translator_ar2en(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids) | |
# set dynamic threshold | |
# print(input_tokens, input_tokens.shape) | |
if input_tokens.shape[1] < 20: | |
threshold = 0.1 | |
elif input_tokens.shape[1] < 30: | |
threshold = 0.01 | |
else: | |
threshold = 0.05 | |
print("threshold", threshold) | |
srchtml, tgthtml = align_words(html_outputs, tokenizer_ar2en, encoder_input_ids, decoder_input_ids, threshold, skip_first_src=False) | |
enhtml = f"<div style='direction: rtl'>{srchtml}</div><br><br><div>{tgthtml}</div>" | |
else: | |
enhtml = f"<div style='font-size: 30px;'>{decoded[0]}</div>" | |
return enhtml | |
def get_audio(input_text): | |
audio_config = speechsdk.audio.AudioOutputConfig(filename=f"{input_text}.wav") | |
speech_config.speech_synthesis_voice_name='ar-SY-AmanyNeural' | |
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config) | |
speech_synthesis_result = speech_synthesizer.speak_text_async(input_text).get() | |
return f"{input_text}.wav" | |
def get_transliteration(input_text, include=["Translit."]): | |
if "Translit." not in include: | |
return "" | |
result = transliterator([input_text]) | |
return result[0]["translation_text"] | |
bla = """ | |
""" | |
css = """ | |
#liter textarea, #trans textarea { font-size: 25px;} | |
#trans textarea { direction: rtl; } | |
#check { border-style: none !important; } | |
:root {--button-secondary-background-focus: #2563eb !important; | |
--button-secondary-background-base: #2563eb !important; | |
--button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2); | |
--button-secondary-text-color-base: white !important; | |
--button-secondary-text-color-hover: white !important; | |
--button-secondary-background-focus: rgb(51 122 216 / 70%) !important; | |
--button-secondary-text-color-focus: white !important} | |
.dark {--button-secondary-background-base: #2563eb !important; | |
--button-secondary-background-focus: rgb(51 122 216 / 70%) !important; | |
--button-secondary-background-hover: linear-gradient(to bottom right, #0692e8, #5859c2)} | |
.feather-music { stroke: #2563eb; } | |
""" | |
def toggle_visibility(include): | |
outs = [gr.Textbox.update(visible=True)] * 4 | |
if "Translit." not in include: | |
outs[0] = gr.Textbox.update(visible=False) | |
if "SYR" not in include: | |
outs[1] = gr.Textbox.update(visible=False) | |
if "LEB" not in include: | |
outs[2] = gr.Textbox.update(visible=False) | |
if "EGY" not in include: | |
outs[3] = gr.Textbox.update(visible=False) | |
return outs | |
with gr.Blocks(title = "Levantine Arabic Translator", css=css, theme="default") as demo: | |
gr.HTML("<h2><span style='color: #2563eb'>Levantine Arabic</span> Translator</h2>") | |
with gr.Tab('En > Ar'): | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input", placeholder="Enter English text", lines=2) | |
gr.Examples(["I wanted to go to the store yesterday, but it rained", "How are you feeling today?"], input_text) | |
btn = gr.Button("Translate", label="Translate") | |
with gr.Row(): | |
include = gr.CheckboxGroup(["Translit.", "SYR", "LEB", "EGY", "Colorize"], | |
label="Disable features to speed up translation", | |
value=["Translit.", "EGY", "Colorize"]) | |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il). Pronunciation model is specifically tailored to urban Palestinian Arabic. Text-to-speech uses Microsoft Azure's API and may provide different result from the transliterated pronunciation.") | |
with gr.Column(): | |
with gr.Box(label = "Palestinian"): | |
gr.Markdown("Palestinian") | |
with gr.Box(): | |
pal_html = gr.HTML("<br>", visible=True, label="Palestinian", elem_id="main") | |
pal = gr.Textbox(lines=1, label="Palestinian", elem_id="trans", visible=False) | |
pal_translit = gr.Textbox(lines=1, label="Palestinian Pronunciation (Urban)", elem_id="liter") | |
sy = gr.Textbox(lines=1, label="Syrian", elem_id="trans", visible=False) | |
lb = gr.Textbox(lines=1, label="Lebanese", elem_id="trans", visible=False) | |
eg = gr.Textbox(lines=1, label="Egyptian", elem_id="trans") | |
# with gr.Row(): | |
audio = gr.Audio(label="Audio - Palestinian", interactive=False) | |
audio_button = gr.Button("Get Audio", label="Click Here to Get Audio") | |
audio_button.click(get_audio, inputs=[pal], outputs=[audio]) | |
btn.click(translate_english,inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg], _js="function jump(x, y){document.getElementById('main').scrollIntoView(); return [x, y];}") | |
input_text.submit(translate_english, inputs=[input_text, include], outputs=[pal_html, pal, sy, lb, eg],scroll_to_output=True) | |
pal.change(get_transliteration, inputs=[pal, include], outputs=[pal_translit]); | |
include.change(toggle_visibility, inputs=[include], outputs=[pal_translit, sy, lb, eg]) | |
with gr.Tab('Ar > En'): | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1, elem_id="trans") | |
gr.Examples(["ุฎูููุง ูุฏูุฑ ุนูู ู ุทุนู ุชุงูู", "ูุฏูุด ุญู ุงูุจูุฏูุฑุฉุ"], input_text) | |
btn = gr.Button("Translate", label="Translate") | |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il).") | |
with gr.Column(): | |
with gr.Box(label = "English"): | |
gr.Markdown("English") | |
with gr.Box(): | |
eng = gr.HTML("<br>", label="English", elem_id="main") | |
btn.click(translate_arabic,inputs=input_text, outputs=[eng]) | |
with gr.Tab("Transliterate"): | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Input", placeholder="Enter Levantine Arabic text", lines=1) | |
gr.Examples(["ุฎูููุง ูุฏูุฑ ุนูู ู ุทุนู ุชุงูู", "ูุฏูุด ุญู ุงูุจูุฏูุฑุฉุ"], input_text) | |
btn = gr.Button("Transliterate", label="Transliterate") | |
gr.Markdown("Built by [Guy Mor-Lan](mailto:guy.mor@mail.huji.ac.il)") | |
with gr.Column(): | |
translit = gr.Textbox(label="Transliteration", lines=1, elem_id="liter") | |
btn.click(get_transliteration, inputs=input_text, outputs=[translit]) | |
demo.launch() |