move is_llama_derived_model into normalize_config (#524)
Browse files- scripts/finetune.py +1 -10
- src/axolotl/utils/config.py +11 -0
scripts/finetune.py
CHANGED
@@ -24,7 +24,7 @@ from axolotl.utils.config import normalize_config, validate_config
|
|
24 |
from axolotl.utils.data import prepare_dataset
|
25 |
from axolotl.utils.dict import DictDefault
|
26 |
from axolotl.utils.distributed import is_main_process
|
27 |
-
from axolotl.utils.models import
|
28 |
from axolotl.utils.tokenization import check_dataset_labels
|
29 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
30 |
|
@@ -216,15 +216,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
216 |
else:
|
217 |
cfg[k] = kwargs[k]
|
218 |
|
219 |
-
model_config = load_model_config(cfg)
|
220 |
-
|
221 |
-
# figure out if the model is llama
|
222 |
-
cfg.is_llama_derived_model = (
|
223 |
-
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
224 |
-
or cfg.is_llama_derived_model
|
225 |
-
or "llama" in cfg.base_model
|
226 |
-
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
227 |
-
)
|
228 |
validate_config(cfg)
|
229 |
|
230 |
normalize_config(cfg)
|
|
|
24 |
from axolotl.utils.data import prepare_dataset
|
25 |
from axolotl.utils.dict import DictDefault
|
26 |
from axolotl.utils.distributed import is_main_process
|
27 |
+
from axolotl.utils.models import load_tokenizer
|
28 |
from axolotl.utils.tokenization import check_dataset_labels
|
29 |
from axolotl.utils.wandb import setup_wandb_env_vars
|
30 |
|
|
|
216 |
else:
|
217 |
cfg[k] = kwargs[k]
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
validate_config(cfg)
|
220 |
|
221 |
normalize_config(cfg)
|
src/axolotl/utils/config.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
import torch
|
7 |
|
8 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
9 |
|
10 |
LOG = logging.getLogger("axolotl")
|
11 |
|
@@ -69,6 +70,16 @@ def normalize_config(cfg):
|
|
69 |
else:
|
70 |
cfg.torch_dtype = torch.float32
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
73 |
|
74 |
|
|
|
6 |
import torch
|
7 |
|
8 |
from axolotl.utils.bench import log_gpu_memory_usage
|
9 |
+
from axolotl.utils.models import load_model_config
|
10 |
|
11 |
LOG = logging.getLogger("axolotl")
|
12 |
|
|
|
70 |
else:
|
71 |
cfg.torch_dtype = torch.float32
|
72 |
|
73 |
+
model_config = load_model_config(cfg)
|
74 |
+
|
75 |
+
# figure out if the model is llama
|
76 |
+
cfg.is_llama_derived_model = (
|
77 |
+
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
78 |
+
or cfg.is_llama_derived_model
|
79 |
+
or "llama" in cfg.base_model
|
80 |
+
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
81 |
+
)
|
82 |
+
|
83 |
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
84 |
|
85 |
|