nllb / app.py
davanstrien's picture
davanstrien HF staff
Fix newline formatting in translate function
126fd42
raw
history blame
2.41 kB
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores import code_mapping
import platform
device = "cpu" if platform.system() == "Darwin" else "cuda"
MODEL_NAME = "facebook/nllb-200-distilled-600M"
code_mapping = dict(sorted(code_mapping.items(), key=lambda item: item[1]))
flores_codes = list(code_mapping.keys())
def load_model():
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
return model, tokenizer
model, tokenizer = load_model()
@spaces.GPU
def _translate(text: str, src_lang: str, tgt_lang: str):
source = code_mapping[src_lang]
target = code_mapping[tgt_lang]
translator = pipeline(
"translation",
model=model,
tokenizer=tokenizer,
src_lang=source,
tgt_lang=target,
device=device,
)
output = translator(text, max_length=400)
return output[0]["translation_text"]
def translate(text: str, src_lang: str, tgt_lang: str):
# split the input text into smaller chunks
# split first on newlines
outputs = ""
paragraph_chunks = text.split("\n")
for chunk in paragraph_chunks:
# check if the chunk is too long
if len(chunk) > 500:
# split on full stops
sentence_chunks = chunk.split(".")
for sentence in sentence_chunks:
outputs += f"{_translate(sentence, src_lang, tgt_lang)}. "
else:
outputs += _translate(chunk, src_lang, tgt_lang) + "\n\n"
return outputs
description = """
No Language Left Behind (NLLB) is a series of open-source models aiming to provide high-quality translations between 200 language."""
with gr.Blocks() as demo:
gr.Markdown("# No Language Left Behind (NLLB) Translation Demo")
gr.Markdown(description)
with gr.Row():
src_lang = gr.Dropdown(label="Source Language", choices=flores_codes)
target_lang = gr.Dropdown(label="Target Language", choices=flores_codes)
with gr.Row():
input_text = gr.Textbox(label="Input Text", lines=6)
with gr.Row():
btn = gr.Button("Translate text")
with gr.Row():
output = gr.Textbox(label="Output Text", lines=6)
btn.click(
translate,
inputs=[input_text, src_lang, target_lang],
outputs=output,
)
demo.launch()