momenaca commited on
Commit
8ca00e0
1 Parent(s): d708cb9

feature/major backend update with agent

Browse files
.gitignore CHANGED
@@ -5,4 +5,5 @@ __pycache__/utils.cpython-38.pyc
5
 
6
  notebooks/
7
  *.pyc
8
- local_tests/
 
 
5
 
6
  notebooks/
7
  *.pyc
8
+ local_tests/
9
+ .vscode/
app.py CHANGED
@@ -64,9 +64,9 @@ async def chat(query, history):
64
  async for event in result:
65
  print(event)
66
  if event["event"] == "on_chat_model_stream":
67
- print("line 66")
68
  if start_streaming == False:
69
- print("line 68")
70
  start_streaming = True
71
  history[-1] = (query, "")
72
 
@@ -77,17 +77,26 @@ async def chat(query, history):
77
  answer_yet = parse_output_llm_with_sources(answer_yet)
78
  history[-1] = (query, answer_yet)
79
 
 
 
 
 
 
 
80
  elif (
81
  event["name"] == "retrieve_documents"
82
  and event["event"] == "on_chain_end"
83
  ):
84
  try:
85
- print("line 84")
 
86
  docs = event["data"]["output"]["documents"]
87
  docs_html = []
88
- for i, d in enumerate(docs, 1):
89
- docs_html.append(make_html_source(d, i))
 
90
  docs_html = "".join(docs_html)
 
91
  except Exception as e:
92
  print(f"Error getting documents: {e}")
93
  print(event)
@@ -97,9 +106,9 @@ async def chat(query, history):
97
  display_output,
98
  ) in steps_display.items():
99
  if event["name"] == event_name:
100
- print("line 99")
101
  if event["event"] == "on_chain_start":
102
- print("line 101")
103
  answer_yet = event_description
104
  history[-1] = (query, answer_yet)
105
 
 
64
  async for event in result:
65
  print(event)
66
  if event["event"] == "on_chat_model_stream":
67
+ # print("line 66")
68
  if start_streaming == False:
69
+ # print("line 68")
70
  start_streaming = True
71
  history[-1] = (query, "")
72
 
 
77
  answer_yet = parse_output_llm_with_sources(answer_yet)
78
  history[-1] = (query, answer_yet)
79
 
80
+ elif (
81
+ event["name"] == "answer_rag_wrong"
82
+ and event["event"] == "on_chain_stream"
83
+ ):
84
+ history[-1] = (query, event["data"]["chunk"]["answer"])
85
+
86
  elif (
87
  event["name"] == "retrieve_documents"
88
  and event["event"] == "on_chain_end"
89
  ):
90
  try:
91
+ # print(event)
92
+ # print("line 84")
93
  docs = event["data"]["output"]["documents"]
94
  docs_html = []
95
+ for i, doc in enumerate(docs, 1):
96
+ docs_html.append(make_html_source(i, doc))
97
+ # print(docs_html)
98
  docs_html = "".join(docs_html)
99
+ # print(docs_html)
100
  except Exception as e:
101
  print(f"Error getting documents: {e}")
102
  print(event)
 
106
  display_output,
107
  ) in steps_display.items():
108
  if event["name"] == event_name:
109
+ # print("line 99")
110
  if event["event"] == "on_chain_start":
111
+ # print("line 101")
112
  answer_yet = event_description
113
  history[-1] = (query, answer_yet)
114
 
celsius_csrd_chatbot/agent.py CHANGED
@@ -39,16 +39,12 @@ def route_intent(state):
39
  return "intent_esrs"
40
 
41
  elif esrs == "wrong_esrs":
42
- return "answer_rag"
43
 
44
  else:
45
  return "retrieve_documents"
46
 
47
 
48
- def make_id_dict(values):
49
- return {k: k for k in values}
50
-
51
-
52
  def make_graph_agent(llm, vectorstore):
53
  workflow = StateGraph(GraphState)
54
 
@@ -70,11 +66,7 @@ def make_graph_agent(llm, vectorstore):
70
  workflow.set_entry_point("categorize_esrs")
