|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import List |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
from audiotools import AudioSignal |
|
from audiotools.ml import BaseModel |
|
from torch import nn |
|
|
|
from .base import CodecMixin |
|
from ..nn.layers import Snake1d |
|
from ..nn.layers import WNConv1d |
|
from ..nn.layers import WNConvTranspose1d |
|
from ..nn.quantize import ResidualVectorQuantize |
|
from .encodec import SConv1d, SConvTranspose1d, SLSTM |
|
|
|
|
|
def init_weights(m): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
class ResidualUnit(nn.Module): |
|
def __init__(self, dim: int = 16, dilation: int = 1, causal: bool = False): |
|
super().__init__() |
|
conv1d_type = SConv1d |
|
pad = ((7 - 1) * dilation) // 2 |
|
self.block = nn.Sequential( |
|
Snake1d(dim), |
|
conv1d_type( |
|
dim, |
|
dim, |
|
kernel_size=7, |
|
dilation=dilation, |
|
padding=pad, |
|
causal=causal, |
|
norm="weight_norm", |
|
), |
|
Snake1d(dim), |
|
conv1d_type(dim, dim, kernel_size=1, causal=causal, norm="weight_norm"), |
|
) |
|
|
|
def forward(self, x): |
|
y = self.block(x) |
|
pad = (x.shape[-1] - y.shape[-1]) // 2 |
|
if pad > 0: |
|
x = x[..., pad:-pad] |
|
return x + y |
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
def __init__(self, dim: int = 16, stride: int = 1, causal: bool = False): |
|
super().__init__() |
|
conv1d_type = SConv1d |
|
self.block = nn.Sequential( |
|
ResidualUnit(dim // 2, dilation=1, causal=causal), |
|
ResidualUnit(dim // 2, dilation=3, causal=causal), |
|
ResidualUnit(dim // 2, dilation=9, causal=causal), |
|
Snake1d(dim // 2), |
|
conv1d_type( |
|
dim // 2, |
|
dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=math.ceil(stride / 2), |
|
causal=causal, |
|
norm="weight_norm", |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int = 64, |
|
strides: list = [2, 4, 8, 8], |
|
d_latent: int = 64, |
|
causal: bool = False, |
|
lstm: int = 2, |
|
): |
|
super().__init__() |
|
conv1d_type = SConv1d |
|
|
|
self.block = [ |
|
conv1d_type( |
|
1, d_model, kernel_size=7, padding=3, causal=causal, norm="weight_norm" |
|
) |
|
] |
|
|
|
|
|
for stride in strides: |
|
d_model *= 2 |
|
self.block += [EncoderBlock(d_model, stride=stride, causal=causal)] |
|
|
|
|
|
self.use_lstm = lstm |
|
if lstm: |
|
self.block += [SLSTM(d_model, lstm)] |
|
|
|
|
|
self.block += [ |
|
Snake1d(d_model), |
|
conv1d_type( |
|
d_model, |
|
d_latent, |
|
kernel_size=3, |
|
padding=1, |
|
causal=causal, |
|
norm="weight_norm", |
|
), |
|
] |
|
|
|
|
|
self.block = nn.Sequential(*self.block) |
|
self.enc_dim = d_model |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int = 16, |
|
output_dim: int = 8, |
|
stride: int = 1, |
|
causal: bool = False, |
|
): |
|
super().__init__() |
|
conv1d_type = SConvTranspose1d |
|
self.block = nn.Sequential( |
|
Snake1d(input_dim), |
|
conv1d_type( |
|
input_dim, |
|
output_dim, |
|
kernel_size=2 * stride, |
|
stride=stride, |
|
padding=math.ceil(stride / 2), |
|
causal=causal, |
|
norm="weight_norm", |
|
), |
|
ResidualUnit(output_dim, dilation=1, causal=causal), |
|
ResidualUnit(output_dim, dilation=3, causal=causal), |
|
ResidualUnit(output_dim, dilation=9, causal=causal), |
|
) |
|
|
|
def forward(self, x): |
|
return self.block(x) |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
input_channel, |
|
channels, |
|
rates, |
|
d_out: int = 1, |
|
causal: bool = False, |
|
lstm: int = 2, |
|
): |
|
super().__init__() |
|
conv1d_type = SConv1d |
|
|
|
layers = [ |
|
conv1d_type( |
|
input_channel, |
|
channels, |
|
kernel_size=7, |
|
padding=3, |
|
causal=causal, |
|
norm="weight_norm", |
|
) |
|
] |
|
|
|
if lstm: |
|
layers += [SLSTM(channels, num_layers=lstm)] |
|
|
|
|
|
for i, stride in enumerate(rates): |
|
input_dim = channels // 2**i |
|
output_dim = channels // 2 ** (i + 1) |
|
layers += [DecoderBlock(input_dim, output_dim, stride, causal=causal)] |
|
|
|
|
|
layers += [ |
|
Snake1d(output_dim), |
|
conv1d_type( |
|
output_dim, |
|
d_out, |
|
kernel_size=7, |
|
padding=3, |
|
causal=causal, |
|
norm="weight_norm", |
|
), |
|
nn.Tanh(), |
|
] |
|
|
|
self.model = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class DAC(BaseModel, CodecMixin): |
|
def __init__( |
|
self, |
|
encoder_dim: int = 64, |
|
encoder_rates: List[int] = [2, 4, 8, 8], |
|
latent_dim: int = None, |
|
decoder_dim: int = 1536, |
|
decoder_rates: List[int] = [8, 8, 4, 2], |
|
n_codebooks: int = 9, |
|
codebook_size: int = 1024, |
|
codebook_dim: Union[int, list] = 8, |
|
quantizer_dropout: bool = False, |
|
sample_rate: int = 44100, |
|
lstm: int = 2, |
|
causal: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.encoder_dim = encoder_dim |
|
self.encoder_rates = encoder_rates |
|
self.decoder_dim = decoder_dim |
|
self.decoder_rates = decoder_rates |
|
self.sample_rate = sample_rate |
|
|
|
if latent_dim is None: |
|
latent_dim = encoder_dim * (2 ** len(encoder_rates)) |
|
|
|
self.latent_dim = latent_dim |
|
|
|
self.hop_length = np.prod(encoder_rates) |
|
self.encoder = Encoder( |
|
encoder_dim, encoder_rates, latent_dim, causal=causal, lstm=lstm |
|
) |
|
|
|
self.n_codebooks = n_codebooks |
|
self.codebook_size = codebook_size |
|
self.codebook_dim = codebook_dim |
|
self.quantizer = ResidualVectorQuantize( |
|
input_dim=latent_dim, |
|
n_codebooks=n_codebooks, |
|
codebook_size=codebook_size, |
|
codebook_dim=codebook_dim, |
|
quantizer_dropout=quantizer_dropout, |
|
) |
|
|
|
self.decoder = Decoder( |
|
latent_dim, |
|
decoder_dim, |
|
decoder_rates, |
|
lstm=lstm, |
|
causal=causal, |
|
) |
|
self.sample_rate = sample_rate |
|
self.apply(init_weights) |
|
|
|
self.delay = self.get_delay() |
|
|
|
def preprocess(self, audio_data, sample_rate): |
|
if sample_rate is None: |
|
sample_rate = self.sample_rate |
|
assert sample_rate == self.sample_rate |
|
|
|
length = audio_data.shape[-1] |
|
right_pad = math.ceil(length / self.hop_length) * self.hop_length - length |
|
audio_data = nn.functional.pad(audio_data, (0, right_pad)) |
|
|
|
return audio_data |
|
|
|
def encode( |
|
self, |
|
audio_data: torch.Tensor, |
|
n_quantizers: int = None, |
|
): |
|
"""Encode given audio data and return quantized latent codes |
|
|
|
Parameters |
|
---------- |
|
audio_data : Tensor[B x 1 x T] |
|
Audio data to encode |
|
n_quantizers : int, optional |
|
Number of quantizers to use, by default None |
|
If None, all quantizers are used. |
|
|
|
Returns |
|
------- |
|
dict |
|
A dictionary with the following keys: |
|
"z" : Tensor[B x D x T] |
|
Quantized continuous representation of input |
|
"codes" : Tensor[B x N x T] |
|
Codebook indices for each codebook |
|
(quantized discrete representation of input) |
|
"latents" : Tensor[B x N*D x T] |
|
Projected latents (continuous representation of input before quantization) |
|
"vq/commitment_loss" : Tensor[1] |
|
Commitment loss to train encoder to predict vectors closer to codebook |
|
entries |
|
"vq/codebook_loss" : Tensor[1] |
|
Codebook loss to update the codebook |
|
"length" : int |
|
Number of samples in input audio |
|
""" |
|
z = self.encoder(audio_data) |
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer( |
|
z, n_quantizers |
|
) |
|
return z, codes, latents, commitment_loss, codebook_loss |
|
|
|
def decode(self, z: torch.Tensor): |
|
"""Decode given latent codes and return audio data |
|
|
|
Parameters |
|
---------- |
|
z : Tensor[B x D x T] |
|
Quantized continuous representation of input |
|
length : int, optional |
|
Number of samples in output audio, by default None |
|
|
|
Returns |
|
------- |
|
dict |
|
A dictionary with the following keys: |
|
"audio" : Tensor[B x 1 x length] |
|
Decoded audio data. |
|
""" |
|
return self.decoder(z) |
|
|
|
def forward( |
|
self, |
|
audio_data: torch.Tensor, |
|
sample_rate: int = None, |
|
n_quantizers: int = None, |
|
): |
|
"""Model forward pass |
|
|
|
Parameters |
|
---------- |
|
audio_data : Tensor[B x 1 x T] |
|
Audio data to encode |
|
sample_rate : int, optional |
|
Sample rate of audio data in Hz, by default None |
|
If None, defaults to `self.sample_rate` |
|
n_quantizers : int, optional |
|
Number of quantizers to use, by default None. |
|
If None, all quantizers are used. |
|
|
|
Returns |
|
------- |
|
dict |
|
A dictionary with the following keys: |
|
"z" : Tensor[B x D x T] |
|
Quantized continuous representation of input |
|
"codes" : Tensor[B x N x T] |
|
Codebook indices for each codebook |
|
(quantized discrete representation of input) |
|
"latents" : Tensor[B x N*D x T] |
|
Projected latents (continuous representation of input before quantization) |
|
"vq/commitment_loss" : Tensor[1] |
|
Commitment loss to train encoder to predict vectors closer to codebook |
|
entries |
|
"vq/codebook_loss" : Tensor[1] |
|
Codebook loss to update the codebook |
|
"length" : int |
|
Number of samples in input audio |
|
"audio" : Tensor[B x 1 x length] |
|
Decoded audio data. |
|
""" |
|
length = audio_data.shape[-1] |
|
audio_data = self.preprocess(audio_data, sample_rate) |
|
z, codes, latents, commitment_loss, codebook_loss = self.encode( |
|
audio_data, n_quantizers |
|
) |
|
|
|
x = self.decode(z) |
|
return { |
|
"audio": x[..., :length], |
|
"z": z, |
|
"codes": codes, |
|
"latents": latents, |
|
"vq/commitment_loss": commitment_loss, |
|
"vq/codebook_loss": codebook_loss, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
import numpy as np |
|
from functools import partial |
|
|
|
model = DAC().to("cpu") |
|
|
|
for n, m in model.named_modules(): |
|
o = m.extra_repr() |
|
p = sum([np.prod(p.size()) for p in m.parameters()]) |
|
fn = lambda o, p: o + f" {p/1e6:<.3f}M params." |
|
setattr(m, "extra_repr", partial(fn, o=o, p=p)) |
|
print(model) |
|
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) |
|
|
|
length = 88200 * 2 |
|
x = torch.randn(1, 1, length).to(model.device) |
|
x.requires_grad_(True) |
|
x.retain_grad() |
|
|
|
|
|
out = model(x)["audio"] |
|
print("Input shape:", x.shape) |
|
print("Output shape:", out.shape) |
|
|
|
|
|
grad = torch.zeros_like(out) |
|
grad[:, :, grad.shape[-1] // 2] = 1 |
|
|
|
|
|
out.backward(grad) |
|
|
|
|
|
gradmap = x.grad.squeeze(0) |
|
gradmap = (gradmap != 0).sum(0) |
|
rf = (gradmap != 0).sum() |
|
|
|
print(f"Receptive field: {rf.item()}") |
|
|
|
x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) |
|
model.decompress(model.compress(x, verbose=True), verbose=True) |
|
|