import argparse import os import torch import torch.nn.functional as F from torch.utils.data import DataLoader from models.tacotron2.tacotron2_ms import Tacotron2MS from utils import get_config from utils.data import ArabDataset, text_mel_collate_fn from utils.logging import TBLogger from utils.training import * parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default="configs/nawar.yaml", help="Path to yaml config file") @torch.inference_mode() def validate(model, test_loader, writer, device, n_iter): loss_sum = 0 n_test_sum = 0 model.eval() for batch in test_loader: text_padded, input_lengths, mel_padded, gate_padded, \ output_lengths = batch_to_device(batch, device) y_pred = model(text_padded, input_lengths, mel_padded, output_lengths, torch.zeros_like(output_lengths)) mel_out, mel_out_postnet, gate_pred, alignments = y_pred mel_loss = F.mse_loss(mel_out, mel_padded) + \ F.mse_loss(mel_out_postnet, mel_padded) gate_loss = F.binary_cross_entropy_with_logits(gate_pred, gate_padded) loss = mel_loss + gate_loss loss_sum += mel_padded.size(0)*loss.item() n_test_sum += mel_padded.size(0) val_loss = loss_sum / n_test_sum idx = random.randint(0, mel_padded.size(0) - 1) mel_infer, *_ = model.infer( text_padded[idx:idx+1], input_lengths[idx:idx+1]*0, input_lengths[idx:idx+1]) writer.add_sample( alignments[idx, :, :input_lengths[idx].item()], mel_out[idx], mel_padded[idx], mel_infer[0], output_lengths[idx], n_iter) writer.add_scalar('loss/val_loss', val_loss, n_iter) model.train() return val_loss def training_loop(model, optimizer, train_loader, test_loader, writer, device, config, n_epoch, n_iter): model.train() for epoch in range(n_epoch, config.epochs): print(f"Epoch: {epoch}") for batch in train_loader: text_padded, input_lengths, mel_padded, gate_padded, \ output_lengths = batch_to_device(batch, device) y_pred = model(text_padded, input_lengths, mel_padded, output_lengths, torch.zeros_like(output_lengths)) mel_out, mel_out_postnet, gate_out, _ = y_pred optimizer.zero_grad() # LOSS mel_loss = F.mse_loss(mel_out, mel_padded) + \ F.mse_loss(mel_out_postnet, mel_padded) gate_loss = F.binary_cross_entropy_with_logits( gate_out, gate_padded) loss = mel_loss + gate_loss loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), config.grad_clip_thresh) optimizer.step() # LOGGING print(f"loss: {loss.item()}, grad_norm: {grad_norm.item()}") writer.add_training_data(loss.item(), grad_norm.item(), config.learning_rate, n_iter) if n_iter % config.n_save_states_iter == 0: save_states(f'states.pth', model, optimizer, n_iter, epoch, config) if n_iter % config.n_save_backup_iter == 0 and n_iter > 0: save_states(f'states_{n_iter}.pth', model, optimizer, n_iter, epoch, config) n_iter += 1 # VALIDATE val_loss = validate(model, test_loader, writer, device, n_iter) print(f"Validation loss: {val_loss}") save_states(f'states_{n_iter}.pth', model, optimizer, n_iter, epoch, config) def main(): args = parser.parse_args() config = get_config(args.config) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set random seed if config.random_seed != False: torch.manual_seed(config.random_seed) torch.cuda.manual_seed_all(config.random_seed) import numpy as np np.random.seed(config.random_seed) # make checkpoint folder if nonexistent if not os.path.isdir(config.checkpoint_dir): os.makedirs(os.path.abspath(config.checkpoint_dir)) print(f"Created checkpoint_dir folder: {config.checkpoint_dir}") # datasets if config.cache_dataset: print('Caching datasets ...') train_dataset = ArabDataset(config.train_labels, config.train_wavs_path, cache=config.cache_dataset) test_dataset = ArabDataset(config.test_labels, config.test_wavs_path, cache=config.cache_dataset) # optional: balanced sampling sampler, shuffle, drop_last = None, True, True if config.balanced_sampling: weights = torch.load(config.sampler_weights_file) sampler = torch.utils.data.WeightedRandomSampler( weights, len(weights), replacement=False) shuffle, drop_last = False, False # dataloaders train_loader = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=text_mel_collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, drop_last=False, shuffle=False, collate_fn=text_mel_collate_fn) # construct model model = Tacotron2MS(n_symbol=40) model = model.to(device) model.decoder.decoder_max_step = config.decoder_max_step # optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) # resume from existing checkpoint n_epoch, n_iter = 0, 0 if config.restore_model != '': state_dicts = torch.load(config.restore_model) model.load_state_dict(state_dicts['model']) if 'optim' in state_dicts: optimizer.load_state_dict(state_dicts['optim']) if 'epoch' in state_dicts: n_epoch = state_dicts['epoch'] if 'iter' in state_dicts: n_iter = state_dicts['iter'] # tensorboard writer writer = TBLogger(config.log_dir) # start training training_loop(model, optimizer, train_loader, test_loader, writer, device, config, n_epoch, n_iter) if __name__ == '__main__': main()