File size: 1,099 Bytes
37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class TrainerSubConfig:
trainer_type: str = ""
trainer: dict = field(default_factory=dict)
@dataclass
class ExprimentConfig:
trainers: List[dict] = field(default_factory=lambda: [])
init_config: dict = field(default_factory=dict)
pretrained_model_name_or_path: str = ""
pretrained_unet_state_dict_path: str = ""
# expriments related parameters
linear_beta_schedule: bool = False
zero_snr: bool = False
prediction_type: Optional[str] = None
seed: Optional[int] = None
max_train_steps: int = 1000000
gradient_accumulation_steps: int = 1
learning_rate: float = 1e-4
lr_scheduler: str = "constant"
lr_warmup_steps: int = 500
use_8bit_adam: bool = False
adam_beta1: float = 0.9
adam_beta2: float = 0.999
adam_weight_decay: float = 1e-2
adam_epsilon: float = 1e-08
max_grad_norm: float = 1.0
mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"]
skip_training: bool = False
debug: bool = False |