File size: 3,116 Bytes
38742d7
 
71ae380
38742d7
 
71ae380
38742d7
 
6bcde50
38742d7
 
 
 
 
 
 
3f23d73
38742d7
 
3f23d73
 
 
 
 
 
 
 
38742d7
 
 
71ae380
 
 
 
 
 
 
3f23d73
 
5cb6981
3f23d73
5cb6981
71ae380
 
 
 
 
 
5cb6981
71ae380
 
 
 
 
 
 
 
3f23d73
15ccfd9
 
38742d7
71ae380
7dc20b3
 
 
 
 
 
71ae380
1473813
38742d7
0badc10
 
 
 
 
38742d7
 
 
bd312ad
0badc10
38742d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from flores import code_mapping
import platform
import torch

device = "cpu" if platform.system() == "Darwin" else "cuda"
MODEL_NAME = "facebook/nllb-200-3.3B"

code_mapping = dict(sorted(code_mapping.items(), key=lambda item: item[1]))
flores_codes = list(code_mapping.keys())


def load_model():
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
    return model


model = load_model()


def load_tokenizer(src_lang, tgt_lang):
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME, src_lang=code_mapping[src_lang], tgt_lang=code_mapping[tgt_lang]
    )
    return tokenizer


@spaces.GPU
def translate(
    text: str,
    src_lang: str,
    tgt_lang: str,
    window_size: int = 800,
    overlap_size: int = 200,
):
    tokenizer = load_tokenizer(src_lang, tgt_lang)

    input_tokens = (
        tokenizer(text, return_tensors="pt").input_ids[0].cpu().numpy().tolist()
    )
    translated_chunks = []

    for i in range(0, len(input_tokens), window_size - overlap_size):
        window = input_tokens[i : i + window_size]
        translated_chunk = model.generate(
            input_ids=torch.tensor([window]).to(device),
            forced_bos_token_id=tokenizer.lang_code_to_id[code_mapping[tgt_lang]],
            max_length=window_size,
            num_return_sequences=1,
        )
        translated_chunk = tokenizer.decode(
            translated_chunk[0], skip_special_tokens=True
        )
        translated_chunks.append(translated_chunk)

    return " ".join(translated_chunks)


description = """
No Language Left Behind (NLLB) is a series of open-source models aiming to provide high-quality translations between 200 languages.
This demo application allows you to use the NLLB model to translate text between a source and target language.

## Notes 

- Whilst the model supports 200 languages, the quality of translations may vary between languages. 
- "Low Resource" languages (languages which are less present on the internet and have a lower amount of investment) may have lower quality translations.
- The demo uses a sliding window approach to handle longer texts.
"""

instructions = """
1. Select the source and target language from the dropdown menus.
2. Enter the text you would like to translate.
3. Click the 'Translate text' button.
"""
with gr.Blocks() as demo:
    gr.Markdown("# No Language Left Behind (NLLB) Translation Demo")
    gr.Markdown(description)
    gr.Markdown("## Instructions")
    gr.Markdown(instructions)
    with gr.Row():
        src_lang = gr.Dropdown(label="Source Language", choices=flores_codes)
        target_lang = gr.Dropdown(label="Target Language", choices=flores_codes)
    with gr.Row():
        input_text = gr.Textbox(label="Input Text", lines=6)
    with gr.Row():
        btn = gr.Button("Translate text")
    with gr.Row():
        output = gr.Textbox(label="Output Text", lines=6)
    btn.click(
        translate,
        inputs=[input_text, src_lang, target_lang],
        outputs=output,
    )
demo.launch()