winglian commited on
Commit
601b77b
1 Parent(s): ff939d8

make sure to capture non-null defaults from config validation (#1415)

Browse files
src/axolotl/utils/config/__init__.py CHANGED
@@ -208,11 +208,11 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
208
  dict(
209
  AxolotlConfigWCapabilities(
210
  **cfg.to_dict(), capabilities=capabilities
211
- ).model_dump(exclude_unset=True)
212
  )
213
  )
214
  return DictDefault(
215
- dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))
216
  )
217
 
218
 
 
208
  dict(
209
  AxolotlConfigWCapabilities(
210
  **cfg.to_dict(), capabilities=capabilities
211
+ ).model_dump(exclude_none=True)
212
  )
213
  )
214
  return DictDefault(
215
+ dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
216
  )
217
 
218
 
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -151,12 +151,6 @@ class PeftConfig(BaseModel):
151
  loftq_config: Optional[LoftQConfig] = None
152
 
153
 
154
- class AutoType(str, Enum):
155
- """auto type string configuration subset - used for bf16"""
156
-
157
- AUTO = "auto"
158
-
159
-
160
  class SpecialTokensConfig(BaseModel):
161
  """Special tokens configuration subset"""
162
 
@@ -307,12 +301,14 @@ class HyperparametersConfig(BaseModel):
307
  },
308
  )
309
 
310
- train_on_inputs: Optional[bool] = None
311
  group_by_length: Optional[bool] = None
312
 
313
  learning_rate: Union[str, float]
314
- weight_decay: Optional[float] = None
315
- optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
 
 
316
  optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
317
  default=None, metadata={"help": "Optional arguments to supply to optimizer."}
318
  )
@@ -323,7 +319,7 @@ class HyperparametersConfig(BaseModel):
323
  },
324
  )
325
  torchdistx_path: Optional[str] = None
326
- lr_scheduler: Optional[SchedulerType] = None
327
  lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
328
  lr_quadratic_warmup: Optional[bool] = None
329
  cosine_min_lr_ratio: Optional[float] = None
@@ -473,7 +469,7 @@ class AxolotlInputConfig(
473
  loss_watchdog_threshold: Optional[float] = None
474
  loss_watchdog_patience: Optional[int] = None
475
 
476
- bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
477
  fp16: Optional[bool] = None
478
  bfloat16: Optional[bool] = None # for non-AMP cases
479
  float16: Optional[bool] = None # for non-AMP cases
@@ -487,7 +483,7 @@ class AxolotlInputConfig(
487
 
488
  unfrozen_parameters: Optional[List[str]] = None
489
 
490
- sequence_len: int = Field(default=1024)
491
  sample_packing: Optional[bool] = None
492
  eval_sample_packing: Optional[bool] = None
493
  pad_to_sequence_len: Optional[bool] = None
@@ -548,10 +544,10 @@ class AxolotlInputConfig(
548
  sample_packing_eff_est: Optional[float] = None
549
  axolotl_config_path: Optional[str] = None
550
 
551
- is_falcon_derived_model: Optional[bool] = Field(default=False)
552
- is_llama_derived_model: Optional[bool] = Field(default=False)
553
- is_mistral_derived_model: Optional[bool] = Field(default=False)
554
- is_qwen_derived_model: Optional[bool] = Field(default=False)
555
 
556
  @field_validator("datasets", mode="before")
557
  @classmethod
 
151
  loftq_config: Optional[LoftQConfig] = None
152
 
153
 
 
 
 
 
 
 
154
  class SpecialTokensConfig(BaseModel):
155
  """Special tokens configuration subset"""
156
 
 
301
  },
302
  )
303
 
304
+ train_on_inputs: Optional[bool] = False
305
  group_by_length: Optional[bool] = None
306
 
307
  learning_rate: Union[str, float]
308
+ weight_decay: Optional[float] = 0.0
309
+ optimizer: Optional[
310
+ Union[OptimizerNames, Literal["lion_pytorch"]]
311
+ ] = OptimizerNames.ADAMW_HF.value
312
  optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
313
  default=None, metadata={"help": "Optional arguments to supply to optimizer."}
314
  )
 
319
  },
320
  )
321
  torchdistx_path: Optional[str] = None
322
+ lr_scheduler: Optional[SchedulerType] = "cosine"
323
  lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
324
  lr_quadratic_warmup: Optional[bool] = None
325
  cosine_min_lr_ratio: Optional[float] = None
 
469
  loss_watchdog_threshold: Optional[float] = None
470
  loss_watchdog_patience: Optional[int] = None
471
 
472
+ bf16: Optional[Union[Literal["auto"], bool]] = "auto"
473
  fp16: Optional[bool] = None
474
  bfloat16: Optional[bool] = None # for non-AMP cases
475
  float16: Optional[bool] = None # for non-AMP cases
 
483
 
484
  unfrozen_parameters: Optional[List[str]] = None
485
 
486
+ sequence_len: int = Field(default=512)
487
  sample_packing: Optional[bool] = None
488
  eval_sample_packing: Optional[bool] = None
489
  pad_to_sequence_len: Optional[bool] = None
 
544
  sample_packing_eff_est: Optional[float] = None
545
  axolotl_config_path: Optional[str] = None
546
 
547
+ is_falcon_derived_model: Optional[bool] = Field(default=None)
548
+ is_llama_derived_model: Optional[bool] = Field(default=None)
549
+ is_mistral_derived_model: Optional[bool] = Field(default=None)
550
+ is_qwen_derived_model: Optional[bool] = Field(default=None)
551
 
552
  @field_validator("datasets", mode="before")
553
  @classmethod
tests/test_validation.py CHANGED
@@ -54,6 +54,18 @@ class TestValidation(BaseValidation):
54
  Test the validation module
55
  """
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def test_datasets_min_length(self):
58
  cfg = DictDefault(
59
  {
 
54
  Test the validation module
55
  """
56
 
57
+ def test_defaults(self, minimal_cfg):
58
+ test_cfg = DictDefault(
59
+ {
60
+ "weight_decay": None,
61
+ }
62
+ | minimal_cfg
63
+ )
64
+ cfg = validate_config(test_cfg)
65
+
66
+ assert cfg.train_on_inputs is False
67
+ assert cfg.weight_decay is None
68
+
69
  def test_datasets_min_length(self):
70
  cfg = DictDefault(
71
  {