r3gm's picture
Upload 224 files
c2dad70
raw
history blame
4.26 kB
from abc import ABCMeta
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from .modules import TFC_TDF
dim_s = 4
class AbstractMDXNet(LightningModule):
__metaclass__ = ABCMeta
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap):
super().__init__()
self.target_name = target_name
self.lr = lr
self.optimizer = optimizer
self.dim_c = dim_c
self.dim_f = dim_f
self.dim_t = dim_t
self.n_fft = n_fft
self.n_bins = n_fft // 2 + 1
self.hop_length = hop_length
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False)
def configure_optimizers(self):
if self.optimizer == 'rmsprop':
return torch.optim.RMSprop(self.parameters(), self.lr)
if self.optimizer == 'adamw':
return torch.optim.AdamW(self.parameters(), self.lr)
class ConvTDFNet(AbstractMDXNet):
def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length,
num_blocks, l, g, k, bn, bias, overlap):
super(ConvTDFNet, self).__init__(
target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap)
self.save_hyperparameters()
self.num_blocks = num_blocks
self.l = l
self.g = g
self.k = k
self.bn = bn
self.bias = bias
if optimizer == 'rmsprop':
norm = nn.BatchNorm2d
if optimizer == 'adamw':
norm = lambda input:nn.GroupNorm(2, input)
self.n = num_blocks // 2
scale = (2, 2)
self.first_conv = nn.Sequential(
nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)),
norm(g),
nn.ReLU(),
)
f = self.dim_f
c = g
self.encoding_blocks = nn.ModuleList()
self.ds = nn.ModuleList()
for i in range(self.n):
self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
self.ds.append(
nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
norm(c + g),
nn.ReLU()
)
)
f = f // 2
c += g
self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm)
self.decoding_blocks = nn.ModuleList()
self.us = nn.ModuleList()
for i in range(self.n):
self.us.append(
nn.Sequential(
nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
norm(c - g),
nn.ReLU()
)
)
f = f * 2
c -= g
self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias, norm=norm))
self.final_conv = nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)),
)
def forward(self, x):
x = self.first_conv(x)
x = x.transpose(-1, -2)
ds_outputs = []
for i in range(self.n):
x = self.encoding_blocks[i](x)
ds_outputs.append(x)
x = self.ds[i](x)
x = self.bottleneck_block(x)
for i in range(self.n):
x = self.us[i](x)
x *= ds_outputs[-i - 1]
x = self.decoding_blocks[i](x)
x = x.transpose(-1, -2)
x = self.final_conv(x)
return x
class Mixer(nn.Module):
def __init__(self, device, mixer_path):
super(Mixer, self).__init__()
self.linear = nn.Linear((dim_s+1)*2, dim_s*2, bias=False)
self.load_state_dict(
torch.load(mixer_path, map_location=device)
)
def forward(self, x):
x = x.reshape(1,(dim_s+1)*2,-1).transpose(-1,-2)
x = self.linear(x)
return x.transpose(-1,-2).reshape(dim_s,2,-1)