File size: 3,617 Bytes
2e6a359
 
 
 
 
 
 
1700804
 
 
2e6a359
1700804
 
 
0bafb6f
 
1700804
 
c633a4d
 
37854cd
1700804
 
 
 
0bafb6f
2e6a359
 
2bbca52
 
 
37854cd
0bafb6f
 
 
 
c633a4d
 
 
2bbca52
 
2e6a359
 
 
 
2bbca52
 
2e6a359
 
 
 
 
 
 
 
2bbca52
2e6a359
 
 
 
 
 
 
 
 
 
 
2bbca52
 
2e6a359
2bbca52
2e6a359
 
 
a760944
 
2e6a359
2bbca52
365f5c4
2e6a359
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import gradio as gr
from transformers import pipeline
import ast

translation_task_names = {
    'English to French': 'translation_en_to_fr',
#    'French to English': 'translation_fr_to_en',
#    'English to Spanish': 'translation_en_to_es',
#    'Spanish to English': 'translation_es_to_en',
    'English to German': 'translation_en_to_de',
#    'German to English': 'translation_de_to_en',
#    'English to Italian': 'translation_en_to_it',
#    'Italian to English': 'translation_it_to_en',
    'English to Dutch': 'translation_en_to_nl',
    'Dutch to English': 'translation_nl_to_en',
#    'English to Portuguese': 'translation_en_to_pt',
#    'Portuguese to English': 'translation_pt_to_en',
    'English to Russian': 'translation_en_to_ru',
    'Russian to English': 'translation_ru_to_en',
    'English to Chinese': 'translation_en_to_zh',
#    'Chinese to English': 'translation_zh_to_en',
#    'English to Japanese': 'translation_en_to_ja',
#    'Japanese to English': 'translation_ja_to_en',
    'English to Romanian': 'translation_en_to_ro',
    'Swedish to English': 'translation_SV_to_EN',
}

model_names = {
    'T5-Base': 't5-base',
    'T5-Small': 't5-small',
    'T5-Large': 't5-large',
    'Opus-En-ZH': 'liam168/trans-opus-mt-en-zh',
    'DDDSSS/translation_en-zh': 'DDDSSS/translation_en-zh',
    'T5-Base-nl-en': 'yhavinga/t5-base-36L-ccmatrix-multi',
    'T5-Small-nl-en': 'yhavinga/t5-small-24L-ccmatrix-multi',
    'Opus-Sv-En': 'Helsinki-NLP/opus-mt-sv-en',
    'Opus-En-Ru': 'Helsinki-NLP/opus-mt-en-ru',
    'Opus-Ru-En': 'Helsinki-NLP/opus-mt-ru-en',
}

# Create a dictionary to store loaded models
loaded_models = {}

# Simple translation function
def translate_text(model_choice, task_choice, text_input, load_in_8bit, device):
    model_key = (model_choice, task_choice, load_in_8bit)  # Create a tuple to represent the unique combination of task and 8bit loading

    # Check if the model is already loaded
    if model_key in loaded_models:
        translator = loaded_models[model_key]
    else:
        model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
        dtype = torch.float16 if load_in_8bit else torch.float32  # Set dtype based on the value of load_in_8bit
        translator = pipeline(task=translation_task_names[task_choice],
                            model=model_names[model_choice],  # Use selected model
                            device=device,  # Use selected device
                            model_kwargs=model_kwargs, 
                            torch_dtype=dtype,  # Set the floating point
                            use_fast=True
                            )
        # Store the loaded model
        loaded_models[model_key] = translator

    translation = translator(text_input)[0]['translation_text']
    return str(translation).strip()

def launch(model_choice, task_choice, text_input, load_in_8bit, device):
    return translate_text(model_choice, task_choice, text_input, load_in_8bit, device)

model_dropdown = gr.Dropdown(choices=list(model_names.keys()), label='Select Model')
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
text_input = gr.Textbox(label="Input Text")  # Single line text input
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
# https://www.gradio.app/docs/radio
device = gr.Radio(['cpu', 'cuda'], label='Select device', value='cpu')

iface = gr.Interface(launch, inputs=[model_dropdown, task_dropdown, text_input, load_in_8bit, device], 
                     outputs=gr.Textbox(type="text", label="Translation"))
iface.launch()