from dataclasses import dataclass | |
from pathlib import Path | |
class TrainingConfig: | |
image_size = 128 # the generated image resolution | |
train_batch_size = 4 | |
eval_batch_size = 4 # how many images to sample during evaluation | |
num_epochs = 50 | |
gradient_accumulation_steps = 1 | |
learning_rate = 1e-4 | |
lr_warmup_steps = 500 | |
save_image_epochs = 1 | |
save_model_epochs = 3 | |
mixed_precision = 'fp16' # `no` for float32, `fp16` for automatic mixed precision | |
output_dir = str(Path(__file__).parent) | |
push_to_hub = True # whether to upload the saved model to the HF Hub | |
hub_model_id = 'jmemon/ddpm-paintings-128-finetuned-celebahq' # the name of the repository to create on the HF Hub | |
hub_private_repo = False | |
overwrite_output_dir = True # overwrite the old model when re-running the notebook | |
seed = 0 | |