Spaces:
Runtime error
Runtime error
domenicrosati
commited on
Commit
β’
3f1f616
1
Parent(s):
bdb2b00
update to use api
Browse files
app.py
CHANGED
@@ -12,15 +12,15 @@ import torch
|
|
12 |
|
13 |
SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
|
14 |
|
15 |
-
class CrossEncoder:
|
16 |
-
|
17 |
-
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
|
25 |
|
26 |
def remove_html(x):
|
@@ -134,23 +134,23 @@ def find_source(text, docs, matched):
|
|
134 |
return None
|
135 |
|
136 |
|
137 |
-
@st.experimental_singleton
|
138 |
-
def init_models():
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
|
153 |
-
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
154 |
|
155 |
|
156 |
def clean_query(query, strict=True, clean=True):
|
@@ -206,32 +206,32 @@ Answers are linked to source documents containing citations where users can expl
|
|
206 |
For example try: Do tanning beds cause cancer?
|
207 |
""")
|
208 |
|
209 |
-
st.markdown("""
|
210 |
-
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
211 |
-
""", unsafe_allow_html=True)
|
212 |
-
|
213 |
-
with st.expander("Settings (strictness, context limit, top hits)"):
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
|
236 |
# def paraphrase(text, max_length=128):
|
237 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
@@ -272,38 +272,120 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
272 |
return None
|
273 |
|
274 |
|
275 |
-
def
|
276 |
-
# if use_query_exp == 'yes':
|
277 |
-
# query_exp = paraphrase(f"question2question: {query}")
|
278 |
-
# st.markdown(f"""
|
279 |
-
# If you are not getting good results try one of:
|
280 |
-
# * {query_exp}
|
281 |
-
# """)
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
|
306 |
-
if len(
|
307 |
return st.markdown("""
|
308 |
<div class="container-fluid">
|
309 |
<div class="row align-items-start">
|
@@ -314,58 +396,7 @@ def run_query(query, progress_bar):
|
|
314 |
</div>
|
315 |
""", unsafe_allow_html=True)
|
316 |
|
317 |
-
|
318 |
-
sentence_pairs = [[query, context] for context in contexts]
|
319 |
-
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
320 |
-
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
321 |
-
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
322 |
-
contexts = sorted_contexts[:context_limit]
|
323 |
-
else:
|
324 |
-
contexts = contexts[:context_limit]
|
325 |
-
|
326 |
-
progress_bar.progress(50)
|
327 |
-
if concat_passages == 'yes':
|
328 |
-
context = '\n---'.join(contexts)
|
329 |
-
model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
|
330 |
-
else:
|
331 |
-
context = ['\n---\n'+ctx for ctx in contexts]
|
332 |
-
model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
|
333 |
-
|
334 |
-
results = []
|
335 |
-
|
336 |
-
progress_bar.progress(75)
|
337 |
-
for i, result in enumerate(model_results):
|
338 |
-
if concat_passages == 'yes':
|
339 |
-
matched = matched_context(result['start'], result['end'], context)
|
340 |
-
else:
|
341 |
-
matched = matched_context(result['start'], result['end'], context[i])
|
342 |
-
support = find_source(result['answer'], orig_docs, matched)
|
343 |
-
if not support:
|
344 |
-
continue
|
345 |
-
results.append({
|
346 |
-
"answer": support['text'],
|
347 |
-
"title": support['source_title'],
|
348 |
-
"link": support['source_link'],
|
349 |
-
"context": support['citation_statement'],
|
350 |
-
"score": result['score'],
|
351 |
-
"doi": support["supporting"]
|
352 |
-
})
|
353 |
-
|
354 |
-
grouped_results = group_results_by_context(results)
|
355 |
-
sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
|
356 |
-
|
357 |
-
if confidence_threshold == 0:
|
358 |
-
threshold = 0
|
359 |
-
else:
|
360 |
-
threshold = (confidence_threshold or 10) / 100
|
361 |
-
|
362 |
-
sorted_result = list(filter(
|
363 |
-
lambda x: x['score'] > threshold,
|
364 |
-
sorted_result
|
365 |
-
))
|
366 |
-
|
367 |
-
progress_bar.progress(100)
|
368 |
-
for r in sorted_result:
|
369 |
ctx = remove_html(r["context"])
|
370 |
for answer in r['texts']:
|
371 |
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
@@ -377,5 +408,4 @@ def run_query(query, progress_bar):
|
|
377 |
query = st.text_input("Ask scientific literature a question", "")
|
378 |
if query != "":
|
379 |
with st.spinner('Loading...'):
|
380 |
-
|
381 |
-
run_query(query, progress_bar)
|
|
|
12 |
|
13 |
SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
|
14 |
|
15 |
+
# class CrossEncoder:
|
16 |
+
# def __init__(self, model_path: str, **kwargs):
|
17 |
+
# self.model = CE(model_path, **kwargs)
|
18 |
|
19 |
+
# def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
|
20 |
+
# return self.model.predict(
|
21 |
+
# sentences=sentences,
|
22 |
+
# batch_size=batch_size,
|
23 |
+
# show_progress_bar=show_progress_bar)
|
24 |
|
25 |
|
26 |
def remove_html(x):
|
|
|
134 |
return None
|
135 |
|
136 |
|
137 |
+
# @st.experimental_singleton
|
138 |
+
# def init_models():
|
139 |
+
# nltk.download('stopwords')
|
140 |
+
# nltk.download('punkt')
|
141 |
+
# from nltk.corpus import stopwords
|
142 |
+
# stop = set(stopwords.words('english') + list(string.punctuation))
|
143 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
144 |
+
# question_answerer = pipeline(
|
145 |
+
# "question-answering", model='nlpconnect/roberta-base-squad2-nq',
|
146 |
+
# device=0 if torch.cuda.is_available() else -1, handle_impossible_answer=False,
|
147 |
+
# )
|
148 |
+
# reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
149 |
+
# # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
150 |
+
# # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
+
# return question_answerer, reranker, stop, device
|
152 |
|
153 |
+
# qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
154 |
|
155 |
|
156 |
def clean_query(query, strict=True, clean=True):
|
|
|
206 |
For example try: Do tanning beds cause cancer?
|
207 |
""")
|
208 |
|
209 |
+
# st.markdown("""
|
210 |
+
# <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
|
211 |
+
# """, unsafe_allow_html=True)
|
212 |
+
|
213 |
+
# with st.expander("Settings (strictness, context limit, top hits)"):
|
214 |
+
# concat_passages = st.radio(
|
215 |
+
# "Concatenate passages as one long context?",
|
216 |
+
# ('yes', 'no'))
|
217 |
+
# present_impossible = st.radio(
|
218 |
+
# "Present impossible answers? (if the model thinks its impossible to answer should it still try?)",
|
219 |
+
# ('yes', 'no'))
|
220 |
+
# support_all = st.radio(
|
221 |
+
# "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
222 |
+
# ('no', 'yes'))
|
223 |
+
# support_abstracts = st.radio(
|
224 |
+
# "Use abstracts as a source document?",
|
225 |
+
# ('yes', 'no', 'abstract only'))
|
226 |
+
# strict_lenient_mix = st.radio(
|
227 |
+
# "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
|
228 |
+
# ('mix', 'fallback'))
|
229 |
+
# confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
|
230 |
+
# use_reranking = st.radio(
|
231 |
+
# "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
232 |
+
# ('yes', 'no'))
|
233 |
+
# top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
|
234 |
+
# context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
|
235 |
|
236 |
# def paraphrase(text, max_length=128):
|
237 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
272 |
return None
|
273 |
|
274 |
|
275 |
+
# def run_query_full(query, progress_bar):
|
276 |
+
# # if use_query_exp == 'yes':
|
277 |
+
# # query_exp = paraphrase(f"question2question: {query}")
|
278 |
+
# # st.markdown(f"""
|
279 |
+
# # If you are not getting good results try one of:
|
280 |
+
# # * {query_exp}
|
281 |
+
# # """)
|
282 |
+
|
283 |
+
# # could also try fallback if there are no good answers by score...
|
284 |
+
# limit = top_hits_limit or 100
|
285 |
+
# context_limit = context_lim or 10
|
286 |
+
# contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
|
287 |
+
# if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
|
288 |
+
# contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
|
289 |
+
# contexts = list(
|
290 |
+
# set(contexts_strict + contexts_lenient)
|
291 |
+
# )
|
292 |
+
# orig_docs = orig_docs_strict + orig_docs_lenient
|
293 |
+
# elif strict_lenient_mix == 'mix':
|
294 |
+
# contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
|
295 |
+
# contexts = list(
|
296 |
+
# set(contexts_strict + contexts_lenient)
|
297 |
+
# )
|
298 |
+
# orig_docs = orig_docs_strict + orig_docs_lenient
|
299 |
+
# else:
|
300 |
+
# contexts = list(
|
301 |
+
# set(contexts_strict)
|
302 |
+
# )
|
303 |
+
# orig_docs = orig_docs_strict
|
304 |
+
# progress_bar.progress(25)
|
305 |
+
|
306 |
+
# if len(contexts) == 0 or not ''.join(contexts).strip():
|
307 |
+
# return st.markdown("""
|
308 |
+
# <div class="container-fluid">
|
309 |
+
# <div class="row align-items-start">
|
310 |
+
# <div class="col-md-12 col-sm-12">
|
311 |
+
# Sorry... no results for that question! Try another...
|
312 |
+
# </div>
|
313 |
+
# </div>
|
314 |
+
# </div>
|
315 |
+
# """, unsafe_allow_html=True)
|
316 |
+
|
317 |
+
# if use_reranking == 'yes':
|
318 |
+
# sentence_pairs = [[query, context] for context in contexts]
|
319 |
+
# scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
320 |
+
# hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
321 |
+
# sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
322 |
+
# contexts = sorted_contexts[:context_limit]
|
323 |
+
# else:
|
324 |
+
# contexts = contexts[:context_limit]
|
325 |
+
|
326 |
+
# progress_bar.progress(50)
|
327 |
+
# if concat_passages == 'yes':
|
328 |
+
# context = '\n---'.join(contexts)
|
329 |
+
# model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
|
330 |
+
# else:
|
331 |
+
# context = ['\n---\n'+ctx for ctx in contexts]
|
332 |
+
# model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
|
333 |
+
|
334 |
+
# results = []
|
335 |
+
|
336 |
+
# progress_bar.progress(75)
|
337 |
+
# for i, result in enumerate(model_results):
|
338 |
+
# if concat_passages == 'yes':
|
339 |
+
# matched = matched_context(result['start'], result['end'], context)
|
340 |
+
# else:
|
341 |
+
# matched = matched_context(result['start'], result['end'], context[i])
|
342 |
+
# support = find_source(result['answer'], orig_docs, matched)
|
343 |
+
# if not support:
|
344 |
+
# continue
|
345 |
+
# results.append({
|
346 |
+
# "answer": support['text'],
|
347 |
+
# "title": support['source_title'],
|
348 |
+
# "link": support['source_link'],
|
349 |
+
# "context": support['citation_statement'],
|
350 |
+
# "score": result['score'],
|
351 |
+
# "doi": support["supporting"]
|
352 |
+
# })
|
353 |
+
|
354 |
+
# grouped_results = group_results_by_context(results)
|
355 |
+
# sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
|
356 |
+
|
357 |
+
# if confidence_threshold == 0:
|
358 |
+
# threshold = 0
|
359 |
+
# else:
|
360 |
+
# threshold = (confidence_threshold or 10) / 100
|
361 |
+
|
362 |
+
# sorted_result = list(filter(
|
363 |
+
# lambda x: x['score'] > threshold,
|
364 |
+
# sorted_result
|
365 |
+
# ))
|
366 |
+
|
367 |
+
# progress_bar.progress(100)
|
368 |
+
# for r in sorted_result:
|
369 |
+
# ctx = remove_html(r["context"])
|
370 |
+
# for answer in r['texts']:
|
371 |
+
# ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
372 |
+
# # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
|
373 |
+
# title = r.get("title", '')
|
374 |
+
# score = round(round(r["score"], 4) * 100, 2)
|
375 |
+
# card(title, ctx, score, r['link'], r['doi'])
|
376 |
+
|
377 |
+
|
378 |
+
def run_query(query):
|
379 |
+
api_location = 'http://74.82.31.93'
|
380 |
+
resp_raw = requests.get(
|
381 |
+
f'{api_location}/question-answer?query={query}'
|
382 |
+
)
|
383 |
+
try:
|
384 |
+
resp = resp_raw.json()
|
385 |
+
except:
|
386 |
+
resp = {'results': []}
|
387 |
|
388 |
+
if len(resp.get('results', [])) == 0:
|
389 |
return st.markdown("""
|
390 |
<div class="container-fluid">
|
391 |
<div class="row align-items-start">
|
|
|
396 |
</div>
|
397 |
""", unsafe_allow_html=True)
|
398 |
|
399 |
+
for r in resp['results']:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
ctx = remove_html(r["context"])
|
401 |
for answer in r['texts']:
|
402 |
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
|
|
|
408 |
query = st.text_input("Ask scientific literature a question", "")
|
409 |
if query != "":
|
410 |
with st.spinner('Loading...'):
|
411 |
+
run_query(query)
|
|