thoristhor commited on
Commit
4eb87d1
1 Parent(s): 7d44e9b
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import random
4
+
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelWithLMHead
6
+ from sentence_splitter import SentenceSplitter, split_text_into_sentences
7
+ splitter = SentenceSplitter(language='en')
8
+
9
+ if torch.cuda.is_available():
10
+ torch_device="cuda:0"
11
+ else:
12
+ torch_device="cpu"
13
+
14
+ ptokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
15
+ pmodel = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase").to(torch_device)
16
+
17
+ def get_answer(input_text,num_return_sequences,num_beams):
18
+ batch = ptokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
19
+ translated = pmodel.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
20
+ tgt_text = ptokenizer.batch_decode(translated, skip_special_tokens=True)
21
+ return tgt_text
22
+
23
+ qtokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
24
+ qmodel = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap").to(torch_device)
25
+
26
+ def get_question(answer, context, max_length=64):
27
+ input_text = "answer: %s context: %s </s>" % (answer, context)
28
+ features = qtokenizer([input_text], return_tensors='pt').to(torch_device)
29
+
30
+ output = qmodel.generate(input_ids=features['input_ids'],
31
+ attention_mask=features['attention_mask'],
32
+ max_length=max_length)
33
+
34
+ return qtokenizer.decode(output[0])
35
+
36
+ def getqna(input):
37
+ input=split_text_into_sentences(text=input, language='en')
38
+ if len(input)==0:
39
+ answer= get_answer(input,10,10)[random.randint(0, 9)]
40
+ else:
41
+ sentences=[get_answer(sentence,10,10)[random.randint(0, 9)] for sentence in input]
42
+ answer= " ".join(sentences)
43
+ answer= get_answer(answer,10,10)[random.randint(0, 9)]
44
+ question= get_question(answer, input).replace("<pad>","").replace("</s>","")
45
+ return "%s \n answer:%s" % (question, answer)
46
+
47
+ app = gr.Interface(fn=getqna, inputs="text", outputs="text")
48
+ app.launch()