fix(config): passing gradient_checkpoint_kwargs (#1412)
Browse files* fix(config): change default use_reentrant to true
* Update trainer_builder.py
* fix: make sure to pass kwargs to enable checkpoint
* chore: lint
- README.md +1 -1
- src/axolotl/core/trainer_builder.py +0 -4
- src/axolotl/utils/models.py +3 -1
README.md
CHANGED
@@ -859,7 +859,7 @@ group_by_length: false
|
|
859 |
gradient_checkpointing: false
|
860 |
# additional kwargs to pass to the trainer for gradient checkpointing
|
861 |
# gradient_checkpointing_kwargs:
|
862 |
-
# use_reentrant:
|
863 |
|
864 |
# Stop training after this many evaluation losses have increased in a row
|
865 |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
|
|
859 |
gradient_checkpointing: false
|
860 |
# additional kwargs to pass to the trainer for gradient checkpointing
|
861 |
# gradient_checkpointing_kwargs:
|
862 |
+
# use_reentrant: true
|
863 |
|
864 |
# Stop training after this many evaluation losses have increased in a row
|
865 |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -970,10 +970,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
970 |
training_arguments_kwargs[
|
971 |
"gradient_checkpointing_kwargs"
|
972 |
] = self.cfg.gradient_checkpointing_kwargs
|
973 |
-
else:
|
974 |
-
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
975 |
-
"use_reentrant": False
|
976 |
-
}
|
977 |
if self.cfg.fsdp:
|
978 |
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
979 |
if self.cfg.fsdp_config:
|
|
|
970 |
training_arguments_kwargs[
|
971 |
"gradient_checkpointing_kwargs"
|
972 |
] = self.cfg.gradient_checkpointing_kwargs
|
|
|
|
|
|
|
|
|
973 |
if self.cfg.fsdp:
|
974 |
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
975 |
if self.cfg.fsdp_config:
|
src/axolotl/utils/models.py
CHANGED
@@ -888,7 +888,9 @@ def load_model(
|
|
888 |
|
889 |
if cfg.adapter in ["lora", "qlora"]:
|
890 |
if cfg.gradient_checkpointing:
|
891 |
-
model.gradient_checkpointing_enable(
|
|
|
|
|
892 |
if (
|
893 |
cfg.load_in_8bit or cfg.load_in_4bit
|
894 |
) and not skip_prepare_model_for_kbit_training:
|
|
|
888 |
|
889 |
if cfg.adapter in ["lora", "qlora"]:
|
890 |
if cfg.gradient_checkpointing:
|
891 |
+
model.gradient_checkpointing_enable(
|
892 |
+
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
|
893 |
+
)
|
894 |
if (
|
895 |
cfg.load_in_8bit or cfg.load_in_4bit
|
896 |
) and not skip_prepare_model_for_kbit_training:
|