run eval on the first step to get a baseline (#617)
Browse files* run eval on the first step to get a baseline
* wandb kleeps getting moved around by pre-commit ...
src/axolotl/utils/callbacks.py
CHANGED
@@ -66,6 +66,29 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
66 |
return control
|
67 |
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
class SaveBetterTransformerModelCallback(
|
70 |
TrainerCallback
|
71 |
): # pylint: disable=too-few-public-methods
|
|
|
66 |
return control
|
67 |
|
68 |
|
69 |
+
class EvalFirstStepCallback(
|
70 |
+
TrainerCallback
|
71 |
+
): # pylint: disable=too-few-public-methods disable=unused-argument
|
72 |
+
"""
|
73 |
+
Callback to trigger evals on the first step
|
74 |
+
"""
|
75 |
+
|
76 |
+
def on_step_end(
|
77 |
+
self,
|
78 |
+
args: TrainingArguments,
|
79 |
+
state: TrainerState,
|
80 |
+
control: TrainerControl,
|
81 |
+
**kwargs,
|
82 |
+
):
|
83 |
+
if (
|
84 |
+
args.evaluation_strategy == IntervalStrategy.STEPS
|
85 |
+
and args.eval_steps < 1.0
|
86 |
+
and state.global_step == 1
|
87 |
+
):
|
88 |
+
control.should_evaluate = True
|
89 |
+
return control
|
90 |
+
|
91 |
+
|
92 |
class SaveBetterTransformerModelCallback(
|
93 |
TrainerCallback
|
94 |
): # pylint: disable=too-few-public-methods
|
src/axolotl/utils/trainer.py
CHANGED
@@ -28,6 +28,7 @@ from transformers.trainer_pt_utils import SequentialDistributedSampler
|
|
28 |
|
29 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
30 |
from axolotl.utils.callbacks import (
|
|
|
31 |
GPUStatsCallback,
|
32 |
SaveBetterTransformerModelCallback,
|
33 |
SavePeftModelCallback,
|
@@ -704,6 +705,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
704 |
|
705 |
callbacks = []
|
706 |
callbacks.append(GPUStatsCallback(cfg))
|
|
|
707 |
|
708 |
if cfg.relora_steps:
|
709 |
callbacks.append(ReLoRACallback(cfg))
|
|
|
28 |
|
29 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
30 |
from axolotl.utils.callbacks import (
|
31 |
+
EvalFirstStepCallback,
|
32 |
GPUStatsCallback,
|
33 |
SaveBetterTransformerModelCallback,
|
34 |
SavePeftModelCallback,
|
|
|
705 |
|
706 |
callbacks = []
|
707 |
callbacks.append(GPUStatsCallback(cfg))
|
708 |
+
callbacks.append(EvalFirstStepCallback)
|
709 |
|
710 |
if cfg.relora_steps:
|
711 |
callbacks.append(ReLoRACallback(cfg))
|