timeki's picture
Merge branch 'bugfix/add_dummy_searchs' into feature/graph_recommandation
aa904c1
raw
history blame
2.97 kB
# https://github.com/langchain-ai/langchain/issues/8623
import pandas as pd
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
from pydantic import Field
class ClimateQARetriever(BaseRetriever):
vectorstore:VectorStore
sources:list = ["IPCC","IPBES","IPOS"]
reports:list = []
threshold:float = 0.6
k_summary:int = 3
k_total:int = 10
namespace:str = "vectors",
min_size:int = 200,
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources,list)
assert self.sources
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(self.reports) > 0:
filters["short_name"] = {"$in":self.reports}
else:
filters["source"] = { "$in":self.sources}
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"report_type": { "$in":["SPM"]},
}
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"report_type": { "$nin":["SPM"]},
}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
# docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i,(doc,score) in enumerate(docs):
doc.page_content = doc.page_content.replace("\r\n"," ")
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
results.append(doc)
# Sort by score
# results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
return results