Nanobit commited on
Commit
b1e3e1b
1 Parent(s): 2ea70eb

fix(config): passing gradient_checkpoint_kwargs (#1412)

Browse files

* fix(config): change default use_reentrant to true

* Update trainer_builder.py

* fix: make sure to pass kwargs to enable checkpoint

* chore: lint

README.md CHANGED
@@ -859,7 +859,7 @@ group_by_length: false
859
  gradient_checkpointing: false
860
  # additional kwargs to pass to the trainer for gradient checkpointing
861
  # gradient_checkpointing_kwargs:
862
- # use_reentrant: false
863
 
864
  # Stop training after this many evaluation losses have increased in a row
865
  # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
 
859
  gradient_checkpointing: false
860
  # additional kwargs to pass to the trainer for gradient checkpointing
861
  # gradient_checkpointing_kwargs:
862
+ # use_reentrant: true
863
 
864
  # Stop training after this many evaluation losses have increased in a row
865
  # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
src/axolotl/core/trainer_builder.py CHANGED
@@ -970,10 +970,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
970
  training_arguments_kwargs[
971
  "gradient_checkpointing_kwargs"
972
  ] = self.cfg.gradient_checkpointing_kwargs
973
- else:
974
- training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
975
- "use_reentrant": False
976
- }
977
  if self.cfg.fsdp:
978
  training_arguments_kwargs["fsdp"] = self.cfg.fsdp
979
  if self.cfg.fsdp_config:
 
970
  training_arguments_kwargs[
971
  "gradient_checkpointing_kwargs"
972
  ] = self.cfg.gradient_checkpointing_kwargs
 
 
 
 
973
  if self.cfg.fsdp:
974
  training_arguments_kwargs["fsdp"] = self.cfg.fsdp
975
  if self.cfg.fsdp_config:
src/axolotl/utils/models.py CHANGED
@@ -888,7 +888,9 @@ def load_model(
888
 
889
  if cfg.adapter in ["lora", "qlora"]:
890
  if cfg.gradient_checkpointing:
891
- model.gradient_checkpointing_enable()
 
 
892
  if (
893
  cfg.load_in_8bit or cfg.load_in_4bit
894
  ) and not skip_prepare_model_for_kbit_training:
 
888
 
889
  if cfg.adapter in ["lora", "qlora"]:
890
  if cfg.gradient_checkpointing:
891
+ model.gradient_checkpointing_enable(
892
+ gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
893
+ )
894
  if (
895
  cfg.load_in_8bit or cfg.load_in_4bit
896
  ) and not skip_prepare_model_for_kbit_training: