|
""" |
|
unit tests for axolotl.core.trainer_builder |
|
""" |
|
|
|
import pytest |
|
|
|
from axolotl.core.trainer_builder import HFRLTrainerBuilder |
|
from axolotl.utils.config import normalize_config |
|
from axolotl.utils.dict import DictDefault |
|
from axolotl.utils.models import load_model, load_tokenizer |
|
|
|
|
|
@pytest.fixture(name="cfg") |
|
def fixture_cfg(): |
|
cfg = DictDefault( |
|
{ |
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", |
|
"model_type": "AutoModelForCausalLM", |
|
"tokenizer_type": "LlamaTokenizer", |
|
"micro_batch_size": 1, |
|
"gradient_accumulation_steps": 1, |
|
"learning_rate": 0.00005, |
|
"save_steps": 100, |
|
"output_dir": "./model-out", |
|
"warmup_steps": 10, |
|
"gradient_checkpointing": False, |
|
"optimizer": "adamw_torch", |
|
"sequence_len": 2048, |
|
"rl": True, |
|
"adam_beta1": 0.998, |
|
"adam_beta2": 0.9, |
|
"adam_epsilon": 0.00001, |
|
"dataloader_num_workers": 1, |
|
"dataloader_pin_memory": True, |
|
"model_config_type": "llama", |
|
} |
|
) |
|
|
|
normalize_config(cfg) |
|
|
|
return cfg |
|
|
|
|
|
@pytest.fixture(name="tokenizer") |
|
def fixture_tokenizer(cfg): |
|
return load_tokenizer(cfg) |
|
|
|
|
|
@pytest.fixture(name="model") |
|
def fixture_model(cfg, tokenizer): |
|
return load_model(cfg, tokenizer) |
|
|
|
|
|
class TestHFRLTrainerBuilder: |
|
""" |
|
TestCase class for DPO trainer builder |
|
""" |
|
|
|
def test_build_training_arguments(self, cfg, model, tokenizer): |
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer) |
|
training_arguments = builder.build_training_arguments(100) |
|
assert training_arguments.adam_beta1 == 0.998 |
|
assert training_arguments.adam_beta2 == 0.9 |
|
assert training_arguments.adam_epsilon == 0.00001 |
|
assert training_arguments.dataloader_num_workers == 1 |
|
assert training_arguments.dataloader_pin_memory is True |
|
|