Commit
·
2e6a359
1
Parent(s):
b9562f5
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import pipeline
|
4 |
+
import ast
|
5 |
+
|
6 |
+
translation_task_names = {
|
7 |
+
'English to French': 'translation_en_to_fr',
|
8 |
+
'French to English': 'translation_fr_to_en',
|
9 |
+
'English to Spanish': 'translation_en_to_es',
|
10 |
+
'Spanish to English': 'translation_es_to_en',
|
11 |
+
'English to German': 'translation_en_to_de',
|
12 |
+
'German to English': 'translation_de_to_en',
|
13 |
+
'English to Italian': 'translation_en_to_it',
|
14 |
+
'Italian to English': 'translation_it_to_en',
|
15 |
+
'English to Dutch': 'translation_en_to_nl',
|
16 |
+
'Dutch to English': 'translation_nl_to_en',
|
17 |
+
'English to Portuguese': 'translation_en_to_pt',
|
18 |
+
'Portuguese to English': 'translation_pt_to_en',
|
19 |
+
'English to Russian': 'translation_en_to_ru',
|
20 |
+
'Russian to English': 'translation_ru_to_en',
|
21 |
+
'English to Chinese': 'translation_en_to_zh',
|
22 |
+
'Chinese to English': 'translation_zh_to_en',
|
23 |
+
'English to Japanese': 'translation_en_to_ja',
|
24 |
+
'Japanese to English': 'translation_ja_to_en',
|
25 |
+
}
|
26 |
+
|
27 |
+
# Create a dictionary to store loaded models
|
28 |
+
loaded_models = {}
|
29 |
+
|
30 |
+
# Simple translation function
|
31 |
+
def translate_text(task_choice, text_input, load_in_8bit, device):
|
32 |
+
model_key = (task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
|
33 |
+
|
34 |
+
# Check if the model is already loaded
|
35 |
+
if model_key in loaded_models:
|
36 |
+
translator = loaded_models[model_key]
|
37 |
+
else:
|
38 |
+
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
|
39 |
+
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
|
40 |
+
translator = pipeline(task=translation_task_names[task_choice],
|
41 |
+
device=device, # Use selected device
|
42 |
+
model_kwargs=model_kwargs,
|
43 |
+
torch_dtype=dtype, # Set the floating point
|
44 |
+
use_fast=True
|
45 |
+
)
|
46 |
+
# Store the loaded model
|
47 |
+
loaded_models[model_key] = translator
|
48 |
+
|
49 |
+
translation = translator(text_input)[0]['translation_text']
|
50 |
+
return str(translation).strip()
|
51 |
+
|
52 |
+
def launch(task_choice, text_input, load_in_8bit, device):
|
53 |
+
return translate_text(task_choice, text_input, load_in_8bit, device)
|
54 |
+
|
55 |
+
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
|
56 |
+
text_input = gr.Textbox(label="Input Text") # Single line text input
|
57 |
+
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
|
58 |
+
device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
|
59 |
+
|
60 |
+
iface = gr.Interface(launch, inputs=[task_dropdown, text_input, load_in_8bit, device],
|
61 |
+
outputs=gr.outputs.Textbox(type="text", label="Translation"))
|
62 |
+
iface.launch()
|