pengdaqian commited on
Commit
62e9d65
1 Parent(s): d7659a0
Dockerfile CHANGED
@@ -13,7 +13,6 @@ ENV PYTHONUNBUFFERED=1
13
 
14
  COPY requirements.txt requirements.txt
15
  RUN pip3 install --no-cache-dir -r requirements.txt
16
- RUN pip3 install --no-cache-dir --no-dependencies spleeter
17
 
18
  RUN mkdir -p $HOME/app/cache/transformers
19
  RUN chmod 777 $HOME/app/cache/transformers
 
13
 
14
  COPY requirements.txt requirements.txt
15
  RUN pip3 install --no-cache-dir -r requirements.txt
 
16
 
17
  RUN mkdir -p $HOME/app/cache/transformers
18
  RUN chmod 777 $HOME/app/cache/transformers
app.py CHANGED
@@ -16,7 +16,9 @@ import numpy as np
16
  import soundfile
17
  from pydub import AudioSegment
18
  import uuid
19
- from torchspleeter import split_to_parts
 
 
20
  import logging
21
 
22
  logging.getLogger('numba').setLevel(logging.WARNING)
@@ -84,6 +86,7 @@ load_svc_model("vits_pretrain/sovits5.0-48k-debug.pth", model)
84
  model.eval()
85
  model.to(device)
86
  whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
 
87
 
88
 
89
  # warm up
@@ -187,7 +190,7 @@ def svc_main(sid, input_audio):
187
  uuid_value = uuid.uuid4()
188
  uuid_string = str(uuid_value)
189
  input_audio_tmp_file = f'{uuid_string}.wav'
190
- tmpfile_path = '/tmp'
191
 
192
  #
193
  # prediction = separator.separate(audio)
@@ -196,9 +199,9 @@ def svc_main(sid, input_audio):
196
  if not os.path.exists(tmpfile_path):
197
  os.makedirs(tmpfile_path)
198
 
199
- split_to_parts(input_audio_tmp_file, tmpfile_path)
200
 
201
- curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
202
  vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
203
  accompaniment_filepath = os.path.join(curr_tmp_path, 'accompaniment.wav')
204
 
@@ -344,10 +347,9 @@ def main():
344
  vc_search.click(auto_search, [vc_input2], [vc_output1, vc_input3])
345
  vc_submit.click(svc_main, [sid, vc_input3], [vc_output1, vc_output2])
346
 
347
- app.queue(max_size=3, api_open=False).launch()
348
  except KeyboardInterrupt:
349
  app.close()
350
- separator._get_session().close()
351
  sys.exit(0)
352
 
353
 
 
16
  import soundfile
17
  from pydub import AudioSegment
18
  import uuid
19
+ from torchspleeter.utils import sound_split
20
+ from torchspleeter.splitter import Splitter
21
+
22
  import logging
23
 
24
  logging.getLogger('numba').setLevel(logging.WARNING)
 
86
  model.eval()
87
  model.to(device)
88
  whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
89
+ splitter_model = Splitter.from_pretrained(os.path.join("torchspleeter/models/2stems", "spleeter.pth")).to(device).eval()
90
 
91
 
92
  # warm up
 
190
  uuid_value = uuid.uuid4()
191
  uuid_string = str(uuid_value)
192
  input_audio_tmp_file = f'{uuid_string}.wav'
193
+ tmpfile_path = f'/tmp/{uuid_string}'
194
 
195
  #
196
  # prediction = separator.separate(audio)
 
199
  if not os.path.exists(tmpfile_path):
200
  os.makedirs(tmpfile_path)
201
 
202
+ sound_split(splitter_model, input_audio_tmp_file, tmpfile_path)
203
 
204
+ curr_tmp_path = tmpfile_path
205
  vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
206
  accompaniment_filepath = os.path.join(curr_tmp_path, 'accompaniment.wav')
207
 
 
347
  vc_search.click(auto_search, [vc_input2], [vc_output1, vc_input3])
348
  vc_submit.click(svc_main, [sid, vc_input3], [vc_output1, vc_output2])
349
 
350
+ app.queue(max_size=3, api_open=False).launch(server_port=6006)
351
  except KeyboardInterrupt:
352
  app.close()
 
353
  sys.exit(0)
354
 
355
 
torchspleeter/.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ data/
2
+ models/
3
+ output/
4
+ docs/
5
+ .vscode/
torchspleeter/.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models/
2
+ output/
3
+ experiments/
4
+ data/
5
+ .vscode/
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ pip-wheel-metadata/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101
+ __pypackages__/
102
+
103
+ # Celery stuff
104
+ celerybeat-schedule
105
+ celerybeat.pid
106
+
107
+ # SageMath parsed files
108
+ *.sage.py
109
+
110
+ # Environments
111
+ .env
112
+ .venv
113
+ env/
114
+ venv/
115
+ ENV/
116
+ env.bak/
117
+ venv.bak/
118
+
119
+ # Spyder project settings
120
+ .spyderproject
121
+ .spyproject
122
+
123
+ # Rope project settings
124
+ .ropeproject
125
+
126
+ # mkdocs documentation
127
+ /site
128
+
129
+ # mypy
130
+ .mypy_cache/
131
+ .dmypy.json
132
+ dmypy.json
133
+
134
+ # Pyre type checker
135
+ .pyre/
torchspleeter/__init__.py CHANGED
@@ -1,4 +0,0 @@
1
- from torchspleeter.command_interface import *
2
-
3
-
4
- version="0.1.5"
 
 
 
 
 
