File size: 5,263 Bytes
99e91d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import sys
import os
from contextlib import contextmanager

from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda

from ..reranker import rerank_docs
from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event



def divide_into_parts(target, parts):
    # Base value for each part
    base = target // parts
    # Remainder to distribute
    remainder = target % parts
    # List to hold the result
    result = []
    
    for i in range(parts):
        if i < remainder:
            # These parts get base value + 1
            result.append(base + 1)
        else:
            # The rest get the base value
            result.append(base)
    
    return result


@contextmanager
def suppress_output():
    # Open a null device
    with open(os.devnull, 'w') as devnull:
        # Store the original stdout and stderr
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        # Redirect stdout and stderr to the null device
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            # Restore stdout and stderr
            sys.stdout = old_stdout
            sys.stderr = old_stderr


@tool
def query_retriever(question):
    """Just a dummy tool to simulate the retriever query"""
    return question







def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):

    # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
    @chain
    async def retrieve_documents(state,config):

        keywords_extraction = make_keywords_extraction_chain(llm)
        
        current_question = state["remaining_questions"][0]
        remaining_questions = state["remaining_questions"][1:]
        
        # ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")


        # # There are several options to get the final top k
        # # Option 1 - Get 100 documents by question and rerank by question
        # # Option 2 - Get 100/n documents by question and rerank the total
        # if rerank_by_question:
        #     k_by_question = divide_into_parts(k_final,len(questions))
        
        # docs = state["documents"]
        # if docs is None: docs = []

        docs = []
        k_by_question = k_final // state["n_questions"]
        
        sources = current_question["sources"]
        question = current_question["question"]
        index = current_question["index"]
        

        await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)


        if index == "Vector":
                
            # Search the document store using the retriever
            # Configure high top k for further reranking step
            retriever = ClimateQARetriever(
                vectorstore=vectorstore,
                sources = sources,
                min_size = 200,
                k_summary = k_summary,
                k_total = k_before_reranking,
                threshold = 0.5,
            )
            docs_question = await retriever.ainvoke(question,config)

        elif index == "OpenAlex":

            keywords = keywords_extraction.invoke(question)["keywords"]
            openalex_query = " AND ".join(keywords)

            print(f"... OpenAlex query: {openalex_query}")

            retriever_openalex = OpenAlexRetriever(
                min_year = state.get("min_year",1960), 
                max_year = state.get("max_year",None), 
                k = k_before_reranking
            )
            docs_question = await retriever_openalex.ainvoke(openalex_query,config)

        else:
            raise Exception(f"Index {index} not found in the routing index")
        
        # Rerank
        if reranker is not None:
            with suppress_output():
                docs_question = rerank_docs(reranker,docs_question,question)
        else:
            # Add a default reranking score
            for doc in docs_question:
                doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
            
        # If rerank by question we select the top documents for each question
        if rerank_by_question:
            docs_question = docs_question[:k_by_question]
            
        # Add sources used in the metadata
        for doc in docs_question:
            doc.metadata["sources_used"] = sources
            doc.metadata["question_used"] = question
            doc.metadata["index_used"] = index
        
        # Add to the list of docs
        docs.extend(docs_question)
            
        # Sorting the list in descending order by rerank_score
        docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
        new_state = {"documents":docs,"remaining_questions":remaining_questions}
        return new_state
    
    return retrieve_documents