Update app.py
Browse files
app.py
CHANGED
@@ -108,103 +108,13 @@ def build_experimental_ui():
|
|
108 |
|
109 |
button_query = st.button('Submit', disabled=False)
|
110 |
|
111 |
-
|
112 |
if button_query:
|
113 |
|
114 |
for question in questions_df['question']:
|
115 |
instruction = f'{prompt}.Question:{question}'
|
116 |
-
|
117 |
-
print('---- run query ----')
|
118 |
-
print(f'model: {selected_model} embeddings: {selected_embeddings}')
|
119 |
-
if selected_embeddings!=st.session_state['selected_embeddings']:
|
120 |
-
st.session_state['selected_embeddings'] = selected_embeddings
|
121 |
-
texts = load_pdf_document(pdf_docs)
|
122 |
-
st.session_state['retriever'] = get_retriever_from_text(texts, embeddings[selected_embeddings])
|
123 |
-
# qa = RetrievalQA.from_chain_type(llm=models[selected_model], chain_type="stuff",
|
124 |
-
# retriever=st.session_state['retriever'], return_source_documents=True)
|
125 |
-
st.session_state['docs'] = st.session_state['retriever'].get_relevant_documents(st.session_state.query)
|
126 |
-
context = '\n\n'.join([doc.page_content for doc in st.session_state['docs']])
|
127 |
-
st.session_state['context'] = context
|
128 |
-
source_files = get_pdf_file_names(st.session_state['pdf_file'])
|
129 |
-
#st.session_state['conversation']= get_conversation_chain(st.session_state['retriever'])
|
130 |
-
|
131 |
-
if strategy=='Without Chain-of-Thought':
|
132 |
-
user_token = model_configs[selected_model]['USER_TOKEN']
|
133 |
-
end_token = model_configs[selected_model]['END_TOKEN']
|
134 |
-
assistant_token = model_configs[selected_model]['ASSISTANT_TOKEN']
|
135 |
-
prompt_pattern, prompt = create_prompt(user_token, instruction, st.session_state.query, end_token, assistant_token, context)
|
136 |
-
updated_context = truncate_context(prompt_pattern, context,
|
137 |
-
max_token_len=model_configs[selected_model]['MAX_TOKENS'],
|
138 |
-
max_new_token_length=model_configs[selected_model]['MAX_NEW_TOKEN_LENGTH'])
|
139 |
-
updated_prompt = prompt_pattern.replace('{context}', updated_context)
|
140 |
-
print(updated_prompt)
|
141 |
-
|
142 |
-
|
143 |
-
with st.spinner():
|
144 |
-
answer = models[selected_model].generate([updated_prompt]).generations[0][0].text.strip()
|
145 |
-
st.write(answer)
|
146 |
-
chat_content['answer'] = answer
|
147 |
-
chat_content['source'] = source_files
|
148 |
-
chat_content['context']=st.session_state['context']
|
149 |
-
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S")
|
150 |
-
if st.session_state['chat_history']:
|
151 |
-
st.session_state['chat_history'].append(chat_content)
|
152 |
-
else:
|
153 |
-
st.session_state['chat_history']=[chat_content]
|
154 |
-
print('------chat history-----',st.session_state['chat_history'])
|
155 |
-
if updated_prompt!=prompt:
|
156 |
-
st.caption(f"Note: The context has been truncated to fit model max tokens of {model_configs[selected_model]['MAX_TOKENS']}. Original context contains {len(context.split())} words. Truncated context contains {len(updated_context.split())} words.")
|
157 |
|
158 |
-
|
159 |
-
chain = PDSCoverageChain()
|
160 |
-
with st.spinner():
|
161 |
-
answer = chain.generate(models[selected_model], model_configs[selected_model], st.session_state.query, context)
|
162 |
-
st.write(answer)
|
163 |
-
chat_content['answer'] = answer
|
164 |
-
chat_content['source'] = source_files
|
165 |
-
chat_content['context']=st.session_state['context']
|
166 |
-
chat_content['time']=datetime.now().strftime("%d-%m-%Y %H:%M:%S")
|
167 |
-
if st.session_state['chat_history']:
|
168 |
-
st.session_state['chat_history'].append(chat_content)
|
169 |
-
else:
|
170 |
-
st.session_state['chat_history']=[chat_content]
|
171 |
-
print('------chat history-----',st.session_state['chat_history'])
|
172 |
-
|
173 |
-
|
174 |
-
if st.session_state['docs']:
|
175 |
-
|
176 |
-
docs = st.session_state['docs']
|
177 |
|
178 |
-
col3, col4, col5, col6 = st.columns([0.2,0.35, 0.65, 3.8])
|
179 |
-
if st.session_state.query is None:
|
180 |
-
disable_query = True
|
181 |
-
else:
|
182 |
-
disable_query = False
|
183 |
-
chat_history = st.session_state['chat_history']
|
184 |
-
with col3:
|
185 |
-
st.button(":thumbsup:", on_click = get_feedback,disabled=disable_query,
|
186 |
-
kwargs=dict(upvote=True, downvote=False,
|
187 |
-
button='upvote'))
|
188 |
-
with col4:
|
189 |
-
st.button(":thumbsdown:", on_click = get_feedback,disabled=disable_query,
|
190 |
-
kwargs=dict(upvote=False, downvote=True,
|
191 |
-
button='downvote'))
|
192 |
-
|
193 |
-
|
194 |
-
with st.expander("References"):
|
195 |
-
for doc in docs:
|
196 |
-
print('-------',doc)
|
197 |
-
#st.markdown('###### Page {}'.format(doc.metadata['page']))
|
198 |
-
st.write(doc.page_content.replace('\n','\n\n').replace('$','\$').replace('**',''))
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
st.button("End Chat", on_click = get_feedback,
|
203 |
-
kwargs=dict(button='end-chat',
|
204 |
-
chat_history=chat_history))
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
|
209 |
else:
|
210 |
st.info("Under Development")
|
|
|
108 |
|
109 |
button_query = st.button('Submit', disabled=False)
|
110 |
|
|
|
111 |
if button_query:
|
112 |
|
113 |
for question in questions_df['question']:
|
114 |
instruction = f'{prompt}.Question:{question}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
else:
|
120 |
st.info("Under Development")
|