Mjlehtim commited on
Commit
e2e4e28
·
verified ·
1 Parent(s): fc9f5e2

persist_directory (2) added

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -245,15 +245,18 @@ def create_vector_database_ESG():
245
  #len(docs)
246
  print(f"length of documents loaded: {len(documents)}")
247
  print(f"total number of document chunks generated :{len(docs)}")
 
248
  embed_model = HuggingFaceEmbeddings()
249
 
250
  vs = Chroma.from_documents(
251
  documents=docs,
252
  embedding=embed_model,
253
- collection_name="rag",
 
254
  )
 
255
  doc_retriever_ESG = vs.as_retriever()
256
-
257
  index = VectorStoreIndex.from_documents(llama_parse_documents)
258
  query_engine = index.as_query_engine()
259
 
@@ -274,19 +277,25 @@ def create_vector_database_financials():
274
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=15)
275
  docs = text_splitter.split_documents(documents)
276
 
 
 
277
  embed_model = HuggingFaceEmbeddings()
278
 
 
279
  vs = Chroma.from_documents(
280
  documents=docs,
281
  embedding=embed_model,
282
- collection_name="rag"
 
283
  )
 
284
  doc_retriever_financials = vs.as_retriever()
285
 
 
286
  index = VectorStoreIndex.from_documents(llama_parse_documents)
287
  query_engine_financials = index.as_query_engine()
288
 
289
- print('Vector DB created successfully !')
290
  return doc_retriever_financials, query_engine_financials
291
 
292
  #--------------
@@ -328,6 +337,7 @@ for uploaded_file in uploaded_files_financials:
328
  #---------------
329
  def ESG_strategy():
330
  doc_retriever_ESG, _ = create_vector_database_ESG()
 
331
  prompt_template = """<|system|>
332
  You are a seasoned specialist in environmental, social and governance matters. You write expert analyses for institutional investors. Always use figures, nemerical and statistical data when possible. Output must have sub-headings in bold font and be fluent.<|end|>
333
  <|user|>
@@ -505,15 +515,8 @@ with strategies_container:
505
  with mrow1_col2:
506
  if "ESG_analysis_button_key" in st.session_state.results and st.session_state.results["ESG_analysis_button_key"]:
507
 
508
- doc_retriever_ESG, query_engine = create_vector_database_ESG()
509
- # Define the file path
510
- file_path = os.path.join("data", "parsed_data_financials.pkl")
511
-
512
- # Check if the file exists before running the function
513
- if os.path.exists(file_path):
514
- doc_retriever_financials, query_engine_financials = create_vector_database_financials()
515
- else:
516
- print(f"The file {file_path} does not exist. Skipping vector database creation.")
517
 
518
  memory = ConversationBufferMemory(memory_key="chat_history", k=3, return_messages=True)
519
  search = SerpAPIWrapper()
@@ -548,19 +551,17 @@ with strategies_container:
548
  """
549
  )
550
 
551
- # LCEL Chains with memory integration
552
- if os.path.exists(file_path):
553
- financials_chain = (
554
- {
555
- "context": doc_retriever_financials,
556
- # Lambda function now accepts one argument (even if unused)
557
- "chat_history": lambda _: format_chat_history(memory.load_memory_variables({})["chat_history"]),
558
- "question": RunnablePassthrough(),
559
- }
560
- | prompt_financials
561
- | llm_tool
562
- | StrOutputParser()
563
- )
564
 
