File size: 9,644 Bytes
99e91d8
 
 
 
 
 
 
 
 
 
40084ba
99e91d8
 
 
40084ba
 
 
99e91d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40084ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e91d8
d562d38
 
40084ba
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
40084ba
 
d562d38
 
 
 
 
 
40084ba
d562d38
 
 
 
 
9609df9
d562d38
 
 
40084ba
 
 
d562d38
40084ba
d562d38
 
 
 
40084ba
d562d38
 
99e91d8
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40084ba
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
99e91d8
 
d562d38
 
 
 
 
 
 
99e91d8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import sys
import os
from contextlib import contextmanager

from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda

from ..reranker import rerank_docs
# from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
from langchain_core.vectorstores import VectorStore
from typing import List
from langchain_core.documents.base import Document



def divide_into_parts(target, parts):
    # Base value for each part
    base = target // parts
    # Remainder to distribute
    remainder = target % parts
    # List to hold the result
    result = []
    
    for i in range(parts):
        if i < remainder:
            # These parts get base value + 1
            result.append(base + 1)
        else:
            # The rest get the base value
            result.append(base)
    
    return result


@contextmanager
def suppress_output():
    # Open a null device
    with open(os.devnull, 'w') as devnull:
        # Store the original stdout and stderr
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        # Redirect stdout and stderr to the null device
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            # Restore stdout and stderr
            sys.stdout = old_stdout
            sys.stderr = old_stderr


@tool
def query_retriever(question):
    """Just a dummy tool to simulate the retriever query"""
    return question

def _add_sources_used_in_metadata(docs,sources,question,index):
    for doc in docs:
        doc.metadata["sources_used"] = sources
        doc.metadata["question_used"] = question
        doc.metadata["index_used"] = index
    return docs

def _get_k_summary_by_question(n_questions):
    if n_questions == 0:
        return 0
    elif n_questions == 1:
        return 5
    elif n_questions == 2:
        return 3
    elif n_questions == 3:
        return 2
    else:
        return 1
    
def _get_k_images_by_question(n_questions):
    if n_questions == 0:
        return 0
    elif n_questions == 1:
        return 5
    elif n_questions == 2:
        return 3
    elif n_questions == 3:
        return 2
    else:
        return 1
    
def _add_metadata_and_score(docs: List) -> Document:
    # Add score to metadata
    docs_with_metadata = []
    for i,(doc,score) in enumerate(docs):
        doc.page_content = doc.page_content.replace("\r\n"," ")
        doc.metadata["similarity_score"] = score
        doc.metadata["content"] = doc.page_content
        doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
        # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
        docs_with_metadata.append(doc)
    return docs_with_metadata

async def get_IPCC_relevant_documents(
    query: str,
    vectorstore:VectorStore,
    sources:list = ["IPCC","IPBES","IPOS"],
    search_figures:bool = False,
    reports:list = [],
    threshold:float = 0.6,
    k_summary:int = 3,
    k_total:int = 10,
    k_images: int = 5,
    namespace:str = "vectors",
    min_size:int = 200,
) :

    # Check if all elements in the list are either IPCC or IPBES
    assert isinstance(sources,list)
    assert sources
    assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
    assert k_total > k_summary, "k_total should be greater than k_summary"

    # Prepare base search kwargs
    filters = {}

    if len(reports) > 0:
        filters["short_name"] = {"$in":reports}
    else:
        filters["source"] = { "$in": sources}

    # INIT 
    docs_summaries = []
    docs_full = []
    docs_images = []

    # Search for k_summary documents in the summaries dataset
    filters_summaries = {
        **filters,
        "chunk_type":"text",
        "report_type": { "$in":["SPM"]},
    }

    docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
    docs_summaries = [x for x in docs_summaries if x[1] > threshold]
    # docs_summaries = []

    # Search for k_total - k_summary documents in the full reports dataset
    filters_full = {
        **filters,
        "chunk_type":"text",
        "report_type": { "$nin":["SPM"]},
    }
    k_full = k_total - len(docs_summaries)
    docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
    
    if search_figures:
        # Images
        filters_image = {
            **filters,
            "chunk_type":"image"
        }
        docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)


    docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
    
    # Filter if length are below threshold
    docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
    docs_full = [x for x in docs_full if len(x.page_content) > min_size]
    
    
    return {
        "docs_summaries" : docs_summaries,
        "docs_full" : docs_full,
        "docs_images" : docs_images,
    }



# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
    print("---- Retrieve documents ----")
    
    # Get the documents from the state
    if "documents" in state and state["documents"] is not None:
        docs = state["documents"]
    else:
        docs = []
    # Get the related_content from the state
    if "related_content" in state and state["related_content"] is not None:
        related_content = state["related_content"]
    else:
        related_content = []
    
    search_figures = "IPCC figures" in state["relevant_content_sources"]

    # Get the current question
    current_question = state["remaining_questions"][0]
    remaining_questions = state["remaining_questions"][1:]
         
    k_by_question = k_final // state["n_questions"]
    k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
    k_images_by_question = _get_k_images_by_question(state["n_questions"])
    
    sources = current_question["sources"]
    question = current_question["question"]
    index = current_question["index"]
    
    print(f"Retrieve documents for question: {question}")
    await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)


    if index == "Vector": # always true for now
        docs_question_dict = await get_IPCC_relevant_documents(
            query  = question,
            vectorstore=vectorstore,
            search_figures = search_figures,
            sources = sources,
            min_size = 200,
            k_summary = k_summary_by_question,
            k_total = k_before_reranking,
            k_images = k_images_by_question,
            threshold = 0.5,
        )

    
    # Rerank
    if reranker is not None:
        with suppress_output():
            docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
            docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
            docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
            if rerank_by_question:
                docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
                docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
                docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
    else:
        docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
        # Add a default reranking score
        for doc in docs_question:
            doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
    
    docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
    docs_question = docs_question[:k_by_question]
    images_question = docs_question_images_reranked[:k_images]
            
    if reranker is not None and rerank_by_question:
        docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
            
    # Add sources used in the metadata
    docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
    images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
    
    # Add to the list of docs
    docs.extend(docs_question)
    related_content.extend(images_question)
    new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
    return new_state
    


def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
    @chain
    async def retrieve_docs(state, config):
        state =  await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
        return state
    
    return retrieve_docs