import sys import os from contextlib import contextmanager from langchain_core.tools import tool from langchain_core.runnables import chain from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_core.runnables import RunnableLambda from ..reranker import rerank_docs from ...knowledge.retriever import ClimateQARetriever from ...knowledge.openalex import OpenAlexRetriever from .keywords_extraction import make_keywords_extraction_chain from ..utils import log_event def divide_into_parts(target, parts): # Base value for each part base = target // parts # Remainder to distribute remainder = target % parts # List to hold the result result = [] for i in range(parts): if i < remainder: # These parts get base value + 1 result.append(base + 1) else: # The rest get the base value result.append(base) return result @contextmanager def suppress_output(): # Open a null device with open(os.devnull, 'w') as devnull: # Store the original stdout and stderr old_stdout = sys.stdout old_stderr = sys.stderr # Redirect stdout and stderr to the null device sys.stdout = devnull sys.stderr = devnull try: yield finally: # Restore stdout and stderr sys.stdout = old_stdout sys.stderr = old_stderr @tool def query_retriever(question): """Just a dummy tool to simulate the retriever query""" return question def _add_sources_used_in_metadata(docs,sources,question,index): for doc in docs: doc.metadata["sources_used"] = sources doc.metadata["question_used"] = question doc.metadata["index_used"] = index return docs def _get_k_summary_by_question(n_questions): if n_questions == 0: return 0 elif n_questions == 1: return 5 elif n_questions == 2: return 3 elif n_questions == 3: return 2 else: return 1 # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results # @chain async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): print("---- Retrieve documents ----") # Get the documents from the state if "documents" in state and state["documents"] is not None: docs = state["documents"] else: docs = [] # Get the related_content from the state if "related_content" in state and state["related_content"] is not None: related_content = state["related_content"] else: related_content = [] # Get the current question current_question = state["remaining_questions"][0] remaining_questions = state["remaining_questions"][1:] k_by_question = k_final // state["n_questions"] k_summary_by_question = _get_k_summary_by_question(state["n_questions"]) sources = current_question["sources"] question = current_question["question"] index = current_question["index"] print(f"Retrieve documents for question: {question}") await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) if index == "Vector": # Search the document store using the retriever # Configure high top k for further reranking step retriever = ClimateQARetriever( vectorstore=vectorstore, sources = sources, min_size = 200, k_summary = k_summary_by_question, k_total = k_before_reranking, threshold = 0.5, ) docs_question_dict = await retriever.ainvoke(question,config) # elif index == "OpenAlex": # # keyword extraction # keywords_extraction = make_keywords_extraction_chain(llm) # keywords = keywords_extraction.invoke(question)["keywords"] # openalex_query = " AND ".join(keywords) # print(f"... OpenAlex query: {openalex_query}") # retriever_openalex = OpenAlexRetriever( # min_year = state.get("min_year",1960), # max_year = state.get("max_year",None), # k = k_before_reranking # ) # docs_question = await retriever_openalex.ainvoke(openalex_query,config) # else: # raise Exception(f"Index {index} not found in the routing index") # Rerank if reranker is not None: with suppress_output(): docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question) docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question) docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question) if rerank_by_question: docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) else: docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"] # Add a default reranking score for doc in docs_question: doc.metadata["reranking_score"] = doc.metadata["similarity_score"] docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked docs_question = docs_question[:k_by_question] images_question = docs_question_images_reranked[:k_by_question] if reranker is not None and rerank_by_question: docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True) # Add sources used in the metadata docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index) images_question = _add_sources_used_in_metadata(images_question,sources,question,index) # Add to the list of docs docs.extend(docs_question) related_content.extend(images_question) new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions} return new_state def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): @chain async def retrieve_docs(state, config): state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary) return state return retrieve_docs