import streamlit as st from transformers import pipeline import requests from bs4 import BeautifulSoup import nltk import string from streamlit.components.v1 import html from sentence_transformers.cross_encoder import CrossEncoder as CE import re from typing import List, Tuple import torch SCITE_API_KEY = st.secrets["SCITE_API_KEY"] class CrossEncoder: def __init__(self, model_path: str, **kwargs): self.model = CE(model_path, **kwargs) def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]: return self.model.predict( sentences=sentences, batch_size=batch_size, show_progress_bar=show_progress_bar) def remove_html(x): soup = BeautifulSoup(x, 'html.parser') text = soup.get_text() return text.strip() # 4 searches: strict y/n, supported y/n # deduplicate # search per query # options are abstract search # all search def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False): term = clean_query(term, clean=clean, strict=strict) # heuristic, 2 searches strict and not? and then merge? # https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true contexts, docs = [], [] if not abstract_only: mode = 'all' if not all_mode: mode = 'citations' search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) try: req.json() except: pass contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations'] if cite['lang'] == 'en'])) for doc in req.json()['hits']] docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '') for doc in req.json()['hits']] if abstracts or abstract_only: search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false" req = requests.get( search, headers={ 'Authorization': f'Bearer {SCITE_API_KEY}' } ) try: req.json() contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']] docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '') for doc in req.json()['hits']] except: pass return ( contexts, docs ) def find_source(text, docs, matched): for doc in docs: for snippet in doc[1]: if text in remove_html(snippet.get('snippet', '')): if matched and remove_html(snippet.get('snippet', '')).strip() != matched.strip(): continue new_text = text for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))): if text in sent: new_text = sent return { 'citation_statement': snippet['snippet'].replace('', '').replace('', ''), 'text': new_text, 'from': snippet['source'], 'supporting': snippet['target'], 'source_title': remove_html(doc[2] or ''), 'source_link': f"https://scite.ai/reports/{doc[0]}" } if text in remove_html(doc[3]): if matched and remove_html(doc[3]).strip() != matched.strip(): continue new_text = text sent_loc = None sents = nltk.sent_tokenize(remove_html(doc[3])) for i, sent in enumerate(sents): if text in sent: new_text = sent sent_loc = i context = remove_html(doc[3]).replace('', '').replace('', '') if sent_loc: context_len = 2 sent_beg = sent_loc - context_len if sent_beg <= 0: sent_beg = 0 sent_end = sent_loc + context_len if sent_end >= len(sents): sent_end = len(sents) context = ''.join(sents[sent_beg:sent_end]) return { 'citation_statement': context, 'text': new_text, 'from': doc[0], 'supporting': doc[0], 'source_title': remove_html(doc[2] or ''), 'source_link': f"https://scite.ai/reports/{doc[0]}" } return None @st.experimental_singleton def init_models(): nltk.download('stopwords') nltk.download('punkt') from nltk.corpus import stopwords stop = set(stopwords.words('english') + list(string.punctuation)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") question_answerer = pipeline( "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B', device=device ) reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device) # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1") # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1") return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer def clean_query(query, strict=True, clean=True): operator = ' ' if strict: operator = ' AND ' query = operator.join( [i for i in query.lower().split(' ') if clean and i not in stop]) if clean: query = query.translate(str.maketrans('', '', string.punctuation)) return query def card(title, context, score, link, supporting): st.markdown(f"""

{context} [Confidence: {score}%]
From {title}
""", unsafe_allow_html=True) html(f"""
""", width=None, height=42, scrolling=False) st.title("Scientific Question Answering with Citations") st.write(""" Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements. Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer. For example try: Do tanning beds cause cancer? """) st.markdown(""" """, unsafe_allow_html=True) with st.expander("Settings (strictness, context limit, top hits)"): support_all = st.radio( "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?", ('yes', 'no')) support_abstracts = st.radio( "Use abstracts as a source document?", ('yes', 'no', 'abstract only')) strict_lenient_mix = st.radio( "Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out", ('fallback', 'mix')) confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1) use_reranking = st.radio( "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.", ('yes', 'no')) top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 50) context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 10) # def paraphrase(text, max_length=128): # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True) # generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length) # queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]) # preds = '\n * '.join(queries) # return preds def group_results_by_context(results): result_groups = {} for result in results: if result['context'] not in result_groups: result_groups[result['context']] = result result_groups[result['context']]['texts'] = [] result_groups[result['context']]['texts'].append( result['answer'] ) if result['score'] > result_groups[result['context']]['score']: result_groups[result['context']]['score'] = result['score'] return list(result_groups.values()) def matched_context(start_i, end_i, contexts_string, seperator='---'): # find seperators to identify start and end doc_starts = [0] for match in re.finditer(seperator, contexts_string): doc_starts.append(match.end()) for i in range(len(doc_starts)): if i == len(doc_starts) - 1: if start_i >= doc_starts[i]: return contexts_string[doc_starts[i]:len(contexts_string)].replace(seperator, '') if start_i >= doc_starts[i] and end_i <= doc_starts[i+1]: return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '') return None def run_query(query): # if use_query_exp == 'yes': # query_exp = paraphrase(f"question2question: {query}") # st.markdown(f""" # If you are not getting good results try one of: # * {query_exp} # """) # address period in highlitht avoidability. Risk factors # address poor tokenization Deletions involving chromosome region 4p16.3 cause WolfHirschhorn syndrome (WHS, OMIM 194190) [Battaglia et al, 2001]. # address highlight html # could also try fallback if there are no good answers by score... limit = top_hits_limit or 100 context_limit = context_lim or 10 contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only') if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit: contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only') contexts = list( set(contexts_strict + contexts_lenient) ) orig_docs = orig_docs_strict + orig_docs_lenient elif strict_lenient_mix == 'mix': contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False) contexts = list( set(contexts_strict + contexts_lenient) ) orig_docs = orig_docs_strict + orig_docs_lenient else: contexts = list( set(contexts_strict) ) orig_docs = orig_docs_strict if len(contexts) == 0 or not ''.join(contexts).strip(): return st.markdown("""
Sorry... no results for that question! Try another...
""", unsafe_allow_html=True) if use_reranking == 'yes': sentence_pairs = [[query, context] for context in contexts] scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False) hits = {contexts[idx]: scores[idx] for idx in range(len(scores))} sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)] context = '\n---'.join(sorted_contexts[:context_limit]) else: context = '\n---'.join(contexts[:context_limit]) results = [] model_results = qa_model(question=query, context=context, top_k=10) for result in model_results: matched = matched_context(result['start'], result['end'], context) support = find_source(result['answer'], orig_docs, matched) if not support: continue results.append({ "answer": support['text'], "title": support['source_title'], "link": support['source_link'], "context": support['citation_statement'], "score": result['score'], "doi": support["supporting"] }) grouped_results = group_results_by_context(results) sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True) if confidence_threshold == 0: threshold = 0 else: threshold = (confidence_threshold or 10) / 100 sorted_result = filter( lambda x: x['score'] > threshold, sorted_result ) for r in sorted_result: ctx = remove_html(r["context"]) for answer in r['texts']: ctx = ctx.replace(answer.strip(), f"{answer.strip()}") # .replace( '