Luca Foppiano commited on
Commit
77341b6
·
unverified ·
2 Parent(s): d635a1a 9c4d6ae

Merge pull request #7 from lfoppiano/move-settings-sidebar

Browse files
Files changed (1) hide show
  1. streamlit_app.py +120 -62
streamlit_app.py CHANGED
@@ -18,10 +18,13 @@ from document_qa.grobid_processors import GrobidAggregationProcessor, decorate_t
18
  from grobid_client_generic import GrobidClientGeneric
19
 
20
  if 'rqa' not in st.session_state:
21
- st.session_state['rqa'] = None
22
 
23
- if 'api_key' not in st.session_state:
24
- st.session_state['api_key'] = False
 
 
 
25
 
26
  if 'doc_id' not in st.session_state:
27
  st.session_state['doc_id'] = None
@@ -42,13 +45,31 @@ if 'git_rev' not in st.session_state:
42
  if "messages" not in st.session_state:
43
  st.session_state.messages = []
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def new_file():
47
  st.session_state['loaded_embeddings'] = None
48
  st.session_state['doc_id'] = None
 
49
 
50
 
51
- @st.cache_resource
52
  def init_qa(model):
53
  if model == 'chatgpt-3.5-turbo':
54
  chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
@@ -67,6 +88,7 @@ def init_qa(model):
67
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
68
  else:
69
  st.error("The model was not loaded properly. Try reloading. ")
 
70
 
71
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
72
 
@@ -94,7 +116,6 @@ def init_ner():
94
  grobid_quantities_client=quantities_client,
95
  grobid_superconductors_client=materials_client
96
  )
97
-
98
  return gqa
99
 
100
 
@@ -123,53 +144,70 @@ def play_old_messages():
123
  st.write(message['content'])
124
 
125
 
126
- is_api_key_provided = st.session_state['api_key']
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- model = st.sidebar.radio("Model (cannot be changed after selection or upload)",
129
- ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
130
- index=1,
131
- captions=[
132
- "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
133
- "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
134
- # "LLama2-70B-Chat + Sentence BERT (embeddings)",
135
- ],
136
- help="Select the model you want to use.",
137
- disabled=is_api_key_provided)
138
 
139
- if not st.session_state['api_key']:
140
  if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
141
- api_key = st.sidebar.text_input('Huggingface API Key',
142
- type="password") # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
 
 
 
 
 
 
143
  if api_key:
144
- st.session_state['api_key'] = is_api_key_provided = True
145
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
146
- st.session_state['rqa'] = init_qa(model)
 
 
 
 
 
147
  elif model == 'chatgpt-3.5-turbo':
148
- api_key = st.sidebar.text_input('OpenAI API Key',
149
- type="password") # if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
 
 
 
 
 
150
  if api_key:
151
- st.session_state['api_key'] = is_api_key_provided = True
152
- os.environ['OPENAI_API_KEY'] = api_key
153
- st.session_state['rqa'] = init_qa(model)
154
- else:
155
- is_api_key_provided = st.session_state['api_key']
 
 
 
 
156
 
157
  st.title("📝 Scientific Document Insight Q&A")
158
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
159
 
160
- upload_col, radio_col, context_col = st.columns([7, 2, 2])
161
- with upload_col:
162
- uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
163
- disabled=not is_api_key_provided,
164
- help="The full-text is extracted using Grobid. ")
165
- with radio_col:
166
- mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0,
167
- help="LLM will respond the question, Embedding will show the "
168
- "paragraphs relevant to the question in the paper.")
169
- with context_col:
170
- context_size = st.slider("Context size", 3, 10, value=4,
171
- help="Number of paragraphs to consider when answering a question",
172
- disabled=not uploaded_file)
173
 
174
  question = st.chat_input(
175
  "Ask something about the article",
@@ -178,14 +216,29 @@ question = st.chat_input(
178
  )
179
 
180
  with st.sidebar:
181
- st.header("Documentation")
182
- st.markdown("https://github.com/lfoppiano/document-qa")
183
- st.markdown(
184
- """After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
 
 
 
 
 
 
185
 
 
186
  st.markdown(
187
  '**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
188
  unsafe_allow_html=True)
 
 
 
 
 
 
 
 
189
  if st.session_state['git_rev'] != "unknown":
190
  st.markdown("**Revision number**: [" + st.session_state[
191
  'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
@@ -198,14 +251,17 @@ with st.sidebar:
198
  """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
199
 
200
  if uploaded_file and not st.session_state.loaded_embeddings:
 
 
 
201
  with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
202
  binary = uploaded_file.getvalue()
203
  tmp_file = NamedTemporaryFile()
204
  tmp_file.write(bytearray(binary))
205
  # hash = get_file_hash(tmp_file.name)[:10]
206
- st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name,
207
- chunk_size=250,
208
- perc_overlap=0.1)
209
  st.session_state['loaded_embeddings'] = True
210
  st.session_state.messages = []
211
 
@@ -218,6 +274,9 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
218
  st.markdown(message["content"], unsafe_allow_html=True)
219
  elif message['mode'] == "Embeddings":
220
  st.write(message["content"])
 
 
 
221
 
222
  with st.chat_message("user"):
223
  st.markdown(question)
@@ -226,27 +285,26 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
226
  text_response = None
227
  if mode == "Embeddings":
228
  with st.spinner("Generating LLM response..."):
229
- text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
230
- context_size=context_size)
231
  elif mode == "LLM":
