Arthur-75 commited on
Commit
b2ba772
1 Parent(s): 37603bc

Update utils_st.py

Browse files
Files changed (1) hide show
  1. utils_st.py +22 -18
utils_st.py CHANGED
@@ -30,18 +30,11 @@ class SentenceTransformerRerank():
30
  self.top_n=top_n
31
 
32
 
33
- def predict(
34
- self,
35
- nodes,
36
- query_bundle = None,
37
  ) :
38
-
39
  query_and_nodes = [
40
- (
41
- query_bundle,
42
- node.text,
43
- )
44
- for node in nodes
45
  ]
46
  def predict_score(pair):
47
  return self.model.predict([pair])[0]
@@ -68,18 +61,19 @@ class SentenceTransformerRerank():
68
  return new_nodes
69
 
70
  @st.cache_resource(show_spinner=False)
71
- def load_data(_embed_model):
72
  with open('nodes_clean.pkl', 'rb') as file:
 
73
  nodes=pickle.load( file)
74
  d = 768
75
  faiss_index = faiss.IndexFlatL2(d)
76
  vector_store = FaissVectorStore(faiss_index=faiss_index )
77
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
78
  # use later nodes_clean
79
- index = VectorStoreIndex(nodes,embed_model=_embed_model,storage_context=storage_context)
80
- retriever_dense = index.as_retriever(similarity_top_k=20,embedding=True)
81
- retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=5)
82
- hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25)
83
  return hybrid_retriever
84
 
85
 
@@ -93,9 +87,10 @@ def load_models():
93
 
94
 
95
  class HybridRetriever(BaseRetriever):
96
- def __init__(self, vector_retriever, bm25_retriever):
97
  self.vector_retriever = vector_retriever
98
  self.bm25_retriever = bm25_retriever
 
99
  super().__init__()
100
  def _retrieve(self, query, **kwargs):
101
  with ThreadPoolExecutor() as executor:
@@ -105,13 +100,22 @@ class HybridRetriever(BaseRetriever):
105
  bm25_nodes = bm25_future.result()
106
  vector_nodes = vector_future.result()
107
  # combine the two lists of nodes
 
 
 
 
108
  all_nodes = []
109
  node_ids = set()
110
- for n in bm25_nodes + vector_nodes:
111
  if n.node.node_id not in node_ids:
112
  all_nodes.append(n)
113
  node_ids.add(n.node.node_id)
114
- return all_nodes
 
 
 
 
 
115
 
116
 
117
  import re
 
30
  self.top_n=top_n
31
 
32
 
33
+ def predict(self,nodes,query = None,
 
 
 
34
  ) :
 
35
  query_and_nodes = [
36
+ (str(query),str(nodes[i].text))
37
+ for i in range(len(nodes))
 
 
 
38
  ]
39
  def predict_score(pair):
40
  return self.model.predict([pair])[0]
 
61
  return new_nodes
62
 
63
  @st.cache_resource(show_spinner=False)
64
+ def load_data():
65
  with open('nodes_clean.pkl', 'rb') as file:
66
+ embed_model, reranker=load_models()
67
  nodes=pickle.load( file)
68
  d = 768
69
  faiss_index = faiss.IndexFlatL2(d)
70
  vector_store = FaissVectorStore(faiss_index=faiss_index )
71
  storage_context = StorageContext.from_defaults(vector_store=vector_store)
72
  # use later nodes_clean
73
+ index = VectorStoreIndex(nodes,embed_model=embed_model,storage_context=storage_context)
74
+ retriever_dense = index.as_retriever(similarity_top_k=40,embedding=True)
75
+ retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=15)
76
+ hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25,reranker)
77
  return hybrid_retriever
78
 
79
 
 
87
 
88
 
89
  class HybridRetriever(BaseRetriever):
90
+ def __init__(self, vector_retriever, bm25_retriever,reranker):
91
  self.vector_retriever = vector_retriever
92
  self.bm25_retriever = bm25_retriever
93
+ self.reranker = reranker
94
  super().__init__()
95
  def _retrieve(self, query, **kwargs):
96
  with ThreadPoolExecutor() as executor:
 
100
  bm25_nodes = bm25_future.result()
101
  vector_nodes = vector_future.result()
102
  # combine the two lists of nodes
103
+ dense_n=20
104
+ bm25_n=5
105
+ combined_nodes = vector_nodes[dense_n:] + bm25_nodes[bm25_n:]
106
+
107
  all_nodes = []
108
  node_ids = set()
109
+ for n in bm25_nodes.copy()[:bm25_n] + vector_nodes[:dense_n]:
110
  if n.node.node_id not in node_ids:
111
  all_nodes.append(n)
112
  node_ids.add(n.node.node_id)
113
+ #reRank only best of retrieved_nodes
114
+ reranked_nodes = self.reranker.predict(
115
+ all_nodes,query
116
+ )
117
+
118
+ return reranked_nodes+combined_nodes
119
 
120
 
121
  import re