|
import sys |
|
import os |
|
from contextlib import contextmanager |
|
|
|
from langchain_core.tools import tool |
|
from langchain_core.runnables import chain |
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough |
|
from langchain_core.runnables import RunnableLambda |
|
|
|
from ..reranker import rerank_docs |
|
|
|
from ...knowledge.openalex import OpenAlexRetriever |
|
from .keywords_extraction import make_keywords_extraction_chain |
|
from ..utils import log_event |
|
from langchain_core.vectorstores import VectorStore |
|
from typing import List |
|
from langchain_core.documents.base import Document |
|
|
|
|
|
|
|
def divide_into_parts(target, parts): |
|
|
|
base = target // parts |
|
|
|
remainder = target % parts |
|
|
|
result = [] |
|
|
|
for i in range(parts): |
|
if i < remainder: |
|
|
|
result.append(base + 1) |
|
else: |
|
|
|
result.append(base) |
|
|
|
return result |
|
|
|
|
|
@contextmanager |
|
def suppress_output(): |
|
|
|
with open(os.devnull, 'w') as devnull: |
|
|
|
old_stdout = sys.stdout |
|
old_stderr = sys.stderr |
|
|
|
sys.stdout = devnull |
|
sys.stderr = devnull |
|
try: |
|
yield |
|
finally: |
|
|
|
sys.stdout = old_stdout |
|
sys.stderr = old_stderr |
|
|
|
|
|
@tool |
|
def query_retriever(question): |
|
"""Just a dummy tool to simulate the retriever query""" |
|
return question |
|
|
|
def _add_sources_used_in_metadata(docs,sources,question,index): |
|
for doc in docs: |
|
doc.metadata["sources_used"] = sources |
|
doc.metadata["question_used"] = question |
|
doc.metadata["index_used"] = index |
|
return docs |
|
|
|
def _get_k_summary_by_question(n_questions): |
|
if n_questions == 0: |
|
return 0 |
|
elif n_questions == 1: |
|
return 5 |
|
elif n_questions == 2: |
|
return 3 |
|
elif n_questions == 3: |
|
return 2 |
|
else: |
|
return 1 |
|
|
|
def _get_k_images_by_question(n_questions): |
|
if n_questions == 0: |
|
return 0 |
|
elif n_questions == 1: |
|
return 5 |
|
elif n_questions == 2: |
|
return 3 |
|
elif n_questions == 3: |
|
return 2 |
|
else: |
|
return 1 |
|
|
|
def _add_metadata_and_score(docs: List) -> Document: |
|
|
|
docs_with_metadata = [] |
|
for i,(doc,score) in enumerate(docs): |
|
doc.page_content = doc.page_content.replace("\r\n"," ") |
|
doc.metadata["similarity_score"] = score |
|
doc.metadata["content"] = doc.page_content |
|
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1 |
|
|
|
docs_with_metadata.append(doc) |
|
return docs_with_metadata |
|
|
|
async def get_IPCC_relevant_documents( |
|
query: str, |
|
vectorstore:VectorStore, |
|
sources:list = ["IPCC","IPBES","IPOS"], |
|
search_figures:bool = False, |
|
reports:list = [], |
|
threshold:float = 0.6, |
|
k_summary:int = 3, |
|
k_total:int = 10, |
|
k_images: int = 5, |
|
namespace:str = "vectors", |
|
min_size:int = 200, |
|
) : |
|
|
|
|
|
assert isinstance(sources,list) |
|
assert sources |
|
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources]) |
|
assert k_total > k_summary, "k_total should be greater than k_summary" |
|
|
|
|
|
filters = {} |
|
|
|
if len(reports) > 0: |
|
filters["short_name"] = {"$in":reports} |
|
else: |
|
filters["source"] = { "$in": sources} |
|
|
|
|
|
docs_summaries = [] |
|
docs_full = [] |
|
docs_images = [] |
|
|
|
|
|
filters_summaries = { |
|
**filters, |
|
"chunk_type":"text", |
|
"report_type": { "$in":["SPM"]}, |
|
} |
|
|
|
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary) |
|
docs_summaries = [x for x in docs_summaries if x[1] > threshold] |
|
|
|
|
|
|
|
filters_full = { |
|
**filters, |
|
"chunk_type":"text", |
|
"report_type": { "$nin":["SPM"]}, |
|
} |
|
k_full = k_total - len(docs_summaries) |
|
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full) |
|
|
|
if search_figures: |
|
|
|
filters_image = { |
|
**filters, |
|
"chunk_type":"image" |
|
} |
|
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images) |
|
|
|
|
|
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images) |
|
|
|
|
|
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size] |
|
docs_full = [x for x in docs_full if len(x.page_content) > min_size] |
|
|
|
|
|
return { |
|
"docs_summaries" : docs_summaries, |
|
"docs_full" : docs_full, |
|
"docs_images" : docs_images, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5): |
|
print("---- Retrieve documents ----") |
|
|
|
|
|
if "documents" in state and state["documents"] is not None: |
|
docs = state["documents"] |
|
else: |
|
docs = [] |
|
|
|
if "related_content" in state and state["related_content"] is not None: |
|
related_content = state["related_content"] |
|
else: |
|
related_content = [] |
|
|
|
search_figures = "IPCC figures" in state["relevant_content_sources"] |
|
|
|
|
|
current_question = state["remaining_questions"][0] |
|
remaining_questions = state["remaining_questions"][1:] |
|
|
|
k_by_question = k_final // state["n_questions"] |
|
k_summary_by_question = _get_k_summary_by_question(state["n_questions"]) |
|
k_images_by_question = _get_k_images_by_question(state["n_questions"]) |
|
|
|
sources = current_question["sources"] |
|
question = current_question["question"] |
|
index = current_question["index"] |
|
|
|
print(f"Retrieve documents for question: {question}") |
|
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) |
|
|
|
|
|
if index == "Vector": |
|
docs_question_dict = await get_IPCC_relevant_documents( |
|
query = question, |
|
vectorstore=vectorstore, |
|
search_figures = search_figures, |
|
sources = sources, |
|
min_size = 200, |
|
k_summary = k_summary_by_question, |
|
k_total = k_before_reranking, |
|
k_images = k_images_by_question, |
|
threshold = 0.5, |
|
) |
|
|
|
|
|
|
|
if reranker is not None: |
|
with suppress_output(): |
|
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question) |
|
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question) |
|
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question) |
|
if rerank_by_question: |
|
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) |
|
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) |
|
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True) |
|
else: |
|
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"] |
|
|
|
for doc in docs_question: |
|
doc.metadata["reranking_score"] = doc.metadata["similarity_score"] |
|
|
|
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked |
|
docs_question = docs_question[:k_by_question] |
|
images_question = docs_question_images_reranked[:k_images] |
|
|
|
if reranker is not None and rerank_by_question: |
|
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True) |
|
|
|
|
|
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index) |
|
images_question = _add_sources_used_in_metadata(images_question,sources,question,index) |
|
|
|
|
|
docs.extend(docs_question) |
|
related_content.extend(images_question) |
|
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions} |
|
return new_state |
|
|
|
|
|
|
|
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): |
|
@chain |
|
async def retrieve_docs(state, config): |
|
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary) |
|
return state |
|
|
|
return retrieve_docs |
|
|
|
|
|
|
|
|