timeki's picture
edit prints for logs
9609df9
raw
history blame
6.95 kB
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs
from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@tool
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def _add_sources_used_in_metadata(docs,sources,question,index):
for doc in docs:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
return docs
def _get_k_summary_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 5
elif n_questions == 2:
return 3
elif n_questions == 3:
return 2
else:
return 1
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
print("---- Retrieve documents ----")
# Get the documents from the state
if "documents" in state and state["documents"] is not None:
docs = state["documents"]
else:
docs = []
# Get the related_content from the state
if "related_content" in state and state["related_content"] is not None:
related_content = state["related_content"]
else:
related_content = []
# Get the current question
current_question = state["remaining_questions"][0]
remaining_questions = state["remaining_questions"][1:]
k_by_question = k_final // state["n_questions"]
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
print(f"Retrieve documents for question: {question}")
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
if index == "Vector":
# Search the document store using the retriever
# Configure high top k for further reranking step
retriever = ClimateQARetriever(
vectorstore=vectorstore,
sources = sources,
min_size = 200,
k_summary = k_summary_by_question,
k_total = k_before_reranking,
threshold = 0.5,
)
docs_question_dict = await retriever.ainvoke(question,config)
# elif index == "OpenAlex":
# # keyword extraction
# keywords_extraction = make_keywords_extraction_chain(llm)
# keywords = keywords_extraction.invoke(question)["keywords"]
# openalex_query = " AND ".join(keywords)
# print(f"... OpenAlex query: {openalex_query}")
# retriever_openalex = OpenAlexRetriever(
# min_year = state.get("min_year",1960),
# max_year = state.get("max_year",None),
# k = k_before_reranking
# )
# docs_question = await retriever_openalex.ainvoke(openalex_query,config)
# else:
# raise Exception(f"Index {index} not found in the routing index")
# Rerank
if reranker is not None:
with suppress_output():
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
if rerank_by_question:
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
else:
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
docs_question = docs_question[:k_by_question]
images_question = docs_question_images_reranked[:k_by_question]
if reranker is not None and rerank_by_question:
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
# Add sources used in the metadata
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
# Add to the list of docs
docs.extend(docs_question)
related_content.extend(images_question)
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
return new_state
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
@chain
async def retrieve_docs(state, config):
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
return state
return retrieve_docs