winglian commited on
Commit
29cf15a
1 Parent(s): dde02fc

improve save callbacks (#1592)

Browse files
src/axolotl/core/trainer_builder.py CHANGED
@@ -43,6 +43,7 @@ from axolotl.utils.callbacks import (
43
  LossWatchDogCallback,
44
  SaveAxolotlConfigtoWandBCallback,
45
  SaveBetterTransformerModelCallback,
 
46
  bench_eval_callback_factory,
47
  causal_lm_bench_eval_callback_factory,
48
  log_prediction_callback_factory,
@@ -888,6 +889,14 @@ class TrainerBuilderBase(abc.ABC):
888
  callbacks.append(
889
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
890
  )
 
 
 
 
 
 
 
 
891
 
892
  return callbacks
893
 
@@ -933,18 +942,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
933
  ):
934
  callbacks.append(SaveBetterTransformerModelCallback())
935
 
936
- if self.cfg.use_mlflow and is_mlflow_available():
937
- from axolotl.utils.callbacks.mlflow_ import (
938
- SaveAxolotlConfigtoMlflowCallback,
939
- )
940
-
941
- callbacks.append(
942
- SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
943
- )
944
-
945
  if self.cfg.loss_watchdog_threshold is not None:
946
  callbacks.append(LossWatchDogCallback(self.cfg))
947
 
 
 
948
  return callbacks
949
 
950
  def get_post_trainer_create_callbacks(self, trainer):
@@ -1427,6 +1429,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1427
 
1428
  def get_callbacks(self):
1429
  callbacks = super().get_callbacks()
 
 
1430
  return callbacks
1431
 
1432
  def get_post_trainer_create_callbacks(self, trainer):
 
43
  LossWatchDogCallback,
44
  SaveAxolotlConfigtoWandBCallback,
45
  SaveBetterTransformerModelCallback,
46
+ SaveModelOnTrainEndCallback,
47
  bench_eval_callback_factory,
48
  causal_lm_bench_eval_callback_factory,
49
  log_prediction_callback_factory,
 
889
  callbacks.append(
890
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
891
  )
892
+ if self.cfg.use_mlflow and is_mlflow_available():
893
+ from axolotl.utils.callbacks.mlflow_ import (
894
+ SaveAxolotlConfigtoMlflowCallback,
895
+ )
896
+
897
+ callbacks.append(
898
+ SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
899
+ )
900
 
901
  return callbacks
902
 
 
942
  ):
943
  callbacks.append(SaveBetterTransformerModelCallback())
944
 
 
 
 
 
 
 
 
 
 
945
  if self.cfg.loss_watchdog_threshold is not None:
946
  callbacks.append(LossWatchDogCallback(self.cfg))
947
 
948
+ callbacks.append(SaveModelOnTrainEndCallback())
949
+
950
  return callbacks
951
 
952
  def get_post_trainer_create_callbacks(self, trainer):
 
1429
 
1430
  def get_callbacks(self):
1431
  callbacks = super().get_callbacks()
1432
+ callbacks.append(SaveModelOnTrainEndCallback())
1433
+
1434
  return callbacks
1435
 
1436
  def get_post_trainer_create_callbacks(self, trainer):
src/axolotl/utils/callbacks/__init__.py CHANGED
@@ -773,3 +773,13 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
773
  except (FileNotFoundError, ConnectionError) as err:
774
  LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
775
  return control
 
 
 
 
 
 
 
 
 
 
 
773
  except (FileNotFoundError, ConnectionError) as err:
774
  LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
775
  return control
776
+
777
+
778
+ class SaveModelOnTrainEndCallback(TrainerCallback):
779
+ """Callback to save model on train end"""
780
+
781
+ def on_train_end( # pylint: disable=unused-argument
782
+ self, args, state, control, **kwargs
783
+ ):
784
+ control.should_save = True
785
+ return control