torchspleeter/checkpoints/2stems/audio_example.mp3 DELETED
Binary file (263 kB)
 
torchspleeter/checkpoints/2stems/testcheckpoint0.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:385ea3372c6a3ceee01f6ded5504bb7ee1e9f0101950ae58869dc18382deb75c
3
- size 59050239
 
 
 
 
torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3e4d6ede1ecad091468550773e77934aac3f1e039c0697fc9039aba9b935e344
3
- size 59033471
 
 
 
 
torchspleeter/command_interface.py DELETED
@@ -1,98 +0,0 @@
1
- """
2
-
3
- This provides an interface to interact with the spleeter system on
4
-
5
-
6
- """
7
-
8
- import os
9
- from torchspleeter.estimator import Estimator
10
- import argparse
11
- import uuid
12
- import numpy as np
13
- import librosa
14
- import soundfile
15
- import torch
16
- import pydub
17
- import os
18
- import shutil
19
-
20
-
21
- def split_to_parts(inputaudiofile, outputdir, instruments=2, models=[]):
22
- filedata = pydub.AudioSegment.from_file(inputaudiofile)
23
- sr = filedata.frame_rate
24
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
- # es = Estimator(2, './checkpoints/2stems/model').to(device)
26
- # es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
27
- es = Estimator()
28
- es.eval()
29
-
30
- # load wav audio
31
- testaudiofile = inputaudiofile
32
- channels = filedata.channels
33
- mono_selection = False
34
- if not os.path.exists(outputdir):
35
- os.makedirs(outputdir, exist_ok=True)
36
- if channels == 1:
37
- mono_selection = True
38
- multichannel = pydub.AudioSegment.from_mono_audiosegments(filedata, filedata)
39
- os.makedirs(os.path.join(outputdir, 'tmp'), exist_ok=True)
40
- testaudiofile = os.path.join(outputdir, "tmp" + str(uuid.uuid4()) + "." + testaudiofile.split('.')[-1])
41
- # testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
42
- multichannel.export(out_f=testaudiofile, format=testaudiofile.split('.')[-1])
43
- print(mono_selection)
44
- print(channels)
45
- wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast', sr=sr)
46
- wav = torch.Tensor(wav).to(device)
47
- if mono_selection:
48
- shutil.rmtree(os.path.join(outputdir, "tmp"))
49
- # os.remove(testaudiofile)
50
- wavs = es.separate(wav)
51
- outputname = str(uuid.uuid4())
52
- returnarray = []
53
- for i in range(len(wavs)):
54
- finaloutput = os.path.join(outputdir, outputname)
55
- fname = '-out_{}.wav'.format(i)
56
- fname = finaloutput + fname
57
- print('Writing ', fname)
58
- soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
59
- returnarray.append(fname)
60
- # write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
61
- return returnarray
62
-
63
-
64
- def get_file_list(dirname):
65
- outputfilelist = []
66
- for subdir, dirs, files in os.walk(dirname):
67
- for file in files:
68
- outputfilelist.append(os.path.join(subdir, file))
69
-
70
- return outputfilelist
71
-
72
-
73
- def main():
74
- parser = argparse.ArgumentParser(
75
- description='torchspleeter allows you to separate instrumentals from audio (vocals, instruments, background noise, etc) in a simple, cross platform manner')
76
- parser.add_argument('-i', '--inputfile', help='Input Audio File to split into instrumentals', required=True)
77
- parser.add_argument('-o', '--output', help='Output directory to deposit split audio', required=True)
78
- parser.add_argument('-n', '--number', help="Number of instruments in the model (default 2)", required=False,
79
- default=2, type=int)
80
- parser.add_argument('-m', '--modeldir',
81
- help="directory containing number of pre-converted torch compatible model components",
82
- required=False)
83
- args = vars(parser.parse_args())
84
- print(args)
85
- if args['modeldir'] is not None:
86
- modelfiles = get_file_list(args['modeldir'])
87
- if len(modelfiles) != args['number']:
88
- raise ValueError("You must have the same number of models as you do number of instruments!")
89
- else:
90
- args['modeldir'] = []
91
- outputfiles = split_to_parts(args['inputfile'], args['output'], args['number'], args['modeldir'])
92
- print("Your output files are:")
93
- for item in outputfiles:
94
- print(item)
95
-
96
-
97
- if __name__ == "__main__":
98
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torchspleeter/data.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+ import musdb
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ class MusdbDataset(Dataset):
11
+ def __init__(
12
+ self, root="data/musdb18-wav", is_train: bool = True, targets: List[str] = None
13
+ ) -> None:
14
+ super().__init__()
15
+ root = Path(root)
16
+ assert root.exists(), f"Path does not exist: {root}"
17
+ self.mus = musdb.DB(
18
+ root=root,
19
+ subsets=["train" if is_train else "test"],
20
+ is_wav=True,
21
+ )
22
+ self.targets = [s for s in targets] if targets else ["vocals", "accompaniment"]
23
+
24
+ def __len__(self) -> int:
25
+ return len(self.mus)
26
+
27
+ def __getitem__(self, index):
28
+ track = self.mus.tracks[index]
29
+ track.chunk_duration = 5.0
30
+ track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
31
+ x_wav = torch.torch.tensor(track.audio.T, dtype=torch.float32)
32
+ y_target_wavs = {
33
+ name: torch.tensor(track.targets[name].audio.T, dtype=torch.float32)
34
+ for name in self.targets
35
+ }
36
+ # original audio (x) and stems (y == targets)
37
+ return x_wav, y_target_wavs
torchspleeter/estimator.py DELETED
@@ -1,163 +0,0 @@
1
- import math
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn
6
- import tqdm
7
- # from torchaudio.functional import istft
8
-
9
- from torchspleeter.unet import UNet
10
- # from .util import tf2pytorch
11
-
12
- import os
13
-
14
- dirname = os.path.dirname(__file__)
15
- defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
16
- defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
17
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
-
20
- def load_ckpt(model, ckpt):
21
- state_dict = model.state_dict()
22
- for k, v in ckpt.items():
23
- if k in state_dict:
24
- target_shape = state_dict[k].shape
25
- assert target_shape == v.shape
26
- state_dict.update({k: torch.from_numpy(v)})
27
- else:
28
- print('Ignore ', k)
29
-
30
- model.load_state_dict(state_dict)
31
- return model
32
-
33
-
34
- def pad_and_partition(tensor, T):
35
- """
36
- pads zero and partition tensor into segments of length T
37
-
38
- Args:
39
- tensor(Tensor): BxCxFxL
40
-
41
- Returns:
42
- tensor of size (B*[L/T] x C x F x T)
43
- """
44
- old_size = tensor.size(3)
45
- new_size = math.ceil(old_size / T) * T
46
- tensor = F.pad(tensor, [0, new_size - old_size])
47
- [b, c, t, f] = tensor.shape
48
- split = new_size // T
49
- return torch.cat(torch.split(tensor, T, dim=3), dim=0)
50
-
51
-
52
- class Estimator(nn.Module):
53
- def __init__(self, num_instrumments=2, checkpoint_path=None):
54
- super(Estimator, self).__init__()
55
- if checkpoint_path is None:
56
- checkpoint_path = [defaultmodel0, defaultmodel1]
57
- else:
58
- if len(checkpoint_path) < 1:
59
- checkpoint_path = [defaultmodel0, defaultmodel1]
60
- # stft config
61
- self.F = 1024
62
- self.T = 512
63
- self.win_length = 4096
64
- self.hop_length = 1024
65
- self.win = nn.Parameter(
66
- torch.hann_window(self.win_length, device=device),
67
- requires_grad=False
68
- )
69
-
70
- ckpts = []
71
- if len(checkpoint_path) != num_instrumments:
72
- raise ValueError("You must submit as many models as there are instruments!")
73
- for ckpt_path in checkpoint_path:
74
- ckpts.append(torch.load(ckpt_path))
75
-
76
- # self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
77
-
78
- # ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
79
-
80
- # filter
81
- self.instruments = nn.ModuleList()
82
- for i in range(num_instrumments):
83
- print('Loading model for instrumment {}'.format(i))
84
- net = UNet(2)
85
- ckpt = ckpts[i]
86
- net = load_ckpt(net, ckpt)
87
- net.eval() # change mode to eval
88
- self.instruments.append(net)
89
-
90
- def compute_stft(self, wav):
91
- """
92
- Computes stft feature from wav
93
-
94
- Args:
95
- wav (Tensor): B x L
96
- """
97
-
98
- stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
99
- center=True, return_complex=False, pad_mode='constant')
100
-
101
- # only keep freqs smaller than self.F
102
- stft = stft[:, :self.F, :, :]
103
- real = stft[:, :, :, 0]
104
- im = stft[:, :, :, 1]
105
- mag = torch.sqrt(real ** 2 + im ** 2)
106
-
107
- return stft, mag
108
-
109
- def inverse_stft(self, stft):
110
- """Inverses stft to wave form"""
111
-
112
- pad = self.win_length // 2 + 1 - stft.size(1)
113
- stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
114
- wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
115
- window=self.win)
116
- return wav.detach()
117
-
118
- def separate(self, wav):
119
- """
120
- Separates stereo wav into different tracks corresponding to different instruments
121
-
122
- Args:
123
- wav (tensor): 2 x L
124
- """
125
-
126
- # stft - 2 X F x L x 2
127
- # stft_mag - 2 X F x L
128
- stft, stft_mag = self.compute_stft(wav)
129
-
130
- L = stft.size(2)
131
-
132
- # 1 x 2 x F x T
133
- stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
134
- stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
135
- stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
136
-
137
- B = stft_mag.shape[0]
138
-
139
- # compute instruments' mask
140
- masks = []
141
- for net in self.instruments:
142
- mask = net(stft_mag)
143
- masks.append(mask)
144
-
145
- # compute denominator
146
- mask_sum = sum([m ** 2 for m in masks])
147
- mask_sum += 1e-10
148
-
149
- wavs = []
150
- for mask in tqdm.tqdm(masks):
151
- mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
152
- mask = mask.transpose(2, 3) # B x 2 X F x T
153
-
154
- mask = torch.cat(
155
- torch.split(mask, 1, dim=0), dim=3)
156
-
157
- mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
158
- stft_masked = stft * mask
159
- wav_masked = self.inverse_stft(stft_masked)
160
-
161
- wavs.append(wav_masked)
162
-
163
- return wavs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torchspleeter/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.14.0
2
+ librosa
3
+ musdb==0.4.0
4
+ soundfile
5
+ tensorflow==2.10.0
6
+ torch==1.13.0
7
+ torchaudio
8
+ typer
torchspleeter/run.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import typer
5
+ from utils import sound_split
6
+
7
+ app = typer.Typer()
8
+
9
+
10
+ @app.command()
11
+ def model_summary() -> None:
12
+ from unet import UNet
13
+
14
+ net = UNet()
15
+ print(net)
16
+
17
+
18
+ @app.command()
19
+ def test() -> None:
20
+ from unet import UNet
21
+
22
+ batch_size = 5
23
+ n_channels = 2
24
+ x = torch.randn(batch_size, n_channels, 512, 128)
25
+ print(x.shape)
26
+ net = UNet(in_channels=n_channels)
27
+ y = net.forward(x)
28
+ print(y.shape)
29
+
30
+
31
+ @app.command()
32
+ def split(
33
+ model_path: str = "models/2stems/model",
34
+ input: str = "data/audio_example.mp3",
35
+ output_dir: str = "output",
36
+ offset: float = 0,
37
+ duration: float = 30,
38
+ write_src: bool = False,
39
+ ) -> None:
40
+
41
+ from splitter import Splitter
42
+
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ splitter_model = Splitter.from_pretrained(model_path).to(device).eval()
45
+
46
+ sound_split(splitter_model, input, output_dir, write_src)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ app()
torchspleeter/splitter.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+ from torch.nn import functional as F
7
+
8
+ from .unet import UNet
9
+
10
+
11
+ def batchify(tensor: Tensor, T: int) -> Tensor:
12
+ """
13
+ partition tensor into segments of length T, zero pad any ragged samples
14
+ Args:
15
+ tensor(Tensor): BxCxFxL
16
+ Returns:
17
+ tensor of size (B*[L/T] x C x F x T)
18
+ """
19
+ # Zero pad the original tensor to an even multiple of T
20
+ orig_size = tensor.size(-1)
21
+ new_size = math.ceil(orig_size / T) * T
22
+ tensor = F.pad(tensor, [0, new_size - orig_size])
23
+ # Partition the tensor into multiple samples of length T and stack them into a batch
24
+ return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
25
+
26
+
27
+ class Splitter(nn.Module):
28
+ def __init__(self, stem_names: List[str] = None):
29
+ super(Splitter, self).__init__()
30
+
31
+ # stft config
32
+ self.F = 1024
33
+ self.T = 512
34
+ self.win_length = 4096
35
+ self.hop_length = 1024
36
+ self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
37
+
38
+ self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2),
39
+ 'accompaniment': UNet(in_channels=2)})
40
+
41
+ def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
42
+ """
43
+ Computes stft feature from wav
44
+ Args:
45
+ wav (Tensor): B x L
46
+ """
47
+ stft = torch.stft(
48
+ wav,
49
+ n_fft=self.win_length,
50
+ hop_length=self.hop_length,
51
+ window=self.win,
52
+ center=True,
53
+ return_complex=False,
54
+ pad_mode="constant",
55
+ )
56
+
57
+ # only keep freqs smaller than self.F
58
+ stft = stft[:, : self.F, :, :]
59
+ real = stft[:, :, :, 0]
60
+ im = stft[:, :, :, 1]
61
+ mag = torch.sqrt(real ** 2 + im ** 2)
62
+
63
+ return stft, mag
64
+
65
+ def inverse_stft(self, stft: Tensor) -> Tensor:
66
+ """Inverses stft to wave form"""
67
+
68
+ pad = self.win_length // 2 + 1 - stft.size(1)
69
+ stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
70
+ wav = torch.istft(
71
+ stft,
72
+ self.win_length,
73
+ hop_length=self.hop_length,
74
+ center=True,
75
+ window=self.win,
76
+ )
77
+ return wav.detach()
78
+
79
+ def forward(self, wav: Tensor) -> Dict[str, Tensor]:
80
+ """
81
+ Separates stereo wav into different tracks (1 predicted track per stem)
82
+ Args:
83
+ wav (tensor): 2 x L
84
+ Returns:
85
+ masked stfts by track name
86
+ """
87
+ # stft - 2 X F x L x 2
88
+ # stft_mag - 2 X F x L
89
+ stft, stft_mag = self.compute_stft(wav.squeeze())
90
+
91
+ L = stft.size(2)
92
+
93
+ # 1 x 2 x F x T
94
+ stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
95
+ stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
96
+ stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
97
+
98
+ # compute stems' mask
99
+ masks = {name: net(stft_mag) for name, net in self.stems.items()}
100
+
101
+ # compute denominator
102
+ mask_sum = sum([m ** 2 for m in masks.values()])
103
+ mask_sum += 1e-10
104
+
105
+ def apply_mask(mask):
106
+ mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
107
+ mask = mask.transpose(2, 3) # B x 2 X F x T
108
+
109
+ mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
110
+
111
+ mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
112
+ stft_masked = stft * mask
113
+ return stft_masked
114
+
115
+ return {name: apply_mask(m) for name, m in masks.items()}
116
+
117
+ def separate(self, wav: Tensor) -> Dict[str, Tensor]:
118
+ """
119
+ Separates stereo wav into different tracks (1 predicted track per stem)
120
+ Args:
121
+ wav (tensor): 2 x L
122
+ Returns:
123
+ wavs by track name
124
+ """
125
+
126
+ stft_masks = self.forward(wav)
127
+
128
+ return {
129
+ name: self.inverse_stft(stft_masked)
130
+ for name, stft_masked in stft_masks.items()
131
+ }
132
+
133
+ @classmethod
134
+ def from_pretrained(cls, model_path: str):
135
+ checkpoint = torch.load(model_path)
136
+ model = cls()
137
+ model.load_state_dict(checkpoint)
138
+ return model
139
+
140
+
torchspleeter/test/test_estimator.py DELETED
@@ -1,45 +0,0 @@
1
- import numpy as np
2
- import librosa
3
- import soundfile
4
- import torch
5
- import pydub
6
- import os
7
- from torchspleeter.estimator import Estimator
8
- dirname = os.path.dirname(__file__)
9
- testfilename = os.path.join(dirname, 'checkpoints/2stems/audio_example.mp3')
10
-
11
- if __name__ == '__main__':
12
- sr = 44100
13
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
- #es = Estimator(2, './checkpoints/2stems/model').to(device)
15
- #es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
16
- es=Estimator()
17
- es.eval()
18
-
19
- # load wav audio
20
- testaudiofile=testfilename
21
- filestats=pydub.AudioSegment.from_file(testaudiofile)
22
- channels=filestats.channels
23
- mono_selection=False
24
- if channels==1:
25
- mono_selection=True
26
- multichannel=pydub.AudioSegment.from_mono_audiosegments(filestats,filestats)
27
- testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
28
- multichannel.export(out_f=testaudiofile,format=testaudiofile.split('.')[-1])
29
- print(mono_selection)
30
- print(channels)
31
- wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast',sr=sr)
32
- wav = torch.Tensor(wav).to(device)
33
- if mono_selection:
34
- os.remove(testaudiofile)
35
-
36
-
37
- # normalize audio
38
- # wav_torch = wav / (wav.max() + 1e-8)
39
-
40
- wavs = es.separate(wav)
41
- for i in range(len(wavs)):
42
- fname = 'output/out_{}.wav'.format(i)
43
- print('Writing ',fname)
44
- soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
45
- # write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torchspleeter/tf2pytorch.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+
4
+ import os
5
+
6
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
+ import tensorflow as tf
8
+
9
+
10
+ def parse_int_or_default(s: str, default: int = 0) -> int:
11
+ try:
12
+ return int(s)
13
+ except:
14
+ return default
15
+
16
+
17
+ def tf2pytorch(checkpoint_path: str) -> Dict:
18
+ init_vars = tf.train.list_variables(checkpoint_path)
19
+
20
+ tf_vars = {}
21
+ for name, _ in init_vars:
22
+ try:
23
+ # print('Loading TF Weight {} with shape {}'.format(name, shape))
24
+ data = tf.train.load_variable(checkpoint_path, name)
25
+ tf_vars[name] = data
26
+ except Exception as e:
27
+ print(f"Load error: {name}")
28
+ raise
29
+
30
+ layer_idxs = set(
31
+ [
32
+ parse_int_or_default(name.split("/")[0].split("_")[-1], default=0)
33
+ for name in tf_vars.keys()
34
+ if "conv2d_transpose" in name
35
+ ]
36
+ )
37
+
38
+ n_layers_per_unet = 6
39
+ n_layers_in_chkpt = max(layer_idxs) + 1
40
+ assert (
41
+ n_layers_in_chkpt % 6 == 0
42
+ ), f"expected multiple of {n_layers_per_unet}... ie: {n_layers_per_unet} layers per unet & 1 unet per stem"
43
+ n_stems = n_layers_in_chkpt // n_layers_per_unet
44
+
45
+ stem_names = {
46
+ 2: ["vocals", "accompaniment"],
47
+ 4: ["vocals", "drums", "bass", "other"],
48
+ 5: ["vocals", "piano", "drums", "bass", "other"],
49
+ }.get(n_stems, [])
50
+
51
+ assert stem_names, f"Unsupported stem count: {n_stems}"
52
+
53
+ state_dict = {}
54
+ tf_idx_conv = 0
55
+ tf_idx_tconv = 0
56
+ tf_idx_bn = 0
57
+
58
+ for stem_name in stem_names:
59
+ # Encoder Blocks (Down sampling)
60
+ for layer_idx in range(n_layers_per_unet):
61
+ prefix = f"stems.{stem_name}.encoder_layers.{layer_idx}"
62
+ conv_suffix = "" if tf_idx_conv == 0 else f"_{tf_idx_conv}"
63
+ bn_suffix = "" if tf_idx_bn == 0 else f"_{tf_idx_bn}"
64
+
65
+ state_dict[f"{prefix}.conv.weight"] = np.transpose(
66
+ tf_vars[f"conv2d{conv_suffix}/kernel"], (3, 2, 0, 1)
67
+ )
68
+ state_dict[f"{prefix}.conv.bias"] = tf_vars[f"conv2d{conv_suffix}/bias"]
69
+ tf_idx_conv += 1
70
+
71
+ state_dict[f"{prefix}.bn.weight"] = tf_vars[
72
+ f"batch_normalization{bn_suffix}/gamma"
73
+ ]
74
+ state_dict[f"{prefix}.bn.bias"] = tf_vars[
75
+ f"batch_normalization{bn_suffix}/beta"
76
+ ]
77
+ state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
78
+ f"batch_normalization{bn_suffix}/moving_mean"
79
+ ]
80
+ state_dict[f"{prefix}.bn.running_var"] = tf_vars[
81
+ f"batch_normalization{bn_suffix}/moving_variance"
82
+ ]
83
+ tf_idx_bn += 1
84
+
85
+ # Decoder Blocks (Up sampling)
86
+ for layer_idx in range(n_layers_per_unet):
87
+ prefix = f"stems.{stem_name}.decoder_layers.{layer_idx}"
88
+ tconv_suffix = "" if tf_idx_tconv == 0 else f"_{tf_idx_tconv}"
89
+ bn_suffix = f"_{tf_idx_bn}"
90
+
91
+ state_dict[f"{prefix}.tconv.weight"] = np.transpose(
92
+ tf_vars[f"conv2d_transpose{tconv_suffix}/kernel"], (3, 2, 0, 1)
93
+ )
94
+ state_dict[f"{prefix}.tconv.bias"] = tf_vars[
95
+ f"conv2d_transpose{tconv_suffix}/bias"
96
+ ]
97
+ tf_idx_tconv += 1
98
+
99
+ state_dict[f"{prefix}.bn.weight"] = tf_vars[
100
+ f"batch_normalization{bn_suffix}/gamma"
101
+ ]
102
+ state_dict[f"{prefix}.bn.bias"] = tf_vars[
103
+ f"batch_normalization{bn_suffix}/beta"
104
+ ]
105
+ state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
106
+ f"batch_normalization{bn_suffix}/moving_mean"
107
+ ]
108
+ state_dict[f"{prefix}.bn.running_var"] = tf_vars[
109
+ f"batch_normalization{bn_suffix}/moving_variance"
110
+ ]
111
+ tf_idx_bn += 1
112
+
113
+ # Final conv2d
114
+ state_dict[f"stems.{stem_name}.up_final.weight"] = np.transpose(
115
+ tf_vars[f"conv2d_{tf_idx_conv}/kernel"], (3, 2, 0, 1)
116
+ )
117
+ state_dict[f"stems.{stem_name}.up_final.bias"] = tf_vars[
118
+ f"conv2d_{tf_idx_conv}/bias"
119
+ ]
120
+ tf_idx_conv += 1
121
+
122
+ return state_dict
torchspleeter/train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import typer
7
+ from accelerate import Accelerator
8
+ from accelerate.utils import LoggerType
9
+ from torch import Tensor
10
+ from torch.optim import AdamW
11
+
12
+ # from torch.optim.lr_scheduler import ReduceLROnPlateau
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ from data import MusdbDataset
17
+ from splitter import Splitter
18
+
19
+ DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False)
20
+ app = typer.Typer(pretty_exceptions_show_locals=False)
21
+
22
+
23
+ def spectrogram_loss(masked_target: Tensor, original: Tensor) -> Tensor:
24
+ """
25
+ masked_target (Tensor): a masked STFT generated by applying a net's
26
+ estimated mask for source S to the ground truth STFT for source S
27
+ original (Tensor): an original input mixture
28
+ """
29
+ square_difference = torch.square(masked_target - original)
30
+ loss_value = torch.mean(square_difference)
31
+ return loss_value
32
+
33
+
34
+ @app.command()
35
+ def train(
36
+ dataset: str = "data/musdb18-wav",
37
+ output_dir: str = None,
38
+ fp16: bool = False,
39
+ cpu: bool = True,
40
+ max_steps: int = 100,
41
+ num_train_epochs: int = 1,
42
+ per_device_train_batch_size: int = 1,
43
+ effective_batch_size: int = 4,
44
+ max_grad_norm: float = 0.0,
45
+ ) -> None:
46
+ if not output_dir:
47
+ now_str = datetime.now().strftime("%Y%m%d-%H%M%S")
48
+ output_dir = f"experiments/{now_str}"
49
+ output_dir = Path(output_dir)
50
+ logging_dir = output_dir / "tracker_logs"
51
+ accelerator = Accelerator(
52
+ fp16=fp16,
53
+ cpu=cpu,
54
+ logging_dir=logging_dir,
55
+ log_with=[LoggerType.TENSORBOARD],
56
+ )
57
+ accelerator.init_trackers(logging_dir / "run")
58
+
59
+ train_dataset = MusdbDataset(root=dataset, is_train=True)
60
+ train_dataloader = DataLoader(
61
+ train_dataset,
62
+ shuffle=True,
63
+ batch_size=per_device_train_batch_size,
64
+ )
65
+
66
+ model = Splitter(stem_names=[s for s in train_dataset.targets])
67
+ optimizer = AdamW(
68
+ model.parameters(),
69
+ lr=1e-3,
70
+ eps=1e-8,
71
+ )
72
+ model, optimizer, train_dataloader = accelerator.prepare(
73
+ model, optimizer, train_dataloader
74
+ )
75
+
76
+ num_train_steps = (
77
+ max_steps if max_steps > 0 else len(train_dataloader) * num_train_epochs
78
+ )
79
+ accelerator.print(f"Num train steps: {num_train_steps}")
80
+
81
+ step_batch_size = per_device_train_batch_size * accelerator.num_processes
82
+ gradient_accumulation_steps = max(
83
+ 1,
84
+ effective_batch_size // step_batch_size,
85
+ )
86
+
87
+ accelerator.print(
88
+ f"Gradient Accumulation Steps: {gradient_accumulation_steps}\nEffective Batch Size: {gradient_accumulation_steps * step_batch_size}"
89
+ )
90
+ global_step = 0
91
+ while global_step < num_train_steps:
92
+ accelerator.wait_for_everyone()
93
+ # accelerator.print(f"global step: {global_step}")
94
+ # accelerator.print("running train...")
95
+ model.train()
96
+ batch_iterator = tqdm(
97
+ train_dataloader,
98
+ desc="Batch",
99
+ disable=((not accelerator.is_local_main_process) or DISABLE_TQDM),
100
+ )
101
+ for batch_idx, batch in enumerate(batch_iterator):
102
+ assert per_device_train_batch_size == 1, "For now limit to 1."
103
+ x_wav, y_target_wavs = batch
104
+ predictions = model(x_wav)
105
+ stem_losses = []
106
+ for name, masked_stft in predictions.items():
107
+ target_stft, _ = model.compute_stft(y_target_wavs[name].squeeze())
108
+ loss = spectrogram_loss(
109
+ masked_target=masked_stft,
110
+ original=target_stft,
111
+ )
112
+ stem_losses.append(loss)
113
+ accelerator.log({f"train-loss-{name}": 1.0 * loss}, step=global_step)
114
+
115
+ total_loss = (
116
+ torch.sum(torch.stack(stem_losses)) / gradient_accumulation_steps
117
+ )
118
+ accelerator.print(f"global step: {global_step}\tloss: {total_loss:.4f}")
119
+ accelerator.log({f"train-loss": 1.0 * total_loss}, step=global_step)
120
+ accelerator.backward(total_loss)
121
+ if (batch_idx + 1) % gradient_accumulation_steps == 0:
122
+ if max_grad_norm > 0:
123
+ accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
124
+ optimizer.step()
125
+ optimizer.zero_grad()
126
+ global_step += 1
127
+
128
+ accelerator.wait_for_everyone()
129
+ accelerator.end_training()
130
+ accelerator.print(f"Saving model to {output_dir}...")
131
+ unwrapped_model = accelerator.unwrap_model(model)
132
+ unwrapped_model.save_pretrained(
133
+ output_dir,
134
+ save_function=accelerator.save,
135
+ state_dict=accelerator.get_state_dict(model),
136
+ )
137
+
138
+ accelerator.wait_for_everyone()
139
+ accelerator.print("DONE!")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ app()
torchspleeter/unet.py CHANGED
@@ -1,97 +1,97 @@
 
