Pydantic 2.x cfg (#1239)
Browse files* WIP conversion to use pydantic for config validation
* wip, more fields, add capabilities
* wip
* update pydantic validation to match existing tests
* tweak requirements
* setup deprecated paams pydantic model
* more validations
* wrap up rest of the validations
* flesh out the rest of the options from the readme into pydantic
* fix model validators as class methods
remember to return in validator
missing return
add missing relora attributes
fix test for DictDefault change
fix sys template for mistral from fastchat change in PR 2872
fix test for batch size warning
* more missing attributes for cfg
* updates from PR feedback
* fix validation for datasets and pretrain datasets
* fix test for lora check
- .mypy.ini +1 -1
- .pre-commit-config.yaml +1 -0
- README.md +1 -7
- requirements.txt +2 -1
- src/axolotl/cli/__init__.py +17 -2
- src/axolotl/utils/{config.py → config/__init__.py} +15 -4
- src/axolotl/utils/config/models/__init__.py +0 -0
- src/axolotl/utils/config/models/input/__init__.py +0 -0
- src/axolotl/utils/config/models/input/next/__init__.py +0 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +931 -0
- src/axolotl/utils/config/models/internals/__init__.py +14 -0
- src/axolotl/utils/dict.py +1 -1
- src/axolotl/utils/models.py +2 -2
- tests/test_dict.py +3 -5
- tests/test_prompt_tokenizers.py +4 -4
- tests/test_validation.py +721 -386
.mypy.ini
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
[mypy]
|
2 |
-
|
3 |
exclude = venv
|
4 |
|
5 |
[mypy-alpaca_lora_4bit.*]
|
|
|
1 |
[mypy]
|
2 |
+
plugins = pydantic.mypy
|
3 |
exclude = venv
|
4 |
|
5 |
[mypy-alpaca_lora_4bit.*]
|
.pre-commit-config.yaml
CHANGED
@@ -31,6 +31,7 @@ repos:
|
|
31 |
additional_dependencies:
|
32 |
[
|
33 |
'types-PyYAML',
|
|
|
34 |
]
|
35 |
- repo: https://github.com/PyCQA/bandit
|
36 |
rev: 1.7.5
|
|
|
31 |
additional_dependencies:
|
32 |
[
|
33 |
'types-PyYAML',
|
34 |
+
'pydantic>=2.5.3',
|
35 |
]
|
36 |
- repo: https://github.com/PyCQA/bandit
|
37 |
rev: 1.7.5
|
README.md
CHANGED
@@ -543,7 +543,7 @@ is_mistral_derived_model:
|
|
543 |
is_qwen_derived_model:
|
544 |
|
545 |
# optional overrides to the base model configuration
|
546 |
-
|
547 |
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
548 |
rope_scaling:
|
549 |
type: # linear | dynamic
|
@@ -560,8 +560,6 @@ bnb_config_kwargs:
|
|
560 |
|
561 |
# Whether you are training a 4-bit GPTQ quantized model
|
562 |
gptq: true
|
563 |
-
gptq_groupsize: 128 # group size
|
564 |
-
gptq_model_v1: false # v1 or v2
|
565 |
|
566 |
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
567 |
load_in_8bit: true
|
@@ -819,10 +817,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
|
|
819 |
# For one_cycle optim
|
820 |
lr_div_factor: # Learning rate div factor
|
821 |
|
822 |
-
# For log_sweep optim
|
823 |
-
log_sweep_min_lr:
|
824 |
-
log_sweep_max_lr:
|
825 |
-
|
826 |
# Specify optimizer
|
827 |
# Valid values are driven by the Transformers OptimizerNames class, see:
|
828 |
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
|
|
543 |
is_qwen_derived_model:
|
544 |
|
545 |
# optional overrides to the base model configuration
|
546 |
+
model_config_overrides:
|
547 |
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
548 |
rope_scaling:
|
549 |
type: # linear | dynamic
|
|
|
560 |
|
561 |
# Whether you are training a 4-bit GPTQ quantized model
|
562 |
gptq: true
|
|
|
|
|
563 |
|
564 |
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
565 |
load_in_8bit: true
|
|
|
817 |
# For one_cycle optim
|
818 |
lr_div_factor: # Learning rate div factor
|
819 |
|
|
|
|
|
|
|
|
|
820 |
# Specify optimizer
|
821 |
# Valid values are driven by the Transformers OptimizerNames class, see:
|
822 |
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
requirements.txt
CHANGED
@@ -6,6 +6,7 @@ tokenizers==0.15.0
|
|
6 |
bitsandbytes>=0.41.1
|
7 |
accelerate==0.26.1
|
8 |
deepspeed>=0.13.1
|
|
|
9 |
addict
|
10 |
fire
|
11 |
PyYAML>=6.0
|
@@ -27,7 +28,7 @@ scipy
|
|
27 |
scikit-learn==1.2.2
|
28 |
pynvml
|
29 |
art
|
30 |
-
fschat==0.2.
|
31 |
gradio==3.50.2
|
32 |
tensorboard
|
33 |
|
|
|
6 |
bitsandbytes>=0.41.1
|
7 |
accelerate==0.26.1
|
8 |
deepspeed>=0.13.1
|
9 |
+
pydantic>=2.5.3
|
10 |
addict
|
11 |
fire
|
12 |
PyYAML>=6.0
|
|
|
28 |
scikit-learn==1.2.2
|
29 |
pynvml
|
30 |
art
|
31 |
+
fschat==0.2.36
|
32 |
gradio==3.50.2
|
33 |
tensorboard
|
34 |
|
src/axolotl/cli/__init__.py
CHANGED
@@ -24,11 +24,13 @@ from art import text2art
|
|
24 |
from huggingface_hub import HfApi
|
25 |
from huggingface_hub.utils import LocalTokenNotFoundError
|
26 |
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
|
|
27 |
|
28 |
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
29 |
from axolotl.logging_config import configure_logging
|
30 |
from axolotl.train import TrainDatasetMeta
|
31 |
from axolotl.utils.config import (
|
|
|
32 |
normalize_cfg_datasets,
|
33 |
normalize_config,
|
34 |
validate_config,
|
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|
328 |
# load the config from the yaml file
|
329 |
with open(config, encoding="utf-8") as file:
|
330 |
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
331 |
-
cfg.axolotl_config_path = config
|
332 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
333 |
# then overwrite the value
|
334 |
cfg_keys = cfg.keys()
|
@@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|
341 |
else:
|
342 |
cfg[k] = kwargs[k]
|
343 |
|
344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
prepare_optim_env(cfg)
|
347 |
|
|
|
24 |
from huggingface_hub import HfApi
|
25 |
from huggingface_hub.utils import LocalTokenNotFoundError
|
26 |
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
27 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
28 |
|
29 |
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
30 |
from axolotl.logging_config import configure_logging
|
31 |
from axolotl.train import TrainDatasetMeta
|
32 |
from axolotl.utils.config import (
|
33 |
+
GPUCapabilities,
|
34 |
normalize_cfg_datasets,
|
35 |
normalize_config,
|
36 |
validate_config,
|
|
|
330 |
# load the config from the yaml file
|
331 |
with open(config, encoding="utf-8") as file:
|
332 |
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
|
333 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
334 |
# then overwrite the value
|
335 |
cfg_keys = cfg.keys()
|
|
|
342 |
else:
|
343 |
cfg[k] = kwargs[k]
|
344 |
|
345 |
+
cfg.axolotl_config_path = config
|
346 |
+
|
347 |
+
try:
|
348 |
+
device_props = torch.cuda.get_device_properties("cuda")
|
349 |
+
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
350 |
+
except: # pylint: disable=bare-except # noqa: E722
|
351 |
+
gpu_version = None
|
352 |
+
|
353 |
+
capabilities = GPUCapabilities(
|
354 |
+
bf16=is_torch_bf16_gpu_available(),
|
355 |
+
n_gpu=os.environ.get("WORLD_SIZE", 1),
|
356 |
+
compute_capability=gpu_version,
|
357 |
+
)
|
358 |
+
|
359 |
+
cfg = validate_config(cfg, capabilities=capabilities)
|
360 |
|
361 |
prepare_optim_env(cfg)
|
362 |
|
src/axolotl/utils/{config.py → config/__init__.py}
RENAMED
@@ -3,11 +3,17 @@ import json
|
|
3 |
import logging
|
4 |
import os
|
5 |
from pathlib import Path
|
|
|
6 |
|
7 |
import torch
|
8 |
from transformers.utils import is_torch_bf16_gpu_available
|
9 |
|
10 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
|
|
|
|
|
|
|
|
11 |
from axolotl.utils.dict import DictDefault
|
12 |
from axolotl.utils.models import load_model_config
|
13 |
|
@@ -191,7 +197,15 @@ def normalize_cfg_datasets(cfg):
|
|
191 |
cfg.datasets[idx].conversation = "chatml"
|
192 |
|
193 |
|
194 |
-
def validate_config(cfg):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
"""
|
196 |
This is a "pre-validation" step that handles the yaml configuration before we have any
|
197 |
information about the model architecture
|
@@ -480,9 +494,6 @@ def validate_config(cfg):
|
|
480 |
if cfg.rope_scaling:
|
481 |
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
482 |
|
483 |
-
if cfg.warmup_steps and cfg.warmup_ratio:
|
484 |
-
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
485 |
-
|
486 |
if cfg.wandb_run_id and not cfg.wandb_name:
|
487 |
cfg.wandb_name = cfg.wandb_run_id
|
488 |
|
|
|
3 |
import logging
|
4 |
import os
|
5 |
from pathlib import Path
|
6 |
+
from typing import Optional
|
7 |
|
8 |
import torch
|
9 |
from transformers.utils import is_torch_bf16_gpu_available
|
10 |
|
11 |
from axolotl.utils.bench import log_gpu_memory_usage
|
12 |
+
from axolotl.utils.config.models.input.v0_4_1 import (
|
13 |
+
AxolotlConfigWCapabilities,
|
14 |
+
AxolotlInputConfig,
|
15 |
+
)
|
16 |
+
from axolotl.utils.config.models.internals import GPUCapabilities
|
17 |
from axolotl.utils.dict import DictDefault
|
18 |
from axolotl.utils.models import load_model_config
|
19 |
|
|
|
197 |
cfg.datasets[idx].conversation = "chatml"
|
198 |
|
199 |
|
200 |
+
def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
|
201 |
+
if capabilities:
|
202 |
+
return DictDefault(
|
203 |
+
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
|
204 |
+
)
|
205 |
+
return DictDefault(dict(AxolotlInputConfig(**cfg.to_dict())))
|
206 |
+
|
207 |
+
|
208 |
+
def legacy_validate_config(cfg):
|
209 |
"""
|
210 |
This is a "pre-validation" step that handles the yaml configuration before we have any
|
211 |
information about the model architecture
|
|
|
494 |
if cfg.rope_scaling:
|
495 |
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
496 |
|
|
|
|
|
|
|
497 |
if cfg.wandb_run_id and not cfg.wandb_name:
|
498 |
cfg.wandb_name = cfg.wandb_run_id
|
499 |
|
src/axolotl/utils/config/models/__init__.py
ADDED
File without changes
|
src/axolotl/utils/config/models/input/__init__.py
ADDED
File without changes
|
src/axolotl/utils/config/models/input/next/__init__.py
ADDED
File without changes
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
ADDED
@@ -0,0 +1,931 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module for pydantic models for configuration
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
from enum import Enum
|
8 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
9 |
+
|
10 |
+
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
11 |
+
from transformers import SchedulerType
|
12 |
+
from transformers.training_args import OptimizerNames
|
13 |
+
|
14 |
+
from axolotl.utils.config.models.internals import GPUCapabilities
|
15 |
+
|
16 |
+
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
17 |
+
|
18 |
+
|
19 |
+
class DeprecatedParameters(BaseModel):
|
20 |
+
"""configurations that are deprecated"""
|
21 |
+
|
22 |
+
max_packed_sequence_len: Optional[int] = None
|
23 |
+
rope_scaling: Optional[Any] = None
|
24 |
+
noisy_embedding_alpha: Optional[float] = None
|
25 |
+
|
26 |
+
@field_validator("max_packed_sequence_len")
|
27 |
+
@classmethod
|
28 |
+
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
29 |
+
if max_packed_sequence_len:
|
30 |
+
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
31 |
+
return max_packed_sequence_len
|
32 |
+
|
33 |
+
@field_validator("rope_scaling")
|
34 |
+
@classmethod
|
35 |
+
def validate_rope_scaling(cls, rope_scaling):
|
36 |
+
if rope_scaling:
|
37 |
+
raise DeprecationWarning(
|
38 |
+
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
39 |
+
)
|
40 |
+
return rope_scaling
|
41 |
+
|
42 |
+
@field_validator("noisy_embedding_alpha")
|
43 |
+
@classmethod
|
44 |
+
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
45 |
+
if noisy_embedding_alpha:
|
46 |
+
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
47 |
+
return noisy_embedding_alpha
|
48 |
+
|
49 |
+
|
50 |
+
class PretrainingDataset(BaseModel):
|
51 |
+
"""pretraining dataset configuration subset"""
|
52 |
+
|
53 |
+
path: Optional[str] = None
|
54 |
+
|
55 |
+
|
56 |
+
class UserDefinedPrompterType(BaseModel):
|
57 |
+
"""structure for user defined prompt types"""
|
58 |
+
|
59 |
+
system_prompt: Optional[str] = None
|
60 |
+
system_format: Optional[str] = None
|
61 |
+
field_system: Optional[str] = None
|
62 |
+
field_instruction: Optional[str] = None
|
63 |
+
field_input: Optional[str] = None
|
64 |
+
field_output: Optional[str] = None
|
65 |
+
|
66 |
+
format: Optional[str] = None
|
67 |
+
no_input_format: Optional[str] = None
|
68 |
+
field: Optional[str] = None
|
69 |
+
|
70 |
+
|
71 |
+
class SFTDataset(BaseModel):
|
72 |
+
"""SFT configuration subset"""
|
73 |
+
|
74 |
+
path: Optional[str] = None
|
75 |
+
split: Optional[str] = None
|
76 |
+
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
77 |
+
shards: Optional[int] = None
|
78 |
+
conversation: Optional[str] = None
|
79 |
+
data_files: Optional[List[str]] = None
|
80 |
+
name: Optional[str] = None
|
81 |
+
ds_type: Optional[str] = None
|
82 |
+
train_on_split: Optional[str] = None
|
83 |
+
|
84 |
+
field_human: Optional[str] = None
|
85 |
+
field_model: Optional[str] = None
|
86 |
+
|
87 |
+
|
88 |
+
class DPODataset(BaseModel):
|
89 |
+
"""DPO configuration subset"""
|
90 |
+
|
91 |
+
path: Optional[str] = None
|
92 |
+
split: Optional[str] = None
|
93 |
+
type: Optional[str] = None
|
94 |
+
data_files: Optional[List[str]] = None
|
95 |
+
|
96 |
+
|
97 |
+
class RLType(str, Enum):
|
98 |
+
"""RL trainer type configuration subset"""
|
99 |
+
|
100 |
+
dpo = "dpo" # pylint: disable=invalid-name
|
101 |
+
ipo = "ipo" # pylint: disable=invalid-name
|
102 |
+
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
103 |
+
|
104 |
+
|
105 |
+
class ChatTemplate(str, Enum):
|
106 |
+
"""Chat templates configuration subset"""
|
107 |
+
|
108 |
+
chatml = "chatml" # pylint: disable=invalid-name
|
109 |
+
inst = "inst" # pylint: disable=invalid-name
|
110 |
+
|
111 |
+
|
112 |
+
class LoftQConfig(BaseModel):
|
113 |
+
"""LoftQ configuration subset"""
|
114 |
+
|
115 |
+
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
|
116 |
+
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
|
117 |
+
|
118 |
+
|
119 |
+
class PeftConfig(BaseModel):
|
120 |
+
"""peftq configuration subset"""
|
121 |
+
|
122 |
+
loftq_config: Optional[LoftQConfig] = None
|
123 |
+
|
124 |
+
|
125 |
+
class AutoType(str, Enum):
|
126 |
+
"""auto type string configuration subset - used for bf16"""
|
127 |
+
|
128 |
+
AUTO = "auto"
|
129 |
+
|
130 |
+
|
131 |
+
class SpecialTokensConfig(BaseModel):
|
132 |
+
"""Special tokens configuration subset"""
|
133 |
+
|
134 |
+
bos_token: Optional[str] = None
|
135 |
+
eos_token: Optional[str] = None
|
136 |
+
pad_token: Optional[str] = None
|
137 |
+
unk_token: Optional[str] = None
|
138 |
+
additional_special_tokens: Optional[List[str]] = None
|
139 |
+
|
140 |
+
|
141 |
+
class LoraConfig(BaseModel):
|
142 |
+
"""Peft / LoRA configuration subset"""
|
143 |
+
|
144 |
+
load_in_8bit: Optional[bool] = Field(default=False)
|
145 |
+
load_in_4bit: Optional[bool] = Field(default=False)
|
146 |
+
|
147 |
+
adapter: Optional[str] = None
|
148 |
+
lora_model_dir: Optional[str] = None
|
149 |
+
lora_rank: Optional[int] = None
|
150 |
+
lora_alpha: Optional[int] = None
|
151 |
+
lora_fan_in_fan_out: Optional[bool] = None
|
152 |
+
lora_target_modules: Optional[List[str]] = None
|
153 |
+
lora_target_linear: Optional[bool] = None
|
154 |
+
lora_modules_to_save: Optional[List[str]] = None
|
155 |
+
lora_dropout: Optional[float] = None
|
156 |
+
peft_layers_to_transform: Optional[List[int]] = None
|
157 |
+
peft: Optional[PeftConfig] = None
|
158 |
+
|
159 |
+
lora_on_cpu: Optional[bool] = None
|
160 |
+
gptq: Optional[bool] = None
|
161 |
+
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
162 |
+
|
163 |
+
merge_lora: Optional[bool] = None
|
164 |
+
|
165 |
+
@model_validator(mode="before")
|
166 |
+
@classmethod
|
167 |
+
def validate_adapter(cls, data):
|
168 |
+
if not data.get("adapter") and (
|
169 |
+
data.get("load_in_8bit") or data.get("load_in_4bit")
|
170 |
+
):
|
171 |
+
raise ValueError(
|
172 |
+
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
173 |
+
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
174 |
+
)
|
175 |
+
return data
|
176 |
+
|
177 |
+
@model_validator(mode="after")
|
178 |
+
def validate_qlora(self):
|
179 |
+
if self.adapter == "qlora":
|
180 |
+
if self.merge_lora:
|
181 |
+
# can't merge qlora if loaded in 8bit or 4bit
|
182 |
+
if self.load_in_8bit:
|
183 |
+
raise ValueError("Can't merge qlora if loaded in 8bit")
|
184 |
+
|
185 |
+
if self.gptq:
|
186 |
+
raise ValueError("Can't merge qlora if gptq")
|
187 |
+
|
188 |
+
if self.load_in_4bit:
|
189 |
+
raise ValueError("Can't merge qlora if loaded in 4bit")
|
190 |
+
|
191 |
+
else:
|
192 |
+
if self.load_in_8bit:
|
193 |
+
raise ValueError("Can't load qlora in 8bit")
|
194 |
+
|
195 |
+
if self.gptq:
|
196 |
+
raise ValueError("Can't load qlora if gptq")
|
197 |
+
|
198 |
+
if not self.load_in_4bit:
|
199 |
+
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
200 |
+
return self
|
201 |
+
|
202 |
+
|
203 |
+
class ReLoRAConfig(BaseModel):
|
204 |
+
"""ReLoRA configuration subset"""
|
205 |
+
|
206 |
+
relora_steps: Optional[int] = None
|
207 |
+
relora_warmup_steps: Optional[int] = None
|
208 |
+
relora_anneal_steps: Optional[int] = None
|
209 |
+
relora_prune_ratio: Optional[float] = None
|
210 |
+
relora_cpu_offload: Optional[bool] = None
|
211 |
+
|
212 |
+
|
213 |
+
class ModelInputConfig(BaseModel):
|
214 |
+
"""model to train on configuration subset"""
|
215 |
+
|
216 |
+
base_model: str
|
217 |
+
base_model_config: Optional[str] = None
|
218 |
+
tokenizer_config: Optional[str] = None
|
219 |
+
tokenizer_use_fast: Optional[bool] = None
|
220 |
+
tokenizer_legacy: Optional[bool] = None
|
221 |
+
tokenizer_type: Optional[str] = Field(
|
222 |
+
default=None, metadata={"help": "transformers tokenizer class"}
|
223 |
+
)
|
224 |
+
model_type: Optional[str] = Field(default=None)
|
225 |
+
model_revision: Optional[str] = None
|
226 |
+
trust_remote_code: Optional[bool] = None
|
227 |
+
|
228 |
+
model_config_overrides: Optional[Dict[str, Any]] = None
|
229 |
+
|
230 |
+
@field_validator("trust_remote_code")
|
231 |
+
@classmethod
|
232 |
+
def hint_trust_remote_code(cls, trust_remote_code):
|
233 |
+
if trust_remote_code:
|
234 |
+
LOG.warning(
|
235 |
+
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
236 |
+
)
|
237 |
+
return trust_remote_code
|
238 |
+
|
239 |
+
|
240 |
+
class HyperparametersConfig(BaseModel):
|
241 |
+
"""training hyperparams configuration subset"""
|
242 |
+
|
243 |
+
gradient_accumulation_steps: Optional[int] = Field(default=1)
|
244 |
+
micro_batch_size: Optional[int] = Field(
|
245 |
+
default=1,
|
246 |
+
metadata={"help": "per gpu micro batch size for training"},
|
247 |
+
)
|
248 |
+
batch_size: Optional[int] = Field(
|
249 |
+
default=None,
|
250 |
+
metadata={
|
251 |
+
"help": "Total batch size, we do not recommended setting this manually"
|
252 |
+
},
|
253 |
+
)
|
254 |
+
eval_batch_size: Optional[int] = Field(
|
255 |
+
default=None,
|
256 |
+
metadata={
|
257 |
+
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
258 |
+
},
|
259 |
+
)
|
260 |
+
|
261 |
+
train_on_inputs: Optional[bool] = None
|
262 |
+
group_by_length: Optional[bool] = None
|
263 |
+
|
264 |
+
learning_rate: Union[str, float]
|
265 |
+
weight_decay: Optional[float] = None
|
266 |
+
optimizer: Optional[OptimizerNames] = None
|
267 |
+
torchdistx_path: Optional[str] = None
|
268 |
+
lr_scheduler: Optional[SchedulerType] = None
|
269 |
+
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
270 |
+
lr_quadratic_warmup: Optional[bool] = None
|
271 |
+
cosine_min_lr_ratio: Optional[float] = None
|
272 |
+
cosine_constant_lr_ratio: Optional[float] = None
|
273 |
+
lr_div_factor: Optional[float] = None
|
274 |
+
|
275 |
+
adam_epsilon: Optional[float] = None
|
276 |
+
adam_beta1: Optional[float] = None
|
277 |
+
adam_beta2: Optional[float] = None
|
278 |
+
max_grad_norm: Optional[float] = None
|
279 |
+
num_epochs: int = Field(default=1)
|
280 |
+
|
281 |
+
@field_validator("batch_size")
|
282 |
+
@classmethod
|
283 |
+
def hint_batch_size_set(cls, batch_size):
|
284 |
+
if batch_size:
|
285 |
+
LOG.warning(
|
286 |
+
"%s\n%s",
|
287 |
+
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
288 |
+
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
289 |
+
)
|
290 |
+
return batch_size
|
291 |
+
|
292 |
+
|
293 |
+
class ModelOutputConfig(BaseModel):
|
294 |
+
"""model save configuration subset"""
|
295 |
+
|
296 |
+
output_dir: str = Field(default="./model-out")
|
297 |
+
hub_model_id: Optional[str] = None
|
298 |
+
hub_strategy: Optional[str] = None
|
299 |
+
save_safetensors: Optional[bool] = None
|
300 |
+
|
301 |
+
|
302 |
+
class MLFlowConfig(BaseModel):
|
303 |
+
"""mlflow configuration subset"""
|
304 |
+
|
305 |
+
use_mlflow: Optional[str] = None
|
306 |
+
mlflow_tracking_uri: Optional[str] = None
|
307 |
+
mlflow_experiment_name: Optional[str] = None
|
308 |
+
|
309 |
+
|
310 |
+
class WandbConfig(BaseModel):
|
311 |
+
"""wandb configuration subset"""
|
312 |
+
|
313 |
+
use_wandb: Optional[bool] = None
|
314 |
+
wandb_name: Optional[str] = None
|
315 |
+
wandb_run_id: Optional[str] = None
|
316 |
+
wandb_mode: Optional[str] = None
|
317 |
+
wandb_project: Optional[str] = None
|
318 |
+
wandb_entity: Optional[str] = None
|
319 |
+
wandb_watch: Optional[str] = None
|
320 |
+
wandb_log_model: Optional[str] = None
|
321 |
+
|
322 |
+
@model_validator(mode="before")
|
323 |
+
@classmethod
|
324 |
+
def check_wandb_run(cls, data):
|
325 |
+
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
326 |
+
data["wandb_name"] = data.get("wandb_run_id")
|
327 |
+
|
328 |
+
LOG.warning(
|
329 |
+
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
330 |
+
)
|
331 |
+
|
332 |
+
return data
|
333 |
+
|
334 |
+
|
335 |
+
# pylint: disable=too-many-public-methods,too-many-ancestors
|
336 |
+
class AxolotlInputConfig(
|
337 |
+
ModelInputConfig,
|
338 |
+
LoraConfig,
|
339 |
+
ReLoRAConfig,
|
340 |
+
HyperparametersConfig,
|
341 |
+
WandbConfig,
|
342 |
+
MLFlowConfig,
|
343 |
+
DeprecatedParameters,
|
344 |
+
BaseModel,
|
345 |
+
):
|
346 |
+
"""wrapper of all config options"""
|
347 |
+
|
348 |
+
strict: Optional[bool] = Field(default=False)
|
349 |
+
resume_from_checkpoint: Optional[str] = None
|
350 |
+
auto_resume_from_checkpoints: Optional[bool] = None
|
351 |
+
resize_token_embeddings_to_32x: Optional[bool] = None
|
352 |
+
|
353 |
+
rl: Optional[RLType] = None
|
354 |
+
|
355 |
+
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
356 |
+
dataset_prepared_path: Optional[str] = None
|
357 |
+
dataset_shard_num: Optional[int] = None
|
358 |
+
dataset_shard_idx: Optional[int] = None
|
359 |
+
|
360 |
+
pretraining_dataset: Optional[ # type: ignore
|
361 |
+
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
|
362 |
+
] = Field(
|
363 |
+
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
364 |
+
)
|
365 |
+
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
366 |
+
dataset_keep_in_memory: Optional[bool] = None
|
367 |
+
dataloader_pin_memory: Optional[bool] = None
|
368 |
+
dataloader_num_workers: Optional[int] = None
|
369 |
+
dataloader_prefetch_factor: Optional[int] = None
|
370 |
+
dataloader_drop_last: Optional[bool] = None
|
371 |
+
|
372 |
+
push_dataset_to_hub: Optional[str] = None
|
373 |
+
hf_use_auth_token: Optional[bool] = None
|
374 |
+
|
375 |
+
device: Optional[Any] = None
|
376 |
+
device_map: Optional[Any] = None
|
377 |
+
world_size: Optional[int] = None
|
378 |
+
local_rank: Optional[int] = None
|
379 |
+
ddp: Optional[bool] = None
|
380 |
+
|
381 |
+
seed: Optional[int] = None
|
382 |
+
ddp_timeout: Optional[int] = None
|
383 |
+
ddp_bucket_cap_mb: Optional[int] = None
|
384 |
+
ddp_broadcast_buffers: Optional[bool] = None
|
385 |
+
ddp_find_unused_parameters: Optional[bool] = None
|
386 |
+
|
387 |
+
eval_table_size: Optional[int] = None
|
388 |
+
eval_max_new_tokens: Optional[int] = None
|
389 |
+
do_causal_lm_eval: Optional[bool] = None
|
390 |
+
eval_causal_lm_metrics: Optional[List[str]] = None
|
391 |
+
do_bench_eval: Optional[bool] = None
|
392 |
+
bench_dataset: Optional[str] = None
|
393 |
+
metric_for_best_model: Optional[str] = None
|
394 |
+
greater_is_better: Optional[bool] = None
|
395 |
+
|
396 |
+
loss_watchdog_threshold: Optional[float] = None
|
397 |
+
loss_watchdog_patience: Optional[int] = None
|
398 |
+
|
399 |
+
bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
|
400 |
+
fp16: Optional[bool] = None
|
401 |
+
bfloat16: Optional[bool] = None # for non-AMP cases
|
402 |
+
float16: Optional[bool] = None # for non-AMP cases
|
403 |
+
tf32: Optional[bool] = None
|
404 |
+
float32: Optional[bool] = None
|
405 |
+
|
406 |
+
# torch_dtype: Optional[torch.dtype]
|
407 |
+
|
408 |
+
gradient_checkpointing: Optional[bool] = Field(default=False)
|
409 |
+
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
|
410 |
+
|
411 |
+
unfrozen_parameters: Optional[List[str]] = None
|
412 |
+
|
413 |
+
sequence_len: int = Field(default=1024)
|
414 |
+
sample_packing: Optional[bool] = None
|
415 |
+
eval_sample_packing: Optional[bool] = None
|
416 |
+
pad_to_sequence_len: Optional[bool] = None
|
417 |
+
|
418 |
+
xformers_attention: Optional[bool] = None
|
419 |
+
sdp_attention: Optional[bool] = None
|
420 |
+
s2_attention: Optional[bool] = None
|
421 |
+
flash_attention: Optional[bool] = None
|
422 |
+
flash_attn_cross_entropy: Optional[bool] = None
|
423 |
+
flash_attn_rms_norm: Optional[bool] = None
|
424 |
+
flash_attn_fuse_qkv: Optional[bool] = None
|
425 |
+
flash_attn_fuse_mlp: Optional[bool] = None
|
426 |
+
flash_optimum: Optional[bool] = None
|
427 |
+
|
428 |
+
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
429 |
+
fsdp: Optional[List[str]] = None
|
430 |
+
fsdp_config: Optional[Dict[str, Any]] = None
|
431 |
+
|
432 |
+
val_set_size: Optional[float] = Field(default=0.0)
|
433 |
+
|
434 |
+
special_tokens: Optional[SpecialTokensConfig] = None
|
435 |
+
tokens: Optional[List[str]] = None
|
436 |
+
|
437 |
+
torch_compile: Optional[bool] = None
|
438 |
+
torch_compile_backend: Optional[str] = None
|
439 |
+
|
440 |
+
max_steps: Optional[int] = None
|
441 |
+
warmup_steps: Optional[int] = None
|
442 |
+
warmup_ratio: Optional[float] = None
|
443 |
+
eval_steps: Optional[int] = None
|
444 |
+
evaluation_strategy: Optional[str] = None
|
445 |
+
save_steps: Optional[int] = None
|
446 |
+
saves_per_epoch: Optional[int] = None
|
447 |
+
save_strategy: Optional[str] = None
|
448 |
+
save_total_limit: Optional[int] = None
|
449 |
+
logging_steps: Optional[int] = None
|
450 |
+
early_stopping_patience: Optional[int] = None
|
451 |
+
|
452 |
+
neftune_noise_alpha: Optional[float] = None
|
453 |
+
|
454 |
+
max_memory: Optional[Union[int, str]] = None
|
455 |
+
gpu_memory_limit: Optional[Union[int, str]] = None
|
456 |
+
|
457 |
+
chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
|
458 |
+
default_system_message: Optional[str] = None
|
459 |
+
|
460 |
+
# INTERNALS - document for now, generally not set externally
|
461 |
+
is_preprocess: Optional[bool] = None
|
462 |
+
|
463 |
+
total_num_tokens: Optional[int] = None
|
464 |
+
total_supervised_tokens: Optional[int] = None
|
465 |
+
sample_packing_eff_est: Optional[float] = None
|
466 |
+
axolotl_config_path: Optional[str] = None
|
467 |
+
|
468 |
+
is_falcon_derived_model: Optional[bool] = Field(default=False)
|
469 |
+
is_llama_derived_model: Optional[bool] = Field(default=False)
|
470 |
+
is_mistral_derived_model: Optional[bool] = Field(default=False)
|
471 |
+
is_qwen_derived_model: Optional[bool] = Field(default=False)
|
472 |
+
|
473 |
+
@field_validator("datasets", mode="before")
|
474 |
+
@classmethod
|
475 |
+
def fix_sharegpt_datasets(cls, datasets):
|
476 |
+
for idx, ds_cfg in enumerate(datasets):
|
477 |
+
if not ds_cfg["type"]:
|
478 |
+
continue
|
479 |
+
if ds_cfg["type"] == "sharegpt:chat":
|
480 |
+
LOG.warning(
|
481 |
+
PendingDeprecationWarning(
|
482 |
+
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
483 |
+
)
|
484 |
+
)
|
485 |
+
datasets[idx]["type"] = "sharegpt"
|
486 |
+
if "sharegpt_simple" in ds_cfg["type"]:
|
487 |
+
LOG.warning(
|
488 |
+
PendingDeprecationWarning(
|
489 |
+
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
490 |
+
)
|
491 |
+
)
|
492 |
+
datasets[idx]["type"] = datasets[idx]["type"].replace(
|
493 |
+
"sharegpt_simple", "sharegpt"
|
494 |
+
)
|
495 |
+
return datasets
|
496 |
+
|
497 |
+
@model_validator(mode="before")
|
498 |
+
@classmethod
|
499 |
+
def check_batch_size_fields(cls, data):
|
500 |
+
fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size")
|
501 |
+
non_empty_count = sum(1 for field in fields if data.get(field))
|
502 |
+
|
503 |
+
if non_empty_count < 2:
|
504 |
+
raise ValueError(f"At least two of {', '.join(fields)} must be set")
|
505 |
+
return data
|
506 |
+
|
507 |
+
@model_validator(mode="before")
|
508 |
+
@classmethod
|
509 |
+
def check_pretraining_w_max_steps(cls, data):
|
510 |
+
if data.get("pretraining_dataset") and not data.get("max_steps"):
|
511 |
+
raise ValueError(
|
512 |
+
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
513 |
+
)
|
514 |
+
return data
|
515 |
+
|
516 |
+
@model_validator(mode="before")
|
517 |
+
@classmethod
|
518 |
+
def check_pretraining_w_group_by_length(cls, data):
|
519 |
+
if data.get("pretraining_dataset") and data.get("group_by_length"):
|
520 |
+
LOG.warning(
|
521 |
+
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
522 |
+
)
|
523 |
+
return data
|
524 |
+
|
525 |
+
@model_validator(mode="before")
|
526 |
+
@classmethod
|
527 |
+
def check_gptq_w_revision(cls, data):
|
528 |
+
if data.get("gptq") and data.get("model_revision"):
|
529 |
+
raise ValueError(
|
530 |
+
"model_revision is not supported for GPTQ models. "
|
531 |
+
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
532 |
+
+ "point to its path, and remove model_revision from the config."
|
533 |
+
)
|
534 |
+
return data
|
535 |
+
|
536 |
+
@model_validator(mode="before")
|
537 |
+
@classmethod
|
538 |
+
def check_sample_packing_w_xformers(cls, data):
|
539 |
+
if data.get("sample_packing") and data.get("xformers_attention"):
|
540 |
+
raise ValueError(
|
541 |
+
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
542 |
+
)
|
543 |
+
|
544 |
+
return data
|
545 |
+
|
546 |
+
@model_validator(mode="before")
|
547 |
+
@classmethod
|
548 |
+
def check_sample_packing_w_rl(cls, data):
|
549 |
+
if data.get("sample_packing") and data.get("rl"):
|
550 |
+
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
551 |
+
return data
|
552 |
+
|
553 |
+
@model_validator(mode="before")
|
554 |
+
@classmethod
|
555 |
+
def hint_sample_packing_padding(cls, data):
|
556 |
+
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
|
557 |
+
LOG.warning(
|
558 |
+
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
559 |
+
)
|
560 |
+
return data
|
561 |
+
|
562 |
+
@model_validator(mode="before")
|
563 |
+
@classmethod
|
564 |
+
def check_gas_bsz(cls, data):
|
565 |
+
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
|
566 |
+
raise ValueError(
|
567 |
+
"please set only one of gradient_accumulation_steps or batch_size"
|
568 |
+
)
|
569 |
+
return data
|
570 |
+
|
571 |
+
@model_validator(mode="before")
|
572 |
+
@classmethod
|
573 |
+
def hint_eval_train_mbsz(cls, data):
|
574 |
+
if (
|
575 |
+
data.get("eval_batch_size")
|
576 |
+
and data.get("micro_batch_size")
|
577 |
+
and data.get("eval_batch_size") != data.get("micro_batch_size")
|
578 |
+
):
|
579 |
+
LOG.warning(
|
580 |
+
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
581 |
+
)
|
582 |
+
return data
|
583 |
+
|
584 |
+
@model_validator(mode="before")
|
585 |
+
@classmethod
|
586 |
+
def check_push_ds_auth(cls, data):
|
587 |
+
if (
|
588 |
+
data.get("push_dataset_to_hub")
|
589 |
+
and data.get("hf_use_auth_token") is not True
|
590 |
+
):
|
591 |
+
raise ValueError(
|
592 |
+
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
593 |
+
)
|
594 |
+
return data
|
595 |
+
|
596 |
+
@model_validator(mode="after")
|
597 |
+
def check_falcon_fsdp(self):
|
598 |
+
if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp:
|
599 |
+
raise ValueError("FSDP is not supported for falcon models")
|
600 |
+
return self
|
601 |
+
|
602 |
+
@model_validator(mode="after")
|
603 |
+
def check_mpt_checkpointing(self):
|
604 |
+
if (
|
605 |
+
self.base_model and "mpt" in self.base_model.lower()
|
606 |
+
) and self.gradient_checkpointing:
|
607 |
+
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
608 |
+
return self
|
609 |
+
|
610 |
+
@model_validator(mode="after")
|
611 |
+
def check_better_transformers(self):
|
612 |
+
if self.flash_optimum is True:
|
613 |
+
if self.adapter:
|
614 |
+
LOG.warning(
|
615 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
616 |
+
)
|
617 |
+
if self.fp16 or self.bf16:
|
618 |
+
raise ValueError("AMP is not supported with BetterTransformer")
|
619 |
+
if self.float16 is not True and self.bfloat16 is not True:
|
620 |
+
LOG.warning(
|
621 |
+
"You should probably set bfloat16 or float16 to true to "
|
622 |
+
"load the model in float16 for BetterTransformers"
|
623 |
+
)
|
624 |
+
return self
|
625 |
+
|
626 |
+
@model_validator(mode="after")
|
627 |
+
def check_adamw_optimizer_params(self):
|
628 |
+
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
629 |
+
not self.optimizer or "adamw" not in self.optimizer.value
|
630 |
+
):
|
631 |
+
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
632 |
+
return self
|
633 |
+
|
634 |
+
@model_validator(mode="before")
|
635 |
+
@classmethod
|
636 |
+
def check_saves(cls, data):
|
637 |
+
if (
|
638 |
+
data.get("save_strategy")
|
639 |
+
and data.get("save_steps")
|
640 |
+
and data.get("save_strategy") != "steps"
|
641 |
+
):
|
642 |
+
raise ValueError(
|
643 |
+
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
644 |
+
)
|
645 |
+
if data.get("saves_per_epoch") and data.get("save_steps"):
|
646 |
+
raise ValueError(
|
647 |
+
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
648 |
+
)
|
649 |
+
return data
|
650 |
+
|
651 |
+
@model_validator(mode="before")
|
652 |
+
@classmethod
|
653 |
+
def check_push_save(cls, data):
|
654 |
+
if data.get("hub_model_id") and not (
|
655 |
+
data.get("save_steps") or data.get("saves_per_epoch")
|
656 |
+
):
|
657 |
+
LOG.warning(
|
658 |
+
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
659 |
+
)
|
660 |
+
return data
|
661 |
+
|
662 |
+
@model_validator(mode="before")
|
663 |
+
@classmethod
|
664 |
+
def check_evals(cls, data):
|
665 |
+
if (
|
666 |
+
data.get("evaluation_strategy")
|
667 |
+
and data.get("eval_steps")
|
668 |
+
and data.get("evaluation_strategy") != "steps"
|
669 |
+
):
|
670 |
+
raise ValueError(
|
671 |
+
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
672 |
+
)
|
673 |
+
|
674 |
+
if (
|
675 |
+
data.get("val_set_size") == 0
|
676 |
+
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
677 |
+
and not data.get("test_datasets")
|
678 |
+
):
|
679 |
+
raise ValueError(
|
680 |
+
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
681 |
+
)
|
682 |
+
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
683 |
+
raise ValueError(
|
684 |
+
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
685 |
+
)
|
686 |
+
if (
|
687 |
+
data.get("evals_per_epoch")
|
688 |
+
and data.get("evaluation_strategy")
|
689 |
+
and data.get("evaluation_strategy") != "steps"
|
690 |
+
):
|
691 |
+
raise ValueError(
|
692 |
+
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
693 |
+
)
|
694 |
+
|
695 |
+
return data
|
696 |
+
|
697 |
+
@model_validator(mode="before")
|
698 |
+
@classmethod
|
699 |
+
def check_eval_packing(cls, data):
|
700 |
+
if (
|
701 |
+
data.get("sample_packing")
|
702 |
+
and data.get("eval_table_size")
|
703 |
+
and data.get("eval_sample_packing") is not False
|
704 |
+
):
|
705 |
+
raise ValueError(
|
706 |
+
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
707 |
+
)
|
708 |
+
return data
|
709 |
+
|
710 |
+
@model_validator(mode="before")
|
711 |
+
@classmethod
|
712 |
+
def check_warmup(cls, data):
|
713 |
+
if data.get("warmup_steps") and data.get("warmup_ratio"):
|
714 |
+
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
715 |
+
return data
|
716 |
+
|
717 |
+
@model_validator(mode="before")
|
718 |
+
@classmethod
|
719 |
+
def check_neftune(cls, data):
|
720 |
+
if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
|
721 |
+
data["neftune_noise_alpha"] = data["noisy_embedding_alpha"]
|
722 |
+
del data["noisy_embedding_alpha"]
|
723 |
+
elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
|
724 |
+
raise ValueError(
|
725 |
+
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
726 |
+
)
|
727 |
+
return data
|
728 |
+
|
729 |
+
@field_validator("neftune_noise_alpha")
|
730 |
+
@classmethod
|
731 |
+
def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
|
732 |
+
if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0:
|
733 |
+
raise ValueError("neftune_noise_alpha must be > 0.0")
|
734 |
+
return neftune_noise_alpha
|
735 |
+
|
736 |
+
@model_validator(mode="before")
|
737 |
+
@classmethod
|
738 |
+
def check_frozen(cls, data):
|
739 |
+
if (
|
740 |
+
data.get("adapter")
|
741 |
+
and data.get("peft_layers_to_transform")
|
742 |
+
and data.get("unfrozen_parameters")
|
743 |
+
):
|
744 |
+
raise ValueError(
|
745 |
+
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
746 |
+
)
|
747 |
+
|
748 |
+
return data
|
749 |
+
|
750 |
+
@model_validator(mode="after")
|
751 |
+
def check_fft_possible_bad_config(self):
|
752 |
+
if (
|
753 |
+
# pylint: disable=too-many-boolean-expressions
|
754 |
+
not (self.bf16 or self.bfloat16)
|
755 |
+
and (self.fp16 or self.float16)
|
756 |
+
and not self.adapter
|
757 |
+
and not self.flash_attention
|
758 |
+
and self.sample_packing
|
759 |
+
):
|
760 |
+
LOG.warning(
|
761 |
+
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
762 |
+
)
|
763 |
+
# ValueError: Attempting to unscale FP16 gradients.
|
764 |
+
# OR
|
765 |
+
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
766 |
+
return self
|
767 |
+
|
768 |
+
@model_validator(mode="after")
|
769 |
+
def check_fused_lora(self):
|
770 |
+
if self.adapter in ["lora", "qlora"] and (
|
771 |
+
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
|
772 |
+
):
|
773 |
+
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
774 |
+
return self
|
775 |
+
|
776 |
+
@model_validator(mode="after")
|
777 |
+
def hint_lora_8bit(self):
|
778 |
+
loftq = (
|
779 |
+
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
|
780 |
+
)
|
781 |
+
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
|
782 |
+
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
783 |
+
return self
|
784 |
+
|
785 |
+
@model_validator(mode="after")
|
786 |
+
def check_early_stopping(self):
|
787 |
+
if self.early_stopping_patience:
|
788 |
+
if not self.save_steps or self.eval_steps:
|
789 |
+
raise ValueError(
|
790 |
+
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
791 |
+
)
|
792 |
+
if self.save_steps % self.eval_steps != 0:
|
793 |
+
raise ValueError(
|
794 |
+
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
795 |
+
)
|
796 |
+
return self
|
797 |
+
|
798 |
+
@model_validator(mode="after")
|
799 |
+
def check_relora(self):
|
800 |
+
if self.relora_steps:
|
801 |
+
if self.adapter not in ("lora", "qlora"):
|
802 |
+
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
803 |
+
|
804 |
+
if self.fsdp:
|
805 |
+
raise ValueError("fsdp not supported with ReLoRA")
|
806 |
+
|
807 |
+
if self.deepspeed:
|
808 |
+
raise ValueError("deepspeed not supported with ReLoRA")
|
809 |
+
|
810 |
+
if self.lr_scheduler == "one_cycle":
|
811 |
+
raise ValueError(
|
812 |
+
"ReLoRA is not compatible with the one_cycle scheduler"
|
813 |
+
)
|
814 |
+
|
815 |
+
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
|
816 |
+
raise ValueError("Fused modules are not supported with ReLoRA")
|
817 |
+
return self
|
818 |
+
|
819 |
+
@model_validator(mode="before")
|
820 |
+
@classmethod
|
821 |
+
def check_mem_mismatch(cls, data):
|
822 |
+
if (
|
823 |
+
data.get("max_memory") is not None
|
824 |
+
and data.get("gpu_memory_limit") is not None
|
825 |
+
):
|
826 |
+
raise ValueError(
|
827 |
+
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
828 |
+
)
|
829 |
+
return data
|
830 |
+
|
831 |
+
@model_validator(mode="before")
|
832 |
+
@classmethod
|
833 |
+
def check_use_reentrant_mismatch(cls, data):
|
834 |
+
if (
|
835 |
+
data.get("unfrozen_parameters")
|
836 |
+
and data.get("gradient_checkpointing_kwargs")
|
837 |
+
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
838 |
+
is True
|
839 |
+
):
|
840 |
+
# https://github.com/huggingface/transformers/issues/21381
|
841 |
+
raise ValueError(
|
842 |
+
"`use_reentrant` must be false when used with partially frozen model."
|
843 |
+
)
|
844 |
+
return data
|
845 |
+
|
846 |
+
@model_validator(mode="before")
|
847 |
+
@classmethod
|
848 |
+
def check_val_w_test_datasets(cls, data):
|
849 |
+
if data.get("test_datasets") and data.get("val_set_size"):
|
850 |
+
raise ValueError(
|
851 |
+
"non-zero val_set_size should not be used with test_datasets configuration"
|
852 |
+
)
|
853 |
+
return data
|
854 |
+
|
855 |
+
@model_validator(mode="before")
|
856 |
+
@classmethod
|
857 |
+
def check_fsdp_w_8bit_optimizer(cls, data):
|
858 |
+
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
|
859 |
+
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
|
860 |
+
return data
|
861 |
+
|
862 |
+
@model_validator(mode="before")
|
863 |
+
@classmethod
|
864 |
+
def check_causal_lm_evals(cls, data):
|
865 |
+
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
|
866 |
+
raise ValueError(
|
867 |
+
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
868 |
+
)
|
869 |
+
|
870 |
+
if data.get("eval_causal_lm_metrics"):
|
871 |
+
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
872 |
+
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
873 |
+
raise ValueError("eval_causal_lm_metrics must be a list")
|
874 |
+
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
875 |
+
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
|
876 |
+
raise ValueError(
|
877 |
+
f"eval_causal_lm_metrics must be one of {supported_metrics}"
|
878 |
+
)
|
879 |
+
return data
|
880 |
+
|
881 |
+
@model_validator(mode="before")
|
882 |
+
@classmethod
|
883 |
+
def check_dataset_or_pretraining_dataset(cls, data):
|
884 |
+
if data.get("datasets") is None and data.get("pretraining_dataset") is None:
|
885 |
+
raise ValueError("either datasets or pretraining_dataset is required")
|
886 |
+
return data
|
887 |
+
|
888 |
+
|
889 |
+
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
890 |
+
"""wrapper to valdiate gpu capabilities with the configured options"""
|
891 |
+
|
892 |
+
capabilities: GPUCapabilities
|
893 |
+
|
894 |
+
@model_validator(mode="after")
|
895 |
+
def check_bf16(self):
|
896 |
+
if self.capabilities.bf16:
|
897 |
+
if not self.bf16 and not self.bfloat16:
|
898 |
+
LOG.info(
|
899 |
+
"bf16 support detected, but not enabled for this configuration."
|
900 |
+
)
|
901 |
+
else:
|
902 |
+
if (
|
903 |
+
not self.merge_lora
|
904 |
+
and not self.is_preprocess
|
905 |
+
and (self.bf16 is True or self.bfloat16 is True)
|
906 |
+
):
|
907 |
+
raise ValueError(
|
908 |
+
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
909 |
+
)
|
910 |
+
return self
|
911 |
+
|
912 |
+
@model_validator(mode="before")
|
913 |
+
@classmethod
|
914 |
+
def check_sample_packing_w_sdpa_bf16(cls, data):
|
915 |
+
is_sm_90: bool = (
|
916 |
+
data["capabilities"]
|
917 |
+
and data["capabilities"].get("compute_capability") == "sm_90"
|
918 |
+
)
|
919 |
+
if (
|
920 |
+
data.get("sample_packing")
|
921 |
+
and data.get("sdp_attention")
|
922 |
+
and (data.get("bfloat16") or data.get("bf16"))
|
923 |
+
and not is_sm_90
|
924 |
+
):
|
925 |
+
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
926 |
+
LOG.warning(
|
927 |
+
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
928 |
+
"This may work on H100s."
|
929 |
+
)
|
930 |
+
|
931 |
+
return data
|
src/axolotl/utils/config/models/internals/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""module for gpu capabilities"""
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
|
7 |
+
class GPUCapabilities(BaseModel):
|
8 |
+
"""model to manage the gpu capabilities statically"""
|
9 |
+
|
10 |
+
bf16: bool = Field(default=False)
|
11 |
+
fp8: bool = Field(default=False)
|
12 |
+
n_gpu: int = Field(default=1)
|
13 |
+
n_node: int = Field(default=1)
|
14 |
+
compute_capability: Optional[str] = Field(default=None)
|
src/axolotl/utils/dict.py
CHANGED
@@ -12,4 +12,4 @@ class DictDefault(Dict):
|
|
12 |
return None
|
13 |
|
14 |
def __or__(self, other):
|
15 |
-
return DictDefault(super().
|
|
|
12 |
return None
|
13 |
|
14 |
def __or__(self, other):
|
15 |
+
return DictDefault(super().__ror__(other))
|
src/axolotl/utils/models.py
CHANGED
@@ -104,8 +104,8 @@ def load_model_config(cfg):
|
|
104 |
)
|
105 |
raise err
|
106 |
|
107 |
-
if cfg.
|
108 |
-
for key, val in cfg.
|
109 |
setattr(model_config, key, val)
|
110 |
|
111 |
check_model_config(cfg, model_config)
|
|
|
104 |
)
|
105 |
raise err
|
106 |
|
107 |
+
if cfg.model_config_overrides:
|
108 |
+
for key, val in cfg.model_config_overrides.items():
|
109 |
setattr(model_config, key, val)
|
110 |
|
111 |
check_model_config(cfg, model_config)
|
tests/test_dict.py
CHANGED
@@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase):
|
|
39 |
), "DictDefault should support in operator for existing keys in list"
|
40 |
|
41 |
def test_dict_or_operator(self):
|
42 |
-
cfg = DictDefault(
|
|
|
|
|
43 |
{
|
44 |
"key_a": {"key_b": "value_a"},
|
45 |
"key_c": "value_c",
|
@@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase):
|
|
48 |
}
|
49 |
)
|
50 |
|
51 |
-
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
52 |
-
{"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
|
53 |
-
)
|
54 |
-
|
55 |
assert (
|
56 |
cfg.key_a.key_b == "value_b"
|
57 |
), "DictDefault should support OR operator for existing nested keys"
|
|
|
39 |
), "DictDefault should support in operator for existing keys in list"
|
40 |
|
41 |
def test_dict_or_operator(self):
|
42 |
+
cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
|
43 |
+
|
44 |
+
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
45 |
{
|
46 |
"key_a": {"key_b": "value_a"},
|
47 |
"key_c": "value_c",
|
|
|
50 |
}
|
51 |
)
|
52 |
|
|
|
|
|
|
|
|
|
53 |
assert (
|
54 |
cfg.key_a.key_b == "value_b"
|
55 |
), "DictDefault should support OR operator for existing nested keys"
|
tests/test_prompt_tokenizers.py
CHANGED
@@ -204,13 +204,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
204 |
# fmt: off
|
205 |
# System message, multi-turn conversations
|
206 |
mt_ids = tokenize(test_data['multi_turn_sys'])
|
207 |
-
assert decode(mt_ids) == '<s> [INST]
|
208 |
-
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
209 |
|
210 |
# System message, single-turn conversations
|
211 |
st_ids = tokenize(test_data['single_turn_sys'])
|
212 |
-
assert decode(st_ids) == '<s> [INST]
|
213 |
-
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
214 |
|
215 |
# No system message, single-turn
|
216 |
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
|
|
204 |
# fmt: off
|
205 |
# System message, multi-turn conversations
|
206 |
mt_ids = tokenize(test_data['multi_turn_sys'])
|
207 |
+
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
208 |
+
assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
209 |
|
210 |
# System message, single-turn conversations
|
211 |
st_ids = tokenize(test_data['single_turn_sys'])
|
212 |
+
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
213 |
+
assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
214 |
|
215 |
# No system message, single-turn
|
216 |
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
tests/test_validation.py
CHANGED
@@ -1,20 +1,39 @@
|
|
|
|
1 |
"""Module for testing the validation module"""
|
2 |
|
3 |
import logging
|
4 |
import os
|
5 |
-
import unittest
|
6 |
from typing import Optional
|
7 |
|
8 |
import pytest
|
9 |
-
from
|
10 |
|
11 |
from axolotl.utils.config import validate_config
|
|
|
12 |
from axolotl.utils.dict import DictDefault
|
13 |
from axolotl.utils.models import check_model_config
|
14 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
15 |
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
"""
|
19 |
Base validation module to setup the log capture
|
20 |
"""
|
@@ -27,199 +46,354 @@ class BaseValidation(unittest.TestCase):
|
|
27 |
|
28 |
|
29 |
# pylint: disable=too-many-public-methods
|
30 |
-
class
|
31 |
"""
|
32 |
Test the validation module
|
33 |
"""
|
34 |
|
35 |
-
def
|
36 |
cfg = DictDefault(
|
37 |
{
|
38 |
-
"
|
|
|
|
|
|
|
|
|
39 |
}
|
40 |
)
|
41 |
|
42 |
-
with
|
|
|
|
|
|
|
43 |
validate_config(cfg)
|
44 |
-
assert "batch_size is not recommended" in self._caplog.records[0].message
|
45 |
|
46 |
-
def
|
47 |
-
|
48 |
-
{
|
49 |
-
"adapter": "qlora",
|
50 |
-
}
|
51 |
-
)
|
52 |
-
|
53 |
-
cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
54 |
{
|
55 |
-
"
|
|
|
|
|
|
|
56 |
}
|
57 |
)
|
58 |
|
59 |
-
with pytest.raises(
|
|
|
|
|
60 |
validate_config(cfg)
|
61 |
|
62 |
-
|
|
|
63 |
{
|
64 |
-
"
|
|
|
|
|
|
|
|
|
|
|
65 |
}
|
66 |
)
|
67 |
|
68 |
-
with pytest.raises(
|
|
|
|
|
|
|
69 |
validate_config(cfg)
|
70 |
|
71 |
-
|
|
|
72 |
{
|
73 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
}
|
75 |
)
|
76 |
|
77 |
-
|
78 |
-
validate_config(cfg)
|
79 |
|
80 |
-
|
|
|
81 |
{
|
82 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
}
|
84 |
)
|
85 |
|
86 |
validate_config(cfg)
|
87 |
|
88 |
-
def
|
89 |
-
|
90 |
{
|
91 |
-
"
|
92 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
}
|
94 |
)
|
95 |
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
{
|
98 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
}
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
with pytest.raises(ValueError, match=r".*8bit.*"):
|
103 |
validate_config(cfg)
|
104 |
|
105 |
-
cfg =
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
109 |
)
|
110 |
|
111 |
with pytest.raises(ValueError, match=r".*gptq.*"):
|
112 |
validate_config(cfg)
|
113 |
|
114 |
-
cfg =
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
)
|
119 |
|
120 |
with pytest.raises(ValueError, match=r".*4bit.*"):
|
121 |
validate_config(cfg)
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
128 |
)
|
129 |
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
validate_config(cfg)
|
132 |
|
133 |
-
cfg =
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
138 |
)
|
139 |
-
validate_config(cfg)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
147 |
)
|
148 |
|
149 |
-
with pytest.raises(
|
150 |
-
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
151 |
-
):
|
152 |
validate_config(cfg)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
validate_config(cfg)
|
161 |
|
|
|
162 |
cfg = DictDefault(
|
163 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
"gradient_accumulation_steps": 1,
|
|
|
165 |
}
|
166 |
)
|
167 |
|
168 |
-
|
|
|
|
|
|
|
169 |
|
170 |
-
def test_falcon_fsdp(self):
|
171 |
regex_exp = r".*FSDP is not supported for falcon models.*"
|
172 |
|
173 |
# Check for lower-case
|
174 |
-
cfg =
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
179 |
)
|
180 |
|
181 |
with pytest.raises(ValueError, match=regex_exp):
|
182 |
validate_config(cfg)
|
183 |
|
184 |
# Check for upper-case
|
185 |
-
cfg =
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
)
|
191 |
|
192 |
with pytest.raises(ValueError, match=regex_exp):
|
193 |
validate_config(cfg)
|
194 |
|
195 |
-
cfg =
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
199 |
)
|
200 |
|
201 |
validate_config(cfg)
|
202 |
|
203 |
-
def test_mpt_gradient_checkpointing(self):
|
204 |
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
205 |
|
206 |
# Check for lower-case
|
207 |
-
cfg =
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
212 |
)
|
213 |
|
214 |
with pytest.raises(ValueError, match=regex_exp):
|
215 |
validate_config(cfg)
|
216 |
|
217 |
-
def test_flash_optimum(self):
|
218 |
-
cfg =
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
223 |
)
|
224 |
|
225 |
with self._caplog.at_level(logging.WARNING):
|
@@ -230,10 +404,14 @@ class ValidationTest(BaseValidation):
|
|
230 |
for record in self._caplog.records
|
231 |
)
|
232 |
|
233 |
-
cfg =
|
234 |
-
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
237 |
)
|
238 |
|
239 |
with self._caplog.at_level(logging.WARNING):
|
@@ -243,34 +421,43 @@ class ValidationTest(BaseValidation):
|
|
243 |
for record in self._caplog.records
|
244 |
)
|
245 |
|
246 |
-
cfg =
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
251 |
)
|
252 |
regex_exp = r".*AMP is not supported.*"
|
253 |
|
254 |
with pytest.raises(ValueError, match=regex_exp):
|
255 |
validate_config(cfg)
|
256 |
|
257 |
-
cfg =
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
262 |
)
|
263 |
regex_exp = r".*AMP is not supported.*"
|
264 |
|
265 |
with pytest.raises(ValueError, match=regex_exp):
|
266 |
validate_config(cfg)
|
267 |
|
268 |
-
def test_adamw_hyperparams(self):
|
269 |
-
cfg =
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
274 |
)
|
275 |
|
276 |
with self._caplog.at_level(logging.WARNING):
|
@@ -281,11 +468,14 @@ class ValidationTest(BaseValidation):
|
|
281 |
for record in self._caplog.records
|
282 |
)
|
283 |
|
284 |
-
cfg =
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
289 |
)
|
290 |
|
291 |
with self._caplog.at_level(logging.WARNING):
|
@@ -296,30 +486,39 @@ class ValidationTest(BaseValidation):
|
|
296 |
for record in self._caplog.records
|
297 |
)
|
298 |
|
299 |
-
cfg =
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
306 |
)
|
307 |
|
308 |
validate_config(cfg)
|
309 |
|
310 |
-
cfg =
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
314 |
)
|
315 |
|
316 |
validate_config(cfg)
|
317 |
|
318 |
-
def test_deprecated_packing(self):
|
319 |
-
cfg =
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
323 |
)
|
324 |
with pytest.raises(
|
325 |
DeprecationWarning,
|
@@ -327,12 +526,15 @@ class ValidationTest(BaseValidation):
|
|
327 |
):
|
328 |
validate_config(cfg)
|
329 |
|
330 |
-
def test_packing(self):
|
331 |
-
cfg =
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
336 |
)
|
337 |
with self._caplog.at_level(logging.WARNING):
|
338 |
validate_config(cfg)
|
@@ -342,62 +544,79 @@ class ValidationTest(BaseValidation):
|
|
342 |
for record in self._caplog.records
|
343 |
)
|
344 |
|
345 |
-
|
346 |
-
is_torch_bf16_gpu_available(),
|
347 |
-
reason="test should only run on gpus w/o bf16 support",
|
348 |
-
)
|
349 |
-
def test_merge_lora_no_bf16_fail(self):
|
350 |
"""
|
351 |
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
352 |
"""
|
353 |
|
354 |
-
cfg =
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
358 |
)
|
359 |
|
360 |
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
361 |
-
|
362 |
-
|
363 |
-
cfg =
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
368 |
)
|
369 |
|
370 |
validate_config(cfg)
|
371 |
|
372 |
-
def test_sharegpt_deprecation(self):
|
373 |
-
cfg =
|
374 |
-
|
|
|
|
|
|
|
375 |
)
|
376 |
with self._caplog.at_level(logging.WARNING):
|
377 |
-
validate_config(cfg)
|
378 |
assert any(
|
379 |
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
380 |
for record in self._caplog.records
|
381 |
)
|
382 |
-
assert
|
383 |
-
|
384 |
-
cfg =
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
)
|
387 |
with self._caplog.at_level(logging.WARNING):
|
388 |
-
validate_config(cfg)
|
389 |
assert any(
|
390 |
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
391 |
for record in self._caplog.records
|
392 |
)
|
393 |
-
assert
|
394 |
-
|
395 |
-
def test_no_conflict_save_strategy(self):
|
396 |
-
cfg =
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
|
|
|
|
|
|
401 |
)
|
402 |
|
403 |
with pytest.raises(
|
@@ -405,11 +624,14 @@ class ValidationTest(BaseValidation):
|
|
405 |
):
|
406 |
validate_config(cfg)
|
407 |
|
408 |
-
cfg =
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
|
|
|
|
|
|
413 |
)
|
414 |
|
415 |
with pytest.raises(
|
@@ -417,45 +639,60 @@ class ValidationTest(BaseValidation):
|
|
417 |
):
|
418 |
validate_config(cfg)
|
419 |
|
420 |
-
cfg =
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
424 |
)
|
425 |
|
426 |
validate_config(cfg)
|
427 |
|
428 |
-
cfg =
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
|
|
433 |
)
|
434 |
|
435 |
validate_config(cfg)
|
436 |
|
437 |
-
cfg =
|
438 |
-
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
441 |
)
|
442 |
|
443 |
validate_config(cfg)
|
444 |
|
445 |
-
cfg =
|
446 |
-
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
449 |
)
|
450 |
|
451 |
validate_config(cfg)
|
452 |
|
453 |
-
def test_no_conflict_eval_strategy(self):
|
454 |
-
cfg =
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
|
|
|
|
|
|
459 |
)
|
460 |
|
461 |
with pytest.raises(
|
@@ -463,11 +700,14 @@ class ValidationTest(BaseValidation):
|
|
463 |
):
|
464 |
validate_config(cfg)
|
465 |
|
466 |
-
cfg =
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
|
|
|
|
|
|
471 |
)
|
472 |
|
473 |
with pytest.raises(
|
@@ -475,44 +715,59 @@ class ValidationTest(BaseValidation):
|
|
475 |
):
|
476 |
validate_config(cfg)
|
477 |
|
478 |
-
cfg =
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
482 |
)
|
483 |
|
484 |
validate_config(cfg)
|
485 |
|
486 |
-
cfg =
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
|
|
|
|
|
|
491 |
)
|
492 |
|
493 |
validate_config(cfg)
|
494 |
|
495 |
-
cfg =
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
499 |
)
|
500 |
|
501 |
validate_config(cfg)
|
502 |
|
503 |
-
cfg =
|
504 |
-
|
505 |
-
|
506 |
-
|
|
|
|
|
|
|
507 |
)
|
508 |
|
509 |
validate_config(cfg)
|
510 |
|
511 |
-
cfg =
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
516 |
)
|
517 |
|
518 |
with pytest.raises(
|
@@ -521,11 +776,14 @@ class ValidationTest(BaseValidation):
|
|
521 |
):
|
522 |
validate_config(cfg)
|
523 |
|
524 |
-
cfg =
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
|
|
|
|
|
|
529 |
)
|
530 |
|
531 |
with pytest.raises(
|
@@ -534,38 +792,50 @@ class ValidationTest(BaseValidation):
|
|
534 |
):
|
535 |
validate_config(cfg)
|
536 |
|
537 |
-
cfg =
|
538 |
-
|
539 |
-
|
540 |
-
|
|
|
|
|
|
|
541 |
)
|
542 |
|
543 |
validate_config(cfg)
|
544 |
|
545 |
-
cfg =
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
|
|
|
|
550 |
)
|
551 |
|
552 |
validate_config(cfg)
|
553 |
|
554 |
-
cfg =
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
|
|
|
|
|
|
559 |
)
|
560 |
|
561 |
validate_config(cfg)
|
562 |
|
563 |
-
def test_eval_table_size_conflict_eval_packing(self):
|
564 |
-
cfg =
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
569 |
)
|
570 |
|
571 |
with pytest.raises(
|
@@ -573,39 +843,51 @@ class ValidationTest(BaseValidation):
|
|
573 |
):
|
574 |
validate_config(cfg)
|
575 |
|
576 |
-
cfg =
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
|
581 |
)
|
582 |
|
583 |
validate_config(cfg)
|
584 |
|
585 |
-
cfg =
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
|
|
|
|
|
|
590 |
)
|
591 |
|
592 |
validate_config(cfg)
|
593 |
|
594 |
-
cfg =
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
|
|
|
|
|
|
600 |
)
|
601 |
|
602 |
validate_config(cfg)
|
603 |
|
604 |
-
def test_load_in_x_bit_without_adapter(self):
|
605 |
-
cfg =
|
606 |
-
|
607 |
-
|
608 |
-
|
|
|
|
|
|
|
609 |
)
|
610 |
|
611 |
with pytest.raises(
|
@@ -614,10 +896,13 @@ class ValidationTest(BaseValidation):
|
|
614 |
):
|
615 |
validate_config(cfg)
|
616 |
|
617 |
-
cfg =
|
618 |
-
|
619 |
-
|
620 |
-
|
|
|
|
|
|
|
621 |
)
|
622 |
|
623 |
with pytest.raises(
|
@@ -626,30 +911,39 @@ class ValidationTest(BaseValidation):
|
|
626 |
):
|
627 |
validate_config(cfg)
|
628 |
|
629 |
-
cfg =
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
|
|
|
|
|
|
634 |
)
|
635 |
|
636 |
validate_config(cfg)
|
637 |
|
638 |
-
cfg =
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
|
|
|
|
|
|
643 |
)
|
644 |
|
645 |
validate_config(cfg)
|
646 |
|
647 |
-
def test_warmup_step_no_conflict(self):
|
648 |
-
cfg =
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
|
|
|
|
|
|
653 |
)
|
654 |
|
655 |
with pytest.raises(
|
@@ -658,29 +952,40 @@ class ValidationTest(BaseValidation):
|
|
658 |
):
|
659 |
validate_config(cfg)
|
660 |
|
661 |
-
cfg =
|
662 |
-
|
663 |
-
|
664 |
-
|
|
|
|
|
|
|
665 |
)
|
666 |
|
667 |
validate_config(cfg)
|
668 |
|
669 |
-
cfg =
|
670 |
-
|
671 |
-
|
672 |
-
|
|
|
|
|
|
|
673 |
)
|
674 |
|
675 |
validate_config(cfg)
|
676 |
|
677 |
-
def test_unfrozen_parameters_w_peft_layers_to_transform(self):
|
678 |
-
cfg =
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
|
|
|
|
|
|
|
|
|
|
684 |
)
|
685 |
|
686 |
with pytest.raises(
|
@@ -689,8 +994,8 @@ class ValidationTest(BaseValidation):
|
|
689 |
):
|
690 |
validate_config(cfg)
|
691 |
|
692 |
-
def test_hub_model_id_save_value_warns(self):
|
693 |
-
cfg = DictDefault({"hub_model_id": "test"})
|
694 |
|
695 |
with self._caplog.at_level(logging.WARNING):
|
696 |
validate_config(cfg)
|
@@ -698,22 +1003,25 @@ class ValidationTest(BaseValidation):
|
|
698 |
"set without any models being saved" in self._caplog.records[0].message
|
699 |
)
|
700 |
|
701 |
-
def test_hub_model_id_save_value(self):
|
702 |
-
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
|
703 |
|
704 |
with self._caplog.at_level(logging.WARNING):
|
705 |
validate_config(cfg)
|
706 |
assert len(self._caplog.records) == 0
|
707 |
|
708 |
|
709 |
-
class
|
710 |
"""
|
711 |
Test the validation for the config when the model config is available
|
712 |
"""
|
713 |
|
714 |
-
def test_llama_add_tokens_adapter(self):
|
715 |
-
cfg =
|
716 |
-
|
|
|
|
|
|
|
717 |
)
|
718 |
model_config = DictDefault({"model_type": "llama"})
|
719 |
|
@@ -723,13 +1031,16 @@ class ValidationCheckModelConfig(BaseValidation):
|
|
723 |
):
|
724 |
check_model_config(cfg, model_config)
|
725 |
|
726 |
-
cfg =
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
|
|
|
|
|
|
733 |
)
|
734 |
|
735 |
with pytest.raises(
|
@@ -738,20 +1049,26 @@ class ValidationCheckModelConfig(BaseValidation):
|
|
738 |
):
|
739 |
check_model_config(cfg, model_config)
|
740 |
|
741 |
-
cfg =
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
|
|
|
|
|
|
748 |
)
|
749 |
|
750 |
check_model_config(cfg, model_config)
|
751 |
|
752 |
-
def test_phi_add_tokens_adapter(self):
|
753 |
-
cfg =
|
754 |
-
|
|
|
|
|
|
|
755 |
)
|
756 |
model_config = DictDefault({"model_type": "phi"})
|
757 |
|
@@ -761,13 +1078,16 @@ class ValidationCheckModelConfig(BaseValidation):
|
|
761 |
):
|
762 |
check_model_config(cfg, model_config)
|
763 |
|
764 |
-
cfg =
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
|
|
|
|
|
|
771 |
)
|
772 |
|
773 |
with pytest.raises(
|
@@ -776,66 +1096,78 @@ class ValidationCheckModelConfig(BaseValidation):
|
|
776 |
):
|
777 |
check_model_config(cfg, model_config)
|
778 |
|
779 |
-
cfg =
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
|
|
|
|
|
|
786 |
)
|
787 |
|
788 |
check_model_config(cfg, model_config)
|
789 |
|
790 |
|
791 |
-
class
|
792 |
"""
|
793 |
Validation test for wandb
|
794 |
"""
|
795 |
|
796 |
-
def test_wandb_set_run_id_to_name(self):
|
797 |
-
cfg =
|
798 |
-
|
799 |
-
|
800 |
-
|
|
|
|
|
|
|
801 |
)
|
802 |
|
803 |
with self._caplog.at_level(logging.WARNING):
|
804 |
-
validate_config(cfg)
|
805 |
assert any(
|
806 |
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
807 |
in record.message
|
808 |
for record in self._caplog.records
|
809 |
)
|
810 |
|
811 |
-
assert
|
812 |
|
813 |
-
cfg =
|
814 |
-
|
815 |
-
|
816 |
-
|
|
|
|
|
|
|
817 |
)
|
818 |
|
819 |
-
validate_config(cfg)
|
820 |
|
821 |
-
assert
|
822 |
|
823 |
-
def test_wandb_sets_env(self):
|
824 |
-
cfg =
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
|
|
|
|
|
|
834 |
)
|
835 |
|
836 |
-
validate_config(cfg)
|
837 |
|
838 |
-
setup_wandb_env_vars(
|
839 |
|
840 |
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
841 |
assert os.environ.get("WANDB_NAME", "") == "bar"
|
@@ -855,24 +1187,27 @@ class ValidationWandbTest(BaseValidation):
|
|
855 |
os.environ.pop("WANDB_LOG_MODEL", None)
|
856 |
os.environ.pop("WANDB_DISABLED", None)
|
857 |
|
858 |
-
def test_wandb_set_disabled(self):
|
859 |
-
cfg = DictDefault({})
|
860 |
|
861 |
-
validate_config(cfg)
|
862 |
|
863 |
-
setup_wandb_env_vars(
|
864 |
|
865 |
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
866 |
|
867 |
-
cfg =
|
868 |
-
|
869 |
-
|
870 |
-
|
|
|
|
|
|
|
871 |
)
|
872 |
|
873 |
-
validate_config(cfg)
|
874 |
|
875 |
-
setup_wandb_env_vars(
|
876 |
|
877 |
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
878 |
|
|
|
1 |
+
# pylint: disable=too-many-lines
|
2 |
"""Module for testing the validation module"""
|
3 |
|
4 |
import logging
|
5 |
import os
|
|
|
6 |
from typing import Optional
|
7 |
|
8 |
import pytest
|
9 |
+
from pydantic import ValidationError
|
10 |
|
11 |
from axolotl.utils.config import validate_config
|
12 |
+
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
13 |
from axolotl.utils.dict import DictDefault
|
14 |
from axolotl.utils.models import check_model_config
|
15 |
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
16 |
|
17 |
|
18 |
+
@pytest.fixture(name="minimal_cfg")
|
19 |
+
def fixture_cfg():
|
20 |
+
return DictDefault(
|
21 |
+
{
|
22 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
23 |
+
"learning_rate": 0.000001,
|
24 |
+
"datasets": [
|
25 |
+
{
|
26 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
27 |
+
"type": "alpaca",
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"micro_batch_size": 1,
|
31 |
+
"gradient_accumulation_steps": 1,
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class BaseValidation:
|
37 |
"""
|
38 |
Base validation module to setup the log capture
|
39 |
"""
|
|
|
46 |
|
47 |
|
48 |
# pylint: disable=too-many-public-methods
|
49 |
+
class TestValidation(BaseValidation):
|
50 |
"""
|
51 |
Test the validation module
|
52 |
"""
|
53 |
|
54 |
+
def test_datasets_min_length(self):
|
55 |
cfg = DictDefault(
|
56 |
{
|
57 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
58 |
+
"learning_rate": 0.000001,
|
59 |
+
"datasets": [],
|
60 |
+
"micro_batch_size": 1,
|
61 |
+
"gradient_accumulation_steps": 1,
|
62 |
}
|
63 |
)
|
64 |
|
65 |
+
with pytest.raises(
|
66 |
+
ValidationError,
|
67 |
+
match=r".*List should have at least 1 item after validation*",
|
68 |
+
):
|
69 |
validate_config(cfg)
|
|
|
70 |
|
71 |
+
def test_datasets_min_length_empty(self):
|
72 |
+
cfg = DictDefault(
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
{
|
74 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
75 |
+
"learning_rate": 0.000001,
|
76 |
+
"micro_batch_size": 1,
|
77 |
+
"gradient_accumulation_steps": 1,
|
78 |
}
|
79 |
)
|
80 |
|
81 |
+
with pytest.raises(
|
82 |
+
ValueError, match=r".*either datasets or pretraining_dataset is required*"
|
83 |
+
):
|
84 |
validate_config(cfg)
|
85 |
|
86 |
+
def test_pretrain_dataset_min_length(self):
|
87 |
+
cfg = DictDefault(
|
88 |
{
|
89 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
90 |
+
"learning_rate": 0.000001,
|
91 |
+
"pretraining_dataset": [],
|
92 |
+
"micro_batch_size": 1,
|
93 |
+
"gradient_accumulation_steps": 1,
|
94 |
+
"max_steps": 100,
|
95 |
}
|
96 |
)
|
97 |
|
98 |
+
with pytest.raises(
|
99 |
+
ValidationError,
|
100 |
+
match=r".*List should have at least 1 item after validation*",
|
101 |
+
):
|
102 |
validate_config(cfg)
|
103 |
|
104 |
+
def test_valid_pretrain_dataset(self):
|
105 |
+
cfg = DictDefault(
|
106 |
{
|
107 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
108 |
+
"learning_rate": 0.000001,
|
109 |
+
"pretraining_dataset": [
|
110 |
+
{
|
111 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
112 |
+
"type": "alpaca",
|
113 |
+
}
|
114 |
+
],
|
115 |
+
"micro_batch_size": 1,
|
116 |
+
"gradient_accumulation_steps": 1,
|
117 |
+
"max_steps": 100,
|
118 |
}
|
119 |
)
|
120 |
|
121 |
+
validate_config(cfg)
|
|
|
122 |
|
123 |
+
def test_valid_sft_dataset(self):
|
124 |
+
cfg = DictDefault(
|
125 |
{
|
126 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
127 |
+
"learning_rate": 0.000001,
|
128 |
+
"datasets": [
|
129 |
+
{
|
130 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
131 |
+
"type": "alpaca",
|
132 |
+
}
|
133 |
+
],
|
134 |
+
"micro_batch_size": 1,
|
135 |
+
"gradient_accumulation_steps": 1,
|
136 |
}
|
137 |
)
|
138 |
|
139 |
validate_config(cfg)
|
140 |
|
141 |
+
def test_batch_size_unused_warning(self):
|
142 |
+
cfg = DictDefault(
|
143 |
{
|
144 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
145 |
+
"learning_rate": 0.000001,
|
146 |
+
"datasets": [
|
147 |
+
{
|
148 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
149 |
+
"type": "alpaca",
|
150 |
+
}
|
151 |
+
],
|
152 |
+
"micro_batch_size": 4,
|
153 |
+
"batch_size": 32,
|
154 |
}
|
155 |
)
|
156 |
|
157 |
+
with self._caplog.at_level(logging.WARNING):
|
158 |
+
validate_config(cfg)
|
159 |
+
assert "batch_size is not recommended" in self._caplog.records[0].message
|
160 |
+
|
161 |
+
def test_batch_size_more_params(self):
|
162 |
+
cfg = DictDefault(
|
163 |
{
|
164 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
165 |
+
"learning_rate": 0.000001,
|
166 |
+
"datasets": [
|
167 |
+
{
|
168 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
169 |
+
"type": "alpaca",
|
170 |
+
}
|
171 |
+
],
|
172 |
+
"batch_size": 32,
|
173 |
}
|
174 |
)
|
175 |
|
176 |
+
with pytest.raises(ValueError, match=r".*At least two of*"):
|
177 |
+
validate_config(cfg)
|
178 |
+
|
179 |
+
def test_qlora(self, minimal_cfg):
|
180 |
+
base_cfg = (
|
181 |
+
DictDefault(
|
182 |
+
{
|
183 |
+
"adapter": "qlora",
|
184 |
+
}
|
185 |
+
)
|
186 |
+
| minimal_cfg
|
187 |
+
)
|
188 |
+
|
189 |
+
cfg = (
|
190 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
191 |
+
{
|
192 |
+
"load_in_8bit": True,
|
193 |
+
}
|
194 |
+
)
|
195 |
+
| base_cfg
|
196 |
+
)
|
197 |
+
|
198 |
with pytest.raises(ValueError, match=r".*8bit.*"):
|
199 |
validate_config(cfg)
|
200 |
|
201 |
+
cfg = (
|
202 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
203 |
+
{
|
204 |
+
"gptq": True,
|
205 |
+
}
|
206 |
+
)
|
207 |
+
| base_cfg
|
208 |
)
|
209 |
|
210 |
with pytest.raises(ValueError, match=r".*gptq.*"):
|
211 |
validate_config(cfg)
|
212 |
|
213 |
+
cfg = (
|
214 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
215 |
+
{
|
216 |
+
"load_in_4bit": False,
|
217 |
+
}
|
218 |
+
)
|
219 |
+
| base_cfg
|
220 |
)
|
221 |
|
222 |
with pytest.raises(ValueError, match=r".*4bit.*"):
|
223 |
validate_config(cfg)
|
224 |
|
225 |
+
cfg = (
|
226 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
227 |
+
{
|
228 |
+
"load_in_4bit": True,
|
229 |
+
}
|
230 |
+
)
|
231 |
+
| base_cfg
|
232 |
)
|
233 |
|
234 |
+
validate_config(cfg)
|
235 |
+
|
236 |
+
def test_qlora_merge(self, minimal_cfg):
|
237 |
+
base_cfg = (
|
238 |
+
DictDefault(
|
239 |
+
{
|
240 |
+
"adapter": "qlora",
|
241 |
+
"merge_lora": True,
|
242 |
+
}
|
243 |
+
)
|
244 |
+
| minimal_cfg
|
245 |
+
)
|
246 |
+
|
247 |
+
cfg = (
|
248 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
249 |
+
{
|
250 |
+
"load_in_8bit": True,
|
251 |
+
}
|
252 |
+
)
|
253 |
+
| base_cfg
|
254 |
+
)
|
255 |
+
|
256 |
+
with pytest.raises(ValueError, match=r".*8bit.*"):
|
257 |
validate_config(cfg)
|
258 |
|
259 |
+
cfg = (
|
260 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
261 |
+
{
|
262 |
+
"gptq": True,
|
263 |
+
}
|
264 |
+
)
|
265 |
+
| base_cfg
|
266 |
)
|
|
|
267 |
|
268 |
+
with pytest.raises(ValueError, match=r".*gptq.*"):
|
269 |
+
validate_config(cfg)
|
270 |
+
|
271 |
+
cfg = (
|
272 |
+
DictDefault( # pylint: disable=unsupported-binary-operation
|
273 |
+
{
|
274 |
+
"load_in_4bit": True,
|
275 |
+
}
|
276 |
+
)
|
277 |
+
| base_cfg
|
278 |
)
|
279 |
|
280 |
+
with pytest.raises(ValueError, match=r".*4bit.*"):
|
|
|
|
|
281 |
validate_config(cfg)
|
282 |
|
283 |
+
def test_hf_use_auth_token(self, minimal_cfg):
|
284 |
+
cfg = (
|
285 |
+
DictDefault(
|
286 |
+
{
|
287 |
+
"push_dataset_to_hub": "namespace/repo",
|
288 |
+
}
|
289 |
+
)
|
290 |
+
| minimal_cfg
|
291 |
)
|
292 |
|
293 |
+
with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
|
294 |
+
validate_config(cfg)
|
295 |
+
|
296 |
+
cfg = (
|
297 |
+
DictDefault(
|
298 |
+
{
|
299 |
+
"push_dataset_to_hub": "namespace/repo",
|
300 |
+
"hf_use_auth_token": True,
|
301 |
+
}
|
302 |
+
)
|
303 |
+
| minimal_cfg
|
304 |
+
)
|
305 |
validate_config(cfg)
|
306 |
|
307 |
+
def test_gradient_accumulations_or_batch_size(self):
|
308 |
cfg = DictDefault(
|
309 |
{
|
310 |
+
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
311 |
+
"learning_rate": 0.000001,
|
312 |
+
"datasets": [
|
313 |
+
{
|
314 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
315 |
+
"type": "alpaca",
|
316 |
+
}
|
317 |
+
],
|
318 |
"gradient_accumulation_steps": 1,
|
319 |
+
"batch_size": 1,
|
320 |
}
|
321 |
)
|
322 |
|
323 |
+
with pytest.raises(
|
324 |
+
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
325 |
+
):
|
326 |
+
validate_config(cfg)
|
327 |
|
328 |
+
def test_falcon_fsdp(self, minimal_cfg):
|
329 |
regex_exp = r".*FSDP is not supported for falcon models.*"
|
330 |
|
331 |
# Check for lower-case
|
332 |
+
cfg = (
|
333 |
+
DictDefault(
|
334 |
+
{
|
335 |
+
"base_model": "tiiuae/falcon-7b",
|
336 |
+
"fsdp": ["full_shard", "auto_wrap"],
|
337 |
+
}
|
338 |
+
)
|
339 |
+
| minimal_cfg
|
340 |
)
|
341 |
|
342 |
with pytest.raises(ValueError, match=regex_exp):
|
343 |
validate_config(cfg)
|
344 |
|
345 |
# Check for upper-case
|
346 |
+
cfg = (
|
347 |
+
DictDefault(
|
348 |
+
{
|
349 |
+
"base_model": "Falcon-7b",
|
350 |
+
"fsdp": ["full_shard", "auto_wrap"],
|
351 |
+
}
|
352 |
+
)
|
353 |
+
| minimal_cfg
|
354 |
)
|
355 |
|
356 |
with pytest.raises(ValueError, match=regex_exp):
|
357 |
validate_config(cfg)
|
358 |
|
359 |
+
cfg = (
|
360 |
+
DictDefault(
|
361 |
+
{
|
362 |
+
"base_model": "tiiuae/falcon-7b",
|
363 |
+
}
|
364 |
+
)
|
365 |
+
| minimal_cfg
|
366 |
)
|
367 |
|
368 |
validate_config(cfg)
|
369 |
|
370 |
+
def test_mpt_gradient_checkpointing(self, minimal_cfg):
|
371 |
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
|
372 |
|
373 |
# Check for lower-case
|
374 |
+
cfg = (
|
375 |
+
DictDefault(
|
376 |
+
{
|
377 |
+
"base_model": "mosaicml/mpt-7b",
|
378 |
+
"gradient_checkpointing": True,
|
379 |
+
}
|
380 |
+
)
|
381 |
+
| minimal_cfg
|
382 |
)
|
383 |
|
384 |
with pytest.raises(ValueError, match=regex_exp):
|
385 |
validate_config(cfg)
|
386 |
|
387 |
+
def test_flash_optimum(self, minimal_cfg):
|
388 |
+
cfg = (
|
389 |
+
DictDefault(
|
390 |
+
{
|
391 |
+
"flash_optimum": True,
|
392 |
+
"adapter": "lora",
|
393 |
+
"bf16": False,
|
394 |
+
}
|
395 |
+
)
|
396 |
+
| minimal_cfg
|
397 |
)
|
398 |
|
399 |
with self._caplog.at_level(logging.WARNING):
|
|
|
404 |
for record in self._caplog.records
|
405 |
)
|
406 |
|
407 |
+
cfg = (
|
408 |
+
DictDefault(
|
409 |
+
{
|
410 |
+
"flash_optimum": True,
|
411 |
+
"bf16": False,
|
412 |
+
}
|
413 |
+
)
|
414 |
+
| minimal_cfg
|
415 |
)
|
416 |
|
417 |
with self._caplog.at_level(logging.WARNING):
|
|
|
421 |
for record in self._caplog.records
|
422 |
)
|
423 |
|
424 |
+
cfg = (
|
425 |
+
DictDefault(
|
426 |
+
{
|
427 |
+
"flash_optimum": True,
|
428 |
+
"fp16": True,
|
429 |
+
}
|
430 |
+
)
|
431 |
+
| minimal_cfg
|
432 |
)
|
433 |
regex_exp = r".*AMP is not supported.*"
|
434 |
|
435 |
with pytest.raises(ValueError, match=regex_exp):
|
436 |
validate_config(cfg)
|
437 |
|
438 |
+
cfg = (
|
439 |
+
DictDefault(
|
440 |
+
{
|
441 |
+
"flash_optimum": True,
|
442 |
+
"bf16": True,
|
443 |
+
}
|
444 |
+
)
|
445 |
+
| minimal_cfg
|
446 |
)
|
447 |
regex_exp = r".*AMP is not supported.*"
|
448 |
|
449 |
with pytest.raises(ValueError, match=regex_exp):
|
450 |
validate_config(cfg)
|
451 |
|
452 |
+
def test_adamw_hyperparams(self, minimal_cfg):
|
453 |
+
cfg = (
|
454 |
+
DictDefault(
|
455 |
+
{
|
456 |
+
"optimizer": None,
|
457 |
+
"adam_epsilon": 0.0001,
|
458 |
+
}
|
459 |
+
)
|
460 |
+
| minimal_cfg
|
461 |
)
|
462 |
|
463 |
with self._caplog.at_level(logging.WARNING):
|
|
|
468 |
for record in self._caplog.records
|
469 |
)
|
470 |
|
471 |
+
cfg = (
|
472 |
+
DictDefault(
|
473 |
+
{
|
474 |
+
"optimizer": "adafactor",
|
475 |
+
"adam_beta1": 0.0001,
|
476 |
+
}
|
477 |
+
)
|
478 |
+
| minimal_cfg
|
479 |
)
|
480 |
|
481 |
with self._caplog.at_level(logging.WARNING):
|
|
|
486 |
for record in self._caplog.records
|
487 |
)
|
488 |
|
489 |
+
cfg = (
|
490 |
+
DictDefault(
|
491 |
+
{
|
492 |
+
"optimizer": "adamw_bnb_8bit",
|
493 |
+
"adam_beta1": 0.9,
|
494 |
+
"adam_beta2": 0.99,
|
495 |
+
"adam_epsilon": 0.0001,
|
496 |
+
}
|
497 |
+
)
|
498 |
+
| minimal_cfg
|
499 |
)
|
500 |
|
501 |
validate_config(cfg)
|
502 |
|
503 |
+
cfg = (
|
504 |
+
DictDefault(
|
505 |
+
{
|
506 |
+
"optimizer": "adafactor",
|
507 |
+
}
|
508 |
+
)
|
509 |
+
| minimal_cfg
|
510 |
)
|
511 |
|
512 |
validate_config(cfg)
|
513 |
|
514 |
+
def test_deprecated_packing(self, minimal_cfg):
|
515 |
+
cfg = (
|
516 |
+
DictDefault(
|
517 |
+
{
|
518 |
+
"max_packed_sequence_len": 1024,
|
519 |
+
}
|
520 |
+
)
|
521 |
+
| minimal_cfg
|
522 |
)
|
523 |
with pytest.raises(
|
524 |
DeprecationWarning,
|
|
|
526 |
):
|
527 |
validate_config(cfg)
|
528 |
|
529 |
+
def test_packing(self, minimal_cfg):
|
530 |
+
cfg = (
|
531 |
+
DictDefault(
|
532 |
+
{
|
533 |
+
"sample_packing": True,
|
534 |
+
"pad_to_sequence_len": None,
|
535 |
+
}
|
536 |
+
)
|
537 |
+
| minimal_cfg
|
538 |
)
|
539 |
with self._caplog.at_level(logging.WARNING):
|
540 |
validate_config(cfg)
|
|
|
544 |
for record in self._caplog.records
|
545 |
)
|
546 |
|
547 |
+
def test_merge_lora_no_bf16_fail(self, minimal_cfg):
|
|
|
|
|
|
|
|
|
548 |
"""
|
549 |
This is assumed to be run on a CPU machine, so bf16 is not supported.
|
550 |
"""
|
551 |
|
552 |
+
cfg = (
|
553 |
+
DictDefault(
|
554 |
+
{
|
555 |
+
"bf16": True,
|
556 |
+
"capabilities": {"bf16": False},
|
557 |
+
}
|
558 |
+
)
|
559 |
+
| minimal_cfg
|
560 |
)
|
561 |
|
562 |
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
|
563 |
+
AxolotlConfigWCapabilities(**cfg.to_dict())
|
564 |
+
|
565 |
+
cfg = (
|
566 |
+
DictDefault(
|
567 |
+
{
|
568 |
+
"bf16": True,
|
569 |
+
"merge_lora": True,
|
570 |
+
"capabilities": {"bf16": False},
|
571 |
+
}
|
572 |
+
)
|
573 |
+
| minimal_cfg
|
574 |
)
|
575 |
|
576 |
validate_config(cfg)
|
577 |
|
578 |
+
def test_sharegpt_deprecation(self, minimal_cfg):
|
579 |
+
cfg = (
|
580 |
+
DictDefault(
|
581 |
+
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
|
582 |
+
)
|
583 |
+
| minimal_cfg
|
584 |
)
|
585 |
with self._caplog.at_level(logging.WARNING):
|
586 |
+
new_cfg = validate_config(cfg)
|
587 |
assert any(
|
588 |
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
589 |
for record in self._caplog.records
|
590 |
)
|
591 |
+
assert new_cfg.datasets[0].type == "sharegpt"
|
592 |
+
|
593 |
+
cfg = (
|
594 |
+
DictDefault(
|
595 |
+
{
|
596 |
+
"datasets": [
|
597 |
+
{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}
|
598 |
+
]
|
599 |
+
}
|
600 |
+
)
|
601 |
+
| minimal_cfg
|
602 |
)
|
603 |
with self._caplog.at_level(logging.WARNING):
|
604 |
+
new_cfg = validate_config(cfg)
|
605 |
assert any(
|
606 |
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
607 |
for record in self._caplog.records
|
608 |
)
|
609 |
+
assert new_cfg.datasets[0].type == "sharegpt:load_role"
|
610 |
+
|
611 |
+
def test_no_conflict_save_strategy(self, minimal_cfg):
|
612 |
+
cfg = (
|
613 |
+
DictDefault(
|
614 |
+
{
|
615 |
+
"save_strategy": "epoch",
|
616 |
+
"save_steps": 10,
|
617 |
+
}
|
618 |
+
)
|
619 |
+
| minimal_cfg
|
620 |
)
|
621 |
|
622 |
with pytest.raises(
|
|
|
624 |
):
|
625 |
validate_config(cfg)
|
626 |
|
627 |
+
cfg = (
|
628 |
+
DictDefault(
|
629 |
+
{
|
630 |
+
"save_strategy": "no",
|
631 |
+
"save_steps": 10,
|
632 |
+
}
|
633 |
+
)
|
634 |
+
| minimal_cfg
|
635 |
)
|
636 |
|
637 |
with pytest.raises(
|
|
|
639 |
):
|
640 |
validate_config(cfg)
|
641 |
|
642 |
+
cfg = (
|
643 |
+
DictDefault(
|
644 |
+
{
|
645 |
+
"save_strategy": "steps",
|
646 |
+
}
|
647 |
+
)
|
648 |
+
| minimal_cfg
|
649 |
)
|
650 |
|
651 |
validate_config(cfg)
|
652 |
|
653 |
+
cfg = (
|
654 |
+
DictDefault(
|
655 |
+
{
|
656 |
+
"save_strategy": "steps",
|
657 |
+
"save_steps": 10,
|
658 |
+
}
|
659 |
+
)
|
660 |
+
| minimal_cfg
|
661 |
)
|
662 |
|
663 |
validate_config(cfg)
|
664 |
|
665 |
+
cfg = (
|
666 |
+
DictDefault(
|
667 |
+
{
|
668 |
+
"save_steps": 10,
|
669 |
+
}
|
670 |
+
)
|
671 |
+
| minimal_cfg
|
672 |
)
|
673 |
|
674 |
validate_config(cfg)
|
675 |
|
676 |
+
cfg = (
|
677 |
+
DictDefault(
|
678 |
+
{
|
679 |
+
"save_strategy": "no",
|
680 |
+
}
|
681 |
+
)
|
682 |
+
| minimal_cfg
|
683 |
)
|
684 |
|
685 |
validate_config(cfg)
|
686 |
|
687 |
+
def test_no_conflict_eval_strategy(self, minimal_cfg):
|
688 |
+
cfg = (
|
689 |
+
DictDefault(
|
690 |
+
{
|
691 |
+
"evaluation_strategy": "epoch",
|
692 |
+
"eval_steps": 10,
|
693 |
+
}
|
694 |
+
)
|
695 |
+
| minimal_cfg
|
696 |
)
|
697 |
|
698 |
with pytest.raises(
|
|
|
700 |
):
|
701 |
validate_config(cfg)
|
702 |
|
703 |
+
cfg = (
|
704 |
+
DictDefault(
|
705 |
+
{
|
706 |
+
"evaluation_strategy": "no",
|
707 |
+
"eval_steps": 10,
|
708 |
+
}
|
709 |
+
)
|
710 |
+
| minimal_cfg
|
711 |
)
|
712 |
|
713 |
with pytest.raises(
|
|
|
715 |
):
|
716 |
validate_config(cfg)
|
717 |
|
718 |
+
cfg = (
|
719 |
+
DictDefault(
|
720 |
+
{
|
721 |
+
"evaluation_strategy": "steps",
|
722 |
+
}
|
723 |
+
)
|
724 |
+
| minimal_cfg
|
725 |
)
|
726 |
|
727 |
validate_config(cfg)
|
728 |
|
729 |
+
cfg = (
|
730 |
+
DictDefault(
|
731 |
+
{
|
732 |
+
"evaluation_strategy": "steps",
|
733 |
+
"eval_steps": 10,
|
734 |
+
}
|
735 |
+
)
|
736 |
+
| minimal_cfg
|
737 |
)
|
738 |
|
739 |
validate_config(cfg)
|
740 |
|
741 |
+
cfg = (
|
742 |
+
DictDefault(
|
743 |
+
{
|
744 |
+
"eval_steps": 10,
|
745 |
+
}
|
746 |
+
)
|
747 |
+
| minimal_cfg
|
748 |
)
|
749 |
|
750 |
validate_config(cfg)
|
751 |
|
752 |
+
cfg = (
|
753 |
+
DictDefault(
|
754 |
+
{
|
755 |
+
"evaluation_strategy": "no",
|
756 |
+
}
|
757 |
+
)
|
758 |
+
| minimal_cfg
|
759 |
)
|
760 |
|
761 |
validate_config(cfg)
|
762 |
|
763 |
+
cfg = (
|
764 |
+
DictDefault(
|
765 |
+
{
|
766 |
+
"evaluation_strategy": "epoch",
|
767 |
+
"val_set_size": 0,
|
768 |
+
}
|
769 |
+
)
|
770 |
+
| minimal_cfg
|
771 |
)
|
772 |
|
773 |
with pytest.raises(
|
|
|
776 |
):
|
777 |
validate_config(cfg)
|
778 |
|
779 |
+
cfg = (
|
780 |
+
DictDefault(
|
781 |
+
{
|
782 |
+
"eval_steps": 10,
|
783 |
+
"val_set_size": 0,
|
784 |
+
}
|
785 |
+
)
|
786 |
+
| minimal_cfg
|
787 |
)
|
788 |
|
789 |
with pytest.raises(
|
|
|
792 |
):
|
793 |
validate_config(cfg)
|
794 |
|
795 |
+
cfg = (
|
796 |
+
DictDefault(
|
797 |
+
{
|
798 |
+
"val_set_size": 0,
|
799 |
+
}
|
800 |
+
)
|
801 |
+
| minimal_cfg
|
802 |
)
|
803 |
|
804 |
validate_config(cfg)
|
805 |
|
806 |
+
cfg = (
|
807 |
+
DictDefault(
|
808 |
+
{
|
809 |
+
"eval_steps": 10,
|
810 |
+
"val_set_size": 0.01,
|
811 |
+
}
|
812 |
+
)
|
813 |
+
| minimal_cfg
|
814 |
)
|
815 |
|
816 |
validate_config(cfg)
|
817 |
|
818 |
+
cfg = (
|
819 |
+
DictDefault(
|
820 |
+
{
|
821 |
+
"evaluation_strategy": "epoch",
|
822 |
+
"val_set_size": 0.01,
|
823 |
+
}
|
824 |
+
)
|
825 |
+
| minimal_cfg
|
826 |
)
|
827 |
|
828 |
validate_config(cfg)
|
829 |
|
830 |
+
def test_eval_table_size_conflict_eval_packing(self, minimal_cfg):
|
831 |
+
cfg = (
|
832 |
+
DictDefault(
|
833 |
+
{
|
834 |
+
"sample_packing": True,
|
835 |
+
"eval_table_size": 100,
|
836 |
+
}
|
837 |
+
)
|
838 |
+
| minimal_cfg
|
839 |
)
|
840 |
|
841 |
with pytest.raises(
|
|
|
843 |
):
|
844 |
validate_config(cfg)
|
845 |
|
846 |
+
cfg = (
|
847 |
+
DictDefault(
|
848 |
+
{
|
849 |
+
"sample_packing": True,
|
850 |
+
"eval_sample_packing": False,
|
851 |
+
}
|
852 |
+
)
|
853 |
+
| minimal_cfg
|
854 |
)
|
855 |
|
856 |
validate_config(cfg)
|
857 |
|
858 |
+
cfg = (
|
859 |
+
DictDefault(
|
860 |
+
{
|
861 |
+
"sample_packing": False,
|
862 |
+
"eval_table_size": 100,
|
863 |
+
}
|
864 |
+
)
|
865 |
+
| minimal_cfg
|
866 |
)
|
867 |
|
868 |
validate_config(cfg)
|
869 |
|
870 |
+
cfg = (
|
871 |
+
DictDefault(
|
872 |
+
{
|
873 |
+
"sample_packing": True,
|
874 |
+
"eval_table_size": 100,
|
875 |
+
"eval_sample_packing": False,
|
876 |
+
}
|
877 |
+
)
|
878 |
+
| minimal_cfg
|
879 |
)
|
880 |
|
881 |
validate_config(cfg)
|
882 |
|
883 |
+
def test_load_in_x_bit_without_adapter(self, minimal_cfg):
|
884 |
+
cfg = (
|
885 |
+
DictDefault(
|
886 |
+
{
|
887 |
+
"load_in_4bit": True,
|
888 |
+
}
|
889 |
+
)
|
890 |
+
| minimal_cfg
|
891 |
)
|
892 |
|
893 |
with pytest.raises(
|
|
|
896 |
):
|
897 |
validate_config(cfg)
|
898 |
|
899 |
+
cfg = (
|
900 |
+
DictDefault(
|
901 |
+
{
|
902 |
+
"load_in_8bit": True,
|
903 |
+
}
|
904 |
+
)
|
905 |
+
| minimal_cfg
|
906 |
)
|
907 |
|
908 |
with pytest.raises(
|
|
|
911 |
):
|
912 |
validate_config(cfg)
|
913 |
|
914 |
+
cfg = (
|
915 |
+
DictDefault(
|
916 |
+
{
|
917 |
+
"load_in_4bit": True,
|
918 |
+
"adapter": "qlora",
|
919 |
+
}
|
920 |
+
)
|
921 |
+
| minimal_cfg
|
922 |
)
|
923 |
|
924 |
validate_config(cfg)
|
925 |
|
926 |
+
cfg = (
|
927 |
+
DictDefault(
|
928 |
+
{
|
929 |
+
"load_in_8bit": True,
|
930 |
+
"adapter": "lora",
|
931 |
+
}
|
932 |
+
)
|
933 |
+
| minimal_cfg
|
934 |
)
|
935 |
|
936 |
validate_config(cfg)
|
937 |
|
938 |
+
def test_warmup_step_no_conflict(self, minimal_cfg):
|
939 |
+
cfg = (
|
940 |
+
DictDefault(
|
941 |
+
{
|
942 |
+
"warmup_steps": 10,
|
943 |
+
"warmup_ratio": 0.1,
|
944 |
+
}
|
945 |
+
)
|
946 |
+
| minimal_cfg
|
947 |
)
|
948 |
|
949 |
with pytest.raises(
|
|
|
952 |
):
|
953 |
validate_config(cfg)
|
954 |
|
955 |
+
cfg = (
|
956 |
+
DictDefault(
|
957 |
+
{
|
958 |
+
"warmup_steps": 10,
|
959 |
+
}
|
960 |
+
)
|
961 |
+
| minimal_cfg
|
962 |
)
|
963 |
|
964 |
validate_config(cfg)
|
965 |
|
966 |
+
cfg = (
|
967 |
+
DictDefault(
|
968 |
+
{
|
969 |
+
"warmup_ratio": 0.1,
|
970 |
+
}
|
971 |
+
)
|
972 |
+
| minimal_cfg
|
973 |
)
|
974 |
|
975 |
validate_config(cfg)
|
976 |
|
977 |
+
def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg):
|
978 |
+
cfg = (
|
979 |
+
DictDefault(
|
980 |
+
{
|
981 |
+
"adapter": "lora",
|
982 |
+
"unfrozen_parameters": [
|
983 |
+
"model.layers.2[0-9]+.block_sparse_moe.gate.*"
|
984 |
+
],
|
985 |
+
"peft_layers_to_transform": [0, 1],
|
986 |
+
}
|
987 |
+
)
|
988 |
+
| minimal_cfg
|
989 |
)
|
990 |
|
991 |
with pytest.raises(
|
|
|
994 |
):
|
995 |
validate_config(cfg)
|
996 |
|
997 |
+
def test_hub_model_id_save_value_warns(self, minimal_cfg):
|
998 |
+
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
999 |
|
1000 |
with self._caplog.at_level(logging.WARNING):
|
1001 |
validate_config(cfg)
|
|
|
1003 |
"set without any models being saved" in self._caplog.records[0].message
|
1004 |
)
|
1005 |
|
1006 |
+
def test_hub_model_id_save_value(self, minimal_cfg):
|
1007 |
+
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
|
1008 |
|
1009 |
with self._caplog.at_level(logging.WARNING):
|
1010 |
validate_config(cfg)
|
1011 |
assert len(self._caplog.records) == 0
|
1012 |
|
1013 |
|
1014 |
+
class TestValidationCheckModelConfig(BaseValidation):
|
1015 |
"""
|
1016 |
Test the validation for the config when the model config is available
|
1017 |
"""
|
1018 |
|
1019 |
+
def test_llama_add_tokens_adapter(self, minimal_cfg):
|
1020 |
+
cfg = (
|
1021 |
+
DictDefault(
|
1022 |
+
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
1023 |
+
)
|
1024 |
+
| minimal_cfg
|
1025 |
)
|
1026 |
model_config = DictDefault({"model_type": "llama"})
|
1027 |
|
|
|
1031 |
):
|
1032 |
check_model_config(cfg, model_config)
|
1033 |
|
1034 |
+
cfg = (
|
1035 |
+
DictDefault(
|
1036 |
+
{
|
1037 |
+
"adapter": "qlora",
|
1038 |
+
"load_in_4bit": True,
|
1039 |
+
"tokens": ["<|imstart|>"],
|
1040 |
+
"lora_modules_to_save": ["embed_tokens"],
|
1041 |
+
}
|
1042 |
+
)
|
1043 |
+
| minimal_cfg
|
1044 |
)
|
1045 |
|
1046 |
with pytest.raises(
|
|
|
1049 |
):
|
1050 |
check_model_config(cfg, model_config)
|
1051 |
|
1052 |
+
cfg = (
|
1053 |
+
DictDefault(
|
1054 |
+
{
|
1055 |
+
"adapter": "qlora",
|
1056 |
+
"load_in_4bit": True,
|
1057 |
+
"tokens": ["<|imstart|>"],
|
1058 |
+
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
1059 |
+
}
|
1060 |
+
)
|
1061 |
+
| minimal_cfg
|
1062 |
)
|
1063 |
|
1064 |
check_model_config(cfg, model_config)
|
1065 |
|
1066 |
+
def test_phi_add_tokens_adapter(self, minimal_cfg):
|
1067 |
+
cfg = (
|
1068 |
+
DictDefault(
|
1069 |
+
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
1070 |
+
)
|
1071 |
+
| minimal_cfg
|
1072 |
)
|
1073 |
model_config = DictDefault({"model_type": "phi"})
|
1074 |
|
|
|
1078 |
):
|
1079 |
check_model_config(cfg, model_config)
|
1080 |
|
1081 |
+
cfg = (
|
1082 |
+
DictDefault(
|
1083 |
+
{
|
1084 |
+
"adapter": "qlora",
|
1085 |
+
"load_in_4bit": True,
|
1086 |
+
"tokens": ["<|imstart|>"],
|
1087 |
+
"lora_modules_to_save": ["embd.wte", "lm_head.linear"],
|
1088 |
+
}
|
1089 |
+
)
|
1090 |
+
| minimal_cfg
|
1091 |
)
|
1092 |
|
1093 |
with pytest.raises(
|
|
|
1096 |
):
|
1097 |
check_model_config(cfg, model_config)
|
1098 |
|
1099 |
+
cfg = (
|
1100 |
+
DictDefault(
|
1101 |
+
{
|
1102 |
+
"adapter": "qlora",
|
1103 |
+
"load_in_4bit": True,
|
1104 |
+
"tokens": ["<|imstart|>"],
|
1105 |
+
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
1106 |
+
}
|
1107 |
+
)
|
1108 |
+
| minimal_cfg
|
1109 |
)
|
1110 |
|
1111 |
check_model_config(cfg, model_config)
|
1112 |
|
1113 |
|
1114 |
+
class TestValidationWandb(BaseValidation):
|
1115 |
"""
|
1116 |
Validation test for wandb
|
1117 |
"""
|
1118 |
|
1119 |
+
def test_wandb_set_run_id_to_name(self, minimal_cfg):
|
1120 |
+
cfg = (
|
1121 |
+
DictDefault(
|
1122 |
+
{
|
1123 |
+
"wandb_run_id": "foo",
|
1124 |
+
}
|
1125 |
+
)
|
1126 |
+
| minimal_cfg
|
1127 |
)
|
1128 |
|
1129 |
with self._caplog.at_level(logging.WARNING):
|
1130 |
+
new_cfg = validate_config(cfg)
|
1131 |
assert any(
|
1132 |
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
1133 |
in record.message
|
1134 |
for record in self._caplog.records
|
1135 |
)
|
1136 |
|
1137 |
+
assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id == "foo"
|
1138 |
|
1139 |
+
cfg = (
|
1140 |
+
DictDefault(
|
1141 |
+
{
|
1142 |
+
"wandb_name": "foo",
|
1143 |
+
}
|
1144 |
+
)
|
1145 |
+
| minimal_cfg
|
1146 |
)
|
1147 |
|
1148 |
+
new_cfg = validate_config(cfg)
|
1149 |
|
1150 |
+
assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id is None
|
1151 |
|
1152 |
+
def test_wandb_sets_env(self, minimal_cfg):
|
1153 |
+
cfg = (
|
1154 |
+
DictDefault(
|
1155 |
+
{
|
1156 |
+
"wandb_project": "foo",
|
1157 |
+
"wandb_name": "bar",
|
1158 |
+
"wandb_run_id": "bat",
|
1159 |
+
"wandb_entity": "baz",
|
1160 |
+
"wandb_mode": "online",
|
1161 |
+
"wandb_watch": "false",
|
1162 |
+
"wandb_log_model": "checkpoint",
|
1163 |
+
}
|
1164 |
+
)
|
1165 |
+
| minimal_cfg
|
1166 |
)
|
1167 |
|
1168 |
+
new_cfg = validate_config(cfg)
|
1169 |
|
1170 |
+
setup_wandb_env_vars(new_cfg)
|
1171 |
|
1172 |
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
1173 |
assert os.environ.get("WANDB_NAME", "") == "bar"
|
|
|
1187 |
os.environ.pop("WANDB_LOG_MODEL", None)
|
1188 |
os.environ.pop("WANDB_DISABLED", None)
|
1189 |
|
1190 |
+
def test_wandb_set_disabled(self, minimal_cfg):
|
1191 |
+
cfg = DictDefault({}) | minimal_cfg
|
1192 |
|
1193 |
+
new_cfg = validate_config(cfg)
|
1194 |
|
1195 |
+
setup_wandb_env_vars(new_cfg)
|
1196 |
|
1197 |
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
1198 |
|
1199 |
+
cfg = (
|
1200 |
+
DictDefault(
|
1201 |
+
{
|
1202 |
+
"wandb_project": "foo",
|
1203 |
+
}
|
1204 |
+
)
|
1205 |
+
| minimal_cfg
|
1206 |
)
|
1207 |
|
1208 |
+
new_cfg = validate_config(cfg)
|
1209 |
|
1210 |
+
setup_wandb_env_vars(new_cfg)
|
1211 |
|
1212 |
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
1213 |
|