Spaces:
Sleeping
Sleeping
from abc import ABCMeta | |
import torch | |
import torch.nn as nn | |
from pytorch_lightning import LightningModule | |
from .modules import TFC_TDF | |
dim_s = 4 | |
class AbstractMDXNet(LightningModule): | |
__metaclass__ = ABCMeta | |
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap): | |
super().__init__() | |
self.target_name = target_name | |
self.lr = lr | |
self.optimizer = optimizer | |
self.dim_c = dim_c | |
self.dim_f = dim_f | |
self.dim_t = dim_t | |
self.n_fft = n_fft | |
self.n_bins = n_fft // 2 + 1 | |
self.hop_length = hop_length | |
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) | |
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False) | |
def configure_optimizers(self): | |
if self.optimizer == 'rmsprop': | |
return torch.optim.RMSprop(self.parameters(), self.lr) | |
if self.optimizer == 'adamw': | |
return torch.optim.AdamW(self.parameters(), self.lr) | |
class ConvTDFNet(AbstractMDXNet): | |
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, | |
num_blocks, l, g, k, bn, bias, overlap): | |
super(ConvTDFNet, self).__init__( | |
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap) | |
self.save_hyperparameters() | |
self.num_blocks = num_blocks | |
self.l = l | |
self.g = g | |
self.k = k | |
self.bn = bn | |
self.bias = bias | |
if optimizer == 'rmsprop': | |
norm = nn.BatchNorm2d | |
if optimizer == 'adamw': | |
norm = lambda input:nn.GroupNorm(2, input) | |
self.n = num_blocks // 2 | |
scale = (2, 2) | |
self.first_conv = nn.Sequential( | |
nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), | |
norm(g), | |
nn.ReLU(), | |
) | |
f = self.dim_f | |
c = g | |
self.encoding_blocks = nn.ModuleList() | |
self.ds = nn.ModuleList() | |
for i in range(self.n): | |
self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) | |
self.ds.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), | |
norm(c + g), | |
nn.ReLU() | |
) | |
) | |
f = f // 2 | |
c += g | |
self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm) | |
self.decoding_blocks = nn.ModuleList() | |
self.us = nn.ModuleList() | |
for i in range(self.n): | |
self.us.append( | |
nn.Sequential( | |
nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), | |
norm(c - g), | |
nn.ReLU() | |
) | |
) | |
f = f * 2 | |
c -= g | |
self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)) | |
self.final_conv = nn.Sequential( | |
nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), | |
) | |
def forward(self, x): | |
x = self.first_conv(x) | |
x = x.transpose(-1, -2) | |
ds_outputs = [] | |
for i in range(self.n): | |
x = self.encoding_blocks[i](x) | |
ds_outputs.append(x) | |
x = self.ds[i](x) | |
x = self.bottleneck_block(x) | |
for i in range(self.n): | |
x = self.us[i](x) | |
x *= ds_outputs[-i - 1] | |
x = self.decoding_blocks[i](x) | |
x = x.transpose(-1, -2) | |
x = self.final_conv(x) | |
return x | |
class Mixer(nn.Module): | |
def __init__(self, device, mixer_path): | |
super(Mixer, self).__init__() | |
self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False) | |
self.load_state_dict( | |
torch.load(mixer_path, map_location=device) | |
) | |
def forward(self, x): | |
x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2) | |
x = self.linear(x) | |
return x.transpose(-1,-2).reshape(dim_s,2,-1) |