thesnak commited on
Commit
a3bd89c
1 Parent(s): a2d1637

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -34,6 +34,7 @@ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instru
34
  ]
35
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
 
37
  # Load PDF document and create doc splits
38
  def load_doc(list_file_path, chunk_size, chunk_overlap):
39
  # Processing for one document only
@@ -120,7 +121,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
120
  top_k = top_k,
121
  )
122
  elif llm_model == "microsoft/phi-2":
123
- raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
124
  llm = HuggingFaceEndpoint(
125
  repo_id=llm_model,
126
  # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
@@ -267,12 +268,7 @@ def conversation(qa_chain, message, history):
267
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
268
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
269
 
270
- def download_csv(chat_history):
271
- df = pd.DataFrame(chat_history, columns=['User Message', 'Assistant Response', 'Rating'])
272
- df.to_csv('conversation_history.csv', index=False)
273
- return 'conversation_history.csv'
274
 
275
-
276
  def upload_file(file_obj):
277
  list_file_path = []
278
  for idx, file in enumerate(file_obj):
@@ -290,16 +286,21 @@ def demo():
290
  collection_name = gr.State()
291
 
292
  gr.Markdown(
293
- """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
294
- <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
295
- <b>Note:</b> This AI assistant performs retrieval-augmented generation from your PDF documents. \
296
- When generating answers, it takes past questions into account (via conversational memory), and includes document references for clarity purposes.</i>
297
- <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate an output.<br>
 
 
298
  """)
299
- with gr.Tab("Step 1 - Document pre-processing"):
 
300
  with gr.Row():
301
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
302
  # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
 
 
303
  with gr.Row():
304
  db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
305
  with gr.Accordion("Advanced options - Document text splitter", open=False):
@@ -310,15 +311,15 @@ def demo():
310
  with gr.Row():
311
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
312
  with gr.Row():
313
- db_btn = gr.Button("Generate vector database...")
314
 
315
- with gr.Tab("Step 2 - QA chain initialization"):
316
  with gr.Row():
317
  llm_btn = gr.Radio(list_llm_simple, \
318
  label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
319
  with gr.Accordion("Advanced options - LLM model", open=False):
320
  with gr.Row():
321
- slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
322
  with gr.Row():
323
  slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
324
  with gr.Row():
@@ -326,14 +327,10 @@ def demo():
326
  with gr.Row():
327
  llm_progress = gr.Textbox(value="None",label="QA chain initialization")
328
  with gr.Row():
329
- qachain_btn = gr.Button("Initialize question-answering chain...")
330
 
331
- with gr.Tab("Step 3 - Conversation with chatbot"):
332
  chatbot = gr.Chatbot(height=300)
333
- # Add rating slider
334
- with gr.Row():
335
- rating_slider = gr.Slider(minimum=1, maximum=5, label="Rate the message:", interactive=True)
336
-
337
  with gr.Accordion("Advanced - Document references", open=False):
338
  with gr.Row():
339
  doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
@@ -345,14 +342,11 @@ def demo():
345
  doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
346
  source3_page = gr.Number(label="Page", scale=1)
347
  with gr.Row():
348
- msg = gr.Textbox(placeholder="Type message", container=True)
349
- with gr.Row():
350
- submit_btn = gr.Button("Submit")
351
- clear_btn = gr.ClearButton([msg, chatbot])
352
-
353
  with gr.Row():
354
- download_btn = gr.Button("Download Conversation History as CSV")
355
-
 
356
  # Preprocessing events
357
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
358
  db_btn.click(initialize_database, \
@@ -365,25 +359,21 @@ def demo():
365
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
366
  queue=False)
367
 
368
- # Conversation events
369
  msg.submit(conversation, \
370
- inputs=[qa_chain, msg, chatbot, rating_slider], \
371
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
372
  queue=False)
373
  submit_btn.click(conversation, \
374
- inputs=[qa_chain, msg, chatbot, rating_slider], \
375
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
376
  queue=False)
377
  clear_btn.click(lambda:[None,"",0,"",0,"",0], \
378
  inputs=None, \
379
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
380
  queue=False)
381
-
382
- # Add download button event
383
- download_btn.click(download_csv, inputs=[chatbot], outputs=None)
384
-
385
- demo.queue().launch(debug=True)
386
 
387
 
388
  if __name__ == "__main__":
389
- demo()
 
34
  ]
35
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
36
 
37
+
38
  # Load PDF document and create doc splits
39
  def load_doc(list_file_path, chunk_size, chunk_overlap):
40
  # Processing for one document only
 
121
  top_k = top_k,
122
  )
123
  elif llm_model == "microsoft/phi-2":
124
+ # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
125
  llm = HuggingFaceEndpoint(
126
  repo_id=llm_model,
127
  # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
 
268
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
269
  return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
270
 
 
 
 
 
271
 
 
272
  def upload_file(file_obj):
273
  list_file_path = []
274
  for idx, file in enumerate(file_obj):
 
286
  collection_name = gr.State()
287
 
288
  gr.Markdown(
289
+ """<center><h2>PDF-based chatbot</center></h2>
290
+ <h3>Ask any questions about your PDF documents</h3>""")
291
+ gr.Markdown(
292
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
293
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
294
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
295
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
296
  """)
297
+
298
+ with gr.Tab("Step 1 - Upload PDF"):
299
  with gr.Row():
300
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
301
  # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
302
+
303
+ with gr.Tab("Step 2 - Process document"):
304
  with gr.Row():
305
  db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
306
  with gr.Accordion("Advanced options - Document text splitter", open=False):
 
311
  with gr.Row():
312
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
313
  with gr.Row():
314
+ db_btn = gr.Button("Generate vector database")
315
 
316
+ with gr.Tab("Step 3 - Initialize QA chain"):
317
  with gr.Row():
318
  llm_btn = gr.Radio(list_llm_simple, \
319
  label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
320
  with gr.Accordion("Advanced options - LLM model", open=False):
321
  with gr.Row():
322
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
323
  with gr.Row():
324
  slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
325
  with gr.Row():
 
327
  with gr.Row():
328
  llm_progress = gr.Textbox(value="None",label="QA chain initialization")
329
  with gr.Row():
330
+ qachain_btn = gr.Button("Initialize Question Answering chain")
331
 
332
+ with gr.Tab("Step 4 - Chatbot"):
333
  chatbot = gr.Chatbot(height=300)
 
 
 
 
334
  with gr.Accordion("Advanced - Document references", open=False):
335
  with gr.Row():
336
  doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
 
342
  doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
343
  source3_page = gr.Number(label="Page", scale=1)
344
  with gr.Row():
345
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
 
 
 
 
346
  with gr.Row():
347
+ submit_btn = gr.Button("Submit message")
348
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
349
+
350
  # Preprocessing events
351
  #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
352
  db_btn.click(initialize_database, \
 
359
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
360
  queue=False)
361
 
362
+ # Chatbot events
363
  msg.submit(conversation, \
364
+ inputs=[qa_chain, msg, chatbot], \
365
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
366
  queue=False)
367
  submit_btn.click(conversation, \
368
+ inputs=[qa_chain, msg, chatbot], \
369
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
370
  queue=False)
371
  clear_btn.click(lambda:[None,"",0,"",0,"",0], \
372
  inputs=None, \
373
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
374
  queue=False)
375
+ demo.queue().launch(debug=True,share=True)
 
 
 
 
376
 
377
 
378
  if __name__ == "__main__":
379
+ demo()