tmm1 commited on
Commit
44454ae
·
unverified ·
1 Parent(s): 09f1543

move is_llama_derived_model into normalize_config (#524)

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +1 -10
  2. src/axolotl/utils/config.py +11 -0
scripts/finetune.py CHANGED
@@ -24,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config
24
  from axolotl.utils.data import prepare_dataset
25
  from axolotl.utils.dict import DictDefault
26
  from axolotl.utils.distributed import is_main_process
27
- from axolotl.utils.models import load_model_config, load_tokenizer
28
  from axolotl.utils.tokenization import check_dataset_labels
29
  from axolotl.utils.wandb import setup_wandb_env_vars
30
 
@@ -216,15 +216,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
216
  else:
217
  cfg[k] = kwargs[k]
218
 
219
- model_config = load_model_config(cfg)
220
-
221
- # figure out if the model is llama
222
- cfg.is_llama_derived_model = (
223
- (hasattr(model_config, "model_type") and model_config.model_type == "llama")
224
- or cfg.is_llama_derived_model
225
- or "llama" in cfg.base_model
226
- or (cfg.model_type and "llama" in cfg.model_type.lower())
227
- )
228
  validate_config(cfg)
229
 
230
  normalize_config(cfg)
 
24
  from axolotl.utils.data import prepare_dataset
25
  from axolotl.utils.dict import DictDefault
26
  from axolotl.utils.distributed import is_main_process
27
+ from axolotl.utils.models import load_tokenizer
28
  from axolotl.utils.tokenization import check_dataset_labels
29
  from axolotl.utils.wandb import setup_wandb_env_vars
30
 
 
216
  else:
217
  cfg[k] = kwargs[k]
218
 
 
 
 
 
 
 
 
 
 
219
  validate_config(cfg)
220
 
221
  normalize_config(cfg)
src/axolotl/utils/config.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import torch
7
 
8
  from axolotl.utils.bench import log_gpu_memory_usage
 
9
 
10
  LOG = logging.getLogger("axolotl")
11
 
@@ -69,6 +70,16 @@ def normalize_config(cfg):
69
  else:
70
  cfg.torch_dtype = torch.float32
71
 
 
 
 
 
 
 
 
 
 
 
72
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
73
 
74
 
 
6
  import torch
7
 
8
  from axolotl.utils.bench import log_gpu_memory_usage
9
+ from axolotl.utils.models import load_model_config
10
 
11
  LOG = logging.getLogger("axolotl")
12
 
 
70
  else:
71
  cfg.torch_dtype = torch.float32
72
 
73
+ model_config = load_model_config(cfg)
74
+
75
+ # figure out if the model is llama
76
+ cfg.is_llama_derived_model = (
77
+ (hasattr(model_config, "model_type") and model_config.model_type == "llama")
78
+ or cfg.is_llama_derived_model
79
+ or "llama" in cfg.base_model
80
+ or (cfg.model_type and "llama" in cfg.model_type.lower())
81
+ )
82
+
83
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
84
 
85