File size: 5,382 Bytes
99e91d8 49acaf1 99e91d8 eee8932 99e91d8 eee8932 99e91d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs
from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@tool
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
@chain
async def retrieve_documents(state,config):
print("---- Retrieve documents ----")
keywords_extraction = make_keywords_extraction_chain(llm)
current_question = state["remaining_questions"][0]
remaining_questions = state["remaining_questions"][1:]
# ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
# # There are several options to get the final top k
# # Option 1 - Get 100 documents by question and rerank by question
# # Option 2 - Get 100/n documents by question and rerank the total
# if rerank_by_question:
# k_by_question = divide_into_parts(k_final,len(questions))
if "documents" in state and state["documents"] is not None:
docs = state["documents"]
else:
docs = []
k_by_question = k_final // state["n_questions"]
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
if index == "Vector":
# Search the document store using the retriever
# Configure high top k for further reranking step
retriever = ClimateQARetriever(
vectorstore=vectorstore,
sources = sources,
min_size = 200,
k_summary = k_summary,
k_total = k_before_reranking,
threshold = 0.5,
)
docs_question = await retriever.ainvoke(question,config)
elif index == "OpenAlex":
keywords = keywords_extraction.invoke(question)["keywords"]
openalex_query = " AND ".join(keywords)
print(f"... OpenAlex query: {openalex_query}")
retriever_openalex = OpenAlexRetriever(
min_year = state.get("min_year",1960),
max_year = state.get("max_year",None),
k = k_before_reranking
)
docs_question = await retriever_openalex.ainvoke(openalex_query,config)
else:
raise Exception(f"Index {index} not found in the routing index")
# Rerank
if reranker is not None:
with suppress_output():
docs_question = rerank_docs(reranker,docs_question,question)
else:
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
# If rerank by question we select the top documents for each question
if rerank_by_question:
docs_question = docs_question[:k_by_question]
# Add sources used in the metadata
for doc in docs_question:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
# Add to the list of docs
docs.extend(docs_question)
# Sorting the list in descending order by rerank_score
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
new_state = {"documents":docs,"remaining_questions":remaining_questions}
return new_state
return retrieve_documents
|