71
 
72
  # CONDITIONAL EDGES
73
- workflow.add_conditional_edges(
74
- "categorize_esrs",
75
- route_intent,
76
- make_id_dict(["intent_esrs", "retrieve_documents", "answer_rag_wrong"]),
77
- )
78
 
79
  # Define the edges
80
  workflow.add_edge("intent_esrs", "retrieve_documents")
 
39
  return "intent_esrs"
40
 
41
  elif esrs == "wrong_esrs":
42
+ return "answer_rag_wrong"
43
 
44
  else:
45
  return "retrieve_documents"
46
 
47
 
 
 
 
 
48
  def make_graph_agent(llm, vectorstore):
49
  workflow = StateGraph(GraphState)
50
 
 
66
  workflow.set_entry_point("categorize_esrs")
67
 
68
  # CONDITIONAL EDGES
69
+ workflow.add_conditional_edges("categorize_esrs", route_intent)
 
 
 
 
70
 
71
  # Define the edges
72
  workflow.add_edge("intent_esrs", "retrieve_documents")
celsius_csrd_chatbot/chains/answer_rag.py CHANGED
@@ -36,6 +36,7 @@ answering_template = """
36
  10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
37
  11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
38
  12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
 
39
 
40
  Question: {query}
41
  Answer:
 
36
  10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
37
  11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
38
  12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
39
+ 13. Never mention these guidelines as a source attribution in your response.
40
 
41
  Question: {query}
42
  Answer:
celsius_csrd_chatbot/chains/esrs_categorization.py CHANGED
@@ -5,7 +5,7 @@ def make_esrs_categorization_node():
5
 
6
  def categorize_message(state):
7
  query = state["query"]
8
- pattern = r"ESRS \d|ESRS [A-Z]\d|ESRS [A-Z] \d"
9
  esrs_truth = [
10
  "ESRS 1",
11
  "ESRS 2",
@@ -25,7 +25,6 @@ def make_esrs_categorization_node():
25
  if matches:
26
  true_matches = [match for match in matches if match in esrs_truth]
27
  output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
28
-
29
  else:
30
  output = {"esrs_type": "none"}
31
 
 
5
 
6
  def categorize_message(state):
7
  query = state["query"]
8
+ pattern = r"ESRS \d+[A-Z0-9]*"
9
  esrs_truth = [
10
  "ESRS 1",
11
  "ESRS 2",
 
25
  if matches:
26
  true_matches = [match for match in matches if match in esrs_truth]
27
  output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
 
28
  else:
29
  output = {"esrs_type": "none"}
30
 
celsius_csrd_chatbot/chains/esrs_intent.py CHANGED
@@ -23,51 +23,41 @@ class ESRSAnalysis(BaseModel):
23
  "ESRS S3",
24
  "ESRS S4",
25
  "ESRS G1",
26
- "none",
27
  ] = Field(
28
- description="""
29
- Given a user question choose which documents would be most relevant for answering their question :
30
-
31
- - ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
32
- - ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
33
- - ESRS E1 is for questions about climate change, global warming, GES and energy
34
- - ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
35
- - ESRS E3 is for questions about water and marine resources
36
- - ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
37
- - ESRS E5 is for questions about resource use and circular economy
38
- - ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
39
- - ESRS S2 is for questions about workers in the value chain, workers' treatment
40
- - SRS S3 is for questions about affected communities, impact on local communities
41
- - ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
42
- - ESRS G1 is for questions about governance, risk management, internal control, and business conduct
43
- - none is for questions that do not fit into any of the above categories
44
-
45
- Follow these guidelines :
46
-
47
- - Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
48
- - Remember, if the question is not related to any ESRS, the output should be 'none'.
49
- """,
50
  )
51
 
52
 
53
  def make_esrs_intent_chain(llm):
54
- parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
55
  prompt_template = """
56
- The following question is about ESRS related topics. Please analyze the question and indicate if it refers to a specific ESRS.
57
-
58
- {format_instructions}
59
-
60
- Please answer with the appropriate ESRS to answer the question.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  Question: '{query}'
63
  Answer:
