File size: 5,162 Bytes
d665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
c9b7c29
 
 
 
 
d665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2ba772
d665c2c
 
b2ba772
 
d665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2ba772
d665c2c
b2ba772
40455cc
 
 
d665c2c
 
40455cc
 
d665c2c
 
b2ba772
c9b7c29
 
b2ba772
d665c2c
 
 
 
 
 
 
 
 
 
 
 
 
b2ba772
d665c2c
 
b2ba772
d665c2c
 
 
 
 
 
 
 
 
c9b7c29
40455cc
b2ba772
 
d665c2c
 
c9b7c29
d665c2c
 
 
b2ba772
 
 
 
 
 
d665c2c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from sentence_transformers import SentenceTransformer
import streamlit as st
from sentence_transformers import CrossEncoder 
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor, as_completed
import pickle
import faiss
from llama_index.core import VectorStoreIndex,StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.retrievers import BaseRetriever
from llama_index.vector_stores.chroma import ChromaVectorStore
#%pip install llama-index-vector-stores-chroma
#pip install chromadb
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore

@st.cache_resource(show_spinner=False)
class SentenceTransformerRerank():
    def __init__(
        self,
        top_n,
        model,
        device = "cpu",
    ):
        
        self.model = CrossEncoder(
            model, max_length=512, device=device
        )
        self.top_n=top_n
        
 
    def predict(self,nodes,query = None,
    ) :
        query_and_nodes = [
            (str(query),str(nodes[i].text))
            for i in range(len(nodes))
        ]
        def predict_score(pair):
            return self.model.predict([pair])[0]

        #scores = self.model.predict(query_and_nodes, num_workers=10)
        scores = []
        with ThreadPoolExecutor() as executor:
            # Submit tasks to the executor
            future_to_index = {executor.submit(predict_score, pair): idx for idx, pair in enumerate(query_and_nodes)}
            for future in as_completed(future_to_index):
                idx = future_to_index[future]
                try:
                    score = future.result()
                    scores.append((idx, score))
                except Exception as exc:
                    print(f'Generated an exception: {exc}')
         # Assign scores back to nodes
        for idx, score in scores:
            nodes[idx].score = score

        new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
            : self.top_n
        ]
        return new_nodes

@st.cache_resource(show_spinner=False)
def load_data():   
    with open('nodes_clean.pkl', 'rb') as file:
        embed_model, reranker=load_models()
        #chroma_client = chromadb.EphemeralClient()
        #chroma_collection = chroma_client.create_collection("quickstart")
        #vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        nodes=pickle.load( file)
        d = 768
        faiss_index = faiss.IndexFlatL2(d)
        vector_store = FaissVectorStore(faiss_index=faiss_index )
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        # use later nodes_clean
        index = VectorStoreIndex(nodes,embed_model=embed_model,storage_context=storage_context)
        retriever_dense = index.as_retriever(similarity_top_k=35,embedding=True)
        retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10)
        hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25,reranker)
        return hybrid_retriever


@st.cache_resource(show_spinner=False)
def load_models():
    EMBEDDING_MODEL = "BAAI/llm-embedder"
    RANK_MODEL_NAME = "BAAI/bge-reranker-base"
    embed_model = HuggingFaceEmbedding(EMBEDDING_MODEL, device='cpu')
    reranker = SentenceTransformerRerank(top_n=25, model=RANK_MODEL_NAME, device='cpu')
    return embed_model, reranker


class HybridRetriever(BaseRetriever):
    def __init__(self, vector_retriever, bm25_retriever,reranker):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        self.reranker = reranker
        super().__init__()
    def _retrieve(self, query, **kwargs):
        with ThreadPoolExecutor() as executor:
            bm25_future = executor.submit(self.bm25_retriever.retrieve, query, **kwargs)
            vector_future = executor.submit(self.vector_retriever.retrieve, query, **kwargs)
            
        bm25_nodes = bm25_future.result()
        vector_nodes = vector_future.result()
        # combine the two lists of nodes
        dense_n=20
        bm25_n=2
        combined_nodes = vector_nodes[dense_n:] + bm25_nodes[bm25_n:]

        all_nodes = []
        node_ids = set()
        for n in bm25_nodes.copy()[:bm25_n] + vector_nodes[:dense_n]:
            if n.node.node_id not in node_ids:
                all_nodes.append(n)
                node_ids.add(n.node.node_id)
        #reRank only best of retrieved_nodes
        reranked_nodes = self.reranker.predict(
            all_nodes,query
            )
        
        return reranked_nodes+combined_nodes


import re
def clean_whitespace(text,k=5):
    text = text.strip()
    text=" ".join([i for i in  text.split("\n")[:k] if len(i.strip())>25]+text.split("\n")[k:])
    text = re.sub(r"\.EU", "", text)
    #text = re.sub(r"\n+", "\n", text)
    text = re.sub(r"\s+", " ", text)
    return text.lower()