Init
Browse files- .gitignore +3 -0
- README.md +2 -1
- app.py +92 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
pyproject.toml
|
2 |
+
.venv/
|
3 |
+
flagged/
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
title: Question Group Generator
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.4
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Question Group Generator
|
3 |
+
emoji: π§βπ«
|
4 |
colorFrom: blue
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.4
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.8.9
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import BartTokenizerFast, BartForConditionalGeneration
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
from qgg_utils.optim import GAOptimizer # https://github.com/p208p2002/qgg-utils.git
|
6 |
+
|
7 |
+
MAX_LENGTH=512
|
8 |
+
|
9 |
+
default_context = "Facebook is an online social media and social networking service owned by American company Meta Platforms. Founded in 2004 by Mark Zuckerberg with fellow Harvard College students and roommates Eduardo Saverin, Andrew McCollum, Dustin Moskovitz, and Chris Hughes, its name comes from the face book directories often given to American university students. Membership was initially limited to Harvard students, gradually expanding to other North American universities and, since 2006, anyone over 13 years old. As of July 2022, Facebook claimed 2.93 billion monthly active users,[6] and ranked third worldwide among the most visited websites as of July 2022. It was the most downloaded mobile app of the 2010s."
|
10 |
+
|
11 |
+
model=BartForConditionalGeneration.from_pretrained("p208p2002/qmst-qgg")
|
12 |
+
tokenizer=BartTokenizerFast.from_pretrained("p208p2002/qmst-qgg")
|
13 |
+
|
14 |
+
def feedback_generation(model, tokenizer, input_ids, feedback_times = 3):
|
15 |
+
outputs = []
|
16 |
+
device = 'cpu'
|
17 |
+
for i in range(feedback_times):
|
18 |
+
gened_text = tokenizer.bos_token * (len(outputs)+1)
|
19 |
+
gened_ids = tokenizer(gened_text,add_special_tokens=False)['input_ids']
|
20 |
+
input_ids = gened_ids + input_ids
|
21 |
+
input_ids = input_ids[:MAX_LENGTH]
|
22 |
+
|
23 |
+
sample_outputs = model.generate(
|
24 |
+
input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device),
|
25 |
+
attention_mask=torch.LongTensor([1]*len(input_ids)).unsqueeze(0).to(device),
|
26 |
+
max_length=50,
|
27 |
+
early_stopping=True,
|
28 |
+
temperature=1.0,
|
29 |
+
do_sample=True,
|
30 |
+
top_p=0.9,
|
31 |
+
top_k=10,
|
32 |
+
num_beams=1,
|
33 |
+
no_repeat_ngram_size=5,
|
34 |
+
num_return_sequences=1,
|
35 |
+
)
|
36 |
+
sample_output = sample_outputs[0]
|
37 |
+
decode_question = tokenizer.decode(sample_output, skip_special_tokens=False)
|
38 |
+
decode_question = re.sub(re.escape(tokenizer.pad_token),'',decode_question)
|
39 |
+
decode_question = re.sub(re.escape(tokenizer.eos_token),'',decode_question)
|
40 |
+
if tokenizer.bos_token is not None:
|
41 |
+
decode_question = re.sub(re.escape(tokenizer.bos_token),'',decode_question)
|
42 |
+
decode_question = decode_question.strip()
|
43 |
+
decode_question = decode_question.replace("[Q:]","")
|
44 |
+
outputs.append(decode_question)
|
45 |
+
return outputs
|
46 |
+
|
47 |
+
def gen_quesion_group(context,question_group_size):
|
48 |
+
question_group_size = int(question_group_size)
|
49 |
+
print(context,question_group_size)
|
50 |
+
candidate_pool_size = question_group_size*2
|
51 |
+
tokenize_result = tokenizer.batch_encode_plus(
|
52 |
+
[context],
|
53 |
+
stride=MAX_LENGTH - int(MAX_LENGTH*0.7),
|
54 |
+
max_length=MAX_LENGTH,
|
55 |
+
truncation=True,
|
56 |
+
add_special_tokens=False,
|
57 |
+
return_overflowing_tokens=True,
|
58 |
+
return_length=True,
|
59 |
+
)
|
60 |
+
candidate_questions = []
|
61 |
+
|
62 |
+
if len(tokenize_result.input_ids)>=10:
|
63 |
+
tokenize_result.input_ids = tokenize_result.input_ids[:10]
|
64 |
+
|
65 |
+
for input_ids in tokenize_result.input_ids:
|
66 |
+
candidate_questions += feedback_generation(
|
67 |
+
model=model,
|
68 |
+
tokenizer=tokenizer,
|
69 |
+
input_ids=input_ids,
|
70 |
+
feedback_times=candidate_pool_size
|
71 |
+
)
|
72 |
+
|
73 |
+
while len(candidate_questions) > question_group_size:
|
74 |
+
qgg_optim = GAOptimizer(len(candidate_questions),question_group_size)
|
75 |
+
candidate_questions = qgg_optim.optimize(candidate_questions,context)
|
76 |
+
|
77 |
+
# format
|
78 |
+
candidate_questions = [f" - {q}" for q in candidate_questions]
|
79 |
+
return '\n'.join(candidate_questions)
|
80 |
+
|
81 |
+
demo = gr.Interface(
|
82 |
+
fn=gen_quesion_group,
|
83 |
+
inputs=[
|
84 |
+
gr.Textbox(lines=10, value=default_context, label="Context",placeholder="Paste some context here"),
|
85 |
+
gr.Slider(3, 8,step=1,label="Group Size")
|
86 |
+
],
|
87 |
+
outputs=gr.Textbox(
|
88 |
+
lines = 8,
|
89 |
+
label = "Generation Question Group"
|
90 |
+
),
|
91 |
+
)
|
92 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.4
|
2 |
+
torch==1.12.1
|
3 |
+
transformers==4.22.2
|
4 |
+
git+https://github.com/p208p2002/qgg-utils.git
|
5 |
+
git+https://github.com/voidful/nlg-eval.git@master
|
6 |
+
stanza
|