Spaces:
Runtime error
Runtime error
import math | |
from typing import Dict, List, Tuple | |
import torch | |
from torch import Tensor, nn | |
from torch.nn import functional as F | |
from .unet import UNet | |
def batchify(tensor: Tensor, T: int) -> Tensor: | |
""" | |
partition tensor into segments of length T, zero pad any ragged samples | |
Args: | |
tensor(Tensor): BxCxFxL | |
Returns: | |
tensor of size (B*[L/T] x C x F x T) | |
""" | |
# Zero pad the original tensor to an even multiple of T | |
orig_size = tensor.size(-1) | |
new_size = math.ceil(orig_size / T) * T | |
tensor = F.pad(tensor, [0, new_size - orig_size]) | |
# Partition the tensor into multiple samples of length T and stack them into a batch | |
return torch.cat(torch.split(tensor, T, dim=-1), dim=0) | |
class Splitter(nn.Module): | |
def __init__(self, stem_names: List[str] = None): | |
super(Splitter, self).__init__() | |
# stft config | |
self.F = 1024 | |
self.T = 512 | |
self.win_length = 4096 | |
self.hop_length = 1024 | |
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False) | |
self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2), | |
'accompaniment': UNet(in_channels=2)}) | |
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]: | |
""" | |
Computes stft feature from wav | |
Args: | |
wav (Tensor): B x L | |
""" | |
stft = torch.stft( | |
wav, | |
n_fft=self.win_length, | |
hop_length=self.hop_length, | |
window=self.win, | |
center=True, | |
return_complex=False, | |
pad_mode="constant", | |
) | |
# only keep freqs smaller than self.F | |
stft = stft[:, : self.F, :, :] | |
real = stft[:, :, :, 0] | |
im = stft[:, :, :, 1] | |
mag = torch.sqrt(real ** 2 + im ** 2) | |
return stft, mag | |
def inverse_stft(self, stft: Tensor) -> Tensor: | |
"""Inverses stft to wave form""" | |
pad = self.win_length // 2 + 1 - stft.size(1) | |
stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) | |
wav = torch.istft( | |
stft, | |
self.win_length, | |
hop_length=self.hop_length, | |
center=True, | |
window=self.win, | |
) | |
return wav.detach() | |
def forward(self, wav: Tensor) -> Dict[str, Tensor]: | |
""" | |
Separates stereo wav into different tracks (1 predicted track per stem) | |
Args: | |
wav (tensor): 2 x L | |
Returns: | |
masked stfts by track name | |
""" | |
# stft - 2 X F x L x 2 | |
# stft_mag - 2 X F x L | |
stft, stft_mag = self.compute_stft(wav.squeeze()) | |
L = stft.size(2) | |
# 1 x 2 x F x T | |
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2]) | |
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T | |
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F | |
# compute stems' mask | |
masks = {name: net(stft_mag) for name, net in self.stems.items()} | |
# compute denominator | |
mask_sum = sum([m ** 2 for m in masks.values()]) | |
mask_sum += 1e-10 | |
def apply_mask(mask): | |
mask = (mask ** 2 + 1e-10 / 2) / (mask_sum) | |
mask = mask.transpose(2, 3) # B x 2 X F x T | |
mask = torch.cat(torch.split(mask, 1, dim=0), dim=3) | |
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1 | |
stft_masked = stft * mask | |
return stft_masked | |
return {name: apply_mask(m) for name, m in masks.items()} | |
def separate(self, wav: Tensor) -> Dict[str, Tensor]: | |
""" | |
Separates stereo wav into different tracks (1 predicted track per stem) | |
Args: | |
wav (tensor): 2 x L | |
Returns: | |
wavs by track name | |
""" | |
stft_masks = self.forward(wav) | |
return { | |
name: self.inverse_stft(stft_masked) | |
for name, stft_masked in stft_masks.items() | |
} | |
def from_pretrained(cls, model_path: str): | |
checkpoint = torch.load(model_path) | |
model = cls() | |
model.load_state_dict(checkpoint) | |
return model | |