Nanobit commited on
Commit
3db5f2f
·
unverified ·
1 Parent(s): cbecf3e

feat(dataset): add config to keep processed dataset in memory (#1152)

Browse files
Files changed (3) hide show
  1. README.md +3 -0
  2. src/axolotl/datasets.py +7 -6
  3. src/axolotl/utils/data.py +15 -10
README.md CHANGED
@@ -618,6 +618,9 @@ push_dataset_to_hub: # repo path
618
  # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
619
  # if not set.
620
  dataset_processes: # defaults to os.cpu_count() if not set
 
 
 
621
  # push checkpoints to hub
622
  hub_model_id: # repo path to push finetuned model
623
  # how to push checkpoints to hub
 
618
  # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
619
  # if not set.
620
  dataset_processes: # defaults to os.cpu_count() if not set
621
+ # Keep dataset in memory while preprocessing
622
+ # Only needed if cached dataset is taking too much storage
623
+ dataset_keep_in_memory:
624
  # push checkpoints to hub
625
  hub_model_id: # repo path to push finetuned model
626
  # how to push checkpoints to hub
src/axolotl/datasets.py CHANGED
@@ -24,6 +24,8 @@ class TokenizedPromptDataset(Dataset):
24
  Args:
25
  prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
26
  dataset (dataset.Dataset): Dataset with text files.
 
 
27
  """
28
 
29
  def __init__( # pylint: disable=super-init-not-called
@@ -31,10 +33,12 @@ class TokenizedPromptDataset(Dataset):
31
  prompt_tokenizer: PromptTokenizingStrategy,
32
  dataset: IterableDataset,
33
  process_count: Optional[int] = None,
 
34
  **kwargs,
35
  ):
36
  self.prompt_tokenizer = prompt_tokenizer
37
  self.process_count = process_count
 
38
  super().__init__(
39
  self.process(dataset).data,
40
  **kwargs,
@@ -42,11 +46,8 @@ class TokenizedPromptDataset(Dataset):
42
 
43
  def process(self, dataset):
44
  features = dataset.features.keys()
45
- num_proc = (
46
- min(64, self.process_count)
47
- if self.process_count
48
- else min(64, os.cpu_count())
49
- )
50
  map_kwargs = {}
51
  if self.prompt_tokenizer.supports_batched:
52
  map_kwargs["batched"] = True
@@ -55,7 +56,7 @@ class TokenizedPromptDataset(Dataset):
55
  self.prompt_tokenizer.tokenize_prompt,
56
  num_proc=num_proc,
57
  remove_columns=features,
58
- keep_in_memory=True,
59
  **map_kwargs,
60
  )
61
 
 
24
  Args:
25
  prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
26
  dataset (dataset.Dataset): Dataset with text files.
27
+ process_count (int): Number of processes to use for tokenizing.
28
+ keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
29
  """
30
 
31
  def __init__( # pylint: disable=super-init-not-called
 
33
  prompt_tokenizer: PromptTokenizingStrategy,
34
  dataset: IterableDataset,
35
  process_count: Optional[int] = None,
36
+ keep_in_memory: Optional[bool] = False,
37
  **kwargs,
38
  ):
39
  self.prompt_tokenizer = prompt_tokenizer
40
  self.process_count = process_count
41
+ self.keep_in_memory = keep_in_memory
42
  super().__init__(
43
  self.process(dataset).data,
44
  **kwargs,
 
46
 
47
  def process(self, dataset):
48
  features = dataset.features.keys()
49
+ num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
50
+
 
 
 
51
  map_kwargs = {}
52
  if self.prompt_tokenizer.supports_batched:
53
  map_kwargs["batched"] = True
 
56
  self.prompt_tokenizer.tokenize_prompt,
57
  num_proc=num_proc,
58
  remove_columns=features,
59
+ keep_in_memory=self.keep_in_memory,
60
  **map_kwargs,
61
  )
62
 
src/axolotl/utils/data.py CHANGED
@@ -588,6 +588,11 @@ def get_dataset_wrapper(
588
  dataset_wrapper = None
589
  dataset_prompter = None
590
 
 
 
 
 
 
591
  if (
592
  "input_ids" in dataset.features
593
  and "attention_mask" in dataset.features
@@ -604,14 +609,14 @@ def get_dataset_wrapper(
604
  dataset_wrapper = TokenizedPromptDataset(
605
  ds_strategy,
606
  dataset,
607
- process_count=cfg.dataset_processes,
608
  )
609
  elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
610
  dataset_prompter = UnsupportedPrompter()
611
  dataset_wrapper = TokenizedPromptDataset(
612
  ds_strategy,
613
  dataset,
614
- process_count=cfg.dataset_processes,
615
  )
616
  elif d_base_type == "alpaca":
617
  dataset_prompter = AlpacaPrompter(d_prompt_style)
@@ -624,7 +629,7 @@ def get_dataset_wrapper(
624
  ds_wrapper = TokenizedPromptDataset(
625
  ds_strategy,
626
  dataset,
627
- process_count=cfg.dataset_processes,
628
  )
629
  dataset_wrapper = ds_wrapper
630
  elif d_base_type == "explainchoice":
@@ -638,7 +643,7 @@ def get_dataset_wrapper(
638
  ds_wrapper = TokenizedPromptDataset(
639
  ds_strategy,
640
  dataset,
641
- process_count=cfg.dataset_processes,
642
  )
643
  dataset_wrapper = ds_wrapper
644
  elif d_base_type == "concisechoice":
@@ -652,7 +657,7 @@ def get_dataset_wrapper(
652
  ds_wrapper = TokenizedPromptDataset(
653
  ds_strategy,
654
  dataset,
655
- process_count=cfg.dataset_processes,
656
  )
657
  dataset_wrapper = ds_wrapper
658
  elif d_base_type == "summarizetldr":
@@ -666,7 +671,7 @@ def get_dataset_wrapper(
666
  ds_wrapper = TokenizedPromptDataset(
667
  ds_strategy,
668
  dataset,
669
- process_count=cfg.dataset_processes,
670
  )
671
  dataset_wrapper = ds_wrapper
672
  elif d_base_type == "jeopardy":
@@ -680,7 +685,7 @@ def get_dataset_wrapper(
680
  ds_wrapper = TokenizedPromptDataset(
681
  ds_strategy,
682
  dataset,
683
- process_count=cfg.dataset_processes,
684
  )
685
  dataset_wrapper = ds_wrapper
686
  elif d_base_type == "oasst":
@@ -694,7 +699,7 @@ def get_dataset_wrapper(
694
  ds_wrapper = TokenizedPromptDataset(
695
  ds_strategy,
696
  dataset,
697
- process_count=cfg.dataset_processes,
698
  )
699
  dataset_wrapper = ds_wrapper
700
  elif d_base_type == "gpteacher":
@@ -708,7 +713,7 @@ def get_dataset_wrapper(
708
  ds_wrapper = TokenizedPromptDataset(
709
  ds_strategy,
710
  dataset,
711
- process_count=cfg.dataset_processes,
712
  )
713
  dataset_wrapper = ds_wrapper
714
  elif d_base_type == "reflection":
@@ -722,7 +727,7 @@ def get_dataset_wrapper(
722
  ds_wrapper = TokenizedPromptDataset(
723
  ds_strategy,
724
  dataset,
725
- process_count=cfg.dataset_processes,
726
  )
727
  dataset_wrapper = ds_wrapper
728
  else:
 
588
  dataset_wrapper = None
589
  dataset_prompter = None
590
 
591
+ ds_kwargs = {
592
+ "process_count": cfg.dataset_processes,
593
+ "keep_in_memory": cfg.dataset_keep_in_memory is True,
594
+ }
595
+
596
  if (
597
  "input_ids" in dataset.features
598
  and "attention_mask" in dataset.features
 
609
  dataset_wrapper = TokenizedPromptDataset(
610
  ds_strategy,
611
  dataset,
612
+ **ds_kwargs,
613
  )
614
  elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
615
  dataset_prompter = UnsupportedPrompter()
616
  dataset_wrapper = TokenizedPromptDataset(
617
  ds_strategy,
618
  dataset,
619
+ **ds_kwargs,
620
  )
621
  elif d_base_type == "alpaca":
622
  dataset_prompter = AlpacaPrompter(d_prompt_style)
 
629
  ds_wrapper = TokenizedPromptDataset(
630
  ds_strategy,
631
  dataset,
632
+ **ds_kwargs,
633
  )
634
  dataset_wrapper = ds_wrapper
635
  elif d_base_type == "explainchoice":
 
643
  ds_wrapper = TokenizedPromptDataset(
644
  ds_strategy,
645
  dataset,
646
+ **ds_kwargs,
647
  )
648
  dataset_wrapper = ds_wrapper
649
  elif d_base_type == "concisechoice":
 
657
  ds_wrapper = TokenizedPromptDataset(
658
  ds_strategy,
659
  dataset,
660
+ **ds_kwargs,
661
  )
662
  dataset_wrapper = ds_wrapper
663
  elif d_base_type == "summarizetldr":
 
671
  ds_wrapper = TokenizedPromptDataset(
672
  ds_strategy,
673
  dataset,
674
+ **ds_kwargs,
675
  )
676
  dataset_wrapper = ds_wrapper
677
  elif d_base_type == "jeopardy":
 
685
  ds_wrapper = TokenizedPromptDataset(
686
  ds_strategy,
687
  dataset,
688
+ **ds_kwargs,
689
  )
690
  dataset_wrapper = ds_wrapper
691
  elif d_base_type == "oasst":
 
699
  ds_wrapper = TokenizedPromptDataset(
700
  ds_strategy,
701
  dataset,
702
+ **ds_kwargs,
703
  )
704
  dataset_wrapper = ds_wrapper
705
  elif d_base_type == "gpteacher":
 
713
  ds_wrapper = TokenizedPromptDataset(
714
  ds_strategy,
715
  dataset,
716
+ **ds_kwargs,
717
  )
718
  dataset_wrapper = ds_wrapper
719
  elif d_base_type == "reflection":
 
727
  ds_wrapper = TokenizedPromptDataset(
728
  ds_strategy,
729
  dataset,
730
+ **ds_kwargs,
731
  )
732
  dataset_wrapper = ds_wrapper
733
  else: