Update app.py
Browse files
app.py
CHANGED
@@ -34,6 +34,7 @@ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instru
|
|
34 |
]
|
35 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
36 |
|
|
|
37 |
# Load PDF document and create doc splits
|
38 |
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
39 |
# Processing for one document only
|
@@ -120,7 +121,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
120 |
top_k = top_k,
|
121 |
)
|
122 |
elif llm_model == "microsoft/phi-2":
|
123 |
-
raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
|
124 |
llm = HuggingFaceEndpoint(
|
125 |
repo_id=llm_model,
|
126 |
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
|
@@ -267,12 +268,7 @@ def conversation(qa_chain, message, history):
|
|
267 |
# return gr.update(value=""), new_history, response_sources[0], response_sources[1]
|
268 |
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
269 |
|
270 |
-
def download_csv(chat_history):
|
271 |
-
df = pd.DataFrame(chat_history, columns=['User Message', 'Assistant Response', 'Rating'])
|
272 |
-
df.to_csv('conversation_history.csv', index=False)
|
273 |
-
return 'conversation_history.csv'
|
274 |
|
275 |
-
|
276 |
def upload_file(file_obj):
|
277 |
list_file_path = []
|
278 |
for idx, file in enumerate(file_obj):
|
@@ -290,16 +286,21 @@ def demo():
|
|
290 |
collection_name = gr.State()
|
291 |
|
292 |
gr.Markdown(
|
293 |
-
"""<center><h2>PDF-based chatbot
|
294 |
-
<h3>Ask any questions about your PDF documents
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
298 |
""")
|
299 |
-
|
|
|
300 |
with gr.Row():
|
301 |
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
|
302 |
# upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
|
|
|
|
|
303 |
with gr.Row():
|
304 |
db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
|
305 |
with gr.Accordion("Advanced options - Document text splitter", open=False):
|
@@ -310,15 +311,15 @@ def demo():
|
|
310 |
with gr.Row():
|
311 |
db_progress = gr.Textbox(label="Vector database initialization", value="None")
|
312 |
with gr.Row():
|
313 |
-
db_btn = gr.Button("Generate vector database
|
314 |
|
315 |
-
with gr.Tab("Step
|
316 |
with gr.Row():
|
317 |
llm_btn = gr.Radio(list_llm_simple, \
|
318 |
label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
|
319 |
with gr.Accordion("Advanced options - LLM model", open=False):
|
320 |
with gr.Row():
|
321 |
-
slider_temperature = gr.Slider(minimum = 0.
|
322 |
with gr.Row():
|
323 |
slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
|
324 |
with gr.Row():
|
@@ -326,14 +327,10 @@ def demo():
|
|
326 |
with gr.Row():
|
327 |
llm_progress = gr.Textbox(value="None",label="QA chain initialization")
|
328 |
with gr.Row():
|
329 |
-
qachain_btn = gr.Button("Initialize
|
330 |
|
331 |
-
with gr.Tab("Step
|
332 |
chatbot = gr.Chatbot(height=300)
|
333 |
-
# Add rating slider
|
334 |
-
with gr.Row():
|
335 |
-
rating_slider = gr.Slider(minimum=1, maximum=5, label="Rate the message:", interactive=True)
|
336 |
-
|
337 |
with gr.Accordion("Advanced - Document references", open=False):
|
338 |
with gr.Row():
|
339 |
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
@@ -345,14 +342,11 @@ def demo():
|
|
345 |
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
346 |
source3_page = gr.Number(label="Page", scale=1)
|
347 |
with gr.Row():
|
348 |
-
msg = gr.Textbox(placeholder="Type message", container=True)
|
349 |
-
with gr.Row():
|
350 |
-
submit_btn = gr.Button("Submit")
|
351 |
-
clear_btn = gr.ClearButton([msg, chatbot])
|
352 |
-
|
353 |
with gr.Row():
|
354 |
-
|
355 |
-
|
|
|
356 |
# Preprocessing events
|
357 |
#upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
|
358 |
db_btn.click(initialize_database, \
|
@@ -365,25 +359,21 @@ def demo():
|
|
365 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
366 |
queue=False)
|
367 |
|
368 |
-
#
|
369 |
msg.submit(conversation, \
|
370 |
-
inputs=[qa_chain, msg, chatbot
|
371 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
372 |
queue=False)
|
373 |
submit_btn.click(conversation, \
|
374 |
-
inputs=[qa_chain, msg, chatbot
|
375 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
376 |
queue=False)
|
377 |
clear_btn.click(lambda:[None,"",0,"",0,"",0], \
|
378 |
inputs=None, \
|
379 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
380 |
queue=False)
|
381 |
-
|
382 |
-
# Add download button event
|
383 |
-
download_btn.click(download_csv, inputs=[chatbot], outputs=None)
|
384 |
-
|
385 |
-
demo.queue().launch(debug=True)
|
386 |
|
387 |
|
388 |
if __name__ == "__main__":
|
389 |
-
demo()
|
|
|
34 |
]
|
35 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
36 |
|
37 |
+
|
38 |
# Load PDF document and create doc splits
|
39 |
def load_doc(list_file_path, chunk_size, chunk_overlap):
|
40 |
# Processing for one document only
|
|
|
121 |
top_k = top_k,
|
122 |
)
|
123 |
elif llm_model == "microsoft/phi-2":
|
124 |
+
# raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
|
125 |
llm = HuggingFaceEndpoint(
|
126 |
repo_id=llm_model,
|
127 |
# model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
|
|
|
268 |
# return gr.update(value=""), new_history, response_sources[0], response_sources[1]
|
269 |
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
270 |
|
|
|
|
|
|
|
|
|
271 |
|
|
|
272 |
def upload_file(file_obj):
|
273 |
list_file_path = []
|
274 |
for idx, file in enumerate(file_obj):
|
|
|
286 |
collection_name = gr.State()
|
287 |
|
288 |
gr.Markdown(
|
289 |
+
"""<center><h2>PDF-based chatbot</center></h2>
|
290 |
+
<h3>Ask any questions about your PDF documents</h3>""")
|
291 |
+
gr.Markdown(
|
292 |
+
"""<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
|
293 |
+
The user interface explicitely shows multiple steps to help understand the RAG workflow.
|
294 |
+
This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
|
295 |
+
<br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
|
296 |
""")
|
297 |
+
|
298 |
+
with gr.Tab("Step 1 - Upload PDF"):
|
299 |
with gr.Row():
|
300 |
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
|
301 |
# upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
|
302 |
+
|
303 |
+
with gr.Tab("Step 2 - Process document"):
|
304 |
with gr.Row():
|
305 |
db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
|
306 |
with gr.Accordion("Advanced options - Document text splitter", open=False):
|
|
|
311 |
with gr.Row():
|
312 |
db_progress = gr.Textbox(label="Vector database initialization", value="None")
|
313 |
with gr.Row():
|
314 |
+
db_btn = gr.Button("Generate vector database")
|
315 |
|
316 |
+
with gr.Tab("Step 3 - Initialize QA chain"):
|
317 |
with gr.Row():
|
318 |
llm_btn = gr.Radio(list_llm_simple, \
|
319 |
label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
|
320 |
with gr.Accordion("Advanced options - LLM model", open=False):
|
321 |
with gr.Row():
|
322 |
+
slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
|
323 |
with gr.Row():
|
324 |
slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
|
325 |
with gr.Row():
|
|
|
327 |
with gr.Row():
|
328 |
llm_progress = gr.Textbox(value="None",label="QA chain initialization")
|
329 |
with gr.Row():
|
330 |
+
qachain_btn = gr.Button("Initialize Question Answering chain")
|
331 |
|
332 |
+
with gr.Tab("Step 4 - Chatbot"):
|
333 |
chatbot = gr.Chatbot(height=300)
|
|
|
|
|
|
|
|
|
334 |
with gr.Accordion("Advanced - Document references", open=False):
|
335 |
with gr.Row():
|
336 |
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
|
|
342 |
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
343 |
source3_page = gr.Number(label="Page", scale=1)
|
344 |
with gr.Row():
|
345 |
+
msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
|
|
|
|
|
|
|
|
|
346 |
with gr.Row():
|
347 |
+
submit_btn = gr.Button("Submit message")
|
348 |
+
clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
|
349 |
+
|
350 |
# Preprocessing events
|
351 |
#upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
|
352 |
db_btn.click(initialize_database, \
|
|
|
359 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
360 |
queue=False)
|
361 |
|
362 |
+
# Chatbot events
|
363 |
msg.submit(conversation, \
|
364 |
+
inputs=[qa_chain, msg, chatbot], \
|
365 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
366 |
queue=False)
|
367 |
submit_btn.click(conversation, \
|
368 |
+
inputs=[qa_chain, msg, chatbot], \
|
369 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
370 |
queue=False)
|
371 |
clear_btn.click(lambda:[None,"",0,"",0,"",0], \
|
372 |
inputs=None, \
|
373 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
374 |
queue=False)
|
375 |
+
demo.queue().launch(debug=True,share=True)
|
|
|
|
|
|
|
|
|
376 |
|
377 |
|
378 |
if __name__ == "__main__":
|
379 |
+
demo()
|