Spaces:
Runtime error
Runtime error
import streamlit as st | |
import json | |
import time | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.cross_encoder import CrossEncoder | |
class DocumentSearch: | |
''' | |
This class is dedicated to | |
perform semantic document search | |
based on previously trained: | |
faiss: index | |
sbert: encoder | |
sbert: cross_encoder | |
''' | |
# we mention pass to every file that needed to run models | |
# and search over our data | |
enc_path = "ivan-savchuk/msmarco-distilbert-dot-v5-tuned-full-v1" | |
idx_path = "idx_vectors.index" | |
cross_enc_path = "ivan-savchuk/cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1" | |
docs_path = "docs.json" | |
def __init__(self): | |
# loading docs and corresponding urls | |
with open(DocumentSearch.docs_path, 'r') as json_file: | |
self.docs = json.load(json_file) | |
# loading sbert encoder model | |
self.encoder = SentenceTransformer(DocumentSearch.enc_path) | |
# loading faiss index | |
self.index = faiss.read_index(DocumentSearch.idx_path) | |
# loading sbert cross_encoder | |
# self.cross_encoder = CrossEncoder(DocumentSearch.cross_enc_path) | |
def search(self, query: str, k: int) -> list: | |
# get vector representation of text query | |
query_vector = self.encoder.encode([query]) | |
# perform search via faiss FlatIP index | |
distances, indeces = self.index.search(query_vector, k*10) | |
# get docs by index | |
res_docs = [self.docs[i] for i in indeces[0]] | |
# get scores by index | |
dists = [dist for dist in distances[0]] | |
return[{'doc': doc[0], 'url': doc[1], 'score': dist} for doc, dist in zip(res_docs, dists)] | |
##### OLD VERSION WITH CROSS-ENCODER ##### | |
# get answers by index | |
#answers = [self.docs[i] for i in indeces[0]] | |
# prepare inputs for cross encoder | |
# model_inputs = [[query, pairs[0]] for pairs in answers] | |
# urls = [pairs[1] for pairs in answers] | |
# get similarity score between query and documents | |
# scores = self.cross_encoder.predict(model_inputs, batch_size=1) | |
# compose results into list of dicts | |
# results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)] | |
# return results sorted by similarity scores | |
# return sorted(results, key=lambda x: x['score'], reverse=True)[:k] | |
if __name__ == "__main__": | |
# get instance of DocumentSearch class | |
surfer = DocumentSearch() | |
# streamlit part starts here with title | |
title = """ | |
<h1 style=' | |
text-align: center; | |
color: #3CB371'> | |
Medical Search | |
</h1> | |
""" | |
st.markdown(title, unsafe_allow_html=True) | |
# input form | |
with st.form("my_form"): | |
# here we have input space | |
query = st.text_input("Enter any query about our medical data", | |
placeholder="Type query here...") | |
# Every form must have a submit button. | |
submitted = st.form_submit_button("Search") | |
# on submit we execute search | |
if(submitted): | |
# set start time | |
stt = time.time() | |
# retrieve top 5 documents | |
results = surfer.search(query, k=5) | |
# set endtime | |
ent = time.time() | |
# measure resulting time | |
elapsed_time = round(ent - stt, 2) | |
# show which query was entered, and what was searching time | |
st.write(f"**Results Related to:** \"{query}\" ({elapsed_time} sec.)") | |
# then we use loop to show results | |
for i, answer in enumerate(results): | |
# answer starts with header | |
st.subheader(f"Answer {i+1}") | |
# cropped answer | |
doc = answer["doc"][:150] + "..." | |
# and url to the full answer | |
url = answer["url"] | |
# then we display it | |
st.markdown(f'{doc}\n[**Read More**]({url})\n', unsafe_allow_html=True) | |
st.markdown("---") | |
st.markdown("**Author:** Ivan Savchuk. 2022") | |
else: | |
st.markdown("Typical queries looks like this: _**\"What is flu?\"**_,\ | |
_**\"How to cure breast cancer?\"**_,\ | |
_**\"I have headache, what should I do?\"**_") | |