232
  with st.spinner("Generating response..."):
233
- _, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
234
- context_size=context_size)
235
 
236
  if not text_response:
237
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
238
 
239
  with st.chat_message("assistant"):
240
  if mode == "LLM":
241
- with st.spinner("Processing NER on LLM response..."):
242
- entities = gqa.process_single_text(text_response)
243
- # for entity in entities:
244
- # entity
245
- decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
246
- decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
247
- decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
248
- st.markdown(decorated_text, unsafe_allow_html=True)
249
- text_response = decorated_text
250
  else:
251
  st.write(text_response)
252
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
 
18
  from grobid_client_generic import GrobidClientGeneric
19
 
20
  if 'rqa' not in st.session_state:
21
+ st.session_state['rqa'] = {}
22
 
23
+ if 'model' not in st.session_state:
24
+ st.session_state['model'] = None
25
+
26
+ if 'api_keys' not in st.session_state:
27
+ st.session_state['api_keys'] = {}
28
 
29
  if 'doc_id' not in st.session_state:
30
  st.session_state['doc_id'] = None
 
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
47
 
48
+ if 'ner_processing' not in st.session_state:
49
+ st.session_state['ner_processing'] = False
50
+
51
+ if 'uploaded' not in st.session_state:
52
+ st.session_state['uploaded'] = False
53
+
54
+ st.set_page_config(
55
+ page_title="Document Insights QA",
56
+ page_icon="📝",
57
+ initial_sidebar_state="expanded",
58
+ menu_items={
59
+ 'Get Help': 'https://github.com/lfoppiano/document-qa',
60
+ 'Report a bug': "https://github.com/lfoppiano/document-qa/issues",
61
+ 'About': "Upload a scientific article in PDF, ask questions, get insights."
62
+ }
63
+ )
64
+
65
 
66
  def new_file():
67
  st.session_state['loaded_embeddings'] = None
68
  st.session_state['doc_id'] = None
69
+ st.session_state['uploaded'] = True
70
 
71
 
72
+ # @st.cache_resource
73
  def init_qa(model):
74
  if model == 'chatgpt-3.5-turbo':
75
  chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
 
88
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
89
  else:
90
  st.error("The model was not loaded properly. Try reloading. ")
91
+ st.stop()
92
 
93
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
94
 
 
116
  grobid_quantities_client=quantities_client,
117
  grobid_superconductors_client=materials_client
118
  )
 
119
  return gqa
120
 
121
 
 
144
  st.write(message['content'])
145
 
146
 
