winglian daaave commited on
Commit
367b2e8
1 Parent(s): bbfed31

Switch to parallel FFD bin packing algorithm. (#1619)

Browse files

* Switch to parallel FFD bin packing algorithm.

Add support for packing in a distributed context.
Add packing efficiency estimate back.

* revert changes to distributed code

* chore: lint

* fix config w new params for packing test

* add sample_packing_group_size and sample_packing_bin_size to cfg schema

* fix lamdbda function

* fix sampler/dataloader calculations for packing

---------

Co-authored-by: dsesclei <dave@sescleifer.com>

docs/config.qmd CHANGED
@@ -186,6 +186,11 @@ eval_sample_packing:
186
  # The trainer will provide recommended values for these values.
187
  sample_packing_eff_est:
188
  total_num_tokens:
 
 
 
 
 
189
 
190
  # Passed through to transformers when loading the model when launched without accelerate
191
  # Use `sequential` when training w/ model parallelism to limit memory
 
186
  # The trainer will provide recommended values for these values.
187
  sample_packing_eff_est:
188
  total_num_tokens:
189
+ # Increasing the following values helps with packing, but usually only slightly (<%1.)
190
+ # The number of samples packed at a time.
191
+ sample_packing_group_size: 100000
192
+ # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
193
+ sample_packing_bin_size: 200
194
 
195
  # Passed through to transformers when loading the model when launched without accelerate
196
  # Use `sequential` when training w/ model parallelism to limit memory
src/axolotl/core/trainer_builder.py CHANGED
@@ -125,14 +125,22 @@ class AxolotlTrainingArguments(TrainingArguments):
125
  default=1.0,
126
  metadata={"help": "Sample packing efficiency for calculating batch length."},
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
128
  max_seq_length: int = field(
129
  default=2048,
130
  metadata={"help": "The maximum sequence length the model can handle"},
131
  )
132
- sample_packing_seq_len_multiplier: int = field(
133
- default=1,
134
- metadata={"help": "the multiplier for the max len for packed sequences"},
135
- )
136
  relora_steps: Optional[int] = field(
137
  default=None,
138
  metadata={"help": "how often to reset for ReLoRA"},
@@ -346,11 +354,11 @@ class AxolotlTrainer(Trainer):
346
  )
347
  return MultipackBatchSampler(
348
  RandomSampler(self.train_dataset),
349
- batch_size=batch_size,
350
- drop_last=True,
351
- batch_max_len=batch_max_len,
352
  lengths=get_dataset_lengths(self.train_dataset),
353
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
 
 
 
354
  )
355
  if self.args.curriculum_sampling:
356
  return SequentialSampler(self.train_dataset)
@@ -370,11 +378,11 @@ class AxolotlTrainer(Trainer):
370
  )
371
  return MultipackBatchSampler(
372
  SequentialSampler(eval_dataset),
373
- batch_size=batch_size,
374
- drop_last=True,
375
  batch_max_len=batch_max_len,
376
- lengths=get_dataset_lengths(eval_dataset),
377
- packing_efficiency_estimate=self.args.sample_packing_efficiency,
 
378
  )
379
  return super()._get_eval_sampler(eval_dataset)
380
 
@@ -1113,11 +1121,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1113
  if self.cfg.save_safetensors is not None:
1114
  training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
1115
 
1116
- if self.cfg.sample_packing_eff_est:
1117
- training_arguments_kwargs[
1118
- "sample_packing_efficiency"
1119
- ] = self.cfg.sample_packing_eff_est
1120
-
1121
  if self.cfg.dataloader_pin_memory is not None:
