winglian commited on
Commit
2ce5c0d
1 Parent(s): 3db5f2f

Deprecate max packed sequence len (#1141)

Browse files
README.md CHANGED
@@ -642,10 +642,6 @@ sequence_len: 2048
642
  # Pad inputs so each step uses constant sized buffers
643
  # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
644
  pad_to_sequence_len:
645
- # Max sequence length to concatenate training samples together up to
646
- # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
647
- # FutureWarning: This will soon be DEPRECATED
648
- max_packed_sequence_len: 1024
649
  # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
650
  sample_packing:
651
  # Set to 'false' if getting errors during eval with sample_packing on.
 
642
  # Pad inputs so each step uses constant sized buffers
643
  # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
644
  pad_to_sequence_len:
 
 
 
 
645
  # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
646
  sample_packing:
647
  # Set to 'false' if getting errors during eval with sample_packing on.
src/axolotl/utils/config.py CHANGED
@@ -157,6 +157,9 @@ def normalize_config(cfg):
157
  if isinstance(cfg.learning_rate, str):
158
  cfg.learning_rate = float(cfg.learning_rate)
159
 
 
 
 
160
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
161
 
162
 
@@ -192,18 +195,8 @@ def validate_config(cfg):
192
  raise ValueError(
193
  "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
194
  )
195
- if cfg.max_packed_sequence_len and cfg.sample_packing:
196
- raise ValueError(
197
- "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
198
- )
199
  if cfg.max_packed_sequence_len:
200
- LOG.warning(
201
- str(
202
- PendingDeprecationWarning(
203
- "max_packed_sequence_len will be deprecated in favor of sample_packing"
204
- )
205
- )
206
- )
207
 
208
  if cfg.sample_packing and not cfg.pad_to_sequence_len:
209
  LOG.warning(
 
157
  if isinstance(cfg.learning_rate, str):
158
  cfg.learning_rate = float(cfg.learning_rate)
159
 
160
+ if isinstance(cfg.pretraining_dataset, dict):
161
+ cfg.pretraining_dataset = [cfg.pretraining_dataset]
162
+
163
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
164
 
165
 
 
195
  raise ValueError(
196
  "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
197
  )
 
 
 
 
198
  if cfg.max_packed_sequence_len:
199
+ raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
 
 
 
 
 
 
200
 
201
  if cfg.sample_packing and not cfg.pad_to_sequence_len:
202
  LOG.warning(
src/axolotl/utils/data.py CHANGED
@@ -19,7 +19,7 @@ from torch.utils.data import RandomSampler
19
  from transformers import PreTrainedTokenizerBase
20
 
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
22
- from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
23
  from axolotl.prompt_strategies import load
24
  from axolotl.prompt_tokenizers import (
25
  AlpacaMultipleChoicePromptTokenizingStrategy,
@@ -71,9 +71,11 @@ def prepare_dataset(cfg, tokenizer):
71
  else:
72
  path = cfg.pretraining_dataset
73
  name = None
74
- if isinstance(cfg.pretraining_dataset, dict):
75
- path = cfg.pretraining_dataset["path"]
76
- name = cfg.pretraining_dataset["name"]
 
 
77
 
78
  train_dataset = load_pretraining_dataset(
79
  path,
@@ -88,11 +90,6 @@ def prepare_dataset(cfg, tokenizer):
88
  eval_dataset = None
89
  return train_dataset, eval_dataset, cfg.max_steps, prompters
90
 
91
- with zero_first(is_main_process()):
92
- train_dataset, eval_dataset = process_datasets_for_packing(
93
- cfg, train_dataset, eval_dataset, tokenizer
94
- )
95
-
96
  if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
97
  total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
98
  if total_eval_steps == 0:
@@ -163,6 +160,10 @@ def load_tokenized_prepared_datasets(
163
  else:
164
  LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
165
  LOG.info("Loading raw datasets...")
 
 
 
 
166
 
167
  if cfg.seed:
168
  seed = cfg.seed
@@ -382,6 +383,9 @@ def load_tokenized_prepared_datasets(
382
  if len(datasets) > 1:
383
  LOG.info("shuffle merged datasets")
384
  dataset = dataset.shuffle(seed=seed)
 
 
 
385
  if cfg.local_rank == 0:
386
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
387
  dataset.save_to_disk(prepared_ds_path)
@@ -419,119 +423,9 @@ def load_prepare_datasets(
419
  cfg,
420
  default_dataset_prepared_path,
421
  ) -> Tuple[Dataset, Dataset, List[Prompter]]:
422
- max_packed_sequence_len = (
423
- cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
424
  )
425
- max_packed_sequence_len = min(
426
- max_packed_sequence_len, cfg.sequence_len
427
- ) # make sure we don't accidentally set it larger than sequence_len
428
-
429
- tokenizer_name = tokenizer.__class__.__name__
430
- prompters: List[Prompter] = []
431
- if cfg.max_packed_sequence_len is not None:
432
- # see if we can go ahead and load the stacked dataset
433
- seed = f"@{str(cfg.seed)}" if cfg.seed else ""
434
- ds_hash = str(
435
- md5(
436
- (
437
- str(cfg.sequence_len)
438
- + "@"
439
- + str(max_packed_sequence_len)
440
- + seed
441
- + "|".join(
442
- sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
443
- )
444
- + "|"
445
- + tokenizer_name
446
- )
447
- )
448
- )
449
- prepared_ds_path = (
450
- Path(cfg.dataset_prepared_path) / ds_hash
451
- if cfg.dataset_prepared_path
452
- else Path(default_dataset_prepared_path) / ds_hash
453
- )
454
-
455
- dataset = None
456
- use_auth_token = cfg.hf_use_auth_token
457
- try:
458
- if cfg.push_dataset_to_hub:
459
- LOG.info(
460
- f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
461
- )
462
- dataset = load_dataset(
463
- f"{cfg.push_dataset_to_hub}/{ds_hash}",
464
- token=use_auth_token,
465
- )
466
- dataset = dataset["train"]
467
- except Exception: # pylint: disable=broad-except # nosec
468
- pass
469
-
470
- if dataset:
471
- ...
472
- elif (
473
- cfg.dataset_prepared_path
474
- and any(prepared_ds_path.glob("*"))
475
- and not cfg.is_preprocess
476
- ):
477
- LOG.info(
478
- f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
479
- )
480
- dataset = load_from_disk(str(prepared_ds_path))
481
- LOG.info("Prepared packed dataset loaded from disk...")
482
- if cfg.push_dataset_to_hub:
483
- LOG.info(
484
- f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
485
- )
486
- dataset.push_to_hub(
487
- f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
488
- )
489
- else:
490
- dataset, prompters = load_tokenized_prepared_datasets(
491
- tokenizer, cfg, default_dataset_prepared_path
492
- )
493
-
494
- if cfg.seed:
495
- dataset = dataset.shuffle(seed=cfg.seed)
496
-
497
- constant_len_dataset = ConstantLengthDataset(
498
- tokenizer,
499
- [dataset],
500
- seq_length=max_packed_sequence_len,
501
- )
502
- LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
503
- dataset = Dataset.from_list(list(constant_len_dataset))
504
-
505
- # filter out bad data
506
- # TODO convert to dataset.filter(...)
507
- dataset = Dataset.from_list(
508
- [
509
- d
510
- for d in dataset
511
- if len(d["input_ids"]) <= cfg.sequence_len
512
- and len(d["input_ids"]) > 0
513
- and len(d["input_ids"]) == len(d["attention_mask"])
514
- and len(d["input_ids"]) == len(d["labels"])
515
- ]
516
- )
517
-
518
- if cfg.local_rank == 0:
519
- LOG.info(
520
- f"Saving packed prepared dataset to disk... {prepared_ds_path}"
521
- )
522
- dataset.save_to_disk(prepared_ds_path)
523
- if cfg.push_dataset_to_hub:
524
- LOG.info(
525
- f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
526
- )
527
- dataset.push_to_hub(
528
- f"{cfg.push_dataset_to_hub}/{ds_hash}",
529
- private=True,
530
- )
531
- else:
532
- dataset, prompters = load_tokenized_prepared_datasets(
533
- tokenizer, cfg, default_dataset_prepared_path
534
- )
535
 
536
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
537
  LOG.info(
@@ -877,6 +771,7 @@ def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, s
877
  dataset = dataset.map(
878
  encode,
879
  batched=True,
 
880
  input_columns="text",
881
  # remove all the existing columns after mapping since they end up having
882
  # a different length than the encoded/tokenized column
 
19
  from transformers import PreTrainedTokenizerBase
20
 
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
22
+ from axolotl.datasets import TokenizedPromptDataset
23
  from axolotl.prompt_strategies import load
24
  from axolotl.prompt_tokenizers import (
25
  AlpacaMultipleChoicePromptTokenizingStrategy,
 
71
  else:
72
  path = cfg.pretraining_dataset
73
  name = None
74
+ if isinstance(cfg.pretraining_dataset, list) and isinstance(
75
+ cfg.pretraining_dataset[0], dict
76
+ ):
77
+ path = cfg.pretraining_dataset[0]["path"]
78
+ name = cfg.pretraining_dataset[0]["name"]
79
 
80
  train_dataset = load_pretraining_dataset(
81
  path,
 
90
  eval_dataset = None
91
  return train_dataset, eval_dataset, cfg.max_steps, prompters
92
 
 
 
 
 
 
93
  if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
94
  total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
95
  if total_eval_steps == 0:
 
160
  else:
161
  LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
162
  LOG.info("Loading raw datasets...")
163
+ if not cfg.is_preprocess:
164
+ LOG.warning(
165
+ "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
166
+ )
167
 
168
  if cfg.seed:
169
  seed = cfg.seed
 
383
  if len(datasets) > 1:
384
  LOG.info("shuffle merged datasets")
385
  dataset = dataset.shuffle(seed=seed)
386
+
387
+ dataset, _ = process_datasets_for_packing(cfg, dataset, None, tokenizer)
388
+
389
  if cfg.local_rank == 0:
390
  LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
391
  dataset.save_to_disk(prepared_ds_path)
 
423
  cfg,
424
  default_dataset_prepared_path,
425
  ) -> Tuple[Dataset, Dataset, List[Prompter]]:
426
+ dataset, prompters = load_tokenized_prepared_datasets(
427
+ tokenizer, cfg, default_dataset_prepared_path
428
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
431
  LOG.info(
 
771
  dataset = dataset.map(
772
  encode,
773
  batched=True,
774
+ batch_size=10_000,
775
  input_columns="text",
776
  # remove all the existing columns after mapping since they end up having
777
  # a different length than the encoded/tokenized column
src/axolotl/utils/models.py CHANGED
@@ -329,11 +329,7 @@ def load_model(
329
  LOG.info("patching mixtral with flash attention")
330
  replace_mixtral_attn_with_multipack_flash_attn()
331
 
332
- if (
333
- cfg.is_llama_derived_model
334
- and (cfg.max_packed_sequence_len or cfg.sample_packing)
335
- and not inference
336
- ):
337
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
338
 
339
  LOG.info("patching _expand_mask")
 
329
  LOG.info("patching mixtral with flash attention")
330
  replace_mixtral_attn_with_multipack_flash_attn()
331
 
332
+ if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
 
 
 
 
333
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
334
 
335
  LOG.info("patching _expand_mask")
src/axolotl/utils/trainer.py CHANGED
@@ -81,6 +81,15 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
81
  return weighted_cross_entropy(logits, labels, weights)
82
 
83
 
 
 
 
 
 
 
 
 
 
84
  def add_position_ids(sample):
85
  sample_len = len(sample["input_ids"])
86
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
@@ -97,15 +106,6 @@ def drop_long_seq(sample, sequence_len=2048):
97
  return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
98
 
99
 
100
- @contextmanager
101
- def disable_datasets_caching():
102
- try:
103
- set_caching_enabled(False)
104
- yield
105
- finally:
106
- set_caching_enabled(True)
107
-
108
-
109
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
110
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
111
  with zero_first(is_main_process()):
@@ -227,8 +227,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
227
  sampler=RandomSampler(train_dataset),
228
  batch_size=cfg.micro_batch_size,
229
  drop_last=True,
230
- batch_max_len=cfg.micro_batch_size
231
- * (cfg.max_packed_sequence_len or cfg.sequence_len),
232
  lengths=get_dataset_lengths(train_dataset),
233
  )
234
 
 
81
  return weighted_cross_entropy(logits, labels, weights)
82
 
83
 
84
+ @contextmanager
85
+ def disable_datasets_caching():
86
+ try:
87
+ set_caching_enabled(False)
88
+ yield
89
+ finally:
90
+ set_caching_enabled(True)
91
+
92
+
93
  def add_position_ids(sample):
94
  sample_len = len(sample["input_ids"])
95
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
 
106
  return len(sample["input_ids"]) <= sequence_len and len(sample["input_ids"]) > 0
107
 
108
 
 
 
 
 
 
 
 
 
 
109
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
110
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
111
  with zero_first(is_main_process()):
 
227
  sampler=RandomSampler(train_dataset),
228
  batch_size=cfg.micro_batch_size,
229
  drop_last=True,
230
+ batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
 
231
  lengths=get_dataset_lengths(train_dataset),
232
  )
233
 
tests/test_validation.py CHANGED
@@ -324,20 +324,19 @@ class ValidationTest(BaseValidation):
324
 
325
  validate_config(cfg)
326
 
327
- def test_packing(self):
328
  cfg = DictDefault(
329
  {
330
- "max_packed_sequence_len": 2048,
331
  }
332
  )
333
- with self._caplog.at_level(logging.WARNING):
 
 
 
334
  validate_config(cfg)
335
- assert any(
336
- "max_packed_sequence_len will be deprecated in favor of sample_packing"
337
- in record.message
338
- for record in self._caplog.records
339
- )
340
 
 
341
  cfg = DictDefault(
342
  {
343
  "sample_packing": True,
@@ -352,16 +351,6 @@ class ValidationTest(BaseValidation):
352
  for record in self._caplog.records
353
  )
354
 
355
- cfg = DictDefault(
356
- {
357
- "max_packed_sequence_len": 2048,
358
- "sample_packing": True,
359
- }
360
- )
361
- regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
362
- with pytest.raises(ValueError, match=regex_exp):
363
- validate_config(cfg)
364
-
365
  @pytest.mark.skipif(
366
  is_torch_bf16_gpu_available(),
367
  reason="test should only run on gpus w/o bf16 support",
 
324
 
325
  validate_config(cfg)
326
 
327
+ def test_deprecated_packing(self):
328
  cfg = DictDefault(
329
  {
330
+ "max_packed_sequence_len": 1024,
331
  }
332
  )
333
+ with pytest.raises(
334
+ DeprecationWarning,
335
+ match=r"`max_packed_sequence_len` is no longer supported",
336
+ ):
337
  validate_config(cfg)
 
 
 
 
 
338
 
339
+ def test_packing(self):
340
  cfg = DictDefault(
341
  {
342
  "sample_packing": True,
 
351
  for record in self._caplog.records
352
  )
353
 
 
 
 
 
 
 
 
 
 
 
354
  @pytest.mark.skipif(
355
  is_torch_bf16_gpu_available(),
356
  reason="test should only run on gpus w/o bf16 support",