from src import config | |
import json | |
def test_trainer_config(): | |
trainer_config = config.TrainerConfig.model_validate_json( | |
json.dumps({"epochs": 21, "_model_config": {"text_config": {"text_model": "test"}}}) | |
) | |
assert trainer_config.epochs == 21 | |
assert trainer_config._model_config.text_config.text_model == "test" | |
assert hasattr(trainer_config._model_config.text_config, "max_len") | |
assert trainer_config._model_config.vision_config == config.TinyCLIPVisionConfig() | |