1
  import torch
2
- from torch import nn
3
- import torch.nn.functional as F
4
-
5
- class CustomPad(nn.Module):
6
- def __init__(self, padding_setting=(1, 2, 1, 2)):
7
- super(CustomPad, self).__init__()
8
- self.padding_setting = padding_setting
9
-
10
- def forward(self, x):
11
- return F.pad(x, self.padding_setting, "constant", 0)
12
-
13
-
14
- class CustomTransposedPad(nn.Module):
15
- def __init__(self, padding_setting=(1, 2, 1, 2)):
16
- super(CustomTransposedPad, self).__init__()
17
- self.padding_setting = padding_setting
18
-
19
- def forward(self, x):
20
- l,r,t,b = self.padding_setting
21
- return x[:,:,l:-r,t:-b]
22
-
23
 
24
- def down_block(in_filters, out_filters):
25
- return nn.Sequential(CustomPad(),
26
- nn.Conv2d(in_filters, out_filters, kernel_size=5, stride=2,padding=0)), \
27
- nn.Sequential(
28
- nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
29
- nn.LeakyReLU(0.2))
30
 
31
 
32
- def up_block(in_filters, out_filters, dropout=False):
33
- layers = [
34
- nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,stride=2),
35
- CustomTransposedPad(),
36
- nn.ReLU(),
37
- nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
38
- ]
39
- if dropout:
40
- layers.append(nn.Dropout(0.5))
 
 
 
 
41
 
