Spaces:
Running
on
A10G
Running
on
A10G
# Base configuration for training a model | |
paths: | |
run_dir: results/${project} | |
ckpt_dir: ${paths.run_dir}/checkpoints | |
hydra: | |
run: | |
dir: ${paths.run_dir} | |
# Lightning Trainer | |
trainer: | |
_target_: lightning.pytorch.trainer.Trainer | |
default_root_dir: ${paths.run_dir} | |
accelerator: gpu | |
num_nodes: 1 | |
devices: auto | |
strategy: | |
_target_: lightning.pytorch.strategies.DDPStrategy | |
process_group_backend: nccl # This should be override when training on windows | |
precision: bf16-mixed | |
# disable validation by epoch end | |
check_val_every_n_epoch: null | |
val_check_interval: 5000 | |
max_steps: 100_000 | |
# Use torch.backends.cudnn.benchmark to speed up training | |
benchmark: true | |
# Callbacks | |
callbacks: | |
model_checkpoint: | |
_target_: lightning.pytorch.callbacks.ModelCheckpoint | |
dirpath: ${paths.ckpt_dir} | |
filename: "step_{step:09d}" | |
save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt | |
save_top_k: 5 # save 5 latest checkpoints | |
monitor: step # use step to monitor checkpoints | |
mode: max # save the latest checkpoint with the highest global_step | |
every_n_epochs: null # don't save checkpoints by epoch end | |
every_n_train_steps: 5000 # save checkpoints every 5000 steps | |
auto_insert_metric_name: false | |
model_summary: | |
_target_: lightning.pytorch.callbacks.ModelSummary | |
max_depth: 2 # the maximum depth of layer nesting that the summary will include | |
learning_rate_monitor: | |
_target_: lightning.pytorch.callbacks.LearningRateMonitor | |
logging_interval: step | |
log_momentum: false | |
grad_norm_monitor: | |
_target_: fish_speech.callbacks.GradNormMonitor | |
norm_type: 2 | |
logging_interval: step | |
# Logger | |
logger: | |
tensorboard: | |
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger | |
save_dir: "${paths.run_dir}/tensorboard/" | |
name: null | |
log_graph: false | |
default_hp_metric: true | |
prefix: "" | |
# wandb: | |
# _target_: lightning.pytorch.loggers.wandb.WandbLogger | |
# # name: "" # name of the run (normally generated by wandb) | |
# save_dir: "${paths.run_dir}" | |
# offline: False | |
# id: null # pass correct id to resume experiment! | |
# anonymous: null # enable anonymous logging | |
# project: "fish-speech" | |
# log_model: False # upload lightning ckpts | |
# prefix: "" # a string to put at the beginning of metric keys | |
# # entity: "" # set to name of your wandb team | |
# group: "" | |
# tags: ["vq", "hq", "finetune"] | |
# job_type: "" | |
# Loop | |
train: true | |
test: false | |