Sanzana Lora commited on
Commit
39a2488
1 Parent(s): 41e4666

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
+ import gradio as gr
3
+
4
+ # Load fine-tuned T5 models for different tasks
5
+ translation_model_en_bn = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
6
+ translation_tokenizer_en_bn = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
7
+
8
+ translation_model_bn_en = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
9
+ translation_tokenizer_bn_en = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
10
+
11
+ summarization_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
12
+ summarization_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
13
+
14
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_banglaparaphrase")
15
+ paraphrase_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_banglaparaphrase")
16
+
17
+ # Function to perform machine translation
18
+ def translate_text(input_text):
19
+ inputs = translation_tokenizer_en_bn("translate: " + input_text, return_tensors="pt")
20
+ outputs = translation_model_en_bn.generate(**inputs)
21
+ translated_text = translation_tokenizer_en_bn.decode(outputs[0], skip_special_tokens=True)
22
+ return translated_text
23
+
24
+ # Function to perform summarization
25
+ def summarize_text(input_text):
26
+ inputs = summarization_tokenizer("summarize: " + input_text, return_tensors="pt")
27
+ outputs = summarization_model.generate(**inputs)
28
+ summarized_text = summarization_tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+ return summarized_text
30
+
31
+ # Function to perform paraphrasing
32
+ def paraphrase_text(input_text):
33
+ inputs = paraphrase_tokenizer("paraphrase: " + input_text, return_tensors="pt")
34
+ outputs = paraphrase_model.generate(**inputs)
35
+ paraphrased_text = paraphrase_tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return paraphrased_text
37
+
38
+ # Gradio Interface
39
+ iface = gr.Interface(
40
+ fn=translate_text, # Placeholder function; will be updated dynamically based on task selection
41
+ inputs=gr.Textbox("textarea", label="Input Text"),
42
+ outputs=gr.Textbox("auto", label="Output Text"),
43
+ live=True
44
+ )
45
+
46
+ # Function to update the Gradio interface based on task selection
47
+ def update_interface(change):
48
+ selected_task = task_selector.value
49
+
50
+ if selected_task == 'Translate':
51
+ iface.fn = translate_text
52
+ elif selected_task == 'Summarize':
53
+ iface.fn = summarize_text
54
+ elif selected_task == 'Paraphrase':
55
+ iface.fn = paraphrase_text
56
+
57
+ # Dropdown widget to select the task
58
+ task_selector = gr.Dropdown(
59
+ ["Translate", "Summarize", "Paraphrase"],
60
+ default="Translate",
61
+ label="Select Task"
62
+ )
63
+
64
+ # Attach the update function to the dropdown widget
65
+ task_selector.observe(update_interface, names='value')
66
+
67
+ # Launch the Gradio app
68
+ iface.launch()