""" | |
Test classes for checking functionality of the cfg normalization | |
""" | |
import unittest | |
from axolotl.utils.config import normalize_config | |
from axolotl.utils.dict import DictDefault | |
class NormalizeConfigTestCase(unittest.TestCase): | |
""" | |
test class for normalize_config checks | |
""" | |
def _get_base_cfg(self): | |
return DictDefault( | |
{ | |
"base_model": "JackFram/llama-68m", | |
"base_model_config": "JackFram/llama-68m", | |
"tokenizer_type": "LlamaTokenizer", | |
"num_epochs": 1, | |
"micro_batch_size": 1, | |
"gradient_accumulation_steps": 1, | |
} | |
) | |
def test_lr_as_float(self): | |
cfg = ( | |
self._get_base_cfg() | |
| DictDefault( # pylint: disable=unsupported-binary-operation | |
{ | |
"learning_rate": "5e-5", | |
} | |
) | |
) | |
normalize_config(cfg) | |
assert cfg.learning_rate == 0.00005 | |
def test_base_model_config_set_when_empty(self): | |
cfg = self._get_base_cfg() | |
del cfg.base_model_config | |
normalize_config(cfg) | |
assert cfg.base_model_config == cfg.base_model | |