File size: 3,018 Bytes
6b43c86
 
 
 
 
 
 
40084ba
 
 
 
 
6b43c86
40084ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b43c86
 
40084ba
 
 
6b43c86
 
 
 
40084ba
6b43c86
40084ba
6b43c86
 
40084ba
6b43c86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

from typing import List

# class GraphRetriever(BaseRetriever):
#     vectorstore:VectorStore
#     sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
#     threshold:float = 0.5
#     k_total:int = 10

#     def _get_relevant_documents(
#         self, query: str, *, run_manager: CallbackManagerForRetrieverRun
#     ) -> List[Document]:

#         # Check if all elements in the list are IEA or OWID
#         assert isinstance(self.sources,list)
#         assert self.sources
#         assert any([x in ["OWID"] for x in self.sources])

#         # Prepare base search kwargs
#         filters = {}

#         filters["source"] = {"$in": self.sources}

#         docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
        
#         # Filter if scores are below threshold
#         docs = [x for x in docs if x[1] > self.threshold]

#         # Remove duplicate documents
#         unique_docs = []
#         seen_docs = []
#         for i, doc in enumerate(docs):
#             if doc[0].page_content not in seen_docs:
#                 unique_docs.append(doc)
#                 seen_docs.append(doc[0].page_content)

#         # Add score to metadata
#         results = []
#         for i,(doc,score) in enumerate(unique_docs):
#             doc.metadata["similarity_score"] = score
#             doc.metadata["content"] = doc.page_content
#             results.append(doc)

#         return results
    
async def retrieve_graphs(
    query: str,
    vectorstore:VectorStore,
    sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
    threshold:float = 0.5,
    k_total:int = 10,
)-> List[Document]:

        # Check if all elements in the list are IEA or OWID
        assert isinstance(sources,list)
        assert sources
        assert any([x in ["OWID"] for x in sources])

        # Prepare base search kwargs
        filters = {}

        filters["source"] = {"$in": sources}

        docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
        
        # Filter if scores are below threshold
        docs = [x for x in docs if x[1] > threshold]

        # Remove duplicate documents
        unique_docs = []
        seen_docs = []
        for i, doc in enumerate(docs):
            if doc[0].page_content not in seen_docs:
                unique_docs.append(doc)
                seen_docs.append(doc[0].page_content)

        # Add score to metadata
        results = []
        for i,(doc,score) in enumerate(unique_docs):
            doc.metadata["similarity_score"] = score
            doc.metadata["content"] = doc.page_content
            results.append(doc)

        return results