ricdomolm winglian commited on
Commit
81d3845
1 Parent(s): 732851f

Efficiently get the length of the tokenized docs (#1063)

Browse files

* Efficiently get the length of the tokenized docs

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/core/trainer_builder.py CHANGED
@@ -37,7 +37,7 @@ from axolotl.utils.collators import (
37
  DataCollatorForSeq2Seq,
38
  MambaDataCollator,
39
  )
40
- from axolotl.utils.samplers import MultipackBatchSampler
41
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
42
 
43
  try:
@@ -170,12 +170,7 @@ class AxolotlTrainer(Trainer):
170
  self.args.train_batch_size,
171
  drop_last=True,
172
  batch_max_len=self._train_batch_size * self.args.max_seq_length,
173
- lengths=(
174
- self.train_dataset.data.column("position_ids")
175
- .to_pandas()
176
- .apply(lambda x: x[-1] + 1)
177
- .values
178
- ),
179
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
180
  )
181
  return super()._get_train_sampler()
@@ -189,12 +184,7 @@ class AxolotlTrainer(Trainer):
189
  self.args.per_device_eval_batch_size,
190
  drop_last=True,
191
  batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
192
- lengths=(
193
- eval_dataset.data.column("position_ids")
194
- .to_pandas()
195
- .apply(lambda x: x[-1] + 1)
196
- .values
197
- ),
198
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
199
  )
200
  return super()._get_eval_sampler(eval_dataset)
 
37
  DataCollatorForSeq2Seq,
38
  MambaDataCollator,
39
  )
40
+ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
41
  from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
42
 
43
  try:
 
170
  self.args.train_batch_size,
171
  drop_last=True,
172
  batch_max_len=self._train_batch_size * self.args.max_seq_length,
173
+ lengths=get_dataset_lengths(self.train_dataset),
 
 
 
 
 
174
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
175
  )
176
  return super()._get_train_sampler()
 
184
  self.args.per_device_eval_batch_size,
185
  drop_last=True,
186
  batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
187
+ lengths=get_dataset_lengths(eval_dataset),
 
 
 
 
 
188
  packing_efficiency_estimate=self.args.sample_packing_efficiency,
189
  )
190
  return super()._get_eval_sampler(eval_dataset)
src/axolotl/utils/data.py CHANGED
@@ -44,7 +44,7 @@ from axolotl.prompters import (
44
  from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
45
  from axolotl.utils.dict import DictDefault
46
  from axolotl.utils.distributed import is_main_process, zero_first
47
- from axolotl.utils.samplers.multipack import MultipackBatchSampler
48
  from axolotl.utils.trainer import (
49
  calculate_total_num_steps,
50
  process_datasets_for_packing,
@@ -889,12 +889,7 @@ def encode_packed_pretraining(
889
  batch_size=batch_size,
890
  drop_last=True,
891
  batch_max_len=batch_size * max_seq_length,
892
- lengths=(
893
- train_dataset.data.column("position_ids")
894
- .to_pandas()
895
- .apply(lambda x: x[-1] + 1)
896
- .values
897
- ),
898
  )
899
 
900
  chunked_data = defaultdict(list)
 
44
  from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
45
  from axolotl.utils.dict import DictDefault
46
  from axolotl.utils.distributed import is_main_process, zero_first
47
+ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
48
  from axolotl.utils.trainer import (
49
  calculate_total_num_steps,
50
  process_datasets_for_packing,
 
889
  batch_size=batch_size,
890
  drop_last=True,
891
  batch_max_len=batch_size * max_seq_length,
892
+ lengths=get_dataset_lengths(train_dataset),
 
 
 
 
 
893
  )
894
 
895
  chunked_data = defaultdict(list)
src/axolotl/utils/samplers/__init__.py CHANGED
@@ -2,3 +2,4 @@
2
  axolotl samplers module
3
  """
4
  from .multipack import MultipackBatchSampler # noqa: F401
 
 
2
  axolotl samplers module
3
  """
4
  from .multipack import MultipackBatchSampler # noqa: F401
5
+ from .utils import get_dataset_lengths # noqa: F401
src/axolotl/utils/samplers/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ helper util to calculate dataset lengths
3
+ """
4
+ import numpy as np
5
+
6
+
7
+ def get_dataset_lengths(dataset):
8
+ if "length" in dataset.data.column_names:
9
+ lengths = np.array(dataset.data.column("length"))
10
+ else:
11
+ lengths = (
12
+ dataset.data.column("position_ids")
13
+ .to_pandas()
14
+ .apply(lambda x: x[-1] + 1)
15
+ .values
16
+ )
17
+ return lengths
src/axolotl/utils/trainer.py CHANGED
@@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, RandomSampler
14
 
15
  from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
16
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
17
- from axolotl.utils.samplers import MultipackBatchSampler
18
 
19
  LOG = get_logger("axolotl")
20
 
@@ -212,12 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
212
  drop_last=True,
213
  batch_max_len=cfg.micro_batch_size
214
  * (cfg.max_packed_sequence_len or cfg.sequence_len),
215
- lengths=(
216
- train_dataset.data.column("position_ids")
217
- .to_pandas()
218
- .apply(lambda x: x[-1] + 1)
219
- .values
220
- ),
221
  )
222
 
223
  data_loader = DataLoader(
 
14
 
15
  from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
16
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
17
+ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
18
 
19
  LOG = get_logger("axolotl")
20
 
 
212
  drop_last=True,
213
  batch_max_len=cfg.micro_batch_size
214
  * (cfg.max_packed_sequence_len or cfg.sequence_len),
215
+ lengths=get_dataset_lengths(train_dataset),
 
 
 
 
 
216
  )
217
 
218
  data_loader = DataLoader(