565
  ESG_chain = (
566
  {
@@ -581,12 +582,11 @@ with strategies_container:
581
  description="Useful for answering questions about specific ESG figures, data and statistics.",
582
  )
583
 
584
- if os.path.exists(file_path):
585
- vector_query_tool_financials = Tool(
586
- name="Vector Query Engine Financials",
587
- func=lambda query: query_engine_financials.query(query), # Use query_engine to query the vector database
588
- description="Useful for answering questions about specific financial figures, data and statistics.",
589
- )
590
 
591
  tools = [
592
  Tool(
@@ -594,23 +594,19 @@ with strategies_container:
594
  func=ESG_chain.invoke,
595
  description="Useful for answering general questions about environmental, social, and governance (ESG) matters related to the company. ",
596
  ),
 
 
 
 
 
597
  Tool(
598
  name="Search Tool",
599
  func=search.run,
600
  description="Useful when other tools do not provide the answer.",
601
  ),
602
  vector_query_tool_ESG,
603
- ]
604
-
605
- if os.path.exists(file_path):
606
- tools.append(
607
- Tool(
608
- name="Financials QA System",
609
- func=financials_chain.invoke,
610
- description="Useful for answering general questions about financial or operational information concerning the company.",
611
- ),
612
  vector_query_tool_financials,
613
- )
614
 
615
  # Initialize the agent with LCEL tools and memory
616
  agent = initialize_agent(
 
245
  #len(docs)
246
  print(f"length of documents loaded: {len(documents)}")
247
  print(f"total number of document chunks generated :{len(docs)}")
248
+ persist_directory = "./chroma_db_ESG" # Specify directory for Chroma persistence
249
  embed_model = HuggingFaceEmbeddings()
250
 
251
  vs = Chroma.from_documents(
252
  documents=docs,
253
  embedding=embed_model,
254
+ collection_name="rag_ESG",
255
+ persist_directory=persist_directory # Ensure persistence
256
  )
257
+
258
  doc_retriever_ESG = vs.as_retriever()
259
+
260
  index = VectorStoreIndex.from_documents(llama_parse_documents)
261
  query_engine = index.as_query_engine()
262
 
 
277
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=15)
278
  docs = text_splitter.split_documents(documents)
279
 
280
+ # Add a persist directory for Chroma DB
281
+ persist_directory = "./chroma_db_financials" # Specify directory for persistence
282
  embed_model = HuggingFaceEmbeddings()
283
 
284
+ # Initialize Chroma with persistence
285
  vs = Chroma.from_documents(
286
  documents=docs,
287
  embedding=embed_model,
288
+ collection_name="rag_financials", # Use a unique collection name
289
+ persist_directory=persist_directory # Persist the data
290
  )
291
+
292
  doc_retriever_financials = vs.as_retriever()
293
 
294
+ # Build a VectorStore index for querying
295
  index = VectorStoreIndex.from_documents(llama_parse_documents)
296
  query_engine_financials = index.as_query_engine()
297
 
298
+ print('Vector DB for financials created successfully!')
299
  return doc_retriever_financials, query_engine_financials
300
 
301
  #--------------
 
337
  #---------------
338
  def ESG_strategy():
339
  doc_retriever_ESG, _ = create_vector_database_ESG()
340
+
341
  prompt_template = """<|system|>
342
  You are a seasoned specialist in environmental, social and governance matters. You write expert analyses for institutional investors. Always use figures, nemerical and statistical data when possible. Output must have sub-headings in bold font and be fluent.<|end|>
343
  <|user|>
 
515
  with mrow1_col2:
516
  if "ESG_analysis_button_key" in st.session_state.results and st.session_state.results["ESG_analysis_button_key"]:
517
 
518
+ doc_retriever_ESG, query_engine = create_vector_database_ESG()
519
+ doc_retriever_financials, query_engine_financials = create_vector_database_financials()
 
 
 
 
 
 
 
520
 
521
  memory = ConversationBufferMemory(memory_key="chat_history", k=3, return_messages=True)
522
  search = SerpAPIWrapper()
 
551
  """
552
  )
553
 
554
+ financials_chain = (
555
+ {
556
+ "context": doc_retriever_financials,
557
+ # Lambda function now accepts one argument (even if unused)
558
+ "chat_history": lambda _: format_chat_history(memory.load_memory_variables({})["chat_history"]),
559
+ "question": RunnablePassthrough(),
560
+ }
561
+ | prompt_financials
562
+ | llm_tool
563
+ | StrOutputParser()
564
+ )
 
 
565
 
566
  ESG_chain = (
567
  {
 
582
  description="Useful for answering questions about specific ESG figures, data and statistics.",
583
  )
584
 
585
+ vector_query_tool_financials = Tool(
586
+ name="Vector Query Engine Financials",
587
+ func=lambda query: query_engine_financials.query(query), # Use query_engine to query the vector database
588
+ description="Useful for answering questions about specific financial figures, data and statistics.",
589
+ )
 
590
 
591
  tools = [
592
  Tool(
 
594
  func=ESG_chain.invoke,
595
  description="Useful for answering general questions about environmental, social, and governance (ESG) matters related to the company. ",
596
  ),
597
+ Tool(
598
+ name="Financials QA System",
599
+ func=financials_chain.invoke,
600
+ description="Useful for answering general questions about financial or operational information concerning the company.",
601
+ ),
602
  Tool(
603
  name="Search Tool",
604
  func=search.run,
605
  description="Useful when other tools do not provide the answer.",
606
  ),
607
  vector_query_tool_ESG,
 
 
 
 
 
 
 
 
 
608
  vector_query_tool_financials,
609
+ ]
610
 
611
  # Initialize the agent with LCEL tools and memory
612
  agent = initialize_agent(