42
- return nn.Sequential(*layers)
 
 
 
 
 
43
 
44
 
45
  class UNet(nn.Module):
46
- def __init__(self, in_channels=2):
47
- super(UNet, self).__init__()
48
- self.down1_conv, self.down1_act = down_block(in_channels, 16)
49
- self.down2_conv, self.down2_act = down_block(16, 32)
50
- self.down3_conv, self.down3_act = down_block(32, 64)
51
- self.down4_conv, self.down4_act = down_block(64, 128)
52
- self.down5_conv, self.down5_act = down_block(128, 256)
53
- self.down6_conv, self.down6_act = down_block(256, 512)
54
-
55
- self.up1 = up_block(512, 256, dropout=True)
56
- self.up2 = up_block(512, 128, dropout=True)
57
- self.up3 = up_block(256, 64, dropout=True)
58
- self.up4 = up_block(128, 32)
59
- self.up5 = up_block(64, 16)
60
- self.up6 = up_block(32, 1)
61
- self.up7 = nn.Sequential(
62
- nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3),
63
- nn.Sigmoid()
64
  )
65
 
66
- def forward(self, x):
67
- d1_conv = self.down1_conv(x)
68
- d1 = self.down1_act(d1_conv)
69
-
70
- d2_conv = self.down2_conv(d1)
71
- d2 = self.down2_act(d2_conv)
72
-
73
- d3_conv = self.down3_conv(d2)
74
- d3 = self.down3_act(d3_conv)
75
-
76
- d4_conv = self.down4_conv(d3)
77
- d4 = self.down4_act(d4_conv)
78
-
79
- d5_conv = self.down5_conv(d4)
80
- d5 = self.down5_act(d5_conv)
81
-
82
- d6_conv = self.down6_conv(d5)
83
- d6 = self.down6_act(d6_conv)
84
-
85
- u1 = self.up1(d6_conv)
86
- u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
87
- u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
88
- u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
89
- u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
90
- u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
91
- u7 = self.up7(u6)
92
- return u7 * x
93
-
94
 