147
+ # is_api_key_provided = st.session_state['api_key']
148
+
149
+ with st.sidebar:
150
+ st.session_state['model'] = model = st.radio(
151
+ "Model",
152
+ ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
153
+ index=1,
154
+ captions=[
155
+ "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
156
+ "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings) :free:"
157
+ # "LLama2-70B-Chat + Sentence BERT (embeddings) :free:",
158
+ ],
159
+ help="Select the LLM model and embeddings you want to use.",
160
+ disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
161
 
162
+ st.markdown(
163
+ ":warning: Mistral is free to use, however requests might hit limits of the huggingface free API and fail. :warning: ")
 
 
 
 
 
 
 
 
164
 
 
165
  if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
166
+ if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
167
+ api_key = st.text_input('Huggingface API Key', type="password")
168
+
169
+ st.markdown(
170
+ "Get it [here](https://huggingface.co/docs/hub/security-tokens)")
171
+ else:
172
+ api_key = os.environ['HUGGINGFACEHUB_API_TOKEN']
173
+
174
  if api_key:
175
+ # st.session_state['api_key'] = is_api_key_provided = True
176
+ if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
177
+ with st.spinner("Preparing environment"):
178
+ st.session_state['api_keys'][model] = api_key
179
+ if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
180
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
181
+ st.session_state['rqa'][model] = init_qa(model)
182
+
183
  elif model == 'chatgpt-3.5-turbo':
184
+ if 'OPENAI_API_KEY' not in os.environ:
185
+ api_key = st.text_input('OpenAI API Key', type="password")
186
+ st.markdown(
187
+ "Get it [here](https://platform.openai.com/account/api-keys)")
188
+ else:
189
+ api_key = os.environ['OPENAI_API_KEY']
190
+
191
  if api_key:
192
+ # st.session_state['api_key'] = is_api_key_provided = True
193
+ if model not in st.session_state['rqa'] or model not in st.session_state['api_keys']:
194
+ with st.spinner("Preparing environment"):
195
+ st.session_state['api_keys'][model] = api_key
196
+ if 'OPENAI_API_KEY' not in os.environ:
197
+ os.environ['OPENAI_API_KEY'] = api_key
198
+ st.session_state['rqa'][model] = init_qa(model)
199
+ # else:
200
+ # is_api_key_provided = st.session_state['api_key']
201
 
202
  st.title("📝 Scientific Document Insight Q&A")
203
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
204
 
205
+ st.markdown(":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
206
+
207
+ uploaded_file = st.file_uploader("Upload an article", type=("pdf", "txt"), on_change=new_file,
208
+ disabled=st.session_state['model'] is not None and st.session_state['model'] not in
209
+ st.session_state['api_keys'],
210
+ help="The full-text is extracted using Grobid. ")
 
 
 
 
 
 
 
211
 
212
  question = st.chat_input(
213
  "Ask something about the article",
 
216
  )
217
 
218
  with st.sidebar:
219
+ st.header("Settings")
220
+ mode = st.radio("Query mode", ("LLM", "Embeddings"), disabled=not uploaded_file, index=0, horizontal=True,
221
+ help="LLM will respond the question, Embedding will show the "
222
+ "paragraphs relevant to the question in the paper.")
223
+ chunk_size = st.slider("Chunks size", 100, 2000, value=250,
224
+ help="Size of chunks in which the document is partitioned",
225
+ disabled=uploaded_file is not None)
226
+ context_size = st.slider("Context size", 3, 10, value=4,
227
+ help="Number of chunks to consider when answering a question",
228
+ disabled=not uploaded_file)
229
 
230
+ st.session_state['ner_processing'] = st.checkbox("Named Entities Recognition (NER) processing on LLM response")
231
  st.markdown(
232
  '**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
233
  unsafe_allow_html=True)
234
+
235
+ st.divider()
236
+
237
+ st.header("Documentation")
238
+ st.markdown("https://github.com/lfoppiano/document-qa")
239
+ st.markdown(
240
+ """Upload a scientific article as PDF document. Once the spinner stops, you can proceed to ask your questions.""")
241
+
242
  if st.session_state['git_rev'] != "unknown":
243
  st.markdown("**Revision number**: [" + st.session_state[
244
  'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
 
251
  """If you switch the mode to "Embedding," the system will return specific chunks from the document that are semantically related to your query. This mode helps to test why sometimes the answers are not satisfying or incomplete. """)
252
 
253
  if uploaded_file and not st.session_state.loaded_embeddings:
254
+ if model not in st.session_state['api_keys']:
255
+ st.error("Before uploading a document, you must enter the API key. ")
256
+ st.stop()
257
  with st.spinner('Reading file, calling Grobid, and creating memory embeddings...'):
258
  binary = uploaded_file.getvalue()
259
  tmp_file = NamedTemporaryFile()
260
  tmp_file.write(bytearray(binary))
261
  # hash = get_file_hash(tmp_file.name)[:10]
262
+ st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
263
+ chunk_size=chunk_size,
264
+ perc_overlap=0.1)
265
  st.session_state['loaded_embeddings'] = True
266
  st.session_state.messages = []
267
 
 
274
  st.markdown(message["content"], unsafe_allow_html=True)
275
  elif message['mode'] == "Embeddings":
276
  st.write(message["content"])
277
+ if model not in st.session_state['rqa']:
278
+ st.error("The API Key for the " + model + " is missing. Please add it before sending any query. `")
279
+ st.stop()
280
 
281
  with st.chat_message("user"):
282
  st.markdown(question)
 
285
  text_response = None
286
  if mode == "Embeddings":
287
  with st.spinner("Generating LLM response..."):
288
+ text_response = st.session_state['rqa'][model].query_storage(question, st.session_state.doc_id,
289
+ context_size=context_size)
290
  elif mode == "LLM":
291
  with st.spinner("Generating response..."):
292
+ _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
293
+ context_size=context_size)
294
 
295
  if not text_response:
296
  st.error("Something went wrong. Contact Luca Foppiano (Foppiano.Luca@nims.co.jp) to report the issue.")
297
 
298
  with st.chat_message("assistant"):
299
  if mode == "LLM":
300
+ if st.session_state['ner_processing']:
301
+ with st.spinner("Processing NER on LLM response..."):
302
+ entities = gqa.process_single_text(text_response)
303
+ decorated_text = decorate_text_with_annotations(text_response.strip(), entities)
304
+ decorated_text = decorated_text.replace('class="label material"', 'style="color:green"')
305
+ decorated_text = re.sub(r'class="label[^"]+"', 'style="color:orange"', decorated_text)
306
+ text_response = decorated_text
307
+ st.markdown(text_response, unsafe_allow_html=True)
 
308
  else:
309
  st.write(text_response)
310
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})