secilozksen commited on
Commit
30dce9f
·
1 Parent(s): 18665b8

files updated

Browse files
basecamp-dpr-contriever-embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:413837017c7b17e8e44556d9ab0cc9d42c9b24d3d28b29a39f3e7e143bd9f482
3
+ size 856086
basecamp.csv CHANGED
The diff for this file is too large to render. See raw diff
 
demo_dpr.py CHANGED
@@ -7,7 +7,7 @@ from sentence_transformers.cross_encoder import CrossEncoder
7
  from st_aggrid import GridOptionsBuilder, AgGrid
8
  import pickle
9
  import torch
10
- from transformers import DPRQuestionEncoderTokenizer, AutoModel
11
  from pathlib import Path
12
  import base64
13
  import io
@@ -20,13 +20,11 @@ DATAFRAME_FILE_BSBS = 'basecamp.csv'
20
  selectbox_selections = {
21
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
22
  'Dense Passage Retrieval':2,
23
- 'Retrieve - Reranking with DPR':3,
24
  'Retrieve - Rerank':4
25
  }
26
  imagebox_selections = {
27
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
28
  'Dense Passage Retrieval': 'DPR_pipeline.png',
29
- 'Retrieve - Reranking with DPR': 'Retrieve-rerank-DPR.png',
30
  'Retrieve - Rerank': 'retrieve-rerank.png'
31
  }
32
 
@@ -63,7 +61,7 @@ class CPU_Unpickler(pickle.Unpickler):
63
  @st.cache(show_spinner=False, allow_output_mutation=True)
64
  def load_paragraphs(path):
65
  with open(path, "rb") as fIn:
66
- cache_data = CPU_Unpickler(fIn).load()
67
  corpus_sentences = cache_data['contexes']
68
  corpus_embeddings = cache_data['embeddings']
69
 
@@ -84,45 +82,25 @@ def dot_product(question_output, context_output):
84
  result = torch.dot(mat1, mat2)
85
  return result
86
 
87
- def retrieve_rerank_DPR(question):
88
- hits = retrieve(question)
89
- return rerank_with_DPR(hits, question)
90
-
91
- def DPR_reranking(question, selected_contexes, selected_embeddings):
92
- scores = []
93
- tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
94
- add_special_tokens=True)
95
- question_output = dpr_trained.model.question_model(**tokenized_question)
96
- question_output = question_output['pooler_output']
97
- for context_embedding in selected_embeddings:
98
- score = dot_product(question_output, context_embedding)
99
- scores.append(score.detach().cpu())
100
-
101
- scores_index = sorted(range(len(scores)), key=lambda x: scores[x], reverse=True)
102
- contexes_list = []
103
- scores_final = []
104
- for i, idx in enumerate(scores_index[:5]):
105
- scores_final.append(scores[idx])
106
- contexes_list.append(selected_contexes[idx])
107
- return scores_final, contexes_list
108
-
109
  def search_pipeline(question, search_method):
110
  if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
111
  return retrieve_rerank_with_trained_cross_encoder(question)
112
  if search_method == 2:
113
  return custom_dpr_pipeline(question) # DPR only
114
- if search_method == 3:
115
- return retrieve_rerank_DPR(question)
116
  if search_method == 4:
117
  return retrieve_rerank(question)
118
 
 
 
 
 
119
 
120
  def custom_dpr_pipeline(question):
121
  #paragraphs
122
- tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
123
- add_special_tokens=True)
124
  question_embedding = dpr_trained.model.question_model(**tokenized_question)
125
- question_embedding = question_embedding['pooler_output']
 
126
  results_list = []
127
  for i,context_embedding in enumerate(dpr_context_embeddings):
128
  score = dot_product(question_embedding, context_embedding)
@@ -145,35 +123,13 @@ def retrieve(question):
145
  hits = hits[0]
146
  return hits
147
 
148
- def retrieve_with_dpr_embeddings(question):
149
- # Semantic Search (Retrieve)
150
- question_tokens = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt",
151
- add_special_tokens=True)
152
-
153
- question_embedding = dpr_trained.model.question_model(**question_tokens)['pooler_output']
154
- question_embedding = torch.squeeze(question_embedding, dim=0)
155
- corpus_embeddings = torch.stack(dpr_context_embeddings)
156
- corpus_embeddings = torch.squeeze(corpus_embeddings, dim=1)
157
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=100, score_function=util.dot_score)
158
- if len(hits) == 0:
159
- return []
160
- hits = hits[0]
161
- return hits, question_embedding
162
-
163
- def rerank_with_DPR(hits, question_embedding):
164
- # Rerank - score all retrieved passages with cross-encoder
165
- selected_contexes = [dpr_contexes[hit['corpus_id']] for hit in hits]
166
- selected_embeddings = [dpr_context_embeddings[hit['corpus_id']] for hit in hits]
167
- top_5_scores, top_5_contexes = DPR_reranking(question_embedding, selected_contexes, selected_embeddings)
168
- return top_5_contexes, top_5_scores
169
-
170
  def retrieve_rerank_with_trained_cross_encoder(question):
