Spaces:
Sleeping
Sleeping
Krishnan Palanisami
commited on
Upload 3 files
Browse files- main.py +473 -0
- requirements.txt +113 -0
- streamlit.py +511 -0
main.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %load questiongenerator.py
|
2 |
+
import en_core_web_sm
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import torch
|
8 |
+
from transformers import (
|
9 |
+
AutoTokenizer,
|
10 |
+
AutoModelForSeq2SeqLM,
|
11 |
+
AutoModelForSequenceClassification,
|
12 |
+
)
|
13 |
+
from typing import Any, List, Mapping, Tuple
|
14 |
+
|
15 |
+
|
16 |
+
class QuestionGenerator:
|
17 |
+
"""A transformer-based NLP system for generating reading comprehension-style questions from
|
18 |
+
texts. It can generate full sentence questions, multiple choice questions, or a mix of the
|
19 |
+
two styles.
|
20 |
+
|
21 |
+
To filter out low quality questions, questions are assigned a score and ranked once they have
|
22 |
+
been generated. Only the top k questions will be returned. This behaviour can be turned off
|
23 |
+
by setting use_evaluator=False.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self) -> None:
|
27 |
+
|
28 |
+
QG_PRETRAINED = "iarfmoose/t5-base-question-generator"
|
29 |
+
self.ANSWER_TOKEN = "<answer>"
|
30 |
+
self.CONTEXT_TOKEN = "<context>"
|
31 |
+
self.SEQ_LENGTH = 512
|
32 |
+
|
33 |
+
self.device = torch.device(
|
34 |
+
"cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
|
36 |
+
self.qg_tokenizer = AutoTokenizer.from_pretrained(
|
37 |
+
QG_PRETRAINED, use_fast=False)
|
38 |
+
self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(QG_PRETRAINED)
|
39 |
+
self.qg_model.to(self.device)
|
40 |
+
self.qg_model.eval()
|
41 |
+
|
42 |
+
self.qa_evaluator = QAEvaluator()
|
43 |
+
|
44 |
+
def generate(
|
45 |
+
self,
|
46 |
+
article: str,
|
47 |
+
use_evaluator: bool = True,
|
48 |
+
num_questions: bool = None,
|
49 |
+
answer_style: str = "all"
|
50 |
+
) -> List:
|
51 |
+
"""Takes an article and generates a set of question and answer pairs. If use_evaluator
|
52 |
+
is True then QA pairs will be ranked and filtered based on their quality. answer_style
|
53 |
+
should selected from ["all", "sentences", "multiple_choice"].
|
54 |
+
"""
|
55 |
+
|
56 |
+
print("Generating questions...\n")
|
57 |
+
|
58 |
+
qg_inputs, qg_answers = self.generate_qg_inputs(article, answer_style)
|
59 |
+
generated_questions = self.generate_questions_from_inputs(qg_inputs)
|
60 |
+
|
61 |
+
message = "{} questions doesn't match {} answers".format(
|
62 |
+
len(generated_questions), len(qg_answers)
|
63 |
+
)
|
64 |
+
assert len(generated_questions) == len(qg_answers), message
|
65 |
+
|
66 |
+
if use_evaluator:
|
67 |
+
print("Evaluating QA pairs...\n")
|
68 |
+
encoded_qa_pairs = self.qa_evaluator.encode_qa_pairs(
|
69 |
+
generated_questions, qg_answers
|
70 |
+
)
|
71 |
+
scores = self.qa_evaluator.get_scores(encoded_qa_pairs)
|
72 |
+
|
73 |
+
if num_questions:
|
74 |
+
qa_list = self._get_ranked_qa_pairs(
|
75 |
+
generated_questions, qg_answers, scores, num_questions
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
qa_list = self._get_ranked_qa_pairs(
|
79 |
+
generated_questions, qg_answers, scores
|
80 |
+
)
|
81 |
+
|
82 |
+
else:
|
83 |
+
print("Skipping evaluation step.\n")
|
84 |
+
qa_list = self._get_all_qa_pairs(generated_questions, qg_answers)
|
85 |
+
|
86 |
+
return qa_list
|
87 |
+
|
88 |
+
def generate_qg_inputs(self, text: str, answer_style: str) -> Tuple[List[str], List[str]]:
|
89 |
+
"""Given a text, returns a list of model inputs and a list of corresponding answers.
|
90 |
+
Model inputs take the form "answer_token <answer text> context_token <context text>" where
|
91 |
+
the answer is a string extracted from the text, and the context is the wider text surrounding
|
92 |
+
the context.
|
93 |
+
"""
|
94 |
+
|
95 |
+
VALID_ANSWER_STYLES = ["all", "sentences", "multiple_choice"]
|
96 |
+
|
97 |
+
if answer_style not in VALID_ANSWER_STYLES:
|
98 |
+
raise ValueError(
|
99 |
+
"Invalid answer style {}. Please choose from {}".format(
|
100 |
+
answer_style, VALID_ANSWER_STYLES
|
101 |
+
)
|
102 |
+
)
|
103 |
+
|
104 |
+
inputs = []
|
105 |
+
answers = []
|
106 |
+
|
107 |
+
if answer_style == "sentences" or answer_style == "all":
|
108 |
+
segments = self._split_into_segments(text)
|
109 |
+
|
110 |
+
for segment in segments:
|
111 |
+
sentences = self._split_text(segment)
|
112 |
+
prepped_inputs, prepped_answers = self._prepare_qg_inputs(
|
113 |
+
sentences, segment
|
114 |
+
)
|
115 |
+
inputs.extend(prepped_inputs)
|
116 |
+
answers.extend(prepped_answers)
|
117 |
+
|
118 |
+
if answer_style == "multiple_choice" or answer_style == "all":
|
119 |
+
sentences = self._split_text(text)
|
120 |
+
prepped_inputs, prepped_answers = self._prepare_qg_inputs_MC(
|
121 |
+
sentences
|
122 |
+
)
|
123 |
+
inputs.extend(prepped_inputs)
|
124 |
+
answers.extend(prepped_answers)
|
125 |
+
|
126 |
+
return inputs, answers
|
127 |
+
|
128 |
+
def generate_questions_from_inputs(self, qg_inputs: List) -> List[str]:
|
129 |
+
"""Given a list of concatenated answers and contexts, with the form:
|
130 |
+
"answer_token <answer text> context_token <context text>", generates a list of
|
131 |
+
questions.
|
132 |
+
"""
|
133 |
+
generated_questions = []
|
134 |
+
|
135 |
+
for qg_input in qg_inputs:
|
136 |
+
question = self._generate_question(qg_input)
|
137 |
+
generated_questions.append(question)
|
138 |
+
|
139 |
+
return generated_questions
|
140 |
+
|
141 |
+
def _split_text(self, text: str) -> List[str]:
|
142 |
+
"""Splits the text into sentences, and attempts to split or truncate long sentences."""
|
143 |
+
MAX_SENTENCE_LEN = 128
|
144 |
+
sentences = re.findall(".*?[.!\?]", text)
|
145 |
+
cut_sentences = []
|
146 |
+
|
147 |
+
for sentence in sentences:
|
148 |
+
if len(sentence) > MAX_SENTENCE_LEN:
|
149 |
+
cut_sentences.extend(re.split("[,;:)]", sentence))
|
150 |
+
|
151 |
+
# remove useless post-quote sentence fragments
|
152 |
+
cut_sentences = [s for s in sentences if len(s.split(" ")) > 5]
|
153 |
+
sentences = sentences + cut_sentences
|
154 |
+
|
155 |
+
return list(set([s.strip(" ") for s in sentences]))
|
156 |
+
|
157 |
+
def _split_into_segments(self, text: str) -> List[str]:
|
158 |
+
"""Splits a long text into segments short enough to be input into the transformer network.
|
159 |
+
Segments are used as context for question generation.
|
160 |
+
"""
|
161 |
+
MAX_TOKENS = 490
|
162 |
+
paragraphs = text.split("\n")
|
163 |
+
tokenized_paragraphs = [
|
164 |
+
self.qg_tokenizer(p)["input_ids"] for p in paragraphs if len(p) > 0
|
165 |
+
]
|
166 |
+
segments = []
|
167 |
+
|
168 |
+
while len(tokenized_paragraphs) > 0:
|
169 |
+
segment = []
|
170 |
+
|
171 |
+
while len(segment) < MAX_TOKENS and len(tokenized_paragraphs) > 0:
|
172 |
+
paragraph = tokenized_paragraphs.pop(0)
|
173 |
+
segment.extend(paragraph)
|
174 |
+
segments.append(segment)
|
175 |
+
|
176 |
+
return [self.qg_tokenizer.decode(s, skip_special_tokens=True) for s in segments]
|
177 |
+
|
178 |
+
def _prepare_qg_inputs(
|
179 |
+
self,
|
180 |
+
sentences: List[str],
|
181 |
+
text: str
|
182 |
+
) -> Tuple[List[str], List[str]]:
|
183 |
+
"""Uses sentences as answers and the text as context. Returns a tuple of (model inputs, answers).
|
184 |
+
Model inputs are "answer_token <answer text> context_token <context text>"
|
185 |
+
"""
|
186 |
+
inputs = []
|
187 |
+
answers = []
|
188 |
+
|
189 |
+
for sentence in sentences:
|
190 |
+
qg_input = f"{self.ANSWER_TOKEN} {sentence} {self.CONTEXT_TOKEN} {text}"
|
191 |
+
inputs.append(qg_input)
|
192 |
+
answers.append(sentence)
|
193 |
+
|
194 |
+
return inputs, answers
|
195 |
+
|
196 |
+
def _prepare_qg_inputs_MC(self, sentences: List[str]) -> Tuple[List[str], List[str]]:
|
197 |
+
"""Performs NER on the text, and uses extracted entities are candidate answers for multiple-choice
|
198 |
+
questions. Sentences are used as context, and entities as answers. Returns a tuple of (model inputs, answers).
|
199 |
+
Model inputs are "answer_token <answer text> context_token <context text>"
|
200 |
+
"""
|
201 |
+
spacy_nlp = en_core_web_sm.load()
|
202 |
+
docs = list(spacy_nlp.pipe(sentences, disable=["parser"]))
|
203 |
+
inputs_from_text = []
|
204 |
+
answers_from_text = []
|
205 |
+
|
206 |
+
for doc, sentence in zip(docs, sentences):
|
207 |
+
entities = doc.ents
|
208 |
+
if entities:
|
209 |
+
|
210 |
+
for entity in entities:
|
211 |
+
qg_input = f"{self.ANSWER_TOKEN} {entity} {self.CONTEXT_TOKEN} {sentence}"
|
212 |
+
answers = self._get_MC_answers(entity, docs)
|
213 |
+
inputs_from_text.append(qg_input)
|
214 |
+
answers_from_text.append(answers)
|
215 |
+
|
216 |
+
return inputs_from_text, answers_from_text
|
217 |
+
|
218 |
+
def _get_MC_answers(self, correct_answer: Any, docs: Any) -> List[Mapping[str, Any]]:
|
219 |
+
"""Finds a set of alternative answers for a multiple-choice question. Will attempt to find
|
220 |
+
alternatives of the same entity type as correct_answer if possible.
|
221 |
+
"""
|
222 |
+
entities = []
|
223 |
+
|
224 |
+
for doc in docs:
|
225 |
+
entities.extend([{"text": e.text, "label_": e.label_} for e in doc.ents])
|
226 |
+
|
227 |
+
# Remove duplicate elements and convert to a list
|
228 |
+
entities_json = [json.dumps(kv) for kv in entities]
|
229 |
+
pool = sorted(set(entities_json)) # Convert pool to a sorted list
|
230 |
+
num_choices = min(4, len(pool)) - 1 # Number of choices to make
|
231 |
+
|
232 |
+
# Add the correct answer
|
233 |
+
final_choices = []
|
234 |
+
correct_label = correct_answer.label_
|
235 |
+
final_choices.append({"answer": correct_answer.text, "correct": True})
|
236 |
+
|
237 |
+
# Remove the correct answer from the pool
|
238 |
+
pool = [e for e in pool if e != json.dumps({"text": correct_answer.text, "label_": correct_answer.label_})]
|
239 |
+
|
240 |
+
# Find answers with the same NER label
|
241 |
+
matches = [e for e in pool if correct_label in e]
|
242 |
+
|
243 |
+
# If not enough matches, add other random answers
|
244 |
+
if len(matches) < num_choices:
|
245 |
+
choices = matches
|
246 |
+
remaining_choices = random.sample(sorted(pool), num_choices - len(choices))
|
247 |
+
choices.extend(remaining_choices)
|
248 |
+
else:
|
249 |
+
choices = random.sample(sorted(matches), num_choices)
|
250 |
+
|
251 |
+
choices = [json.loads(s) for s in choices]
|
252 |
+
|
253 |
+
for choice in choices:
|
254 |
+
final_choices.append({"answer": choice["text"], "correct": False})
|
255 |
+
|
256 |
+
random.shuffle(final_choices)
|
257 |
+
return final_choices
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
# def _get_MC_answers(self, correct_answer: Any, docs: Any) -> List[Mapping[str, Any]]:
|
262 |
+
# """Finds a set of alternative answers for a multiple-choice question. Will attempt to find
|
263 |
+
# alternatives of the same entity type as correct_answer if possible.
|
264 |
+
# """
|
265 |
+
# entities = []
|
266 |
+
|
267 |
+
# for doc in docs:
|
268 |
+
# entities.extend([{"text": e.text, "label_": e.label_}
|
269 |
+
# for e in doc.ents])
|
270 |
+
|
271 |
+
# # remove duplicate elements
|
272 |
+
# entities_json = [json.dumps(kv) for kv in entities]
|
273 |
+
# pool = set(entities_json)
|
274 |
+
# num_choices = (
|
275 |
+
# min(4, len(pool)) - 1
|
276 |
+
# ) # -1 because we already have the correct answer
|
277 |
+
|
278 |
+
# # add the correct answer
|
279 |
+
# final_choices = []
|
280 |
+
# correct_label = correct_answer.label_
|
281 |
+
# final_choices.append({"answer": correct_answer.text, "correct": True})
|
282 |
+
# pool.remove(
|
283 |
+
# json.dumps({"text": correct_answer.text,
|
284 |
+
# "label_": correct_answer.label_})
|
285 |
+
# )
|
286 |
+
|
287 |
+
# # find answers with the same NER label
|
288 |
+
# matches = [e for e in pool if correct_label in e]
|
289 |
+
|
290 |
+
# # if we don't have enough then add some other random answers
|
291 |
+
# if len(matches) < num_choices:
|
292 |
+
# choices = matches
|
293 |
+
# pool = pool.difference(set(choices))
|
294 |
+
# choices.extend(random.sample(pool, num_choices - len(choices)))
|
295 |
+
# else:
|
296 |
+
# choices = random.sample(matches, num_choices)
|
297 |
+
|
298 |
+
# choices = [json.loads(s) for s in choices]
|
299 |
+
|
300 |
+
# for choice in choices:
|
301 |
+
# final_choices.append({"answer": choice["text"], "correct": False})
|
302 |
+
|
303 |
+
# random.shuffle(final_choices)
|
304 |
+
# return final_choices
|
305 |
+
|
306 |
+
@torch.no_grad()
|
307 |
+
def _generate_question(self, qg_input: str) -> str:
|
308 |
+
"""Takes qg_input which is the concatenated answer and context, and uses it to generate
|
309 |
+
a question sentence. The generated question is decoded and then returned.
|
310 |
+
"""
|
311 |
+
encoded_input = self._encode_qg_input(qg_input)
|
312 |
+
output = self.qg_model.generate(input_ids=encoded_input["input_ids"])
|
313 |
+
question = self.qg_tokenizer.decode(
|
314 |
+
output[0],
|
315 |
+
skip_special_tokens=True
|
316 |
+
)
|
317 |
+
return question
|
318 |
+
|
319 |
+
def _encode_qg_input(self, qg_input: str) -> torch.tensor:
|
320 |
+
"""Tokenizes a string and returns a tensor of input ids corresponding to indices of tokens in
|
321 |
+
the vocab.
|
322 |
+
"""
|
323 |
+
return self.qg_tokenizer(
|
324 |
+
qg_input,
|
325 |
+
padding='max_length',
|
326 |
+
max_length=self.SEQ_LENGTH,
|
327 |
+
truncation=True,
|
328 |
+
return_tensors="pt",
|
329 |
+
).to(self.device)
|
330 |
+
|
331 |
+
def _get_ranked_qa_pairs(
|
332 |
+
self, generated_questions: List[str], qg_answers: List[str], scores, num_questions: int = 10
|
333 |
+
) -> List[Mapping[str, str]]:
|
334 |
+
"""Ranks generated questions according to scores, and returns the top num_questions examples.
|
335 |
+
"""
|
336 |
+
if num_questions > len(scores):
|
337 |
+
num_questions = len(scores)
|
338 |
+
print((
|
339 |
+
f"\nWas only able to generate {num_questions} questions.",
|
340 |
+
"For more questions, please input a longer text.")
|
341 |
+
)
|
342 |
+
|
343 |
+
qa_list = []
|
344 |
+
|
345 |
+
for i in range(num_questions):
|
346 |
+
index = scores[i]
|
347 |
+
qa = {
|
348 |
+
"question": generated_questions[index].split("?")[0] + "?",
|
349 |
+
"answer": qg_answers[index]
|
350 |
+
}
|
351 |
+
qa_list.append(qa)
|
352 |
+
|
353 |
+
return qa_list
|
354 |
+
|
355 |
+
def _get_all_qa_pairs(self, generated_questions: List[str], qg_answers: List[str]):
|
356 |
+
"""Formats question and answer pairs without ranking or filtering."""
|
357 |
+
qa_list = []
|
358 |
+
|
359 |
+
for question, answer in zip(generated_questions, qg_answers):
|
360 |
+
qa = {
|
361 |
+
"question": question.split("?")[0] + "?",
|
362 |
+
"answer": answer
|
363 |
+
}
|
364 |
+
qa_list.append(qa)
|
365 |
+
|
366 |
+
return qa_list
|
367 |
+
|
368 |
+
|
369 |
+
class QAEvaluator:
|
370 |
+
"""Wrapper for a transformer model which evaluates the quality of question-answer pairs.
|
371 |
+
Given a QA pair, the model will generate a score. Scores can be used to rank and filter
|
372 |
+
QA pairs.
|
373 |
+
"""
|
374 |
+
|
375 |
+
def __init__(self) -> None:
|
376 |
+
|
377 |
+
QAE_PRETRAINED = "iarfmoose/bert-base-cased-qa-evaluator"
|
378 |
+
self.SEQ_LENGTH = 512
|
379 |
+
|
380 |
+
self.device = torch.device(
|
381 |
+
"cuda" if torch.cuda.is_available() else "cpu")
|
382 |
+
|
383 |
+
self.qae_tokenizer = AutoTokenizer.from_pretrained(QAE_PRETRAINED)
|
384 |
+
self.qae_model = AutoModelForSequenceClassification.from_pretrained(
|
385 |
+
QAE_PRETRAINED
|
386 |
+
)
|
387 |
+
self.qae_model.to(self.device)
|
388 |
+
self.qae_model.eval()
|
389 |
+
|
390 |
+
def encode_qa_pairs(self, questions: List[str], answers: List[str]) -> List[torch.tensor]:
|
391 |
+
"""Takes a list of questions and a list of answers and encodes them as a list of tensors."""
|
392 |
+
encoded_pairs = []
|
393 |
+
|
394 |
+
for question, answer in zip(questions, answers):
|
395 |
+
encoded_qa = self._encode_qa(question, answer)
|
396 |
+
encoded_pairs.append(encoded_qa.to(self.device))
|
397 |
+
|
398 |
+
return encoded_pairs
|
399 |
+
|
400 |
+
def get_scores(self, encoded_qa_pairs: List[torch.tensor]) -> List[float]:
|
401 |
+
"""Generates scores for a list of encoded QA pairs."""
|
402 |
+
scores = {}
|
403 |
+
|
404 |
+
for i in range(len(encoded_qa_pairs)):
|
405 |
+
scores[i] = self._evaluate_qa(encoded_qa_pairs[i])
|
406 |
+
|
407 |
+
return [
|
408 |
+
k for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)
|
409 |
+
]
|
410 |
+
|
411 |
+
def _encode_qa(self, question: str, answer: str) -> torch.tensor:
|
412 |
+
"""Concatenates a question and answer, and then tokenizes them. Returns a tensor of
|
413 |
+
input ids corresponding to indices in the vocab.
|
414 |
+
"""
|
415 |
+
if type(answer) is list:
|
416 |
+
for a in answer:
|
417 |
+
if a["correct"]:
|
418 |
+
correct_answer = a["answer"]
|
419 |
+
else:
|
420 |
+
correct_answer = answer
|
421 |
+
|
422 |
+
return self.qae_tokenizer(
|
423 |
+
text=question,
|
424 |
+
text_pair=correct_answer,
|
425 |
+
padding="max_length",
|
426 |
+
max_length=self.SEQ_LENGTH,
|
427 |
+
truncation=True,
|
428 |
+
return_tensors="pt",
|
429 |
+
)
|
430 |
+
|
431 |
+
@torch.no_grad()
|
432 |
+
def _evaluate_qa(self, encoded_qa_pair: torch.tensor) -> float:
|
433 |
+
"""Takes an encoded QA pair and returns a score."""
|
434 |
+
output = self.qae_model(**encoded_qa_pair)
|
435 |
+
return output[0][0][1]
|
436 |
+
|
437 |
+
|
438 |
+
def print_qa(qa_list: List[Mapping[str, str]], show_answers: bool = True) -> None:
|
439 |
+
"""Formats and prints a list of generated questions and answers."""
|
440 |
+
|
441 |
+
for i in range(len(qa_list)):
|
442 |
+
# wider space for 2 digit q nums
|
443 |
+
space = " " * int(np.where(i < 9, 3, 4))
|
444 |
+
|
445 |
+
print(f"{i + 1}) Q: {qa_list[i]['question']}")
|
446 |
+
|
447 |
+
answer = qa_list[i]["answer"]
|
448 |
+
|
449 |
+
# print a list of multiple choice answers
|
450 |
+
if type(answer) is list:
|
451 |
+
|
452 |
+
if show_answers:
|
453 |
+
print(
|
454 |
+
f"{space}A: 1. {answer[0]['answer']} "
|
455 |
+
f"{np.where(answer[0]['correct'], '(correct)', '')}"
|
456 |
+
)
|
457 |
+
for j in range(1, len(answer)):
|
458 |
+
print(
|
459 |
+
f"{space + ' '}{j + 1}. {answer[j]['answer']} "
|
460 |
+
f"{np.where(answer[j]['correct']==True,'(correct)', '')}"
|
461 |
+
)
|
462 |
+
|
463 |
+
else:
|
464 |
+
print(f"{space}A: 1. {answer[0]['answer']}")
|
465 |
+
for j in range(1, len(answer)):
|
466 |
+
print(f"{space + ' '}{j + 1}. {answer[j]['answer']}")
|
467 |
+
|
468 |
+
print("")
|
469 |
+
|
470 |
+
# print full sentence answers
|
471 |
+
else:
|
472 |
+
if show_answers:
|
473 |
+
print(f"{space}A: {answer}\n")
|
requirements.txt
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.6.0
|
2 |
+
anyascii==0.3.2
|
3 |
+
anyio==4.3.0
|
4 |
+
appdirs==1.4.4
|
5 |
+
attrs==23.2.0
|
6 |
+
backoff==2.2.1
|
7 |
+
beautifulsoup4==4.12.3
|
8 |
+
blis==0.7.11
|
9 |
+
boilerpy3==1.0.7
|
10 |
+
catalogue==2.0.10
|
11 |
+
cattrs==23.2.3
|
12 |
+
certifi==2024.2.2
|
13 |
+
charset-normalizer==3.3.2
|
14 |
+
click==8.1.7
|
15 |
+
cloudpathlib==0.16.0
|
16 |
+
colorama==0.4.6
|
17 |
+
confection==0.1.4
|
18 |
+
contractions==0.1.73
|
19 |
+
cymem==2.0.8
|
20 |
+
docopt==0.6.2
|
21 |
+
editdistance==0.8.1
|
22 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
|
23 |
+
Events==0.5
|
24 |
+
farm-haystack==1.25.2
|
25 |
+
filelock==3.13.4
|
26 |
+
flashtext==2.7
|
27 |
+
fsspec==2024.3.1
|
28 |
+
future==1.0.0
|
29 |
+
h11==0.14.0
|
30 |
+
httpcore==1.0.5
|
31 |
+
httpx==0.27.0
|
32 |
+
huggingface-hub==0.22.2
|
33 |
+
idna==3.7
|
34 |
+
inflect==7.2.0
|
35 |
+
interaction==1.3
|
36 |
+
jellyfish==0.8.2
|
37 |
+
Jinja2==3.1.3
|
38 |
+
joblib==1.4.0
|
39 |
+
jsonschema==4.21.1
|
40 |
+
jsonschema-specifications==2023.12.1
|
41 |
+
langcodes==3.3.0
|
42 |
+
lazy-imports==0.3.1
|
43 |
+
logging==0.4.9.6
|
44 |
+
MarkupSafe==2.1.5
|
45 |
+
monotonic==1.6
|
46 |
+
more-itertools==10.2.0
|
47 |
+
mpmath==1.3.0
|
48 |
+
murmurhash==1.0.10
|
49 |
+
networkx==3.3
|
50 |
+
nltk==3.8.1
|
51 |
+
num2words==0.5.13
|
52 |
+
numpy==1.26.4
|
53 |
+
packaging==24.0
|
54 |
+
pandas==2.2.2
|
55 |
+
pathlib==1.0.1
|
56 |
+
pillow==10.3.0
|
57 |
+
pke @ git+https://github.com/boudinfl/pke.git@69871ffdb720b83df23684fea53ec8776fd87e63
|
58 |
+
platformdirs==4.2.0
|
59 |
+
posthog==3.5.0
|
60 |
+
preshed==3.0.9
|
61 |
+
prompthub-py==4.0.0
|
62 |
+
protobuf==5.26.1
|
63 |
+
pyahocorasick==2.1.0
|
64 |
+
pydantic==1.10.15
|
65 |
+
pydantic_core==2.18.1
|
66 |
+
python-dateutil==2.9.0.post0
|
67 |
+
pytz==2024.1
|
68 |
+
PyYAML==6.0.1
|
69 |
+
quantulum3==0.9.1
|
70 |
+
rank-bm25==0.2.2
|
71 |
+
referencing==0.34.0
|
72 |
+
regex==2024.4.16
|
73 |
+
requests==2.31.0
|
74 |
+
requests-cache==0.9.8
|
75 |
+
rpds-py==0.18.0
|
76 |
+
safetensors==0.4.3
|
77 |
+
scikit-learn==1.4.2
|
78 |
+
scipy==1.13.0
|
79 |
+
sense2vec==2.0.2
|
80 |
+
sentence-transformers==2.7.0
|
81 |
+
sentencepiece==0.2.0
|
82 |
+
similarity==0.0.1
|
83 |
+
six==1.16.0
|
84 |
+
smart-open==6.4.0
|
85 |
+
sniffio==1.3.1
|
86 |
+
soupsieve==2.5
|
87 |
+
spacy==3.7.4
|
88 |
+
spacy-legacy==3.0.12
|
89 |
+
spacy-loggers==1.0.5
|
90 |
+
srsly==2.4.8
|
91 |
+
sseclient-py==1.8.0
|
92 |
+
strsim==0.0.3
|
93 |
+
sympy==1.12
|
94 |
+
tenacity==8.2.3
|
95 |
+
textsearch==0.0.24
|
96 |
+
textwrap3==0.9.2
|
97 |
+
thinc==8.2.3
|
98 |
+
threadpoolctl==3.4.0
|
99 |
+
tiktoken==0.6.0
|
100 |
+
tokenizers==0.15.2
|
101 |
+
torch==2.2.2
|
102 |
+
tqdm==4.66.2
|
103 |
+
transformers==4.37.2
|
104 |
+
typeguard==4.2.1
|
105 |
+
typer==0.9.4
|
106 |
+
typing_extensions==4.11.0
|
107 |
+
tzdata==2024.1
|
108 |
+
Unidecode==1.3.8
|
109 |
+
url-normalize==1.4.3
|
110 |
+
urllib3==2.2.1
|
111 |
+
wasabi==1.1.2
|
112 |
+
weasel==0.3.4
|
113 |
+
wikipedia==1.4.0
|
streamlit.py
ADDED
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import wikipedia
|
3 |
+
from haystack.document_stores import InMemoryDocumentStore
|
4 |
+
from haystack.utils import clean_wiki_text, convert_files_to_docs
|
5 |
+
from haystack.nodes import TfidfRetriever, FARMReader
|
6 |
+
from haystack.pipelines import ExtractiveQAPipeline
|
7 |
+
from main import print_qa, QuestionGenerator
|
8 |
+
|
9 |
+
def main():
|
10 |
+
# Set the Streamlit app title
|
11 |
+
st.title("Question Generation using Haystack and Streamlit")
|
12 |
+
|
13 |
+
# Select the input type
|
14 |
+
inputs = ["Input Paragraph", "Wikipedia Examples"]
|
15 |
+
input_type = st.selectbox("Select an input type:", inputs)
|
16 |
+
|
17 |
+
# Initialize wiki_text as an empty string
|
18 |
+
wiki_text = ""
|
19 |
+
|
20 |
+
# Handle different input types
|
21 |
+
if input_type == "Input Paragraph":
|
22 |
+
# Allow user to input text paragraph
|
23 |
+
wiki_text = st.text_area("Input paragraph:", height=200)
|
24 |
+
|
25 |
+
elif input_type == "Wikipedia Examples":
|
26 |
+
# Define topics for selection
|
27 |
+
topics = ["Deep Learning", "Machine Learning"]
|
28 |
+
selected_topic = st.selectbox("Select a topic:", topics)
|
29 |
+
|
30 |
+
# Retrieve Wikipedia content based on the selected topic
|
31 |
+
if selected_topic:
|
32 |
+
wiki = wikipedia.page(selected_topic)
|
33 |
+
wiki_text = wiki.content
|
34 |
+
|
35 |
+
# Display the retrieved Wikipedia content (optional)
|
36 |
+
st.text_area("Retrieved Wikipedia content:", wiki_text, height=200)
|
37 |
+
|
38 |
+
# Preprocess the input text
|
39 |
+
wiki_text = clean_wiki_text(wiki_text)
|
40 |
+
|
41 |
+
# Allow user to specify the number of questions to generate
|
42 |
+
num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5)
|
43 |
+
|
44 |
+
# Allow user to specify the model to use
|
45 |
+
model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2", "deepset/flan-t5-xl-squad2"]
|
46 |
+
model_name = st.selectbox("Select model:", model_options)
|
47 |
+
|
48 |
+
# Button to generate questions
|
49 |
+
if st.button("Generate Questions"):
|
50 |
+
document_store = InMemoryDocumentStore()
|
51 |
+
|
52 |
+
# Convert the preprocessed text into a document
|
53 |
+
document = {"content": wiki_text}
|
54 |
+
document_store.write_documents([document])
|
55 |
+
|
56 |
+
# Initialize a TfidfRetriever
|
57 |
+
retriever = TfidfRetriever(document_store=document_store)
|
58 |
+
|
59 |
+
# Initialize a FARMReader with the selected model
|
60 |
+
reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
61 |
+
|
62 |
+
# Initialize the question generation pipeline
|
63 |
+
pipe = ExtractiveQAPipeline(reader, retriever)
|
64 |
+
|
65 |
+
# Initialize the QuestionGenerator
|
66 |
+
qg = QuestionGenerator()
|
67 |
+
|
68 |
+
# Generate multiple-choice questions
|
69 |
+
qa_list = qg.generate(
|
70 |
+
wiki_text,
|
71 |
+
num_questions=num_questions,
|
72 |
+
answer_style='multiple_choice'
|
73 |
+
)
|
74 |
+
|
75 |
+
# Display the generated questions and answers
|
76 |
+
st.header("Generated Questions and Answers:")
|
77 |
+
for idx, qa in enumerate(qa_list):
|
78 |
+
# Display the question
|
79 |
+
st.write(f"Question {idx + 1}: {qa['question']}")
|
80 |
+
|
81 |
+
# Display the answer options
|
82 |
+
if 'answer' in qa:
|
83 |
+
for i, option in enumerate(qa['answer']):
|
84 |
+
correct_marker = "(correct)" if option["correct"] else ""
|
85 |
+
st.write(f"Option {i + 1}: {option['answer']} {correct_marker}")
|
86 |
+
|
87 |
+
# Add a separator after each question-answer pair
|
88 |
+
st.write("-" * 40)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
# Run the Streamlit app
|
97 |
+
if __name__ == "__main__":
|
98 |
+
main()
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
# import streamlit as st
|
103 |
+
# import wikipedia
|
104 |
+
# from haystack.document_stores import InMemoryDocumentStore
|
105 |
+
# from haystack.utils import clean_wiki_text, convert_files_to_docs
|
106 |
+
# from haystack.nodes import TfidfRetriever, FARMReader
|
107 |
+
# from haystack.pipelines import ExtractiveQAPipeline
|
108 |
+
# from main import print_qa, QuestionGenerator
|
109 |
+
# import torch
|
110 |
+
|
111 |
+
# def main():
|
112 |
+
# # Set the Streamlit app title
|
113 |
+
# st.title("Question Generation using Haystack and Streamlit")
|
114 |
+
|
115 |
+
# # Select the input type
|
116 |
+
# inputs = ["Input Paragraph", "Wikipedia Examples"]
|
117 |
+
# input_type = st.selectbox("Select an input type:", inputs, key="input_type")
|
118 |
+
|
119 |
+
# # Initialize wiki_text as an empty string (to avoid UnboundLocalError)
|
120 |
+
# wiki_text = """ Deep learning is the subset of machine learning methods based on artificial neural networks (ANNs) with representation learning. The adjective "deep" refers to the use of multiple layers in the network. Methods used can be either supervised, semi-supervised or unsupervised.Deep-learning architectures such as deep neural networks, deep belief networks, recurrent neural networks, convolutional neural networks and transformers have been applied to fields including computer vision, speech recognition, natural language processing, machine translation, bioinformatics, drug design, medical image analysis, climate science, material inspection and board game programs, where they have produced results comparable to and in some cases surpassing human expert performance.Artificial neural networks were inspired by information processing and distributed communication nodes in biological systems. ANNs have various differences from biological brains. Specifically, artificial neural networks tend to be static and symbolic, while the biological brain of most living organisms is dynamic (plastic) and analog. ANNs are generally seen as low quality models for brain function."""
|
121 |
+
|
122 |
+
# # Handle different input types
|
123 |
+
# if input_type == "Input Paragraph":
|
124 |
+
# # Allow user to input text paragraph
|
125 |
+
# wiki_text = st.text_area("Input paragraph:", height=200, key="input_paragraph")
|
126 |
+
|
127 |
+
# elif input_type == "Wikipedia Examples":
|
128 |
+
# # Define options for selecting the topic
|
129 |
+
# topics = ["Deep Learning", "Machine Learning"]
|
130 |
+
# selected_topic = st.selectbox("Select a topic:", topics, key="wiki_topic")
|
131 |
+
|
132 |
+
# # Retrieve Wikipedia content based on the selected topic
|
133 |
+
# if selected_topic:
|
134 |
+
# wiki = wikipedia.page(selected_topic)
|
135 |
+
# wiki_text = wiki.content
|
136 |
+
|
137 |
+
# # Display the retrieved Wikipedia content (optional)
|
138 |
+
# st.text_area("Retrieved Wikipedia content:", wiki_text, height=200, key="wiki_text")
|
139 |
+
|
140 |
+
# # Allow user to specify the number of questions to generate
|
141 |
+
# num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5, key="num_questions")
|
142 |
+
|
143 |
+
# # Allow user to specify the model to use
|
144 |
+
# model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2", "deepset/flan-t5-xl-squad2"]
|
145 |
+
# model_name = st.selectbox("Select model:", model_options, key="model_name")
|
146 |
+
|
147 |
+
# # Button to generate questions
|
148 |
+
# if st.button("Generate Questions", key="generate_button"):
|
149 |
+
# # Initialize the document store
|
150 |
+
# with open('wiki_txt.txt', 'w', encoding='utf-8') as f:
|
151 |
+
# f.write(wiki_text)
|
152 |
+
# document_store = InMemoryDocumentStore()
|
153 |
+
# doc_dir = "/content"
|
154 |
+
# docs = convert_files_to_docs(dir_path=doc_dir, clean_func=clean_wiki_text, split_paragraphs=True)
|
155 |
+
# document_store.write_documents(docs)
|
156 |
+
# retriever = TfidfRetriever(document_store=document_store)
|
157 |
+
|
158 |
+
# # # Convert the input text paragraph or Wikipedia content into a document
|
159 |
+
# # document = {"content": wiki_text}
|
160 |
+
# # document_store.write_documents([document])
|
161 |
+
|
162 |
+
# # Initialize a TfidfRetriever
|
163 |
+
# # retriever = TfidfRetriever(document_store=document_store)
|
164 |
+
|
165 |
+
# # Initialize a FARMReader with the selected model
|
166 |
+
# reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
167 |
+
|
168 |
+
# # Initialize the question generation pipeline
|
169 |
+
# pipe = ExtractiveQAPipeline(reader, retriever)
|
170 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
171 |
+
|
172 |
+
# # Initialize the QuestionGenerator
|
173 |
+
# qg = QuestionGenerator()
|
174 |
+
|
175 |
+
# # Generate multiple-choice questions
|
176 |
+
# qa_list = qg.generate(wiki_text, num_questions=num_questions, answer_style='multiple_choice')
|
177 |
+
|
178 |
+
# # Display the generated questions and answers
|
179 |
+
# st.header("Generated Questions and Answers:")
|
180 |
+
# for idx, qa in enumerate(qa_list):
|
181 |
+
# # Display the question
|
182 |
+
# st.write(f"Question {idx + 1}: {qa['question']}")
|
183 |
+
|
184 |
+
# # Display the answer options
|
185 |
+
# if 'answer' in qa:
|
186 |
+
# for i, option in enumerate(qa['answer']):
|
187 |
+
# correct_marker = "(correct)" if option["correct"] else ""
|
188 |
+
# st.write(f"Option {i + 1}: {option['answer']} {correct_marker}")
|
189 |
+
|
190 |
+
# # Add a separator after each question-answer pair
|
191 |
+
# st.write("-" * 40)
|
192 |
+
|
193 |
+
# # Run the Streamlit app
|
194 |
+
# if __name__ == "__main__":
|
195 |
+
# main()
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
# # import streamlit as st
|
224 |
+
# # import wikipedia
|
225 |
+
# # from haystack.document_stores import InMemoryDocumentStore
|
226 |
+
# # from haystack.utils import clean_wiki_text, convert_files_to_docs
|
227 |
+
# # from haystack.nodes import TfidfRetriever, FARMReader
|
228 |
+
# # from haystack.pipelines import ExtractiveQAPipeline
|
229 |
+
# # from main import print_qa, QuestionGenerator
|
230 |
+
|
231 |
+
# # def main():
|
232 |
+
# # # Set the Streamlit app title
|
233 |
+
# # st.title("Question Generation using Haystack and Streamlit")
|
234 |
+
# # # select the input type
|
235 |
+
# # inputs = ["Input Paragraph", "Wikipedia Examples"]
|
236 |
+
# # input=st.selectbox("Select a Input Type :", inputs)
|
237 |
+
# # if(input=="Input Paragraph"):
|
238 |
+
# # # Allow user to input text paragraph
|
239 |
+
# # wiki_text = st.text_area("Input paragraph:", height=200)
|
240 |
+
|
241 |
+
# # # # Allow user to specify the number of questions to generate
|
242 |
+
# # # num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5)
|
243 |
+
|
244 |
+
# # # # Allow user to specify the model to use
|
245 |
+
# # # model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2","deepset/flan-t5-xl-squad2"]
|
246 |
+
# # # model_name = st.selectbox("Select model:", model_options)
|
247 |
+
|
248 |
+
# # # # Button to generate questions
|
249 |
+
# # # if st.button("Generate Questions"):
|
250 |
+
# # # qno=0
|
251 |
+
|
252 |
+
# # # # Initialize the document store
|
253 |
+
# # # document_store = InMemoryDocumentStore()
|
254 |
+
|
255 |
+
# # # # Convert the input text paragraph into a document
|
256 |
+
# # # document = {"content": wiki_text}
|
257 |
+
# # # document_store.write_documents([document])
|
258 |
+
|
259 |
+
# # # # Initialize a TfidfRetriever
|
260 |
+
# # # retriever = TfidfRetriever(document_store=document_store)
|
261 |
+
|
262 |
+
# # # # Initialize a FARMReader with the selected model
|
263 |
+
# # # reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
264 |
+
|
265 |
+
# # # # Initialize the question generation pipeline
|
266 |
+
# # # pipe = ExtractiveQAPipeline(reader, retriever)
|
267 |
+
|
268 |
+
# # # # Initialize the QuestionGenerator
|
269 |
+
# # # qg = QuestionGenerator()
|
270 |
+
|
271 |
+
# # # # Generate multiple-choice questions
|
272 |
+
# # # qa_list = qg.generate(
|
273 |
+
# # # wiki_text,
|
274 |
+
# # # num_questions=num_questions,
|
275 |
+
# # # answer_style='multiple_choice')
|
276 |
+
# # # print("QA List Structure:")
|
277 |
+
# # # # Display the generated questions and answers
|
278 |
+
# # # st.header("Generated Questions and Answers:")
|
279 |
+
# # # for qa in qa_list:
|
280 |
+
# # # opno=0
|
281 |
+
|
282 |
+
# # # # Display the question
|
283 |
+
# # # st.write(f"Question: {qno+1}{qa['question']}")
|
284 |
+
|
285 |
+
# # # # Display the answer options
|
286 |
+
# # # if 'answer' in qa:
|
287 |
+
# # # for idx, option in enumerate(qa['answer']):
|
288 |
+
# # # # Indicate if the option is correct
|
289 |
+
# # # correct_marker = "(correct)" if option["correct"] else ""
|
290 |
+
# # # st.write(f"Option {idx + 1}: {option['answer']} {correct_marker}")
|
291 |
+
|
292 |
+
# # # # Add a separator after each question-answer pair
|
293 |
+
# # # st.write("-" * 40)
|
294 |
+
|
295 |
+
# # if(input == "Wikipedia Examples"):
|
296 |
+
# # # Define options for selecting the topic
|
297 |
+
# # topics = ["Deep Learning", "MachineLearning"]
|
298 |
+
# # selected_topic = st.selectbox("Select a topic:", topics)
|
299 |
+
|
300 |
+
# # # Retrieve Wikipedia content based on the selected topic
|
301 |
+
# # if selected_topic:
|
302 |
+
# # wiki = wikipedia.page(selected_topic)
|
303 |
+
# # wiki_text = wiki.content
|
304 |
+
|
305 |
+
# # # Display the retrieved Wikipedia content in a text area (optional)
|
306 |
+
# # st.text_area("Retrieved Wikipedia content:", wiki_text, height=200)
|
307 |
+
|
308 |
+
# # # # Allow user to specify the number of questions to generate
|
309 |
+
# # # num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5)
|
310 |
+
|
311 |
+
# # # # Allow user to specify the model to use
|
312 |
+
# # # model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2","deepset/flan-t5-xl-squad2"]
|
313 |
+
# # # model_name = st.selectbox("Select model:", model_options)
|
314 |
+
|
315 |
+
# # # # Button to generate questions
|
316 |
+
# # # if st.button("Generate Questions"):
|
317 |
+
# # # # Initialize the document store
|
318 |
+
# # # document_store = InMemoryDocumentStore()
|
319 |
+
|
320 |
+
# # # # Convert the retrieved Wikipedia content into a document
|
321 |
+
# # # document = {"content": wiki_text}
|
322 |
+
# # # document_store.write_documents([document])
|
323 |
+
|
324 |
+
# # # # Initialize a TfidfRetriever
|
325 |
+
# # # retriever = TfidfRetriever(document_store=document_store)
|
326 |
+
|
327 |
+
# # # # Initialize a FARMReader with the selected model
|
328 |
+
# # # reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
329 |
+
|
330 |
+
# # # # Initialize the ExtractiveQAPipeline
|
331 |
+
# # # pipeline = ExtractiveQAPipeline(reader, retriever)
|
332 |
+
|
333 |
+
# # # # Initialize the QuestionGenerator
|
334 |
+
# # # qg = QuestionGenerator()
|
335 |
+
|
336 |
+
# # # # Generate multiple-choice questions
|
337 |
+
# # # qa_list = qg.generate(
|
338 |
+
# # # wiki_text,
|
339 |
+
# # # num_questions=num_questions,
|
340 |
+
# # # answer_style='multiple_choice'
|
341 |
+
# # # )
|
342 |
+
|
343 |
+
# # # # Display the generated questions and answers
|
344 |
+
# # # st.header("Generated Questions and Answers:")
|
345 |
+
# # # for idx, qa in enumerate(qa_list):
|
346 |
+
# # # # Display the question
|
347 |
+
# # # st.write(f"Question {idx + 1}: {qa['question']}")
|
348 |
+
|
349 |
+
# # # # Display the answer options
|
350 |
+
# # # if 'answer' in qa:
|
351 |
+
# # # for i, option in enumerate(qa['answer']):
|
352 |
+
# # # correct_marker = "(correct)" if option["correct"] else ""
|
353 |
+
# # # st.write(f"Option {i + 1}: {option['answer']} {correct_marker}")
|
354 |
+
|
355 |
+
# # # # Add a separator after each question-answer pair
|
356 |
+
# # # st.write("-" * 40)
|
357 |
+
|
358 |
+
# # # Allow user to specify the number of questions to generate
|
359 |
+
# # num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5)
|
360 |
+
# # # Allow user to specify the model to use
|
361 |
+
# # model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2","deepset/flan-t5-xl-squad2"]
|
362 |
+
# # model_name = st.selectbox("Select model:", model_options)
|
363 |
+
|
364 |
+
# # # Button to generate questions
|
365 |
+
# # if st.button("Generate Questions"):
|
366 |
+
# # qno=0
|
367 |
+
|
368 |
+
# # # Initialize the document store
|
369 |
+
# # document_store = InMemoryDocumentStore()
|
370 |
+
|
371 |
+
# # # Convert the input text paragraph into a document
|
372 |
+
# # document = {"content": wiki_text}
|
373 |
+
# # document_store.write_documents([document])
|
374 |
+
|
375 |
+
# # # Initialize a TfidfRetriever
|
376 |
+
# # retriever = TfidfRetriever(document_store=document_store)
|
377 |
+
|
378 |
+
# # # Initialize a FARMReader with the selected model
|
379 |
+
# # reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
380 |
+
|
381 |
+
# # # Initialize the question generation pipeline
|
382 |
+
# # pipe = ExtractiveQAPipeline(reader, retriever)
|
383 |
+
|
384 |
+
# # # Initialize the QuestionGenerator
|
385 |
+
# # qg = QuestionGenerator()
|
386 |
+
|
387 |
+
# # # Generate multiple-choice questions
|
388 |
+
# # qa_list = qg.generate(
|
389 |
+
# # wiki_text,
|
390 |
+
# # num_questions=num_questions,
|
391 |
+
# # answer_style='multiple_choice')
|
392 |
+
# # print("QA List Structure:")
|
393 |
+
# # # Display the generated questions and answers
|
394 |
+
# # st.header("Generated Questions and Answers:")
|
395 |
+
# # for qa in qa_list:
|
396 |
+
# # opno=0
|
397 |
+
|
398 |
+
# # # Display the question
|
399 |
+
# # st.write(f"Question: {qno+1}{qa['question']}")
|
400 |
+
|
401 |
+
# # # Display the answer options
|
402 |
+
# # if 'answer' in qa:
|
403 |
+
# # for idx, option in enumerate(qa['answer']):
|
404 |
+
# # # Indicate if the option is correct
|
405 |
+
# # correct_marker = "(correct)" if option["correct"] else ""
|
406 |
+
# # st.write(f"Option {idx + 1}: {option['answer']} {correct_marker}")
|
407 |
+
|
408 |
+
# # # Add a separator after each question-answer pair
|
409 |
+
# # st.write("-" * 40)
|
410 |
+
|
411 |
+
# # # Run the Streamlit app
|
412 |
+
# # if __name__ == "__main__":
|
413 |
+
# # main()
|
414 |
+
|
415 |
+
|
416 |
+
|
417 |
+
|
418 |
+
# # # import streamlit as st
|
419 |
+
# # # import re
|
420 |
+
# # # import pke
|
421 |
+
# # # import contractions
|
422 |
+
# # # import wikipedia
|
423 |
+
# # # import logging
|
424 |
+
# # # from haystack.document_stores import InMemoryDocumentStore
|
425 |
+
# # # from haystack.utils import clean_wiki_text, convert_files_to_docs, fetch_archive_from_http
|
426 |
+
# # # from transformers.pipelines import question_answering
|
427 |
+
# # # from haystack.nodes import TfidfRetriever
|
428 |
+
# # # from haystack.pipelines import ExtractiveQAPipeline
|
429 |
+
# # # from haystack.nodes import FARMReader
|
430 |
+
# # # import torch
|
431 |
+
|
432 |
+
# # # from main import print_qa
|
433 |
+
# # # from main import QuestionGenerator
|
434 |
+
|
435 |
+
# # # def main():
|
436 |
+
# # # # Initialize Streamlit app
|
437 |
+
# # # st.title("Question Generation using Haystack and Streamlit")
|
438 |
+
|
439 |
+
# # # # Allow user to input text paragraph
|
440 |
+
# # # wiki_text = st.text_area("Input paragraph:", height=200)
|
441 |
+
|
442 |
+
# # # # Allow user to specify the number of questions to generate
|
443 |
+
# # # num_questions = st.slider("Number of questions to generate:", min_value=1, max_value=20, value=5)
|
444 |
+
|
445 |
+
# # # # Allow user to specify the model to use
|
446 |
+
# # # model_options = ["deepset/roberta-base-squad2", "deepset/roberta-base-squad2-distilled", "bert-large-uncased-whole-word-masking-squad2"]
|
447 |
+
# # # model_name = st.selectbox("Select model:", model_options)
|
448 |
+
|
449 |
+
# # # # Button to generate questions
|
450 |
+
# # # if st.button("Generate Questions"):
|
451 |
+
# # # # Initialize the document store
|
452 |
+
# # # document_store = InMemoryDocumentStore()
|
453 |
+
|
454 |
+
# # # # Convert the input text paragraph into a document
|
455 |
+
# # # document = {"content": wiki_text}
|
456 |
+
# # # document_store.write_documents([document])
|
457 |
+
|
458 |
+
# # # # Initialize a TfidfRetriever
|
459 |
+
# # # retriever = TfidfRetriever(document_store=document_store)
|
460 |
+
|
461 |
+
# # # # Initialize a FARMReader with the selected model
|
462 |
+
# # # reader = FARMReader(model_name_or_path=model_name, use_gpu=False)
|
463 |
+
|
464 |
+
# # # # Initialize the question generation pipeline
|
465 |
+
# # # pipe = ExtractiveQAPipeline(reader, retriever)
|
466 |
+
|
467 |
+
# # # # Initialize the QuestionGenerator
|
468 |
+
# # # qg = QuestionGenerator()
|
469 |
+
|
470 |
+
# # # # Generate multiple-choice questions
|
471 |
+
# # # qa_list = qg.generate(
|
472 |
+
# # # wiki_text,
|
473 |
+
# # # num_questions=num_questions,
|
474 |
+
# # # answer_style='multiple_choice')
|
475 |
+
# # # print("QA List Structure:")
|
476 |
+
# # # # Display the generated questions and answers
|
477 |
+
# # # st.header("Generated Questions and Answers:")
|
478 |
+
# # # for qa in qa_list:
|
479 |
+
# # # # Display the question
|
480 |
+
# # # st.write(f"Question: {qa['question']}")
|
481 |
+
|
482 |
+
# # # # Display the answer options
|
483 |
+
# # # if 'answer' in qa:
|
484 |
+
# # # for idx, option in enumerate(qa['answer']):
|
485 |
+
# # # # Indicate if the option is correct
|
486 |
+
# # # correct_marker = "(correct)" if option["correct"] else ""
|
487 |
+
# # # st.write(f"Option {idx + 1}: {option['answer']} {correct_marker}")
|
488 |
+
|
489 |
+
# # # # Add a separator after each question-answer pair
|
490 |
+
# # # st.write("-" * 40)
|
491 |
+
# # # # for qa in qa_list:
|
492 |
+
# # # # print(qa)
|
493 |
+
|
494 |
+
# # # # # Proceed with displaying the generated questions
|
495 |
+
# # # # st.header("Generated Questions:")
|
496 |
+
# # # # for qa in qa_list:
|
497 |
+
# # # # st.write(f"Question: {qa['question']}")
|
498 |
+
# # # # # Adjust the code to match the structure of the output
|
499 |
+
# # # # if 'answers' in qa:
|
500 |
+
# # # # for idx, answer in enumerate(qa['answers']):
|
501 |
+
# # # # prefix = f"Option {idx + 1}:"
|
502 |
+
# # # # if answer["correct"]:
|
503 |
+
# # # # prefix += " (correct)"
|
504 |
+
# # # # st.write(f"{prefix} {answer['text']}")
|
505 |
+
# # # # else:
|
506 |
+
# # # # st.write("No answers available for this question.")
|
507 |
+
# # # # st.write("") # Add an empty line between each question for better readability
|
508 |
+
|
509 |
+
# # # # Run the Streamlit app
|
510 |
+
# # # if __name__ == "__main__":
|
511 |
+
# # # main()
|