winglian commited on
Commit
7570446
1 Parent(s): ece0211

Preprocess dataset size fix (#1131)

Browse files

* overwrite cache on preprocess step
* don't cache the TokenizedPromptDataset at all
* load_from_cache_file no longer needed

src/axolotl/cli/preprocess.py CHANGED
@@ -25,6 +25,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
25
  # pylint: disable=duplicate-code
26
  print_axolotl_text_art()
27
  parsed_cfg = load_cfg(config, **kwargs)
 
28
  check_accelerate_default_config()
29
  check_user_token()
30
  parser = transformers.HfArgumentParser((PreprocessCliArgs))
 
25
  # pylint: disable=duplicate-code
26
  print_axolotl_text_art()
27
  parsed_cfg = load_cfg(config, **kwargs)
28
+ parsed_cfg.is_preprocess = True
29
  check_accelerate_default_config()
30
  check_user_token()
31
  parser = transformers.HfArgumentParser((PreprocessCliArgs))
src/axolotl/datasets.py CHANGED
@@ -35,7 +35,10 @@ class TokenizedPromptDataset(Dataset):
35
  ):
36
  self.prompt_tokenizer = prompt_tokenizer
37
  self.process_count = process_count
38
- super().__init__(self.process(dataset).data, **kwargs)
 
 
 
39
 
40
  def process(self, dataset):
41
  features = dataset.features.keys()
@@ -52,6 +55,7 @@ class TokenizedPromptDataset(Dataset):
52
  self.prompt_tokenizer.tokenize_prompt,
53
  num_proc=num_proc,
54
  remove_columns=features,
 
55
  **map_kwargs,
56
  )
57
 
 
35
  ):
36
  self.prompt_tokenizer = prompt_tokenizer
37
  self.process_count = process_count
38
+ super().__init__(
39
+ self.process(dataset).data,
40
+ **kwargs,
41
+ )
42
 
43
  def process(self, dataset):
44
  features = dataset.features.keys()
 
55
  self.prompt_tokenizer.tokenize_prompt,
56
  num_proc=num_proc,
57
  remove_columns=features,
58
+ keep_in_memory=True,
59
  **map_kwargs,
60
  )
61
 
src/axolotl/utils/data.py CHANGED
@@ -594,12 +594,16 @@ def get_dataset_wrapper(
594
  )
595
  dataset_prompter = UnsupportedPrompter()
596
  dataset_wrapper = TokenizedPromptDataset(
597
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
598
  )
599
  elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
600
  dataset_prompter = UnsupportedPrompter()
601
  dataset_wrapper = TokenizedPromptDataset(
602
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
603
  )
604
  elif d_base_type == "alpaca":
605
  dataset_prompter = AlpacaPrompter(d_prompt_style)
@@ -610,7 +614,9 @@ def get_dataset_wrapper(
610
  cfg.sequence_len,
611
  )
612
  ds_wrapper = TokenizedPromptDataset(
613
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
614
  )
615
  dataset_wrapper = ds_wrapper
616
  elif d_base_type == "explainchoice":
@@ -622,7 +628,9 @@ def get_dataset_wrapper(
622
  cfg.sequence_len,
623
  )
624
  ds_wrapper = TokenizedPromptDataset(
625
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
626
  )
627
  dataset_wrapper = ds_wrapper
628
  elif d_base_type == "concisechoice":
@@ -634,7 +642,9 @@ def get_dataset_wrapper(
634
  cfg.sequence_len,
635
  )
636
  ds_wrapper = TokenizedPromptDataset(
637
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
638
  )
639
  dataset_wrapper = ds_wrapper
640
  elif d_base_type == "summarizetldr":
@@ -646,7 +656,9 @@ def get_dataset_wrapper(
646
  cfg.sequence_len,
647
  )
648
  ds_wrapper = TokenizedPromptDataset(
649
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
650
  )
651
  dataset_wrapper = ds_wrapper
652
  elif d_base_type == "jeopardy":
@@ -658,7 +670,9 @@ def get_dataset_wrapper(
658
  cfg.sequence_len,
659
  )
660
  ds_wrapper = TokenizedPromptDataset(
661
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
662
  )
663
  dataset_wrapper = ds_wrapper
664
  elif d_base_type == "oasst":
@@ -670,7 +684,9 @@ def get_dataset_wrapper(
670
  cfg.sequence_len,
671
  )
672
  ds_wrapper = TokenizedPromptDataset(
673
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
674
  )
675
  dataset_wrapper = ds_wrapper
676
  elif d_base_type == "gpteacher":
@@ -682,7 +698,9 @@ def get_dataset_wrapper(
682
  cfg.sequence_len,
683
  )
684
  ds_wrapper = TokenizedPromptDataset(
685
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
686
  )
687
  dataset_wrapper = ds_wrapper
688
  elif d_base_type == "reflection":
@@ -694,7 +712,9 @@ def get_dataset_wrapper(
694
  cfg.sequence_len,
695
  )
696
  ds_wrapper = TokenizedPromptDataset(
697
- ds_strategy, dataset, process_count=cfg.dataset_processes
 
 
698
  )
699
  dataset_wrapper = ds_wrapper
700
  else:
 
594
  )
595
  dataset_prompter = UnsupportedPrompter()
596
  dataset_wrapper = TokenizedPromptDataset(
597
+ ds_strategy,
598
+ dataset,
599
+ process_count=cfg.dataset_processes,
600
  )
601
  elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
602
  dataset_prompter = UnsupportedPrompter()
