File size: 2,514 Bytes
8c92a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from models.vocoders.gan.gan_vocoder_trainer import GANVocoderTrainer
from models.vocoders.diffusion.diffusion_vocoder_trainer import DiffusionVocoderTrainer
from utils.util import load_config
def build_trainer(args, cfg):
supported_trainer = {
"GANVocoder": GANVocoderTrainer,
"DiffusionVocoder": DiffusionVocoderTrainer,
}
trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer
def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--resume_type",
type=str,
help="resume for continue to train, finetune for finetuning",
)
parser.add_argument(
"--checkpoint",
type=str,
help="checkpoint to resume",
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
args = parser.parse_args()
cfg = load_config(args.config)
# Data Augmentation
if cfg.preprocess.data_augment:
new_datasets_list = []
for dataset in cfg.preprocess.data_augment:
new_datasets = [
# f"{dataset}_pitch_shift",
# f"{dataset}_formant_shift",
f"{dataset}_equalizer",
f"{dataset}_time_stretch",
]
new_datasets_list.extend(new_datasets)
cfg.dataset.extend(new_datasets_list)
# CUDA settings
cuda_relevant()
# Build trainer
trainer = build_trainer(args, cfg)
trainer.train_loop()
if __name__ == "__main__":
main()
|