timeki commited on
Commit
d562d38
1 Parent(s): 76603df

Rerank documents and force summary for policy makers

Browse files
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -57,107 +57,135 @@ def query_retriever(question):
57
  """Just a dummy tool to simulate the retriever query"""
58
  return question
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
-
63
-
64
-
65
-
66
- def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
67
-
68
- # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
69
- @chain
70
- async def retrieve_documents(state,config):
71
- print("---- Retrieve documents ----")
72
-
73
-
74
- keywords_extraction = make_keywords_extraction_chain(llm)
75
-
76
- current_question = state["remaining_questions"][0]
77
- remaining_questions = state["remaining_questions"][1:]
78
-
79
- # ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
80
-
81
-
82
- # # There are several options to get the final top k
83
- # # Option 1 - Get 100 documents by question and rerank by question
84
- # # Option 2 - Get 100/n documents by question and rerank the total
85
- # if rerank_by_question:
86
- # k_by_question = divide_into_parts(k_final,len(questions))
87
- if "documents" in state and state["documents"] is not None:
88
- docs = state["documents"]
89
- else:
90
- docs = []
91
 
92
 
93
-
94
- k_by_question = k_final // state["n_questions"]
 
95
 
96
- sources = current_question["sources"]
97
- question = current_question["question"]
98
- index = current_question["index"]
99
-
100
-
101
- await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
102
 
 
103
 
104
- if index == "Vector":
105
-
106
- # Search the document store using the retriever
107
- # Configure high top k for further reranking step
108
- retriever = ClimateQARetriever(
109
- vectorstore=vectorstore,
110
- sources = sources,
111
- min_size = 200,
112
- k_summary = k_summary,
113
- k_total = k_before_reranking,
114
- threshold = 0.5,
115
- )
116
- docs_question = await retriever.ainvoke(question,config)
117
 
118
- elif index == "OpenAlex":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- keywords = keywords_extraction.invoke(question)["keywords"]
121
- openalex_query = " AND ".join(keywords)
122
 
123
- print(f"... OpenAlex query: {openalex_query}")
 
 
 
 
 
 
124
 
125
- retriever_openalex = OpenAlexRetriever(
126
- min_year = state.get("min_year",1960),
127
- max_year = state.get("max_year",None),
128
- k = k_before_reranking
129
- )
130
- docs_question = await retriever_openalex.ainvoke(openalex_query,config)
131
 
132
- else:
133
- raise Exception(f"Index {index} not found in the routing index")
134
-
135
- # Rerank
136
- if reranker is not None:
137
- with suppress_output():
138
- docs_question = rerank_docs(reranker,docs_question,question)
139
- else:
140
- # Add a default reranking score
141
- for doc in docs_question:
142
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
143
-
144
- # If rerank by question we select the top documents for each question
145
- if rerank_by_question:
146
- docs_question = docs_question[:k_by_question]
147
-
148
- # Add sources used in the metadata
149
- for doc in docs_question:
150
- doc.metadata["sources_used"] = sources
151
- doc.metadata["question_used"] = question
152
- doc.metadata["index_used"] = index
153
-
154
- # Add to the list of docs
155
- docs.extend(docs_question)
156
-
157
- # Sorting the list in descending order by rerank_score
158
- docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
159
- new_state = {"documents":docs,"remaining_questions":remaining_questions}
160
- return new_state
161
-
162
- return retrieve_documents
163
 
 
57
  """Just a dummy tool to simulate the retriever query"""
58
  return question
59
 
60
+ def _add_sources_used_in_metadata(docs,sources,question,index):
61
+ for doc in docs:
62
+ doc.metadata["sources_used"] = sources
63
+ doc.metadata["question_used"] = question
64
+ doc.metadata["index_used"] = index
65
+ return docs
66
+
67
+ def _get_k_summary_by_question(n_questions):
68
+ if n_questions == 0:
69
+ return 0
70
+ elif n_questions == 1:
71
+ return 5
72
+ elif n_questions == 2:
73
+ return 3
74
+ elif n_questions == 3:
75
+ return 2
76
+ else:
77
+ return 1
78
+
79
 
