|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from models.base.new_trainer import BaseTrainer |
|
from models.svc.base.svc_dataset import ( |
|
SVCOfflineCollator, |
|
SVCOfflineDataset, |
|
SVCOnlineCollator, |
|
SVCOnlineDataset, |
|
) |
|
from processors.audio_features_extractor import AudioFeaturesExtractor |
|
from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema |
|
|
|
EPS = 1.0e-12 |
|
|
|
|
|
class SVCTrainer(BaseTrainer): |
|
r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements |
|
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this |
|
class, and implement ``_build_model``, ``_forward_step``. |
|
""" |
|
|
|
def __init__(self, args=None, cfg=None): |
|
self.args = args |
|
self.cfg = cfg |
|
|
|
self._init_accelerator() |
|
|
|
|
|
with self.accelerator.main_process_first(): |
|
self.singers = self._build_singer_lut() |
|
|
|
|
|
BaseTrainer.__init__(self, args, cfg) |
|
|
|
|
|
self.task_type = "SVC" |
|
self.logger.info("Task type: {}".format(self.task_type)) |
|
|
|
|
|
def _build_dataset(self): |
|
self.online_features_extraction = ( |
|
self.cfg.preprocess.features_extraction_mode == "online" |
|
) |
|
|
|
if not self.online_features_extraction: |
|
return SVCOfflineDataset, SVCOfflineCollator |
|
else: |
|
self.audio_features_extractor = AudioFeaturesExtractor(self.cfg) |
|
return SVCOnlineDataset, SVCOnlineCollator |
|
|
|
def _extract_svc_features(self, batch): |
|
""" |
|
Features extraction during training |
|
|
|
Batch: |
|
wav: (B, T) |
|
wav_len: (B) |
|
target_len: (B) |
|
mask: (B, n_frames, 1) |
|
spk_id: (B, 1) |
|
|
|
wav_{sr}: (B, T) |
|
wav_{sr}_len: (B) |
|
|
|
Added elements when output: |
|
mel: (B, n_frames, n_mels) |
|
frame_pitch: (B, n_frames) |
|
frame_uv: (B, n_frames) |
|
frame_energy: (B, n_frames) |
|
frame_{content}: (B, n_frames, D) |
|
""" |
|
|
|
padded_n_frames = torch.max(batch["target_len"]) |
|
final_n_frames = padded_n_frames |
|
|
|
|
|
if self.cfg.preprocess.use_mel: |
|
|
|
raw_mel = self.audio_features_extractor.get_mel_spectrogram(batch["wav"]) |
|
if self.cfg.preprocess.use_min_max_norm_mel: |
|
|
|
|
|
|
|
if not hasattr(self, "mel_extrema"): |
|
|
|
m, M = load_mel_extrema(self.cfg.preprocess, "vctk") |
|
|
|
m = ( |
|
torch.as_tensor(m, device=raw_mel.device) |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
) |
|
M = ( |
|
torch.as_tensor(M, device=raw_mel.device) |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
) |
|
self.mel_extrema = m, M |
|
|
|
m, M = self.mel_extrema |
|
mel = (raw_mel - m) / (M - m + EPS) * 2 - 1 |
|
|
|
else: |
|
mel = raw_mel |
|
|
|
final_n_frames = min(final_n_frames, mel.size(-1)) |
|
|
|
|
|
batch["mel"] = mel.transpose(1, 2) |
|
else: |
|
raw_mel = None |
|
|
|
|
|
if self.cfg.preprocess.use_frame_pitch: |
|
|
|
raw_f0, raw_uv = self.audio_features_extractor.get_f0( |
|
batch["wav"], |
|
wav_lens=batch["wav_len"], |
|
use_interpolate=self.cfg.preprocess.use_interpolation_for_uv, |
|
return_uv=True, |
|
) |
|
final_n_frames = min(final_n_frames, raw_f0.size(-1)) |
|
batch["frame_pitch"] = raw_f0 |
|
|
|
if self.cfg.preprocess.use_uv: |
|
batch["frame_uv"] = raw_uv |
|
|
|
|
|
if self.cfg.preprocess.use_frame_energy: |
|
|
|
raw_energy = self.audio_features_extractor.get_energy( |
|
batch["wav"], mel_spec=raw_mel |
|
) |
|
final_n_frames = min(final_n_frames, raw_energy.size(-1)) |
|
batch["frame_energy"] = raw_energy |
|
|
|
|
|
if self.cfg.model.condition_encoder.use_whisper: |
|
|
|
whisper_feats = self.audio_features_extractor.get_whisper_features( |
|
wavs=batch["wav_{}".format(self.cfg.preprocess.whisper_sample_rate)], |
|
target_frame_len=padded_n_frames, |
|
) |
|
final_n_frames = min(final_n_frames, whisper_feats.size(1)) |
|
batch["whisper_feat"] = whisper_feats |
|
|
|
if self.cfg.model.condition_encoder.use_contentvec: |
|
|
|
contentvec_feats = self.audio_features_extractor.get_contentvec_features( |
|
wavs=batch["wav_{}".format(self.cfg.preprocess.contentvec_sample_rate)], |
|
target_frame_len=padded_n_frames, |
|
) |
|
final_n_frames = min(final_n_frames, contentvec_feats.size(1)) |
|
batch["contentvec_feat"] = contentvec_feats |
|
|
|
if self.cfg.model.condition_encoder.use_wenet: |
|
|
|
wenet_feats = self.audio_features_extractor.get_wenet_features( |
|
wavs=batch["wav_{}".format(self.cfg.preprocess.wenet_sample_rate)], |
|
target_frame_len=padded_n_frames, |
|
wav_lens=batch[ |
|
"wav_{}_len".format(self.cfg.preprocess.wenet_sample_rate) |
|
], |
|
) |
|
final_n_frames = min(final_n_frames, wenet_feats.size(1)) |
|
batch["wenet_feat"] = wenet_feats |
|
|
|
|
|
frame_level_features = [ |
|
"mask", |
|
"mel", |
|
"frame_pitch", |
|
"frame_uv", |
|
"frame_energy", |
|
"whisper_feat", |
|
"contentvec_feat", |
|
"wenet_feat", |
|
] |
|
for k in frame_level_features: |
|
if k in batch: |
|
|
|
batch[k] = batch[k][:, :final_n_frames].contiguous() |
|
|
|
return batch |
|
|
|
@staticmethod |
|
def _build_criterion(): |
|
criterion = nn.MSELoss(reduction="none") |
|
return criterion |
|
|
|
@staticmethod |
|
def _compute_loss(criterion, y_pred, y_gt, loss_mask): |
|
""" |
|
Args: |
|
criterion: MSELoss(reduction='none') |
|
y_pred, y_gt: (B, seq_len, D) |
|
loss_mask: (B, seq_len, 1) |
|
Returns: |
|
loss: Tensor of shape [] |
|
""" |
|
|
|
|
|
loss = criterion(y_pred, y_gt) |
|
|
|
loss_mask = loss_mask.repeat(1, 1, loss.shape[-1]) |
|
|
|
loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask) |
|
return loss |
|
|
|
def _save_auxiliary_states(self): |
|
""" |
|
To save the singer's look-up table in the checkpoint saving path |
|
""" |
|
with open( |
|
os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), |
|
"w", |
|
encoding="utf-8", |
|
) as f: |
|
json.dump(self.singers, f, indent=4, ensure_ascii=False) |
|
|
|
def _build_singer_lut(self): |
|
resumed_singer_path = None |
|
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": |
|
resumed_singer_path = os.path.join( |
|
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id |
|
) |
|
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): |
|
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) |
|
|
|
if resumed_singer_path: |
|
with open(resumed_singer_path, "r") as f: |
|
singers = json.load(f) |
|
else: |
|
singers = dict() |
|
|
|
for dataset in self.cfg.dataset: |
|
singer_lut_path = os.path.join( |
|
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id |
|
) |
|
with open(singer_lut_path, "r") as singer_lut_path: |
|
singer_lut = json.load(singer_lut_path) |
|
for singer in singer_lut.keys(): |
|
if singer not in singers: |
|
singers[singer] = len(singers) |
|
|
|
with open( |
|
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" |
|
) as singer_file: |
|
json.dump(singers, singer_file, indent=4, ensure_ascii=False) |
|
print( |
|
"singers have been dumped to {}".format( |
|
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) |
|
) |
|
) |
|
return singers |
|
|