p208p2002's picture
Init
b947701
import gradio as gr
from transformers import BartTokenizerFast, BartForConditionalGeneration
import torch
import re
from qgg_utils.optim import GAOptimizer # https://github.com/p208p2002/qgg-utils.git
MAX_LENGTH=512
default_context = "Facebook is an online social media and social networking service owned by American company Meta Platforms. Founded in 2004 by Mark Zuckerberg with fellow Harvard College students and roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz, and Chris Hughes, its name comes from the face book directories often given to American university students. Membership was initially limited to Harvard students, gradually expanding to other North American universities and, since 2006, anyone over 13 years old. As of July 2022, Facebook claimed 2.93 billion monthly active users,[6] and ranked third worldwide among the most visited websites as of July 2022. It was the most downloaded mobile app of the 2010s."
model=BartForConditionalGeneration.from_pretrained("p208p2002/qmst-qgg")
tokenizer=BartTokenizerFast.from_pretrained("p208p2002/qmst-qgg")
def feedback_generation(model, tokenizer, input_ids, feedback_times = 3):
outputs = []
device = 'cpu'
for i in range(feedback_times):
gened_text = tokenizer.bos_token * (len(outputs)+1)
gened_ids = tokenizer(gened_text,add_special_tokens=False)['input_ids']
input_ids = gened_ids + input_ids
input_ids = input_ids[:MAX_LENGTH]
sample_outputs = model.generate(
input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device),
attention_mask=torch.LongTensor([1]*len(input_ids)).unsqueeze(0).to(device),
max_length=50,
early_stopping=True,
temperature=1.0,
do_sample=True,
top_p=0.9,
top_k=10,
num_beams=1,
no_repeat_ngram_size=5,
num_return_sequences=1,
)
sample_output = sample_outputs[0]
decode_question = tokenizer.decode(sample_output, skip_special_tokens=False)
decode_question = re.sub(re.escape(tokenizer.pad_token),'',decode_question)
decode_question = re.sub(re.escape(tokenizer.eos_token),'',decode_question)
if tokenizer.bos_token is not None:
decode_question = re.sub(re.escape(tokenizer.bos_token),'',decode_question)
decode_question = decode_question.strip()
decode_question = decode_question.replace("[Q:]","")
outputs.append(decode_question)
return outputs
def gen_quesion_group(context,question_group_size):
question_group_size = int(question_group_size)
print(context,question_group_size)
candidate_pool_size = question_group_size*2
tokenize_result = tokenizer.batch_encode_plus(
[context],
stride=MAX_LENGTH - int(MAX_LENGTH*0.7),
max_length=MAX_LENGTH,
truncation=True,
add_special_tokens=False,
return_overflowing_tokens=True,
return_length=True,
)
candidate_questions = []
if len(tokenize_result.input_ids)>=10:
tokenize_result.input_ids = tokenize_result.input_ids[:10]
for input_ids in tokenize_result.input_ids:
candidate_questions += feedback_generation(
model=model,
tokenizer=tokenizer,
input_ids=input_ids,
feedback_times=candidate_pool_size
)
while len(candidate_questions) > question_group_size:
qgg_optim = GAOptimizer(len(candidate_questions),question_group_size)
candidate_questions = qgg_optim.optimize(candidate_questions,context)
# format
candidate_questions = [f" - {q}" for q in candidate_questions]
return '\n'.join(candidate_questions)
demo = gr.Interface(
fn=gen_quesion_group,
inputs=[
gr.Textbox(lines=10, value=default_context, label="Context",placeholder="Paste some context here"),
gr.Slider(3, 8,step=1,label="Group Size")
],
outputs=gr.Textbox(
lines = 8,
label = "Generation Question Group"
),
)
demo.launch()