wasm-ar-tts / wasq /scripts /train_fp.py
wasmdashai's picture
first commit
7694c84
import argparse
import os
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from models.fastpitch import net_config
from models.fastpitch.fastpitch.model import FastPitch
from models.fastpitch.fastpitch.data_function import (TTSCollate, batch_to_gpu)
from models.fastpitch.fastpitch.loss_function import FastPitchLoss
from models.fastpitch.fastpitch.attn_loss_function import AttentionBinarizationLoss
from utils.data import DynBatchDataset
from utils import get_config
from utils.training import save_states
# %%
try:
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
default="configs/nawar_fp.yaml", help="Path to yaml config file")
args = parser.parse_args()
config_path = args.config
except:
config_path = './configs/nawar_fp.yaml'
# %%
config = get_config(config_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# make checkpoint folder if nonexistent
if not os.path.isdir(config.checkpoint_dir):
os.makedirs(os.path.abspath(config.checkpoint_dir))
print(f"Created checkpoint folder @ {config.checkpoint_dir}")
train_dataset = DynBatchDataset(
txtpath=config.train_labels,
wavpath=config.train_wavs_path,
label_pattern=config.label_pattern,
f0_dict_path=config.f0_dict_path,
f0_mean=config.f0_mean, f0_std=config.f0_std,
max_lengths=config.max_lengths,
batch_sizes=config.batch_sizes,
)
# %%
collate_fn = TTSCollate()
config.batch_size = 1
sampler, shuffle, drop_last = None, True, True
train_loader = DataLoader(train_dataset,
batch_size=config.batch_size,
collate_fn=lambda x: collate_fn(x[0]),
shuffle=shuffle, drop_last=drop_last,
sampler=sampler)
# %% Generator
model = FastPitch(**net_config).to(device)
optimizer = torch.optim.AdamW(model.parameters(),
lr=config.g_lr,
betas=(config.g_beta1, config.g_beta2),
weight_decay=config.weight_decay)
criterion = FastPitchLoss()
attention_kl_loss = AttentionBinarizationLoss()
# %%
# resume from existing checkpoint
n_epoch, n_iter = 0, 0
if config.restore_model != '':
state_dicts = torch.load(config.restore_model)
model.load_state_dict(state_dicts['model'])
if 'optim' in state_dicts:
optimizer.load_state_dict(state_dicts['optim'])
if 'epoch' in state_dicts:
n_epoch = state_dicts['epoch']
if 'iter' in state_dicts:
n_iter = state_dicts['iter']
else:
# from https://catalog.ngc.nvidia.com/orgs/nvidia/teams/dle/models/fastpitch__pyt_ckpt
model_sd = torch.load('G:/models/fastpitch/nvidia_fastpitch_210824+cfg.pt')
model.load_state_dict(
{k.removeprefix('module.'): v for k, v in model_sd['state_dict'].items()})
# %%
writer = SummaryWriter(config.log_dir)
# %% TRAINING LOOP
model.train()
for epoch in range(n_epoch, config.epochs):
train_dataset.shuffle()
for batch in train_loader:
x, y, _ = batch_to_gpu(batch)
y_pred = model(x)
mel_out, *_, attn_soft, attn_hard, _, _ = y_pred
_, _, mel_padded, output_lengths, *_ = x
# generator step
loss, meta = criterion(y_pred, y)
binarization_loss = attention_kl_loss(attn_hard, attn_soft)
loss += 1.0 * binarization_loss
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), 1000.)
optimizer.step()
# LOGGING
meta['kl_loss'] = binarization_loss.clone().detach()
print(f"loss: {meta['loss'].item()} gnorm: {grad_norm}")
for k, v in meta.items():
writer.add_scalar(f'train/{k}', v.item(), n_iter)
if n_iter % config.n_save_states_iter == 0:
save_states(f'states.pth', model,
optimizer, n_iter,
epoch, net_config, config)
if n_iter % config.n_save_backup_iter == 0 and n_iter > 0:
save_states(f'states_{n_iter}.pth', model,
optimizer, n_iter,
epoch, net_config, config)
n_iter += 1
save_states(f'states.pth', model,
optimizer, n_iter,
epoch, net_config, config)
# %%
# (mel_out, 0
# dec_mask, 1
# dur_pred, 2
# log_dur_pred, 3
# pitch_pred, 4
# pitch_tgt, 5
# energy_pred, 6
# energy_tgt, 7
# attn_soft, 8
# attn_hard, 9
# attn_dur, 10
# attn_logprob, 11
# ) = model_out
# x = [text_padded, input_lengths, mel_padded, output_lengths,
# pitch_padded, energy_padded, speaker, attn_prior, audiopaths]
# y = [mel_padded, input_lengths, output_lengths]