improve save callbacks (#1592)
Browse files
src/axolotl/core/trainer_builder.py
CHANGED
@@ -43,6 +43,7 @@ from axolotl.utils.callbacks import (
|
|
43 |
LossWatchDogCallback,
|
44 |
SaveAxolotlConfigtoWandBCallback,
|
45 |
SaveBetterTransformerModelCallback,
|
|
|
46 |
bench_eval_callback_factory,
|
47 |
causal_lm_bench_eval_callback_factory,
|
48 |
log_prediction_callback_factory,
|
@@ -888,6 +889,14 @@ class TrainerBuilderBase(abc.ABC):
|
|
888 |
callbacks.append(
|
889 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
890 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
891 |
|
892 |
return callbacks
|
893 |
|
@@ -933,18 +942,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
933 |
):
|
934 |
callbacks.append(SaveBetterTransformerModelCallback())
|
935 |
|
936 |
-
if self.cfg.use_mlflow and is_mlflow_available():
|
937 |
-
from axolotl.utils.callbacks.mlflow_ import (
|
938 |
-
SaveAxolotlConfigtoMlflowCallback,
|
939 |
-
)
|
940 |
-
|
941 |
-
callbacks.append(
|
942 |
-
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
943 |
-
)
|
944 |
-
|
945 |
if self.cfg.loss_watchdog_threshold is not None:
|
946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
947 |
|
|
|
|
|
948 |
return callbacks
|
949 |
|
950 |
def get_post_trainer_create_callbacks(self, trainer):
|
@@ -1427,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1427 |
|
1428 |
def get_callbacks(self):
|
1429 |
callbacks = super().get_callbacks()
|
|
|
|
|
1430 |
return callbacks
|
1431 |
|
1432 |
def get_post_trainer_create_callbacks(self, trainer):
|
|
|
43 |
LossWatchDogCallback,
|
44 |
SaveAxolotlConfigtoWandBCallback,
|
45 |
SaveBetterTransformerModelCallback,
|
46 |
+
SaveModelOnTrainEndCallback,
|
47 |
bench_eval_callback_factory,
|
48 |
causal_lm_bench_eval_callback_factory,
|
49 |
log_prediction_callback_factory,
|
|
|
889 |
callbacks.append(
|
890 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
891 |
)
|
892 |
+
if self.cfg.use_mlflow and is_mlflow_available():
|
893 |
+
from axolotl.utils.callbacks.mlflow_ import (
|
894 |
+
SaveAxolotlConfigtoMlflowCallback,
|
895 |
+
)
|
896 |
+
|
897 |
+
callbacks.append(
|
898 |
+
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
899 |
+
)
|
900 |
|
901 |
return callbacks
|
902 |
|
|
|
942 |
):
|
943 |
callbacks.append(SaveBetterTransformerModelCallback())
|
944 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
if self.cfg.loss_watchdog_threshold is not None:
|
946 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
947 |
|
948 |
+
callbacks.append(SaveModelOnTrainEndCallback())
|
949 |
+
|
950 |
return callbacks
|
951 |
|
952 |
def get_post_trainer_create_callbacks(self, trainer):
|
|
|
1429 |
|
1430 |
def get_callbacks(self):
|
1431 |
callbacks = super().get_callbacks()
|
1432 |
+
callbacks.append(SaveModelOnTrainEndCallback())
|
1433 |
+
|
1434 |
return callbacks
|
1435 |
|
1436 |
def get_post_trainer_create_callbacks(self, trainer):
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
@@ -773,3 +773,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
773 |
except (FileNotFoundError, ConnectionError) as err:
|
774 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
775 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
773 |
except (FileNotFoundError, ConnectionError) as err:
|
774 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
775 |
return control
|
776 |
+
|
777 |
+
|
778 |
+
class SaveModelOnTrainEndCallback(TrainerCallback):
|
779 |
+
"""Callback to save model on train end"""
|
780 |
+
|
781 |
+
def on_train_end( # pylint: disable=unused-argument
|
782 |
+
self, args, state, control, **kwargs
|
783 |
+
):
|
784 |
+
control.should_save = True
|
785 |
+
return control
|