Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
+
from IndicTransToolkit import IndicProcessor
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
# Define the model and tokenizer
|
7 |
+
model_name = "ai4bharat/indictrans2-indic-indic-1B"
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
9 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
|
10 |
+
ip = IndicProcessor(inference=True)
|
11 |
+
|
12 |
+
# Define the language codes
|
13 |
+
LANGUAGES = {
|
14 |
+
"Assamese (asm_Beng)": "asm_Beng",
|
15 |
+
"Kashmiri (kas_Arab)": "kas_Arab",
|
16 |
+
"Punjabi (pan_Guru)": "pan_Guru",
|
17 |
+
"Bengali (ben_Beng)": "ben_Beng",
|
18 |
+
"Kashmiri (kas_Deva)": "kas_Deva",
|
19 |
+
"Sanskrit (san_Deva)": "san_Deva",
|
20 |
+
"Bodo (brx_Deva)": "brx_Deva",
|
21 |
+
"Maithili (mai_Deva)": "mai_Deva",
|
22 |
+
"Santali (sat_Olck)": "sat_Olck",
|
23 |
+
"Dogri (doi_Deva)": "doi_Deva",
|
24 |
+
"Malayalam (mal_Mlym)": "mal_Mlym",
|
25 |
+
"Sindhi (snd_Arab)": "snd_Arab",
|
26 |
+
"English (eng_Latn)": "eng_Latn",
|
27 |
+
"Marathi (mar_Deva)": "mar_Deva",
|
28 |
+
"Sindhi (snd_Deva)": "snd_Deva",
|
29 |
+
"Konkani (gom_Deva)": "gom_Deva",
|
30 |
+
"Manipuri (mni_Beng)": "mni_Beng",
|
31 |
+
"Tamil (tam_Taml)": "tam_Taml",
|
32 |
+
"Gujarati (guj_Gujr)": "guj_Gujr",
|
33 |
+
"Manipuri (mni_Mtei)": "mni_Mtei",
|
34 |
+
"Telugu (tel_Telu)": "tel_Telu",
|
35 |
+
"Hindi (hin_Deva)": "hin_Deva",
|
36 |
+
"Nepali (npi_Deva)": "npi_Deva",
|
37 |
+
"Urdu (urd_Arab)": "urd_Arab",
|
38 |
+
"Kannada (kan_Knda)": "kan_Knda",
|
39 |
+
"Odia (ory_Orya)": "ory_Orya",
|
40 |
+
}
|
41 |
+
|
42 |
+
# Define the translation function
|
43 |
+
def translate(text, src_lang, tgt_lang):
|
44 |
+
batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang)
|
45 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
46 |
+
inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
|
47 |
+
with torch.no_grad():
|
48 |
+
generated_tokens = model.generate(
|
49 |
+
**inputs,
|
50 |
+
use_cache=True,
|
51 |
+
min_length=0,
|
52 |
+
max_length=256,
|
53 |
+
num_beams=5,
|
54 |
+
num_return_sequences=1,
|
55 |
+
)
|
56 |
+
with tokenizer.as_target_tokenizer():
|
57 |
+
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
58 |
+
return generated_text
|
59 |
+
|
60 |
+
# Create a Gradio interface
|
61 |
+
with gr.Blocks() as demo:
|
62 |
+
gr.Markdown("### Indic Translations")
|
63 |
+
input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate")
|
64 |
+
src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys()))
|
65 |
+
tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys()))
|
66 |
+
translate_button = gr.Button("Translate")
|
67 |
+
translation_output = gr.Textbox(label="Translation", interactive=False)
|
68 |
+
|
69 |
+
@translate_button.click
|
70 |
+
def on_translate(text, src_lang, tgt_lang):
|
71 |
+
translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang])
|
72 |
+
translation_output.value = translation
|
73 |
+
|
74 |
+
demo.launch()
|