wasm-ar-tts / wasq /utils /training.py
wasmdashai's picture
first commit
7694c84
raw
history blame
2.79 kB
import random
import torch
import torch.nn.functional as F
def save_states(fname,
model,
optimizer,
n_iter, epoch,
net_config, config):
torch.save({'model': model.state_dict(),
'optim': optimizer.state_dict(),
'epoch': epoch,
'iter': n_iter,
'config': net_config,
},
f'{config.checkpoint_dir}/{fname}')
def save_states_gan(fname,
model, model_d,
optimizer, optimizer_d,
n_iter, epoch,
net_config, config):
torch.save({'model': model.state_dict(),
'model_d': model_d.state_dict(),
'optim': optimizer.state_dict(),
'optim_d': optimizer_d.state_dict(),
'epoch': epoch, 'iter': n_iter,
'config': net_config,
},
f'{config.checkpoint_dir}/{fname}')
def batch_to_device(batch, device):
text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths = batch
text_padded = text_padded.to(device, non_blocking=True)
input_lengths = input_lengths.to(device, non_blocking=True)
mel_padded = mel_padded.to(device, non_blocking=True)
gate_padded = gate_padded.to(device, non_blocking=True)
output_lengths = output_lengths.to(device, non_blocking=True)
return (text_padded, input_lengths, mel_padded, gate_padded,
output_lengths)
@torch.inference_mode()
def validate(model, test_loader, writer, device, n_iter):
loss_sum = 0
n_test_sum = 0
model.eval()
for batch in test_loader:
text_padded, input_lengths, mel_padded, gate_padded, \
output_lengths = batch_to_device(batch, device)
y_pred = model(text_padded, input_lengths,
mel_padded, output_lengths)
mel_out, mel_out_postnet, gate_pred, alignments = y_pred
mel_loss = F.mse_loss(mel_out, mel_padded) + \
F.mse_loss(mel_out_postnet, mel_padded)
gate_loss = F.binary_cross_entropy_with_logits(gate_pred, gate_padded)
loss = mel_loss + gate_loss
loss_sum += mel_padded.size(0)*loss.item()
n_test_sum += mel_padded.size(0)
val_loss = loss_sum / n_test_sum
idx = random.randint(0, mel_padded.size(0) - 1)
mel_infer, *_ = model.infer(
text_padded[idx:idx+1], input_lengths[idx:idx+1])
writer.add_sample(
alignments[idx, :, :input_lengths[idx].item()],
mel_out[idx], mel_padded[idx], mel_infer[0],
output_lengths[idx], n_iter)
writer.add_scalar('loss/val_loss', val_loss, n_iter)
model.train()
return val_loss