winglian commited on
Commit
802f966
·
unverified ·
1 Parent(s): b8e5603

improve vram use w gradient checkpointing (#1167) [skip ci]

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/config.py +7 -0
src/axolotl/utils/config.py CHANGED
@@ -159,6 +159,13 @@ def normalize_config(cfg):
159
  if isinstance(cfg.pretraining_dataset, dict):
160
  cfg.pretraining_dataset = [cfg.pretraining_dataset]
161
 
 
 
 
 
 
 
 
162
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
163
 
164
 
 
159
  if isinstance(cfg.pretraining_dataset, dict):
160
  cfg.pretraining_dataset = [cfg.pretraining_dataset]
161
 
162
+ if (
163
+ cfg.gradient_checkpointing
164
+ and cfg.unfrozen_parameters is None
165
+ and cfg.gradient_checkpointing_kwargs is None
166
+ ):
167
+ cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
168
+
169
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
170
 
171