Spaces:
Runtime error
Runtime error
import streamlit as st | |
import json | |
import time | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.cross_encoder import CrossEncoder | |
class DocumentSearch: | |
''' | |
This class is dedicated to | |
perform semantic document search | |
based on previously trained: | |
faiss: index | |
sbert: encoder | |
sbert: cross_encoder | |
''' | |
def __init__(self, labels_path: str, encoder_path: str, | |
index_path: str, cross_encoder_path: str): | |
# loading docs and corresponding urls | |
with open(labels_path, 'r') as json_file: | |
self.docs = json.load(json_file) | |
# loading sbert encoder model | |
self.encoder = SentenceTransformer(encoder_path) | |
# loading faiss index | |
self.index = faiss.read_index(index_path) | |
# loading sbert cross_encoder | |
self.cross_encoder = CrossEncoder(cross_encoder_path) | |
def search(self, query: str, k: int) -> list: | |
# get vector representation of text query | |
query_vector = self.encoder.encode([query]) | |
# perform search via faiss FlatIP index | |
_, indeces = self.index.search(query_vector, k*10) | |
# get answers by index | |
answers = [self.docs[i] for i in indeces[0]] | |
# prepare inputs for cross encoder | |
model_inputs = [[query, pairs[0]] for pairs in answers] | |
urls = [pairs[1] for pairs in answers] | |
# get similarity score between query and documents | |
scores = self.cross_encoder.predict(model_inputs, batch_size=1) | |
# compose results into list of dicts | |
results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)] | |
# return results sorteed by similarity scores | |
return sorted(results, key=lambda x: x['score'], reverse=True)[:k] | |
if __name__ == "__main__": | |
enc_path = "ivan-savchuk/msmarco-distilbert-dot-v5-tuned-full-v1" | |
idx_path = "idx_vectors.index" | |
cross_enc_path = "ivan-savchuk/cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1" | |
docs_path = "docs.json" | |
# get instance of DocumentSearch class | |
surfer = DocumentSearch( | |
labels_path=docs_path, | |
encoder_path=enc_path, | |
index_path=idx_path, | |
cross_encoder_path=cross_enc_path | |
) | |
# streamlit part starts here with title | |
st.title('Medical Search') | |
# input form | |
with st.form("my_form"): | |
# here we have input space | |
query = st.text_input("Enter any query about our medical data", | |
placeholder="Type query here...") | |
# Every form must have a submit button. | |
submitted = st.form_submit_button("Search") | |
# on submit we execute search | |
if(submitted): | |
# set start time | |
stt = time.time() | |
# retrieve top 5 documents | |
results = surfer.search(query, k=5) | |
# set endtime | |
ent = time.time() | |
# measure resulting time | |
elapsed_time = round(ent - stt, 2) | |
# show which query was entered, and what was searching time | |
st.write(f"**Results Related to:** {query} ({elapsed_time} sec.)") | |
# then we use loop to show results | |
for i, answer in enumerate(results): | |
# answer starts with header | |
st.subheader(f"Answer {i+1}") | |
# cropped answer | |
doc = answer["doc"][:150] + "..." | |
# and url to the full answer | |
url = answer["url"] | |
# then we display it | |
st.markdown(f'{doc}\n[**Read More**]({url})\n', unsafe_allow_html=True) | |
st.markdown("---") | |
st.markdown("**Author:** Ivan Savchuk. 2022") | |