Spaces:
Build error
Build error
File size: 2,053 Bytes
c88f8ff 5c46280 c88f8ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
import torch
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelWithLMHead
from sentence_splitter import SentenceSplitter, split_text_into_sentences
splitter = SentenceSplitter(language='en')
if torch.cuda.is_available():
torch_device="cuda:0"
else:
torch_device="cpu"
ptokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase")
pmodel = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase").to(torch_device)
def get_answer(input_text,num_return_sequences,num_beams):
batch = ptokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
translated = pmodel.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
tgt_text = ptokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text
qtokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
qmodel = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap").to(torch_device)
def get_question(answer, context, max_length=64):
input_text = "answer: %s context: %s </s>" % (answer, context)
features = qtokenizer([input_text], return_tensors='pt').to(torch_device)
output = qmodel.generate(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
max_length=max_length)
return qtokenizer.decode(output[0])
def getqna(input):
input=split_text_into_sentences(text=input, language='en')
if len(input)==0:
answer= get_answer(input,10,10)[random.randint(0, 9)]
else:
sentences=[get_answer(sentence,10,10)[random.randint(0, 9)] for sentence in input]
answer= " ".join(sentences)
answer= get_answer(answer,10,10)[random.randint(0, 9)]
question= get_question(answer, input).replace("<pad>","").replace("</s>","")
return "%s \n answer:%s" % (question, answer)
app = gr.Interface(fn=getqna, inputs="text", outputs="text")
app.launch()
|