Spaces:
Running
Running
import math | |
import torch | |
from einops import rearrange | |
from grad.base import BaseModule | |
from grad.solver import NoiseScheduleVP, MaxLikelihood, GradRaw | |
class Mish(BaseModule): | |
def forward(self, x): | |
return x * torch.tanh(torch.nn.functional.softplus(x)) | |
class Upsample(BaseModule): | |
def __init__(self, dim): | |
super(Upsample, self).__init__() | |
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) | |
def forward(self, x): | |
return self.conv(x) | |
class Downsample(BaseModule): | |
def __init__(self, dim): | |
super(Downsample, self).__init__() | |
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) | |
def forward(self, x): | |
return self.conv(x) | |
class Rezero(BaseModule): | |
def __init__(self, fn): | |
super(Rezero, self).__init__() | |
self.fn = fn | |
self.g = torch.nn.Parameter(torch.zeros(1)) | |
def forward(self, x): | |
return self.fn(x) * self.g | |
class Block(BaseModule): | |
def __init__(self, dim, dim_out, groups=8): | |
super(Block, self).__init__() | |
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, | |
padding=1), torch.nn.GroupNorm( | |
groups, dim_out), Mish()) | |
def forward(self, x, mask): | |
output = self.block(x * mask) | |
return output * mask | |
class ResnetBlock(BaseModule): | |
def __init__(self, dim, dim_out, time_emb_dim, groups=8): | |
super(ResnetBlock, self).__init__() | |
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, | |
dim_out)) | |
self.block1 = Block(dim, dim_out, groups=groups) | |
self.block2 = Block(dim_out, dim_out, groups=groups) | |
if dim != dim_out: | |
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) | |
else: | |
self.res_conv = torch.nn.Identity() | |
def forward(self, x, mask, time_emb): | |
h = self.block1(x, mask) | |
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) | |
h = self.block2(h, mask) | |
output = h + self.res_conv(x * mask) | |
return output | |
class LinearAttention(BaseModule): | |
def __init__(self, dim, heads=4, dim_head=32): | |
super(LinearAttention, self).__init__() | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x) | |
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', | |
heads = self.heads, qkv=3) | |
k = k.softmax(dim=-1) | |
context = torch.einsum('bhdn,bhen->bhde', k, v) | |
out = torch.einsum('bhde,bhdn->bhen', context, q) | |
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', | |
heads=self.heads, h=h, w=w) | |
return self.to_out(out) | |
class Residual(BaseModule): | |
def __init__(self, fn): | |
super(Residual, self).__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
output = self.fn(x, *args, **kwargs) + x | |
return output | |
class SinusoidalPosEmb(BaseModule): | |
def __init__(self, dim): | |
super(SinusoidalPosEmb, self).__init__() | |
self.dim = dim | |
def forward(self, x, scale=1000): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) | |
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class GradLogPEstimator2d(BaseModule): | |
def __init__(self, dim, dim_mults=(1, 2, 4), emb_dim=64, n_mels=100, | |
groups=8, pe_scale=1000): | |
super(GradLogPEstimator2d, self).__init__() | |
self.dim = dim | |
self.dim_mults = dim_mults | |
self.emb_dim = emb_dim | |
self.groups = groups | |
self.pe_scale = pe_scale | |
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim * 4), Mish(), | |
torch.nn.Linear(emb_dim * 4, n_mels)) | |
self.time_pos_emb = SinusoidalPosEmb(dim) | |
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), | |
torch.nn.Linear(dim * 4, dim)) | |
dims = [2 + 1, *map(lambda m: dim * m, dim_mults)] | |
in_out = list(zip(dims[:-1], dims[1:])) | |
self.downs = torch.nn.ModuleList([]) | |
self.ups = torch.nn.ModuleList([]) | |
num_resolutions = len(in_out) | |
for ind, (dim_in, dim_out) in enumerate(in_out): # 2 downs | |
is_last = ind >= (num_resolutions - 1) | |
self.downs.append(torch.nn.ModuleList([ | |
ResnetBlock(dim_in, dim_out, time_emb_dim=dim), | |
ResnetBlock(dim_out, dim_out, time_emb_dim=dim), | |
Residual(Rezero(LinearAttention(dim_out))), | |
Downsample(dim_out) if not is_last else torch.nn.Identity()])) | |
mid_dim = dims[-1] | |
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) | |
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) | |
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) | |
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): # 2 ups | |
self.ups.append(torch.nn.ModuleList([ | |
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), | |
ResnetBlock(dim_in, dim_in, time_emb_dim=dim), | |
Residual(Rezero(LinearAttention(dim_in))), | |
Upsample(dim_in)])) | |
self.final_block = Block(dim, dim) | |
self.final_conv = torch.nn.Conv2d(dim, 1, 1) | |
def forward(self, spk, x, mask, mu, t): | |
s = self.spk_mlp(spk) | |
t = self.time_pos_emb(t, scale=self.pe_scale) | |
t = self.mlp(t) | |
s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) | |
x = torch.stack([mu, x, s], 1) | |
mask = mask.unsqueeze(1) | |
hiddens = [] | |
masks = [mask] | |
for resnet1, resnet2, attn, downsample in self.downs: | |
mask_down = masks[-1] | |
x = resnet1(x, mask_down, t) | |
x = resnet2(x, mask_down, t) | |
x = attn(x) | |
hiddens.append(x) | |
x = downsample(x * mask_down) | |
masks.append(mask_down[:, :, :, ::2]) | |
masks = masks[:-1] | |
mask_mid = masks[-1] | |
x = self.mid_block1(x, mask_mid, t) | |
x = self.mid_attn(x) | |
x = self.mid_block2(x, mask_mid, t) | |
for resnet1, resnet2, attn, upsample in self.ups: | |
mask_up = masks.pop() | |
x = torch.cat((x, hiddens.pop()), dim=1) | |
x = resnet1(x, mask_up, t) | |
x = resnet2(x, mask_up, t) | |
x = attn(x) | |
x = upsample(x * mask_up) | |
x = self.final_block(x, mask) | |
output = self.final_conv(x * mask) | |
return (output * mask).squeeze(1) | |
def get_noise(t, beta_init, beta_term, cumulative=False): | |
if cumulative: | |
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2) | |
else: | |
noise = beta_init + (beta_term - beta_init)*t | |
return noise | |
class Diffusion(BaseModule): | |
def __init__(self, n_mels, dim, emb_dim=64, | |
beta_min=0.05, beta_max=20, pe_scale=1000): | |
super(Diffusion, self).__init__() | |
self.n_mels = n_mels | |
self.beta_min = beta_min | |
self.beta_max = beta_max | |
# self.solver = NoiseScheduleVP() | |
self.solver = MaxLikelihood() | |
# self.solver = GradRaw() | |
self.estimator = GradLogPEstimator2d(dim, | |
n_mels=n_mels, | |
emb_dim=emb_dim, | |
pe_scale=pe_scale) | |
def forward_diffusion(self, mel, mask, mu, t): | |
time = t.unsqueeze(-1).unsqueeze(-1) | |
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) | |
mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise)) | |
variance = 1.0 - torch.exp(-cum_noise) | |
z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device, | |
requires_grad=False) | |
xt = mean + z * torch.sqrt(variance) | |
return xt * mask, z * mask | |
def forward(self, spk, z, mask, mu, n_timesteps, stoc=False): | |
return self.solver.reverse_diffusion(self.estimator, spk, z, mask, mu, n_timesteps, stoc) | |
def loss_t(self, spk, mel, mask, mu, t): | |
xt, z = self.forward_diffusion(mel, mask, mu, t) | |
time = t.unsqueeze(-1).unsqueeze(-1) | |
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True) | |
noise_estimation = self.estimator(spk, xt, mask, mu, t) | |
noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise)) | |
loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_mels) | |
return loss, xt | |
def compute_loss(self, spk, mel, mask, mu, offset=1e-5): | |
t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False) | |
t = torch.clamp(t, offset, 1.0 - offset) | |
return self.loss_t(spk, mel, mask, mu, t) | |