Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import pipeline | |
import requests | |
from bs4 import BeautifulSoup | |
import nltk | |
import string | |
from streamlit.components.v1 import html | |
from sentence_transformers.cross_encoder import CrossEncoder as CE | |
import re | |
from typing import List, Tuple | |
import torch | |
SCITE_API_KEY = st.secrets["SCITE_API_KEY"] | |
class CrossEncoder: | |
def __init__(self, model_path: str, **kwargs): | |
self.model = CE(model_path, **kwargs) | |
def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]: | |
return self.model.predict( | |
sentences=sentences, | |
batch_size=batch_size, | |
show_progress_bar=show_progress_bar) | |
def remove_html(x): | |
soup = BeautifulSoup(x, 'html.parser') | |
text = soup.get_text() | |
return text.strip() | |
# 4 searches: strict y/n, supported y/n | |
# deduplicate | |
# search per query | |
# options are abstract search | |
# all search | |
def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False): | |
term = clean_query(term, clean=clean, strict=strict) | |
# heuristic, 2 searches strict and not? and then merge? | |
# https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true | |
contexts, docs = [], [] | |
if not abstract_only: | |
mode = 'all' | |
if not all_mode: | |
mode = 'citations' | |
search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" | |
req = requests.get( | |
search, | |
headers={ | |
'Authorization': f'Bearer {SCITE_API_KEY}' | |
} | |
) | |
try: | |
req.json() | |
except: | |
pass | |
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']] | |
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '') | |
for doc in req.json()['hits']] | |
if abstracts or abstract_only: | |
search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" | |
req = requests.get( | |
search, | |
headers={ | |
'Authorization': f'Bearer {SCITE_API_KEY}' | |
} | |
) | |
try: | |
req.json() | |
contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']] | |
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '') | |
for doc in req.json()['hits']] | |
except: | |
pass | |
return ( | |
contexts, | |
docs | |
) | |
def find_source(text, docs, matched): | |
for doc in docs: | |
for snippet in doc[1]: | |
if text in remove_html(snippet.get('snippet', '')): | |
if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip(): | |
continue | |
new_text = text | |
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))): | |
if text in sent: | |
new_text = sent | |
return { | |
'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''), | |
'text': new_text, | |
'from': snippet['source'], | |
'supporting': snippet['target'], | |
'source_title': remove_html(doc[2] or ''), | |
'source_link': f"https://scite.ai/reports/{doc[0]}" | |
} | |
if text in remove_html(doc[3]): | |
if matched and remove_html(doc[3]).strip() != matched.strip(): | |
continue | |
new_text = text | |
sent_loc = None | |
sents = nltk.sent_tokenize(remove_html(doc[3])) | |
for i, sent in enumerate(sents): | |
if text in sent: | |
new_text = sent | |
sent_loc = i | |
context = remove_html(doc[3]).replace('<strong class="highlight">', '').replace('</strong>', '') | |
if sent_loc: | |
context_len = 2 | |
sent_beg = sent_loc - context_len | |
if sent_beg <= 0: sent_beg = 0 | |
sent_end = sent_loc + context_len | |
if sent_end >= len(sents): | |
sent_end = len(sents) | |
context = ''.join(sents[sent_beg:sent_end]) | |
return { | |
'citation_statement': context, | |
'text': new_text, | |
'from': doc[0], | |
'supporting': doc[0], | |
'source_title': remove_html(doc[2] or ''), | |
'source_link': f"https://scite.ai/reports/{doc[0]}" | |
} | |
return None | |
def init_models(): | |
nltk.download('stopwords') | |
nltk.download('punkt') | |
from nltk.corpus import stopwords | |
stop = set(stopwords.words('english') + list(string.punctuation)) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
question_answerer = pipeline( | |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B', | |
device=device | |
) | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device) | |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1") | |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1") | |
return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer | |
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer | |
def clean_query(query, strict=True, clean=True): | |
operator = ' ' | |
if strict: | |
operator = ' AND ' | |
query = operator.join( | |
[i for i in query.lower().split(' ') if clean and i not in stop]) | |
if clean: | |
query = query.translate(str.maketrans('', '', string.punctuation)) | |
return query | |
def card(title, context, score, link, supporting): | |
st.markdown(f""" | |
<div class="container-fluid"> | |
<div class="row align-items-start"> | |
<div class="col-md-12 col-sm-12"> | |
<br> | |
<span> | |
{context} | |
[<b>Confidence: </b>{score}%] | |
</span> | |
<br> | |
<b>From <a href="{link}">{title}</a></b> | |
</div> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
html(f""" | |
<div | |
class="scite-badge" | |
data-doi="{supporting}" | |
data-layout="horizontal" | |
data-show-zero="false" | |
data-show-labels="false" | |
data-tally-show="true" | |
/> | |
<script | |
async | |
type="application/javascript" | |
src="https://cdn.scite.ai/badge/scite-badge-latest.min.js"> | |
</script> | |
""", width=None, height=42, scrolling=False) | |
st.title("Scientific Question Answering with Citations") | |
st.write(""" | |
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements. | |
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. | |
For example try: Do tanning beds cause cancer? | |
""") | |
st.markdown(""" | |
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.0.0/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous"> | |
""", unsafe_allow_html=True) | |
with st.expander("Settings (strictness, context limit, top hits)"): | |
support_all = st.radio( | |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?", | |
('yes', 'no')) | |
support_abstracts = st.radio( | |
"Use abstracts as a source document?", | |
('yes', 'no', 'abstract only')) | |
strict_lenient_mix = st.radio( | |
"Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out", | |
('fallback', 'mix')) | |
confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1) | |
use_reranking = st.radio( | |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.", | |
('yes', 'no')) | |
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 50) | |
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 10) | |
# def paraphrase(text, max_length=128): | |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) | |
# generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length) | |
# queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]) | |
# preds = '\n * '.join(queries) | |
# return preds | |
def group_results_by_context(results): | |
result_groups = {} | |
for result in results: | |
if result['context'] not in result_groups: | |
result_groups[result['context']] = result | |
result_groups[result['context']]['texts'] = [] | |
result_groups[result['context']]['texts'].append( | |
result['answer'] | |
) | |
if result['score'] > result_groups[result['context']]['score']: | |
result_groups[result['context']]['score'] = result['score'] | |
return list(result_groups.values()) | |
def matched_context(start_i, end_i, contexts_string, seperator='---'): | |
# find seperators to identify start and end | |
doc_starts = [0] | |
for match in re.finditer(seperator, contexts_string): | |
doc_starts.append(match.end()) | |
for i in range(len(doc_starts)): | |
if i == len(doc_starts) - 1: | |
if start_i >= doc_starts[i]: | |
return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '') | |
if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]: | |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '') | |
return None | |
def run_query(query): | |
# if use_query_exp == 'yes': | |
# query_exp = paraphrase(f"question2question: {query}") | |
# st.markdown(f""" | |
# If you are not getting good results try one of: | |
# * {query_exp} | |
# """) | |
# address period in highlitht avoidability. Risk factors | |
# address poor tokenization Deletions involving chromosome region 4p16.3 cause WolfHirschhorn syndrome (WHS, OMIM 194190) [Battaglia et al, 2001]. | |
# address highlight html | |
# could also try fallback if there are no good answers by score... | |
limit = top_hits_limit or 100 | |
context_limit = context_lim or 10 | |
contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only') | |
if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit: | |
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only') | |
contexts = list( | |
set(contexts_strict + contexts_lenient) | |
) | |
orig_docs = orig_docs_strict + orig_docs_lenient | |
elif strict_lenient_mix == 'mix': | |
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False) | |
contexts = list( | |
set(contexts_strict + contexts_lenient) | |
) | |
orig_docs = orig_docs_strict + orig_docs_lenient | |
else: | |
contexts = list( | |
set(contexts_strict) | |
) | |
orig_docs = orig_docs_strict | |
if len(contexts) == 0 or not ''.join(contexts).strip(): | |
return st.markdown(""" | |
<div class="container-fluid"> | |
<div class="row align-items-start"> | |
<div class="col-md-12 col-sm-12"> | |
Sorry... no results for that question! Try another... | |
</div> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
if use_reranking == 'yes': | |
sentence_pairs = [[query, context] for context in contexts] | |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False) | |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))} | |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)] | |
context = '\n---'.join(sorted_contexts[:context_limit]) | |
else: | |
context = '\n---'.join(contexts[:context_limit]) | |
results = [] | |
model_results = qa_model(question=query, context=context, top_k=10) | |
for result in model_results: | |
matched = matched_context(result['start'], result['end'], context) | |
support = find_source(result['answer'], orig_docs, matched) | |
if not support: | |
continue | |
results.append({ | |
"answer": support['text'], | |
"title": support['source_title'], | |
"link": support['source_link'], | |
"context": support['citation_statement'], | |
"score": result['score'], | |
"doi": support["supporting"] | |
}) | |
grouped_results = group_results_by_context(results) | |
sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True) | |
if confidence_threshold == 0: | |
threshold = 0 | |
else: | |
threshold = (confidence_threshold or 10) / 100 | |
sorted_result = filter( | |
lambda x: x['score'] > threshold, | |
sorted_result | |
) | |
for r in sorted_result: | |
ctx = remove_html(r["context"]) | |
for answer in r['texts']: | |
ctx = ctx.replace(answer.strip(), f"<mark>{answer.strip()}</mark>") | |
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/') | |
title = r.get("title", '') | |
score = round(round(r["score"], 4) * 100, 2) | |
card(title, ctx, score, r['link'], r['doi']) | |
query = st.text_input("Ask scientific literature a question", "") | |
if query != "": | |
with st.spinner('Loading...'): | |
run_query(query) | |