File size: 9,644 Bytes
99e91d8 40084ba 99e91d8 40084ba 99e91d8 d562d38 40084ba 99e91d8 d562d38 40084ba d562d38 40084ba d562d38 40084ba d562d38 9609df9 d562d38 40084ba d562d38 40084ba d562d38 40084ba d562d38 99e91d8 d562d38 40084ba d562d38 99e91d8 d562d38 99e91d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 |
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs
# from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
from langchain_core.vectorstores import VectorStore
from typing import List
from langchain_core.documents.base import Document
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@tool
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def _add_sources_used_in_metadata(docs,sources,question,index):
for doc in docs:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
return docs
def _get_k_summary_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 5
elif n_questions == 2:
return 3
elif n_questions == 3:
return 2
else:
return 1
def _get_k_images_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 5
elif n_questions == 2:
return 3
elif n_questions == 3:
return 2
else:
return 1
def _add_metadata_and_score(docs: List) -> Document:
# Add score to metadata
docs_with_metadata = []
for i,(doc,score) in enumerate(docs):
doc.page_content = doc.page_content.replace("\r\n"," ")
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
docs_with_metadata.append(doc)
return docs_with_metadata
async def get_IPCC_relevant_documents(
query: str,
vectorstore:VectorStore,
sources:list = ["IPCC","IPBES","IPOS"],
search_figures:bool = False,
reports:list = [],
threshold:float = 0.6,
k_summary:int = 3,
k_total:int = 10,
k_images: int = 5,
namespace:str = "vectors",
min_size:int = 200,
) :
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(sources,list)
assert sources
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
assert k_total > k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(reports) > 0:
filters["short_name"] = {"$in":reports}
else:
filters["source"] = { "$in": sources}
# INIT
docs_summaries = []
docs_full = []
docs_images = []
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"chunk_type":"text",
"report_type": { "$in":["SPM"]},
}
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
# docs_summaries = []
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"chunk_type":"text",
"report_type": { "$nin":["SPM"]},
}
k_full = k_total - len(docs_summaries)
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
if search_figures:
# Images
filters_image = {
**filters,
"chunk_type":"image"
}
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
# Filter if length are below threshold
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
return {
"docs_summaries" : docs_summaries,
"docs_full" : docs_full,
"docs_images" : docs_images,
}
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
print("---- Retrieve documents ----")
# Get the documents from the state
if "documents" in state and state["documents"] is not None:
docs = state["documents"]
else:
docs = []
# Get the related_content from the state
if "related_content" in state and state["related_content"] is not None:
related_content = state["related_content"]
else:
related_content = []
search_figures = "IPCC figures" in state["relevant_content_sources"]
# Get the current question
current_question = state["remaining_questions"][0]
remaining_questions = state["remaining_questions"][1:]
k_by_question = k_final // state["n_questions"]
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
k_images_by_question = _get_k_images_by_question(state["n_questions"])
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
print(f"Retrieve documents for question: {question}")
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
if index == "Vector": # always true for now
docs_question_dict = await get_IPCC_relevant_documents(
query = question,
vectorstore=vectorstore,
search_figures = search_figures,
sources = sources,
min_size = 200,
k_summary = k_summary_by_question,
k_total = k_before_reranking,
k_images = k_images_by_question,
threshold = 0.5,
)
# Rerank
if reranker is not None:
with suppress_output():
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
if rerank_by_question:
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
else:
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
docs_question = docs_question[:k_by_question]
images_question = docs_question_images_reranked[:k_images]
if reranker is not None and rerank_by_question:
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
# Add sources used in the metadata
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
# Add to the list of docs
docs.extend(docs_question)
related_content.extend(images_question)
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
return new_state
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
@chain
async def retrieve_docs(state, config):
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
return state
return retrieve_docs
|