64
  """
65
-
66
- prompt = PromptTemplate(
67
- template=prompt_template,
68
- input_variables=["query"],
69
- partial_variables={"format_instructions": parser.get_format_instructions()},
70
- )
71
  chain = {"query": itemgetter("query")} | prompt | llm | parser
72
 
73
  return chain
@@ -78,7 +68,9 @@ def make_esrs_intent_node(llm):
78
  def intent_message(state):
79
  query = state["query"]
80
  categorization_chain = make_esrs_intent_chain(llm)
81
- output = categorization_chain.invoke(query)
 
 
82
 
83
  return output
84
 
 
23
  "ESRS S3",
24
  "ESRS S4",
25
  "ESRS G1",
26
+ "no_intent",
27
  ] = Field(
28
+ description="""The ESRS type that the user query refers to.""",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
 
32
  def make_esrs_intent_chain(llm):
 
33
  prompt_template = """
34
+ Please analyze the question and indicate if it refers to a specific ESRS.
35
+
36
+ Follow these definitions in order to choose the appropriate ESRS :
37
+ - ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
38
+ - ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
39
+ - ESRS E1 is for questions about climate change, global warming, GES and energy
40
+ - ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
41
+ - ESRS E3 is for questions about water and marine resources
42
+ - ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
43
+ - ESRS E5 is for questions about resource use and circular economy
44
+ - ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
45
+ - ESRS S2 is for questions about workers in the value chain, workers' treatment
46
+ - ESRS S3 is for questions about affected communities, impact on local communities
47
+ - ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
48
+ - ESRS G1 is for questions about governance, risk management, internal control, and business conduct
49
+ - no_intent is for questions that do not fit into any of the above categories
50
+
51
+ Keep in mind these guidelines :
52
+ - Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
53
+
54
+ The output needs to respect a JSON format with 'esrs_type' as the key and the appropriate ESRS as the value.
55
 
56
  Question: '{query}'
57
  Answer:
58
  """
59
+ parser = PydanticOutputParser(pydantic_object=ESRSAnalysis, method="json_mode")
60
+ prompt = PromptTemplate(template=prompt_template, input_variables=["query"])
 
 
 
 
61
  chain = {"query": itemgetter("query")} | prompt | llm | parser
62
 
63
  return chain
 
68
  def intent_message(state):
69
  query = state["query"]
70
  categorization_chain = make_esrs_intent_chain(llm)
71
+ output = {
72
+ "esrs_type": [categorization_chain.invoke({"query": query}).esrs_type]
73
+ }
74
 
75
  return output
76
 
celsius_csrd_chatbot/chains/retriever.py CHANGED
@@ -1,16 +1,15 @@
1
  def make_retriever_node(vectorstore, k=10):
2
-
3
  def retrieve_documents(state):
4
  sources = state["esrs_type"]
5
  query = state["query"]
6
- if sources == "none":
7
- filters_full = {}
8
  else:
9
- filters_full = {"ESRS_filter": {"$in": sources}}
 
 
 
10
  docs = []
11
- docs_retrieved = vectorstore.similarity_search_with_score(
12
- query=query, filter=filters_full, k=k
13
- )
14
  for doc in docs_retrieved:
15
  doc_append = doc[0]
16
  doc_append.metadata["similarity_score"] = doc[1]
 
1
  def make_retriever_node(vectorstore, k=10):
 
2
  def retrieve_documents(state):
3
  sources = state["esrs_type"]
4
  query = state["query"]
5
+ if sources == "none" or sources == "no_intent":
6
+ docs_retrieved = vectorstore.similarity_search_with_score(query=query, k=k)
7
  else:
8
+ filters = {"ESRS_filter": {"$in": sources}}
9
+ docs_retrieved = vectorstore.similarity_search_with_score(
10
+ query=query, filter=filters, k=k
11
+ )
12
  docs = []
 
 
 
13
  for doc in docs_retrieved:
14
  doc_append = doc[0]
15
  doc_append.metadata["similarity_score"] = doc[1]