Achyut Tiwari commited on
Commit
e49e418
1 Parent(s): 880018c

Add files via upload

Browse files
training/run_retriever_no_trainer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import logging
4
+ import math
5
+ from random import choice, randint
6
+
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from datasets import load_dataset
11
+ from torch.utils import checkpoint
12
+ from torch.utils.data import Dataset, RandomSampler, DataLoader, SequentialSampler
13
+ from tqdm.auto import tqdm
14
+ from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSequenceClassification
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_parser():
20
+ parser = argparse.ArgumentParser(description="Train ELI5 retriever")
21
+ parser.add_argument(
22
+ "--dataset_name",
23
+ type=str,
24
+ default="vblagoje/lfqa",
25
+ help="The name of the dataset to use (via the datasets library).",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--per_device_train_batch_size",
30
+ type=int,
31
+ default=1024,
32
+ )
33
+
34
+ parser.add_argument(
35
+ "--per_device_eval_batch_size",
36
+ type=int,
37
+ default=1024,
38
+ help="Batch size (per device) for the evaluation dataloader.",
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--max_length",
43
+ type=int,
44
+ default=128,
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--checkpoint_batch_size",
49
+ type=int,
50
+ default=32,
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--pretrained_model_name",
55
+ type=str,
56
+ default="google/bert_uncased_L-8_H-768_A-12",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--model_save_name",
61
+ type=str,
62
+ default="eli5_retriever_model_l-12_h-768_b-512-512",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--learning_rate",
67
+ type=float,
68
+ default=2e-4,
69
+ )
70
+
71
+ parser.add_argument(
72
+ "--weight_decay",
73
+ type=float,
74
+ default=0.2,
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--log_freq",
79
+ type=int,
80
+ default=500,
81
+ help="Log train/validation loss every log_freq update steps"
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--num_train_epochs",
86
+ type=int,
87
+ default=4,
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--max_train_steps",
92
+ type=int,
93
+ default=None,
94
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
95
+ )
96
+
97
+ parser.add_argument(
98
+ "--gradient_accumulation_steps",
99
+ type=int,
100
+ default=1,
101
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--lr_scheduler_type",
106
+ type=SchedulerType,
107
+ default="linear", # this is linear with warmup
108
+ help="The scheduler type to use.",
109
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--num_warmup_steps",
114
+ type=int,
115
+ default=100,
116
+ help="Number of steps for the warmup in the lr scheduler."
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--warmup_percentage",
121
+ type=float,
122
+ default=0.08,
123
+ help="Number of steps for the warmup in the lr scheduler."
124
+ )
125
+ return parser
126
+
127
+
128
+ class RetrievalQAEmbedder(torch.nn.Module):
129
+ def __init__(self, sent_encoder):
130
+ super(RetrievalQAEmbedder, self).__init__()
131
+ dim = sent_encoder.config.hidden_size
132
+ self.bert_query = sent_encoder
133
+ self.output_dim = 128
134
+ self.project_query = torch.nn.Linear(dim, self.output_dim, bias=False)
135
+ self.project_doc = torch.nn.Linear(dim, self.output_dim, bias=False)
136
+ self.ce_loss = torch.nn.CrossEntropyLoss(reduction="mean")
137
+
138
+ def embed_sentences_checkpointed(self, input_ids, attention_mask, checkpoint_batch_size=-1):
139
+ # reproduces BERT forward pass with checkpointing
140
+ if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
141
+ return self.bert_query(input_ids, attention_mask=attention_mask)[1]
142
+ else:
143
+ # prepare implicit variables
144
+ device = input_ids.device
145
+ input_shape = input_ids.size()
146
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
147
+ head_mask = [None] * self.bert_query.config.num_hidden_layers
148
+ extended_attention_mask: torch.Tensor = self.bert_query.get_extended_attention_mask(
149
+ attention_mask, input_shape, device
150
+ )
151
+
152
+ # define function for checkpointing
153
+ def partial_encode(*inputs):
154
+ encoder_outputs = self.bert_query.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask, )
155
+ sequence_output = encoder_outputs[0]
156
+ pooled_output = self.bert_query.pooler(sequence_output)
157
+ return pooled_output
158
+
159
+ # run embedding layer on everything at once
160
+ embedding_output = self.bert_query.embeddings(
161
+ input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
162
+ )
163
+ # run encoding and pooling on one mini-batch at a time
164
+ pooled_output_list = []
165
+ for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
166
+ b_embedding_output = embedding_output[b * checkpoint_batch_size: (b + 1) * checkpoint_batch_size]
167
+ b_attention_mask = extended_attention_mask[b * checkpoint_batch_size: (b + 1) * checkpoint_batch_size]
168
+ pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
169
+ pooled_output_list.append(pooled_output)
170
+ return torch.cat(pooled_output_list, dim=0)
171
+
172
+ def embed_questions(self, q_ids, q_mask, checkpoint_batch_size=-1):
173
+ q_reps = self.embed_sentences_checkpointed(q_ids, q_mask, checkpoint_batch_size)
174
+ return self.project_query(q_reps)
175
+
176
+ def embed_answers(self, a_ids, a_mask, checkpoint_batch_size=-1):
177
+ a_reps = self.embed_sentences_checkpointed(a_ids, a_mask, checkpoint_batch_size)
178
+ return self.project_doc(a_reps)
179
+
180
+ def forward(self, q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=-1):
181
+ device = q_ids.device
182
+ q_reps = self.embed_questions(q_ids, q_mask, checkpoint_batch_size)
183
+ a_reps = self.embed_answers(a_ids, a_mask, checkpoint_batch_size)
184
+ compare_scores = torch.mm(q_reps, a_reps.t())
185
+ loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
186
+ loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
187
+ loss = (loss_qa + loss_aq) / 2
188
+ return loss
189
+
190
+
191
+ class ELI5DatasetQARetriever(Dataset):
192
+ def __init__(self, examples_array, extra_answer_threshold=3, min_answer_length=64, training=True, n_samples=None):
193
+ self.data = examples_array
194
+ self.answer_thres = extra_answer_threshold
195
+ self.min_length = min_answer_length
196
+ self.training = training
197
+ self.n_samples = self.data.num_rows if n_samples is None else n_samples
198
+
199
+ def __len__(self):
200
+ return self.n_samples
201
+
202
+ def make_example(self, idx):
203
+ example = self.data[idx]
204
+ question = example["title"]
205
+ if self.training:
206
+ answers = [a for i, (a, sc) in enumerate(zip(example["answers"]["text"], example["answers"]["score"]))]
207
+ answer_tab = choice(answers).split(" ")
208
+ start_idx = randint(0, max(0, len(answer_tab) - self.min_length))
209
+ answer_span = " ".join(answer_tab[start_idx:])
210
+ else:
211
+ answer_span = example["answers"]["text"][0]
212
+ return question, answer_span
213
+
214
+ def __getitem__(self, idx):
215
+ return self.make_example(idx % self.data.num_rows)
216
+
217
+
218
+ def make_qa_retriever_batch(qa_list, tokenizer, max_len=64):
219
+ q_ls = [q for q, a in qa_list]
220
+ a_ls = [a for q, a in qa_list]
221
+ q_toks = tokenizer(q_ls, padding="max_length", max_length=max_len, truncation=True)
222
+ q_ids, q_mask = (
223
+ torch.LongTensor(q_toks["input_ids"]),
224
+ torch.LongTensor(q_toks["attention_mask"])
225
+ )
226
+ a_toks = tokenizer(a_ls, padding="max_length", max_length=max_len, truncation=True)
227
+ a_ids, a_mask = (
228
+ torch.LongTensor(a_toks["input_ids"]),
229
+ torch.LongTensor(a_toks["attention_mask"]),
230
+ )
231
+ return q_ids, q_mask, a_ids, a_mask
232
+
233
+
234
+ def evaluate_qa_retriever(model, data_loader):
235
+ # make iterator
236
+ epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
237
+ tot_loss = 0.0
238
+ with torch.no_grad():
239
+ for step, batch in enumerate(epoch_iterator):
240
+ q_ids, q_mask, a_ids, a_mask = batch
241
+ loss = model(q_ids, q_mask, a_ids, a_mask)
242
+ tot_loss += loss.item()
243
+ return tot_loss / (step + 1)
244
+
245
+
246
+ def train(config):
247
+ set_seed(42)
248
+ args = config["args"]
249
+ data_files = {"train": "train.json", "validation": "validation.json", "test": "test.json"}
250
+ eli5 = load_dataset(args.dataset_name, data_files=data_files)
251
+
252
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
253
+ accelerator = Accelerator()
254
+ # Make one log on every process with the configuration for debugging.
255
+ logging.basicConfig(
256
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
257
+ datefmt="%m/%d/%Y %H:%M:%S",
258
+ level=logging.INFO,
259
+ )
260
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
261
+ logger.info(accelerator.state)
262
+
263
+ # prepare torch Dataset objects
264
+ train_dataset = ELI5DatasetQARetriever(eli5['train'], training=True)
265
+ valid_dataset = ELI5DatasetQARetriever(eli5['validation'], training=False)
266
+
267
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
268
+ base_model = AutoModel.from_pretrained(args.pretrained_model_name)
269
+
270
+ model = RetrievalQAEmbedder(base_model)
271
+ no_decay = ['bias', 'LayerNorm.weight']
272
+ optimizer_grouped_parameters = [
273
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
274
+ 'weight_decay': args.weight_decay},
275
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
276
+ ]
277
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
278
+
279
+ model_collate_fn = functools.partial(make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length)
280
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
281
+ sampler=RandomSampler(train_dataset), collate_fn=model_collate_fn)
282
+
283
+ model_collate_fn = functools.partial(make_qa_retriever_batch, tokenizer=tokenizer, max_len=args.max_length)
284
+ eval_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size,
285
+ sampler=SequentialSampler(valid_dataset), collate_fn=model_collate_fn)
286
+
287
+ # train the model
288
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer,
289
+ train_dataloader, eval_dataloader)
290
+ # Scheduler and math around the number of training steps.
291
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
292
+ if args.max_train_steps is None:
293
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
294
+ else:
295
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
296
+
297
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
298
+ args.warmup_percentage)
299
+ scheduler = get_scheduler(
300
+ name=args.lr_scheduler_type,
301
+ optimizer=optimizer,
302
+ num_warmup_steps=args.num_warmup_steps,
303
+ num_training_steps=args.max_train_steps,
304
+ )
305
+
306
+ # Train!
307
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
308
+
309
+ logger.info("***** Running training *****")
310
+ logger.info(f" Num examples = {len(train_dataset)}")
311
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
312
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
313
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
314
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
315
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
316
+ logger.info(f" Warmup steps = {num_warmup_steps}")
317
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
318
+
319
+ loc_loss = 0.0
320
+ current_loss = 0.0
321
+ checkpoint_step = 0
322
+
323
+ completed_steps = checkpoint_step
324
+ progress_bar = tqdm(range(args.max_train_steps), initial=checkpoint_step,
325
+ disable=not accelerator.is_local_main_process)
326
+ for epoch in range(args.num_train_epochs):
327
+ model.train()
328
+ batch = next(iter(train_dataloader))
329
+ for step in range(1000):
330
+ #for step, batch in enumerate(train_dataloader, start=checkpoint_step):
331
+ # model inputs
332
+ q_ids, q_mask, a_ids, a_mask = batch
333
+ pre_loss = model(q_ids, q_mask, a_ids, a_mask, checkpoint_batch_size=args.checkpoint_batch_size)
334
+ loss = pre_loss.sum() / args.gradient_accumulation_steps
335
+ accelerator.backward(loss)
336
+ loc_loss += loss.item()
337
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
338
+ current_loss = loc_loss
339
+ optimizer.step()
340
+ scheduler.step()
341
+ optimizer.zero_grad()
342
+ progress_bar.update(1)
343
+ progress_bar.set_postfix(loss=loc_loss)
344
+ loc_loss = 0
345
+ completed_steps += 1
346
+
347
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
348
+ accelerator.wait_for_everyone()
349
+ unwrapped_model = accelerator.unwrap_model(model)
350
+ eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
351
+ logger.info(f"Train loss {current_loss} , eval loss {eval_loss}")
352
+ if args.wandb and accelerator.is_local_main_process:
353
+ import wandb
354
+ wandb.log({"loss": current_loss, "eval_loss": eval_loss, "step": completed_steps})
355
+
356
+ if completed_steps >= args.max_train_steps:
357
+ break
358
+
359
+ logger.info("Saving model {}".format(args.model_save_name))
360
+ accelerator.wait_for_everyone()
361
+ unwrapped_model = accelerator.unwrap_model(model)
362
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
363
+ eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
364
+ logger.info("Evaluation loss epoch {:4d}: {:.3f}".format(epoch, eval_loss))
365
+
366
+
367
+ if __name__ == "__main__":
368
+ parser = get_parser()
369
+ parser.add_argument(
370
+ "--wandb",
371
+ action="store_true",
372
+ help="Whether to use W&B logging",
373
+ )
374
+ main_args, _ = parser.parse_known_args()
375
+ config = {"args": main_args}
376
+ if main_args.wandb:
377
+ import wandb
378
+ wandb.init(project="Retriever")
379
+
380
+ train(config=config)
381
+
training/run_retriever_no_trainer_gpl.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import List, Any, Union, Optional
6
+
7
+ import torch
8
+ import ujson
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import set_seed
11
+ from torch import nn, Tensor
12
+ from torch.nn import functional as F
13
+ from torch.utils.data import Dataset, RandomSampler, DataLoader, SequentialSampler
14
+ from tqdm.auto import tqdm
15
+ from transformers import get_scheduler, AutoTokenizer, AutoModel, AdamW, SchedulerType, PreTrainedTokenizerBase, AutoModelForSequenceClassification, BatchEncoding
16
+ from transformers.file_utils import PaddingStrategy
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def get_parser():
22
+ parser = argparse.ArgumentParser(description="Train LFQA retriever")
23
+ parser.add_argument(
24
+ "--dpr_input_file",
25
+ type=str,
26
+ help="DPR formatted input file with question/positive/negative pairs in a JSONL file",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--per_device_train_batch_size",
31
+ type=int,
32
+ default=32,
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--per_device_eval_batch_size",
37
+ type=int,
38
+ default=32,
39
+ help="Batch size (per device) for the evaluation dataloader.",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--max_length",
44
+ type=int,
45
+ default=128,
46
+ )
47
+
48
+
49
+ parser.add_argument(
50
+ "--pretrained_model_name",
51
+ type=str,
52
+ default="sentence-transformers/all-MiniLM-L6-v2",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--ce_model_name",
57
+ type=str,
58
+ default="cross-encoder/ms-marco-MiniLM-L-6-v2",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--model_save_name",
63
+ type=str,
64
+ default="eli5_retriever_model_l-12_h-768_b-512-512",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--learning_rate",
69
+ type=float,
70
+ default=2e-5,
71
+ )
72
+
73
+ parser.add_argument(
74
+ "--weight_decay",
75
+ type=float,
76
+ default=0.01,
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--log_freq",
81
+ type=int,
82
+ default=500,
83
+ help="Log train/validation loss every log_freq update steps"
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--num_train_epochs",
88
+ type=int,
89
+ default=4,
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--max_train_steps",
94
+ type=int,
95
+ default=None,
96
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--gradient_accumulation_steps",
101
+ type=int,
102
+ default=1,
103
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--lr_scheduler_type",
108
+ type=SchedulerType,
109
+ default="linear", # this is linear with warmup
110
+ help="The scheduler type to use.",
111
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--num_warmup_steps",
116
+ type=int,
117
+ default=100,
118
+ help="Number of steps for the warmup in the lr scheduler."
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--warmup_percentage",
123
+ type=float,
124
+ default=0.08,
125
+ help="Number of steps for the warmup in the lr scheduler."
126
+ )
127
+ return parser
128
+
129
+
130
+ @dataclass
131
+ class InputExample:
132
+ guid: str = ""
133
+ texts: List[str] = None
134
+ label: Union[int, float] = 0
135
+
136
+
137
+ class DPRDataset(Dataset):
138
+ """
139
+ Dataset DPR format of question, answers, positive, negative, and hard negative passages
140
+ See https://github.com/facebookresearch/DPR#retriever-input-data-format for more details
141
+ """
142
+
143
+ def __init__(self, file_path: str, include_all_positive: bool = False) -> None:
144
+ super().__init__()
145
+ with open(file_path, "r") as fp:
146
+ self.data = []
147
+
148
+ def dpr_example_to_input_example(idx, dpr_item):
149
+ examples = []
150
+ for p_idx, p_item in enumerate(dpr_item["positive_ctxs"]):
151
+ for n_idx, n_item in enumerate(dpr_item["negative_ctxs"]):
152
+ examples.append(InputExample(guid=[idx, p_idx, n_idx], texts=[dpr_item["question"],
153
+ p_item["text"],
154
+ n_item["text"]]))
155
+ if not include_all_positive:
156
+ break
157
+ return examples
158
+
159
+ for idx, line in enumerate(fp):
160
+ self.data.extend(dpr_example_to_input_example(idx, ujson.loads(line)))
161
+
162
+ def __len__(self):
163
+ return len(self.data)
164
+
165
+ def __getitem__(self, index):
166
+ return self.data[index]
167
+
168
+
169
+ def dpr_collate_fn(batch):
170
+ query_id, pos_id, neg_id = zip(*[example.guid for example in batch])
171
+ query, pos, neg = zip(*[example.texts for example in batch])
172
+ return (query_id, pos_id, neg_id), (query, pos, neg)
173
+
174
+
175
+ # Mean Pooling - Take attention mask into account for correct averaging
176
+ def mean_pooling(model_output, attention_mask):
177
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
178
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
179
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
180
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
181
+ return sum_embeddings / sum_mask
182
+
183
+
184
+ @dataclass
185
+ class CrossEncoderCollator:
186
+ tokenizer: PreTrainedTokenizerBase
187
+ model: Any
188
+ target_tokenizer: PreTrainedTokenizerBase
189
+ padding: Union[bool, str, PaddingStrategy] = True
190
+ max_length: Optional[int] = None
191
+ pad_to_multiple_of: Optional[int] = None
192
+ return_tensors: str = "pt"
193
+
194
+ def __call__(self, batch):
195
+ query_id, pos_id, neg_id = zip(*[example.guid for example in batch])
196
+ query, pos_passage, neg_passage = zip(*[example.texts for example in batch])
197
+ batch_input: List[List[str]] = list(zip(query, pos_passage)) + list(zip(query, neg_passage))
198
+ features = self.tokenizer(batch_input, padding=self.padding, truncation=True,
199
+ return_tensors=self.return_tensors)
200
+ with torch.no_grad():
201
+ scores = self.model(**features).logits
202
+
203
+ labels = scores[:len(query)] - scores[len(query):]
204
+ batch_input: List[str] = list(query) + list(pos_passage) + list(neg_passage)
205
+ #breakpoint()
206
+ encoded_input = self.target_tokenizer(batch_input, padding=True, truncation=True,
207
+ max_length=256, return_tensors='pt')
208
+
209
+ encoded_input["labels"] = labels
210
+
211
+ return encoded_input
212
+
213
+
214
+ class RetrievalQAEmbedder(torch.nn.Module):
215
+ def __init__(self, sent_encoder, sent_tokenizer, batch_size:int = 32):
216
+ super(RetrievalQAEmbedder, self).__init__()
217
+ dim = sent_encoder.config.hidden_size
218
+ self.model = sent_encoder
219
+ self.tokenizer = sent_tokenizer
220
+ self.scale = 1
221
+ self.similarity_fct = 'dot'
222
+ self.batch_size = 32
223
+ self.loss_fct = nn.MSELoss()
224
+
225
+ def forward(self, examples: BatchEncoding):
226
+ # Tokenize sentences
227
+ labels = examples.pop("labels")
228
+ # Compute token embeddings
229
+ model_output = self.model(**examples)
230
+
231
+ examples["labels"] = labels
232
+
233
+ # Perform pooling. In this case, mean pooling
234
+ sentence_embeddings = mean_pooling(model_output, examples['attention_mask'])
235
+ target_shape = (3, self.batch_size, sentence_embeddings.shape[-1])
236
+ sentence_embeddings_reshaped = torch.reshape(sentence_embeddings, target_shape)
237
+
238
+ #breakpoint()
239
+
240
+ embeddings_query = sentence_embeddings_reshaped[0]
241
+ embeddings_pos = sentence_embeddings_reshaped[1]
242
+ embeddings_neg = sentence_embeddings_reshaped[2]
243
+
244
+ if self.similarity_fct == 'cosine':
245
+ embeddings_query = F.normalize(embeddings_query, p=2, dim=1)
246
+ embeddings_pos = F.normalize(embeddings_pos, p=2, dim=1)
247
+ embeddings_neg = F.normalize(embeddings_neg, p=2, dim=1)
248
+
249
+ scores_pos = (embeddings_query * embeddings_pos).sum(dim=-1) * self.scale
250
+ scores_neg = (embeddings_query * embeddings_neg).sum(dim=-1) * self.scale
251
+ margin_pred = scores_pos - scores_neg
252
+ #breakpoint()
253
+ return self.loss_fct(margin_pred, labels.squeeze())
254
+
255
+
256
+ def evaluate_qa_retriever(model, data_loader):
257
+ # make iterator
258
+ epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
259
+ tot_loss = 0.0
260
+ with torch.no_grad():
261
+ for step, batch in enumerate(epoch_iterator):
262
+ q_ids, q_mask, a_ids, a_mask = batch
263
+ loss = model(q_ids, q_mask, a_ids, a_mask)
264
+ tot_loss += loss.item()
265
+ return tot_loss / (step + 1)
266
+
267
+
268
+ def train(config):
269
+ set_seed(42)
270
+ args = config["args"]
271
+
272
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
273
+ accelerator = Accelerator()
274
+ # Make one log on every process with the configuration for debugging.
275
+ logging.basicConfig(
276
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
277
+ datefmt="%m/%d/%Y %H:%M:%S",
278
+ level=logging.INFO,
279
+ )
280
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
281
+ logger.info(accelerator.state)
282
+
283
+ # prepare torch Dataset objects
284
+ train_dataset = DPRDataset(file_path=args.dpr_input_file)
285
+ valid_dataset = Dataset()
286
+
287
+ base_tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
288
+ base_model = AutoModel.from_pretrained(args.pretrained_model_name)
289
+
290
+ ce_tokenizer = AutoTokenizer.from_pretrained(args.ce_model_name)
291
+ ce_model = AutoModelForSequenceClassification.from_pretrained(args.ce_model_name)
292
+ _ = ce_model.eval()
293
+
294
+ model = RetrievalQAEmbedder(base_model, base_tokenizer)
295
+ no_decay = ['bias', 'LayerNorm.weight']
296
+ optimizer_grouped_parameters = [
297
+ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
298
+ 'weight_decay': args.weight_decay},
299
+ {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
300
+ ]
301
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
302
+
303
+ cec = CrossEncoderCollator(model=ce_model, tokenizer=ce_tokenizer, target_tokenizer=base_tokenizer)
304
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
305
+ sampler=RandomSampler(train_dataset), collate_fn=cec)
306
+
307
+ eval_dataloader = DataLoader(valid_dataset, batch_size=args.per_device_eval_batch_size,
308
+ sampler=SequentialSampler(valid_dataset), collate_fn=cec)
309
+
310
+ # train the model
311
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer,
312
+ train_dataloader, eval_dataloader)
313
+ # Scheduler and math around the number of training steps.
314
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
315
+ if args.max_train_steps is None:
316
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
317
+ else:
318
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
319
+
320
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
321
+ args.warmup_percentage)
322
+ scheduler = get_scheduler(
323
+ name=args.lr_scheduler_type,
324
+ optimizer=optimizer,
325
+ num_warmup_steps=args.num_warmup_steps,
326
+ num_training_steps=args.max_train_steps,
327
+ )
328
+
329
+ # Train!
330
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
331
+
332
+ logger.info("***** Running training *****")
333
+ logger.info(f" Num examples = {len(train_dataset)}")
334
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
335
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
336
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
337
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
338
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
339
+ logger.info(f" Warmup steps = {num_warmup_steps}")
340
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
341
+
342
+ loc_loss = 0.0
343
+ current_loss = 0.0
344
+ checkpoint_step = 0
345
+
346
+ completed_steps = checkpoint_step
347
+ progress_bar = tqdm(range(args.max_train_steps), initial=checkpoint_step,
348
+ disable=not accelerator.is_local_main_process)
349
+ for epoch in range(args.num_train_epochs):
350
+ model.train()
351
+ for step, batch in enumerate(train_dataloader, start=checkpoint_step):
352
+ # model inputs
353
+ pre_loss = model(batch)
354
+ loss = pre_loss / args.gradient_accumulation_steps
355
+ accelerator.backward(loss)
356
+ loc_loss += loss.item()
357
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
358
+ current_loss = loc_loss
359
+ optimizer.step()
360
+ scheduler.step()
361
+ optimizer.zero_grad()
362
+ progress_bar.update(1)
363
+ progress_bar.set_postfix(loss=loc_loss)
364
+ loc_loss = 0
365
+ completed_steps += 1
366
+
367
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
368
+ # accelerator.wait_for_everyone()
369
+ # unwrapped_model = accelerator.unwrap_model(model)
370
+ # eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
371
+ eval_loss = 0
372
+ logger.info(f"Train loss {current_loss} , eval loss {eval_loss}")
373
+ if args.wandb and accelerator.is_local_main_process:
374
+ import wandb
375
+ wandb.log({"loss": current_loss, "eval_loss": eval_loss, "step": completed_steps})
376
+
377
+ if completed_steps >= args.max_train_steps:
378
+ break
379
+
380
+ logger.info("Saving model {}".format(args.model_save_name))
381
+ accelerator.wait_for_everyone()
382
+ unwrapped_model = accelerator.unwrap_model(model)
383
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
384
+ eval_loss = evaluate_qa_retriever(unwrapped_model, eval_dataloader)
385
+ logger.info("Evaluation loss epoch {:4d}: {:.3f}".format(epoch, eval_loss))
386
+
387
+
388
+ if __name__ == "__main__":
389
+ parser = get_parser()
390
+ parser.add_argument(
391
+ "--wandb",
392
+ action="store_true",
393
+ help="Whether to use W&B logging",
394
+ )
395
+ main_args, _ = parser.parse_known_args()
396
+ config = {"args": main_args}
397
+ if main_args.wandb:
398
+ import wandb
399
+
400
+ wandb.init(project="Retriever")
401
+
402
+ train(config=config)
403
+
training/run_seq2seq_no_trainer.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import re
5
+
6
+ import numpy as np
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+ from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSeq2SeqLM, \
13
+ DataCollatorWithPadding
14
+
15
+ from datasets import load_dataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_parser():
21
+ parser = argparse.ArgumentParser(description="Train ELI5 seq2seq answer generation model")
22
+ parser.add_argument(
23
+ "--dataset_name",
24
+ type=str,
25
+ default="vblagoje/lfqa",
26
+ help="The name of the dataset to use (via the datasets library).",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--per_device_train_batch_size",
31
+ type=int,
32
+ default=4,
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--per_device_eval_batch_size",
37
+ type=int,
38
+ default=4,
39
+ help="Batch size (per device) for the evaluation dataloader.",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--pretrained_model_name",
44
+ type=str,
45
+ default="facebook/bart-large",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--model_save_name",
50
+ type=str,
51
+ default="eli5_bart_model",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--learning_rate",
56
+ type=float,
57
+ default=2e-4,
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--weight_decay",
62
+ type=float,
63
+ default=0.0,
64
+ help="Weight decay to use."
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--log_freq",
69
+ type=int,
70
+ default=100,
71
+ help="Log train/validation loss every log_freq update steps"
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--ignore_pad_token_for_loss",
76
+ type=bool,
77
+ default=True,
78
+ help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
79
+ )
80
+
81
+ parser.add_argument(
82
+ "--num_train_epochs",
83
+ type=int,
84
+ default=3,
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--max_train_steps",
89
+ type=int,
90
+ default=None,
91
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--gradient_accumulation_steps",
96
+ type=int,
97
+ default=16,
98
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--pad_to_max_length",
103
+ action="store_true",
104
+ help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets"
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--max_source_length",
113
+ type=int,
114
+ default=1024,
115
+ help="The maximum total input sequence length after "
116
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--max_target_length",
121
+ type=int,
122
+ default=360,
123
+ help="The maximum total sequence length for target text after "
124
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--lr_scheduler_type",
129
+ type=SchedulerType,
130
+ default="linear", # this is linear with warmup
131
+ help="The scheduler type to use.",
132
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--num_warmup_steps",
137
+ type=int,
138
+ default=None,
139
+ help="Number of steps for the warmup in the lr scheduler."
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--warmup_percentage",
144
+ type=float,
145
+ default=0.08,
146
+ help="Number of steps for the warmup in the lr scheduler."
147
+ )
148
+ return parser
149
+
150
+
151
+ def cleanup_references(text):
152
+ # URL reference where we need to remove both the link text and URL
153
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal
154
+ # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
155
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
156
+ result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
157
+
158
+ # URL reference where we need to preserve link text but remove URL
159
+ # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
160
+ # At the outbreak of the Civil War, Leyburn left his church and joined the South.
161
+ result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
162
+
163
+ # lastly remove just dangling _URL_[0-9]_ URL references
164
+ result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
165
+ return result
166
+
167
+
168
+ def clean_answer(text):
169
+ result = cleanup_references(text)
170
+ result = result.replace("\n", " ")
171
+ result = re.sub(r"\s\s+", " ", result)
172
+ result = re.sub(r"BULLET::::-", "", result)
173
+ return result.strip()
174
+
175
+
176
+ def clean_question(text):
177
+ result = cleanup_references(text)
178
+ result = result.replace("\n", " ")
179
+ result = re.sub(r"\s\s+", " ", result)
180
+ result = result.replace("[deleted]", "")
181
+ return result.lower().strip()
182
+
183
+
184
+ def prepare_support_docs(example):
185
+ provenances = example["output"][-1]["provenance"]
186
+ context = "<P> " + " <P> ".join([p["text"] for p in provenances])
187
+ return {"context": context}
188
+
189
+
190
+ def preprocess_eli5(examples, **fn_kwargs):
191
+ document_cache = fn_kwargs["document_cache"]
192
+ training = fn_kwargs.get("training", True)
193
+ extra_answer_threshold = fn_kwargs.get("extra_answer_threshold", 3)
194
+ include_selftext = fn_kwargs.get("include_selftext", False)
195
+ exclude_answer_patterns = fn_kwargs.get("exclude_answer_patterns", [])
196
+
197
+ questions, contexts, answers = [], [], []
198
+ for q_id, question, selftext, answer in zip(examples["q_id"], examples["title"], examples["selftext"],
199
+ examples["answers"]):
200
+ accepted_answer_idx = []
201
+ if training:
202
+ accepted_answer_idx = [idx for idx, score in enumerate(answer["score"]) if
203
+ score > extra_answer_threshold]
204
+ if not training or not accepted_answer_idx:
205
+ accepted_answer_idx = [0]
206
+ document = document_cache[q_id]
207
+ for idx in accepted_answer_idx:
208
+ skip_answer = any([p.search(answer["text"][idx]) for p in exclude_answer_patterns])
209
+ if skip_answer:
210
+ continue
211
+ if include_selftext:
212
+ questions.append(clean_question(f"{question} {selftext}"))
213
+ else:
214
+ questions.append(clean_question(question))
215
+ contexts.append(document.lower().strip())
216
+ answers.append(clean_answer(answer["text"][idx]))
217
+
218
+ return {"question": questions, "context": contexts, "answer": answers}
219
+
220
+
221
+ def eval_qa_s2s_epoch(model, dataloader, accelerator, args):
222
+ model.eval()
223
+ num_eval_steps = math.ceil(len(dataloader))
224
+ progress_bar = tqdm(range(num_eval_steps), disable=not accelerator.is_local_main_process)
225
+ total_loss = 0.
226
+ with torch.no_grad():
227
+ for step, batch in enumerate(dataloader):
228
+ outputs = model(**batch)
229
+ loss = outputs.loss
230
+ total_loss += loss.item()
231
+ progress_bar.update(1)
232
+ progress_bar.set_postfix(loss=round((total_loss / (step + 1)), 3))
233
+ return total_loss / (step + 1)
234
+
235
+
236
+ def train(config):
237
+ set_seed(42)
238
+ args = config["args"]
239
+ eli5 = load_dataset(args.dataset_name)
240
+
241
+ support_docs = load_dataset("vblagoje/lfqa_support_docs")
242
+
243
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
244
+ accelerator = Accelerator()
245
+ # Make one log on every process with the configuration for debugging.
246
+ logging.basicConfig(
247
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
248
+ datefmt="%m/%d/%Y %H:%M:%S",
249
+ level=logging.INFO,
250
+ )
251
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
252
+ logger.info(accelerator.state)
253
+
254
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
255
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.pretrained_model_name)
256
+
257
+ # Optimizer
258
+ # Split weights in two groups, one with weight decay and the other not.
259
+ no_decay = ["bias", "LayerNorm.weight"]
260
+ optimizer_grouped_parameters = [
261
+ {
262
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
263
+ "weight_decay": args.weight_decay,
264
+ },
265
+ {
266
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
267
+ "weight_decay": 0.0,
268
+ },
269
+ ]
270
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
271
+
272
+ processed_datasets = {}
273
+ support_docs_prepared = {}
274
+ with accelerator.main_process_first():
275
+ for split in ["train", "validation"]:
276
+ support_docs_prepared[split] = support_docs[split].map(prepare_support_docs,
277
+ batched=False,
278
+ cache_file_name=f"./support_docs_{split}.arrow",
279
+ load_from_cache_file=not args.overwrite_cache,
280
+ desc="Preparing support docs",
281
+ )
282
+ column_names = eli5["train"].column_names
283
+ for split in ["train", "validation"]:
284
+ d_cache = dict([(e["id"], e["context"]) for e in tqdm(support_docs_prepared[split],
285
+ desc=f"Adding support docs to LFQA {split}")])
286
+ processed_datasets[split] = eli5[split].map(preprocess_eli5,
287
+ batched=True,
288
+ remove_columns=column_names,
289
+ cache_file_name=f"./processed_datasets_{split}.arrow",
290
+ load_from_cache_file=not args.overwrite_cache,
291
+ desc="Preparing dataset for tokenization",
292
+ fn_kwargs={"document_cache": d_cache,
293
+ "training": split == "train",
294
+ "exclude_answer_patterns": [re.compile("not sure what you"),
295
+ re.compile("\n\n >")]}
296
+ )
297
+
298
+ padding = "max_length" if args.pad_to_max_length else False
299
+ # Temporarily set max_target_length for training.
300
+ max_target_length = args.max_target_length
301
+
302
+ label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
303
+
304
+ def tokenize_dataset(examples):
305
+ inputs = ["question: {} context: {}".format(q, c) for q, c in zip(examples["question"], examples["context"])]
306
+ targets = examples["answer"]
307
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
308
+
309
+ # Setup the tokenizer for targets
310
+ with tokenizer.as_target_tokenizer():
311
+ labels = tokenizer(targets, max_length=max_target_length, padding=True, truncation=True,
312
+ return_tensors="np")
313
+
314
+ model_inputs["decoder_input_ids"] = labels["input_ids"][:, :-1].tolist()
315
+ # replace pad_token_id with label_pad_token_id to avoid loss calculation on those tokens
316
+ labels["input_ids"] = np.where(labels["input_ids"] == tokenizer.pad_token_id,
317
+ label_pad_token_id, labels["input_ids"])
318
+
319
+ model_inputs["labels"] = labels["input_ids"][:, 1:].tolist()
320
+ return model_inputs
321
+
322
+ tokenized_datasets = {}
323
+ with accelerator.main_process_first():
324
+ for split, dataset in processed_datasets.items():
325
+ tokenized_datasets[split] = dataset.map(
326
+ tokenize_dataset,
327
+ batched=True,
328
+ cache_file_name=f"./tokenized_dataset_{split}.arrow",
329
+ remove_columns=dataset.column_names,
330
+ load_from_cache_file=not args.overwrite_cache,
331
+ desc="Running tokenizer on dataset"
332
+ )
333
+
334
+ train_dataset = tokenized_datasets["train"]
335
+ eval_dataset = tokenized_datasets["validation"]
336
+ train_dataset.set_format(type='torch')
337
+ eval_dataset.set_format(type='torch')
338
+
339
+ data_collator = DataCollatorWithPadding(tokenizer, "max_length")
340
+
341
+ # first epoch we don't shuffle
342
+ train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.per_device_train_batch_size,
343
+ collate_fn=data_collator)
344
+ eval_dataloader = DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=data_collator)
345
+
346
+ # train the model
347
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader,
348
+ eval_dataloader)
349
+ # Scheduler and math around the number of training steps.
350
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
351
+ if args.max_train_steps is None:
352
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
353
+ else:
354
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
355
+
356
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
357
+ args.warmup_percentage)
358
+ scheduler = get_scheduler(
359
+ name=args.lr_scheduler_type,
360
+ optimizer=optimizer,
361
+ num_warmup_steps=num_warmup_steps,
362
+ num_training_steps=args.max_train_steps,
363
+ )
364
+ # Train!
365
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
366
+
367
+ logger.info("***** Running training *****")
368
+ logger.info(f" Num examples = {len(train_dataset)}")
369
+ logger.info(f" Num eval examples = {len(eval_dataset)}")
370
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
371
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
372
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
373
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
374
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
375
+ logger.info(f" Warmup steps = {num_warmup_steps}")
376
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
377
+
378
+ # Only show the progress bar once on each machine.
379
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
380
+ completed_steps = 0
381
+ switched_train_dataloader = False
382
+ for epoch in range(args.num_train_epochs):
383
+ model.train()
384
+ if epoch > 0 and not switched_train_dataloader:
385
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
386
+ shuffle=True, collate_fn=data_collator)
387
+ train_dataloader = accelerator.prepare(train_dataloader)
388
+ switched_train_dataloader = True
389
+
390
+ for step, batch in enumerate(train_dataloader):
391
+ outputs = model(**batch)
392
+ loss = torch.mean(outputs.loss)
393
+ accelerator.backward(loss)
394
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
395
+ optimizer.step()
396
+ scheduler.step()
397
+ optimizer.zero_grad()
398
+ progress_bar.update(1)
399
+ progress_bar.set_postfix(loss=round(loss.item(), 3))
400
+ completed_steps += 1
401
+
402
+ if completed_steps >= args.max_train_steps:
403
+ break
404
+
405
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
406
+ validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
407
+ model.train()
408
+ logger.info(f"Train loss {loss.item()} , validation loss {validation_loss}")
409
+ if args.wandb and accelerator.is_local_main_process:
410
+ import wandb
411
+ wandb.log({"loss": loss.item(),
412
+ "lr": scheduler.get_last_lr()[0],
413
+ "validation_loss": validation_loss,
414
+ "completed_steps": completed_steps})
415
+
416
+ logger.info("Saving model {}".format(args.model_save_name))
417
+ accelerator.wait_for_everyone()
418
+ unwrapped_model = accelerator.unwrap_model(model)
419
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
420
+
421
+ # Calculating the validation loss over epoch
422
+ validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
423
+
424
+ logger.info("Epoch: {}".format(epoch))
425
+ logger.info("Validation loss: {}".format(validation_loss))
426
+
427
+
428
+ def main():
429
+ parser = get_parser()
430
+ parser.add_argument(
431
+ "--wandb",
432
+ action="store_true",
433
+ help="If true, use W&B logging",
434
+ )
435
+ main_args, _ = parser.parse_known_args()
436
+ config = {"args": main_args}
437
+ if main_args.wandb:
438
+ import wandb
439
+ wandb.init(project="Bart_ELI5")
440
+ train(config=config)
441
+
442
+
443
+ main()
444
+
445
+
446
+