PoTaTo721's picture
update to 1.2
69e8a46
# Base configuration for training a model
paths:
run_dir: results/${project}
ckpt_dir: ${paths.run_dir}/checkpoints
hydra:
run:
dir: ${paths.run_dir}
# Lightning Trainer
trainer:
_target_: lightning.pytorch.trainer.Trainer
default_root_dir: ${paths.run_dir}
accelerator: gpu
num_nodes: 1
devices: auto
strategy:
_target_: lightning.pytorch.strategies.DDPStrategy
process_group_backend: nccl # This should be override when training on windows
precision: bf16-mixed
# disable validation by epoch end
check_val_every_n_epoch: null
val_check_interval: 5000
max_steps: 100_000
# Use torch.backends.cudnn.benchmark to speed up training
benchmark: true
# Callbacks
callbacks:
model_checkpoint:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
dirpath: ${paths.ckpt_dir}
filename: "step_{step:09d}"
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
save_top_k: 5 # save 5 latest checkpoints
monitor: step # use step to monitor checkpoints
mode: max # save the latest checkpoint with the highest global_step
every_n_epochs: null # don't save checkpoints by epoch end
every_n_train_steps: 5000 # save checkpoints every 5000 steps
auto_insert_metric_name: false
model_summary:
_target_: lightning.pytorch.callbacks.ModelSummary
max_depth: 2 # the maximum depth of layer nesting that the summary will include
learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: step
log_momentum: false
grad_norm_monitor:
_target_: fish_speech.callbacks.GradNormMonitor
norm_type: 2
logging_interval: step
# Logger
logger:
tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
save_dir: "${paths.run_dir}/tensorboard/"
name: null
log_graph: false
default_hp_metric: true
prefix: ""
# wandb:
# _target_: lightning.pytorch.loggers.wandb.WandbLogger
# # name: "" # name of the run (normally generated by wandb)
# save_dir: "${paths.run_dir}"
# offline: False
# id: null # pass correct id to resume experiment!
# anonymous: null # enable anonymous logging
# project: "fish-speech"
# log_model: False # upload lightning ckpts
# prefix: "" # a string to put at the beginning of metric keys
# # entity: "" # set to name of your wandb team
# group: ""
# tags: ["vq", "hq", "finetune"]
# job_type: ""
# Loop
train: true
test: false