TheoLvs's picture
Corrected bugs causing errors in async mode
38ed905
raw
history blame
5.43 kB
# https://github.com/langchain-ai/langchain/issues/8623
import pandas as pd
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
from pydantic import Field
class ClimateQARetriever(BaseRetriever):
vectorstore:VectorStore
sources:list = ["IPCC","IPBES"]
reports:list = []
threshold:float = 0.6
k_summary:int = 3
k_total:int = 10
namespace:str = "vectors"
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources,list)
assert all([x in ["IPCC","IPBES"] for x in self.sources])
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
if len(self.reports) > 0:
filters["short_name"] = {"$in":self.reports}
else:
filters["source"] = { "$in":self.sources}
# Search for k_summary documents in the summaries dataset
filters_summaries = {
**filters,
"report_type": { "$in":["SPM"]},
}
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {
**filters,
"report_type": { "$nin":["SPM"]},
}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i,(doc,score) in enumerate(docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"])
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
results.append(doc)
# Sort by score
# results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)
return results
# def filter_summaries(df,k_summary = 3,k_total = 10):
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"
# # # Filter by source
# # if source == "IPCC":
# # df = df.loc[df["source"]=="IPCC"]
# # elif source == "IPBES":
# # df = df.loc[df["source"]=="IPBES"]
# # else:
# # pass
# # Separate summaries and full reports
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]
# # Find passages from summaries dataset
# passages_summaries = df_summaries.head(k_summary)
# # Find passages from full reports dataset
# passages_fullreports = df_full.head(k_total - len(passages_summaries))
# # Concatenate passages
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
# return passages
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
# assert max_k > k_total
# validated_sources = ["IPCC","IPBES"]
# sources = [x for x in sources if x in validated_sources]
# filters = {
# "source": { "$in": sources },
# }
# print(filters)
# # Retrieve documents
# docs = retriever.retrieve(query,top_k = max_k,filters = filters)
# # Filter by score
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]
# if len(docs) == 0:
# return []
# res = pd.DataFrame(docs)
# passages_df = filter_summaries(res,k_summary,k_total)
# if as_dict:
# contents = passages_df["content"].tolist()
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
# passages = []
# for i in range(len(contents)):
# passages.append({"content":contents[i],"meta":meta[i]})
# return passages
# else:
# return passages_df
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):
# print("hellooooo")
# # Reformulate queries
# reformulated_query,language = reformulate(query)
# print(reformulated_query)
# # Retrieve documents
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
# response = {
# "query":query,
# "reformulated_query":reformulated_query,
# "language":language,
# "sources":passages,
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
# }
# return response