winglian commited on
Commit
797f3dd
1 Parent(s): 0de1457

don't train if eval split is too small (#873)

Browse files

* allow zero len dataset

* better handling and warning of small eval splits

* raise error if eval split is too small

* don't mess with calculating total num steps in distributed context

* fix eval_sample_packing training args logic

src/axolotl/core/trainer_builder.py CHANGED
@@ -658,7 +658,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
658
  self.cfg.sample_packing if self.cfg.sample_packing else False
659
  )
660
  training_arguments_kwargs["eval_sample_packing"] = (
661
- self.cfg.sample_packing if self.cfg.sample_packing else False
 
 
662
  )
663
  training_arguments_kwargs[
664
  "sample_packing_seq_len_multiplier"
 
658
  self.cfg.sample_packing if self.cfg.sample_packing else False
659
  )
660
  training_arguments_kwargs["eval_sample_packing"] = (
661
+ self.cfg.sample_packing
662
+ if self.cfg.eval_sample_packing is not False
663
+ else False
664
  )
665
  training_arguments_kwargs[
666
  "sample_packing_seq_len_multiplier"
src/axolotl/utils/data.py CHANGED
@@ -79,6 +79,14 @@ def prepare_dataset(cfg, tokenizer):
79
  train_dataset, eval_dataset = process_datasets_for_packing(
80
  cfg, train_dataset, eval_dataset, tokenizer
81
  )
 
 
 
 
 
 
 
 
82
  if cfg.max_steps:
83
  total_num_steps = min(
84
  calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
 
79
  train_dataset, eval_dataset = process_datasets_for_packing(
80
  cfg, train_dataset, eval_dataset, tokenizer
81
  )
82
+
83
+ if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
84
+ total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
85
+ if total_eval_steps == 0:
86
+ raise ValueError(
87
+ "eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
88
+ )
89
+
90
  if cfg.max_steps:
91
  total_num_steps = min(
92
  calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
src/axolotl/utils/samplers/multipack.py CHANGED
@@ -182,7 +182,7 @@ class MultipackBatchSampler(BatchSampler):
182
 
183
  # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
184
  return max(
185
- 1,
186
  (
187
  world_size
188
  * math.floor(
 
182
 
183
  # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
184
  return max(
185
+ 0,
186
  (
187
  world_size
188
  * math.floor(
src/axolotl/utils/trainer.py CHANGED
@@ -141,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
141
  return train_dataset, eval_dataset
142
 
143
 
144
- def calculate_total_num_steps(cfg, train_dataset):
145
  if not cfg.total_num_tokens:
146
  total_num_tokens = np.sum(
147
  train_dataset.data.column("input_ids")
@@ -150,7 +150,8 @@ def calculate_total_num_steps(cfg, train_dataset):
150
  .values
151
  )
152
  LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
153
- cfg.total_num_tokens = total_num_tokens
 
154
 
155
  if not cfg.total_supervised_tokens:
156
  total_supervised_tokens = (
@@ -163,7 +164,8 @@ def calculate_total_num_steps(cfg, train_dataset):
163
  f"`total_supervised_tokens: {total_supervised_tokens}`",
164
  main_process_only=True,
165
  )
166
- cfg.total_supervised_tokens = total_supervised_tokens
 
167
 
168
  if cfg.sample_packing:
169
  # we have to drop anything longer then sequence len otherwise
@@ -232,7 +234,8 @@ def calculate_total_num_steps(cfg, train_dataset):
232
  sample_packing_eff_est = (
233
  math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
234
  )
235
- cfg.sample_packing_eff_est = sample_packing_eff_est
 
236
  LOG.debug(
237
  f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
238
  main_process_only=True,
 
141
  return train_dataset, eval_dataset
142
 
143
 
144
+ def calculate_total_num_steps(cfg, train_dataset, update=True):
145
  if not cfg.total_num_tokens:
146
  total_num_tokens = np.sum(
147
  train_dataset.data.column("input_ids")
 
150
  .values
151
  )
152
  LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
153
+ if update:
154
+ cfg.total_num_tokens = total_num_tokens
155
 
156
  if not cfg.total_supervised_tokens:
157
  total_supervised_tokens = (
 
164
  f"`total_supervised_tokens: {total_supervised_tokens}`",
165
  main_process_only=True,
166
  )
167
+ if update:
168
+ cfg.total_supervised_tokens = total_supervised_tokens
169
 
170
  if cfg.sample_packing:
171
  # we have to drop anything longer then sequence len otherwise
 
234
  sample_packing_eff_est = (
235
  math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
236
  )
237
+ if update:
238
+ cfg.sample_packing_eff_est = sample_packing_eff_est
239
  LOG.debug(
240
  f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
241
  main_process_only=True,