Update pages/3_Earnings_Semantic_Search_π_.py
Browse files
pages/3_Earnings_Semantic_Search_π_.py
CHANGED
@@ -9,10 +9,24 @@ def gen_sentiment(text):
|
|
9 |
'''Generate sentiment of given text'''
|
10 |
return sent_pipe(text)[0]['label']
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
|
13 |
-
'e5-base':'intfloat/e5-base',
|
14 |
'instructor-base': 'hkunlp/instructor-base',
|
15 |
-
'mpnet-base-dot-v1':'multi-qa-mpnet-base-dot-v1',
|
16 |
'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}
|
17 |
|
18 |
search_input = st.text_input(
|
@@ -20,9 +34,9 @@ search_input = st.text_input(
|
|
20 |
|
21 |
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
try:
|
28 |
|
@@ -32,69 +46,36 @@ try:
|
|
32 |
|
33 |
## Save to a dataframe for ease of visualization
|
34 |
sen_df = st.session_state['sen_df']
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
with st.spinner(
|
39 |
-
text=f"Loading {bi_enc_dict[sbert_model_name]}
|
40 |
):
|
41 |
-
sbert = load_sbert(bi_enc_dict[sbert_model_name])
|
42 |
-
|
43 |
-
|
44 |
-
##### Sematic Search #####
|
45 |
-
# Encode the query using the bi-encoder and find potentially relevant passages
|
46 |
-
# corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
47 |
-
# question_embedding = sbert.encode(search_input, convert_to_tensor=True)
|
48 |
-
# question_embedding = question_embedding.cpu()
|
49 |
-
# hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
50 |
-
# hits = hits[0] # Get the hits for the first query
|
51 |
-
|
52 |
-
# ##### Re-Ranking #####
|
53 |
-
# # Now, score all retrieved passages with the cross_encoder
|
54 |
-
# cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
|
55 |
-
# cross_scores = cross_encoder.predict(cross_inp)
|
56 |
-
|
57 |
-
# # Sort results by the cross-encoder scores
|
58 |
-
# for idx in range(len(cross_scores)):
|
59 |
-
# hits[idx]['cross-score'] = cross_scores[idx]
|
60 |
-
|
61 |
-
# # Output of top-3 hits from re-ranker
|
62 |
-
# hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
63 |
-
embedding_model = bi_enc_dict[sbert_model_name]
|
64 |
-
|
65 |
-
hits = embed_text(search_input,passages,embedding_model)
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
for row in df.itertuples():
|
77 |
-
label = row[3]
|
78 |
-
text = row[2]
|
79 |
-
if label == 'Positive':
|
80 |
-
tag_list.append((text,label,'#8fce00'))
|
81 |
-
elif label == 'Negative':
|
82 |
-
tag_list.append((text,label,'#f44336'))
|
83 |
-
else:
|
84 |
-
tag_list.append((text,label,'#000000'))
|
85 |
-
|
86 |
-
return tag_list
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
first, second = text_annotations[0], text_annotations[1]
|
91 |
|
|
|
92 |
|
93 |
-
with st.expander(label='
|
94 |
-
annotated_text(
|
95 |
|
96 |
-
with st.expander(label='
|
97 |
-
|
|
|
98 |
|
99 |
else:
|
100 |
|
|
|
9 |
'''Generate sentiment of given text'''
|
10 |
return sent_pipe(text)[0]['label']
|
11 |
|
12 |
+
def gen_annotated_text(df):
|
13 |
+
'''Generate annotated text'''
|
14 |
+
|
15 |
+
tag_list=[]
|
16 |
+
for row in df.itertuples():
|
17 |
+
label = row[3]
|
18 |
+
text = row[2]
|
19 |
+
if label == 'Positive':
|
20 |
+
tag_list.append((text,label,'#8fce00'))
|
21 |
+
elif label == 'Negative':
|
22 |
+
tag_list.append((text,label,'#f44336'))
|
23 |
+
else:
|
24 |
+
tag_list.append((text,label,'#000000'))
|
25 |
+
|
26 |
+
return tag_list
|
27 |
+
|
28 |
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
|
|
|
29 |
'instructor-base': 'hkunlp/instructor-base',
|
|
|
30 |
'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}
|
31 |
|
32 |
search_input = st.text_input(
|
|
|
34 |
|
35 |
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
|
36 |
|
37 |
+
chunk_size = st.sidebar.slider("Number of Words per Chunk of Text",min_value=100,max_value=250,value=200)
|
38 |
+
overlap_size = st.sidebar.slider("Number of Overlap Words in Search Response",min_value=30,max_value=100,value=50)
|
39 |
+
chain_type = st.sidebar.radio("Langchain Chain Type",options = ['Normal','Refined'])
|
40 |
|
41 |
try:
|
42 |
|
|
|
46 |
|
47 |
## Save to a dataframe for ease of visualization
|
48 |
sen_df = st.session_state['sen_df']
|
49 |
+
|
50 |
+
title = st.session_state['title']
|
51 |
+
|
52 |
with st.spinner(
|
53 |
+
text=f"Loading {bi_enc_dict[sbert_model_name]} embedding model and Generating Response..."
|
54 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
result = embed_text(search_input,st.session_state['earnings_passages'],title,
|
57 |
+
bi_enc_dict[sbert_model_name],
|
58 |
+
emb_tokenizer,chain_type=chain_type)
|
59 |
+
|
60 |
+
references = [doc.page_content for doc in result['input_documents']]
|
61 |
+
|
62 |
+
answer = result['output_text']
|
63 |
+
|
64 |
+
sentiment_label = gen_sentiment(answer)
|
65 |
|
66 |
+
##### Sematic Search #####
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
df = pd.DataFrame([(num,res,lab) for num, res, lab in zip(1,answer,sentiment_label)],columns=['Index','Text','Sentiment'])
|
69 |
+
|
|
|
70 |
|
71 |
+
text_annotations = gen_annotated_text(df)[0]
|
72 |
|
73 |
+
with st.expander(label='Query Result', expanded=True):
|
74 |
+
annotated_text(text_annotations)
|
75 |
|
76 |
+
with st.expander(label='References from Corpus used to Generate Result'):
|
77 |
+
for ref in references:
|
78 |
+
st.write(ref)
|
79 |
|
80 |
else:
|
81 |
|