UlanYisaev commited on
Commit
6349cfb
1 Parent(s): d2b47b7

Upload 5 files

Browse files
Files changed (4) hide show
  1. gitattributes +17 -0
  2. special_tokens_map.json +1 -51
  3. tokenizer_config.json +1 -55
  4. train_script.py +253 -0
gitattributes ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
special_tokens_map.json CHANGED
@@ -1,51 +1 @@
1
- {
2
- "bos_token": {
3
- "content": "<s>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "cls_token": {
10
- "content": "<s>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- },
16
- "eos_token": {
17
- "content": "</s>",
18
- "lstrip": false,
19
- "normalized": false,
20
- "rstrip": false,
21
- "single_word": false
22
- },
23
- "mask_token": {
24
- "content": "<mask>",
25
- "lstrip": false,
26
- "normalized": false,
27
- "rstrip": false,
28
- "single_word": false
29
- },
30
- "pad_token": {
31
- "content": "<pad>",
32
- "lstrip": false,
33
- "normalized": false,
34
- "rstrip": false,
35
- "single_word": false
36
- },
37
- "sep_token": {
38
- "content": "</s>",
39
- "lstrip": false,
40
- "normalized": false,
41
- "rstrip": false,
42
- "single_word": false
43
- },
44
- "unk_token": {
45
- "content": "<unk>",
46
- "lstrip": false,
47
- "normalized": false,
48
- "rstrip": false,
49
- "single_word": false
50
- }
51
- }
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer_config.json CHANGED
@@ -1,55 +1 @@
1
- {
2
- "added_tokens_decoder": {
3
- "0": {
4
- "content": "<s>",
5
- "lstrip": false,
6
- "normalized": false,
7
- "rstrip": false,
8
- "single_word": false,
9
- "special": true
10
- },
11
- "1": {
12
- "content": "<pad>",
13
- "lstrip": false,
14
- "normalized": false,
15
- "rstrip": false,
16
- "single_word": false,
17
- "special": true
18
- },
19
- "2": {
20
- "content": "</s>",
21
- "lstrip": false,
22
- "normalized": false,
23
- "rstrip": false,
24
- "single_word": false,
25
- "special": true
26
- },
27
- "3": {
28
- "content": "<unk>",
29
- "lstrip": false,
30
- "normalized": false,
31
- "rstrip": false,
32
- "single_word": false,
33
- "special": true
34
- },
35
- "250001": {
36
- "content": "<mask>",
37
- "lstrip": false,
38
- "normalized": false,
39
- "rstrip": false,
40
- "single_word": false,
41
- "special": true
42
- }
43
- },
44
- "bos_token": "<s>",
45
- "clean_up_tokenization_spaces": true,
46
- "cls_token": "<s>",
47
- "eos_token": "</s>",
48
- "mask_token": "<mask>",
49
- "model_max_length": 1000000000000000019884624838656,
50
- "pad_token": "<pad>",
51
- "sep_token": "</s>",
52
- "sp_model_kwargs": {},
53
- "tokenizer_class": "XLMRobertaTokenizer",
54
- "unk_token": "<unk>"
55
- }
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "special_tokens_map_file": "/root/.cache/huggingface/transformers/8ed73a1ab9ef4e90a9451497bf96cfc38d34354352838a143f2dda1c81aed5ca.0dc5b1041f62041ebbd23b1297f2f573769d5c97d8b7c28180ec86b8f6185aa8", "name_or_path": "microsoft/Multilingual-MiniLM-L12-H384", "sp_model_kwargs": {}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train_script.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import random
3
+
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW
5
+ import sys
6
+ import torch
7
+ import transformers
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torch.cuda.amp import autocast
10
+ import tqdm
11
+ from datetime import datetime
12
+ from shutil import copyfile
13
+ import os
14
+ ####################################
15
+
16
+ import gzip
17
+ from collections import defaultdict
18
+ import logging
19
+ import tqdm
20
+ import numpy as np
21
+ import sys
22
+ import pytrec_eval
23
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
24
+ import torch
25
+
26
+
27
+ model_name = sys.argv[1]
28
+ max_length = 350
29
+
30
+ ######### Evaluation
31
+ queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz'
32
+ queries_eval = {}
33
+ with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
34
+ for line in fIn:
35
+ qid, query = line.strip().split("\t")[0:2]
36
+ queries_eval[qid] = query
37
+
38
+ rel = defaultdict(lambda: defaultdict(int))
39
+
40
+ with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn:
41
+ for line in fIn:
42
+ qid, _, pid, score = line.strip().split()
43
+ score = int(score)
44
+ if score > 0:
45
+ rel[qid][pid] = score
46
+
47
+ relevant_qid = []
48
+ for qid in queries_eval:
49
+ if len(rel[qid]) > 0:
50
+ relevant_qid.append(qid)
51
+
52
+ # Read top 1k
53
+ passage_cand = {}
54
+
55
+ with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn:
56
+ for line in fIn:
57
+ qid, pid, query, passage = line.strip().split("\t")
58
+ if qid not in passage_cand:
59
+ passage_cand[qid] = []
60
+
61
+ passage_cand[qid].append([pid, passage])
62
+
63
+
64
+
65
+ def eval_modal(model_path):
66
+ run = {}
67
+ model = CrossEncoder(model_path, max_length=512)
68
+
69
+ for qid in relevant_qid:
70
+ query = queries_eval[qid]
71
+
72
+ cand = passage_cand[qid]
73
+ pids = [c[0] for c in cand]
74
+ corpus_sentences = [c[1] for c in cand]
75
+
76
+ ## CrossEncoder
77
+ cross_inp = [[query, sent] for sent in corpus_sentences]
78
+ if model.config.num_labels > 1:
79
+ cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
80
+ else:
81
+ cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist()
82
+
83
+ cross_scores_sparse = {}
84
+ for idx, pid in enumerate(pids):
85
+ cross_scores_sparse[pid] = cross_scores[idx]
86
+
87
+ sparse_scores = cross_scores_sparse
88
+ run[qid] = {}
89
+ for pid in sparse_scores:
90
+ run[qid][pid] = float(sparse_scores[pid])
91
+
92
+ evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'})
93
+ scores = evaluator.evaluate(run)
94
+ scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()])
95
+
96
+ print("NDCG@10: {:.2f}".format(scores_mean * 100))
97
+ return scores_mean
98
+
99
+ ################################
100
+
101
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
102
+ config = AutoConfig.from_pretrained(model_name)
103
+ config.num_labels = 1
104
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
105
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
106
+
107
+
108
+
109
+
110
+ #######################
111
+
112
+ queries = {}
113
+ corpus = {}
114
+
115
+ output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
116
+ output_save_path_latest = output_save_path+"-latest"
117
+ tokenizer.save_pretrained(output_save_path)
118
+ tokenizer.save_pretrained(output_save_path_latest)
119
+
120
+
121
+ # Write self to path
122
+ train_script_path = os.path.join(output_save_path, 'train_script.py')
123
+ copyfile(__file__, train_script_path)
124
+ with open(train_script_path, 'a') as fOut:
125
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
126
+
127
+
128
+ ####
129
+ train_script_path = os.path.join(output_save_path_latest, 'train_script.py')
130
+ copyfile(__file__, train_script_path)
131
+ with open(train_script_path, 'a') as fOut:
132
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
133
+
134
+
135
+
136
+ #### Read train files
137
+ class MultilingualDataset(Dataset):
138
+ def __init__(self):
139
+ self.examples = defaultdict(lambda: defaultdict(list)) #[id][lang] => [samples...]
140
+
141
+ def add(self, lang, filepath):
142
+ open_method = gzip.open if filepath.endswith('.gz') else open
143
+ with open_method(filepath, 'rt') as fIn:
144
+ for line in fIn:
145
+ pid, passage = line.strip().split("\t")
146
+ self.examples[pid][lang].append(passage)
147
+
148
+
149
+ def __len__(self):
150
+ return len(self.examples)
151
+
152
+ def __getitem__(self, item):
153
+ all_examples = self.examples[item] #All examples in all languages
154
+ lang_examples = random.choice(list(all_examples.values())) #Examples in on specific language
155
+ return random.choice(lang_examples) #One random example
156
+
157
+
158
+ train_corpus = MultilingualDataset()
159
+ train_corpus.add('en', 'msmarco-data/collection.tsv')
160
+ train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz')
161
+ train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz')
162
+
163
+
164
+ train_queries = MultilingualDataset()
165
+ train_queries.add('en', 'msmarco-data/queries.train.tsv')
166
+ train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz')
167
+ train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz')
168
+
169
+ ############## MSE Dataset
170
+ class MSEDataset(Dataset):
171
+ def __init__(self, filepath):
172
+ super().__init__()
173
+
174
+ self.examples = []
175
+ with open(filepath) as fIn:
176
+ for line in fIn:
177
+ pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
178
+ self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)])
179
+
180
+ def __len__(self):
181
+ return len(self.examples)
182
+
183
+ def __getitem__(self, item):
184
+ return self.examples[item]
185
+
186
+ train_batch_size = 16
187
+ train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
188
+ train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size)
189
+
190
+
191
+ ############## Optimizer
192
+
193
+ weight_decay = 0.01
194
+ max_grad_norm = 1
195
+ param_optimizer = list(model.named_parameters())
196
+
197
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
198
+ optimizer_grouped_parameters = [
199
+ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
200
+ {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
201
+ ]
202
+
203
+ optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
204
+ scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader))
205
+ scaler = torch.cuda.amp.GradScaler()
206
+
207
+ loss_fct = torch.nn.MSELoss()
208
+ ### Start training
209
+ model.to(device)
210
+
211
+ auto_save = 10000
212
+ best_ndcg_score = 0
213
+ for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
214
+ batch_queries = [train_queries[qid] for qid in batch[0]]
215
+ batch_pos = [train_corpus[cid] for cid in batch[1]]
216
+ batch_neg = [train_corpus[cid] for cid in batch[2]]
217
+ scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device)
218
+
219
+ with autocast():
220
+ inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
221
+ pred_pos = model(**inp_pos).logits.squeeze()
222
+
223
+ inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
224
+ pred_neg = model(**inp_neg).logits.squeeze()
225
+
226
+ pred_diff = pred_pos - pred_neg
227
+ loss_value = loss_fct(pred_diff, scores)
228
+
229
+
230
+ scaler.scale(loss_value).backward()
231
+ scaler.unscale_(optimizer)
232
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
233
+ scaler.step(optimizer)
234
+ scaler.update()
235
+
236
+ optimizer.zero_grad()
237
+ scheduler.step()
238
+
239
+ if (step_idx+1) % auto_save == 0:
240
+ print("Step:", step_idx+1)
241
+ model.save_pretrained(output_save_path_latest)
242
+ ndcg_score = eval_modal(output_save_path_latest)
243
+
244
+ if ndcg_score >= best_ndcg_score:
245
+ best_ndcg_score = ndcg_score
246
+ print("Save to:", output_save_path)
247
+ model.save_pretrained(output_save_path)
248
+
249
+ model.save_pretrained(output_save_path)
250
+
251
+
252
+ # Script was called via:
253
+ #python train_cross-encoder_mse_multilingual.py microsoft/Multilingual-MiniLM-L12-H384