kusht55 commited on
Commit
c781d20
·
verified ·
1 Parent(s): b828f1f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from IndicTransToolkit import IndicProcessor
4
+ import gradio as gr
5
+
6
+ # Define the model and tokenizer
7
+ model_name = "ai4bharat/indictrans2-indic-indic-1B"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
10
+ ip = IndicProcessor(inference=True)
11
+
12
+ # Define the language codes
13
+ LANGUAGES = {
14
+ "Assamese (asm_Beng)": "asm_Beng",
15
+ "Kashmiri (kas_Arab)": "kas_Arab",
16
+ "Punjabi (pan_Guru)": "pan_Guru",
17
+ "Bengali (ben_Beng)": "ben_Beng",
18
+ "Kashmiri (kas_Deva)": "kas_Deva",
19
+ "Sanskrit (san_Deva)": "san_Deva",
20
+ "Bodo (brx_Deva)": "brx_Deva",
21
+ "Maithili (mai_Deva)": "mai_Deva",
22
+ "Santali (sat_Olck)": "sat_Olck",
23
+ "Dogri (doi_Deva)": "doi_Deva",
24
+ "Malayalam (mal_Mlym)": "mal_Mlym",
25
+ "Sindhi (snd_Arab)": "snd_Arab",
26
+ "English (eng_Latn)": "eng_Latn",
27
+ "Marathi (mar_Deva)": "mar_Deva",
28
+ "Sindhi (snd_Deva)": "snd_Deva",
29
+ "Konkani (gom_Deva)": "gom_Deva",
30
+ "Manipuri (mni_Beng)": "mni_Beng",
31
+ "Tamil (tam_Taml)": "tam_Taml",
32
+ "Gujarati (guj_Gujr)": "guj_Gujr",
33
+ "Manipuri (mni_Mtei)": "mni_Mtei",
34
+ "Telugu (tel_Telu)": "tel_Telu",
35
+ "Hindi (hin_Deva)": "hin_Deva",
36
+ "Nepali (npi_Deva)": "npi_Deva",
37
+ "Urdu (urd_Arab)": "urd_Arab",
38
+ "Kannada (kan_Knda)": "kan_Knda",
39
+ "Odia (ory_Orya)": "ory_Orya",
40
+ }
41
+
42
+ # Define the translation function
43
+ def translate(text, src_lang, tgt_lang):
44
+ batch = ip.preprocess_batch([text], src_lang=src_lang, tgt_lang=tgt_lang)
45
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
+ inputs = tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
47
+ with torch.no_grad():
48
+ generated_tokens = model.generate(
49
+ **inputs,
50
+ use_cache=True,
51
+ min_length=0,
52
+ max_length=256,
53
+ num_beams=5,
54
+ num_return_sequences=1,
55
+ )
56
+ with tokenizer.as_target_tokenizer():
57
+ generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
58
+ return generated_text
59
+
60
+ # Create a Gradio interface
61
+ with gr.Blocks() as demo:
62
+ gr.Markdown("### Indic Translations")
63
+ input_text = gr.Textbox(label="Input Text", placeholder="Enter text to translate")
64
+ src_lang = gr.Dropdown(label="Source Language", choices=list(LANGUAGES.keys()))
65
+ tgt_lang = gr.Dropdown(label="Target Language", choices=list(LANGUAGES.keys()))
66
+ translate_button = gr.Button("Translate")
67
+ translation_output = gr.Textbox(label="Translation", interactive=False)
68
+
69
+ @translate_button.click
70
+ def on_translate(text, src_lang, tgt_lang):
71
+ translation = translate(text, LANGUAGES[src_lang], LANGUAGES[tgt_lang])
72
+ translation_output.value = translation
73
+
74
+ demo.launch()