wasm-ar-tts / wasq /scripts /train_tc2_adv.py
wasmdashai's picture
first commit
7694c84
# %%
import argparse
import os
import torch
from torch.utils.data import DataLoader
from models.tacotron2.tacotron2_ms import Tacotron2MS
from utils import get_config
from utils.data import ArabDataset, text_mel_collate_fn
from utils.logging import TBLogger
from utils.training import batch_to_device, save_states_gan as save_states
from models.common.loss import PatchDiscriminator, extract_chunks, calc_feature_match_loss
from models.tacotron2.loss import Tacotron2Loss
# %%
try:
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
default="configs/nawar_tc2_adv.yaml", help="Path to yaml config file")
args = parser.parse_args()
config_path = args.config
except:
config_path = './configs/nawar_tc2_adv.yaml'
# %%
config = get_config(config_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set random seed
if config.random_seed != False:
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
import numpy as np
np.random.seed(config.random_seed)
# make checkpoint folder if nonexistent
if not os.path.isdir(config.checkpoint_dir):
os.makedirs(os.path.abspath(config.checkpoint_dir))
print(f"Created checkpoint_dir folder: {config.checkpoint_dir}")
# datasets
train_dataset = ArabDataset(txtpath=config.train_labels,
wavpath=config.train_wavs_path,
label_pattern=config.label_pattern)
# test_dataset = ArabDataset(config.test_labels, config.test_wavs_path)
# optional: balanced sampling
sampler, shuffle, drop_last = None, True, True
if config.balanced_sampling:
weights = torch.load(config.sampler_weights_file)
sampler = torch.utils.data.WeightedRandomSampler(
weights, len(weights), replacement=False)
shuffle, drop_last = False, False
# dataloaders
train_loader = DataLoader(train_dataset,
batch_size=config.batch_size,
collate_fn=text_mel_collate_fn,
shuffle=shuffle, drop_last=drop_last,
sampler=sampler)
# test_loader = DataLoader(test_dataset,
# batch_size=config.batch_size, drop_last=False,
# shuffle=False, collate_fn=text_mel_collate_fn)
# %% Generator
model = Tacotron2MS(n_symbol=40, num_speakers=40)
model = model.to(device)
model.decoder.decoder_max_step = config.decoder_max_step
optimizer = torch.optim.AdamW(model.parameters(),
lr=config.g_lr,
betas=(config.g_beta1, config.g_beta2),
weight_decay=config.weight_decay)
criterion = Tacotron2Loss(mel_loss_scale=1.0)
# %% Discriminator
critic = PatchDiscriminator(1, 32).to(device)
optimizer_d = torch.optim.AdamW(critic.parameters(),
lr=config.d_lr,
betas=(config.d_beta1, config.d_beta2),
weight_decay=config.weight_decay)
tar_len = 128
# %%
# resume from existing checkpoint
n_epoch, n_iter = 0, 0
if config.restore_model != '':
state_dicts = torch.load(config.restore_model)
model.load_state_dict(state_dicts['model'])
if 'model_d' in state_dicts:
critic.load_state_dict(state_dicts['model_d'], strict=False)
if 'optim' in state_dicts:
optimizer.load_state_dict(state_dicts['optim'])
if 'optim_d' in state_dicts:
optimizer_d.load_state_dict(state_dicts['optim_d'])
if 'epoch' in state_dicts:
n_epoch = state_dicts['epoch']
if 'iter' in state_dicts:
n_iter = state_dicts['iter']
# %%
# tensorboard writer
writer = TBLogger(config.log_dir)
# %%
def trunc_batch(batch, N):
return (batch[0][:N], batch[1][:N], batch[2][:N],
batch[3][:N], batch[4][:N])
# %% TRAINING LOOP
model.train()
for epoch in range(n_epoch, config.epochs):
print(f"Epoch: {epoch}")
for batch in train_loader:
if batch[-1][0] > 2000:
batch = trunc_batch(batch, 6)
text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths = batch_to_device(batch, device)
y_pred = model(text_padded, input_lengths,
mel_padded, output_lengths,
torch.zeros_like(output_lengths))
mel_out, mel_out_postnet, gate_out, alignments = y_pred
# extract chunks for critic
Nchunks = mel_out.size(0)
tar_len_ = min(output_lengths.min().item(), tar_len)
mel_ids = torch.randint(0, mel_out.size(0), (Nchunks,)).cuda(non_blocking=True)
ofx_perc = torch.rand(Nchunks).cuda(non_blocking=True)
out_lens = output_lengths[mel_ids]
ofx = (ofx_perc * (out_lens + tar_len_) - tar_len_/2) \
.clamp(out_lens*0, out_lens - tar_len_).long()
chunks_org = extract_chunks(
mel_padded, ofx, mel_ids, tar_len_) # mel_padded: B F T
chunks_gen = extract_chunks(
mel_out_postnet, ofx, mel_ids, tar_len_) # mel_out_postnet: B F T
chunks_org_ = (chunks_org.unsqueeze(1) + 4.5) / 2.5
chunks_gen_ = (chunks_gen.unsqueeze(1) + 4.5) / 2.5
# DISCRIMINATOR
d_org, fmaps_org = critic(chunks_org_.requires_grad_(True))
d_gen, _ = critic(chunks_gen_.detach())
loss_d = 0.5*(d_org - 1).square().mean() + 0.5*d_gen.square().mean()
critic.zero_grad()
loss_d.backward()
optimizer_d.step()
# GENERATOR
loss, meta = criterion(mel_out, mel_out_postnet, mel_padded,
gate_out, gate_padded)
d_gen2, fmaps_gen = critic(chunks_gen_)
loss_score = (d_gen2 - 1).square().mean()
loss_fmatch = calc_feature_match_loss(fmaps_gen, fmaps_org)
loss += config.gan_loss_weight * loss_score
loss += config.feat_loss_weight * loss_fmatch
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), config.grad_clip_thresh)
optimizer.step()
# LOGGING
meta['score'] = loss_score.clone().detach()
meta['fmatch'] = loss_fmatch.clone().detach()
meta['loss'] = loss.clone().detach()
print(f"loss: {loss.item()}, grad_norm: {grad_norm.item()}")
writer.add_training_data(meta, grad_norm.item(),
config.learning_rate, n_iter)
if n_iter % config.n_save_states_iter == 0:
save_states(f'states.pth', model, critic,
optimizer, optimizer_d, n_iter,
epoch, None, config)
if n_iter % config.n_save_backup_iter == 0 and n_iter > 0:
save_states(f'states_{n_iter}.pth', model, critic,
optimizer, optimizer_d, n_iter,
epoch, None, config)
n_iter += 1
# VALIDATE
# val_loss = validate(model, test_loader, writer, device, n_iter)
# print(f"Validation loss: {val_loss}")
save_states(f'states.pth', model, critic,
optimizer, optimizer_d, n_iter,
epoch, None, config)
# %%