"""Module for testing the validation module"""

import logging
import unittest
from typing import Optional

import pytest

from axolotl.utils.dict import DictDefault
from axolotl.utils.validation import validate_config


class ValidationTest(unittest.TestCase):
    """
    Test the validation module
    """

    _caplog: Optional[pytest.LogCaptureFixture] = None

    @pytest.fixture(autouse=True)
    def inject_fixtures(self, caplog):
        self._caplog = caplog

    def test_load_4bit_deprecate(self):
        cfg = DictDefault(
            {
                "load_4bit": True,
            }
        )

        with pytest.raises(ValueError):
            validate_config(cfg)

    def test_batch_size_unused_warning(self):
        cfg = DictDefault(
            {
                "batch_size": 32,
            }
        )

        with self._caplog.at_level(logging.WARNING):
            validate_config(cfg)
            assert "batch_size is not recommended" in self._caplog.records[0].message

    def test_qlora(self):
        base_cfg = DictDefault(
            {
                "adapter": "qlora",
            }
        )

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "load_in_8bit": True,
            }
        )

        with pytest.raises(ValueError, match=r".*8bit.*"):
            validate_config(cfg)

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "gptq": True,
            }
        )

        with pytest.raises(ValueError, match=r".*gptq.*"):
            validate_config(cfg)

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "load_in_4bit": False,
            }
        )

        with pytest.raises(ValueError, match=r".*4bit.*"):
            validate_config(cfg)

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "load_in_4bit": True,
            }
        )

        validate_config(cfg)

    def test_qlora_merge(self):
        base_cfg = DictDefault(
            {
                "adapter": "qlora",
                "merge_lora": True,
            }
        )

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "load_in_8bit": True,
            }
        )

        with pytest.raises(ValueError, match=r".*8bit.*"):
            validate_config(cfg)

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "gptq": True,
            }
        )

        with pytest.raises(ValueError, match=r".*gptq.*"):
            validate_config(cfg)

        cfg = base_cfg | DictDefault(  # pylint: disable=unsupported-binary-operation
            {
                "load_in_4bit": True,
            }
        )

        with pytest.raises(ValueError, match=r".*4bit.*"):
            validate_config(cfg)

    def test_hf_use_auth_token(self):
        cfg = DictDefault(
            {
                "push_dataset_to_hub": "namespace/repo",
            }
        )

        with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
            validate_config(cfg)

        cfg = DictDefault(
            {
                "push_dataset_to_hub": "namespace/repo",
                "hf_use_auth_token": True,
            }
        )
        validate_config(cfg)

    def test_gradient_accumulations_or_batch_size(self):
        cfg = DictDefault(
            {
                "gradient_accumulation_steps": 1,
                "batch_size": 1,
            }
        )

        with pytest.raises(
            ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
        ):
            validate_config(cfg)

        cfg = DictDefault(
            {
                "batch_size": 1,
            }
        )

        validate_config(cfg)

        cfg = DictDefault(
            {
                "gradient_accumulation_steps": 1,
            }
        )

        validate_config(cfg)

    def test_falcon_fsdp(self):
        regex_exp = r".*FSDP is not supported for falcon models.*"

        # Check for lower-case
        cfg = DictDefault(
            {
                "base_model": "tiiuae/falcon-7b",
                "fsdp": ["full_shard", "auto_wrap"],
            }
        )

        with pytest.raises(ValueError, match=regex_exp):
            validate_config(cfg)

        # Check for upper-case
        cfg = DictDefault(
            {
                "base_model": "Falcon-7b",
                "fsdp": ["full_shard", "auto_wrap"],
            }
        )

        with pytest.raises(ValueError, match=regex_exp):
            validate_config(cfg)

        cfg = DictDefault(
            {
                "base_model": "tiiuae/falcon-7b",
            }
        )

        validate_config(cfg)

    def test_mpt_gradient_checkpointing(self):
        regex_exp = r".*gradient_checkpointing is not supported for MPT models*"

        # Check for lower-case
        cfg = DictDefault(
            {
                "base_model": "mosaicml/mpt-7b",
                "gradient_checkpointing": True,
            }
        )

        with pytest.raises(ValueError, match=regex_exp):
            validate_config(cfg)

    def test_flash_optimum(self):
        cfg = DictDefault(
            {
                "flash_optimum": True,
                "adapter": "lora",
            }
        )

        with self._caplog.at_level(logging.WARNING):
            validate_config(cfg)
            assert any(
                "BetterTransformers probably doesn't work with PEFT adapters"
                in record.message
                for record in self._caplog.records
            )

        cfg = DictDefault(
            {
                "flash_optimum": True,
            }
        )

        with self._caplog.at_level(logging.WARNING):
            validate_config(cfg)
            assert any(
                "probably set bfloat16 or float16" in record.message
                for record in self._caplog.records
            )

        cfg = DictDefault(
            {
                "flash_optimum": True,
                "fp16": True,
            }
        )
        regex_exp = r".*AMP is not supported.*"

        with pytest.raises(ValueError, match=regex_exp):
            validate_config(cfg)

        cfg = DictDefault(
            {
                "flash_optimum": True,
                "bf16": True,
            }
        )
        regex_exp = r".*AMP is not supported.*"

        with pytest.raises(ValueError, match=regex_exp):
            validate_config(cfg)

    def test_adamw_hyperparams(self):
        cfg = DictDefault(
            {
                "optimizer": None,
                "adam_epsilon": 0.0001,
            }
        )

        with self._caplog.at_level(logging.WARNING):
            validate_config(cfg)
            assert any(
                "adamw hyperparameters found, but no adamw optimizer set"
                in record.message
                for record in self._caplog.records
            )

        cfg = DictDefault(
            {
                "optimizer": "adafactor",
                "adam_beta1": 0.0001,
            }
        )

        with self._caplog.at_level(logging.WARNING):
            validate_config(cfg)
            assert any(
                "adamw hyperparameters found, but no adamw optimizer set"
                in record.message
                for record in self._caplog.records
            )

        cfg = DictDefault(
            {
                "optimizer": "adamw_bnb_8bit",
                "adam_beta1": 0.9,
                "adam_beta2": 0.99,
                "adam_epsilon": 0.0001,
            }
        )

        validate_config(cfg)

        cfg = DictDefault(
            {
                "optimizer": "adafactor",
            }
        )

        validate_config(cfg)