File size: 5,619 Bytes
411678e 31b6e92 64af83f 411678e ce57a20 64af83f 411678e 3b52176 4bed905 446f9c9 8e77d9f 64af83f 16f2ce2 64af83f 16f2ce2 64af83f e232116 446f9c9 8e77d9f 3b52176 446f9c9 b38b575 d29d938 e694dea 64af83f a957eeb 741aa8b 8e77d9f 64c86a7 8e77d9f 64c86a7 8e77d9f 883ae12 8e77d9f 09e96c9 08b13e8 64af83f 8e77d9f e3a147c b8e67ea 64af83f 8e77d9f 64af83f 8e77d9f ec3a17f 8e77d9f a957eeb 8e77d9f a957eeb 8e77d9f a957eeb 8e77d9f 9e17c08 a957eeb 8e77d9f 741aa8b a957eeb 741aa8b 82bf281 a957eeb 953c510 8eb51fc ee6d004 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import streamlit as st
from functions import *
from langchain.chains import QAGenerationChain
import itertools
st.set_page_config(page_title="Earnings Question/Answering", page_icon="π")
st.sidebar.header("Semantic Search")
st.markdown("Earnings Semantic Search with LangChain, OpenAI & SBert")
starter_message = "Ask me anything about the Earnings Call!"
st.markdown(
"""
<style>
#MainMenu {visibility: hidden;
# }
footer {visibility: hidden;
}
.css-card {
border-radius: 0px;
padding: 30px 10px 10px 10px;
background-color: black;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin-bottom: 10px;
font-family: "IBM Plex Sans", sans-serif;
}
.card-tag {
border-radius: 0px;
padding: 1px 5px 1px 5px;
margin-bottom: 10px;
position: absolute;
left: 0px;
top: 0px;
font-size: 0.6rem;
font-family: "IBM Plex Sans", sans-serif;
color: white;
background-color: green;
}
.css-zt5igj {left:0;
}
span.css-10trblm {margin-left:0;
}
div.css-1kyxreq {margin-top: -40px;
}
</style>
""",
unsafe_allow_html=True,
)
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
'instructor-base': 'hkunlp/instructor-base',
'FlagEmbedding': 'BAAI/bge-base-en'}
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
st.sidebar.markdown('Earnings QnA Generator')
chunk_size = 1000
overlap_size = 50
try:
if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
## Save to a dataframe for ease of visualization
sen_df = st.session_state['sen_df']
title = st.session_state['title']
print(f'Earnings Call title: {title}')
earnings_text = st.session_state['earnings_passages']
st.session_state.eval_set = generate_eval(
earnings_text, 10, 3000)
# Display the question-answer pairs in the sidebar with smaller text
for i, qa_pair in enumerate(st.session_state.eval_set):
st.sidebar.markdown(
f"""
<div class="css-card">
<span class="card-tag">Question {i + 1}</span>
<p style="font-size: 12px;">{qa_pair['question']}</p>
<p style="font-size: 12px;">{qa_pair['answer']}</p>
</div>
""",
unsafe_allow_html=True,
)
embedding_model = bi_enc_dict[sbert_model_name]
with st.spinner(
text=f"Loading {embedding_model} embedding model and creating vectorstore..."
):
docsearch = create_vectorstore(earnings_text,title, embedding_model)
memory, agent_executor = create_memory_and_agent(docsearch)
if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
st.session_state["messages"] = [AIMessage(content=starter_message)]
for msg in st.session_state.messages:
if isinstance(msg, AIMessage):
st.chat_message("assistant").write(msg.content)
elif isinstance(msg, HumanMessage):
st.chat_message("user").write(msg.content)
memory.chat_memory.add_message(msg)
if user_question := st.chat_input(placeholder=starter_message):
st.chat_message("user").write(user_question)
with st.chat_message("assistant"):
st_callback = StreamlitCallbackHandler(st.container())
response = agent_executor(
{"input": user_question, "history": st.session_state.messages},
callbacks=[st_callback],
include_run_info=True,
)
answer = response["output"]
st.session_state.messages.append(AIMessage(content=answer))
st.write(answer)
memory.save_context({"input": user_question}, response)
st.session_state["messages"] = memory.buffer
run_id = response["__run"].run_id
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])
with col_text:
st.text("Feedback:")
with col1:
st.button("π", on_click=send_feedback, args=(run_id, 1))
with col2:
st.button("π", on_click=send_feedback, args=(run_id, 0))
with st.expander(label='Query Result with Sentiment Tag', expanded=True):
sentiment_label = gen_sentiment(answer)
df = pd.DataFrame.from_dict({'Text':[answer],'Sentiment':[sentiment_label]})
text_annotations = gen_annotated_text(df)[0]
annotated_text(text_annotations)
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
except RuntimeError:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
|