medical-search / app.py
ivan-savchuk's picture
Upload app.py
b951bdb
raw
history blame
3.47 kB
import streamlit as st
import json
import time
import faiss
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
class DocumentSearch:
'''
This class is dedicated to
perform semantic document search
based on previously trained:
faiss: index
sbert: encoder
sbert: cross_encoder
'''
def __init__(self, labels_path: str, encoder_path: str,
index_path: str, cross_encoder_path: str):
# loading docs and corresponding urls
with open(labels_path, 'r') as json_file:
self.docs = json.load(json_file)
# loading sbert encoder model
self.encoder = SentenceTransformer(encoder_path)
# loading faiss index
self.index = faiss.read_index(index_path)
# loading sbert cross_encoder
self.cross_encoder = CrossEncoder(cross_encoder_path)
def search(self, query: str, k: int) -> list:
# get vector representation of text query
query_vector = self.encoder.encode([query])
# perform search via faiss FlatIP index
_, indeces = self.index.search(query_vector, k*10)
# get answers by index
answers = [self.docs[i] for i in indeces[0]]
# prepare inputs for cross encoder
model_inputs = [[query, pairs[0]] for pairs in answers]
urls = [pairs[1] for pairs in answers]
# get similarity score between query and documents
scores = self.cross_encoder.predict(model_inputs, batch_size=1)
# compose results into list of dicts
results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
# return results sorteed by similarity scores
return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
enc_path = "msmarco-distilbert-dot-v5-tuned-full-v1"
idx_path = "idx_vectors.index"
cross_enc_path = "cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1"
docs_path = "docs.json"
# get instance of DocumentSearch class
surfer = DocumentSearch(
labels_path=docs_path,
encoder_path=enc_path,
index_path=idx_path,
cross_encoder_path=cross_enc_path
)
if __name__ == "__main__":
# streamlit part starts here with title
st.title('Medical Search')
# here we have input space
query = st.text_input("Enter any query about our data",
placeholder="Type query here...")
# on submit we execute search
if(st.button("Search")):
# set start time
stt = time.time()
# retrieve top 5 documents
results = surfer.search(query, k=5)
# set endtime
ent = time.time()
# measure resulting time
elapsed_time = round(ent - stt, 2)
# define container for answers
with st.container():
# show which query was entered, and what was searching time
st.write(f"**Results Related to:** {query} ({elapsed_time} sec.)")
# then we use loop to show results
for i, answer in enumerate(results):
# answer starts with header
st.subheader(f"Answer {i+1}")
# cropped answer
doc = answer["doc"][:150] + "..."
# and url to the full answer
url = answer["url"]
# then we display it
st.markdown(f"{doc}\n[**Read More**]({url})\n")