|
"""Module for testing the validation module""" |
|
|
|
import logging |
|
import unittest |
|
from typing import Optional |
|
|
|
import pytest |
|
|
|
from axolotl.utils.config import validate_config |
|
from axolotl.utils.dict import DictDefault |
|
|
|
|
|
class ValidationTest(unittest.TestCase): |
|
""" |
|
Test the validation module |
|
""" |
|
|
|
_caplog: Optional[pytest.LogCaptureFixture] = None |
|
|
|
@pytest.fixture(autouse=True) |
|
def inject_fixtures(self, caplog): |
|
self._caplog = caplog |
|
|
|
def test_load_4bit_deprecate(self): |
|
cfg = DictDefault( |
|
{ |
|
"load_4bit": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError): |
|
validate_config(cfg) |
|
|
|
def test_batch_size_unused_warning(self): |
|
cfg = DictDefault( |
|
{ |
|
"batch_size": 32, |
|
} |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert "batch_size is not recommended" in self._caplog.records[0].message |
|
|
|
def test_qlora(self): |
|
base_cfg = DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
} |
|
) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"load_in_8bit": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*8bit.*"): |
|
validate_config(cfg) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"gptq": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*gptq.*"): |
|
validate_config(cfg) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"load_in_4bit": False, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*4bit.*"): |
|
validate_config(cfg) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"load_in_4bit": True, |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_qlora_merge(self): |
|
base_cfg = DictDefault( |
|
{ |
|
"adapter": "qlora", |
|
"merge_lora": True, |
|
} |
|
) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"load_in_8bit": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*8bit.*"): |
|
validate_config(cfg) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"gptq": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*gptq.*"): |
|
validate_config(cfg) |
|
|
|
cfg = base_cfg | DictDefault( |
|
{ |
|
"load_in_4bit": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*4bit.*"): |
|
validate_config(cfg) |
|
|
|
def test_hf_use_auth_token(self): |
|
cfg = DictDefault( |
|
{ |
|
"push_dataset_to_hub": "namespace/repo", |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"): |
|
validate_config(cfg) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"push_dataset_to_hub": "namespace/repo", |
|
"hf_use_auth_token": True, |
|
} |
|
) |
|
validate_config(cfg) |
|
|
|
def test_gradient_accumulations_or_batch_size(self): |
|
cfg = DictDefault( |
|
{ |
|
"gradient_accumulation_steps": 1, |
|
"batch_size": 1, |
|
} |
|
) |
|
|
|
with pytest.raises( |
|
ValueError, match=r".*gradient_accumulation_steps or batch_size.*" |
|
): |
|
validate_config(cfg) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"batch_size": 1, |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"gradient_accumulation_steps": 1, |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_falcon_fsdp(self): |
|
regex_exp = r".*FSDP is not supported for falcon models.*" |
|
|
|
|
|
cfg = DictDefault( |
|
{ |
|
"base_model": "tiiuae/falcon-7b", |
|
"fsdp": ["full_shard", "auto_wrap"], |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
|
|
cfg = DictDefault( |
|
{ |
|
"base_model": "Falcon-7b", |
|
"fsdp": ["full_shard", "auto_wrap"], |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"base_model": "tiiuae/falcon-7b", |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_mpt_gradient_checkpointing(self): |
|
regex_exp = r".*gradient_checkpointing is not supported for MPT models*" |
|
|
|
|
|
cfg = DictDefault( |
|
{ |
|
"base_model": "mosaicml/mpt-7b", |
|
"gradient_checkpointing": True, |
|
} |
|
) |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
def test_flash_optimum(self): |
|
cfg = DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"adapter": "lora", |
|
} |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"BetterTransformers probably doesn't work with PEFT adapters" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
} |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"probably set bfloat16 or float16" in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"fp16": True, |
|
} |
|
) |
|
regex_exp = r".*AMP is not supported.*" |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"flash_optimum": True, |
|
"bf16": True, |
|
} |
|
) |
|
regex_exp = r".*AMP is not supported.*" |
|
|
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|
|
def test_adamw_hyperparams(self): |
|
cfg = DictDefault( |
|
{ |
|
"optimizer": None, |
|
"adam_epsilon": 0.0001, |
|
} |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"adamw hyperparameters found, but no adamw optimizer set" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"optimizer": "adafactor", |
|
"adam_beta1": 0.0001, |
|
} |
|
) |
|
|
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"adamw hyperparameters found, but no adamw optimizer set" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"optimizer": "adamw_bnb_8bit", |
|
"adam_beta1": 0.9, |
|
"adam_beta2": 0.99, |
|
"adam_epsilon": 0.0001, |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"optimizer": "adafactor", |
|
} |
|
) |
|
|
|
validate_config(cfg) |
|
|
|
def test_packing(self): |
|
cfg = DictDefault( |
|
{ |
|
"max_packed_sequence_len": 2048, |
|
} |
|
) |
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"max_packed_sequence_len will be deprecated in favor of sample_packing" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"sample_packing": True, |
|
"pad_to_sequence_len": None, |
|
} |
|
) |
|
with self._caplog.at_level(logging.WARNING): |
|
validate_config(cfg) |
|
assert any( |
|
"`pad_to_sequence_len: true` is recommended when using sample_packing" |
|
in record.message |
|
for record in self._caplog.records |
|
) |
|
|
|
cfg = DictDefault( |
|
{ |
|
"max_packed_sequence_len": 2048, |
|
"sample_packing": True, |
|
} |
|
) |
|
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*" |
|
with pytest.raises(ValueError, match=regex_exp): |
|
validate_config(cfg) |
|
|