import sys import os from contextlib import contextmanager from langchain_core.tools import tool from langchain_core.runnables import chain from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_core.runnables import RunnableLambda from ..reranker import rerank_docs # from ...knowledge.retriever import ClimateQARetriever from ...knowledge.openalex import OpenAlexRetriever from .keywords_extraction import make_keywords_extraction_chain from ..utils import log_event from langchain_core.vectorstores import VectorStore from typing import List from langchain_core.documents.base import Document def divide_into_parts(target, parts): # Base value for each part base = target // parts # Remainder to distribute remainder = target % parts # List to hold the result result = [] for i in range(parts): if i < remainder: # These parts get base value + 1 result.append(base + 1) else: # The rest get the base value result.append(base) return result @contextmanager def suppress_output(): # Open a null device with open(os.devnull, 'w') as devnull: # Store the original stdout and stderr old_stdout = sys.stdout old_stderr = sys.stderr # Redirect stdout and stderr to the null device sys.stdout = devnull sys.stderr = devnull try: yield finally: # Restore stdout and stderr sys.stdout = old_stdout sys.stderr = old_stderr @tool def query_retriever(question): """Just a dummy tool to simulate the retriever query""" return question def _add_sources_used_in_metadata(docs,sources,question,index): for doc in docs: doc.metadata["sources_used"] = sources doc.metadata["question_used"] = question doc.metadata["index_used"] = index return docs def _get_k_summary_by_question(n_questions): if n_questions == 0: return 0 elif n_questions == 1: return 5 elif n_questions == 2: return 3 elif n_questions == 3: return 2 else: return 1 def _get_k_images_by_question(n_questions): if n_questions == 0: return 0 elif n_questions == 1: return 5 elif n_questions == 2: return 3 elif n_questions == 3: return 2 else: return 1 def _add_metadata_and_score(docs: List) -> Document: # Add score to metadata docs_with_metadata = [] for i,(doc,score) in enumerate(docs): doc.page_content = doc.page_content.replace("\r\n"," ") doc.metadata["similarity_score"] = score doc.metadata["content"] = doc.page_content doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1 # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" docs_with_metadata.append(doc) return docs_with_metadata async def get_IPCC_relevant_documents( query: str, vectorstore:VectorStore, sources:list = ["IPCC","IPBES","IPOS"], search_figures:bool = False, reports:list = [], threshold:float = 0.6, k_summary:int = 3, k_total:int = 10, k_images: int = 5, namespace:str = "vectors", min_size:int = 200, ) : # Check if all elements in the list are either IPCC or IPBES assert isinstance(sources,list) assert sources assert all([x in ["IPCC","IPBES","IPOS"] for x in sources]) assert k_total > k_summary, "k_total should be greater than k_summary" # Prepare base search kwargs filters = {} if len(reports) > 0: filters["short_name"] = {"$in":reports} else: filters["source"] = { "$in": sources} # INIT docs_summaries = [] docs_full = [] docs_images = [] # Search for k_summary documents in the summaries dataset filters_summaries = { **filters, "chunk_type":"text", "report_type": { "$in":["SPM"]}, } docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary) docs_summaries = [x for x in docs_summaries if x[1] > threshold] # docs_summaries = [] # Search for k_total - k_summary documents in the full reports dataset filters_full = { **filters, "chunk_type":"text", "report_type": { "$nin":["SPM"]}, } k_full = k_total - len(docs_summaries) docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full) if search_figures: # Images filters_image = { **filters, "chunk_type":"image" } docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images) # Filter if length are below threshold docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size] docs_full = [x for x in docs_full if len(x.page_content) > min_size] return { "docs_summaries" : docs_summaries, "docs_full" : docs_full, "docs_images" : docs_images, } # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results # @chain async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5): print("---- Retrieve documents ----") # Get the documents from the state if "documents" in state and state["documents"] is not None: docs = state["documents"] else: docs = [] # Get the related_content from the state if "related_content" in state and state["related_content"] is not None: related_content = state["related_content"] else: related_content = [] search_figures = "IPCC figures" in state["relevant_content_sources"] # Get the current question current_question = state["remaining_questions"][0] remaining_questions = state["remaining_questions"][1:] k_by_question = k_final // state["n_questions"] k_summary_by_question = _get_k_summary_by_question(state["n_questions"]) k_images_by_question = _get_k_images_by_question(state["n_questions"]) sources = current_question["sources"] question = current_question["question"] index = current_question["index"] print(f"Retrieve documents for question: {question}") await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) if index == "Vector": # always true for now docs_question_dict = await get_IPCC_relevant_documents( query = question, vectorstore=vectorstore, search_figures = search_figures, sources = sources, min_size = 200, k_summary = k_summary_by_question, k_total = k_before_reranking, k_images = k_images_by_question, threshold = 0.5, ) # Rerank if reranker is not None: with suppress_output(): docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question) docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question) docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question) if rerank_by_question: docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) else: docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"] # Add a default reranking score for doc in docs_question: doc.metadata["reranking_score"] = doc.metadata["similarity_score"] docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked docs_question = docs_question[:k_by_question] images_question = docs_question_images_reranked[:k_images] if reranker is not None and rerank_by_question: docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True) # Add sources used in the metadata docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index) images_question = _add_sources_used_in_metadata(images_question,sources,question,index) # Add to the list of docs docs.extend(docs_question) related_content.extend(images_question) new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions} return new_state def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): @chain async def retrieve_docs(state, config): state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary) return state return retrieve_docs