171
  hits = retrieve(question)
172
  cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
173
  cross_scores = trained_cross_encoder.predict(cross_inp)
174
  # Sort results by the cross-encoder scores
175
  for idx in range(len(cross_scores)):
176
- hits[idx]['cross-score'] = cross_scores[idx][1]
177
 
178
  # Output of top-5 hits from re-ranker
179
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
@@ -263,19 +219,22 @@ def qa_main_widgetsv2():
263
 
264
  @st.cache(show_spinner=False, allow_output_mutation = True)
265
  def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
 
266
  dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token,
267
  trust_remote_code=True)
268
  bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
269
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
270
  bi_encoder.max_seq_length = 500
271
  trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
272
- question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
273
  return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
274
 
275
  context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
276
- dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-context-embeddings.pkl')
277
  dataframe_bsbs = load_dataframes()
278
  dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
279
-
280
  qa_main_widgetsv2()
281
 
 
 
 
 
7
  from st_aggrid import GridOptionsBuilder, AgGrid
8
  import pickle
9
  import torch
10
+ from transformers import AutoTokenizer, AutoModel
11
  from pathlib import Path
12
  import base64
13
  import io
 
20
  selectbox_selections = {
21
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
22
  'Dense Passage Retrieval':2,
 
23
  'Retrieve - Rerank':4
24
  }
25
  imagebox_selections = {
26
  'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
27
  'Dense Passage Retrieval': 'DPR_pipeline.png',
 
28
  'Retrieve - Rerank': 'retrieve-rerank.png'
29
  }
30
 
 
61
  @st.cache(show_spinner=False, allow_output_mutation=True)
62
  def load_paragraphs(path):
63
  with open(path, "rb") as fIn:
64
+ cache_data = pickle.load(fIn)
65
  corpus_sentences = cache_data['contexes']
66
  corpus_embeddings = cache_data['embeddings']
67
 
 
82
  result = torch.dot(mat1, mat2)
83
  return result
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def search_pipeline(question, search_method):
86
  if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
87
  return retrieve_rerank_with_trained_cross_encoder(question)
88
  if search_method == 2:
89
  return custom_dpr_pipeline(question) # DPR only
 
 
90
  if search_method == 4:
91
  return retrieve_rerank(question)
92
 
93
+ def mean_pooling(token_embeddings, mask):
94
+ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
95
+ sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
96
+ return sentence_embeddings
97
 
98
  def custom_dpr_pipeline(question):
99
  #paragraphs
100
+ tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt")
 
101
  question_embedding = dpr_trained.model.question_model(**tokenized_question)
102
+ question_embedding = mean_pooling(question_embedding[0], tokenized_question['attention_mask'])
103
+ # question_embedding = question_embedding['pooler_output']
104
  results_list = []
105
  for i,context_embedding in enumerate(dpr_context_embeddings):
106
  score = dot_product(question_embedding, context_embedding)
 
123
  hits = hits[0]
124
  return hits
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def retrieve_rerank_with_trained_cross_encoder(question):
127
  hits = retrieve(question)
128
  cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
129
  cross_scores = trained_cross_encoder.predict(cross_inp)
130
  # Sort results by the cross-encoder scores
131
  for idx in range(len(cross_scores)):
132
+ hits[idx]['cross-score'] = cross_scores[idx]
133
 
134
  # Output of top-5 hits from re-ranker
135
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
 
219
 
220
  @st.cache(show_spinner=False, allow_output_mutation = True)
221
  def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
222
+
223
  dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token,
224
  trust_remote_code=True)
225
  bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
226
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
227
  bi_encoder.max_seq_length = 500
228
  trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
229
+ question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
230
  return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer
231
 
232
  context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
233
+ dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
234
  dataframe_bsbs = load_dataframes()
235
  dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
 
236
  qa_main_widgetsv2()
237
 
238
+ #if __name__ == '__main__':
239
+ # top_5_contexes, top_5_scores = search_pipeline('What are the benefits of 37Signals Visa Card?', 1)
240
+
st-context-embeddings.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd65fe793062375df1efd50218e9a7c35253fe06a24e5527de7855671a4f958c
3
- size 468299
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79e231244e12074d5e22f46cf3da70f4f1dd43cc6e82f36959d2c6817f2e2bf2
3
+ size 441107