ivan-savchuk commited on
Commit
b951bdb
β€’
1 Parent(s): 2992997

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import time
4
+ import faiss
5
+ from sentence_transformers import SentenceTransformer
6
+ from sentence_transformers.cross_encoder import CrossEncoder
7
+
8
+
9
+ class DocumentSearch:
10
+ '''
11
+ This class is dedicated to
12
+ perform semantic document search
13
+ based on previously trained:
14
+ faiss: index
15
+ sbert: encoder
16
+ sbert: cross_encoder
17
+ '''
18
+ def __init__(self, labels_path: str, encoder_path: str,
19
+ index_path: str, cross_encoder_path: str):
20
+ # loading docs and corresponding urls
21
+ with open(labels_path, 'r') as json_file:
22
+ self.docs = json.load(json_file)
23
+ # loading sbert encoder model
24
+ self.encoder = SentenceTransformer(encoder_path)
25
+ # loading faiss index
26
+ self.index = faiss.read_index(index_path)
27
+ # loading sbert cross_encoder
28
+ self.cross_encoder = CrossEncoder(cross_encoder_path)
29
+
30
+ def search(self, query: str, k: int) -> list:
31
+ # get vector representation of text query
32
+ query_vector = self.encoder.encode([query])
33
+ # perform search via faiss FlatIP index
34
+ _, indeces = self.index.search(query_vector, k*10)
35
+ # get answers by index
36
+ answers = [self.docs[i] for i in indeces[0]]
37
+ # prepare inputs for cross encoder
38
+ model_inputs = [[query, pairs[0]] for pairs in answers]
39
+ urls = [pairs[1] for pairs in answers]
40
+ # get similarity score between query and documents
41
+ scores = self.cross_encoder.predict(model_inputs, batch_size=1)
42
+ # compose results into list of dicts
43
+ results = [{'doc': doc[1], 'url': url, 'score': score} for doc, url, score in zip(model_inputs, urls, scores)]
44
+ # return results sorteed by similarity scores
45
+ return sorted(results, key=lambda x: x['score'], reverse=True)[:k]
46
+
47
+
48
+ enc_path = "msmarco-distilbert-dot-v5-tuned-full-v1"
49
+ idx_path = "idx_vectors.index"
50
+ cross_enc_path = "cross-encoder-ms-marco-MiniLM-L-12-v2-tuned_mediqa-v1"
51
+ docs_path = "docs.json"
52
+ # get instance of DocumentSearch class
53
+ surfer = DocumentSearch(
54
+ labels_path=docs_path,
55
+ encoder_path=enc_path,
56
+ index_path=idx_path,
57
+ cross_encoder_path=cross_enc_path
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ # streamlit part starts here with title
63
+ st.title('Medical Search')
64
+ # here we have input space
65
+ query = st.text_input("Enter any query about our data",
66
+ placeholder="Type query here...")
67
+ # on submit we execute search
68
+ if(st.button("Search")):
69
+ # set start time
70
+ stt = time.time()
71
+ # retrieve top 5 documents
72
+ results = surfer.search(query, k=5)
73
+ # set endtime
74
+ ent = time.time()
75
+ # measure resulting time
76
+ elapsed_time = round(ent - stt, 2)
77
+
78
+ # define container for answers
79
+ with st.container():
80
+ # show which query was entered, and what was searching time
81
+ st.write(f"**Results Related to:** {query} ({elapsed_time} sec.)")
82
+ # then we use loop to show results
83
+ for i, answer in enumerate(results):
84
+ # answer starts with header
85
+ st.subheader(f"Answer {i+1}")
86
+ # cropped answer
87
+ doc = answer["doc"][:150] + "..."
88
+ # and url to the full answer
89
+ url = answer["url"]
90
+ # then we display it
91
+ st.markdown(f"{doc}\n[**Read More**]({url})\n")