Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
from grad_extend.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate | |
from grad_extend.utils import plot_tensor, save_plot, load_model, print_error | |
from grad.utils import fix_len_compatibility | |
from grad.model import GradTTS | |
# 200 frames | |
out_size = fix_len_compatibility(200) | |
def train(hps, chkpt_path=None): | |
print('Initializing logger...') | |
logger = SummaryWriter(log_dir=hps.train.log_dir) | |
print('Initializing data loaders...') | |
train_dataset = TextMelSpeakerDataset(hps.train.train_files) | |
batch_collate = TextMelSpeakerBatchCollate() | |
loader = DataLoader(dataset=train_dataset, | |
batch_size=hps.train.batch_size, | |
collate_fn=batch_collate, | |
drop_last=True, | |
num_workers=8, | |
shuffle=True) | |
test_dataset = TextMelSpeakerDataset(hps.train.valid_files) | |
print('Initializing model...') | |
model = GradTTS(hps.grad.n_mels, hps.grad.n_vecs, hps.grad.n_pits, hps.grad.n_spks, hps.grad.n_embs, | |
hps.grad.n_enc_channels, hps.grad.filter_channels, | |
hps.grad.dec_dim, hps.grad.beta_min, hps.grad.beta_max, hps.grad.pe_scale).cuda() | |
print('Number of encoder parameters = %.2fm' % (model.encoder.nparams/1e6)) | |
print('Number of decoder parameters = %.2fm' % (model.decoder.nparams/1e6)) | |
# Load Pretrain | |
if os.path.isfile(hps.train.pretrain): | |
print("Start from Grad_SVC pretrain model: %s" % hps.train.pretrain) | |
checkpoint = torch.load(hps.train.pretrain, map_location='cpu') | |
load_model(model, checkpoint['model']) | |
hps.train.learning_rate = 2e-5 | |
# fine_tune | |
model.fine_tune() | |
else: | |
print_error(10 * '~' + "No Pretrain Model" + 10 * '~') | |
print('Initializing optimizer...') | |
optim = torch.optim.Adam(params=model.parameters(), lr=hps.train.learning_rate) | |
initepoch = 1 | |
iteration = 0 | |
# Load Continue | |
if chkpt_path is not None: | |
print("Resuming from checkpoint: %s" % chkpt_path) | |
checkpoint = torch.load(chkpt_path, map_location='cpu') | |
model.load_state_dict(checkpoint['model']) | |
optim.load_state_dict(checkpoint['optim']) | |
initepoch = checkpoint['epoch'] | |
iteration = checkpoint['steps'] | |
print('Logging test batch...') | |
test_batch = test_dataset.sample_test_batch(size=hps.train.test_size) | |
for i, item in enumerate(test_batch): | |
mel = item['mel'] | |
logger.add_image(f'image_{i}/ground_truth', plot_tensor(mel.squeeze()), | |
global_step=0, dataformats='HWC') | |
save_plot(mel.squeeze(), f'{hps.train.log_dir}/original_{i}.png') | |
print('Start training...') | |
skip_diff_train = True | |
if initepoch >= hps.train.fast_epochs: | |
skip_diff_train = False | |
for epoch in range(initepoch, hps.train.full_epochs + 1): | |
if epoch % hps.train.test_step == 0: | |
model.eval() | |
print('Synthesis...') | |
with torch.no_grad(): | |
for i, item in enumerate(test_batch): | |
l_vec = item['vec'].shape[0] | |
d_vec = item['vec'].shape[1] | |
lengths_fix = fix_len_compatibility(l_vec) | |
lengths = torch.LongTensor([l_vec]).cuda() | |
vec = torch.zeros((1, lengths_fix, d_vec), dtype=torch.float32).cuda() | |
pit = torch.zeros((1, lengths_fix), dtype=torch.float32).cuda() | |
spk = item['spk'].to(torch.float32).unsqueeze(0).cuda() | |
vec[0, :l_vec, :] = item['vec'] | |
pit[0, :l_vec] = item['pit'] | |
y_enc, y_dec = model(lengths, vec, pit, spk, n_timesteps=50) | |
logger.add_image(f'image_{i}/generated_enc', | |
plot_tensor(y_enc.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
logger.add_image(f'image_{i}/generated_dec', | |
plot_tensor(y_dec.squeeze().cpu()), | |
global_step=iteration, dataformats='HWC') | |
save_plot(y_enc.squeeze().cpu(), | |
f'{hps.train.log_dir}/generated_enc_{i}.png') | |
save_plot(y_dec.squeeze().cpu(), | |
f'{hps.train.log_dir}/generated_dec_{i}.png') | |
model.train() | |
prior_losses = [] | |
diff_losses = [] | |
mel_losses = [] | |
spk_losses = [] | |
with tqdm(loader, total=len(train_dataset)//hps.train.batch_size) as progress_bar: | |
for batch in progress_bar: | |
model.zero_grad() | |
lengths = batch['lengths'].cuda() | |
vec = batch['vec'].cuda() | |
pit = batch['pit'].cuda() | |
spk = batch['spk'].cuda() | |
mel = batch['mel'].cuda() | |
prior_loss, diff_loss, mel_loss, spk_loss = model.compute_loss( | |
lengths, vec, pit, spk, | |
mel, out_size=out_size, | |
skip_diff=skip_diff_train) | |
loss = sum([prior_loss, diff_loss, mel_loss, spk_loss]) | |
loss.backward() | |
enc_grad_norm = torch.nn.utils.clip_grad_norm_(model.encoder.parameters(), | |
max_norm=1) | |
dec_grad_norm = torch.nn.utils.clip_grad_norm_(model.decoder.parameters(), | |
max_norm=1) | |
optim.step() | |
logger.add_scalar('training/mel_loss', mel_loss, | |
global_step=iteration) | |
logger.add_scalar('training/prior_loss', prior_loss, | |
global_step=iteration) | |
logger.add_scalar('training/diffusion_loss', diff_loss, | |
global_step=iteration) | |
logger.add_scalar('training/encoder_grad_norm', enc_grad_norm, | |
global_step=iteration) | |
logger.add_scalar('training/decoder_grad_norm', dec_grad_norm, | |
global_step=iteration) | |
msg = f'Epoch: {epoch}, iteration: {iteration} | ' | |
msg = msg + f'prior_loss: {prior_loss.item():.3f}, ' | |
msg = msg + f'diff_loss: {diff_loss.item():.3f}, ' | |
msg = msg + f'mel_loss: {mel_loss.item():.3f}, ' | |
msg = msg + f'spk_loss: {spk_loss.item():.3f}, ' | |
progress_bar.set_description(msg) | |
prior_losses.append(prior_loss.item()) | |
diff_losses.append(diff_loss.item()) | |
mel_losses.append(mel_loss.item()) | |
spk_losses.append(spk_loss.item()) | |
iteration += 1 | |
msg = 'Epoch %d: ' % (epoch) | |
msg += '| spk loss = %.3f ' % np.mean(spk_losses) | |
msg += '| mel loss = %.3f ' % np.mean(mel_losses) | |
msg += '| prior loss = %.3f ' % np.mean(prior_losses) | |
msg += '| diffusion loss = %.3f\n' % np.mean(diff_losses) | |
with open(f'{hps.train.log_dir}/train.log', 'a') as f: | |
f.write(msg) | |
# if (np.mean(prior_losses) < 1.05): | |
# skip_diff_train = False | |
if epoch > hps.train.fast_epochs: | |
skip_diff_train = False | |
if epoch % hps.train.save_step > 0: | |
continue | |
save_path = f"{hps.train.log_dir}/grad_svc_{epoch}.pt" | |
torch.save({ | |
'model': model.state_dict(), | |
'optim': optim.state_dict(), | |
'epoch': epoch, | |
'steps': iteration, | |
}, save_path) | |
print("Saved checkpoint to: %s" % save_path) | |