1122
  training_arguments_kwargs[
1123
  "dataloader_pin_memory"
@@ -1293,20 +1296,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1293
  training_arguments_kwargs["weight_decay"] = (
1294
  self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
1295
  )
1296
- training_arguments_kwargs["sample_packing"] = (
1297
- self.cfg.sample_packing if self.cfg.sample_packing else False
1298
- )
1299
- training_arguments_kwargs["multipack_real_batches"] = (
1300
- self.cfg.flash_attention is not True
1301
- )
1302
- training_arguments_kwargs["eval_sample_packing"] = (
1303
- self.cfg.sample_packing
1304
- if self.cfg.eval_sample_packing is not False
1305
- else False
1306
- )
1307
  training_arguments_kwargs[
1308
- "sample_packing_seq_len_multiplier"
1309
- ] = self.cfg.micro_batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
  if self.cfg.relora_steps:
1311
  training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
1312
  training_arguments_kwargs[
 
125
  default=1.0,
126
  metadata={"help": "Sample packing efficiency for calculating batch length."},
127
  )
128
+ sample_packing_bin_size: int = field(
129
+ default=200,
130
+ metadata={
131
+ "help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
132
+ },
133
+ )
134
+ sample_packing_group_size: int = field(
135
+ default=100000,
136
+ metadata={
137
+ "help": "The number of samples to group together for packing. Increase for better packing."
138
+ },
139
+ )
140
  max_seq_length: int = field(
141
  default=2048,
142
  metadata={"help": "The maximum sequence length the model can handle"},
143
  )
 
 
 
 
144
  relora_steps: Optional[int] = field(
145
  default=None,
146
  metadata={"help": "how often to reset for ReLoRA"},
 
354
  )
355
  return MultipackBatchSampler(
356
  RandomSampler(self.train_dataset),
 
 
 
357
  lengths=get_dataset_lengths(self.train_dataset),
358
+ batch_max_len=batch_max_len,
359
+ batch_size=batch_size,
360
+ group_size=self.args.sample_packing_group_size,
361
+ bin_size=self.args.sample_packing_bin_size,
362
  )
363
  if self.args.curriculum_sampling:
364
  return SequentialSampler(self.train_dataset)
 
378
  )
379
  return MultipackBatchSampler(
380
  SequentialSampler(eval_dataset),
381
+ lengths=get_dataset_lengths(self.eval_dataset),
 
382
  batch_max_len=batch_max_len,
383
+ batch_size=batch_size,
384
+ group_size=self.args.sample_packing_group_size,
385
+ bin_size=self.args.sample_packing_bin_size,
386
  )
387
  return super()._get_eval_sampler(eval_dataset)
388
 
 
1121
  if self.cfg.save_safetensors is not None:
1122
  training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
1123
 
 
 
 
 
 
1124
  if self.cfg.dataloader_pin_memory is not None:
