feature: loss watchdog for terminating training runs that are failing (#899)
Browse files- README.md +3 -0
- examples/mistral/qlora.yml +3 -0
- src/axolotl/core/trainer_builder.py +4 -0
- src/axolotl/utils/callbacks.py +30 -0
README.md
CHANGED
@@ -694,6 +694,9 @@ max_steps:
|
|
694 |
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
695 |
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
696 |
|
|
|
|
|
|
|
697 |
# Save model as safetensors (require safetensors package)
|
698 |
save_safetensors:
|
699 |
|
|
|
694 |
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
695 |
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
696 |
|
697 |
+
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
698 |
+
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
699 |
+
|
700 |
# Save model as safetensors (require safetensors package)
|
701 |
save_safetensors:
|
702 |
|
examples/mistral/qlora.yml
CHANGED
@@ -62,6 +62,9 @@ logging_steps: 1
|
|
62 |
xformers_attention:
|
63 |
flash_attention: true
|
64 |
|
|
|
|
|
|
|
65 |
warmup_steps: 10
|
66 |
eval_steps: 0.05
|
67 |
eval_table_size:
|
|
|
62 |
xformers_attention:
|
63 |
flash_attention: true
|
64 |
|
65 |
+
loss_watchdog_threshold: 5.0
|
66 |
+
loss_watchdog_patience: 3
|
67 |
+
|
68 |
warmup_steps: 10
|
69 |
eval_steps: 0.05
|
70 |
eval_table_size:
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -25,6 +25,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
25 |
from axolotl.utils.callbacks import (
|
26 |
EvalFirstStepCallback,
|
27 |
GPUStatsCallback,
|
|
|
28 |
SaveAxolotlConfigtoWandBCallback,
|
29 |
SaveBetterTransformerModelCallback,
|
30 |
bench_eval_callback_factory,
|
@@ -430,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
430 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
431 |
)
|
432 |
|
|
|
|
|
|
|
433 |
return callbacks
|
434 |
|
435 |
def get_post_trainer_create_callbacks(self, trainer):
|
|
|
25 |
from axolotl.utils.callbacks import (
|
26 |
EvalFirstStepCallback,
|
27 |
GPUStatsCallback,
|
28 |
+
LossWatchDogCallback,
|
29 |
SaveAxolotlConfigtoWandBCallback,
|
30 |
SaveBetterTransformerModelCallback,
|
31 |
bench_eval_callback_factory,
|
|
|
431 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
432 |
)
|
433 |
|
434 |
+
if self.cfg.loss_watchdog_threshold is not None:
|
435 |
+
callbacks.append(LossWatchDogCallback(self.cfg))
|
436 |
+
|
437 |
return callbacks
|
438 |
|
439 |
def get_post_trainer_create_callbacks(self, trainer):
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -124,6 +124,36 @@ class GPUStatsCallback(
|
|
124 |
return control
|
125 |
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
def bench_eval_callback_factory(trainer, tokenizer):
|
128 |
accuracy = evaluate.load("accuracy")
|
129 |
abcd_idx = [
|
|
|
124 |
return control
|
125 |
|
126 |
|
127 |
+
class LossWatchDogCallback(TrainerCallback):
|
128 |
+
"""Callback to track loss and stop training if loss is too high"""
|
129 |
+
|
130 |
+
def __init__(self, cfg):
|
131 |
+
self.cfg = cfg
|
132 |
+
self.logged = False
|
133 |
+
self.violations = 0
|
134 |
+
self.threshold = cfg.loss_watchdog_threshold
|
135 |
+
self.patience = cfg.loss_watchdog_patience or 3
|
136 |
+
|
137 |
+
def on_step_end(
|
138 |
+
self,
|
139 |
+
_args: TrainingArguments,
|
140 |
+
state: TrainerState,
|
141 |
+
control: TrainerControl,
|
142 |
+
**_kwargs,
|
143 |
+
):
|
144 |
+
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
145 |
+
if state.log_history[-1]["loss"] > self.threshold:
|
146 |
+
self.violations += 1
|
147 |
+
if self.violations >= self.patience:
|
148 |
+
LOG.warning(
|
149 |
+
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
150 |
+
)
|
151 |
+
control.should_training_stop = True
|
152 |
+
else:
|
153 |
+
self.violations = 0
|
154 |
+
return control
|
155 |
+
|
156 |
+
|
157 |
def bench_eval_callback_factory(trainer, tokenizer):
|
158 |
accuracy = evaluate.load("accuracy")
|
159 |
abcd_idx = [
|