80
+ # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
81
+ # @chain
82
+ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
83
+ print("---- Retrieve documents ----")
84
+
85
+ # Get the documents from the state
86
+ if "documents" in state and state["documents"] is not None:
87
+ docs = state["documents"]
88
+ else:
89
+ docs = []
90
+ # Get the related_content from the state
91
+ if "related_content" in state and state["related_content"] is not None:
92
+ related_content = state["related_content"]
93
+ else:
94
+ related_content = []
95
+
96
+ # Get the current question
97
+ current_question = state["remaining_questions"][0]
98
+ remaining_questions = state["remaining_questions"][1:]
99
+
100
+ k_by_question = k_final // state["n_questions"]
101
+ k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
102
+
103
+ sources = current_question["sources"]
104
+ question = current_question["question"]
105
+ index = current_question["index"]
106
+
107
 
108
+ await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
109
+
110
+
111
+ if index == "Vector":
112
+ # Search the document store using the retriever
113
+ # Configure high top k for further reranking step
114
+ retriever = ClimateQARetriever(
115
+ vectorstore=vectorstore,
116
+ sources = sources,
117
+ min_size = 200,
118
+ k_summary = k_summary_by_question,
119
+ k_total = k_before_reranking,
120
+ threshold = 0.5,
121
+ )
122
+ docs_question_dict = await retriever.ainvoke(question,config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
+ # elif index == "OpenAlex":
126
+ # # keyword extraction
127
+ # keywords_extraction = make_keywords_extraction_chain(llm)
128
 
129
+ # keywords = keywords_extraction.invoke(question)["keywords"]
130
+ # openalex_query = " AND ".join(keywords)
 
 
 
 
131
 
132
+ # print(f"... OpenAlex query: {openalex_query}")
133
 
134
+ # retriever_openalex = OpenAlexRetriever(
135
+ # min_year = state.get("min_year",1960),
136
+ # max_year = state.get("max_year",None),
137
+ # k = k_before_reranking
138
+ # )
139
+ # docs_question = await retriever_openalex.ainvoke(openalex_query,config)
 
 
 
 
 
 
 
140
 
141
+ # else:
142
+ # raise Exception(f"Index {index} not found in the routing index")
143
+
144
+
145
+
146
+ # Rerank
147
+ if reranker is not None:
148
+ with suppress_output():
149
+ docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
150
+ docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
151
+ docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
152
+ if rerank_by_question:
153
+ docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
154
+ docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
155
+ docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
156
+ else:
157
+ docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
158
+ # Add a default reranking score
159
+ for doc in docs_question:
160
+ doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
161
+
162
+ docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
163
+ docs_question = docs_question[:k_by_question]
164
+ images_question = docs_question_images_reranked[:k_by_question]
165
+
166
+ if reranker is not None and rerank_by_question:
167
+ docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
168
+
169
+ # Add sources used in the metadata
170
+ docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
171
+ images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
172
+
173
+ # Add to the list of docs
174
+ docs.extend(docs_question)
175
+ related_content.extend(images_question)
176
+
177
+ new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
178
+ return new_state
179
+
180
 
 
 
181
 
182
+ def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
183
+ @chain
184
+ async def retrieve_docs(state, config):
185
+ state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
186
+ return state
187
+
188
+ return retrieve_docs
189
 
 
 
 
 
 
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
climateqa/engine/chains/retriever.py CHANGED
@@ -1,126 +1,126 @@
1
- import sys
2
- import os
3
- from contextlib import contextmanager
4
 
5
- from ..reranker import rerank_docs
6
- from ...knowledge.retriever import ClimateQARetriever
7
 
8
 
9
 
10
 
11
- def divide_into_parts(target, parts):
12
- # Base value for each part
13
- base = target // parts
14
- # Remainder to distribute
15
- remainder = target % parts
16
- # List to hold the result
17
- result = []
18
 
19
- for i in range(parts):
20
- if i < remainder:
21
- # These parts get base value + 1
22
- result.append(base + 1)
23
- else:
24
- # The rest get the base value
25
- result.append(base)
26
 
27
- return result
28
 
29
 
30
- @contextmanager
31
- def suppress_output():
32
- # Open a null device
33
- with open(os.devnull, 'w') as devnull:
34
- # Store the original stdout and stderr
35
- old_stdout = sys.stdout
36
- old_stderr = sys.stderr
37
- # Redirect stdout and stderr to the null device
38
- sys.stdout = devnull
39
- sys.stderr = devnull
40
- try:
41
- yield
42
- finally:
43
- # Restore stdout and stderr
44
- sys.stdout = old_stdout
45
- sys.stderr = old_stderr
46
 
47
 
48
 
49
- def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
50
 
51
- def retrieve_documents(state):
52
 
53
- POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
54
- questions = state["questions"]
55
 
56
- # Use sources from the user input or from the LLM detection
57
- if "sources_input" not in state or state["sources_input"] is None:
58
- sources_input = ["auto"]
59
- else:
60
- sources_input = state["sources_input"]
61
- auto_mode = "auto" in sources_input
62
 
63
- # There are several options to get the final top k
64
- # Option 1 - Get 100 documents by question and rerank by question
65
- # Option 2 - Get 100/n documents by question and rerank the total
66
- if rerank_by_question:
67
- k_by_question = divide_into_parts(k_final,len(questions))
68
 
69
- docs = []
70
 
71
- for i,q in enumerate(questions):
72
 
73
- sources = q["sources"]
74
- question = q["question"]
75
 
76
- # If auto mode, we use the sources detected by the LLM
77
- if auto_mode:
78
- sources = [x for x in sources if x in POSSIBLE_SOURCES]
79
 
80
- # Otherwise, we use the config
81
- else:
82
- sources = sources_input
83
 
84
- # Search the document store using the retriever
85
- # Configure high top k for further reranking step
86
- retriever = ClimateQARetriever(
87
- vectorstore=vectorstore,
88
- sources = sources,
89
- # reports = ias_reports,
90
- min_size = 200,
91
- k_summary = k_summary,
92
- k_total = k_before_reranking,
93
- threshold = 0.5,
94
- )
95
- docs_question = retriever.get_relevant_documents(question)
96
 
97
- # Rerank
98
- if reranker is not None:
99
- with suppress_output():
100
- docs_question = rerank_docs(reranker,docs_question,question)
101
- else:
102
- # Add a default reranking score
103
- for doc in docs_question:
104
- doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
105
 
106
- # If rerank by question we select the top documents for each question
107
- if rerank_by_question:
108
- docs_question = docs_question[:k_by_question[i]]
109
 
110
- # Add sources used in the metadata
111
- for doc in docs_question:
112
- doc.metadata["sources_used"] = sources
113
 
114
- # Add to the list of docs
115
- docs.extend(docs_question)
116
 
117
- # Sorting the list in descending order by rerank_score
118
- # Then select the top k
119
- docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
120
- docs = docs[:k_final]
121
 
122
- new_state = {"documents":docs}
123
- return new_state
124
 
125
- return retrieve_documents
126
 
 
1
+ # import sys
2
+ # import os
3
+ # from contextlib import contextmanager
4
 
5
+ # from ..reranker import rerank_docs
6
+ # from ...knowledge.retriever import ClimateQARetriever
7
 
8
 
9
 
10
 
11
+ # def divide_into_parts(target, parts):
12
+ # # Base value for each part
13
+ # base = target // parts
14
+ # # Remainder to distribute
15
+ # remainder = target % parts
16
+ # # List to hold the result
17
+ # result = []
18
 
19
+ # for i in range(parts):
20
+ # if i < remainder:
21
+ # # These parts get base value + 1
22
+ # result.append(base + 1)
23
+ # else:
24
+ # # The rest get the base value
25
+ # result.append(base)
26
 
27
+ # return result
28
 
29
 
30
+ # @contextmanager
31
+ # def suppress_output():
32
+ # # Open a null device
33
+ # with open(os.devnull, 'w') as devnull:
34
+ # # Store the original stdout and stderr
35
+ # old_stdout = sys.stdout
36
+ # old_stderr = sys.stderr
37
+ # # Redirect stdout and stderr to the null device
38
+ # sys.stdout = devnull
39
+ # sys.stderr = devnull
40
+ # try:
41
+ # yield
42
+ # finally:
43
+ # # Restore stdout and stderr
44
+ # sys.stdout = old_stdout
45
+ # sys.stderr = old_stderr
46
 
47
 
48
 
49
+ # def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
50
 
51
+ # def retrieve_documents(state):
52
 
53
+ # POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
54
+ # questions = state["questions"]
55
 
56
+ # # Use sources from the user input or from the LLM detection
57
+ # if "sources_input" not in state or state["sources_input"] is None:
58
+ # sources_input = ["auto"]
59
+ # else:
60
+ # sources_input = state["sources_input"]
61
+ # auto_mode = "auto" in sources_input
62
 
63
+ # # There are several options to get the final top k
64
+ # # Option 1 - Get 100 documents by question and rerank by question
65
+ # # Option 2 - Get 100/n documents by question and rerank the total
66
+ # if rerank_by_question:
67
+ # k_by_question = divide_into_parts(k_final,len(questions))
68
 
69
+ # docs = []
70
 
71
+ # for i,q in enumerate(questions):
72
 
73
+ # sources = q["sources"]
74
+ # question = q["question"]
75
 
76
+ # # If auto mode, we use the sources detected by the LLM
77
+ # if auto_mode:
78
+ # sources = [x for x in sources if x in POSSIBLE_SOURCES]
79
 
80
+ # # Otherwise, we use the config
81
+ # else:
82
+ # sources = sources_input
83
 
84
+ # # Search the document store using the retriever
85
+ # # Configure high top k for further reranking step
86
+ # retriever = ClimateQARetriever(
87
+ # vectorstore=vectorstore,
88
+ # sources = sources,
89
+ # # reports = ias_reports,
90
+ # min_size = 200,
91
+ # k_summary = k_summary,
92
+ # k_total = k_before_reranking,
93
+ # threshold = 0.5,
94
+ # )
95
+ # docs_question = retriever.get_relevant_documents(question)
96
 
97
+ # # Rerank
98
+ # if reranker is not None:
99
+ # with suppress_output():
100
+ # docs_question = rerank_docs(reranker,docs_question,question)
101
+ # else:
102
+ # # Add a default reranking score
103
+ # for doc in docs_question:
104
+ # doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
105
 
106
+ # # If rerank by question we select the top documents for each question
107
+ # if rerank_by_question:
108
+ # docs_question = docs_question[:k_by_question[i]]
109
 
110
+ # # Add sources used in the metadata
111
+ # for doc in docs_question:
112
+ # doc.metadata["sources_used"] = sources
113
 
114
+ # # Add to the list of docs
115
+ # docs.extend(docs_question)
116
 
117
+ # # Sorting the list in descending order by rerank_score
118
+ # # Then select the top k
119
+ # docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
120
+ # docs = docs[:k_final]
121
 
122
+ # new_state = {"documents":docs}
123
+ # return new_state
124
 
125
+ # return retrieve_documents
126
 
climateqa/engine/graph.py CHANGED
@@ -40,6 +40,7 @@ class GraphState(TypedDict):
40
  min_year: int = 1960
41
  max_year: int = None
42
  documents: List[Document]
 
43
  recommended_content : List[Document]
44
  # graphs_returned: Dict[str,str]
45
 
 
40
  min_year: int = 1960
41
  max_year: int = None
42
  documents: List[Document]
43
+ related_contents : Dict[str,Document]
44
  recommended_content : List[Document]
45
  # graphs_returned: Dict[str,str]
46
 
climateqa/knowledge/retriever.py CHANGED
@@ -11,6 +11,18 @@ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
11
  from typing import List
12
  from pydantic import Field
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class ClimateQARetriever(BaseRetriever):
15
  vectorstore:VectorStore
16
  sources:list = ["IPCC","IPBES","IPOS"]
@@ -20,6 +32,7 @@ class ClimateQARetriever(BaseRetriever):
20
  k_total:int = 10
21
  namespace:str = "vectors",
22
  min_size:int = 200,
 
23
 
24
 
25
  def _get_relevant_documents(
@@ -43,6 +56,7 @@ class ClimateQARetriever(BaseRetriever):
43
  # Search for k_summary documents in the summaries dataset
44
  filters_summaries = {
45
  **filters,
 
46
  "report_type": { "$in":["SPM"]},
47
  }
48
 
@@ -52,31 +66,36 @@ class ClimateQARetriever(BaseRetriever):
52
  # Search for k_total - k_summary documents in the full reports dataset
53
  filters_full = {
54
  **filters,
 
55
  "report_type": { "$nin":["SPM"]},
56
  }
57
  k_full = self.k_total - len(docs_summaries)
58
  docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
 
 
 
 
 
 
 
59
 
60
  # Concatenate documents
61
- docs = docs_summaries + docs_full
62
 
63
  # Filter if scores are below threshold
64
  docs = [x for x in docs if len(x[0].page_content) > self.min_size]
65
  # docs = [x for x in docs if x[1] > self.threshold]
66
 
67
- # Add score to metadata
68
- results = []
69
- for i,(doc,score) in enumerate(docs):
70
- doc.page_content = doc.page_content.replace("\r\n"," ")
71
- doc.metadata["similarity_score"] = score
72
- doc.metadata["content"] = doc.page_content
73
- doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
74
- # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
75
- results.append(doc)
76
 
77
- # Sort by score
78
- # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
79
-
80
- return results
 
81
 
82
 
 
11
  from typing import List
12
  from pydantic import Field
13
 
14
+ def _add_metadata_and_score(docs: List) -> Document:
15
+ # Add score to metadata
16
+ docs_with_metadata = []
17
+ for i,(doc,score) in enumerate(docs):
18
+ doc.page_content = doc.page_content.replace("\r\n"," ")
19
+ doc.metadata["similarity_score"] = score
20
+ doc.metadata["content"] = doc.page_content
21
+ doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
22
+ # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
23
+ docs_with_metadata.append(doc)
24
+ return docs_with_metadata
25
+
26
  class ClimateQARetriever(BaseRetriever):
27
  vectorstore:VectorStore
28
  sources:list = ["IPCC","IPBES","IPOS"]
 
32
  k_total:int = 10
33
  namespace:str = "vectors",
34
  min_size:int = 200,
35
+
36
 
37
 
38
  def _get_relevant_documents(
 
56
  # Search for k_summary documents in the summaries dataset
57
  filters_summaries = {
58
  **filters,
59
+ "chunk_type":"text",
60
  "report_type": { "$in":["SPM"]},
61
  }
62
 
 
66
  # Search for k_total - k_summary documents in the full reports dataset
67
  filters_full = {
68
  **filters,
69
+ "chunk_type":"text",
70
  "report_type": { "$nin":["SPM"]},
71
  }
72
  k_full = self.k_total - len(docs_summaries)
73
  docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
74
+
75
+ # Images
76
+ filters_image = {
77
+ **filters,
78
+ "chunk_type":"image"
79
+ }
80
+ docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
81
 
82
  # Concatenate documents
83
+ docs = docs_summaries + docs_full + docs_images
84
 
85
  # Filter if scores are below threshold
86
  docs = [x for x in docs if len(x[0].page_content) > self.min_size]
87
  # docs = [x for x in docs if x[1] > self.threshold]
88
 
89
+ docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
90
+
91
+ # Filter if length are below threshold
92
+ docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
93
+ docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
 
 
 
 
94
 
95
+ return {
96
+ "docs_summaries" : docs_summaries,
97
+ "docs_full" : docs_full,
98
+ "docs_images" : docs_images
99
+ }
100
 
101
 
sandbox/20241104 - CQA - StepByStep CQA.ipynb CHANGED
The diff for this file is too large to render. See raw diff