|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import os |
|
import json5 |
|
from tqdm import tqdm |
|
import json |
|
import shutil |
|
|
|
from models.svc.base import SVCTrainer |
|
from modules.encoder.condition_encoder import ConditionEncoder |
|
from models.svc.comosvc.comosvc import ComoSVC |
|
|
|
|
|
class ComoSVCTrainer(SVCTrainer): |
|
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and |
|
implements ``_build_model`` and ``_forward_step`` methods. |
|
""" |
|
|
|
def __init__(self, args=None, cfg=None): |
|
SVCTrainer.__init__(self, args, cfg) |
|
self.distill = cfg.model.comosvc.distill |
|
self.skip_diff = True |
|
|
|
|
|
|
|
def _load_teacher_model(self, model): |
|
r"""Load teacher model from checkpoint file.""" |
|
self.checkpoint_file = self.teacher_model_path |
|
self.logger.info( |
|
"Load teacher acoustic model from {}".format(self.checkpoint_file) |
|
) |
|
raw_dict = torch.load(self.checkpoint_file) |
|
model.load_state_dict(raw_dict) |
|
|
|
def _build_model(self): |
|
r"""Build the model for training. This function is called in ``__init__`` function.""" |
|
|
|
|
|
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min |
|
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max |
|
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) |
|
self.acoustic_mapper = ComoSVC(self.cfg) |
|
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) |
|
if self.cfg.model.comosvc.distill: |
|
if not self.args.resume: |
|
|
|
self.teacher_model_path = self.cfg.model.teacher_model_path |
|
self._load_teacher_model(model) |
|
|
|
self.acoustic_mapper.decoder.init_consistency_training() |
|
self.freeze_net(self.condition_encoder) |
|
self.freeze_net(self.acoustic_mapper.encoder) |
|
self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_pretrained) |
|
self.freeze_net(self.acoustic_mapper.decoder.denoise_fn_ema) |
|
return model |
|
|
|
def freeze_net(self, model): |
|
r"""Freeze the model for training.""" |
|
for name, param in model.named_parameters(): |
|
param.requires_grad = False |
|
|
|
def __build_optimizer(self): |
|
r"""Build optimizer for training. This function is called in ``__init__`` function.""" |
|
|
|
if self.cfg.train.optimizer.lower() == "adamw": |
|
optimizer = torch.optim.AdamW( |
|
params=filter(lambda p: p.requires_grad, self.model.parameters()), |
|
**self.cfg.train.adamw, |
|
) |
|
|
|
else: |
|
raise NotImplementedError( |
|
"Not support optimizer: {}".format(self.cfg.train.optimizer) |
|
) |
|
|
|
return optimizer |
|
|
|
def _forward_step(self, batch): |
|
r"""Forward step for training and inference. This function is called |
|
in ``_train_step`` & ``_test_step`` function. |
|
""" |
|
loss = {} |
|
mask = batch["mask"] |
|
mel_input = batch["mel"] |
|
cond = self.condition_encoder(batch) |
|
if self.distill: |
|
cond = cond.detach() |
|
self.skip_diff = True if self.step < self.cfg.train.fast_steps else False |
|
ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss( |
|
mask, cond, mel_input, skip_diff=self.skip_diff |
|
) |
|
if self.distill: |
|
loss["distil_loss"] = diff_loss |
|
else: |
|
loss["ssim_loss_encoder"] = ssim_loss |
|
loss["prior_loss_encoder"] = prior_loss |
|
loss["diffusion_loss_decoder"] = diff_loss |
|
|
|
return loss |
|
|
|
def _train_epoch(self): |
|
r"""Training epoch. Should return average loss of a batch (sample) over |
|
one epoch. See ``train_loop`` for usage. |
|
""" |
|
self.model.train() |
|
epoch_sum_loss: float = 0.0 |
|
epoch_step: int = 0 |
|
for batch in tqdm( |
|
self.train_dataloader, |
|
desc=f"Training Epoch {self.epoch}", |
|
unit="batch", |
|
colour="GREEN", |
|
leave=False, |
|
dynamic_ncols=True, |
|
smoothing=0.04, |
|
disable=not self.accelerator.is_main_process, |
|
): |
|
|
|
with self.accelerator.accumulate(self.model): |
|
loss = self._train_step(batch) |
|
total_loss = 0 |
|
for k, v in loss.items(): |
|
total_loss += v |
|
self.accelerator.backward(total_loss) |
|
enc_grad_norm = torch.nn.utils.clip_grad_norm_( |
|
self.acoustic_mapper.encoder.parameters(), max_norm=1 |
|
) |
|
dec_grad_norm = torch.nn.utils.clip_grad_norm_( |
|
self.acoustic_mapper.decoder.parameters(), max_norm=1 |
|
) |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
self.batch_count += 1 |
|
|
|
|
|
|
|
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: |
|
epoch_sum_loss += total_loss |
|
log_info = {} |
|
for k, v in loss.items(): |
|
key = "Step/Train Loss/{}".format(k) |
|
log_info[key] = v |
|
log_info["Step/Learning Rate"] = self.optimizer.param_groups[0]["lr"] |
|
self.accelerator.log( |
|
log_info, |
|
step=self.step, |
|
) |
|
self.step += 1 |
|
epoch_step += 1 |
|
|
|
self.accelerator.wait_for_everyone() |
|
return ( |
|
epoch_sum_loss |
|
/ len(self.train_dataloader) |
|
* self.cfg.train.gradient_accumulation_step, |
|
loss, |
|
) |
|
|
|
def train_loop(self): |
|
r"""Training loop. The public entry of training process.""" |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
if self.accelerator.is_main_process: |
|
self.__dump_cfg(self.config_save_path) |
|
self.model.train() |
|
self.optimizer.zero_grad() |
|
|
|
self.accelerator.wait_for_everyone() |
|
while self.epoch < self.max_epoch: |
|
self.logger.info("\n") |
|
self.logger.info("-" * 32) |
|
self.logger.info("Epoch {}: ".format(self.epoch)) |
|
|
|
|
|
|
|
|
|
train_loss, loss = self._train_epoch() |
|
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) |
|
for k, v in loss.items(): |
|
self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v)) |
|
valid_loss = self._valid_epoch() |
|
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) |
|
self.accelerator.log( |
|
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, |
|
step=self.epoch, |
|
) |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
self.scheduler.step(valid_loss) |
|
|
|
|
|
run_eval = False |
|
if self.accelerator.is_main_process: |
|
save_checkpoint = False |
|
hit_dix = [] |
|
for i, num in enumerate(self.save_checkpoint_stride): |
|
if self.epoch % num == 0: |
|
save_checkpoint = True |
|
hit_dix.append(i) |
|
run_eval |= self.run_eval[i] |
|
|
|
self.accelerator.wait_for_everyone() |
|
if ( |
|
self.accelerator.is_main_process |
|
and save_checkpoint |
|
and (self.distill or not self.skip_diff) |
|
): |
|
path = os.path.join( |
|
self.checkpoint_dir, |
|
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( |
|
self.epoch, self.step, train_loss |
|
), |
|
) |
|
self.tmp_checkpoint_save_path = path |
|
self.accelerator.save_state(path) |
|
print(f"save checkpoint in {path}") |
|
json.dump( |
|
self.checkpoints_path, |
|
open(os.path.join(path, "ckpts.json"), "w"), |
|
ensure_ascii=False, |
|
indent=4, |
|
) |
|
self._save_auxiliary_states() |
|
|
|
|
|
to_remove = [] |
|
for idx in hit_dix: |
|
self.checkpoints_path[idx].append(path) |
|
while len(self.checkpoints_path[idx]) > self.keep_last[idx]: |
|
to_remove.append((idx, self.checkpoints_path[idx].pop(0))) |
|
|
|
|
|
total = set() |
|
for i in self.checkpoints_path: |
|
total |= set(i) |
|
do_remove = set() |
|
for idx, path in to_remove[::-1]: |
|
if path in total: |
|
self.checkpoints_path[idx].insert(0, path) |
|
else: |
|
do_remove.add(path) |
|
|
|
|
|
for path in do_remove: |
|
shutil.rmtree(path, ignore_errors=True) |
|
self.logger.debug(f"Remove old checkpoint: {path}") |
|
|
|
self.accelerator.wait_for_everyone() |
|
if run_eval: |
|
|
|
pass |
|
|
|
|
|
self.epoch += 1 |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
if self.accelerator.is_main_process: |
|
self.accelerator.save_state( |
|
os.path.join( |
|
self.checkpoint_dir, |
|
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( |
|
self.epoch, self.step, valid_loss |
|
), |
|
) |
|
) |
|
self._save_auxiliary_states() |
|
self.accelerator.end_training() |
|
|
|
@torch.inference_mode() |
|
def _valid_epoch(self): |
|
r"""Testing epoch. Should return average loss of a batch (sample) over |
|
one epoch. See ``train_loop`` for usage. |
|
""" |
|
self.model.eval() |
|
epoch_sum_loss = 0.0 |
|
for batch in tqdm( |
|
self.valid_dataloader, |
|
desc=f"Validating Epoch {self.epoch}", |
|
unit="batch", |
|
colour="GREEN", |
|
leave=False, |
|
dynamic_ncols=True, |
|
smoothing=0.04, |
|
disable=not self.accelerator.is_main_process, |
|
): |
|
batch_loss = self._valid_step(batch) |
|
for k, v in batch_loss.items(): |
|
epoch_sum_loss += v |
|
|
|
self.accelerator.wait_for_everyone() |
|
return epoch_sum_loss / len(self.valid_dataloader) |
|
|
|
@staticmethod |
|
def __count_parameters(model): |
|
model_param = 0.0 |
|
if isinstance(model, dict): |
|
for key, value in model.items(): |
|
model_param += sum(p.numel() for p in model[key].parameters()) |
|
else: |
|
model_param = sum(p.numel() for p in model.parameters()) |
|
return model_param |
|
|
|
def __dump_cfg(self, path): |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
json5.dump( |
|
self.cfg, |
|
open(path, "w"), |
|
indent=4, |
|
sort_keys=True, |
|
ensure_ascii=False, |
|
quote_keys=True, |
|
) |
|
|