maxmax20160403's picture
Upload 39 files
3aa4060
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate
from grad_extend.utils import plot_tensor, save_plot, load_model, print_error
from grad.utils import fix_len_compatibility
from grad.model import GradTTS
# 200 frames
out_size = fix_len_compatibility(200)
def train(hps, chkpt_path=None):
print('Initializing logger...')
logger = SummaryWriter(log_dir=hps.train.log_dir)
print('Initializing data loaders...')
train_dataset = TextMelSpeakerDataset(hps.train.train_files)
batch_collate = TextMelSpeakerBatchCollate()
loader = DataLoader(dataset=train_dataset,
batch_size=hps.train.batch_size,
collate_fn=batch_collate,
drop_last=True,
num_workers=8,
shuffle=True)
test_dataset = TextMelSpeakerDataset(hps.train.valid_files)
print('Initializing model...')
model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs,
hps.grad.n_enc_channels, hps.grad.filter_channels,
hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale).cuda()
print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6))
print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6))
# Load Pretrain
if os.path.isfile(hps.train.pretrain):
print("Start from Grad_SVC pretrain model: %s" % hps.train.pretrain)
checkpoint = torch.load(hps.train.pretrain, map_location='cpu')
load_model(model, checkpoint['model'])
hps.train.learning_rate = 2e-5
# fine_tune
model.fine_tune()
else:
print_error(10 * '~' + "No Pretrain Model" + 10 * '~')
print('Initializing optimizer...')
optim = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate)
initepoch = 1
iteration = 0
# Load Continue
if chkpt_path is not None:
print("Resuming from checkpoint: %s" % chkpt_path)
checkpoint = torch.load(chkpt_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
initepoch = checkpoint['epoch']
iteration = checkpoint['steps']
print('Logging test batch...')
test_batch = test_dataset.sample_test_batch(size=hps.train.test_size)
for i, item in enumerate(test_batch):
mel = item['mel']
logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()),
global_step=0, dataformats='HWC')
save_plot(mel.squeeze(), f'{hps.train.log_dir}/original_{i}.png')
print('Start training...')
skip_diff_train = True
if initepoch >= hps.train.fast_epochs:
skip_diff_train = False
for epoch in range(initepoch, hps.train.full_epochs + 1):
if epoch % hps.train.test_step == 0:
model.eval()
print('Synthesis...')
with torch.no_grad():
for i, item in enumerate(test_batch):
l_vec = item['vec'].shape[0]
d_vec = item['vec'].shape[1]
lengths_fix = fix_len_compatibility(l_vec)
lengths = torch.LongTensor([l_vec]).cuda()
vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).cuda()
pit = torch.zeros((1, lengths_fix), dtype=torch.float32).cuda()
spk = item['spk'].to(torch.float32).unsqueeze(0).cuda()
vec[0, :l_vec, :] = item['vec']
pit[0, :l_vec] = item['pit']
y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=50)
logger.add_image(f'image_{i}/generated_enc',
plot_tensor(y_enc.squeeze().cpu()),
global_step=iteration, dataformats='HWC')
logger.add_image(f'image_{i}/generated_dec',
plot_tensor(y_dec.squeeze().cpu()),
global_step=iteration, dataformats='HWC')
save_plot(y_enc.squeeze().cpu(),
f'{hps.train.log_dir}/generated_enc_{i}.png')
save_plot(y_dec.squeeze().cpu(),
f'{hps.train.log_dir}/generated_dec_{i}.png')
model.train()
prior_losses = []
diff_losses = []
mel_losses = []
spk_losses = []
with tqdm(loader, total=len(train_dataset)//hps.train.batch_size) as progress_bar:
for batch in progress_bar:
model.zero_grad()
lengths = batch['lengths'].cuda()
vec = batch['vec'].cuda()
pit = batch['pit'].cuda()
spk = batch['spk'].cuda()
mel = batch['mel'].cuda()
prior_loss, diff_loss, mel_loss, spk_loss = model.compute_loss(
lengths, vec, pit, spk,
mel, out_size=out_size,
skip_diff=skip_diff_train)
loss = sum([prior_loss, diff_loss, mel_loss, spk_loss])
loss.backward()
enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(),
max_norm=1)
dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(),
max_norm=1)
optim.step()
logger.add_scalar('training/mel_loss', mel_loss,
global_step=iteration)
logger.add_scalar('training/prior_loss', prior_loss,
global_step=iteration)
logger.add_scalar('training/diffusion_loss', diff_loss,
global_step=iteration)
logger.add_scalar('training/encoder_grad_norm', enc_grad_norm,
global_step=iteration)
logger.add_scalar('training/decoder_grad_norm', dec_grad_norm,
global_step=iteration)
msg = f'Epoch: {epoch}, iteration: {iteration} | '
msg = msg + f'prior_loss: {prior_loss.item():.3f}, '
msg = msg + f'diff_loss: {diff_loss.item():.3f}, '
msg = msg + f'mel_loss: {mel_loss.item():.3f}, '
msg = msg + f'spk_loss: {spk_loss.item():.3f}, '
progress_bar.set_description(msg)
prior_losses.append(prior_loss.item())
diff_losses.append(diff_loss.item())
mel_losses.append(mel_loss.item())
spk_losses.append(spk_loss.item())
iteration += 1
msg = 'Epoch %d: ' % (epoch)
msg += '| spk loss = %.3f ' % np.mean(spk_losses)
msg += '| mel loss = %.3f ' % np.mean(mel_losses)
msg += '| prior loss = %.3f ' % np.mean(prior_losses)
msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses)
with open(f'{hps.train.log_dir}/train.log', 'a') as f:
f.write(msg)
# if (np.mean(prior_losses) < 1.05):
# skip_diff_train = False
if epoch > hps.train.fast_epochs:
skip_diff_train = False
if epoch % hps.train.save_step > 0:
continue
save_path = f"{hps.train.log_dir}/grad_svc_{epoch}.pt"
torch.save({
'model': model.state_dict(),
'optim': optim.state_dict(),
'epoch': epoch,
'steps': iteration,
}, save_path)
print("Saved checkpoint to: %s" % save_path)