make sure to capture non-null defaults from config validation (#1415)
Browse files
src/axolotl/utils/config/__init__.py
CHANGED
@@ -208,11 +208,11 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|
208 |
dict(
|
209 |
AxolotlConfigWCapabilities(
|
210 |
**cfg.to_dict(), capabilities=capabilities
|
211 |
-
).model_dump(
|
212 |
)
|
213 |
)
|
214 |
return DictDefault(
|
215 |
-
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(
|
216 |
)
|
217 |
|
218 |
|
|
|
208 |
dict(
|
209 |
AxolotlConfigWCapabilities(
|
210 |
**cfg.to_dict(), capabilities=capabilities
|
211 |
+
).model_dump(exclude_none=True)
|
212 |
)
|
213 |
)
|
214 |
return DictDefault(
|
215 |
+
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
216 |
)
|
217 |
|
218 |
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -151,12 +151,6 @@ class PeftConfig(BaseModel):
|
|
151 |
loftq_config: Optional[LoftQConfig] = None
|
152 |
|
153 |
|
154 |
-
class AutoType(str, Enum):
|
155 |
-
"""auto type string configuration subset - used for bf16"""
|
156 |
-
|
157 |
-
AUTO = "auto"
|
158 |
-
|
159 |
-
|
160 |
class SpecialTokensConfig(BaseModel):
|
161 |
"""Special tokens configuration subset"""
|
162 |
|
@@ -307,12 +301,14 @@ class HyperparametersConfig(BaseModel):
|
|
307 |
},
|
308 |
)
|
309 |
|
310 |
-
train_on_inputs: Optional[bool] =
|
311 |
group_by_length: Optional[bool] = None
|
312 |
|
313 |
learning_rate: Union[str, float]
|
314 |
-
weight_decay: Optional[float] =
|
315 |
-
optimizer: Optional[
|
|
|
|
|
316 |
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
317 |
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
318 |
)
|
@@ -323,7 +319,7 @@ class HyperparametersConfig(BaseModel):
|
|
323 |
},
|
324 |
)
|
325 |
torchdistx_path: Optional[str] = None
|
326 |
-
lr_scheduler: Optional[SchedulerType] =
|
327 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
328 |
lr_quadratic_warmup: Optional[bool] = None
|
329 |
cosine_min_lr_ratio: Optional[float] = None
|
@@ -473,7 +469,7 @@ class AxolotlInputConfig(
|
|
473 |
loss_watchdog_threshold: Optional[float] = None
|
474 |
loss_watchdog_patience: Optional[int] = None
|
475 |
|
476 |
-
bf16: Optional[Union[
|
477 |
fp16: Optional[bool] = None
|
478 |
bfloat16: Optional[bool] = None # for non-AMP cases
|
479 |
float16: Optional[bool] = None # for non-AMP cases
|
@@ -487,7 +483,7 @@ class AxolotlInputConfig(
|
|
487 |
|
488 |
unfrozen_parameters: Optional[List[str]] = None
|
489 |
|
490 |
-
sequence_len: int = Field(default=
|
491 |
sample_packing: Optional[bool] = None
|
492 |
eval_sample_packing: Optional[bool] = None
|
493 |
pad_to_sequence_len: Optional[bool] = None
|
@@ -548,10 +544,10 @@ class AxolotlInputConfig(
|
|
548 |
sample_packing_eff_est: Optional[float] = None
|
549 |
axolotl_config_path: Optional[str] = None
|
550 |
|
551 |
-
is_falcon_derived_model: Optional[bool] = Field(default=
|
552 |
-
is_llama_derived_model: Optional[bool] = Field(default=
|
553 |
-
is_mistral_derived_model: Optional[bool] = Field(default=
|
554 |
-
is_qwen_derived_model: Optional[bool] = Field(default=
|
555 |
|
556 |
@field_validator("datasets", mode="before")
|
557 |
@classmethod
|
|
|
151 |
loftq_config: Optional[LoftQConfig] = None
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
class SpecialTokensConfig(BaseModel):
|
155 |
"""Special tokens configuration subset"""
|
156 |
|
|
|
301 |
},
|
302 |
)
|
303 |
|
304 |
+
train_on_inputs: Optional[bool] = False
|
305 |
group_by_length: Optional[bool] = None
|
306 |
|
307 |
learning_rate: Union[str, float]
|
308 |
+
weight_decay: Optional[float] = 0.0
|
309 |
+
optimizer: Optional[
|
310 |
+
Union[OptimizerNames, Literal["lion_pytorch"]]
|
311 |
+
] = OptimizerNames.ADAMW_HF.value
|
312 |
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
313 |
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
314 |
)
|
|
|
319 |
},
|
320 |
)
|
321 |
torchdistx_path: Optional[str] = None
|
322 |
+
lr_scheduler: Optional[SchedulerType] = "cosine"
|
323 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
324 |
lr_quadratic_warmup: Optional[bool] = None
|
325 |
cosine_min_lr_ratio: Optional[float] = None
|
|
|
469 |
loss_watchdog_threshold: Optional[float] = None
|
470 |
loss_watchdog_patience: Optional[int] = None
|
471 |
|
472 |
+
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
|
473 |
fp16: Optional[bool] = None
|
474 |
bfloat16: Optional[bool] = None # for non-AMP cases
|
475 |
float16: Optional[bool] = None # for non-AMP cases
|
|
|
483 |
|
484 |
unfrozen_parameters: Optional[List[str]] = None
|
485 |
|
486 |
+
sequence_len: int = Field(default=512)
|
487 |
sample_packing: Optional[bool] = None
|
488 |
eval_sample_packing: Optional[bool] = None
|
489 |
pad_to_sequence_len: Optional[bool] = None
|
|
|
544 |
sample_packing_eff_est: Optional[float] = None
|
545 |
axolotl_config_path: Optional[str] = None
|
546 |
|
547 |
+
is_falcon_derived_model: Optional[bool] = Field(default=None)
|
548 |
+
is_llama_derived_model: Optional[bool] = Field(default=None)
|
549 |
+
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
550 |
+
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
551 |
|
552 |
@field_validator("datasets", mode="before")
|
553 |
@classmethod
|
tests/test_validation.py
CHANGED
@@ -54,6 +54,18 @@ class TestValidation(BaseValidation):
|
|
54 |
Test the validation module
|
55 |
"""
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def test_datasets_min_length(self):
|
58 |
cfg = DictDefault(
|
59 |
{
|
|
|
54 |
Test the validation module
|
55 |
"""
|
56 |
|
57 |
+
def test_defaults(self, minimal_cfg):
|
58 |
+
test_cfg = DictDefault(
|
59 |
+
{
|
60 |
+
"weight_decay": None,
|
61 |
+
}
|
62 |
+
| minimal_cfg
|
63 |
+
)
|
64 |
+
cfg = validate_config(test_cfg)
|
65 |
+
|
66 |
+
assert cfg.train_on_inputs is False
|
67 |
+
assert cfg.weight_decay is None
|
68 |
+
|
69 |
def test_datasets_min_length(self):
|
70 |
cfg = DictDefault(
|
71 |
{
|