timeki's picture
Add content recommandation (#17)
bcc8503 verified
history blame
11.5 kB
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs
# from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
from langchain_core.vectorstores import VectorStore
from typing import List
from langchain_core.documents.base import Document
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
# The rest get the base value
return result
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def _add_sources_used_in_metadata(docs,sources,question,index):
for doc in docs:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
return docs
def _get_k_summary_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 5
elif n_questions == 2:
return 3
elif n_questions == 3:
return 2
return 1
def _get_k_images_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 7
elif n_questions == 2:
return 5
elif n_questions == 3:
return 2
return 1
def _add_metadata_and_score(docs: List) -> Document:
# Add score to metadata
docs_with_metadata = []
for i,(doc,score) in enumerate(docs):
doc.page_content = doc.page_content.replace("\r\n"," ")
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
return docs_with_metadata
async def get_IPCC_relevant_documents(
query: str,
sources:list = ["IPCC","IPBES","IPOS"],
search_figures:bool = False,
reports:list = [],
threshold:float = 0.6,
k_summary:int = 3,
k_total:int = 10,
k_images: int = 5,
namespace:str = "vectors",
min_size:int = 200,
search_only:bool = False,
) :
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(sources,list)
assert sources
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
assert k_total > k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(reports) > 0:
filters["short_name"] = {"$in":reports}
filters["source"] = { "$in": sources}
docs_summaries = []
docs_full = []
docs_images = []
if search_only:
# Only search for images if search_only is True
if search_figures:
filters_image = {
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_images = _add_metadata_and_score(docs_images)
# Regular search flow for text and optionally images
# Search for k_summary documents in the summaries dataset
filters_summaries = {
"report_type": { "$in":["SPM"]},
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
"report_type": { "$nin":["SPM"]},
k_full = k_total - len(docs_summaries)
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
if search_figures:
# Images
filters_image = {
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
# Filter if length are below threshold
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
return {
"docs_summaries" : docs_summaries,
"docs_full" : docs_full,
"docs_images" : docs_images,
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
Retrieve and rerank documents based on the current question in the state.
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
config (dict): Configuration settings for logging and other purposes.
vectorstore (object): The vector store used to retrieve relevant documents.
reranker (object): The reranker used to rerank the retrieved documents.
llm (object): The language model used for processing.
rerank_by_question (bool, optional): Whether to rerank documents by question. Defaults to True.
k_final (int, optional): The final number of documents to retrieve. Defaults to 15.
k_before_reranking (int, optional): The number of documents to retrieve before reranking. Defaults to 100.
k_summary (int, optional): The number of summary documents to retrieve. Defaults to 5.
k_images (int, optional): The number of image documents to retrieve. Defaults to 5.
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
print("---- Retrieve documents ----")
# Get the documents from the state
if "documents" in state and state["documents"] is not None:
docs = state["documents"]
docs = []
# Get the related_content from the state
if "related_content" in state and state["related_content"] is not None:
related_content = state["related_content"]
related_content = []
search_figures = "IPCC figures" in state["relevant_content_sources"]
search_only = state["search_only"]
# Get the current question
current_question = state["remaining_questions"][0]
remaining_questions = state["remaining_questions"][1:]
k_by_question = k_final // state["n_questions"]
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
k_images_by_question = _get_k_images_by_question(state["n_questions"])
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
print(f"Retrieve documents for question: {question}")
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
if index == "Vector": # always true for now
docs_question_dict = await get_IPCC_relevant_documents(
query = question,
search_figures = search_figures,
sources = sources,
min_size = 200,
k_summary = k_summary_by_question,
k_total = k_before_reranking,
k_images = k_images_by_question,
threshold = 0.5,
search_only = search_only,
# Rerank
if reranker is not None:
with suppress_output():
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
if rerank_by_question:
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
docs_question = docs_question[:k_by_question]
images_question = docs_question_images_reranked[:k_images]
if reranker is not None and rerank_by_question:
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
# Add sources used in the metadata
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
# Add to the list of docs
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
return new_state
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
async def retrieve_docs(state, config):
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
return state
return retrieve_docs