Jan Philipp Harries Jan Philipp Harries commited on
Commit
396a7a7
1 Parent(s): b21e4a2

Added advanced DDP args (#515)

Browse files

* add ddp_config

* add advanced ddp config

* add ddp_config

* add advanced ddp config

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>

Files changed (2) hide show
  1. README.md +5 -0
  2. src/axolotl/utils/trainer.py +9 -0
README.md CHANGED
@@ -623,6 +623,11 @@ fsdp_config:
623
  # Deepspeed config path
624
  deepspeed:
625
 
 
 
 
 
 
626
  # Path to torch distx for optim 'adamw_anyprecision'
627
  torchdistx_path:
628
 
 
623
  # Deepspeed config path
624
  deepspeed:
625
 
626
+ # Advanced DDP Arguments
627
+ ddp_timeout:
628
+ ddp_bucket_cap_mb:
629
+ ddp_broadcast_buffers:
630
+
631
  # Path to torch distx for optim 'adamw_anyprecision'
632
  torchdistx_path:
633
 
src/axolotl/utils/trainer.py CHANGED
@@ -579,6 +579,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
579
  if cfg.bench_dataset:
580
  training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
581
 
 
 
 
 
 
 
 
 
 
582
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
583
  max_steps=total_num_steps if cfg.max_steps else -1,
584
  max_seq_length=cfg.sequence_len,
 
579
  if cfg.bench_dataset:
580
  training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
581
 
582
+ # DDP Config
583
+ if cfg.ddp_timeout:
584
+ training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
585
+ # see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
586
+ if cfg.ddp_bucket_cap_mb:
587
+ training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
588
+ if cfg.ddp_broadcast_buffers is not None:
589
+ training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
590
+
591
  training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
592
  max_steps=total_num_steps if cfg.max_steps else -1,
593
  max_seq_length=cfg.sequence_len,