|
import math |
|
from dataclasses import dataclass |
|
from typing import Union, Tuple, Literal |
|
|
|
import torch as T |
|
import torch.nn as nn |
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
|
from utils import load_ckpt |
|
from utils.interp import print_colored |
|
from utils import si_module, get_activation |
|
|
|
|
|
|
|
|
|
|
|
def Conv1d1x1(in_channels, out_channels, bias=True): |
|
return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias) |
|
|
|
|
|
class NonCausalConv1d(nn.Module): |
|
"""1D noncausal convolution w/ 2-sides padding.""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=-1, |
|
dilation=1, |
|
groups=1, |
|
bias=True): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
if padding < 0: |
|
padding = (kernel_size - 1) // 2 * dilation |
|
self.dilation = dilation |
|
self.conv = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (Tensor): Float tensor variable with the shape (B, C, T). |
|
Returns: |
|
Tensor: Float tensor variable with the shape (B, C, T). |
|
""" |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class NonCausalConvTranspose1d(nn.Module): |
|
"""1D noncausal transpose convolution.""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
padding=-1, |
|
output_padding=-1, |
|
groups=1, |
|
bias=True, |
|
): |
|
super().__init__() |
|
if padding < 0: |
|
padding = (stride+1) // 2 |
|
if output_padding < 0: |
|
output_padding = 1 if stride % 2 else 0 |
|
self.deconv = nn.ConvTranspose1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (Tensor): Float tensor variable with the shape (B, C, T). |
|
Returns: |
|
Tensor: Float tensor variable with the shape (B, C', T'). |
|
""" |
|
x = self.deconv(x) |
|
return x |
|
|
|
|
|
class CausalConv1d(NonCausalConv1d): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
dilation=1, |
|
groups=1, |
|
bias=True |
|
): |
|
super(CausalConv1d, self).__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
self.stride = stride |
|
self.pad_length = (kernel_size - 1) * dilation |
|
def forward(self, x): |
|
pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) |
|
x = pad(x) |
|
return self.conv(x) |
|
|
|
|
|
class CausalConvTranspose1d(NonCausalConvTranspose1d): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
bias=True, |
|
pad_buffer=None, |
|
): |
|
super(CausalConvTranspose1d, self).__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
output_padding=0, |
|
bias=bias, |
|
) |
|
self.stride = stride |
|
self.pad_length = (math.ceil(kernel_size/stride) - 1) |
|
if pad_buffer is None: |
|
pad_buffer = T.zeros(1, in_channels, self.pad_length) |
|
self.register_buffer("pad_buffer", pad_buffer) |
|
|
|
def forward(self, x): |
|
pad = nn.ReplicationPad1d((self.pad_length, 0)) |
|
x = pad(x) |
|
return self.deconv(x)[:, :, self.stride : -self.stride] |
|
|
|
def inference(self, x): |
|
x = T.cat((self.pad_buffer, x), -1) |
|
self.pad_buffer = x[:, :, -self.pad_length:] |
|
return self.deconv(x)[:, :, self.stride : -self.stride] |
|
|
|
def reset_buffer(self): |
|
self.pad_buffer.zero_() |
|
|
|
|
|
class NonCausalResUnit(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=7, |
|
dilation=1, |
|
bias=False, |
|
): |
|
super().__init__() |
|
self.activation = nn.ELU() |
|
self.conv1 = NonCausalConv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
self.conv2 = Conv1d1x1(out_channels, out_channels, bias) |
|
|
|
def forward(self, x): |
|
y = self.conv1(self.activation(x)) |
|
y = self.conv2(self.activation(y)) |
|
return x + y |
|
|
|
|
|
class CausalResUnit(NonCausalResUnit): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=7, |
|
dilation=1, |
|
bias=False, |
|
): |
|
super(CausalResUnit, self).__init__( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
self.conv1 = CausalConv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
dilation=dilation, |
|
bias=bias, |
|
) |
|
|
|
def inference(self, x): |
|
y = self.conv1.inference(self.activation(x)) |
|
y = self.conv2(self.activation(y)) |
|
return x + y |
|
|
|
|
|
class ResNetBlock(nn.Module): |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
stride, |
|
kernel_size=7, |
|
dilations=(1, 3, 9), |
|
bias=True, |
|
mode='encoder', |
|
): |
|
super().__init__() |
|
assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!" |
|
|
|
self.mode = mode |
|
self.stride = stride |
|
|
|
ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d |
|
|
|
res_channels = in_channels if mode == 'encoder' else out_channels |
|
|
|
res_units = [CausalResUnit( |
|
res_channels, |
|
res_channels, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
) for dilation in dilations] |
|
|
|
if in_channels == out_channels: |
|
if mode == 'encoder': |
|
self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride) |
|
if mode == 'decoder': |
|
self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') |
|
conv_unit = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
bias=bias, |
|
) if in_channels != out_channels else nn.Identity() |
|
else: |
|
conv_unit = ConvUnit( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=(2 * stride), |
|
stride=stride, |
|
bias=bias, |
|
) |
|
|
|
if mode == 'encoder': |
|
if in_channels == out_channels: |
|
self.res_block = nn.Sequential(*res_units, self.pool, conv_unit) |
|
else: |
|
self.res_block = nn.Sequential(*res_units, conv_unit) |
|
elif mode == 'decoder': |
|
if in_channels == out_channels: |
|
self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units) |
|
else: |
|
self.res_block = nn.Sequential(conv_unit, *res_units) |
|
|
|
def forward(self, x): |
|
out = x |
|
for unit in self.res_block: |
|
out = unit(out) |
|
return out |
|
|
|
def inference(self, x): |
|
for unit in self.res_block: |
|
x = unit.inference(x) |
|
return x |
|
|
|
|
|
|
|
|
|
@si_module |
|
class ResNetStack(nn.Module): |
|
""" |
|
ResNet encoder or decoder stack. Channel ratios |
|
and strides take the default order of from |
|
data/io-layer, to the middle of the model. |
|
""" |
|
class Config: |
|
input_channels: int = 1 |
|
output_channels: int = 1 |
|
encode_channels: int = 32 |
|
decode_channel_multiplier: int = 1 |
|
latent_dim: int = None |
|
kernel_size: int = 7 |
|
bias: bool = True |
|
channel_ratios: Tuple[int, ...] = (2, 4, 8, 16) |
|
strides: Tuple[int, ...] = (3, 4, 5, 5) |
|
mode: Literal['encoder', 'decoder'] = 'encoder' |
|
|
|
def __init__(self, c: Config): |
|
super().__init__() |
|
assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!" |
|
|
|
self.mode = c.mode |
|
|
|
assert len(c.channel_ratios) == len(c.strides) |
|
channel_ratios = (1,) + c.channel_ratios |
|
strides = c.strides |
|
self.middle_channels = c.encode_channels * channel_ratios[-1] |
|
if c.mode == 'decoder': |
|
channel_ratios = tuple(reversed(channel_ratios)) |
|
strides = tuple(reversed(strides)) |
|
|
|
self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1 |
|
res_blocks = [ResNetBlock( |
|
c.encode_channels * channel_ratios[s_idx] * self.multiplier, |
|
c.encode_channels * channel_ratios[s_idx+1] * self.multiplier, |
|
stride, |
|
kernel_size=c.kernel_size, |
|
bias=c.bias, |
|
mode=c.mode, |
|
) for s_idx, stride in enumerate(strides)] |
|
|
|
data_conv = CausalConv1d( |
|
in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier, |
|
out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels, |
|
kernel_size=c.kernel_size, |
|
stride=1, |
|
bias=False, |
|
) |
|
|
|
if c.mode == 'encoder': |
|
self.res_stack = nn.Sequential(data_conv, *res_blocks) |
|
elif c.mode == 'decoder': |
|
self.res_stack = nn.Sequential(*res_blocks, data_conv) |
|
|
|
if c.latent_dim is not None: |
|
self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias) |
|
if self.multiplier != 1: |
|
self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias) |
|
|
|
def forward(self, x, return_feats=False): |
|
if self.c.latent_dim is not None and self.mode == 'decoder': |
|
x = self.latent_proj(x) |
|
if self.multiplier != 1: |
|
x = self.multiplier_proj(x) |
|
|
|
feats = [] |
|
for block in self.res_stack: |
|
x = block(x) |
|
if return_feats: |
|
feats.append(x) |
|
if self.c.latent_dim is not None and self.mode == 'encoder': |
|
x = self.latent_proj(x) |
|
if return_feats: |
|
feats.append(x) |
|
if return_feats: |
|
return feats |
|
return x |
|
|
|
def inference(self, x): |
|
for block in self.res_stack: |
|
x = block.inference(x) |
|
return x |
|
|
|
def reset_buffer(self): |
|
def _reset_buffer(m): |
|
if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): |
|
m.reset_buffer() |
|
self.apply(_reset_buffer) |
|
|
|
def reset_parameters(self): |
|
def _reset_parameters(m): |
|
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): |
|
m.weight.data.normal_(0.0, 0.01) |
|
|
|
self.apply(_reset_parameters) |
|
|
|
|
|
def apply_weight_norm(self): |
|
def _apply_weight_norm(m): |
|
if isinstance(m, nn.Conv1d) or isinstance( |
|
m, nn.ConvTranspose1d |
|
): |
|
nn.utils.parametrizations.weight_norm(m) |
|
|
|
self.apply(_apply_weight_norm) |
|
|
|
|
|
def remove_weight_norm(self): |
|
def _remove_weight_norm(m): |
|
try: |
|
print(m) |
|
nn.utils.remove_weight_norm(m) |
|
except ValueError: |
|
return |
|
|
|
self.apply(_remove_weight_norm) |
|
|
|
|
|
|
|
@si_module |
|
class GaussianZ(nn.Module): |
|
class Config: |
|
dim: int |
|
latent_dim: int |
|
bias: bool = False |
|
use_weight_norm: bool = False |
|
|
|
def __init__(self, c: Config): |
|
super().__init__() |
|
|
|
self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias) |
|
self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias) |
|
|
|
if c.use_weight_norm: |
|
self.proj_in = weight_norm(self.proj_in) |
|
self.proj_out = weight_norm(self.proj_out) |
|
|
|
def reparam(self, mu, logvar): |
|
std = T.exp(logvar / 2) |
|
eps = T.randn_like(std) |
|
return mu + eps * std |
|
|
|
def kl_divergence(self, mu, logvar): |
|
return T.mean(-0.5 * T.sum( |
|
1 + logvar - mu.pow(2) - logvar.exp(), |
|
dim=(1, 2)) |
|
) |
|
|
|
def repr_from_latent(self, latent: Union[dict, T.Tensor]): |
|
if isinstance(latent, T.Tensor): |
|
z = latent |
|
else: |
|
z = self.reparam(latent['mu'], latent['logvar']) |
|
l = self.proj_out(z) |
|
return l |
|
|
|
def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]: |
|
mu, logvar = self.proj_in(x).chunk(2, dim=-1) |
|
kl_div = self.kl_divergence(mu, logvar) |
|
z = self.reparam(mu, logvar) |
|
xhat = self.proj_out(z) |
|
latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div} |
|
return xhat, latent |
|
|
|
|
|
|
|
@si_module |
|
class WaveCodec(nn.Module): |
|
class Config: |
|
resnet_config: ResNetStack.Config = None |
|
sample_rate: int = 16_000 |
|
use_weight_norm: bool = False |
|
|
|
compressor_config: dataclass = None |
|
|
|
norm_stddev: float = 1.0 |
|
|
|
def __init__(self, c: Config): |
|
super().__init__() |
|
self.norm_stddev = c.norm_stddev |
|
self.encoder = c.resnet_config(mode='encoder') |
|
self.sample_rate = c.sample_rate |
|
|
|
self.total_stride = 1 |
|
for stride in c.resnet_config.strides: |
|
self.total_stride *= stride |
|
self.tokens_per_second = self.sample_rate / self.total_stride |
|
|
|
self.compressor = c.compressor_config(dim=self.encoder.middle_channels) |
|
|
|
self.decoder = c.resnet_config(mode='decoder') |
|
|
|
if c.use_weight_norm: |
|
self.encoder.apply_weight_norm() |
|
self.decoder.apply_weight_norm() |
|
self.encoder.reset_parameters() |
|
self.decoder.reset_parameters() |
|
|
|
def encode(self, data): |
|
return self.encoder(data/self.norm_stddev) |
|
|
|
def decode(self, latent): |
|
return self.decoder(latent.transpose(1, 2))*self.norm_stddev |
|
|
|
@T.no_grad() |
|
def latent_from_data(self, data, get_parameters=False): |
|
x = self.encode(data) |
|
l_in = x.transpose(1, 2) |
|
l, latent = self.compressor(l_in) |
|
return latent['z'] if not get_parameters else { |
|
'mu': latent['mu'], |
|
'logvar': latent['logvar'], |
|
'z': latent['z'], |
|
} |
|
|
|
@T.no_grad() |
|
def data_from_latent(self, latent): |
|
l = self.compressor.repr_from_latent(latent) |
|
x = self.decode(l) |
|
return x |
|
|
|
def process(self, x): |
|
return self.latent_from_data(x) |
|
|
|
def unprocess(self, latent): |
|
return self.data_from_latent(latent) |
|
|
|
def forward(self, audio_input): |
|
x = self.encode(audio_input) |
|
|
|
l_in = x.transpose(1, 2) |
|
l, latent = self.compressor(l_in) |
|
|
|
xhat = self.decode(l) |
|
return xhat, latent |
|
|
|
|
|
|
|
def make_tokenizer(device='cuda'): |
|
generator_config = WaveCodec.Config( |
|
resnet_config=ResNetStack.Config( |
|
input_channels=1, |
|
output_channels=1, |
|
encode_channels=16, |
|
decode_channel_multiplier=4, |
|
kernel_size=7, |
|
bias=True, |
|
channel_ratios=(4, 8, 16, 16, 16, 16), |
|
strides=(2, 2, 4, 5, 5, 5), |
|
mode=None, |
|
), |
|
use_weight_norm=True, |
|
|
|
compressor_config=GaussianZ.Config( |
|
dim=None, |
|
latent_dim=32, |
|
|
|
bias=True, |
|
use_weight_norm=True |
|
), |
|
|
|
norm_stddev=0.05, |
|
) |
|
checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") |
|
|
|
tokenizer = generator_config() |
|
|
|
load_result = tokenizer.load_state_dict(checkpoint, strict=False) |
|
print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") |
|
|
|
tokenizer = tokenizer.eval() |
|
|
|
if device == 'cuda': |
|
tokenizer = tokenizer.bfloat16() |
|
tokenizer = tokenizer.to(device) |
|
tokenizer.requires_grad_ = False |
|
return tokenizer |
|
|
|
|