timeki's picture
Add content recommandation (#17)
bcc8503 verified
# # https://github.com/langchain-ai/langchain/issues/8623
# import pandas as pd
# from langchain_core.retrievers import BaseRetriever
# from langchain_core.vectorstores import VectorStoreRetriever
# from langchain_core.documents.base import Document
# from langchain_core.vectorstores import VectorStore
# from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
# from typing import List
# from pydantic import Field
# def _add_metadata_and_score(docs: List) -> Document:
# # Add score to metadata
# docs_with_metadata = []
# for i,(doc,score) in enumerate(docs):
# doc.page_content = doc.page_content.replace("\r\n"," ")
# doc.metadata["similarity_score"] = score
# doc.metadata["content"] = doc.page_content
# doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
# # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
# docs_with_metadata.append(doc)
# return docs_with_metadata
# class ClimateQARetriever(BaseRetriever):
# vectorstore:VectorStore
# sources:list = ["IPCC","IPBES","IPOS"]
# reports:list = []
# threshold:float = 0.6
# k_summary:int = 3
# k_total:int = 10
# namespace:str = "vectors",
# min_size:int = 200,
# def _get_relevant_documents(
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
# ) -> List[Document]:
# # Check if all elements in the list are either IPCC or IPBES
# assert isinstance(self.sources,list)
# assert self.sources
# assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
# assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# # Prepare base search kwargs
# filters = {}
# if len(self.reports) > 0:
# filters["short_name"] = {"$in":self.reports}
# else:
# filters["source"] = { "$in":self.sources}
# # Search for k_summary documents in the summaries dataset
# filters_summaries = {
# **filters,
# "chunk_type":"text",
# "report_type": { "$in":["SPM"]},
# }
# docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
# docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
# # docs_summaries = []
# # Search for k_total - k_summary documents in the full reports dataset
# filters_full = {
# **filters,
# "chunk_type":"text",
# "report_type": { "$nin":["SPM"]},
# }
# k_full = self.k_total - len(docs_summaries)
# docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
# # Images
# filters_image = {
# **filters,
# "chunk_type":"image"
# }
# docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
# # docs_images = []
# # Concatenate documents
# # docs = docs_summaries + docs_full + docs_images
# # Filter if scores are below threshold
# # docs = [x for x in docs if x[1] > self.threshold]
# docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
# # Filter if length are below threshold
# docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
# docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
# return {
# "docs_summaries" : docs_summaries,
# "docs_full" : docs_full,
# "docs_images" : docs_images,
# }