Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
import random | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelWithLMHead | |
from sentence_splitter import SentenceSplitter, split_text_into_sentences | |
splitter = SentenceSplitter(language='en') | |
if torch.cuda.is_available(): | |
torch_device="cuda:0" | |
else: | |
torch_device="cpu" | |
ptokenizer = AutoTokenizer.from_pretrained("tuner007/pegasus_paraphrase") | |
pmodel = AutoModelForSeq2SeqLM.from_pretrained("tuner007/pegasus_paraphrase").to(torch_device) | |
def get_answer(input_text,num_return_sequences,num_beams): | |
batch = ptokenizer([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device) | |
translated = pmodel.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5) | |
tgt_text = ptokenizer.batch_decode(translated, skip_special_tokens=True) | |
return tgt_text | |
qtokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
qmodel = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap").to(torch_device) | |
def get_question(answer, context, max_length=64): | |
input_text = "answer: %s context: %s </s>" % (answer, context) | |
features = qtokenizer([input_text], return_tensors='pt').to(torch_device) | |
output = qmodel.generate(input_ids=features['input_ids'], | |
attention_mask=features['attention_mask'], | |
max_length=max_length) | |
return qtokenizer.decode(output[0]) | |
def getqna(input): | |
input=split_text_into_sentences(text=input, language='en') | |
if len(input)==0: | |
answer= get_answer(input,10,10)[random.randint(0, 9)] | |
else: | |
sentences=[get_answer(sentence,10,10)[random.randint(0, 9)] for sentence in input] | |
answer= " ".join(sentences) | |
answer= get_answer(answer,10,10)[random.randint(0, 9)] | |
question= get_question(answer, input).replace("<pad>","").replace("</s>","") | |
return "%s \n answer:%s" % (question, answer) | |
app = gr.Interface(fn=getqna, inputs="text", outputs="text") | |
app.launch() |