Update pages/3_Earnings_Semantic_Search_π_.py
Browse files
pages/3_Earnings_Semantic_Search_π_.py
CHANGED
@@ -8,13 +8,17 @@ st.markdown("## Earnings Semantic Search with SBert")
|
|
8 |
def gen_sentiment(text):
|
9 |
'''Generate sentiment of given text'''
|
10 |
return sent_pipe(text)[0]['label']
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
search_input = st.text_input(
|
15 |
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
|
16 |
|
17 |
-
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=
|
18 |
|
19 |
top_k = 2
|
20 |
|
@@ -32,31 +36,34 @@ try:
|
|
32 |
passages = chunk_long_text(st.session_state['earnings_passages'],150,window_size=window_size)
|
33 |
|
34 |
with st.spinner(
|
35 |
-
text=f"Loading {sbert_model_name} encoder..."
|
36 |
):
|
37 |
-
sbert = load_sbert(sbert_model_name)
|
38 |
|
39 |
|
40 |
##### Sematic Search #####
|
41 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
42 |
-
corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
43 |
-
question_embedding = sbert.encode(search_input, convert_to_tensor=True)
|
44 |
-
question_embedding = question_embedding.cpu()
|
45 |
-
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
46 |
-
hits = hits[0] # Get the hits for the first query
|
47 |
|
48 |
-
##### Re-Ranking #####
|
49 |
-
# Now, score all retrieved passages with the cross_encoder
|
50 |
-
cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
|
51 |
-
cross_scores = cross_encoder.predict(cross_inp)
|
52 |
|
53 |
-
# Sort results by the cross-encoder scores
|
54 |
-
for idx in range(len(cross_scores)):
|
55 |
-
|
56 |
|
57 |
-
# Output of top-3 hits from re-ranker
|
58 |
-
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
|
|
59 |
|
|
|
|
|
60 |
score='cross-score'
|
61 |
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
|
62 |
df['Score'] = round(df['Score'],2)
|
|
|
8 |
def gen_sentiment(text):
|
9 |
'''Generate sentiment of given text'''
|
10 |
return sent_pipe(text)[0]['label']
|
11 |
+
|
12 |
+
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
|
13 |
+
'e5-base':'intfloat/e5-base',
|
14 |
+
'instructor-base': 'hkunlp/instructor-base',
|
15 |
+
'mpnet-base-dot-v1':'multi-qa-mpnet-base-dot-v1',
|
16 |
+
'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}
|
17 |
|
18 |
search_input = st.text_input(
|
19 |
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
|
20 |
|
21 |
+
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
|
22 |
|
23 |
top_k = 2
|
24 |
|
|
|
36 |
passages = chunk_long_text(st.session_state['earnings_passages'],150,window_size=window_size)
|
37 |
|
38 |
with st.spinner(
|
39 |
+
text=f"Loading {bi_enc_dict[sbert_model_name]} encoder model..."
|
40 |
):
|
41 |
+
sbert = load_sbert(bi_enc_dict[sbert_model_name])
|
42 |
|
43 |
|
44 |
##### Sematic Search #####
|
45 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
46 |
+
# corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
47 |
+
# question_embedding = sbert.encode(search_input, convert_to_tensor=True)
|
48 |
+
# question_embedding = question_embedding.cpu()
|
49 |
+
# hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
50 |
+
# hits = hits[0] # Get the hits for the first query
|
51 |
|
52 |
+
# ##### Re-Ranking #####
|
53 |
+
# # Now, score all retrieved passages with the cross_encoder
|
54 |
+
# cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
|
55 |
+
# cross_scores = cross_encoder.predict(cross_inp)
|
56 |
|
57 |
+
# # Sort results by the cross-encoder scores
|
58 |
+
# for idx in range(len(cross_scores)):
|
59 |
+
# hits[idx]['cross-score'] = cross_scores[idx]
|
60 |
|
61 |
+
# # Output of top-3 hits from re-ranker
|
62 |
+
# hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
63 |
+
embedding_model = bi_enc_dict[sbert_model_name]
|
64 |
|
65 |
+
hits = embed_text(search_input,passages,embedding_model)
|
66 |
+
|
67 |
score='cross-score'
|
68 |
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
|
69 |
df['Score'] = round(df['Score'],2)
|