Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
from AinaTheme import theme | |
from huggingface_hub import snapshot_download | |
import subprocess | |
import os | |
from translate import translate_nos | |
load_dotenv() | |
MODELS_PATH = "./models" | |
HF_CACHE_DIR = "./hf_cache" | |
MAX_INPUT_CHARACTERS = int(os.environ.get("MAX_INPUT_CHARACTERS", default=1000)) | |
LANGS_WITHOUT_SUBWORDING = ["English","Spanish","Galician"] | |
LANGS_WITH_SUBWORDING = ["Catalan","Basque"] | |
# Model paths e languages avaliables ----------------------------------------------------------- | |
def download_model(repo_id, revision="main"): | |
return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR) | |
def write_text_to_file(filename, text): | |
with open(filename, 'w') as file: | |
file.write(text) | |
"""" | |
print("Downloading model gl-es...") | |
model_dir_gl_es = download_model("proxectonos/Nos_MT-OpenNMT-gl-es", revision="main") | |
print("Downloading model es-gl...") | |
model_dir_es_gl = download_model("proxectonos/Nos_MT-OpenNMT-es-gl", revision="main") | |
print("Downloading model gl-en...") | |
model_dir_gl_en = download_model("proxectonos/Nos_MT-OpenNMT-gl-en", revision="main") | |
print("Downloading model en-gl...") | |
model_dir_en_gl = download_model("proxectonos/Nos_MT-OpenNMT-en-gl", revision="main") | |
model_dir_gl_ca = "" | |
print("Downloading model ca-gl...") | |
model_dir_ca_gl = download_model("proxectonos/Nos_MT-OpenNMT-ca-gl", revision="main") | |
""" | |
model_dir_gl_es = model_dir_es_gl = model_dir_gl_en = model_dir_en_gl = model_dir_gl_ca = model_dir_gl_eu= " " | |
print("Downloading model ca-gl...") | |
model_dir_ca_gl = download_model("proxectonos/Nos_MT-OpenNMT-ca-gl", revision="main") | |
print("Downloading model eu-gl...") | |
model_dir_eu_gl = download_model("proxectonos/Nos_MT-OpenNMT-eu-gl", revision="main") | |
print("Downloading model gl-en...") | |
model_dir_gl_en = download_model("proxectonos/Nos_MT-OpenNMT-gl-en", revision="main") | |
print("Downloading model en-gl...") | |
model_dir_en_gl = download_model("proxectonos/Nos_MT-OpenNMT-en-gl", revision="main") | |
print("Models downloaded correctly!") | |
print(f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}") | |
print(os.listdir(f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}")) | |
directions_reduced = { | |
"Catalan": { | |
"target": { | |
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ca-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ct2_detok-ca-gl_sint_10k")}, | |
} | |
}, | |
"Basque": { | |
"target": { | |
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/gl-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/eu_gl.ct2_10k")}, | |
} | |
} | |
} | |
directions = { | |
"Galician": { | |
"target": { | |
"Spanish": {"src": "gl", "tgt":"es","model": (f"{os.path.join(MODELS_PATH, model_dir_gl_es)}/bpe/es.code", f"{os.path.join(MODELS_PATH, model_dir_gl_es)}")}, | |
"English": {"model": (f"{os.path.join(MODELS_PATH, model_dir_gl_en)}/bpe/en.code", f"{os.path.join(MODELS_PATH, model_dir_gl_en)}")}, | |
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_gl_ca)}/bpe/ca.code", f"{os.path.join(MODELS_PATH, model_dir_gl_ca)}")}, | |
"Basque": {"model": (f"{os.path.join(MODELS_PATH, model_dir_gl_eu)}/bpe/eu.code", f"{os.path.join(MODELS_PATH, model_dir_gl_eu)}")}, | |
} | |
}, | |
"Spanish": { | |
"target": { | |
"Galician": {"src": "es", "tgt":"gl","model": (f"{os.path.join(MODELS_PATH, model_dir_es_gl)}/bpe/gl.code", f"{os.path.join(MODELS_PATH, model_dir_es_gl)}")}, | |
} | |
}, | |
"English": { | |
"target": { | |
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_en_gl)}/bpe/gl.code", f"{os.path.join(MODELS_PATH, model_dir_en_gl)}")}, | |
} | |
}, | |
"Catalan": { | |
"target": { | |
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ca-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ct2_detok-ca-gl_sint_10k")}, | |
} | |
}, | |
"Basque": { | |
"target": { | |
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/gl-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/eu_gl.ct2_10k")}, | |
} | |
} | |
} | |
DEFAULT_SOURCE_LANGUAGE = list(directions.keys())[0] | |
# Translation fuctions ------------------------------------------------------------------------------ | |
def get_target_languages(source_language): | |
return list(directions.get(source_language, {}).get("target", {}).keys()) | |
def get_target_language_model(source_language, target_language): | |
# return directions.get(source_language, {}).get("target", {}).get(target_language, {}).get("model") | |
return directions.get(source_language, {}).get("target", {}).get(target_language, {}) | |
def translate(input, source_language, target_language): | |
translation = "" | |
if source_language in LANGS_WITHOUT_SUBWORDING: #ES, GL, EN | |
translation = translate_without_subwording(input, source_language, target_language) | |
elif source_language in LANGS_WITH_SUBWORDING: #CA, EU | |
translation = translate_with_subwording(input, source_language, target_language) | |
else: | |
raise Exception(f"Language {source_language} not supported") | |
return translation | |
def translate_without_subwording(input, source_language, target_language): | |
write_text_to_file('input.txt', input) | |
target_language_model = get_target_language_model(source_language, target_language) | |
command = f"onmt_translate -src input.txt -model {target_language_model.get('model')[1]} --output ./output_file.txt --replace_unk" | |
print("Comando: ",command) | |
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
stdout, stderr = process.communicate() | |
if process.returncode != 0: | |
raise Exception(f"Error occurred: {stderr.decode().strip()}") | |
with open ('./output_file.txt','r') as f: | |
resultado= f.read() | |
return resultado | |
def translate_with_subwording(input, source_language, target_language): | |
target_language_model = get_target_language_model(source_language, target_language) | |
translation = translate_nos(input,target_language_model.get('model')) | |
return translation | |
# Gradio UI ------------------------------------------------------------------------------ | |
def clear(): | |
return None, None | |
def change_interactive(text): | |
if len(text.strip()) > MAX_INPUT_CHARACTERS: | |
return gr.update(interactive = True), gr.update(interactive = False) | |
return gr.update(interactive = True), gr.update(interactive = True) | |
def update_target_languages_dropdown(source_language): | |
output_languages = get_target_languages(source_language) | |
return gr.update(choices=output_languages, value=output_languages[0], interactive=True) | |
with gr.Blocks(theme=theme) as app: | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=2): | |
placeholder_max_token = gr.Textbox( | |
visible=False, | |
interactive=False, | |
value= MAX_INPUT_CHARACTERS | |
) | |
source_language = gr.Dropdown(label="Source Language", choices=list(directions.keys()), value=DEFAULT_SOURCE_LANGUAGE) | |
input = gr.Textbox(placeholder="Enter a text here to translate.", max_lines=100, lines=12, show_label=False, interactive=True) | |
with gr.Row(variant="panel", equal_height=True): | |
gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""") | |
gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span> / {MAX_INPUT_CHARACTERS}</span>""") | |
with gr.Column(scale=2): | |
target_outputs = get_target_languages(DEFAULT_SOURCE_LANGUAGE) | |
#target_language = gr.Dropdown(choices=target_outputs, label="Target Language", value=target_outputs[0]) | |
target_language = gr.Radio(choices=target_outputs, label="Target Language", value=target_outputs[0]) | |
output = gr.Textbox(max_lines=100, lines=12, show_label=False, interactive=False, show_copy_button=True) | |
with gr.Row(variant="panel"): | |
clear_btn = gr.Button( | |
"Clear", | |
) | |
submit_btn = gr.Button( | |
"Submit", | |
variant="primary", | |
) | |
source_language.change(fn=update_target_languages_dropdown, inputs=[source_language], outputs=target_language) | |
input.change( | |
fn=change_interactive, | |
inputs=[input], | |
outputs=[clear_btn, submit_btn], | |
api_name=False | |
) | |
input.change( | |
fn=None, | |
inputs=[input], | |
js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """, | |
api_name=False | |
) | |
input.change( | |
fn=None, | |
inputs=[input, placeholder_max_token], | |
js="""(i, m) => { | |
document.getElementById('inputlenght').textContent = i.length + ' ' | |
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : ""; | |
}""", | |
api_name=False | |
) | |
clear_btn.click( | |
fn=clear, | |
inputs=[], | |
outputs=[input, output], | |
queue=False, | |
api_name=False | |
) | |
submit_btn.click( | |
fn=translate, | |
inputs=[input, source_language, target_language], | |
outputs=[output], | |
api_name="translate", | |
concurrency_limit=1, | |
) | |
app.launch(show_api=True) |