winglian tmm1 commited on
Commit
a4e1bb6
·
unverified ·
1 Parent(s): 36e53c7

let hf trainer handle torch compile (#516)

Browse files

* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>

README.md CHANGED
@@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
519
  # where to save the finished model to
520
  output_dir: ./completed-model
521
 
 
 
 
 
522
  # training hyperparameters
523
  gradient_accumulation_steps: 1
524
  micro_batch_size: 2
 
519
  # where to save the finished model to
520
  output_dir: ./completed-model
521
 
522
+ # whether to use torch.compile and which backend to use
523
+ torch_compile: # bool
524
+ torch_compile_backend: # Optional[str]
525
+
526
  # training hyperparameters
527
  gradient_accumulation_steps: 1
528
  micro_batch_size: 2
src/axolotl/train.py CHANGED
@@ -80,10 +80,6 @@ def train(
80
 
81
  model.config.use_cache = False
82
 
83
- if torch.__version__ >= "2" and sys.platform != "win32":
84
- LOG.info("Compiling torch model")
85
- model = torch.compile(model)
86
-
87
  # go ahead and presave, so we have the adapter config available to inspect
88
  if peft_config:
89
  LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
 
80
 
81
  model.config.use_cache = False
82
 
 
 
 
 
83
  # go ahead and presave, so we have the adapter config available to inspect
84
  if peft_config:
85
  LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
src/axolotl/utils/trainer.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
  from typing import Optional, Union
12
 
13
  import numpy as np
 
14
  import torch.cuda
15
  import transformers
16
  from datasets import Dataset, set_caching_enabled
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
604
  if cfg.greater_is_better:
605
  training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  # DDP Config
608
  if cfg.ddp_timeout:
609
  training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
 
11
  from typing import Optional, Union
12
 
13
  import numpy as np
14
+ import torch
15
  import torch.cuda
16
  import transformers
17
  from datasets import Dataset, set_caching_enabled
 
605
  if cfg.greater_is_better:
606
  training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
607
 
608
+ if cfg.torch_compile:
609
+ if torch.__version__ < "2.1.0": # pylint: disable=protected-access
610
+ LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
611
+ else:
612
+ import torch._dynamo # pylint: disable=redefined-outer-name
613
+
614
+ torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
615
+ True
616
+ )
617
+ training_arguments_kwargs["torch_compile"] = cfg.torch_compile
618
+ if cfg.torch_compile_backend:
619
+ training_arguments_kwargs[
620
+ "torch_compile_backend"
621
+ ] = cfg.torch_compile_backend
622
+
623
  # DDP Config
624
  if cfg.ddp_timeout:
625
  training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout