Spaces:
Runtime error
Runtime error
File size: 4,159 Bytes
62e9d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import math
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from .unet import UNet
def batchify(tensor: Tensor, T: int) -> Tensor:
"""
partition tensor into segments of length T, zero pad any ragged samples
Args:
tensor(Tensor): BxCxFxL
Returns:
tensor of size (B*[L/T] x C x F x T)
"""
# Zero pad the original tensor to an even multiple of T
orig_size = tensor.size(-1)
new_size = math.ceil(orig_size / T) * T
tensor = F.pad(tensor, [0, new_size - orig_size])
# Partition the tensor into multiple samples of length T and stack them into a batch
return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
class Splitter(nn.Module):
def __init__(self, stem_names: List[str] = None):
super(Splitter, self).__init__()
# stft config
self.F = 1024
self.T = 512
self.win_length = 4096
self.hop_length = 1024
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2),
'accompaniment': UNet(in_channels=2)})
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
"""
Computes stft feature from wav
Args:
wav (Tensor): B x L
"""
stft = torch.stft(
wav,
n_fft=self.win_length,
hop_length=self.hop_length,
window=self.win,
center=True,
return_complex=False,
pad_mode="constant",
)
# only keep freqs smaller than self.F
stft = stft[:, : self.F, :, :]
real = stft[:, :, :, 0]
im = stft[:, :, :, 1]
mag = torch.sqrt(real ** 2 + im ** 2)
return stft, mag
def inverse_stft(self, stft: Tensor) -> Tensor:
"""Inverses stft to wave form"""
pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
wav = torch.istft(
stft,
self.win_length,
hop_length=self.hop_length,
center=True,
window=self.win,
)
return wav.detach()
def forward(self, wav: Tensor) -> Dict[str, Tensor]:
"""
Separates stereo wav into different tracks (1 predicted track per stem)
Args:
wav (tensor): 2 x L
Returns:
masked stfts by track name
"""
# stft - 2 X F x L x 2
# stft_mag - 2 X F x L
stft, stft_mag = self.compute_stft(wav.squeeze())
L = stft.size(2)
# 1 x 2 x F x T
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
# compute stems' mask
masks = {name: net(stft_mag) for name, net in self.stems.items()}
# compute denominator
mask_sum = sum([m ** 2 for m in masks.values()])
mask_sum += 1e-10
def apply_mask(mask):
mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
mask = mask.transpose(2, 3) # B x 2 X F x T
mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
stft_masked = stft * mask
return stft_masked
return {name: apply_mask(m) for name, m in masks.items()}
def separate(self, wav: Tensor) -> Dict[str, Tensor]:
"""
Separates stereo wav into different tracks (1 predicted track per stem)
Args:
wav (tensor): 2 x L
Returns:
wavs by track name
"""
stft_masks = self.forward(wav)
return {
name: self.inverse_stft(stft_masked)
for name, stft_masked in stft_masks.items()
}
@classmethod
def from_pretrained(cls, model_path: str):
checkpoint = torch.load(model_path)
model = cls()
model.load_state_dict(checkpoint)
return model
|