"""Module for testing models utils file.""" | |
import unittest | |
from unittest.mock import patch | |
import pytest | |
from axolotl.utils.dict import DictDefault | |
from axolotl.utils.models import load_model | |
class ModelsUtilsTest(unittest.TestCase): | |
"""Testing module for models utils.""" | |
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): | |
cfg = DictDefault( | |
{ | |
"s2_attention": True, | |
"sample_packing": True, | |
"base_model": "", | |
"model_type": "LlamaForCausalLM", | |
} | |
) | |
# Mock out call to HF hub | |
with patch( | |
"axolotl.utils.models.load_model_config" | |
) as mocked_load_model_config: | |
mocked_load_model_config.return_value = {} | |
with pytest.raises(ValueError) as exc: | |
# Should error before hitting tokenizer, so we pass in an empty str | |
load_model(cfg, tokenizer="") | |
assert ( | |
"shifted-sparse attention does not currently support sample packing" | |
in str(exc.value) | |
) | |