1125
  training_arguments_kwargs[
1126
  "dataloader_pin_memory"
 
1296
  training_arguments_kwargs["weight_decay"] = (
1297
  self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
1298
  )
1299
+
1300
+ training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
 
 
 
 
 
 
 
 
 
1301
  training_arguments_kwargs[
1302
+ "multipack_real_batches"
1303
+ ] = not self.cfg.flash_attention
1304
+ training_arguments_kwargs["eval_sample_packing"] = bool(
1305
+ self.cfg.eval_sample_packing
1306
+ )
1307
+ if self.cfg.sample_packing_bin_size is not None:
1308
+ training_arguments_kwargs[
1309
+ "sample_packing_bin_size"
1310
+ ] = self.cfg.sample_packing_bin_size
1311
+ if self.cfg.sample_packing_group_size is not None:
1312
+ training_arguments_kwargs[
1313
+ "sample_packing_group_size"
1314
+ ] = self.cfg.sample_packing_group_size
1315
+ if self.cfg.sample_packing_eff_est:
1316
+ training_arguments_kwargs[
1317
+ "sample_packing_efficiency"
1318
+ ] = self.cfg.sample_packing_eff_est
1319
+
1320
  if self.cfg.relora_steps:
1321
  training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
1322
  training_arguments_kwargs[
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -551,6 +551,8 @@ class AxolotlInputConfig(
551
  default=512, metadata={"help": "maximum prompt length for RL training"}
552
  )
553
  sample_packing: Optional[bool] = None
 
 
554
  eval_sample_packing: Optional[bool] = None
555
  pad_to_sequence_len: Optional[bool] = None
556
  curriculum_sampling: Optional[bool] = None
 
551
  default=512, metadata={"help": "maximum prompt length for RL training"}
552
  )
553
  sample_packing: Optional[bool] = None
554
+ sample_packing_group_size: Optional[int] = 100_000
555
+ sample_packing_bin_size: Optional[int] = 200
556
  eval_sample_packing: Optional[bool] = None
557
  pad_to_sequence_len: Optional[bool] = None
558
  curriculum_sampling: Optional[bool] = None
src/axolotl/utils/data/pretraining.py CHANGED
@@ -150,6 +150,8 @@ def wrap_pretraining_dataset(
150
  max_seq_length=max_tokens,
151
  batch_size=batch_size,
152
  multipack_attn=cfg.pretrain_multipack_attn,
 
 
153
  )
154
  # set this to 1 so downstream data_loader doesn't try to increase the batch again
155
  cfg.micro_batch_size = 1
@@ -189,6 +191,8 @@ def encode_packed_pretraining(
189
  max_seq_length: int = 2048,
190
  batch_size: int = 4,
191
  multipack_attn: Optional[bool] = False,
 
 
192
  ) -> Dict[str, List]:
193
  # pylint: disable=duplicate-code
194
  # tokenize all the examples
@@ -202,11 +206,13 @@ def encode_packed_pretraining(
202
  )
203
 
204
  sampler = MultipackBatchSampler(
205
- RandomSampler(train_dataset),
 
206
  batch_size=1,
207
- drop_last=True,
208
  batch_max_len=batch_size * max_seq_length,
209
- lengths=get_dataset_lengths(train_dataset),
 
 
210
  )
211
 
212
  chunked_data = defaultdict(list)
 
150
  max_seq_length=max_tokens,
151
  batch_size=batch_size,
152
  multipack_attn=cfg.pretrain_multipack_attn,
153
+ group_size=cfg.sample_packing_group_size,
154
+ bin_size=cfg.sample_packing_bin_size,
155
  )
156
  # set this to 1 so downstream data_loader doesn't try to increase the batch again
157
  cfg.micro_batch_size = 1
 
191
  max_seq_length: int = 2048,
192
  batch_size: int = 4,
193
  multipack_attn: Optional[bool] = False,
194
+ group_size: int = 100000,
195
+ bin_size: int = 200,
196
  ) -> Dict[str, List]:
197
  # pylint: disable=duplicate-code
198
  # tokenize all the examples
 
206
  )
207
 
208
  sampler = MultipackBatchSampler(
209
+ sampler=RandomSampler(train_dataset),
210
+ lengths=get_dataset_lengths(train_dataset),
211
  batch_size=1,
 
212
  batch_max_len=batch_size * max_seq_length,
213
+ group_size=group_size,
214
+ bin_size=bin_size,
215
+ drop_last=True,
216
  )
217
 
218
  chunked_data = defaultdict(list)
src/axolotl/utils/samplers/multipack.py CHANGED
@@ -1,105 +1,64 @@
1
- # pylint: skip-file
2
  """
3
  Multipack Batch Sampler
4
  """
5
  import logging
6
- import math
7
- import os
8
- from typing import Any, Iterable, List, Union
9
 
10
  import numba
11
  import numpy as np
12
- from torch.utils.data import BatchSampler, Sampler
13
 
14
  LOG = logging.getLogger("axolotl.utils.samplers.multipack")
15
 
16
 
 
17
  @numba.njit
18
- def ffd_check(a: np.ndarray, c: int, n: int):
19
- # First-fit-decreasing bin packing
20
- # Check if a[] could fit in n bins with capacity c
21
- # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
22
-
23
- a = np.sort(a)[::-1]
24
- bins = np.full((n,), c, dtype=a.dtype)
25
- for size in a:
26
- not_found = True
27
- for idx in range(n):
28
- if bins[idx] >= size:
29
- bins[idx] -= size
30
- not_found = False
 
 
 
 
 
