petrojm commited on
Commit
3b758aa
1 Parent(s): e443083

changes to app

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -7,17 +7,17 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
7
 
8
  from src.document_retrieval import DocumentRetrieval
9
  from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials
10
- from utils.parsing.sambaparse import parse_doc_universal # added Petro
11
  from utils.vectordb.vector_db import VectorDb
12
 
13
  CONFIG_PATH = os.path.join(current_dir,'config.yaml')
14
  PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
15
 
16
- def handle_userinput(user_question, conversation, history):
17
  if user_question:
18
  try:
19
  # Generate response
20
- response = conversation.invoke({"question": user_question})
21
 
22
  # Append user message and response to chat history
23
  history = history + [(user_question, response["answer"])]
@@ -28,27 +28,27 @@ def handle_userinput(user_question, conversation, history):
28
  else:
29
  return history, ""
30
 
31
- def process_documents(files, document_retrieval, vectorstore, conversation, save_location=None):
32
  try:
33
  document_retrieval = DocumentRetrieval()
34
  _, _, text_chunks = parse_doc_universal(doc=files)
35
  print(text_chunks)
36
  embeddings = document_retrieval.load_embedding_model()
37
- collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
38
  vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
39
  document_retrieval.init_retriever(vectorstore)
40
- conversation = document_retrieval.get_qa_retrieval_chain()
41
  #input_disabled = False
42
- return conversation, vectorstore, document_retrieval, "Complete! You can now ask questions."
43
  except Exception as e:
44
- return conversation, vectorstore, document_retrieval, f"An error occurred while processing: {str(e)}"
45
 
46
  # Read config file
47
  with open(CONFIG_PATH, 'r') as yaml_file:
48
  config = yaml.safe_load(yaml_file)
49
 
50
  prod_mode = config.get('prod_mode', False)
51
- default_collection = 'ekr_default_collection'
52
 
53
  # Load env variables
54
  initialize_env_variables(prod_mode)
@@ -58,8 +58,9 @@ caution_text = """⚠️ Note: depending on the size of your document, this coul
58
 
59
  with gr.Blocks() as demo:
60
  vectorstore = gr.State()
61
- conversation = gr.State()
62
  document_retrieval = gr.State()
 
63
 
64
  gr.Markdown("# Enterprise Knowledge Retriever",
65
  elem_id="title")
@@ -80,7 +81,7 @@ with gr.Blocks() as demo:
80
  gr.Markdown(caution_text)
81
 
82
  # Preprocessing events
83
- process_btn.click(process_documents, inputs=[docs, document_retrieval, vectorstore, conversation], outputs=[conversation, vectorstore, document_retrieval, setup_output], concurrency_limit=10)
84
 
85
  # Step 3: Chat with your data
86
  gr.Markdown("## 3️⃣ Chat with your document")
@@ -90,7 +91,7 @@ with gr.Blocks() as demo:
90
  sources_output = gr.Textbox(label="Sources", visible=False)
91
 
92
  # Chatbot events
93
- msg.submit(handle_userinput, inputs=[msg, conversation, chatbot], outputs=[chatbot, msg], queue=False)
94
  clear_btn.click(lambda: [None, ""], inputs=None, outputs=[chatbot, msg], queue=False)
95
 
96
  if __name__ == "__main__":
 
7
 
8
  from src.document_retrieval import DocumentRetrieval
9
  from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials
10
+ from utils.parsing.sambaparse import parse_doc_universal # added
11
  from utils.vectordb.vector_db import VectorDb
12
 
13
  CONFIG_PATH = os.path.join(current_dir,'config.yaml')
14
  PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
15
 
16
+ def handle_userinput(user_question, conversation_chain, history):
17
  if user_question:
18
  try:
19
  # Generate response
20
+ response = conversation_chain.invoke({"question": user_question})
21
 
22
  # Append user message and response to chat history
23
  history = history + [(user_question, response["answer"])]
 
28
  else:
29
  return history, ""
30
 
31
+ def process_documents(files, collection_name, document_retrieval, vectorstore, conversation_chain, save_location=None):
32
  try:
33
  document_retrieval = DocumentRetrieval()
34
  _, _, text_chunks = parse_doc_universal(doc=files)
35
  print(text_chunks)
36
  embeddings = document_retrieval.load_embedding_model()
37
+ collection_name = 'ekr_default_collection'
38
  vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
39
  document_retrieval.init_retriever(vectorstore)
40
+ conversation_chain = document_retrieval.get_qa_retrieval_chain()
41
  #input_disabled = False
42
+ return conversation_chain, vectorstore, document_retrieval, collection_name, "Complete! You can now ask questions."
43
  except Exception as e:
44
+ return conversation_chain, vectorstore, document_retrieval, collection_name, f"An error occurred while processing: {str(e)}"
45
 
46
  # Read config file
47
  with open(CONFIG_PATH, 'r') as yaml_file:
48
  config = yaml.safe_load(yaml_file)
49
 
50
  prod_mode = config.get('prod_mode', False)
51
+ #default_collection = 'ekr_default_collection'
52
 
53
  # Load env variables
54
  initialize_env_variables(prod_mode)
 
58
 
59
  with gr.Blocks() as demo:
60
  vectorstore = gr.State()
61
+ conversation_chain = gr.State()
62
  document_retrieval = gr.State()
63
+ collection_name=gr.State()
64
 
65
  gr.Markdown("# Enterprise Knowledge Retriever",
66
  elem_id="title")
 
81
  gr.Markdown(caution_text)
82
 
83
  # Preprocessing events
84
+ process_btn.click(process_documents, inputs=[docs, collection_name, document_retrieval, vectorstore, conversation_chain], outputs=[conversation_chain, vectorstore, document_retrieval, collection_name, setup_output], concurrency_limit=10)
85
 
86
  # Step 3: Chat with your data
87
  gr.Markdown("## 3️⃣ Chat with your document")
 
91
  sources_output = gr.Textbox(label="Sources", visible=False)
92
 
93
  # Chatbot events
94
+ msg.submit(handle_userinput, inputs=[msg, conversation_chain, chatbot], outputs=[chatbot, msg], queue=False)
95
  clear_btn.click(lambda: [None, ""], inputs=None, outputs=[chatbot, msg], queue=False)
96
 
97
  if __name__ == "__main__":