603
  dataset_wrapper = TokenizedPromptDataset(
604
+ ds_strategy,
605
+ dataset,
606
+ process_count=cfg.dataset_processes,
607
  )
608
  elif d_base_type == "alpaca":
609
  dataset_prompter = AlpacaPrompter(d_prompt_style)
 
614
  cfg.sequence_len,
615
  )
616
  ds_wrapper = TokenizedPromptDataset(
617
+ ds_strategy,
618
+ dataset,
619
+ process_count=cfg.dataset_processes,
620
  )
621
  dataset_wrapper = ds_wrapper
622
  elif d_base_type == "explainchoice":
 
628
  cfg.sequence_len,
629
  )
630
  ds_wrapper = TokenizedPromptDataset(
631
+ ds_strategy,
632
+ dataset,
633
+ process_count=cfg.dataset_processes,
634
  )
635
  dataset_wrapper = ds_wrapper
636
  elif d_base_type == "concisechoice":
 
642
  cfg.sequence_len,
643
  )
644
  ds_wrapper = TokenizedPromptDataset(
645
+ ds_strategy,
646
+ dataset,
647
+ process_count=cfg.dataset_processes,
648
  )
649
  dataset_wrapper = ds_wrapper
650
  elif d_base_type == "summarizetldr":
 
656
  cfg.sequence_len,
657
  )
658
  ds_wrapper = TokenizedPromptDataset(
659
+ ds_strategy,
660
+ dataset,
661
+ process_count=cfg.dataset_processes,
662
  )
663
  dataset_wrapper = ds_wrapper
664
  elif d_base_type == "jeopardy":
 
670
  cfg.sequence_len,
671
  )
672
  ds_wrapper = TokenizedPromptDataset(
673
+ ds_strategy,
674
+ dataset,
675
+ process_count=cfg.dataset_processes,
676
  )
677
  dataset_wrapper = ds_wrapper
678
  elif d_base_type == "oasst":
 
684
  cfg.sequence_len,
685
  )
686
  ds_wrapper = TokenizedPromptDataset(
687
+ ds_strategy,
688
+ dataset,
689
+ process_count=cfg.dataset_processes,
690
  )
691
  dataset_wrapper = ds_wrapper
692
  elif d_base_type == "gpteacher":
 
698
  cfg.sequence_len,
699
  )
700
  ds_wrapper = TokenizedPromptDataset(
701
+ ds_strategy,
702
+ dataset,
703
+ process_count=cfg.dataset_processes,
704
  )
705
  dataset_wrapper = ds_wrapper
706
  elif d_base_type == "reflection":
 
712
  cfg.sequence_len,
713
  )
714
  ds_wrapper = TokenizedPromptDataset(
715
+ ds_strategy,
716
+ dataset,
717
+ process_count=cfg.dataset_processes,
718
  )
719
  dataset_wrapper = ds_wrapper
720
  else:
src/axolotl/utils/trainer.py CHANGED
@@ -111,27 +111,39 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
111
  with zero_first(is_main_process()):
112
  if cfg.group_by_length:
113
  train_dataset = train_dataset.map(
114
- add_length, num_proc=cfg.dataset_processes
 
 
115
  )
116
 
117
  if cfg.sample_packing:
118
  train_dataset = train_dataset.map(
119
- add_position_ids, num_proc=cfg.dataset_processes
 
 
120
  )
121
  if cfg.eval_sample_packing is not False:
122
  if eval_dataset:
123
  eval_dataset = eval_dataset.map(
124
- add_position_ids, num_proc=cfg.dataset_processes
 
 
125
  )
126
 
127
  if cfg.group_by_length or cfg.sample_packing:
128
  max_input_len = np.max(get_dataset_lengths(train_dataset))
129
  LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
130
 
131
- train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
 
 
 
 
132
  if eval_dataset:
133
  eval_dataset = eval_dataset.filter(
134
- drop_long, num_proc=cfg.dataset_processes
 
 
135
  )
136
 
137
  # Phi doesn't want the attention_mask feature when training
 
111
  with zero_first(is_main_process()):
112
  if cfg.group_by_length:
113
  train_dataset = train_dataset.map(
114
+ add_length,
115
+ num_proc=cfg.dataset_processes,
116
+ load_from_cache_file=not cfg.is_preprocess,
117
  )
118
 
119
  if cfg.sample_packing:
120
  train_dataset = train_dataset.map(
121
+ add_position_ids,
122
+ num_proc=cfg.dataset_processes,
123
+ load_from_cache_file=not cfg.is_preprocess,
124
  )
125
  if cfg.eval_sample_packing is not False:
126
  if eval_dataset:
127
  eval_dataset = eval_dataset.map(
128
+ add_position_ids,
129
+ num_proc=cfg.dataset_processes,
130
+ load_from_cache_file=not cfg.is_preprocess,
131
  )
132
 
133
  if cfg.group_by_length or cfg.sample_packing:
134
  max_input_len = np.max(get_dataset_lengths(train_dataset))
135
  LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
136
 
137
+ train_dataset = train_dataset.filter(
138
+ drop_long,
139
+ num_proc=cfg.dataset_processes,
140
+ load_from_cache_file=not cfg.is_preprocess,
141
+ )
142
  if eval_dataset:
143
  eval_dataset = eval_dataset.filter(
144
+ drop_long,
145
+ num_proc=cfg.dataset_processes,
146
+ load_from_cache_file=not cfg.is_preprocess,
147
  )
148
 
149
  # Phi doesn't want the attention_mask feature when training