Feat(config): add max steps (#387)
Browse files- scripts/finetune.py +7 -1
- src/axolotl/utils/trainer.py +1 -1
scripts/finetune.py
CHANGED
@@ -209,7 +209,13 @@ def train(
|
|
209 |
cfg, train_dataset, eval_dataset
|
210 |
)
|
211 |
barrier()
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
if cfg.debug or "debug" in kwargs:
|
215 |
LOG.info("check_dataset_labels...")
|
|
|
209 |
cfg, train_dataset, eval_dataset
|
210 |
)
|
211 |
barrier()
|
212 |
+
if cfg.max_steps:
|
213 |
+
total_num_steps = min(
|
214 |
+
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
215 |
+
)
|
216 |
+
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
217 |
+
else:
|
218 |
+
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
219 |
|
220 |
if cfg.debug or "debug" in kwargs:
|
221 |
LOG.info("check_dataset_labels...")
|
src/axolotl/utils/trainer.py
CHANGED
@@ -461,7 +461,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
461 |
evaluation_strategy = "steps"
|
462 |
|
463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
464 |
-
|
465 |
max_seq_length=cfg.sequence_len,
|
466 |
per_device_train_batch_size=cfg.micro_batch_size,
|
467 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
461 |
evaluation_strategy = "steps"
|
462 |
|
463 |
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
464 |
+
max_steps=total_num_steps if cfg.max_steps else -1,
|
465 |
max_seq_length=cfg.sequence_len,
|
466 |
per_device_train_batch_size=cfg.micro_batch_size,
|
467 |
per_device_eval_batch_size=cfg.eval_batch_size
|