import argparse import itertools import os import time import torch import torch.distributed as dist from academicodec.models.encodec.dataset import NSynthDataset from academicodec.models.encodec.loss import criterion_d from academicodec.models.encodec.loss import criterion_g from academicodec.models.encodec.loss import loss_dis from academicodec.models.encodec.loss import loss_g from academicodec.models.encodec.msstftd import MultiScaleSTFTDiscriminator from academicodec.models.encodec.net3 import SoundStream from academicodec.models.soundstream.models import MultiPeriodDiscriminator from academicodec.models.soundstream.models import MultiScaleDiscriminator from academicodec.utils import Logger from academicodec.utils import seed_everything from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm def getModelSize(model): param_size = 0 param_sum = 0 for param in model.parameters(): param_size += param.nelement() * param.element_size() param_sum += param.nelement() buffer_size = 0 buffer_sum = 0 for buffer in model.buffers(): buffer_size += buffer.nelement() * buffer.element_size() buffer_sum += buffer.nelement() all_size = (param_size + buffer_size) / 1024 / 1024 print('模型总大小为:{:.3f}MB'.format(all_size)) return (param_size, param_sum, buffer_size, buffer_sum, all_size) def get_args(): parser = argparse.ArgumentParser() parser.add_argument( '--local_rank', default=-1, type=int, help='node rank for distributed training') # args for random parser.add_argument( '--seed', type=int, default=6666, help='seed for initializing training. ') parser.add_argument( '--cudnn_deterministic', action='store_true', help='set cudnn.deterministic True') parser.add_argument( '--tensorboard', action='store_true', help='use tensorboard for logging') # args for training parser.add_argument( '--LAMBDA_WAV', type=float, default=100, help='hyper-parameter for wav time-domain loss') parser.add_argument( '--LAMBDA_ADV', type=float, default=1, help='hyper-parameter for adver loss') parser.add_argument( '--LAMBDA_FEAT', type=float, default=1, help='hyper-parameter for feat loss') parser.add_argument( '--LAMBDA_REC', type=float, default=1, help='hyper-parameter for rec loss') parser.add_argument( '--LAMBDA_COM', type=float, default=1000, help='hyper-parameter for commit loss') parser.add_argument( '--N_EPOCHS', type=int, default=100, help='Total training epoch') parser.add_argument( '--st_epoch', type=int, default=0, help='start training epoch') parser.add_argument( '--global_step', type=int, default=0, help='record the global step') parser.add_argument('--discriminator_iter_start', type=int, default=500) parser.add_argument('--BATCH_SIZE', type=int, default=10, help='batch size') parser.add_argument( '--PATH', type=str, default='model_path', help='model save path') parser.add_argument('--sr', type=int, default=16000, help='sample rate') parser.add_argument( '--print_freq', type=int, default=10, help='the print number') parser.add_argument( '--save_dir', type=str, default='log', help='log save path') parser.add_argument( '--train_data_path', type=str, # default='/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/code4/InstructTTS2/data_process/soundstream_data/train16k.lst', default="/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/data/codec_data_24k/train_valid_lists/train.lst", help='training data') parser.add_argument( '--valid_data_path', type=str, # default='/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/dongchao/code4/InstructTTS2/data_process/soundstream_data/val16k.lst', default="/apdcephfs_cq2/share_1297902/speech_user/shaunxliu/data/codec_data_24k/train_valid_lists/valid_256.lst", help='validation data') parser.add_argument( '--resume', action='store_true', help='whether re-train model') parser.add_argument( '--resume_path', type=str, default=None, help='resume_path') parser.add_argument( '--ratios', type=int, nargs='+', # probs(ratios) = hop_size default=[8, 5, 4, 2], help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)' ) parser.add_argument( '--target_bandwidths', type=float, nargs='+', # default for 16k_320d default=[1, 1.5, 2, 4, 6, 12], help='target_bandwidths of net3.py') args = parser.parse_args() time_str = time.strftime('%Y-%m-%d-%H-%M') if args.resume: args.PATH = args.resume_path # direcly use the old model path else: args.PATH = os.path.join(args.PATH, time_str) args.save_dir = os.path.join(args.save_dir, time_str) os.makedirs(args.PATH, exist_ok=True) return args def get_input(x): x = x.to(memory_format=torch.contiguous_format) return x.float() def main(): args = get_args() if args.seed is not None or args.cudnn_deterministic: seed_everything(args.seed, args.cudnn_deterministic) args.ngpus_per_node = torch.cuda.device_count() main_worker(args.local_rank, args) def main_worker(local_rank, args): rank = local_rank args.local_rank = local_rank args.global_rank = local_rank args.distributed = args.ngpus_per_node > 1 if args.ngpus_per_node > 1: from torch.distributed import init_process_group torch.cuda.set_device(local_rank) init_process_group(backend='nccl') #CUDA_VISIBLE_DEVICES = int(args.local_rank) logger = Logger(args) soundstream = SoundStream( n_filters=32, D=512, ratios=args.ratios, sample_rate=args.sr, target_bandwidths=args.target_bandwidths) msd = MultiScaleDiscriminator() mpd = MultiPeriodDiscriminator() stft_disc = MultiScaleSTFTDiscriminator(filters=32) if logger.is_primary: getModelSize(soundstream) getModelSize(msd) getModelSize(mpd) getModelSize(stft_disc) if args.distributed: soundstream = torch.nn.SyncBatchNorm.convert_sync_batchnorm(soundstream) stft_disc = torch.nn.SyncBatchNorm.convert_sync_batchnorm(stft_disc) msd = torch.nn.SyncBatchNorm.convert_sync_batchnorm(msd) mpd = torch.nn.SyncBatchNorm.convert_sync_batchnorm(mpd) # torch.distributed.barrier() args.device = torch.device('cuda', args.local_rank) soundstream.to(args.device) stft_disc.to(args.device) msd.to(args.device) mpd.to(args.device) find_unused_parameters = False if args.distributed: soundstream = DDP( soundstream, device_ids=[args.local_rank], find_unused_parameters=find_unused_parameters ) # device_ids=[args.local_rank], output_device=args.local_rank stft_disc = DDP(stft_disc, device_ids=[args.local_rank], find_unused_parameters=find_unused_parameters) msd = DDP(msd, device_ids=[args.local_rank], find_unused_parameters=find_unused_parameters) mpd = DDP(mpd, device_ids=[args.local_rank], find_unused_parameters=find_unused_parameters) # 这里之后需要看下 sr 的问题,如果输入 wav 的 sr 和 `--sr` 不一致则会有问题 logger.log_info('Training set') train_dataset = NSynthDataset(audio_dir=args.train_data_path) logger.log_info('valid set') valid_dataset = NSynthDataset(audio_dir=args.valid_data_path) args.sr = train_dataset.sr if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, drop_last=True, shuffle=True) valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_dataset) else: train_sampler = None valid_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.BATCH_SIZE, num_workers=8, sampler=train_sampler) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=args.BATCH_SIZE, num_workers=8, sampler=valid_sampler) logger.log_info("Build optimizers and lr-schedulers") optimizer_g = torch.optim.AdamW( soundstream.parameters(), lr=3e-4, betas=(0.5, 0.9)) lr_scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optimizer_g, gamma=0.999) optimizer_d = torch.optim.AdamW( itertools.chain(stft_disc.parameters(), msd.parameters(), mpd.parameters()), lr=3e-4, betas=(0.5, 0.9)) lr_scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optimizer_d, gamma=0.999) if args.resume: latest_info = torch.load(args.resume_path + '/latest.pth') args.st_epoch = latest_info['epoch'] soundstream.load_state_dict(latest_info['soundstream']) stft_disc.load_state_dict(latest_info['stft_disc']) mpd.load_state_dict(latest_info['mpd']) msd.load_state_dict(latest_info['msd']) optimizer_g.load_state_dict(latest_info['optimizer_g']) lr_scheduler_g.load_state_dict(latest_info['lr_scheduler_g']) optimizer_d.load_state_dict(latest_info['optimizer_d']) lr_scheduler_d.load_state_dict(latest_info['lr_scheduler_d']) train(args, soundstream, stft_disc, msd, mpd, train_loader, valid_loader, optimizer_g, optimizer_d, lr_scheduler_g, lr_scheduler_d, logger) def train(args, soundstream, stft_disc, msd, mpd, train_loader, valid_loader, optimizer_g, optimizer_d, lr_scheduler_g, lr_scheduler_d, logger): print('args ', args.global_rank) best_val_loss = float("inf") best_val_epoch = -1 global_step = 0 for epoch in range(args.st_epoch, args.N_EPOCHS + 1): soundstream.train() stft_disc.train() msd.train() mpd.train() train_loss_d = 0.0 train_adv_g_loss = 0.0 train_feat_loss = 0.0 train_rec_loss = 0.0 train_loss_g = 0.0 train_commit_loss = 0.0 k_iter = 0 if args.distributed: train_loader.sampler.set_epoch(epoch) for x in tqdm(train_loader): x = x.to(args.device) k_iter += 1 global_step += 1 # record the global step for optimizer_idx in [0, 1]: # we have two optimizer x_wav = get_input(x) G_x, commit_loss, last_layer = soundstream(x_wav) if optimizer_idx == 0: # update generator y_disc_r, fmap_r = stft_disc(x_wav.contiguous()) y_disc_gen, fmap_gen = stft_disc(G_x.contiguous()) y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd( x_wav.contiguous(), G_x.contiguous()) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd( x_wav.contiguous(), G_x.contiguous()) total_loss_g, rec_loss, adv_g_loss, feat_loss, d_weight = loss_g( commit_loss, x_wav, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen, global_step, y_df_hat_r, y_df_hat_g, y_ds_hat_r, y_ds_hat_g, fmap_f_r, fmap_f_g, fmap_s_r, fmap_s_g, last_layer=last_layer, is_training=True, args=args) train_commit_loss += commit_loss.item() train_loss_g += total_loss_g.item() train_adv_g_loss += adv_g_loss.item() train_feat_loss += feat_loss.item() train_rec_loss += rec_loss.item() optimizer_g.zero_grad() total_loss_g.backward() optimizer_g.step() else: # update discriminator y_disc_r_det, fmap_r_det = stft_disc(x.detach()) y_disc_gen_det, fmap_gen_det = stft_disc(G_x.detach()) # MPD y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd( x.detach(), G_x.detach()) #MSD y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd( x.detach(), G_x.detach()) loss_d = loss_dis( y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g, global_step, args) train_loss_d += loss_d.item() optimizer_d.zero_grad() loss_d.backward() optimizer_d.step() message = ''.format( epoch, k_iter, total_loss_g.item(), adv_g_loss.item(), feat_loss.item(), rec_loss.item(), commit_loss.item(), loss_d.item(), d_weight.item()) if k_iter % args.print_freq == 0: logger.log_info(message) lr_scheduler_g.step() lr_scheduler_d.step() message = ''.format( epoch, train_loss_g / len(train_loader), train_rec_loss / len(train_loader), train_adv_g_loss / len(train_loader), train_feat_loss / len(train_loader), train_commit_loss / len(train_loader)) logger.log_info(message) with torch.no_grad(): soundstream.eval() stft_disc.eval() mpd.eval() msd.eval() valid_loss_d = 0.0 valid_loss_g = 0.0 valid_commit_loss = 0.0 valid_adv_g_loss = 0.0 valid_feat_loss = 0.0 valid_rec_loss = 0.0 if args.distributed: valid_loader.sampler.set_epoch(epoch) for x in tqdm(valid_loader): x = x.to(args.device) for optimizer_idx in [0, 1]: x_wav = get_input(x) G_x, commit_loss, _ = soundstream(x_wav) if optimizer_idx == 0: valid_commit_loss += commit_loss y_disc_r, fmap_r = stft_disc(x_wav.contiguous()) y_disc_gen, fmap_gen = stft_disc(G_x.contiguous()) y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd( x_wav.contiguous(), G_x.contiguous()) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd( x_wav.contiguous(), G_x.contiguous()) total_loss_g, adv_g_loss, feat_loss, rec_loss = criterion_g( commit_loss, x_wav, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen, y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g, args=args) valid_loss_g += total_loss_g.item() valid_adv_g_loss += adv_g_loss.item() valid_feat_loss += feat_loss.item() valid_rec_loss += rec_loss.item() else: y_disc_r_det, fmap_r_det = stft_disc( x_wav.contiguous().detach()) y_disc_gen_det, fmap_gen_det = stft_disc( G_x.contiguous().detach()) y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd( x_wav.contiguous().detach(), G_x.contiguous().detach()) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd( x_wav.contiguous().detach(), G_x.contiguous().detach()) loss_d = criterion_d(y_disc_r_det, y_disc_gen_det, fmap_r_det, fmap_gen_det, y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g) valid_loss_d += loss_d.item() if dist.get_rank() == 0: best_model = soundstream.state_dict().copy() latest_model_soundstream = soundstream.state_dict().copy() latest_model_dis = stft_disc.state_dict().copy() latest_mpd = mpd.state_dict().copy() latest_msd = msd.state_dict().copy() if valid_rec_loss < best_val_loss: best_val_loss = valid_rec_loss best_val_epoch = epoch torch.save(best_model, args.PATH + '/best_' + str(epoch) + '.pth') latest_save = {} latest_save['soundstream'] = latest_model_soundstream latest_save['stft_disc'] = latest_model_dis latest_save['mpd'] = latest_mpd latest_save['msd'] = latest_msd latest_save['epoch'] = epoch latest_save['optimizer_g'] = optimizer_g.state_dict() latest_save['optimizer_d'] = optimizer_d.state_dict() latest_save['lr_scheduler_g'] = lr_scheduler_g.state_dict() latest_save['lr_scheduler_d'] = lr_scheduler_d.state_dict() torch.save(latest_save, args.PATH + '/latest.pth') message = ''.format( epoch, valid_loss_g / len(valid_loader), valid_rec_loss / len(valid_loader), valid_adv_g_loss / len(valid_loader), valid_feat_loss / len(valid_loader), valid_commit_loss / len(valid_loader), valid_loss_d / len(valid_loader), best_val_epoch) logger.log_info(message) if __name__ == '__main__': main()