31
  break
32
 
33
- if not_found:
34
- return False
 
 
35
 
36
- return True
37
 
38
 
39
- @numba.njit
40
- def ffd_with_result(a: np.ndarray, c: int, start_index: int):
41
- # First-fit-decreasing bin packing (with result return)
42
-
43
- indices = np.argsort(a)[::-1]
44
- a = a[indices]
45
-
46
- bins: List[Any] = []
47
- bins_result: List[Any] = []
48
- for a_id, size in enumerate(a):
49
- add_new = True
50
- for idx in range(len(bins)):
51
- if bins[idx] >= size:
52
- bins[idx] -= size
53
- bins_result[idx].append(indices[a_id] + start_index)
54
- add_new = False
55
- break
56
-
57
- if add_new:
58
- bins.append(c - size)
59
- bins_result.append([indices[a_id] + start_index])
60
-
61
- return bins_result
62
-
63
-
64
- @numba.njit
65
- def allocate(
66
- lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
67
- ):
68
- # Dynamic batch allocator, similar to Multifit
69
- # https://en.wikipedia.org/wiki/Multifit_algorithm
70
- # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
71
-
72
- s = 0
73
- start_index = 0
74
- result = []
75
-
76
- while True:
77
- # binary search [l, r)
78
- left = 1
79
- right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
80
-
81
- while right - left > 1:
82
- mid = (left + right) // 2
83
- if ffd_check(lengths[start_index : start_index + mid], c, n):
84
- left = mid
85
- else:
86
- right = mid
87
-
88
- # use length l
89
- batch = ffd_with_result(
90
- lengths[start_index : start_index + left], c, start_index
91
- )
92
- assert len(batch) <= n
93
- if len(batch) < n:
94
- break
95
-
96
- start_index += left
97
- s = lengths_cumsum[start_index - 1]
98
 
99
- # add local rank
100
- result.append(batch[rank])
 
 
 
 
 
101
 
102
- return result, s, len(result) * c * n
103
 
104
 
105
  class MultipackBatchSampler(BatchSampler):
@@ -109,94 +68,63 @@ class MultipackBatchSampler(BatchSampler):
109
 
110
  def __init__(
111
  self,
112
- sampler: Union[Sampler[int], Iterable[int]],
113
- batch_size: int,
114
- drop_last: bool,
115
- batch_max_len: int,
116
- lengths: np.ndarray,
117
- packing_efficiency_estimate: float = 1.0,
 
118
  ):
119
- super().__init__(sampler, batch_size, drop_last)
120
- self.batch_size = batch_size
121
  self.batch_max_len = batch_max_len
122
- self.lengths: np.ndarray = lengths
123
- self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
124
-
125
- assert isinstance(self.lengths, np.ndarray)
126
-
127
- self.epoch = 0
128
-
129
- # statistics
130
- self.eff_total_used = 0
131
- self.eff_total_slots = 0
132
-
133
- def set_epoch(self, epoch: int):
134
- self.epoch = epoch
135
-
136
- def generate_batches(self, set_stats=False):
137
- indices = [idx for idx in self.sampler]
138
 
139
- lengths = self.lengths[indices]
140
- lengths_cumsum = np.cumsum(lengths)
141
 
142
- batches, total_used, total_slots = allocate(
143
- lengths=lengths,
144
- lengths_cumsum=lengths_cumsum,
145
- rank=0,
146
- c=self.batch_max_len,
147
- n=1,
 
 
 
 
 
 
 
 
 
148
  )
149
 
150
- batches = [
151
- [
152
- [indices[b_idx] for b_idx in batch]
153
- for batch in batches[i : i + self.batch_size]
154
- ]
155
- for i in range(0, len(batches), self.batch_size)
 
 
156
  ]
157
 
158
- # statistics
159
- if set_stats:
160
- self.eff_total_used += total_used
161
- self.eff_total_slots += total_slots
162
 
