Preprocess dataset size fix (#1131)
Browse files* overwrite cache on preprocess step
* don't cache the TokenizedPromptDataset at all
* load_from_cache_file no longer needed
- src/axolotl/cli/preprocess.py +1 -0
- src/axolotl/datasets.py +5 -1
- src/axolotl/utils/data.py +30 -10
- src/axolotl/utils/trainer.py +17 -5
src/axolotl/cli/preprocess.py
CHANGED
@@ -25,6 +25,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
25 |
# pylint: disable=duplicate-code
|
26 |
print_axolotl_text_art()
|
27 |
parsed_cfg = load_cfg(config, **kwargs)
|
|
|
28 |
check_accelerate_default_config()
|
29 |
check_user_token()
|
30 |
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
|
|
25 |
# pylint: disable=duplicate-code
|
26 |
print_axolotl_text_art()
|
27 |
parsed_cfg = load_cfg(config, **kwargs)
|
28 |
+
parsed_cfg.is_preprocess = True
|
29 |
check_accelerate_default_config()
|
30 |
check_user_token()
|
31 |
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
src/axolotl/datasets.py
CHANGED
@@ -35,7 +35,10 @@ class TokenizedPromptDataset(Dataset):
|
|
35 |
):
|
36 |
self.prompt_tokenizer = prompt_tokenizer
|
37 |
self.process_count = process_count
|
38 |
-
super().__init__(
|
|
|
|
|
|
|
39 |
|
40 |
def process(self, dataset):
|
41 |
features = dataset.features.keys()
|
@@ -52,6 +55,7 @@ class TokenizedPromptDataset(Dataset):
|
|
52 |
self.prompt_tokenizer.tokenize_prompt,
|
53 |
num_proc=num_proc,
|
54 |
remove_columns=features,
|
|
|
55 |
**map_kwargs,
|
56 |
)
|
57 |
|
|
|
35 |
):
|
36 |
self.prompt_tokenizer = prompt_tokenizer
|
37 |
self.process_count = process_count
|
38 |
+
super().__init__(
|
39 |
+
self.process(dataset).data,
|
40 |
+
**kwargs,
|
41 |
+
)
|
42 |
|
43 |
def process(self, dataset):
|
44 |
features = dataset.features.keys()
|
|
|
55 |
self.prompt_tokenizer.tokenize_prompt,
|
56 |
num_proc=num_proc,
|
57 |
remove_columns=features,
|
58 |
+
keep_in_memory=True,
|
59 |
**map_kwargs,
|
60 |
)
|
61 |
|
src/axolotl/utils/data.py
CHANGED
@@ -594,12 +594,16 @@ def get_dataset_wrapper(
|
|
594 |
)
|
595 |
dataset_prompter = UnsupportedPrompter()
|
596 |
dataset_wrapper = TokenizedPromptDataset(
|
597 |
-
ds_strategy,
|
|
|
|
|
598 |
)
|
599 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
600 |
dataset_prompter = UnsupportedPrompter()
|
601 |
dataset_wrapper = TokenizedPromptDataset(
|
602 |
-
ds_strategy,
|
|
|
|
|
603 |
)
|
604 |
elif d_base_type == "alpaca":
|
605 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
@@ -610,7 +614,9 @@ def get_dataset_wrapper(
|
|
610 |
cfg.sequence_len,
|
611 |
)
|
612 |
ds_wrapper = TokenizedPromptDataset(
|
613 |
-
ds_strategy,
|
|
|
|
|
614 |
)
|
615 |
dataset_wrapper = ds_wrapper
|
616 |
elif d_base_type == "explainchoice":
|
@@ -622,7 +628,9 @@ def get_dataset_wrapper(
|
|
622 |
cfg.sequence_len,
|
623 |
)
|
624 |
ds_wrapper = TokenizedPromptDataset(
|
625 |
-
ds_strategy,
|
|
|
|
|
626 |
)
|
627 |
dataset_wrapper = ds_wrapper
|
628 |
elif d_base_type == "concisechoice":
|
@@ -634,7 +642,9 @@ def get_dataset_wrapper(
|
|
634 |
cfg.sequence_len,
|
635 |
)
|
636 |
ds_wrapper = TokenizedPromptDataset(
|
637 |
-
ds_strategy,
|
|
|
|
|
638 |
)
|
639 |
dataset_wrapper = ds_wrapper
|
640 |
elif d_base_type == "summarizetldr":
|
@@ -646,7 +656,9 @@ def get_dataset_wrapper(
|
|
646 |
cfg.sequence_len,
|
647 |
)
|
648 |
ds_wrapper = TokenizedPromptDataset(
|
649 |
-
ds_strategy,
|
|
|
|
|
650 |
)
|
651 |
dataset_wrapper = ds_wrapper
|
652 |
elif d_base_type == "jeopardy":
|
@@ -658,7 +670,9 @@ def get_dataset_wrapper(
|
|
658 |
cfg.sequence_len,
|
659 |
)
|
660 |
ds_wrapper = TokenizedPromptDataset(
|
661 |
-
ds_strategy,
|
|
|
|
|
662 |
)
|
663 |
dataset_wrapper = ds_wrapper
|
664 |
elif d_base_type == "oasst":
|
@@ -670,7 +684,9 @@ def get_dataset_wrapper(
|
|
670 |
cfg.sequence_len,
|
671 |
)
|
672 |
ds_wrapper = TokenizedPromptDataset(
|
673 |
-
ds_strategy,
|
|
|
|
|
674 |
)
|
675 |
dataset_wrapper = ds_wrapper
|
676 |
elif d_base_type == "gpteacher":
|
@@ -682,7 +698,9 @@ def get_dataset_wrapper(
|
|
682 |
cfg.sequence_len,
|
683 |
)
|
684 |
ds_wrapper = TokenizedPromptDataset(
|
685 |
-
ds_strategy,
|
|
|
|
|
686 |
)
|
687 |
dataset_wrapper = ds_wrapper
|
688 |
elif d_base_type == "reflection":
|
@@ -694,7 +712,9 @@ def get_dataset_wrapper(
|
|
694 |
cfg.sequence_len,
|
695 |
)
|
696 |
ds_wrapper = TokenizedPromptDataset(
|
697 |
-
ds_strategy,
|
|
|
|
|
698 |
)
|
699 |
dataset_wrapper = ds_wrapper
|
700 |
else:
|
|
|
594 |
)
|
595 |
dataset_prompter = UnsupportedPrompter()
|
596 |
dataset_wrapper = TokenizedPromptDataset(
|
597 |
+
ds_strategy,
|
598 |
+
dataset,
|
599 |
+
process_count=cfg.dataset_processes,
|
600 |
)
|
601 |
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
602 |
dataset_prompter = UnsupportedPrompter()
|
603 |
dataset_wrapper = TokenizedPromptDataset(
|
604 |
+
ds_strategy,
|
605 |
+
dataset,
|
606 |
+
process_count=cfg.dataset_processes,
|
607 |
)
|
608 |
elif d_base_type == "alpaca":
|
609 |
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
|
614 |
cfg.sequence_len,
|
615 |
)
|
616 |
ds_wrapper = TokenizedPromptDataset(
|
617 |
+
ds_strategy,
|
618 |
+
dataset,
|
619 |
+
process_count=cfg.dataset_processes,
|
620 |
)
|
621 |
dataset_wrapper = ds_wrapper
|
622 |
elif d_base_type == "explainchoice":
|
|
|
628 |
cfg.sequence_len,
|
629 |
)
|
630 |
ds_wrapper = TokenizedPromptDataset(
|
631 |
+
ds_strategy,
|
632 |
+
dataset,
|
633 |
+
process_count=cfg.dataset_processes,
|
634 |
)
|
635 |
dataset_wrapper = ds_wrapper
|
636 |
elif d_base_type == "concisechoice":
|
|
|
642 |
cfg.sequence_len,
|
643 |
)
|
644 |
ds_wrapper = TokenizedPromptDataset(
|
645 |
+
ds_strategy,
|
646 |
+
dataset,
|
647 |
+
process_count=cfg.dataset_processes,
|
648 |
)
|
649 |
dataset_wrapper = ds_wrapper
|
650 |
elif d_base_type == "summarizetldr":
|
|
|
656 |
cfg.sequence_len,
|
657 |
)
|
658 |
ds_wrapper = TokenizedPromptDataset(
|
659 |
+
ds_strategy,
|
660 |
+
dataset,
|
661 |
+
process_count=cfg.dataset_processes,
|
662 |
)
|
663 |
dataset_wrapper = ds_wrapper
|
664 |
elif d_base_type == "jeopardy":
|
|
|
670 |
cfg.sequence_len,
|
671 |
)
|
672 |
ds_wrapper = TokenizedPromptDataset(
|
673 |
+
ds_strategy,
|
674 |
+
dataset,
|
675 |
+
process_count=cfg.dataset_processes,
|
676 |
)
|
677 |
dataset_wrapper = ds_wrapper
|
678 |
elif d_base_type == "oasst":
|
|
|
684 |
cfg.sequence_len,
|
685 |
)
|
686 |
ds_wrapper = TokenizedPromptDataset(
|
687 |
+
ds_strategy,
|
688 |
+
dataset,
|
689 |
+
process_count=cfg.dataset_processes,
|
690 |
)
|
691 |
dataset_wrapper = ds_wrapper
|
692 |
elif d_base_type == "gpteacher":
|
|
|
698 |
cfg.sequence_len,
|
699 |
)
|
700 |
ds_wrapper = TokenizedPromptDataset(
|
701 |
+
ds_strategy,
|
702 |
+
dataset,
|
703 |
+
process_count=cfg.dataset_processes,
|
704 |
)
|
705 |
dataset_wrapper = ds_wrapper
|
706 |
elif d_base_type == "reflection":
|
|
|
712 |
cfg.sequence_len,
|
713 |
)
|
714 |
ds_wrapper = TokenizedPromptDataset(
|
715 |
+
ds_strategy,
|
716 |
+
dataset,
|
717 |
+
process_count=cfg.dataset_processes,
|
718 |
)
|
719 |
dataset_wrapper = ds_wrapper
|
720 |
else:
|
src/axolotl/utils/trainer.py
CHANGED
@@ -111,27 +111,39 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
111 |
with zero_first(is_main_process()):
|
112 |
if cfg.group_by_length:
|
113 |
train_dataset = train_dataset.map(
|
114 |
-
add_length,
|
|
|
|
|
115 |
)
|
116 |
|
117 |
if cfg.sample_packing:
|
118 |
train_dataset = train_dataset.map(
|
119 |
-
add_position_ids,
|
|
|
|
|
120 |
)
|
121 |
if cfg.eval_sample_packing is not False:
|
122 |
if eval_dataset:
|
123 |
eval_dataset = eval_dataset.map(
|
124 |
-
add_position_ids,
|
|
|
|
|
125 |
)
|
126 |
|
127 |
if cfg.group_by_length or cfg.sample_packing:
|
128 |
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
129 |
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
130 |
|
131 |
-
train_dataset = train_dataset.filter(
|
|
|
|
|
|
|
|
|
132 |
if eval_dataset:
|
133 |
eval_dataset = eval_dataset.filter(
|
134 |
-
drop_long,
|
|
|
|
|
135 |
)
|
136 |
|
137 |
# Phi doesn't want the attention_mask feature when training
|
|
|
111 |
with zero_first(is_main_process()):
|
112 |
if cfg.group_by_length:
|
113 |
train_dataset = train_dataset.map(
|
114 |
+
add_length,
|
115 |
+
num_proc=cfg.dataset_processes,
|
116 |
+
load_from_cache_file=not cfg.is_preprocess,
|
117 |
)
|
118 |
|
119 |
if cfg.sample_packing:
|
120 |
train_dataset = train_dataset.map(
|
121 |
+
add_position_ids,
|
122 |
+
num_proc=cfg.dataset_processes,
|
123 |
+
load_from_cache_file=not cfg.is_preprocess,
|
124 |
)
|
125 |
if cfg.eval_sample_packing is not False:
|
126 |
if eval_dataset:
|
127 |
eval_dataset = eval_dataset.map(
|
128 |
+
add_position_ids,
|
129 |
+
num_proc=cfg.dataset_processes,
|
130 |
+
load_from_cache_file=not cfg.is_preprocess,
|
131 |
)
|
132 |
|
133 |
if cfg.group_by_length or cfg.sample_packing:
|
134 |
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
135 |
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
136 |
|
137 |
+
train_dataset = train_dataset.filter(
|
138 |
+
drop_long,
|
139 |
+
num_proc=cfg.dataset_processes,
|
140 |
+
load_from_cache_file=not cfg.is_preprocess,
|
141 |
+
)
|
142 |
if eval_dataset:
|
143 |
eval_dataset = eval_dataset.filter(
|
144 |
+
drop_long,
|
145 |
+
num_proc=cfg.dataset_processes,
|
146 |
+
load_from_cache_file=not cfg.is_preprocess,
|
147 |
)
|
148 |
|
149 |
# Phi doesn't want the attention_mask feature when training
|