95
- if __name__ == '__main__':
96
- net = UNet(14)
97
- print(net(torch.rand(1, 14, 20, 48)).shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
  import torch
3
+ from torch import nn, Tensor
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class EncoderBlock(nn.Module):
8
+ def __init__(self, in_channels: int, out_channels: int) -> None:
9
+ super().__init__()
10
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2))
11
+ self.bn = nn.BatchNorm2d(
12
+ num_features=out_channels,
13
+ track_running_stats=True,
14
+ eps=1e-3,
15
+ momentum=0.01,
16
+ )
17
+ self.relu = nn.LeakyReLU(negative_slope=0.2)
 
 
 
 
 
 
18
 
19
+ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
20
+ down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0))
21
+ return down, self.relu(self.bn(down))
 
 
 
22
 
23
 
24
+ class DecoderBlock(nn.Module):
25
+ def __init__(
26
+ self, in_channels: int, out_channels: int, dropout_prob: float = 0.0
27
+ ) -> None:
28
+ super().__init__()
29
+ self.tconv = nn.ConvTranspose2d(
30
+ in_channels, out_channels, kernel_size=5, stride=2
31
+ )
32
+ self.relu = nn.ReLU()
33
+ self.bn = nn.BatchNorm2d(
34
+ out_channels, track_running_stats=True, eps=1e-3, momentum=0.01
35
+ )
36
+ self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()
37
 
