Update pages/3_Earnings_Semantic_Search_π_.py
Browse files
pages/3_Earnings_Semantic_Search_π_.py
CHANGED
@@ -8,24 +8,36 @@ st.markdown("## Earnings Semantic Search with SBert")
|
|
8 |
def gen_sentiment(text):
|
9 |
'''Generate sentiment of given text'''
|
10 |
return sent_pipe(text)[0]['label']
|
|
|
|
|
11 |
|
12 |
search_input = st.text_input(
|
13 |
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
|
14 |
|
|
|
|
|
15 |
top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)
|
16 |
|
17 |
window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
|
18 |
|
19 |
-
if search_input:
|
20 |
|
21 |
|
22 |
if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
|
23 |
|
|
|
|
|
24 |
## Save to a dataframe for ease of visualization
|
25 |
sen_df = st.session_state['sen_df']
|
26 |
|
27 |
passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
##### Sematic Search #####
|
30 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
31 |
corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
|
|
8 |
def gen_sentiment(text):
|
9 |
'''Generate sentiment of given text'''
|
10 |
return sent_pipe(text)[0]['label']
|
11 |
+
|
12 |
+
bi_enc_options = ["multi-qa-mpnet-base-dot-v1","all-mpnet-base-v2","multi-qa-MiniLM-L6-cos-v1","neeva/query2query"]
|
13 |
|
14 |
search_input = st.text_input(
|
15 |
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
|
16 |
|
17 |
+
sbert_model_name = st.sidebar.selectbox("Encoder Model", options=bi_enc_options, key='sbox')
|
18 |
+
|
19 |
top_k = st.sidebar.slider("Number of Top Hits Generated",min_value=1,max_value=5,value=2)
|
20 |
|
21 |
window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
|
22 |
|
23 |
+
if search_input and sbert_model_name:
|
24 |
|
25 |
|
26 |
if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
|
27 |
|
28 |
+
|
29 |
+
|
30 |
## Save to a dataframe for ease of visualization
|
31 |
sen_df = st.session_state['sen_df']
|
32 |
|
33 |
passages = preprocess_plain_text(st.session_state['earnings_passages'],window_size=window_size)
|
34 |
|
35 |
+
with st.spinner(
|
36 |
+
text=f"Loading {sbert_model_name} encoder and embedding text into vector space. This might take a few seconds depending on the length of text..."
|
37 |
+
):
|
38 |
+
sbert = load_sbert(sbert_model_name)
|
39 |
+
|
40 |
+
|
41 |
##### Sematic Search #####
|
42 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
43 |
corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|