163
- return batches
164
 
165
  def __iter__(self):
166
- batches = self.generate_batches(set_stats=True)
167
- return iter(batches)
168
-
169
- def num_batches(self):
170
- batches = self.generate_batches(set_stats=True)
171
- return len(batches)
172
-
173
- def efficiency(self):
174
- return self.eff_total_used / self.eff_total_slots
175
 
176
  def __len__(self):
177
- self.num_batches()
178
- return self._len_est()
179
-
180
- def _len_est(self):
181
- world_size = int(os.getenv("WORLD_SIZE", "1"))
182
- lengths_sum = np.sum(self.lengths)
183
- lengths_sum_per_device = lengths_sum // world_size
184
- LOG.info(
185
- f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
186
- f"total_num_tokens per device: {lengths_sum_per_device}"
187
- )
188
-
189
- # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
190
- return max(
191
- 0,
192
- (
193
- world_size
194
- * math.floor(
195
- 0.99
196
- * lengths_sum_per_device
197
- / self.packing_efficiency_estimate
198
- // (self.batch_max_len * self.batch_size)
199
- )
200
- - 1
201
- ),
202
- )
 
 
1
  """
2
  Multipack Batch Sampler
3
  """
4
  import logging
5
+ from concurrent.futures import ProcessPoolExecutor
6
+ from multiprocessing import cpu_count
 
7
 
8
  import numba
9
  import numpy as np
10
+ from torch.utils.data import BatchSampler
11
 
12
  LOG = logging.getLogger("axolotl.utils.samplers.multipack")
13
 
14
 
15
+ # First-fit-decreasing bin packing.
16
  @numba.njit
17
+ def pack_group(items, group_offset, bin_capacity, max_items_per_bin):
18
+ idxs = np.argsort(items)[::-1]
19
+ sorted_items = items[idxs]
20
+ num_bins = len(items)
21
+ bins = np.full(num_bins, bin_capacity, dtype=np.int32)
22
+ bin_counts = np.zeros(num_bins, dtype=np.int32)
23
+ group_packing = np.full((num_bins, max_items_per_bin), -1, dtype=np.int32)
24
+
25
+ for idx, item in enumerate(sorted_items):
26
+ global_idx = idxs[idx] + group_offset
27
+
28
+ placed = False
29
+ for i in range(num_bins):
30
+ if bins[i] >= item and bin_counts[i] < max_items_per_bin:
31
+ bins[i] -= item
32
+ group_packing[i, bin_counts[i]] = global_idx
33
+ bin_counts[i] += 1
34
+ placed = True
35
  break
36
 
37
+ if not placed:
38
+ raise ValueError(
39
+ f"Item could not be packed. Try increasing cfg.sample_packing_bin_size ({max_items_per_bin})."
40
+ )
41
 
42
+ return group_packing
43
 
44
 
