timeki's picture
Add content recommandation (#17)
bcc8503 verified
raw
history blame
3.02 kB
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
# class GraphRetriever(BaseRetriever):
# vectorstore:VectorStore
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
# threshold:float = 0.5
# k_total:int = 10
# def _get_relevant_documents(
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
# ) -> List[Document]:
# # Check if all elements in the list are IEA or OWID
# assert isinstance(self.sources,list)
# assert self.sources
# assert any([x in ["OWID"] for x in self.sources])
# # Prepare base search kwargs
# filters = {}
# filters["source"] = {"$in": self.sources}
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
# # Filter if scores are below threshold
# docs = [x for x in docs if x[1] > self.threshold]
# # Remove duplicate documents
# unique_docs = []
# seen_docs = []
# for i, doc in enumerate(docs):
# if doc[0].page_content not in seen_docs:
# unique_docs.append(doc)
# seen_docs.append(doc[0].page_content)
# # Add score to metadata
# results = []
# for i,(doc,score) in enumerate(unique_docs):
# doc.metadata["similarity_score"] = score
# doc.metadata["content"] = doc.page_content
# results.append(doc)
# return results
async def retrieve_graphs(
query: str,
vectorstore:VectorStore,
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
threshold:float = 0.5,
k_total:int = 10,
)-> List[Document]:
# Check if all elements in the list are IEA or OWID
assert isinstance(sources,list)
assert sources
assert any([x in ["OWID"] for x in sources])
# Prepare base search kwargs
filters = {}
filters["source"] = {"$in": sources}
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > threshold]
# Remove duplicate documents
unique_docs = []
seen_docs = []
for i, doc in enumerate(docs):
if doc[0].page_content not in seen_docs:
unique_docs.append(doc)
seen_docs.append(doc[0].page_content)
# Add score to metadata
results = []
for i,(doc,score) in enumerate(unique_docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
results.append(doc)
return results