38
+ def forward(self, input: Tensor) -> Tensor:
39
+ up = self.tconv(input)
40
+ # reverse padding
41
+ l, r, t, b = 1, 2, 1, 2
42
+ up = up[:, :, l:-r, t:-b]
43
+ return self.dropout(self.bn(self.relu(up)))
44
 
45
 
46
  class UNet(nn.Module):
47
+ def __init__(
48
+ self,
49
+ n_layers: int = 6,
50
+ in_channels: int = 1,
51
+ ) -> None:
52
+ super().__init__()
53
+
54
+ # DownSample layers
55
+ down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
56
+ self.encoder_layers = nn.ModuleList(
57
+ [
58
+ EncoderBlock(in_channels=in_ch, out_channels=out_ch)
59
+ for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
60
+ ]
 
 
 
 
61
  )
62
 
63
+ # UpSample layers
64
+ up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
65
+ up_set.reverse()
66
+ self.decoder_layers = nn.ModuleList(
67
+ [
68
+ DecoderBlock(
69
+ # doubled for concatenated inputs (skip connections)
70
+ in_channels=in_ch if i == 0 else in_ch * 2,
71
+ out_channels=out_ch,
72
+ # 50 % dropout... first 3 layers only
73
+ dropout_prob=0.5 if i < 3 else 0,
74
+ )
75
+ for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))
76
+ ]
77
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ # reconstruct the final mask same as the original channels
80
+ self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
81
+ self.sigmoid = nn.Sigmoid()
82
+
83
+ def forward(self, input: Tensor) -> Tensor:
84
+ encoder_outputs_pre_act = []
85
+ x = input
86
+ for down in self.encoder_layers:
87
+ conv, x = down(x)
88
+ encoder_outputs_pre_act.append(conv)
89
+
90
+ for i, up in enumerate(self.decoder_layers):
91
+ if i == 0:
92
+ x = up(encoder_outputs_pre_act.pop())
93
+ else:
94
+ # merge skip connection
95
+ x = up(torch.concat([encoder_outputs_pre_act.pop(), x], axis=1))
96
+ mask = self.sigmoid(self.up_final(x))
97
+ return mask * input
torchspleeter/utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import soundfile
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from .splitter import Splitter
8
+
9
+
10
+ def sound_split(
11
+ model: Splitter,
12
+ input: str = "data/audio_example.mp3",
13
+ output_dir: str = "output",
14
+ write_src: bool = False,
15
+ ) -> None:
16
+ sr = 44100
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ # splitter = Splitter.from_pretrained(model_path).to(device).eval()
19
+
20
+ # load wav audio
21
+ fpath_src = Path(input)
22
+ wav, _ = librosa.load(
23
+ fpath_src,
24
+ mono=False,
25
+ res_type="kaiser_fast",
26
+ sr=sr,
27
+ )
28
+ wav = torch.Tensor(wav).to(device)
29
+
30
+ # normalize audio
31
+ # wav_torch = wav / (wav.max() + 1e-8)
32
+
33
+ with torch.no_grad():
34
+ stems = model.separate(wav)
35
+
36
+ if write_src:
37
+ stems["input"] = wav
38
+ for name, stem in stems.items():
39
+ fpath_dst = Path(output_dir) / f"{name}.wav"
40
+ print(f"Writing {fpath_dst}")
41
+ fpath_dst.parent.mkdir(exist_ok=True)
42
+ soundfile.write(fpath_dst, stem.cpu().detach().numpy().T, sr, "PCM_16")
warm_up.py CHANGED
@@ -1,5 +1,5 @@
1
  from spleeter.separator import Separator
2
 
3
- separator = Separator('spleeter:2stems')
4
  # warm up
5
  separator.separate_to_file('warm.wav', '/tmp/warm')
 
1
  from spleeter.separator import Separator
2
 
3
+ separator = Separator('spleeter.pth:2stems')
4
  # warm up
5
  separator.separate_to_file('warm.wav', '/tmp/warm')