|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.optim.lr_scheduler import ExponentialLR |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
import shutil |
|
import accelerate |
|
|
|
|
|
from models.svc.base.svc_dataset import SVCOfflineCollator, SVCOfflineDataset |
|
from models.svc.vits.vits import * |
|
from models.svc.base import SVCTrainer |
|
|
|
from utils.mel import mel_spectrogram_torch |
|
import json |
|
|
|
from models.vocoders.gan.discriminator.mpd import ( |
|
MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator, |
|
) |
|
|
|
|
|
class VitsSVCTrainer(SVCTrainer): |
|
def __init__(self, args, cfg): |
|
self.args = args |
|
self.cfg = cfg |
|
SVCTrainer.__init__(self, args, cfg) |
|
|
|
def _accelerator_prepare(self): |
|
( |
|
self.train_dataloader, |
|
self.valid_dataloader, |
|
) = self.accelerator.prepare( |
|
self.train_dataloader, |
|
self.valid_dataloader, |
|
) |
|
if isinstance(self.model, dict): |
|
for key in self.model.keys(): |
|
self.model[key] = self.accelerator.prepare(self.model[key]) |
|
else: |
|
self.model = self.accelerator.prepare(self.model) |
|
|
|
if isinstance(self.optimizer, dict): |
|
for key in self.optimizer.keys(): |
|
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key]) |
|
else: |
|
self.optimizer = self.accelerator.prepare(self.optimizer) |
|
|
|
if isinstance(self.scheduler, dict): |
|
for key in self.scheduler.keys(): |
|
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key]) |
|
else: |
|
self.scheduler = self.accelerator.prepare(self.scheduler) |
|
|
|
def _load_model( |
|
self, |
|
checkpoint_dir: str = None, |
|
checkpoint_path: str = None, |
|
resume_type: str = "", |
|
): |
|
r"""Load model from checkpoint. If checkpoint_path is None, it will |
|
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not |
|
None, it will load the checkpoint specified by checkpoint_path. **Only use this |
|
method after** ``accelerator.prepare()``. |
|
""" |
|
if checkpoint_path is None: |
|
ls = [str(i) for i in Path(checkpoint_dir).glob("*")] |
|
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) |
|
checkpoint_path = ls[0] |
|
self.logger.info("Resume from {}...".format(checkpoint_path)) |
|
|
|
if resume_type in ["resume", ""]: |
|
|
|
self.accelerator.load_state(input_dir=checkpoint_path) |
|
|
|
|
|
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 |
|
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 |
|
|
|
elif resume_type == "finetune": |
|
|
|
accelerate.load_checkpoint_and_dispatch( |
|
self.accelerator.unwrap_model(self.model["generator"]), |
|
os.path.join(checkpoint_path, "pytorch_model.bin"), |
|
) |
|
accelerate.load_checkpoint_and_dispatch( |
|
self.accelerator.unwrap_model(self.model["discriminator"]), |
|
os.path.join(checkpoint_path, "pytorch_model.bin"), |
|
) |
|
self.logger.info("Load model weights for finetune...") |
|
|
|
else: |
|
raise ValueError("Resume_type must be `resume` or `finetune`.") |
|
|
|
return checkpoint_path |
|
|
|
def _build_model(self): |
|
net_g = SynthesizerTrn( |
|
self.cfg.preprocess.n_fft // 2 + 1, |
|
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, |
|
|
|
self.cfg, |
|
) |
|
net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm) |
|
model = {"generator": net_g, "discriminator": net_d} |
|
|
|
return model |
|
|
|
def _build_dataset(self): |
|
return SVCOfflineDataset, SVCOfflineCollator |
|
|
|
def _build_optimizer(self): |
|
optimizer_g = torch.optim.AdamW( |
|
self.model["generator"].parameters(), |
|
self.cfg.train.learning_rate, |
|
betas=self.cfg.train.AdamW.betas, |
|
eps=self.cfg.train.AdamW.eps, |
|
) |
|
optimizer_d = torch.optim.AdamW( |
|
self.model["discriminator"].parameters(), |
|
self.cfg.train.learning_rate, |
|
betas=self.cfg.train.AdamW.betas, |
|
eps=self.cfg.train.AdamW.eps, |
|
) |
|
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d} |
|
|
|
return optimizer |
|
|
|
def _build_scheduler(self): |
|
scheduler_g = ExponentialLR( |
|
self.optimizer["optimizer_g"], |
|
gamma=self.cfg.train.lr_decay, |
|
last_epoch=self.epoch - 1, |
|
) |
|
scheduler_d = ExponentialLR( |
|
self.optimizer["optimizer_d"], |
|
gamma=self.cfg.train.lr_decay, |
|
last_epoch=self.epoch - 1, |
|
) |
|
|
|
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d} |
|
return scheduler |
|
|
|
def _build_criterion(self): |
|
class GeneratorLoss(nn.Module): |
|
def __init__(self, cfg): |
|
super(GeneratorLoss, self).__init__() |
|
self.cfg = cfg |
|
self.l1_loss = nn.L1Loss() |
|
|
|
def generator_loss(self, disc_outputs): |
|
loss = 0 |
|
gen_losses = [] |
|
for dg in disc_outputs: |
|
dg = dg.float() |
|
l = torch.mean((1 - dg) ** 2) |
|
gen_losses.append(l) |
|
loss += l |
|
|
|
return loss, gen_losses |
|
|
|
def feature_loss(self, fmap_r, fmap_g): |
|
loss = 0 |
|
for dr, dg in zip(fmap_r, fmap_g): |
|
for rl, gl in zip(dr, dg): |
|
rl = rl.float().detach() |
|
gl = gl.float() |
|
loss += torch.mean(torch.abs(rl - gl)) |
|
|
|
return loss * 2 |
|
|
|
def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask): |
|
""" |
|
z_p, logs_q: [b, h, t_t] |
|
m_p, logs_p: [b, h, t_t] |
|
""" |
|
z_p = z_p.float() |
|
logs_q = logs_q.float() |
|
m_p = m_p.float() |
|
logs_p = logs_p.float() |
|
z_mask = z_mask.float() |
|
|
|
kl = logs_p - logs_q - 0.5 |
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) |
|
kl = torch.sum(kl * z_mask) |
|
l = kl / torch.sum(z_mask) |
|
return l |
|
|
|
def forward( |
|
self, |
|
outputs_g, |
|
outputs_d, |
|
y_mel, |
|
y_hat_mel, |
|
): |
|
loss_g = {} |
|
|
|
|
|
loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel |
|
loss_g["loss_mel"] = loss_mel |
|
|
|
|
|
loss_kl = ( |
|
self.kl_loss( |
|
outputs_g["z_p"], |
|
outputs_g["logs_q"], |
|
outputs_g["m_p"], |
|
outputs_g["logs_p"], |
|
outputs_g["z_mask"], |
|
) |
|
* self.cfg.train.c_kl |
|
) |
|
loss_g["loss_kl"] = loss_kl |
|
|
|
|
|
loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"]) |
|
loss_g["loss_fm"] = loss_fm |
|
|
|
|
|
loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"]) |
|
loss_g["loss_gen"] = loss_gen |
|
loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen |
|
|
|
return loss_g |
|
|
|
class DiscriminatorLoss(nn.Module): |
|
def __init__(self, cfg): |
|
super(DiscriminatorLoss, self).__init__() |
|
self.cfg = cfg |
|
self.l1Loss = torch.nn.L1Loss(reduction="mean") |
|
|
|
def __call__(self, disc_real_outputs, disc_generated_outputs): |
|
loss_d = {} |
|
|
|
loss = 0 |
|
r_losses = [] |
|
g_losses = [] |
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs): |
|
dr = dr.float() |
|
dg = dg.float() |
|
r_loss = torch.mean((1 - dr) ** 2) |
|
g_loss = torch.mean(dg**2) |
|
loss += r_loss + g_loss |
|
r_losses.append(r_loss.item()) |
|
g_losses.append(g_loss.item()) |
|
|
|
loss_d["loss_disc_all"] = loss |
|
|
|
return loss_d |
|
|
|
criterion = { |
|
"generator": GeneratorLoss(self.cfg), |
|
"discriminator": DiscriminatorLoss(self.cfg), |
|
} |
|
return criterion |
|
|
|
|
|
def write_summary( |
|
self, |
|
losses, |
|
stats, |
|
images={}, |
|
audios={}, |
|
audio_sampling_rate=24000, |
|
tag="train", |
|
): |
|
for key, value in losses.items(): |
|
self.sw.add_scalar(tag + "/" + key, value, self.step) |
|
self.sw.add_scalar( |
|
"learning_rate", |
|
self.optimizer["optimizer_g"].param_groups[0]["lr"], |
|
self.step, |
|
) |
|
|
|
if len(images) != 0: |
|
for key, value in images.items(): |
|
self.sw.add_image(key, value, self.global_step, batchformats="HWC") |
|
if len(audios) != 0: |
|
for key, value in audios.items(): |
|
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) |
|
|
|
def write_valid_summary( |
|
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val" |
|
): |
|
for key, value in losses.items(): |
|
self.sw.add_scalar(tag + "/" + key, value, self.step) |
|
|
|
if len(images) != 0: |
|
for key, value in images.items(): |
|
self.sw.add_image(key, value, self.global_step, batchformats="HWC") |
|
if len(audios) != 0: |
|
for key, value in audios.items(): |
|
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate) |
|
|
|
def _get_state_dict(self): |
|
state_dict = { |
|
"generator": self.model["generator"].state_dict(), |
|
"discriminator": self.model["discriminator"].state_dict(), |
|
"optimizer_g": self.optimizer["optimizer_g"].state_dict(), |
|
"optimizer_d": self.optimizer["optimizer_d"].state_dict(), |
|
"scheduler_g": self.scheduler["scheduler_g"].state_dict(), |
|
"scheduler_d": self.scheduler["scheduler_d"].state_dict(), |
|
"step": self.step, |
|
"epoch": self.epoch, |
|
"batch_size": self.cfg.train.batch_size, |
|
} |
|
return state_dict |
|
|
|
def get_state_dict(self): |
|
state_dict = { |
|
"generator": self.model["generator"].state_dict(), |
|
"discriminator": self.model["discriminator"].state_dict(), |
|
"optimizer_g": self.optimizer["optimizer_g"].state_dict(), |
|
"optimizer_d": self.optimizer["optimizer_d"].state_dict(), |
|
"scheduler_g": self.scheduler["scheduler_g"].state_dict(), |
|
"scheduler_d": self.scheduler["scheduler_d"].state_dict(), |
|
"step": self.step, |
|
"epoch": self.epoch, |
|
"batch_size": self.cfg.train.batch_size, |
|
} |
|
return state_dict |
|
|
|
def load_model(self, checkpoint): |
|
self.step = checkpoint["step"] |
|
self.epoch = checkpoint["epoch"] |
|
self.model["generator"].load_state_dict(checkpoint["generator"]) |
|
self.model["discriminator"].load_state_dict(checkpoint["discriminator"]) |
|
self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"]) |
|
self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"]) |
|
self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"]) |
|
self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"]) |
|
|
|
@torch.inference_mode() |
|
def _valid_step(self, batch): |
|
r"""Testing forward step. Should return average loss of a sample over |
|
one batch. Provoke ``_forward_step`` is recommended except for special case. |
|
See ``_test_epoch`` for usage. |
|
""" |
|
|
|
valid_losses = {} |
|
total_loss = 0 |
|
valid_stats = {} |
|
|
|
|
|
|
|
outputs_g = self.model["generator"](batch) |
|
|
|
y_mel = slice_segments( |
|
batch["mel"].transpose(1, 2), |
|
outputs_g["ids_slice"], |
|
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, |
|
) |
|
y_hat_mel = mel_spectrogram_torch( |
|
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess |
|
) |
|
y = slice_segments( |
|
batch["audio"].unsqueeze(1), |
|
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size, |
|
self.cfg.preprocess.segment_size, |
|
) |
|
|
|
|
|
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach()) |
|
|
|
loss_d = self.criterion["discriminator"]( |
|
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"] |
|
) |
|
valid_losses.update(loss_d) |
|
|
|
|
|
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"]) |
|
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel) |
|
valid_losses.update(loss_g) |
|
|
|
for item in valid_losses: |
|
valid_losses[item] = valid_losses[item].item() |
|
|
|
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"] |
|
|
|
return ( |
|
total_loss.item(), |
|
valid_losses, |
|
valid_stats, |
|
) |
|
|
|
@torch.inference_mode() |
|
def _valid_epoch(self): |
|
r"""Testing epoch. Should return average loss of a batch (sample) over |
|
one epoch. See ``train_loop`` for usage. |
|
""" |
|
if isinstance(self.model, dict): |
|
for key in self.model.keys(): |
|
self.model[key].eval() |
|
else: |
|
self.model.eval() |
|
|
|
epoch_sum_loss = 0.0 |
|
epoch_losses = dict() |
|
for batch in tqdm( |
|
self.valid_dataloader, |
|
desc=f"Validating Epoch {self.epoch}", |
|
unit="batch", |
|
colour="GREEN", |
|
leave=False, |
|
dynamic_ncols=True, |
|
smoothing=0.04, |
|
disable=not self.accelerator.is_main_process, |
|
): |
|
total_loss, valid_losses, valid_stats = self._valid_step(batch) |
|
epoch_sum_loss += total_loss |
|
if isinstance(valid_losses, dict): |
|
for key, value in valid_losses.items(): |
|
if key not in epoch_losses.keys(): |
|
epoch_losses[key] = value |
|
else: |
|
epoch_losses[key] += value |
|
|
|
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader) |
|
for key in epoch_losses.keys(): |
|
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
return epoch_sum_loss, epoch_losses |
|
|
|
|
|
def train_loop(self): |
|
r"""Training loop. The public entry of training process.""" |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
if self.accelerator.is_main_process: |
|
self.__dump_cfg(self.config_save_path) |
|
|
|
|
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
while self.epoch < self.max_epoch: |
|
self.logger.info("\n") |
|
self.logger.info("-" * 32) |
|
self.logger.info("Epoch {}: ".format(self.epoch)) |
|
|
|
|
|
train_total_loss, train_losses = self._train_epoch() |
|
if isinstance(train_losses, dict): |
|
for key, loss in train_losses.items(): |
|
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) |
|
self.accelerator.log( |
|
{"Epoch/Train {} Loss".format(key): loss}, |
|
step=self.epoch, |
|
) |
|
|
|
valid_total_loss, valid_losses = self._valid_epoch() |
|
if isinstance(valid_losses, dict): |
|
for key, loss in valid_losses.items(): |
|
self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss)) |
|
self.accelerator.log( |
|
{"Epoch/Train {} Loss".format(key): loss}, |
|
step=self.epoch, |
|
) |
|
|
|
self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss)) |
|
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss)) |
|
self.accelerator.log( |
|
{ |
|
"Epoch/Train Loss": train_total_loss, |
|
"Epoch/Valid Loss": valid_total_loss, |
|
}, |
|
step=self.epoch, |
|
) |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
|
|
run_eval = False |
|
if self.accelerator.is_main_process: |
|
save_checkpoint = False |
|
hit_dix = [] |
|
for i, num in enumerate(self.save_checkpoint_stride): |
|
if self.epoch % num == 0: |
|
save_checkpoint = True |
|
hit_dix.append(i) |
|
run_eval |= self.run_eval[i] |
|
|
|
self.accelerator.wait_for_everyone() |
|
if self.accelerator.is_main_process and save_checkpoint: |
|
path = os.path.join( |
|
self.checkpoint_dir, |
|
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( |
|
self.epoch, self.step, train_total_loss |
|
), |
|
) |
|
self.tmp_checkpoint_save_path = path |
|
self.accelerator.save_state(path) |
|
|
|
json.dump( |
|
self.checkpoints_path, |
|
open(os.path.join(path, "ckpts.json"), "w"), |
|
ensure_ascii=False, |
|
indent=4, |
|
) |
|
self._save_auxiliary_states() |
|
|
|
|
|
to_remove = [] |
|
for idx in hit_dix: |
|
self.checkpoints_path[idx].append(path) |
|
while len(self.checkpoints_path[idx]) > self.keep_last[idx]: |
|
to_remove.append((idx, self.checkpoints_path[idx].pop(0))) |
|
|
|
|
|
total = set() |
|
for i in self.checkpoints_path: |
|
total |= set(i) |
|
do_remove = set() |
|
for idx, path in to_remove[::-1]: |
|
if path in total: |
|
self.checkpoints_path[idx].insert(0, path) |
|
else: |
|
do_remove.add(path) |
|
|
|
|
|
for path in do_remove: |
|
shutil.rmtree(path, ignore_errors=True) |
|
self.logger.debug(f"Remove old checkpoint: {path}") |
|
|
|
self.accelerator.wait_for_everyone() |
|
if run_eval: |
|
|
|
pass |
|
|
|
|
|
self.epoch += 1 |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
if self.accelerator.is_main_process: |
|
path = os.path.join( |
|
self.checkpoint_dir, |
|
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( |
|
self.epoch, self.step, valid_total_loss |
|
), |
|
) |
|
self.tmp_checkpoint_save_path = path |
|
self.accelerator.save_state( |
|
os.path.join( |
|
self.checkpoint_dir, |
|
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( |
|
self.epoch, self.step, valid_total_loss |
|
), |
|
) |
|
) |
|
|
|
json.dump( |
|
self.checkpoints_path, |
|
open(os.path.join(path, "ckpts.json"), "w"), |
|
ensure_ascii=False, |
|
indent=4, |
|
) |
|
self._save_auxiliary_states() |
|
|
|
self.accelerator.end_training() |
|
|
|
def _train_step(self, batch): |
|
r"""Forward step for training and inference. This function is called |
|
in ``_train_step`` & ``_test_step`` function. |
|
""" |
|
|
|
train_losses = {} |
|
total_loss = 0 |
|
training_stats = {} |
|
|
|
|
|
|
|
outputs_g = self.model["generator"](batch) |
|
|
|
y_mel = slice_segments( |
|
batch["mel"].transpose(1, 2), |
|
outputs_g["ids_slice"], |
|
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, |
|
) |
|
y_hat_mel = mel_spectrogram_torch( |
|
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess |
|
) |
|
|
|
y = slice_segments( |
|
|
|
batch["audio"].unsqueeze(1), |
|
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size, |
|
self.cfg.preprocess.segment_size, |
|
) |
|
|
|
|
|
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach()) |
|
|
|
loss_d = self.criterion["discriminator"]( |
|
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"] |
|
) |
|
train_losses.update(loss_d) |
|
|
|
|
|
self.optimizer["optimizer_d"].zero_grad() |
|
self.accelerator.backward(loss_d["loss_disc_all"]) |
|
self.optimizer["optimizer_d"].step() |
|
|
|
|
|
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"]) |
|
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel) |
|
train_losses.update(loss_g) |
|
|
|
|
|
self.optimizer["optimizer_g"].zero_grad() |
|
self.accelerator.backward(loss_g["loss_gen_all"]) |
|
self.optimizer["optimizer_g"].step() |
|
|
|
for item in train_losses: |
|
train_losses[item] = train_losses[item].item() |
|
|
|
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"] |
|
|
|
return ( |
|
total_loss.item(), |
|
train_losses, |
|
training_stats, |
|
) |
|
|
|
def _train_epoch(self): |
|
r"""Training epoch. Should return average loss of a batch (sample) over |
|
one epoch. See ``train_loop`` for usage. |
|
""" |
|
epoch_sum_loss: float = 0.0 |
|
epoch_losses: dict = {} |
|
epoch_step: int = 0 |
|
for batch in tqdm( |
|
self.train_dataloader, |
|
desc=f"Training Epoch {self.epoch}", |
|
unit="batch", |
|
colour="GREEN", |
|
leave=False, |
|
dynamic_ncols=True, |
|
smoothing=0.04, |
|
disable=not self.accelerator.is_main_process, |
|
): |
|
|
|
with self.accelerator.accumulate(self.model): |
|
total_loss, train_losses, training_stats = self._train_step(batch) |
|
self.batch_count += 1 |
|
|
|
|
|
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: |
|
epoch_sum_loss += total_loss |
|
for key, value in train_losses.items(): |
|
if key not in epoch_losses.keys(): |
|
epoch_losses[key] = value |
|
else: |
|
epoch_losses[key] += value |
|
|
|
self.accelerator.log( |
|
{ |
|
"Step/Generator Loss": train_losses["loss_gen_all"], |
|
"Step/Discriminator Loss": train_losses["loss_disc_all"], |
|
"Step/Generator Learning Rate": self.optimizer[ |
|
"optimizer_d" |
|
].param_groups[0]["lr"], |
|
"Step/Discriminator Learning Rate": self.optimizer[ |
|
"optimizer_g" |
|
].param_groups[0]["lr"], |
|
}, |
|
step=self.step, |
|
) |
|
self.step += 1 |
|
epoch_step += 1 |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
epoch_sum_loss = ( |
|
epoch_sum_loss |
|
/ len(self.train_dataloader) |
|
* self.cfg.train.gradient_accumulation_step |
|
) |
|
|
|
for key in epoch_losses.keys(): |
|
epoch_losses[key] = ( |
|
epoch_losses[key] |
|
/ len(self.train_dataloader) |
|
* self.cfg.train.gradient_accumulation_step |
|
) |
|
|
|
return epoch_sum_loss, epoch_losses |
|
|
|
def __dump_cfg(self, path): |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
json5.dump( |
|
self.cfg, |
|
open(path, "w"), |
|
indent=4, |
|
sort_keys=True, |
|
ensure_ascii=False, |
|
quote_keys=True, |
|
) |
|
|