|
|
|
|
|
import pandas as pd |
|
|
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.vectorstores import VectorStoreRetriever |
|
from langchain_core.documents.base import Document |
|
from langchain_core.vectorstores import VectorStore |
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun |
|
|
|
from typing import List |
|
from pydantic import Field |
|
|
|
def _add_metadata_and_score(docs: List) -> Document: |
|
|
|
docs_with_metadata = [] |
|
for i,(doc,score) in enumerate(docs): |
|
doc.page_content = doc.page_content.replace("\r\n"," ") |
|
doc.metadata["similarity_score"] = score |
|
doc.metadata["content"] = doc.page_content |
|
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1 |
|
|
|
docs_with_metadata.append(doc) |
|
return docs_with_metadata |
|
|
|
class ClimateQARetriever(BaseRetriever): |
|
vectorstore:VectorStore |
|
sources:list = ["IPCC","IPBES","IPOS"] |
|
reports:list = [] |
|
threshold:float = 0.6 |
|
k_summary:int = 3 |
|
k_total:int = 10 |
|
namespace:str = "vectors", |
|
min_size:int = 200, |
|
|
|
|
|
|
|
def _get_relevant_documents( |
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
|
|
|
|
assert isinstance(self.sources,list) |
|
assert self.sources |
|
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources]) |
|
assert self.k_total > self.k_summary, "k_total should be greater than k_summary" |
|
|
|
|
|
filters = {} |
|
|
|
if len(self.reports) > 0: |
|
filters["short_name"] = {"$in":self.reports} |
|
else: |
|
filters["source"] = { "$in":self.sources} |
|
|
|
|
|
filters_summaries = { |
|
**filters, |
|
"chunk_type":"text", |
|
"report_type": { "$in":["SPM"]}, |
|
} |
|
|
|
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary) |
|
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold] |
|
|
|
|
|
filters_full = { |
|
**filters, |
|
"chunk_type":"text", |
|
"report_type": { "$nin":["SPM"]}, |
|
} |
|
k_full = self.k_total - len(docs_summaries) |
|
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full) |
|
|
|
|
|
filters_image = { |
|
**filters, |
|
"chunk_type":"image" |
|
} |
|
docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full) |
|
|
|
|
|
docs = docs_summaries + docs_full + docs_images |
|
|
|
|
|
docs = [x for x in docs if len(x[0].page_content) > self.min_size] |
|
|
|
|
|
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images) |
|
|
|
|
|
docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size] |
|
docs_full = [x for x in docs_full if len(x.page_content) > self.min_size] |
|
|
|
return { |
|
"docs_summaries" : docs_summaries, |
|
"docs_full" : docs_full, |
|
"docs_images" : docs_images |
|
} |
|
|
|
|
|
|