"""Module for testing the validation module""" import logging import os import unittest from typing import Optional import pytest from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import check_model_config from axolotl.utils.wandb_ import setup_wandb_env_vars class BaseValidation(unittest.TestCase): """ Base validation module to setup the log capture """ _caplog: Optional[pytest.LogCaptureFixture] = None @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog # pylint: disable=too-many-public-methods class ValidationTest(BaseValidation): """ Test the validation module """ def test_load_4bit_deprecate(self): cfg = DictDefault( { "load_4bit": True, } ) with pytest.raises(ValueError): validate_config(cfg) def test_batch_size_unused_warning(self): cfg = DictDefault( { "batch_size": 32, } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert "batch_size is not recommended" in self._caplog.records[0].message def test_qlora(self): base_cfg = DictDefault( { "adapter": "qlora", } ) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_8bit": True, } ) with pytest.raises(ValueError, match=r".*8bit.*"): validate_config(cfg) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "gptq": True, } ) with pytest.raises(ValueError, match=r".*gptq.*"): validate_config(cfg) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_4bit": False, } ) with pytest.raises(ValueError, match=r".*4bit.*"): validate_config(cfg) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_4bit": True, } ) validate_config(cfg) def test_qlora_merge(self): base_cfg = DictDefault( { "adapter": "qlora", "merge_lora": True, } ) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_8bit": True, } ) with pytest.raises(ValueError, match=r".*8bit.*"): validate_config(cfg) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "gptq": True, } ) with pytest.raises(ValueError, match=r".*gptq.*"): validate_config(cfg) cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation { "load_in_4bit": True, } ) with pytest.raises(ValueError, match=r".*4bit.*"): validate_config(cfg) def test_hf_use_auth_token(self): cfg = DictDefault( { "push_dataset_to_hub": "namespace/repo", } ) with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"): validate_config(cfg) cfg = DictDefault( { "push_dataset_to_hub": "namespace/repo", "hf_use_auth_token": True, } ) validate_config(cfg) def test_gradient_accumulations_or_batch_size(self): cfg = DictDefault( { "gradient_accumulation_steps": 1, "batch_size": 1, } ) with pytest.raises( ValueError, match=r".*gradient_accumulation_steps or batch_size.*" ): validate_config(cfg) cfg = DictDefault( { "batch_size": 1, } ) validate_config(cfg) cfg = DictDefault( { "gradient_accumulation_steps": 1, } ) validate_config(cfg) def test_falcon_fsdp(self): regex_exp = r".*FSDP is not supported for falcon models.*" # Check for lower-case cfg = DictDefault( { "base_model": "tiiuae/falcon-7b", "fsdp": ["full_shard", "auto_wrap"], } ) with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) # Check for upper-case cfg = DictDefault( { "base_model": "Falcon-7b", "fsdp": ["full_shard", "auto_wrap"], } ) with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) cfg = DictDefault( { "base_model": "tiiuae/falcon-7b", } ) validate_config(cfg) def test_mpt_gradient_checkpointing(self): regex_exp = r".*gradient_checkpointing is not supported for MPT models*" # Check for lower-case cfg = DictDefault( { "base_model": "mosaicml/mpt-7b", "gradient_checkpointing": True, } ) with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) def test_flash_optimum(self): cfg = DictDefault( { "flash_optimum": True, "adapter": "lora", } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "BetterTransformers probably doesn't work with PEFT adapters" in record.message for record in self._caplog.records ) cfg = DictDefault( { "flash_optimum": True, } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "probably set bfloat16 or float16" in record.message for record in self._caplog.records ) cfg = DictDefault( { "flash_optimum": True, "fp16": True, } ) regex_exp = r".*AMP is not supported.*" with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) cfg = DictDefault( { "flash_optimum": True, "bf16": True, } ) regex_exp = r".*AMP is not supported.*" with pytest.raises(ValueError, match=regex_exp): validate_config(cfg) def test_adamw_hyperparams(self): cfg = DictDefault( { "optimizer": None, "adam_epsilon": 0.0001, } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "adamw hyperparameters found, but no adamw optimizer set" in record.message for record in self._caplog.records ) cfg = DictDefault( { "optimizer": "adafactor", "adam_beta1": 0.0001, } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "adamw hyperparameters found, but no adamw optimizer set" in record.message for record in self._caplog.records ) cfg = DictDefault( { "optimizer": "adamw_bnb_8bit", "adam_beta1": 0.9, "adam_beta2": 0.99, "adam_epsilon": 0.0001, } ) validate_config(cfg) cfg = DictDefault( { "optimizer": "adafactor", } ) validate_config(cfg) def test_deprecated_packing(self): cfg = DictDefault( { "max_packed_sequence_len": 1024, } ) with pytest.raises( DeprecationWarning, match=r"`max_packed_sequence_len` is no longer supported", ): validate_config(cfg) def test_packing(self): cfg = DictDefault( { "sample_packing": True, "pad_to_sequence_len": None, } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "`pad_to_sequence_len: true` is recommended when using sample_packing" in record.message for record in self._caplog.records ) @pytest.mark.skipif( is_torch_bf16_gpu_available(), reason="test should only run on gpus w/o bf16 support", ) def test_merge_lora_no_bf16_fail(self): """ This is assumed to be run on a CPU machine, so bf16 is not supported. """ cfg = DictDefault( { "bf16": True, } ) with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"): validate_config(cfg) cfg = DictDefault( { "bf16": True, "merge_lora": True, } ) validate_config(cfg) def test_sharegpt_deprecation(self): cfg = DictDefault( {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "`type: sharegpt:chat` will soon be deprecated." in record.message for record in self._caplog.records ) assert cfg.datasets[0].type == "sharegpt" cfg = DictDefault( {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]} ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "`type: sharegpt_simple` will soon be deprecated." in record.message for record in self._caplog.records ) assert cfg.datasets[0].type == "sharegpt:load_role" def test_no_conflict_save_strategy(self): cfg = DictDefault( { "save_strategy": "epoch", "save_steps": 10, } ) with pytest.raises( ValueError, match=r".*save_strategy and save_steps mismatch.*" ): validate_config(cfg) cfg = DictDefault( { "save_strategy": "no", "save_steps": 10, } ) with pytest.raises( ValueError, match=r".*save_strategy and save_steps mismatch.*" ): validate_config(cfg) cfg = DictDefault( { "save_strategy": "steps", } ) validate_config(cfg) cfg = DictDefault( { "save_strategy": "steps", "save_steps": 10, } ) validate_config(cfg) cfg = DictDefault( { "save_steps": 10, } ) validate_config(cfg) cfg = DictDefault( { "save_strategy": "no", } ) validate_config(cfg) def test_no_conflict_eval_strategy(self): cfg = DictDefault( { "evaluation_strategy": "epoch", "eval_steps": 10, } ) with pytest.raises( ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" ): validate_config(cfg) cfg = DictDefault( { "evaluation_strategy": "no", "eval_steps": 10, } ) with pytest.raises( ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*" ): validate_config(cfg) cfg = DictDefault( { "evaluation_strategy": "steps", } ) validate_config(cfg) cfg = DictDefault( { "evaluation_strategy": "steps", "eval_steps": 10, } ) validate_config(cfg) cfg = DictDefault( { "eval_steps": 10, } ) validate_config(cfg) cfg = DictDefault( { "evaluation_strategy": "no", } ) validate_config(cfg) cfg = DictDefault( { "evaluation_strategy": "epoch", "val_set_size": 0, } ) with pytest.raises( ValueError, match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", ): validate_config(cfg) cfg = DictDefault( { "eval_steps": 10, "val_set_size": 0, } ) with pytest.raises( ValueError, match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*", ): validate_config(cfg) cfg = DictDefault( { "val_set_size": 0, } ) validate_config(cfg) cfg = DictDefault( { "eval_steps": 10, "val_set_size": 0.01, } ) validate_config(cfg) cfg = DictDefault( { "evaluation_strategy": "epoch", "val_set_size": 0.01, } ) validate_config(cfg) def test_eval_table_size_conflict_eval_packing(self): cfg = DictDefault( { "sample_packing": True, "eval_table_size": 100, } ) with pytest.raises( ValueError, match=r".*Please set 'eval_sample_packing' to false.*" ): validate_config(cfg) cfg = DictDefault( { "sample_packing": True, "eval_sample_packing": False, } ) validate_config(cfg) cfg = DictDefault( { "sample_packing": False, "eval_table_size": 100, } ) validate_config(cfg) cfg = DictDefault( { "sample_packing": True, "eval_table_size": 100, "eval_sample_packing": False, } ) validate_config(cfg) def test_load_in_x_bit_without_adapter(self): cfg = DictDefault( { "load_in_4bit": True, } ) with pytest.raises( ValueError, match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", ): validate_config(cfg) cfg = DictDefault( { "load_in_8bit": True, } ) with pytest.raises( ValueError, match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", ): validate_config(cfg) cfg = DictDefault( { "load_in_4bit": True, "adapter": "qlora", } ) validate_config(cfg) cfg = DictDefault( { "load_in_8bit": True, "adapter": "lora", } ) validate_config(cfg) def test_warmup_step_no_conflict(self): cfg = DictDefault( { "warmup_steps": 10, "warmup_ratio": 0.1, } ) with pytest.raises( ValueError, match=r".*warmup_steps and warmup_ratio are mutually exclusive*", ): validate_config(cfg) cfg = DictDefault( { "warmup_steps": 10, } ) validate_config(cfg) cfg = DictDefault( { "warmup_ratio": 0.1, } ) validate_config(cfg) def test_unfrozen_parameters_w_peft_layers_to_transform(self): cfg = DictDefault( { "adapter": "lora", "unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"], "peft_layers_to_transform": [0, 1], } ) with pytest.raises( ValueError, match=r".*can have unexpected behavior*", ): validate_config(cfg) def test_hub_model_id_save_value_warns(self): cfg = DictDefault({"hub_model_id": "test"}) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert ( "set without any models being saved" in self._caplog.records[0].message ) def test_hub_model_id_save_value(self): cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert len(self._caplog.records) == 0 class ValidationCheckModelConfig(BaseValidation): """ Test the validation for the config when the model config is available """ def test_llama_add_tokens_adapter(self): cfg = DictDefault( {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} ) model_config = DictDefault({"model_type": "llama"}) with pytest.raises( ValueError, match=r".*`lora_modules_to_save` not properly set when adding new tokens*", ): check_model_config(cfg, model_config) cfg = DictDefault( { "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], "lora_modules_to_save": ["embed_tokens"], } ) with pytest.raises( ValueError, match=r".*`lora_modules_to_save` not properly set when adding new tokens*", ): check_model_config(cfg, model_config) cfg = DictDefault( { "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], "lora_modules_to_save": ["embed_tokens", "lm_head"], } ) check_model_config(cfg, model_config) def test_phi_add_tokens_adapter(self): cfg = DictDefault( {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]} ) model_config = DictDefault({"model_type": "phi"}) with pytest.raises( ValueError, match=r".*`lora_modules_to_save` not properly set when adding new tokens*", ): check_model_config(cfg, model_config) cfg = DictDefault( { "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], "lora_modules_to_save": ["embd.wte", "lm_head.linear"], } ) with pytest.raises( ValueError, match=r".*`lora_modules_to_save` not properly set when adding new tokens*", ): check_model_config(cfg, model_config) cfg = DictDefault( { "adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"], "lora_modules_to_save": ["embed_tokens", "lm_head"], } ) check_model_config(cfg, model_config) class ValidationWandbTest(BaseValidation): """ Validation test for wandb """ def test_wandb_set_run_id_to_name(self): cfg = DictDefault( { "wandb_run_id": "foo", } ) with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." in record.message for record in self._caplog.records ) assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo" cfg = DictDefault( { "wandb_name": "foo", } ) validate_config(cfg) assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None def test_wandb_sets_env(self): cfg = DictDefault( { "wandb_project": "foo", "wandb_name": "bar", "wandb_run_id": "bat", "wandb_entity": "baz", "wandb_mode": "online", "wandb_watch": "false", "wandb_log_model": "checkpoint", } ) validate_config(cfg) setup_wandb_env_vars(cfg) assert os.environ.get("WANDB_PROJECT", "") == "foo" assert os.environ.get("WANDB_NAME", "") == "bar" assert os.environ.get("WANDB_RUN_ID", "") == "bat" assert os.environ.get("WANDB_ENTITY", "") == "baz" assert os.environ.get("WANDB_MODE", "") == "online" assert os.environ.get("WANDB_WATCH", "") == "false" assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint" assert os.environ.get("WANDB_DISABLED", "") != "true" os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_NAME", None) os.environ.pop("WANDB_RUN_ID", None) os.environ.pop("WANDB_ENTITY", None) os.environ.pop("WANDB_MODE", None) os.environ.pop("WANDB_WATCH", None) os.environ.pop("WANDB_LOG_MODEL", None) os.environ.pop("WANDB_DISABLED", None) def test_wandb_set_disabled(self): cfg = DictDefault({}) validate_config(cfg) setup_wandb_env_vars(cfg) assert os.environ.get("WANDB_DISABLED", "") == "true" cfg = DictDefault( { "wandb_project": "foo", } ) validate_config(cfg) setup_wandb_env_vars(cfg) assert os.environ.get("WANDB_DISABLED", "") != "true" os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_DISABLED", None)