nreimers commited on
Commit
9bdbf4c
·
1 Parent(s): 6fb179e
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #cross_encoder-msmarco-word2vec256k
2
+
3
+ This CrossEncoder was trained with MarginMSE loss from the [nicoladecao/msmarco-word2vec256000-distilbert-base-uncased](https://hf.co/nicoladecao/msmarco-word2vec256000-distilbert-base-uncased) checkpoint. **Word embedding matrix has been frozen during training**.
4
+
5
+ You can load the model with [sentence-transformers](https://sbert.net):
6
+ ```python
7
+ from sentence_transformers import CrossEncoder
8
+ from torch import nn
9
+ model = CrossEncoder(model_name, default_activation_function=nn.Identity())
10
+ ```
11
+
12
+ Performance on TREC Deep Learning (nDCG@10):
13
+ - TREC-DL 19: 72.49
14
+ - TREC-DL 20: 72.71
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "DistilBertForSequenceClassification"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "LABEL_0"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "label2id": {
16
+ "LABEL_0": 0
17
+ },
18
+ "max_position_embeddings": 512,
19
+ "model_type": "distilbert",
20
+ "n_heads": 12,
21
+ "n_layers": 6,
22
+ "pad_token_id": 0,
23
+ "qa_dropout": 0.1,
24
+ "seq_classif_dropout": 0.2,
25
+ "sinusoidal_pos_embds": false,
26
+ "tie_weights_": true,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.11.3",
29
+ "vocab_size": 256000
30
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eabe1045290bb03411ddfd73fb87a43f997a64c3bffcefd9939824eebe6b7c1
3
+ size 960528535
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 512, "unk_token": "[UNK]", "cls_token": "[CLS]", "sep_token": "[SEP]", "pad_token": "[PAD]", "mask_token": "[MASK]", "model_input_names": ["input_ids", "attention_mask"], "special_tokens_map_file": "/root/.cache/huggingface/transformers/fe09c361189d8238b9e387f10a088e93f70620bfe74b82036baff1fed512a153.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d", "name_or_path": "nicoladecao/msmarco-word2vec256000-distilbert-base-uncased", "tokenizer_class": "PreTrainedTokenizerFast"}
train_script.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW
3
+ import sys
4
+ import torch
5
+ import transformers
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torch.cuda.amp import autocast
8
+ import tqdm
9
+ from datetime import datetime
10
+ from shutil import copyfile
11
+ import os
12
+ ####################################
13
+
14
+ import gzip
15
+ from collections import defaultdict
16
+ import logging
17
+ import tqdm
18
+ import numpy as np
19
+ import sys
20
+ import pytrec_eval
21
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
22
+ import torch
23
+
24
+
25
+ ######### Evaluation
26
+ queries_filepath = '/home/msmarco/data/trec2019/msmarco-test2019-queries.tsv.gz'
27
+ queries_eval = {}
28
+ with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
29
+ for line in fIn:
30
+ qid, query = line.strip().split("\t")[0:2]
31
+ queries_eval[qid] = query
32
+
33
+ rel = defaultdict(lambda: defaultdict(int))
34
+
35
+ with open('/home/msmarco/data/trec2019/2019qrels-pass.txt') as fIn:
36
+ for line in fIn:
37
+ qid, _, pid, score = line.strip().split()
38
+ score = int(score)
39
+ if score > 0:
40
+ rel[qid][pid] = score
41
+
42
+ relevant_qid = []
43
+ for qid in queries_eval:
44
+ if len(rel[qid]) > 0:
45
+ relevant_qid.append(qid)
46
+
47
+ # Read top 1k
48
+ passage_cand = {}
49
+
50
+ with gzip.open('/home/msmarco/data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn:
51
+ for line in fIn:
52
+ qid, pid, query, passage = line.strip().split("\t")
53
+ if qid not in passage_cand:
54
+ passage_cand[qid] = []
55
+
56
+ passage_cand[qid].append([pid, passage])
57
+
58
+
59
+
60
+ def eval_modal(model_path):
61
+ run = {}
62
+ model = CrossEncoder(model_path, max_length=512)
63
+
64
+ for qid in relevant_qid:
65
+ query = queries_eval[qid]
66
+
67
+ cand = passage_cand[qid]
68
+ pids = [c[0] for c in cand]
69
+ corpus_sentences = [c[1] for c in cand]
70
+
71
+ ## CrossEncoder
72
+ cross_inp = [[query, sent] for sent in corpus_sentences]
73
+ if model.config.num_labels > 1:
74
+ cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
75
+ else:
76
+ cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist()
77
+
78
+ cross_scores_sparse = {}
79
+ for idx, pid in enumerate(pids):
80
+ cross_scores_sparse[pid] = cross_scores[idx]
81
+
82
+ sparse_scores = cross_scores_sparse
83
+ run[qid] = {}
84
+ for pid in sparse_scores:
85
+ run[qid][pid] = float(sparse_scores[pid])
86
+
87
+ evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'})
88
+ scores = evaluator.evaluate(run)
89
+ scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()])
90
+
91
+ print("NDCG@10: {:.2f}".format(scores_mean * 100))
92
+ return scores_mean
93
+
94
+ ################################
95
+
96
+ model_name = sys.argv[1]
97
+
98
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
99
+ config = AutoConfig.from_pretrained(model_name)
100
+ config.num_labels = 1
101
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
102
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
103
+
104
+ ## Freeze embedding layer
105
+ model.distilbert.embeddings.word_embeddings.requires_grad_(False)
106
+
107
+
108
+
109
+
110
+
111
+ #######################
112
+
113
+ queries = {}
114
+ corpus = {}
115
+
116
+ output_save_path = 'output-ce-emb_frozen/{}-{}'.format(model_name.replace("/", "-"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
117
+ output_save_path_latest = output_save_path+"-latest"
118
+ tokenizer.save_pretrained(output_save_path)
119
+ tokenizer.save_pretrained(output_save_path_latest)
120
+
121
+
122
+ # Write self to path
123
+ train_script_path = os.path.join(output_save_path, 'train_script.py')
124
+ copyfile(__file__, train_script_path)
125
+ with open(train_script_path, 'a') as fOut:
126
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
127
+
128
+
129
+ ####
130
+ train_script_path = os.path.join(output_save_path_latest, 'train_script.py')
131
+ copyfile(__file__, train_script_path)
132
+ with open(train_script_path, 'a') as fOut:
133
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
134
+
135
+
136
+
137
+ #### Read train file
138
+ with gzip.open('/home/msmarco/data/collection.tsv.gz', 'rt') as fIn:
139
+ for line in fIn:
140
+ pid, passage = line.strip().split("\t")
141
+ corpus[pid] = passage
142
+
143
+ with open('/home/msmarco/data/queries.train.tsv', 'r') as fIn:
144
+ for line in fIn:
145
+ qid, query = line.strip().split("\t")
146
+ queries[qid] = query
147
+
148
+
149
+ ############## Train Dataset
150
+ class MSEDataset(Dataset):
151
+ def __init__(self, filepath):
152
+ super().__init__()
153
+
154
+ self.examples = []
155
+ with open(filepath) as fIn:
156
+ for line in fIn:
157
+ pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
158
+ self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)])
159
+
160
+ def __len__(self):
161
+ return len(self.examples)
162
+
163
+ def __getitem__(self, item):
164
+ return self.examples[item]
165
+
166
+ train_batch_size = 32
167
+ train_dataset = MSEDataset('/home/msmarco/data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
168
+ train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=16)
169
+
170
+
171
+ ############## Optimizer
172
+
173
+ weight_decay = 0.01
174
+ max_grad_norm = 1
175
+ param_optimizer = list(model.named_parameters())
176
+
177
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
178
+ optimizer_grouped_parameters = [
179
+ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
180
+ {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
181
+ ]
182
+
183
+ optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
184
+ scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader))
185
+ scaler = torch.cuda.amp.GradScaler()
186
+
187
+ loss_fct = torch.nn.MSELoss()
188
+ ### Start training
189
+ model.to(device)
190
+
191
+ auto_save = 10000
192
+ best_ndcg_score = 0
193
+ for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
194
+ batch_queries = [queries[qid] for qid in batch[0]]
195
+ pos = [corpus[cid] for cid in batch[1]]
196
+ neg = [corpus[cid] for cid in batch[2]]
197
+ scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device)
198
+
199
+ with autocast():
200
+ inp_pos = tokenizer(batch_queries, pos, max_length=512, padding=True, truncation='longest_first', return_tensors='pt').to(device)
201
+ pred_pos = model(**inp_pos).logits.squeeze()
202
+
203
+ inp_neg = tokenizer(batch_queries, neg, max_length=512, padding=True, truncation='longest_first', return_tensors='pt').to(device)
204
+ pred_neg = model(**inp_neg).logits.squeeze()
205
+
206
+ pred_diff = pred_pos - pred_neg
207
+ loss_value = loss_fct(pred_diff, scores)
208
+
209
+
210
+ scaler.scale(loss_value).backward()
211
+ scaler.unscale_(optimizer)
212
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
213
+ scaler.step(optimizer)
214
+ scaler.update()
215
+
216
+ optimizer.zero_grad()
217
+ scheduler.step()
218
+
219
+ if (step_idx+1) % auto_save == 0:
220
+ print("Step:", step_idx+1)
221
+ model.save_pretrained(output_save_path_latest)
222
+ ndcg_score = eval_modal(output_save_path_latest)
223
+
224
+ if ndcg_score >= best_ndcg_score:
225
+ best_ndcg_score = ndcg_score
226
+ print("Save to:", output_save_path)
227
+ model.save_pretrained(output_save_path)
228
+
229
+ model.save_pretrained(output_save_path)
230
+
231
+
232
+ # Script was called via:
233
+ #python train_ce_emb_frozen.py nicoladecao/msmarco-word2vec256000-distilbert-base-uncased