File size: 4,707 Bytes
6b43c86
 
 
 
 
40084ba
6b43c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40084ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b43c86
40084ba
6b43c86
40084ba
6b43c86
40084ba
 
 
 
 
 
 
 
 
 
6b43c86
40084ba
 
 
 
 
 
 
 
 
6b43c86
40084ba
 
 
 
6b43c86
40084ba
6b43c86
40084ba
6b43c86
40084ba
 
 
 
 
 
 
 
 
6b43c86
40084ba
6b43c86
40084ba
 
 
 
 
6b43c86
40084ba
 
 
 
6b43c86
40084ba
6b43c86
40084ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import sys
import os
from contextlib import contextmanager

from ..reranker import rerank_docs
from ..graph_retriever import retrieve_graphs # GraphRetriever
from ...utils import remove_duplicates_keep_highest_score


def divide_into_parts(target, parts):
    # Base value for each part
    base = target // parts
    # Remainder to distribute
    remainder = target % parts
    # List to hold the result
    result = []
    
    for i in range(parts):
        if i < remainder:
            # These parts get base value + 1
            result.append(base + 1)
        else:
            # The rest get the base value
            result.append(base)
    
    return result


@contextmanager
def suppress_output():
    # Open a null device
    with open(os.devnull, 'w') as devnull:
        # Store the original stdout and stderr
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        # Redirect stdout and stderr to the null device
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            # Restore stdout and stderr
            sys.stdout = old_stdout
            sys.stderr = old_stderr


def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):

    async def node_retrieve_graphs(state):
        print("---- Retrieving graphs ----")
        
        POSSIBLE_SOURCES = ["IEA", "OWID"]
        questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[]  else [state["query"]]
        # sources_input = state["sources_input"]
        sources_input = ["auto"]

        auto_mode = "auto" in sources_input

        # There are several options to get the final top k
        # Option 1 - Get 100 documents by question and rerank by question
        # Option 2 - Get 100/n documents by question and rerank the total
        if rerank_by_question:
            k_by_question = divide_into_parts(k_final,len(questions))
        
        docs = []
        
        for i,q in enumerate(questions):
            
            question = q["question"] if isinstance(q, dict) else q
            
            print(f"Subquestion {i}: {question}")
            
            # If auto mode, we use all sources
            if auto_mode:
                sources = POSSIBLE_SOURCES
            # Otherwise, we use the config
            else:
                sources = sources_input

            if any([x in POSSIBLE_SOURCES for x in sources]):

                sources = [x for x in sources if x in POSSIBLE_SOURCES]
                
                # Search the document store using the retriever
                docs_question = await retrieve_graphs(
                    query = question,
                    vectorstore = vectorstore,
                    sources = sources,
                    k_total = k_before_reranking,
                    threshold = 0.5,
                    )
                # docs_question = retriever.get_relevant_documents(question)
                
                # Rerank
                if reranker is not None and docs_question!=[]:
                    with suppress_output():
                        docs_question = rerank_docs(reranker,docs_question,question)
                else:
                    # Add a default reranking score
                    for doc in docs_question:
                        doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
                    
                # If rerank by question we select the top documents for each question
                if rerank_by_question:
                    docs_question = docs_question[:k_by_question[i]]
                    
                # Add sources used in the metadata
                for doc in docs_question:
                    doc.metadata["sources_used"] = sources
                
                print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")

                docs.extend(docs_question)

            else:
                print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
                
            # Remove duplicates and keep the duplicate document with the highest reranking score
            docs = remove_duplicates_keep_highest_score(docs)

            # Sorting the list in descending order by rerank_score
            # Then select the top k
            docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
            docs = docs[:k_final]

        return {"recommended_content": docs}
        
    return node_retrieve_graphs