File size: 4,707 Bytes
6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba 6b43c86 40084ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import sys
import os
from contextlib import contextmanager
from ..reranker import rerank_docs
from ..graph_retriever import retrieve_graphs # GraphRetriever
from ...utils import remove_duplicates_keep_highest_score
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
async def node_retrieve_graphs(state):
print("---- Retrieving graphs ----")
POSSIBLE_SOURCES = ["IEA", "OWID"]
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
# sources_input = state["sources_input"]
sources_input = ["auto"]
auto_mode = "auto" in sources_input
# There are several options to get the final top k
# Option 1 - Get 100 documents by question and rerank by question
# Option 2 - Get 100/n documents by question and rerank the total
if rerank_by_question:
k_by_question = divide_into_parts(k_final,len(questions))
docs = []
for i,q in enumerate(questions):
question = q["question"] if isinstance(q, dict) else q
print(f"Subquestion {i}: {question}")
# If auto mode, we use all sources
if auto_mode:
sources = POSSIBLE_SOURCES
# Otherwise, we use the config
else:
sources = sources_input
if any([x in POSSIBLE_SOURCES for x in sources]):
sources = [x for x in sources if x in POSSIBLE_SOURCES]
# Search the document store using the retriever
docs_question = await retrieve_graphs(
query = question,
vectorstore = vectorstore,
sources = sources,
k_total = k_before_reranking,
threshold = 0.5,
)
# docs_question = retriever.get_relevant_documents(question)
# Rerank
if reranker is not None and docs_question!=[]:
with suppress_output():
docs_question = rerank_docs(reranker,docs_question,question)
else:
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
# If rerank by question we select the top documents for each question
if rerank_by_question:
docs_question = docs_question[:k_by_question[i]]
# Add sources used in the metadata
for doc in docs_question:
doc.metadata["sources_used"] = sources
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
docs.extend(docs_question)
else:
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
# Remove duplicates and keep the duplicate document with the highest reranking score
docs = remove_duplicates_keep_highest_score(docs)
# Sorting the list in descending order by rerank_score
# Then select the top k
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs = docs[:k_final]
return {"recommended_content": docs}
return node_retrieve_graphs |