File size: 4,958 Bytes
38ae436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import math
from einops import rearrange
from vector_quantize_pytorch import GroupedResidualFSQ
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvNeXtBlock(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
kernel, dilation,
layer_scale_init_value: float = 1e-6,
):
# ConvNeXt Block copied from Vocos.
super().__init__()
self.dwconv = nn.Conv1d(dim, dim,
kernel_size=kernel, padding=dilation*(kernel//2),
dilation=dilation, groups=dim
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class GFSQ(nn.Module):
def __init__(self,
dim, levels, G, R, eps=1e-5, transpose = True
):
super(GFSQ, self).__init__()
self.quantizer = GroupedResidualFSQ(
dim=dim,
levels=levels,
num_quantizers=R,
groups=G,
)
self.n_ind = math.prod(levels)
self.eps = eps
self.transpose = transpose
self.G = G
self.R = R
def _embed(self, x):
if self.transpose:
x = x.transpose(1,2)
x = rearrange(
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
)
feat = self.quantizer.get_output_from_indices(x)
return feat.transpose(1,2) if self.transpose else feat
def forward(self, x,):
if self.transpose:
x = x.transpose(1,2)
feat, ind = self.quantizer(x)
ind = rearrange(
ind, "g b t r ->b t (g r)",
)
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
e_mean = torch.mean(embed_onehot, dim=[0,1])
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
return (
torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
feat.transpose(1,2) if self.transpose else feat,
perplexity,
None,
ind.transpose(1,2) if self.transpose else ind,
)
class DVAEDecoder(nn.Module):
def __init__(self, idim, odim,
n_layer = 12, bn_dim = 64, hidden = 256,
kernel = 7, dilation = 2, up = False
):
super().__init__()
self.up = up
self.conv_in = nn.Sequential(
nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
nn.Conv1d(bn_dim, hidden, 3, 1, 1)
)
self.decoder_block = nn.ModuleList([
ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
for _ in range(n_layer)])
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
def forward(self, input, conditioning=None):
# B, T, C
x = input.transpose(1, 2)
x = self.conv_in(x)
for f in self.decoder_block:
x = f(x, conditioning)
x = self.conv_out(x)
return x.transpose(1, 2)
class DVAE(nn.Module):
def __init__(
self, decoder_config, vq_config, dim=512
):
super().__init__()
self.register_buffer('coef', torch.randn(1, 100, 1))
self.decoder = DVAEDecoder(**decoder_config)
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
if vq_config is not None:
self.vq_layer = GFSQ(**vq_config)
else:
self.vq_layer = None
def forward(self, inp):
if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp.detach().clone()
temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
temp = torch.stack(temp, -1)
vq_feats = temp.reshape(*temp.shape[:2], -1)
vq_feats = vq_feats.transpose(1, 2)
dec_out = self.decoder(input=vq_feats)
dec_out = self.out_conv(dec_out.transpose(1, 2))
mel = dec_out * self.coef
return mel
|