use recommended setting for use_reentrant w gradient checkpointing (#1021)
Browse files* use recommended setting for use_reentrant w gradient checkpointing
* add doc for gradient_checkpointing_kwargs
- README.md +3 -0
- src/axolotl/core/trainer_builder.py +8 -0
README.md
CHANGED
@@ -741,6 +741,9 @@ group_by_length: false
|
|
741 |
|
742 |
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
743 |
gradient_checkpointing: false
|
|
|
|
|
|
|
744 |
|
745 |
# Stop training after this many evaluation losses have increased in a row
|
746 |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
|
|
741 |
|
742 |
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
743 |
gradient_checkpointing: false
|
744 |
+
# additional kwargs to pass to the trainer for gradient checkpointing
|
745 |
+
# gradient_checkpointing_kwargs:
|
746 |
+
# use_reentrant: false
|
747 |
|
748 |
# Stop training after this many evaluation losses have increased in a row
|
749 |
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -566,6 +566,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
566 |
training_arguments_kwargs[
|
567 |
"gradient_checkpointing"
|
568 |
] = self.cfg.gradient_checkpointing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
if self.cfg.fsdp:
|
570 |
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
571 |
if self.cfg.fsdp_config:
|
|
|
566 |
training_arguments_kwargs[
|
567 |
"gradient_checkpointing"
|
568 |
] = self.cfg.gradient_checkpointing
|
569 |
+
if self.cfg.gradient_checkpointing_kwargs:
|
570 |
+
training_arguments_kwargs[
|
571 |
+
"gradient_checkpointing_kwargs"
|
572 |
+
] = self.cfg.gradient_checkpointing_kwargs
|
573 |
+
else:
|
574 |
+
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
575 |
+
"use_reentrant": False
|
576 |
+
}
|
577 |
if self.cfg.fsdp:
|
578 |
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
579 |
if self.cfg.fsdp_config:
|