|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import copy |
|
import numpy as np |
|
import math |
|
from tqdm.auto import tqdm |
|
|
|
from utils.ssim import SSIM |
|
|
|
from models.svc.transformer.conformer import Conformer, BaseModule |
|
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper |
|
|
|
|
|
class Consistency(nn.Module): |
|
def __init__(self, cfg, distill=False): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.denoise_fn = DiffusionWrapper(self.cfg) |
|
self.cfg = cfg.model.comosvc |
|
self.teacher = not distill |
|
self.P_mean = self.cfg.P_mean |
|
self.P_std = self.cfg.P_std |
|
self.sigma_data = self.cfg.sigma_data |
|
self.sigma_min = self.cfg.sigma_min |
|
self.sigma_max = self.cfg.sigma_max |
|
self.rho = self.cfg.rho |
|
self.N = self.cfg.n_timesteps |
|
self.ssim_loss = SSIM() |
|
|
|
|
|
step_indices = torch.arange(self.N) |
|
|
|
t_steps = ( |
|
self.sigma_min ** (1 / self.rho) |
|
+ step_indices |
|
/ (self.N - 1) |
|
* (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) |
|
) ** self.rho |
|
self.t_steps = torch.cat( |
|
[torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)] |
|
) |
|
|
|
def init_consistency_training(self): |
|
self.denoise_fn_ema = copy.deepcopy(self.denoise_fn) |
|
self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn) |
|
|
|
def EDMPrecond(self, x, sigma, cond, denoise_fn): |
|
""" |
|
karras diffusion reverse process |
|
|
|
Args: |
|
x: noisy mel-spectrogram [B x n_mel x L] |
|
sigma: noise level [B x 1 x 1] |
|
cond: output of conformer encoder [B x n_mel x L] |
|
denoise_fn: denoiser neural network e.g. DilatedCNN |
|
|
|
Returns: |
|
denoised mel-spectrogram [B x n_mel x L] |
|
""" |
|
sigma = sigma.reshape(-1, 1, 1) |
|
|
|
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) |
|
c_out = ( |
|
(sigma - self.sigma_min) |
|
* self.sigma_data |
|
/ (sigma**2 + self.sigma_data**2).sqrt() |
|
) |
|
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() |
|
c_noise = sigma.log() / 4 |
|
|
|
x_in = c_in * x |
|
x_in = x_in.transpose(1, 2) |
|
x = x.transpose(1, 2) |
|
cond = cond.transpose(1, 2) |
|
c_noise = c_noise.squeeze() |
|
if c_noise.dim() == 0: |
|
c_noise = c_noise.unsqueeze(0) |
|
F_x = denoise_fn(x_in, c_noise, cond) |
|
D_x = c_skip * x + c_out * (F_x) |
|
D_x = D_x.transpose(1, 2) |
|
return D_x |
|
|
|
def EDMLoss(self, x_start, cond, mask): |
|
""" |
|
compute loss for EDM model |
|
|
|
Args: |
|
x_start: ground truth mel-spectrogram [B x n_mel x L] |
|
cond: output of conformer encoder [B x n_mel x L] |
|
mask: mask of padded frames [B x n_mel x L] |
|
""" |
|
rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device) |
|
sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
|
weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 |
|
|
|
|
|
noise = (torch.randn_like(x_start) + cond) * sigma |
|
D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn) |
|
loss = weight * ((D_yn - x_start) ** 2) |
|
loss = torch.sum(loss * mask) / torch.sum(mask) |
|
return loss |
|
|
|
def round_sigma(self, sigma): |
|
return torch.as_tensor(sigma) |
|
|
|
def edm_sampler( |
|
self, |
|
latents, |
|
cond, |
|
nonpadding, |
|
num_steps=50, |
|
sigma_min=0.002, |
|
sigma_max=80, |
|
rho=7, |
|
S_churn=0, |
|
S_min=0, |
|
S_max=float("inf"), |
|
S_noise=1, |
|
): |
|
""" |
|
karras diffusion sampler |
|
|
|
Args: |
|
latents: noisy mel-spectrogram [B x n_mel x L] |
|
cond: output of conformer encoder [B x n_mel x L] |
|
nonpadding: mask of padded frames [B x n_mel x L] |
|
num_steps: number of steps for diffusion inference |
|
|
|
Returns: |
|
denoised mel-spectrogram [B x n_mel x L] |
|
""" |
|
|
|
|
|
num_steps = num_steps + 1 |
|
step_indices = torch.arange(num_steps, device=latents.device) |
|
t_steps = ( |
|
sigma_max ** (1 / rho) |
|
+ step_indices |
|
/ (num_steps - 1) |
|
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) |
|
) ** rho |
|
t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) |
|
|
|
|
|
x_next = latents * t_steps[0] |
|
|
|
bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:]))) |
|
for i, (t_cur, t_next) in bar: |
|
x_cur = x_next |
|
|
|
gamma = ( |
|
min(S_churn / num_steps, np.sqrt(2) - 1) |
|
if S_min <= t_cur <= S_max |
|
else 0 |
|
) |
|
t_hat = self.round_sigma(t_cur + gamma * t_cur) |
|
t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device) |
|
t[:, 0, 0] = t_hat |
|
t_hat = t |
|
x_hat = x_cur + ( |
|
t_hat**2 - t_cur**2 |
|
).sqrt() * S_noise * torch.randn_like(x_cur) |
|
|
|
denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn) |
|
d_cur = (x_hat - denoised) / t_hat |
|
x_next = x_hat + (t_next - t_hat) * d_cur |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x_next |
|
|
|
def CTLoss_D(self, y, cond, mask): |
|
""" |
|
compute loss for consistency distillation |
|
|
|
Args: |
|
y: ground truth mel-spectrogram [B x n_mel x L] |
|
cond: output of conformer encoder [B x n_mel x L] |
|
mask: mask of padded frames [B x n_mel x L] |
|
""" |
|
with torch.no_grad(): |
|
mu = 0.95 |
|
for p, ema_p in zip( |
|
self.denoise_fn.parameters(), self.denoise_fn_ema.parameters() |
|
): |
|
ema_p.mul_(mu).add_(p, alpha=1 - mu) |
|
|
|
n = torch.randint(1, self.N, (y.shape[0],)) |
|
z = torch.randn_like(y) + cond |
|
|
|
tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device) |
|
f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn) |
|
|
|
with torch.no_grad(): |
|
tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device) |
|
|
|
|
|
x_hat = y + tn_1 * z |
|
denoised = self.EDMPrecond(x_hat, tn_1, cond, self.denoise_fn_pretrained) |
|
d_cur = (x_hat - denoised) / tn_1 |
|
y_tn = x_hat + (tn - tn_1) * d_cur |
|
|
|
|
|
|
|
denoised2 = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_pretrained) |
|
d_prime = (y_tn - denoised2) / tn |
|
y_tn = x_hat + (tn - tn_1) * (0.5 * d_cur + 0.5 * d_prime) |
|
|
|
f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema) |
|
|
|
loss = (f_theta - f_theta_ema.detach()) ** 2 |
|
loss = torch.sum(loss * mask) / torch.sum(mask) |
|
|
|
|
|
if torch.any(torch.isnan(loss)): |
|
print("nan loss") |
|
if torch.any(torch.isnan(f_theta)): |
|
print("nan f_theta") |
|
if torch.any(torch.isnan(f_theta_ema)): |
|
print("nan f_theta_ema") |
|
|
|
return loss |
|
|
|
def get_t_steps(self, N): |
|
N = N + 1 |
|
step_indices = torch.arange(N) |
|
t_steps = ( |
|
self.sigma_min ** (1 / self.rho) |
|
+ step_indices |
|
/ (N - 1) |
|
* (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) |
|
) ** self.rho |
|
|
|
return t_steps.flip(0) |
|
|
|
def CT_sampler(self, latents, cond, nonpadding, t_steps=1): |
|
""" |
|
consistency distillation sampler |
|
|
|
Args: |
|
latents: noisy mel-spectrogram [B x n_mel x L] |
|
cond: output of conformer encoder [B x n_mel x L] |
|
nonpadding: mask of padded frames [B x n_mel x L] |
|
t_steps: number of steps for diffusion inference |
|
|
|
Returns: |
|
denoised mel-spectrogram [B x n_mel x L] |
|
""" |
|
|
|
if t_steps == 1: |
|
t_steps = [80] |
|
|
|
else: |
|
t_steps = self.get_t_steps(t_steps) |
|
|
|
t_steps = torch.as_tensor(t_steps).to(latents.device) |
|
latents = latents * t_steps[0] |
|
_t = torch.zeros((latents.shape[0], 1, 1), device=latents.device) |
|
_t[:, 0, 0] = t_steps[0] |
|
x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema) |
|
|
|
for t in t_steps[1:-1]: |
|
z = torch.randn_like(x) + cond |
|
x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z |
|
_t = torch.zeros((x.shape[0], 1, 1), device=x.device) |
|
_t[:, 0, 0] = t |
|
t = _t |
|
x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema) |
|
return x |
|
|
|
def forward(self, x, nonpadding, cond, t_steps=1, infer=False): |
|
""" |
|
calculate loss or sample mel-spectrogram |
|
|
|
Args: |
|
x: |
|
training: ground truth mel-spectrogram [B x n_mel x L] |
|
inference: output of encoder [B x n_mel x L] |
|
""" |
|
if self.teacher: |
|
if not infer: |
|
loss = self.EDMLoss(x, cond, nonpadding) |
|
return loss |
|
else: |
|
shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) |
|
x = torch.randn(shape, device=x.device) + cond |
|
x = self.edm_sampler(x, cond, nonpadding, t_steps) |
|
|
|
return x |
|
else: |
|
if not infer: |
|
loss = self.CTLoss_D(x, cond, nonpadding) |
|
return loss |
|
|
|
else: |
|
shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) |
|
x = torch.randn(shape, device=x.device) + cond |
|
x = self.CT_sampler(x, cond, nonpadding, t_steps=1) |
|
|
|
return x |
|
|
|
|
|
class ComoSVC(BaseModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel |
|
self.distill = self.cfg.model.comosvc.distill |
|
self.encoder = Conformer(self.cfg.model.comosvc) |
|
self.decoder = Consistency(self.cfg, distill=self.distill) |
|
self.ssim_loss = SSIM() |
|
|
|
@torch.no_grad() |
|
def forward(self, x_mask, x, n_timesteps, temperature=1.0): |
|
""" |
|
Generates mel-spectrogram from pitch, content vector, energy. Returns: |
|
1. encoder outputs (from conformer) |
|
2. decoder outputs (from diffusion-based decoder) |
|
|
|
Args: |
|
x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] |
|
x : output of encoder framework. [B x L x d_condition] |
|
n_timesteps : number of steps to use for reverse diffusion in decoder. |
|
temperature : controls variance of terminal distribution. |
|
""" |
|
|
|
|
|
mu_x = self.encoder(x, x_mask) |
|
encoder_outputs = mu_x |
|
|
|
mu_x = mu_x.transpose(1, 2) |
|
x_mask = x_mask.transpose(1, 2) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True |
|
) |
|
decoder_outputs = decoder_outputs.transpose(1, 2) |
|
return encoder_outputs, decoder_outputs |
|
|
|
def compute_loss(self, x_mask, x, mel, skip_diff=False): |
|
""" |
|
Computes 2 losses: |
|
1. prior loss: loss between mel-spectrogram and encoder outputs. (l2 and ssim loss) |
|
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. |
|
|
|
Args: |
|
x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] |
|
x : output of encoder framework. [B x L x d_condition] |
|
mel : ground truth mel-spectrogram. [B x L x n_mel] |
|
""" |
|
|
|
mu_x = self.encoder(x, x_mask) |
|
|
|
x_mask = x_mask.repeat(1, 1, mel.shape[-1]) |
|
prior_loss = torch.sum( |
|
0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask |
|
) |
|
|
|
prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel) |
|
|
|
ssim_loss = self.ssim_loss(mu_x, mel) |
|
ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask) |
|
|
|
x_mask = x_mask.transpose(1, 2) |
|
mu_x = mu_x.transpose(1, 2) |
|
mel = mel.transpose(1, 2) |
|
if not self.distill and skip_diff: |
|
diff_loss = prior_loss.clone() |
|
diff_loss.fill_(0) |
|
|
|
|
|
else: |
|
mu_y = mu_x |
|
mask_y = x_mask |
|
|
|
diff_loss = self.decoder(mel, mask_y, mu_y, infer=False) |
|
|
|
return ssim_loss, prior_loss, diff_loss |
|
|