Spaces:
Runtime error
Runtime error
File size: 5,550 Bytes
c021d8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import time
import pickle
import datetime
import itertools
import numpy as np
import torch
import torch.nn.functional as F
from onmt_modules.misc import sequence_mask
from model_autopst import Generator_1 as Predictor
class Solver(object):
def __init__(self, data_loader, config, hparams):
"""Initialize configurations."""
self.data_loader = data_loader
self.hparams = hparams
self.gate_threshold = hparams.gate_threshold
self.use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:{}'.format(config.device_id) if self.use_cuda else 'cpu')
self.num_iters = config.num_iters
self.log_step = config.log_step
# Build the model
self.build_model()
def build_model(self):
self.P = Predictor(self.hparams)
self.optimizer = torch.optim.Adam(self.P.parameters(), 0.0001, [0.9, 0.999])
self.P.to(self.device)
self.BCELoss = torch.nn.BCEWithLogitsLoss().to(self.device)
def train(self):
# Set data loader
data_loader = self.data_loader
data_iter = iter(data_loader)
# Print logs in specified order
keys = ['P/loss_tx2sp', 'P/loss_stop_sp']
# Start training.
print('Start training...')
start_time = time.time()
for i in range(self.num_iters):
try:
sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
except:
data_iter = iter(data_loader)
sp_real, cep_real, cd_real, _, num_rep_sync, len_real, _, len_short_sync, spk_emb = next(data_iter)
sp_real = sp_real.to(self.device)
cep_real = cep_real.to(self.device)
cd_real = cd_real.to(self.device)
len_real = len_real.to(self.device)
spk_emb = spk_emb.to(self.device)
num_rep_sync = num_rep_sync.to(self.device)
len_short_sync = len_short_sync.to(self.device)
# real spect masks
mask_sp_real = ~sequence_mask(len_real, sp_real.size(1))
mask_long = (~mask_sp_real).float()
len_real_mask = torch.min(len_real + 10,
torch.full_like(len_real, sp_real.size(1)))
loss_tx2sp_mask = sequence_mask(len_real_mask, sp_real.size(1)).float().unsqueeze(-1)
# text input masks
codes_mask = sequence_mask(len_short_sync, num_rep_sync.size(1)).float()
# =================================================================================== #
# 2. Train #
# =================================================================================== #
self.P = self.P.train()
sp_real_sft = torch.zeros_like(sp_real)
sp_real_sft[:, 1:, :] = sp_real[:, :-1, :]
spect_pred, stop_pred_sp = self.P(cep_real.transpose(2,1),
mask_long,
codes_mask,
num_rep_sync,
len_short_sync+1,
sp_real_sft.transpose(1,0),
len_real+1,
spk_emb)
loss_tx2sp = (F.mse_loss(spect_pred.permute(1,0,2), sp_real, reduction='none')
* loss_tx2sp_mask).sum() / loss_tx2sp_mask.sum()
loss_stop_sp = self.BCELoss(stop_pred_sp.squeeze(-1).t(), mask_sp_real.float())
loss_total = loss_tx2sp + loss_stop_sp
# Backward and optimize
self.optimizer.zero_grad()
loss_total.backward()
self.optimizer.step()
# Logging
loss = {}
loss['P/loss_tx2sp'] = loss_tx2sp.item()
loss['P/loss_stop_sp'] = loss_stop_sp.item()
# =================================================================================== #
# 4. Miscellaneous #
# =================================================================================== #
# Print out training information
if (i+1) % self.log_step == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
for tag in keys:
log += ", {}: {:.8f}".format(tag, loss[tag])
print(log)
# Save model checkpoints.
if (i+1) % 10000 == 0:
torch.save({'model': self.P.state_dict(),
'optimizer': self.optimizer.state_dict()}, f'./assets/{i+1}-A.ckpt')
print('Saved model checkpoints into assets ...') |