winglian commited on
Commit
4d2e842
1 Parent(s): 3678a6c

use recommended setting for use_reentrant w gradient checkpointing (#1021)

Browse files

* use recommended setting for use_reentrant w gradient checkpointing

* add doc for gradient_checkpointing_kwargs

Files changed (2) hide show
  1. README.md +3 -0
  2. src/axolotl/core/trainer_builder.py +8 -0
README.md CHANGED
@@ -741,6 +741,9 @@ group_by_length: false
741
 
742
  # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
743
  gradient_checkpointing: false
 
 
 
744
 
745
  # Stop training after this many evaluation losses have increased in a row
746
  # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
 
741
 
742
  # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
743
  gradient_checkpointing: false
744
+ # additional kwargs to pass to the trainer for gradient checkpointing
745
+ # gradient_checkpointing_kwargs:
746
+ # use_reentrant: false
747
 
748
  # Stop training after this many evaluation losses have increased in a row
749
  # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
src/axolotl/core/trainer_builder.py CHANGED
@@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
566
  training_arguments_kwargs[
567
  "gradient_checkpointing"
568
  ] = self.cfg.gradient_checkpointing
 
 
 
 
 
 
 
 
569
  if self.cfg.fsdp:
570
  training_arguments_kwargs["fsdp"] = self.cfg.fsdp
571
  if self.cfg.fsdp_config:
 
566
  training_arguments_kwargs[
567
  "gradient_checkpointing"
568
  ] = self.cfg.gradient_checkpointing
569
+ if self.cfg.gradient_checkpointing_kwargs:
570
+ training_arguments_kwargs[
571
+ "gradient_checkpointing_kwargs"
572
+ ] = self.cfg.gradient_checkpointing_kwargs
573
+ else:
574
+ training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
575
+ "use_reentrant": False
576
+ }
577
  if self.cfg.fsdp:
578
  training_arguments_kwargs["fsdp"] = self.cfg.fsdp
579
  if self.cfg.fsdp_config: