don't train if eval split is too small (#873)
Browse files* allow zero len dataset
* better handling and warning of small eval splits
* raise error if eval split is too small
* don't mess with calculating total num steps in distributed context
* fix eval_sample_packing training args logic
src/axolotl/core/trainer_builder.py
CHANGED
@@ -658,7 +658,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
658 |
self.cfg.sample_packing if self.cfg.sample_packing else False
|
659 |
)
|
660 |
training_arguments_kwargs["eval_sample_packing"] = (
|
661 |
-
self.cfg.sample_packing
|
|
|
|
|
662 |
)
|
663 |
training_arguments_kwargs[
|
664 |
"sample_packing_seq_len_multiplier"
|
|
|
658 |
self.cfg.sample_packing if self.cfg.sample_packing else False
|
659 |
)
|
660 |
training_arguments_kwargs["eval_sample_packing"] = (
|
661 |
+
self.cfg.sample_packing
|
662 |
+
if self.cfg.eval_sample_packing is not False
|
663 |
+
else False
|
664 |
)
|
665 |
training_arguments_kwargs[
|
666 |
"sample_packing_seq_len_multiplier"
|
src/axolotl/utils/data.py
CHANGED
@@ -79,6 +79,14 @@ def prepare_dataset(cfg, tokenizer):
|
|
79 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
80 |
cfg, train_dataset, eval_dataset, tokenizer
|
81 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
if cfg.max_steps:
|
83 |
total_num_steps = min(
|
84 |
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
|
|
79 |
train_dataset, eval_dataset = process_datasets_for_packing(
|
80 |
cfg, train_dataset, eval_dataset, tokenizer
|
81 |
)
|
82 |
+
|
83 |
+
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
84 |
+
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
85 |
+
if total_eval_steps == 0:
|
86 |
+
raise ValueError(
|
87 |
+
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
88 |
+
)
|
89 |
+
|
90 |
if cfg.max_steps:
|
91 |
total_num_steps = min(
|
92 |
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
src/axolotl/utils/samplers/multipack.py
CHANGED
@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
|
|
182 |
|
183 |
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
184 |
return max(
|
185 |
-
|
186 |
(
|
187 |
world_size
|
188 |
* math.floor(
|
|
|
182 |
|
183 |
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
184 |
return max(
|
185 |
+
0,
|
186 |
(
|
187 |
world_size
|
188 |
* math.floor(
|
src/axolotl/utils/trainer.py
CHANGED
@@ -141,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
141 |
return train_dataset, eval_dataset
|
142 |
|
143 |
|
144 |
-
def calculate_total_num_steps(cfg, train_dataset):
|
145 |
if not cfg.total_num_tokens:
|
146 |
total_num_tokens = np.sum(
|
147 |
train_dataset.data.column("input_ids")
|
@@ -150,7 +150,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|
150 |
.values
|
151 |
)
|
152 |
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
153 |
-
|
|
|
154 |
|
155 |
if not cfg.total_supervised_tokens:
|
156 |
total_supervised_tokens = (
|
@@ -163,7 +164,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|
163 |
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
164 |
main_process_only=True,
|
165 |
)
|
166 |
-
|
|
|
167 |
|
168 |
if cfg.sample_packing:
|
169 |
# we have to drop anything longer then sequence len otherwise
|
@@ -232,7 +234,8 @@ def calculate_total_num_steps(cfg, train_dataset):
|
|
232 |
sample_packing_eff_est = (
|
233 |
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
234 |
)
|
235 |
-
|
|
|
236 |
LOG.debug(
|
237 |
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
238 |
main_process_only=True,
|
|
|
141 |
return train_dataset, eval_dataset
|
142 |
|
143 |
|
144 |
+
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
145 |
if not cfg.total_num_tokens:
|
146 |
total_num_tokens = np.sum(
|
147 |
train_dataset.data.column("input_ids")
|
|
|
150 |
.values
|
151 |
)
|
152 |
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
153 |
+
if update:
|
154 |
+
cfg.total_num_tokens = total_num_tokens
|
155 |
|
156 |
if not cfg.total_supervised_tokens:
|
157 |
total_supervised_tokens = (
|
|
|
164 |
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
165 |
main_process_only=True,
|
166 |
)
|
167 |
+
if update:
|
168 |
+
cfg.total_supervised_tokens = total_supervised_tokens
|
169 |
|
170 |
if cfg.sample_packing:
|
171 |
# we have to drop anything longer then sequence len otherwise
|
|
|
234 |
sample_packing_eff_est = (
|
235 |
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
236 |
)
|
237 |
+
if update:
|
238 |
+
cfg.sample_packing_eff_est = sample_packing_eff_est
|
239 |
LOG.debug(
|
240 |
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
241 |
main_process_only=True,
|