|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from models.tts.naturalspeech2.diffusion import Diffusion |
|
from models.tts.naturalspeech2.diffusion_flow import DiffusionFlow |
|
from models.tts.naturalspeech2.wavenet import WaveNet |
|
from models.tts.naturalspeech2.prior_encoder import PriorEncoder |
|
from modules.naturalpseech2.transformers import TransformerEncoder |
|
from encodec import EncodecModel |
|
from einops import rearrange, repeat |
|
|
|
import os |
|
import json |
|
|
|
|
|
class NaturalSpeech2(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
self.latent_dim = cfg.latent_dim |
|
self.query_emb_num = cfg.query_emb.query_token_num |
|
|
|
self.prior_encoder = PriorEncoder(cfg.prior_encoder) |
|
if cfg.diffusion.diffusion_type == "diffusion": |
|
self.diffusion = Diffusion(cfg.diffusion) |
|
elif cfg.diffusion.diffusion_type == "flow": |
|
self.diffusion = DiffusionFlow(cfg.diffusion) |
|
|
|
self.prompt_encoder = TransformerEncoder(cfg=cfg.prompt_encoder) |
|
if self.latent_dim != cfg.prompt_encoder.encoder_hidden: |
|
self.prompt_lin = nn.Linear( |
|
self.latent_dim, cfg.prompt_encoder.encoder_hidden |
|
) |
|
self.prompt_lin.weight.data.normal_(0.0, 0.02) |
|
else: |
|
self.prompt_lin = None |
|
|
|
self.query_emb = nn.Embedding(self.query_emb_num, cfg.query_emb.hidden_size) |
|
self.query_attn = nn.MultiheadAttention( |
|
cfg.query_emb.hidden_size, cfg.query_emb.head_num, batch_first=True |
|
) |
|
|
|
codec_model = EncodecModel.encodec_model_24khz() |
|
codec_model.set_target_bandwidth(12.0) |
|
codec_model.requires_grad_(False) |
|
self.quantizer = codec_model.quantizer |
|
|
|
@torch.no_grad() |
|
def code_to_latent(self, code): |
|
latent = self.quantizer.decode(code.transpose(0, 1)) |
|
return latent |
|
|
|
def latent_to_code(self, latent, nq=16): |
|
residual = latent |
|
all_indices = [] |
|
all_dist = [] |
|
for i in range(nq): |
|
layer = self.quantizer.vq.layers[i] |
|
x = rearrange(residual, "b d n -> b n d") |
|
x = layer.project_in(x) |
|
shape = x.shape |
|
x = layer._codebook.preprocess(x) |
|
embed = layer._codebook.embed.t() |
|
dist = -( |
|
x.pow(2).sum(1, keepdim=True) |
|
- 2 * x @ embed |
|
+ embed.pow(2).sum(0, keepdim=True) |
|
) |
|
indices = dist.max(dim=-1).indices |
|
indices = layer._codebook.postprocess_emb(indices, shape) |
|
dist = dist.reshape(*shape[:-1], dist.shape[-1]) |
|
quantized = layer.decode(indices) |
|
residual = residual - quantized |
|
all_indices.append(indices) |
|
all_dist.append(dist) |
|
|
|
out_indices = torch.stack(all_indices) |
|
out_dist = torch.stack(all_dist) |
|
|
|
return out_indices, out_dist |
|
|
|
@torch.no_grad() |
|
def latent_to_latent(self, latent, nq=16): |
|
codes, _ = self.latent_to_code(latent, nq) |
|
latent = self.quantizer.vq.decode(codes) |
|
return latent |
|
|
|
def forward( |
|
self, |
|
code=None, |
|
pitch=None, |
|
duration=None, |
|
phone_id=None, |
|
phone_id_frame=None, |
|
frame_nums=None, |
|
ref_code=None, |
|
ref_frame_nums=None, |
|
phone_mask=None, |
|
mask=None, |
|
ref_mask=None, |
|
): |
|
ref_latent = self.code_to_latent(ref_code) |
|
latent = self.code_to_latent(code) |
|
|
|
if self.latent_dim is not None: |
|
ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) |
|
|
|
ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) |
|
spk_emb = ref_latent.transpose(1, 2) |
|
|
|
spk_query_emb = self.query_emb( |
|
torch.arange(self.query_emb_num).to(latent.device) |
|
).repeat( |
|
latent.shape[0], 1, 1 |
|
) |
|
spk_query_emb, _ = self.query_attn( |
|
spk_query_emb, |
|
spk_emb.transpose(1, 2), |
|
spk_emb.transpose(1, 2), |
|
key_padding_mask=~(ref_mask.bool()), |
|
) |
|
|
|
prior_out = self.prior_encoder( |
|
phone_id=phone_id, |
|
duration=duration, |
|
pitch=pitch, |
|
phone_mask=phone_mask, |
|
mask=mask, |
|
ref_emb=spk_emb, |
|
ref_mask=ref_mask, |
|
is_inference=False, |
|
) |
|
prior_condition = prior_out["prior_out"] |
|
|
|
diff_out = self.diffusion(latent, mask, prior_condition, spk_query_emb) |
|
|
|
return diff_out, prior_out |
|
|
|
@torch.no_grad() |
|
def inference( |
|
self, ref_code=None, phone_id=None, ref_mask=None, inference_steps=1000 |
|
): |
|
ref_latent = self.code_to_latent(ref_code) |
|
|
|
if self.latent_dim is not None: |
|
ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) |
|
|
|
ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) |
|
spk_emb = ref_latent.transpose(1, 2) |
|
|
|
spk_query_emb = self.query_emb( |
|
torch.arange(self.query_emb_num).to(ref_latent.device) |
|
).repeat( |
|
ref_latent.shape[0], 1, 1 |
|
) |
|
spk_query_emb, _ = self.query_attn( |
|
spk_query_emb, |
|
spk_emb.transpose(1, 2), |
|
spk_emb.transpose(1, 2), |
|
key_padding_mask=~(ref_mask.bool()), |
|
) |
|
|
|
prior_out = self.prior_encoder( |
|
phone_id=phone_id, |
|
duration=None, |
|
pitch=None, |
|
phone_mask=None, |
|
mask=None, |
|
ref_emb=spk_emb, |
|
ref_mask=ref_mask, |
|
is_inference=True, |
|
) |
|
prior_condition = prior_out["prior_out"] |
|
|
|
z = torch.randn( |
|
prior_condition.shape[0], self.latent_dim, prior_condition.shape[1] |
|
).to(ref_latent.device) / (1.20) |
|
x0 = self.diffusion.reverse_diffusion( |
|
z, None, prior_condition, inference_steps, spk_query_emb |
|
) |
|
|
|
return x0, prior_out |
|
|
|
@torch.no_grad() |
|
def reverse_diffusion_from_t( |
|
self, |
|
code=None, |
|
pitch=None, |
|
duration=None, |
|
phone_id=None, |
|
ref_code=None, |
|
phone_mask=None, |
|
mask=None, |
|
ref_mask=None, |
|
n_timesteps=None, |
|
t=None, |
|
): |
|
|
|
|
|
ref_latent = self.code_to_latent(ref_code) |
|
latent = self.code_to_latent(code) |
|
|
|
if self.latent_dim is not None: |
|
ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) |
|
|
|
ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) |
|
spk_emb = ref_latent.transpose(1, 2) |
|
|
|
spk_query_emb = self.query_emb( |
|
torch.arange(self.query_emb_num).to(latent.device) |
|
).repeat( |
|
latent.shape[0], 1, 1 |
|
) |
|
spk_query_emb, _ = self.query_attn( |
|
spk_query_emb, |
|
spk_emb.transpose(1, 2), |
|
spk_emb.transpose(1, 2), |
|
key_padding_mask=~(ref_mask.bool()), |
|
) |
|
|
|
prior_out = self.prior_encoder( |
|
phone_id=phone_id, |
|
duration=duration, |
|
pitch=pitch, |
|
phone_mask=phone_mask, |
|
mask=mask, |
|
ref_emb=spk_emb, |
|
ref_mask=ref_mask, |
|
is_inference=False, |
|
) |
|
prior_condition = prior_out["prior_out"] |
|
|
|
diffusion_step = ( |
|
torch.ones( |
|
latent.shape[0], |
|
dtype=latent.dtype, |
|
device=latent.device, |
|
requires_grad=False, |
|
) |
|
* t |
|
) |
|
diffusion_step = torch.clamp(diffusion_step, 1e-5, 1.0 - 1e-5) |
|
xt, _ = self.diffusion.forward_diffusion( |
|
x0=latent, diffusion_step=diffusion_step |
|
) |
|
|
|
|
|
x0 = self.diffusion.reverse_diffusion_from_t( |
|
xt, mask, prior_condition, n_timesteps, spk_query_emb, t_start=t |
|
) |
|
|
|
return x0, prior_out, xt |
|
|