Spaces:
Runtime error
Runtime error
File size: 3,468 Bytes
b951bdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import streamlit as st
import json
import time
import faiss
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
class DocumentSearch:
'''
This class is dedicated to
perform semantic document search
based on previously trained:
faiss: index
sbert: encoder
sbert: cross_encoder
'''
def __init__(self, labels_path: str, encoder_path: str,
index_path: str, cross_encoder_path: str):
# loading docs and corresponding urls
with open(labels_path, 'r') as json_file:
self.docs = json.load(json_file)
# loading sbert encoder model
self.encoder = SentenceTransformer(encoder_path)
# loading faiss index
self.index = faiss.read_index(index_path)
# loading sbert cross_encoder
self.cross_encoder = CrossEncoder(cross_encoder_path)
def search(self, query: str, k: int) -> list:
# get vector representation of text query
query_vector = self.encoder.encode([query])
# perform search via faiss FlatIP index
_, indeces = self.index.search(query_vector, k*10)
# get answers by index
answers = [self.docs[i] for i in indeces[0]]
# prepare inputs for cross encoder
model_inputs = [[query, pairs[0]] for pairs in answers]
urls = [pairs[1] for pairs in answers]
# get similarity score between query and documents
scores = self.cross_encoder.predict(model_inputs, batch_size=1)
# compose results into list of dicts
results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
# return results sorteed by similarity scores
return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
enc_path = "msmarco-distilbert-dot-v5-tuned-full-v1"
idx_path = "idx_vectors.index"
cross_enc_path = "cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1"
docs_path = "docs.json"
# get instance of DocumentSearch class
surfer = DocumentSearch(
labels_path=docs_path,
encoder_path=enc_path,
index_path=idx_path,
cross_encoder_path=cross_enc_path
)
if __name__ == "__main__":
# streamlit part starts here with title
st.title('Medical Search')
# here we have input space
query = st.text_input("Enter any query about our data",
placeholder="Type query here...")
# on submit we execute search
if(st.button("Search")):
# set start time
stt = time.time()
# retrieve top 5 documents
results = surfer.search(query, k=5)
# set endtime
ent = time.time()
# measure resulting time
elapsed_time = round(ent - stt, 2)
# define container for answers
with st.container():
# show which query was entered, and what was searching time
st.write(f"**Results Related to:** {query} ({elapsed_time} sec.)")
# then we use loop to show results
for i, answer in enumerate(results):
# answer starts with header
st.subheader(f"Answer {i+1}")
# cropped answer
doc = answer["doc"][:150] + "..."
# and url to the full answer
url = answer["url"]
# then we display it
st.markdown(f"{doc}\n[**Read More**]({url})\n")
|