45
+ def pack(items, bin_capacity, group_size, max_items_per_bin):
46
+ num_items = len(items)
47
+ num_processes = max(1, min(num_items // group_size, cpu_count()))
48
+ tasks = [
49
+ (items[i : i + group_size], i, bin_capacity, max_items_per_bin)
50
+ for i in range(0, num_items, group_size)
51
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ packed_bins = []
54
+ with ProcessPoolExecutor(max_workers=num_processes) as executor:
55
+ for group_packing in executor.map(pack_group, *zip(*tasks)):
56
+ for bin_pack in group_packing:
57
+ filtered_pack = bin_pack[bin_pack != -1]
58
+ if filtered_pack.size > 0:
59
+ packed_bins.append(filtered_pack.tolist())
60
 
61
+ return packed_bins
62
 
63
 
64
  class MultipackBatchSampler(BatchSampler):
 
68
 
69
  def __init__(
70
  self,
71
+ sampler,
72
+ lengths,
73
+ batch_max_len,
74
+ batch_size,
75
+ group_size=100_000,
76
+ bin_size=200,
77
+ drop_last=False,
78
  ):
79
+ self.sampler = sampler
80
+ self.lengths = np.array(lengths, dtype=np.int32)
81
  self.batch_max_len = batch_max_len
82
+ self.batch_size = batch_size
83
+ self.group_size = group_size
84
+ self.bin_size = bin_size
85
+ self.drop_last = drop_last
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ self._efficiency = None
88
+ self._batches = None
89
 
90
+ def efficiency(self):
91
+ if self._efficiency is None:
92
+ self._batches = self._pack_batches()
93
+ return self._efficiency
94
+
95
+ def _pack_batches(self):
96
+ # Get possibly shuffled indices from sampler.
97
+ sample_idxs = np.arange(len(self.sampler))
98
+ lengths = self.lengths[sample_idxs]
99
+
100
+ pack_idxs = pack(
101
+ lengths,
102
+ self.batch_max_len,
103
+ self.group_size,
104
+ self.bin_size,
105
  )
106
 
107
+ used_tokens = self.lengths.sum()
108
+ available_tokens = len(pack_idxs) * self.batch_max_len
109
+ self._efficiency = used_tokens / available_tokens
110
+
111
+ # Wrap packs into batches.
112
+ batch_idxs = [
113
+ pack_idxs[i : i + self.batch_size]
114
+ for i in range(0, len(pack_idxs), self.batch_size)
115
  ]
116
 
117
+ # Drop last batch if needed.
118
+ if self.drop_last and len(batch_idxs[-1]) < self.batch_size:
119
+ batch_idxs = batch_idxs[:-1]
 
120
 
121
+ return batch_idxs
122
 
123
  def __iter__(self):
124
+ self._batches = self._pack_batches()
125
+ return iter(self._batches)
 
 
 
 
 
 
 
126
 
127
  def __len__(self):
128
+ if self._batches is None:
129
+ self._batches = self._pack_batches()
130
+ return len(self._batches)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/trainer.py CHANGED
@@ -341,27 +341,26 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
341
  )
342
  else:
343
  if cfg.flash_attention:
344
- batch_size = 1
345
  batch_max_len = cfg.micro_batch_size * cfg.sequence_len
346
  else:
347
- batch_size = cfg.micro_batch_size
348
  batch_max_len = cfg.sequence_len
349
  sampler = MultipackBatchSampler(
350
  sampler=RandomSampler(train_dataset),
351
- batch_size=batch_size,
352
- drop_last=True,
353
- batch_max_len=batch_max_len,
354
  lengths=get_dataset_lengths(train_dataset),
 
 
 
 
 
355
  )
356
 
357
  data_loader = DataLoader(
358
  train_dataset.remove_columns(["length"]),
359
  batch_sampler=sampler,
360
  )
361
- data_loader_len = len(data_loader) // (
362
- cfg.world_size * cfg.gradient_accumulation_steps
363
- )
364
- actual_eff = sampler.efficiency()
365
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
366
  # FIXME: is there a bug here somewhere? the total num steps depends
367
  # on the agreed on value for sample_packing_eff_est
@@ -372,7 +371,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
372
  return max(estimates)
373
 
374
  sample_packing_actual_eff_all = reduce_and_broadcast(
375
- lambda: actual_eff,
376
  calc_sample_packing_eff_est,
377
  )
378
  sample_packing_eff_est = (
 
341
  )
342
  else:
343
  if cfg.flash_attention:
344
+ sampler_batch_size = 1
345
  batch_max_len = cfg.micro_batch_size * cfg.sequence_len
346
  else:
347
+ sampler_batch_size = cfg.micro_batch_size
348
  batch_max_len = cfg.sequence_len
349
  sampler = MultipackBatchSampler(
350
  sampler=RandomSampler(train_dataset),
 
 
 
351
  lengths=get_dataset_lengths(train_dataset),
352
+ batch_size=sampler_batch_size,
353
+ batch_max_len=batch_max_len,
354
+ group_size=cfg.sample_packing_group_size,
355
+ bin_size=cfg.sample_packing_bin_size,
356
+ drop_last=True,
357
  )
358
 
359
  data_loader = DataLoader(
360
  train_dataset.remove_columns(["length"]),
361
  batch_sampler=sampler,
362
  )
363
+ data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
 
 
 
364
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
365
  # FIXME: is there a bug here somewhere? the total num steps depends
366
  # on the agreed on value for sample_packing_eff_est
 
371
  return max(estimates)
372
 
373
  sample_packing_actual_eff_all = reduce_and_broadcast(
374
+ lambda: sampler.efficiency(), # pylint: disable=unnecessary-lambda
375
  calc_sample_packing_eff_est,
376
  )
377
  sample_packing_eff_est = (
tests/test_packed_batch_sampler.py CHANGED
@@ -62,12 +62,14 @@ class TestBatchedSamplerPacking:
62
  dataset,
63
  )
64
  train_dataset = concatenate_datasets([dataset_wrapper])
 
65
  batch_sampler = MultipackBatchSampler(
66
  sampler=RandomSampler(train_dataset),
 
67
  batch_size=batch_size,
68
- drop_last=True,
69
  batch_max_len=max_seq_length,
70
- lengths=get_dataset_lengths(train_dataset),
 
71
  )
72
 
73
  loader = DataLoader(
@@ -81,19 +83,15 @@ class TestBatchedSamplerPacking:
81
  ),
82
  num_workers=num_workers,
83
  )
