File size: 4,159 Bytes
62e9d65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import math
from typing import Dict, List, Tuple

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from .unet import UNet


def batchify(tensor: Tensor, T: int) -> Tensor:
    """
    partition tensor into segments of length T, zero pad any ragged samples
    Args:
        tensor(Tensor): BxCxFxL
    Returns:
        tensor of size (B*[L/T] x C x F x T)
    """
    # Zero pad the original tensor to an even multiple of T
    orig_size = tensor.size(-1)
    new_size = math.ceil(orig_size / T) * T
    tensor = F.pad(tensor, [0, new_size - orig_size])
    # Partition the tensor into multiple samples of length T and stack them into a batch
    return torch.cat(torch.split(tensor, T, dim=-1), dim=0)


class Splitter(nn.Module):
    def __init__(self, stem_names: List[str] = None):
        super(Splitter, self).__init__()

        # stft config
        self.F = 1024
        self.T = 512
        self.win_length = 4096
        self.hop_length = 1024
        self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)

        self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2),
                                    'accompaniment': UNet(in_channels=2)})

    def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Computes stft feature from wav
        Args:
            wav (Tensor): B x L
        """
        stft = torch.stft(
            wav,
            n_fft=self.win_length,
            hop_length=self.hop_length,
            window=self.win,
            center=True,
            return_complex=False,
            pad_mode="constant",
        )

        # only keep freqs smaller than self.F
        stft = stft[:, : self.F, :, :]
        real = stft[:, :, :, 0]
        im = stft[:, :, :, 1]
        mag = torch.sqrt(real ** 2 + im ** 2)

        return stft, mag

    def inverse_stft(self, stft: Tensor) -> Tensor:
        """Inverses stft to wave form"""

        pad = self.win_length // 2 + 1 - stft.size(1)
        stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
        wav = torch.istft(
            stft,
            self.win_length,
            hop_length=self.hop_length,
            center=True,
            window=self.win,
        )
        return wav.detach()

    def forward(self, wav: Tensor) -> Dict[str, Tensor]:
        """
        Separates stereo wav into different tracks (1 predicted track per stem)
        Args:
            wav (tensor): 2 x L
        Returns:
            masked stfts by track name
        """
        # stft - 2 X F x L x 2
        # stft_mag - 2 X F x L
        stft, stft_mag = self.compute_stft(wav.squeeze())

        L = stft.size(2)

        # 1 x 2 x F x T
        stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
        stft_mag = batchify(stft_mag, self.T)  # B x 2 x F x T
        stft_mag = stft_mag.transpose(2, 3)  # B x 2 x T x F

        # compute stems' mask
        masks = {name: net(stft_mag) for name, net in self.stems.items()}

        # compute denominator
        mask_sum = sum([m ** 2 for m in masks.values()])
        mask_sum += 1e-10

        def apply_mask(mask):
            mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
            mask = mask.transpose(2, 3)  # B x 2 X F x T

            mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)

            mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1)  # 2 x F x L x 1
            stft_masked = stft * mask
            return stft_masked

        return {name: apply_mask(m) for name, m in masks.items()}

    def separate(self, wav: Tensor) -> Dict[str, Tensor]:
        """
        Separates stereo wav into different tracks (1 predicted track per stem)
        Args:
            wav (tensor): 2 x L
        Returns:
            wavs by track name
        """

        stft_masks = self.forward(wav)

        return {
            name: self.inverse_stft(stft_masked)
            for name, stft_masked in stft_masks.items()
        }

    @classmethod
    def from_pretrained(cls, model_path: str):
        checkpoint = torch.load(model_path)
        model = cls()
        model.load_state_dict(checkpoint)
        return model