Spaces:
Runtime error
Runtime error
File size: 16,996 Bytes
7fbcea5 f1fd3e1 7fbcea5 4c36cd4 964c419 4c36cd4 e996282 3f1f616 4c36cd4 3f1f616 4c36cd4 7fbcea5 964c419 7fbcea5 4c36cd4 e996282 03de2e8 e996282 03de2e8 5ed186b 4c36cd4 03de2e8 5ed186b 964c419 5ed186b 7fbcea5 5ed186b 7fbcea5 964c419 7fbcea5 a776895 964c419 a776895 b7e15be 5ed186b a776895 964c419 5ed186b 964c419 5ed186b 4b582d1 5ed186b 4b582d1 29232c3 4b582d1 5ed186b 4b582d1 5ed186b a14da38 1137a5a a776895 4c36cd4 7fbcea5 3f1f616 4c36cd4 3f1f616 7fbcea5 69d7ac6 4c36cd4 7fbcea5 4c36cd4 7fbcea5 ee33bad 7fbcea5 4c36cd4 7fbcea5 7cfb21e e996282 7fbcea5 2d1b5a8 3f1f616 69d7ac6 e15c8b9 00e4b2e b7e15be 964c419 f5555cd 3f1f616 a91b925 3f1f616 7fbcea5 8890bde 7fbcea5 8890bde 964c419 3f1f616 b7e15be 964c419 b7e15be ee33bad 4c36cd4 7fbcea5 4c36cd4 3f1f616 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 |
import streamlit as st
from transformers import pipeline, AutoTokenizer, LEDForConditionalGeneration
import requests
from bs4 import BeautifulSoup
import nltk
import string
from streamlit.components.v1 import html
from sentence_transformers.cross_encoder import CrossEncoder as CE
import re
from typing import List, Tuple
import torch
SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
# class CrossEncoder:
# def __init__(self, model_path: str, **kwargs):
# self.model = CE(model_path, **kwargs)
# def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
# return self.model.predict(
# sentences=sentences,
# batch_size=batch_size,
# show_progress_bar=show_progress_bar)
def remove_html(x):
soup = BeautifulSoup(x, 'html.parser')
text = soup.get_text()
return text.strip()
# 4 searches: strict y/n, supported y/n
# deduplicate
# search per query
# options are abstract search
# all search
def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False):
term = clean_query(term, clean=clean, strict=strict)
# heuristic, 2 searches strict and not? and then merge?
# https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
contexts, docs = [], []
if not abstract_only:
mode = 'all'
if not all_mode:
mode = 'citations'
search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
req = requests.get(
search,
headers={
'Authorization': f'Bearer {SCITE_API_KEY}'
}
)
try:
req.json()
except:
pass
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']]
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
for doc in req.json()['hits']]
if abstracts or abstract_only:
search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
req = requests.get(
search,
headers={
'Authorization': f'Bearer {SCITE_API_KEY}'
}
)
try:
req.json()
contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']]
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
for doc in req.json()['hits']]
except:
pass
return (
contexts,
docs
)
def find_source(text, docs, matched):
for doc in docs:
for snippet in doc[1]:
if text in remove_html(snippet.get('snippet', '')):
if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip():
continue
new_text = text
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
if text in sent:
new_text = sent
return {
'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''),
'text': new_text,
'from': snippet['source'],
'supporting': snippet['target'],
'source_title': remove_html(doc[2] or ''),
'source_link': f"https://scite.ai/reports/{doc[0]}"
}
if text in remove_html(doc[3]):
if matched and remove_html(doc[3]).strip() != matched.strip():
continue
new_text = text
sent_loc = None
sents = nltk.sent_tokenize(remove_html(doc[3]))
for i, sent in enumerate(sents):
if text in sent:
new_text = sent
sent_loc = i
context = remove_html(doc[3]).replace('<strong class="highlight">', '').replace('</strong>', '')
if sent_loc:
context_len = 3
sent_beg = sent_loc - context_len
if sent_beg <= 0: sent_beg = 0
sent_end = sent_loc + context_len
if sent_end >= len(sents):
sent_end = len(sents)
context = ''.join(sents[sent_beg:sent_end])
return {
'citation_statement': context,
'text': new_text,
'from': doc[0],
'supporting': doc[0],
'source_title': remove_html(doc[2] or ''),
'source_link': f"https://scite.ai/reports/{doc[0]}"
}
return None
# @st.experimental_singleton
# def init_models():
# nltk.download('stopwords')
# nltk.download('punkt')
# from nltk.corpus import stopwords
# stop = set(stopwords.words('english') + list(string.punctuation))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# question_answerer = pipeline(
# "question-answering", model='nlpconnect/roberta-base-squad2-nq',
# device=0 if torch.cuda.is_available() else -1, handle_impossible_answer=False,
# )
# reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
# # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
# # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
# return question_answerer, reranker, stop, device
# qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
def clean_query(query, strict=True, clean=True):
operator = ' '
if strict:
operator = ' AND '
query = operator.join(
[i for i in query.lower().split(' ') if clean and i not in stop])
if clean:
query = query.translate(str.maketrans('', '', string.punctuation))
return query
def card(title, context, score, link, supporting):
st.markdown(f"""
<div class="container-fluid">
<div class="row align-items-start">
<div class="col-md-12 col-sm-12">
<br>
<span>
{context}
[<b>Confidence: </b>{score}%]
</span>
<br>
<b>From <a href="{link}">{title}</a></b>
</div>
</div>
</div>
""", unsafe_allow_html=True)
html(f"""
<div
class="scite-badge"
data-doi="{supporting}"
data-layout="horizontal"
data-show-zero="false"
data-show-labels="false"
data-tally-show="true"
/>
<script
async
type="application/javascript"
src="https://cdn.scite.ai/badge/scite-badge-latest.min.js">
</script>
""", width=None, height=42, scrolling=False)
st.title("Scientific Question Answering with Citations")
st.write("""
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
For example try: Do tanning beds cause cancer?
""")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
# with st.expander("Settings (strictness, context limit, top hits)"):
# concat_passages = st.radio(
# "Concatenate passages as one long context?",
# ('yes', 'no'))
# present_impossible = st.radio(
# "Present impossible answers? (if the model thinks its impossible to answer should it still try?)",
# ('yes', 'no'))
# support_all = st.radio(
# "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
# ('no', 'yes'))
# support_abstracts = st.radio(
# "Use abstracts as a source document?",
# ('yes', 'no', 'abstract only'))
# strict_lenient_mix = st.radio(
# "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
# ('mix', 'fallback'))
# confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
# use_reranking = st.radio(
# "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
# ('yes', 'no'))
# top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 100)
# context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 25)
# def paraphrase(text, max_length=128):
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
# generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
# queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
# preds = '\n * '.join(queries)
# return preds
def group_results_by_context(results):
result_groups = {}
for result in results:
if result['context'] not in result_groups:
result_groups[result['context']] = result
result_groups[result['context']]['texts'] = []
result_groups[result['context']]['texts'].append(
result['answer']
)
if result['score'] > result_groups[result['context']]['score']:
result_groups[result['context']]['score'] = result['score']
return list(result_groups.values())
def matched_context(start_i, end_i, contexts_string, seperator='---'):
# find seperators to identify start and end
doc_starts = [0]
for match in re.finditer(seperator, contexts_string):
doc_starts.append(match.end())
for i in range(len(doc_starts)):
if i == len(doc_starts) - 1:
if start_i >= doc_starts[i]:
return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '')
if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]:
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
return None
# def run_query_full(query, progress_bar):
# # if use_query_exp == 'yes':
# # query_exp = paraphrase(f"question2question: {query}")
# # st.markdown(f"""
# # If you are not getting good results try one of:
# # * {query_exp}
# # """)
# # could also try fallback if there are no good answers by score...
# limit = top_hits_limit or 100
# context_limit = context_lim or 10
# contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
# if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
# contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
# contexts = list(
# set(contexts_strict + contexts_lenient)
# )
# orig_docs = orig_docs_strict + orig_docs_lenient
# elif strict_lenient_mix == 'mix':
# contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
# contexts = list(
# set(contexts_strict + contexts_lenient)
# )
# orig_docs = orig_docs_strict + orig_docs_lenient
# else:
# contexts = list(
# set(contexts_strict)
# )
# orig_docs = orig_docs_strict
# progress_bar.progress(25)
# if len(contexts) == 0 or not ''.join(contexts).strip():
# return st.markdown("""
# <div class="container-fluid">
# <div class="row align-items-start">
# <div class="col-md-12 col-sm-12">
# Sorry... no results for that question! Try another...
# </div>
# </div>
# </div>
# """, unsafe_allow_html=True)
# if use_reranking == 'yes':
# sentence_pairs = [[query, context] for context in contexts]
# scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
# hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
# sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
# contexts = sorted_contexts[:context_limit]
# else:
# contexts = contexts[:context_limit]
# progress_bar.progress(50)
# if concat_passages == 'yes':
# context = '\n---'.join(contexts)
# model_results = qa_model(question=query, context=context, top_k=10, doc_stride=512 // 2, max_answer_len=128, max_seq_len=512, handle_impossible_answer=present_impossible=='yes')
# else:
# context = ['\n---\n'+ctx for ctx in contexts]
# model_results = qa_model(question=[query]*len(contexts), context=context, handle_impossible_answer=present_impossible=='yes')
# results = []
# progress_bar.progress(75)
# for i, result in enumerate(model_results):
# if concat_passages == 'yes':
# matched = matched_context(result['start'], result['end'], context)
# else:
# matched = matched_context(result['start'], result['end'], context[i])
# support = find_source(result['answer'], orig_docs, matched)
# if not support:
# continue
# results.append({
# "answer": support['text'],
# "title": support['source_title'],
# "link": support['source_link'],
# "context": support['citation_statement'],
# "score": result['score'],
# "doi": support["supporting"]
# })
# grouped_results = group_results_by_context(results)
# sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
# if confidence_threshold == 0:
# threshold = 0
# else:
# threshold = (confidence_threshold or 10) / 100
# sorted_result = list(filter(
# lambda x: x['score'] > threshold,
# sorted_result
# ))
# progress_bar.progress(100)
# for r in sorted_result:
# ctx = remove_html(r["context"])
# for answer in r['texts']:
# ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
# # .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
# title = r.get("title", '')
# score = round(round(r["score"], 4) * 100, 2)
# card(title, ctx, score, r['link'], r['doi'])
def run_query(query):
api_location = 'http://74.82.31.93'
resp_raw = requests.get(
f'{api_location}/question-answer?query={query}'
)
try:
resp = resp_raw.json()
except:
resp = {'results': []}
if len(resp.get('results', [])) == 0:
return st.markdown("""
<div class="container-fluid">
<div class="row align-items-start">
<div class="col-md-12 col-sm-12">
Sorry... no results for that question! Try another...
</div>
</div>
</div>
""", unsafe_allow_html=True)
for r in resp['results']:
ctx = remove_html(r["context"])
for answer in r['texts']:
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>")
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
title = r.get("title", '')
score = round(round(r["score"], 4) * 100, 2)
card(title, ctx, score, r['link'], r['doi'])
query = st.text_input("Ask scientific literature a question", "")
if query != "":
with st.spinner('Loading...'):
run_query(query)
|