84
- inputs = next(iter(loader))
85
 
86
- assert inputs["input_ids"].shape == (batch_size, max_seq_length)
87
- assert inputs["labels"].shape == (batch_size, max_seq_length)
88
- assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
 
89
 
90
- assert inputs["input_ids"].tolist()[0][0] == 2
91
- assert inputs["labels"].tolist()[0][0] == -100
92
- assert inputs["attention_mask"].tolist()[0][0] == 0
93
- assert inputs["attention_mask"].tolist()[0][-1] > 1
94
 
95
- if batch_size >= 2:
96
- assert inputs["input_ids"].tolist()[1][0] == 2
97
- assert inputs["labels"].tolist()[1][0] == -100
98
- assert inputs["attention_mask"].tolist()[1][0] == 0
99
- assert inputs["attention_mask"].tolist()[1][-1] > 1
 
62
  dataset,
63
  )
64
  train_dataset = concatenate_datasets([dataset_wrapper])
65
+ lengths = get_dataset_lengths(train_dataset)
66
  batch_sampler = MultipackBatchSampler(
67
  sampler=RandomSampler(train_dataset),
68
+ lengths=lengths,
69
  batch_size=batch_size,
 
70
  batch_max_len=max_seq_length,
71
+ group_size=100000,
72
+ bin_size=200,
73
  )
74
 
75
  loader = DataLoader(
 
83
  ),
84
  num_workers=num_workers,
85
  )
 
86
 
87
+ batch_idxs = []
88
+ for batch in batch_sampler:
89
+ for pack in batch:
90
+ batch_idxs.extend(pack)
91
 
92
+ for batch in loader:
93
+ assert len(batch["input_ids"]) <= batch_size * max_seq_length
94
+ assert batch["input_ids"].shape[1] == max_seq_length
 
95
 
96
+ original_idxs = set(range(len(train_dataset)))
97
+ assert original_idxs == set(batch_idxs)
 
 
 
tests/test_packed_pretraining.py CHANGED
@@ -42,6 +42,8 @@ class TestPretrainingPacking(unittest.TestCase):
42
  "pad_to_sequence_len": True,
43
  "sequence_len": 2048,
44
  "micro_batch_size": 2,
 
 
45
  }
46
  )
47
 
 
42
  "pad_to_sequence_len": True,
43
  "sequence_len": 2048,
44
  "micro_batch_size": 2,
45
+ "sample_packing_group_size": 100000,
46
+ "sample_packing_bin_size": 200,
47
  }
48
  )
49