Spaces:
Runtime error
Runtime error
pengdaqian
commited on
Commit
•
62e9d65
1
Parent(s):
d7659a0
fix
Browse files- Dockerfile +0 -1
- app.py +8 -6
- torchspleeter/.dockerignore +5 -0
- torchspleeter/.gitignore +135 -0
- torchspleeter/__init__.py +0 -4
- torchspleeter/checkpoints/2stems/audio_example.mp3 +0 -0
- torchspleeter/checkpoints/2stems/testcheckpoint0.ckpt +0 -3
- torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt +0 -3
- torchspleeter/command_interface.py +0 -98
- torchspleeter/data.py +37 -0
- torchspleeter/estimator.py +0 -163
- torchspleeter/requirements.txt +8 -0
- torchspleeter/run.py +50 -0
- torchspleeter/splitter.py +140 -0
- torchspleeter/test/test_estimator.py +0 -45
- torchspleeter/tf2pytorch.py +122 -0
- torchspleeter/train.py +143 -0
- torchspleeter/unet.py +86 -86
- torchspleeter/utils.py +42 -0
- warm_up.py +1 -1
Dockerfile
CHANGED
@@ -13,7 +13,6 @@ ENV PYTHONUNBUFFERED=1
|
|
13 |
|
14 |
COPY requirements.txt requirements.txt
|
15 |
RUN pip3 install --no-cache-dir -r requirements.txt
|
16 |
-
RUN pip3 install --no-cache-dir --no-dependencies spleeter
|
17 |
|
18 |
RUN mkdir -p $HOME/app/cache/transformers
|
19 |
RUN chmod 777 $HOME/app/cache/transformers
|
|
|
13 |
|
14 |
COPY requirements.txt requirements.txt
|
15 |
RUN pip3 install --no-cache-dir -r requirements.txt
|
|
|
16 |
|
17 |
RUN mkdir -p $HOME/app/cache/transformers
|
18 |
RUN chmod 777 $HOME/app/cache/transformers
|
app.py
CHANGED
@@ -16,7 +16,9 @@ import numpy as np
|
|
16 |
import soundfile
|
17 |
from pydub import AudioSegment
|
18 |
import uuid
|
19 |
-
from torchspleeter import
|
|
|
|
|
20 |
import logging
|
21 |
|
22 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
@@ -84,6 +86,7 @@ load_svc_model("vits_pretrain/sovits5.0-48k-debug.pth", model)
|
|
84 |
model.eval()
|
85 |
model.to(device)
|
86 |
whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
|
|
|
87 |
|
88 |
|
89 |
# warm up
|
@@ -187,7 +190,7 @@ def svc_main(sid, input_audio):
|
|
187 |
uuid_value = uuid.uuid4()
|
188 |
uuid_string = str(uuid_value)
|
189 |
input_audio_tmp_file = f'{uuid_string}.wav'
|
190 |
-
tmpfile_path = '/tmp'
|
191 |
|
192 |
#
|
193 |
# prediction = separator.separate(audio)
|
@@ -196,9 +199,9 @@ def svc_main(sid, input_audio):
|
|
196 |
if not os.path.exists(tmpfile_path):
|
197 |
os.makedirs(tmpfile_path)
|
198 |
|
199 |
-
|
200 |
|
201 |
-
curr_tmp_path =
|
202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
203 |
accompaniment_filepath = os.path.join(curr_tmp_path, 'accompaniment.wav')
|
204 |
|
@@ -344,10 +347,9 @@ def main():
|
|
344 |
vc_search.click(auto_search, [vc_input2], [vc_output1, vc_input3])
|
345 |
vc_submit.click(svc_main, [sid, vc_input3], [vc_output1, vc_output2])
|
346 |
|
347 |
-
app.queue(max_size=3, api_open=False).launch()
|
348 |
except KeyboardInterrupt:
|
349 |
app.close()
|
350 |
-
separator._get_session().close()
|
351 |
sys.exit(0)
|
352 |
|
353 |
|
|
|
16 |
import soundfile
|
17 |
from pydub import AudioSegment
|
18 |
import uuid
|
19 |
+
from torchspleeter.utils import sound_split
|
20 |
+
from torchspleeter.splitter import Splitter
|
21 |
+
|
22 |
import logging
|
23 |
|
24 |
logging.getLogger('numba').setLevel(logging.WARNING)
|
|
|
86 |
model.eval()
|
87 |
model.to(device)
|
88 |
whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
|
89 |
+
splitter_model = Splitter.from_pretrained(os.path.join("torchspleeter/models/2stems", "spleeter.pth")).to(device).eval()
|
90 |
|
91 |
|
92 |
# warm up
|
|
|
190 |
uuid_value = uuid.uuid4()
|
191 |
uuid_string = str(uuid_value)
|
192 |
input_audio_tmp_file = f'{uuid_string}.wav'
|
193 |
+
tmpfile_path = f'/tmp/{uuid_string}'
|
194 |
|
195 |
#
|
196 |
# prediction = separator.separate(audio)
|
|
|
199 |
if not os.path.exists(tmpfile_path):
|
200 |
os.makedirs(tmpfile_path)
|
201 |
|
202 |
+
sound_split(splitter_model, input_audio_tmp_file, tmpfile_path)
|
203 |
|
204 |
+
curr_tmp_path = tmpfile_path
|
205 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
206 |
accompaniment_filepath = os.path.join(curr_tmp_path, 'accompaniment.wav')
|
207 |
|
|
|
347 |
vc_search.click(auto_search, [vc_input2], [vc_output1, vc_input3])
|
348 |
vc_submit.click(svc_main, [sid, vc_input3], [vc_output1, vc_output2])
|
349 |
|
350 |
+
app.queue(max_size=3, api_open=False).launch(server_port=6006)
|
351 |
except KeyboardInterrupt:
|
352 |
app.close()
|
|
|
353 |
sys.exit(0)
|
354 |
|
355 |
|
torchspleeter/.dockerignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
models/
|
3 |
+
output/
|
4 |
+
docs/
|
5 |
+
.vscode/
|
torchspleeter/.gitignore
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models/
|
2 |
+
output/
|
3 |
+
experiments/
|
4 |
+
data/
|
5 |
+
.vscode/
|
6 |
+
|
7 |
+
# Byte-compiled / optimized / DLL files
|
8 |
+
__pycache__/
|
9 |
+
*.py[cod]
|
10 |
+
*$py.class
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
pip-wheel-metadata/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
*.py,cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
.python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
101 |
+
__pypackages__/
|
102 |
+
|
103 |
+
# Celery stuff
|
104 |
+
celerybeat-schedule
|
105 |
+
celerybeat.pid
|
106 |
+
|
107 |
+
# SageMath parsed files
|
108 |
+
*.sage.py
|
109 |
+
|
110 |
+
# Environments
|
111 |
+
.env
|
112 |
+
.venv
|
113 |
+
env/
|
114 |
+
venv/
|
115 |
+
ENV/
|
116 |
+
env.bak/
|
117 |
+
venv.bak/
|
118 |
+
|
119 |
+
# Spyder project settings
|
120 |
+
.spyderproject
|
121 |
+
.spyproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
|
126 |
+
# mkdocs documentation
|
127 |
+
/site
|
128 |
+
|
129 |
+
# mypy
|
130 |
+
.mypy_cache/
|
131 |
+
.dmypy.json
|
132 |
+
dmypy.json
|
133 |
+
|
134 |
+
# Pyre type checker
|
135 |
+
.pyre/
|
torchspleeter/__init__.py
CHANGED
@@ -1,4 +0,0 @@
|
|
1 |
-
from torchspleeter.command_interface import *
|
2 |
-
|
3 |
-
|
4 |
-
version="0.1.5"
|
|
|
|
|
|
|
|
|
|
torchspleeter/checkpoints/2stems/audio_example.mp3
DELETED
Binary file (263 kB)
|
|
torchspleeter/checkpoints/2stems/testcheckpoint0.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:385ea3372c6a3ceee01f6ded5504bb7ee1e9f0101950ae58869dc18382deb75c
|
3 |
-
size 59050239
|
|
|
|
|
|
|
|
torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:3e4d6ede1ecad091468550773e77934aac3f1e039c0697fc9039aba9b935e344
|
3 |
-
size 59033471
|
|
|
|
|
|
|
|
torchspleeter/command_interface.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
|
3 |
-
This provides an interface to interact with the spleeter system on
|
4 |
-
|
5 |
-
|
6 |
-
"""
|
7 |
-
|
8 |
-
import os
|
9 |
-
from torchspleeter.estimator import Estimator
|
10 |
-
import argparse
|
11 |
-
import uuid
|
12 |
-
import numpy as np
|
13 |
-
import librosa
|
14 |
-
import soundfile
|
15 |
-
import torch
|
16 |
-
import pydub
|
17 |
-
import os
|
18 |
-
import shutil
|
19 |
-
|
20 |
-
|
21 |
-
def split_to_parts(inputaudiofile, outputdir, instruments=2, models=[]):
|
22 |
-
filedata = pydub.AudioSegment.from_file(inputaudiofile)
|
23 |
-
sr = filedata.frame_rate
|
24 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
25 |
-
# es = Estimator(2, './checkpoints/2stems/model').to(device)
|
26 |
-
# es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
|
27 |
-
es = Estimator()
|
28 |
-
es.eval()
|
29 |
-
|
30 |
-
# load wav audio
|
31 |
-
testaudiofile = inputaudiofile
|
32 |
-
channels = filedata.channels
|
33 |
-
mono_selection = False
|
34 |
-
if not os.path.exists(outputdir):
|
35 |
-
os.makedirs(outputdir, exist_ok=True)
|
36 |
-
if channels == 1:
|
37 |
-
mono_selection = True
|
38 |
-
multichannel = pydub.AudioSegment.from_mono_audiosegments(filedata, filedata)
|
39 |
-
os.makedirs(os.path.join(outputdir, 'tmp'), exist_ok=True)
|
40 |
-
testaudiofile = os.path.join(outputdir, "tmp" + str(uuid.uuid4()) + "." + testaudiofile.split('.')[-1])
|
41 |
-
# testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
|
42 |
-
multichannel.export(out_f=testaudiofile, format=testaudiofile.split('.')[-1])
|
43 |
-
print(mono_selection)
|
44 |
-
print(channels)
|
45 |
-
wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast', sr=sr)
|
46 |
-
wav = torch.Tensor(wav).to(device)
|
47 |
-
if mono_selection:
|
48 |
-
shutil.rmtree(os.path.join(outputdir, "tmp"))
|
49 |
-
# os.remove(testaudiofile)
|
50 |
-
wavs = es.separate(wav)
|
51 |
-
outputname = str(uuid.uuid4())
|
52 |
-
returnarray = []
|
53 |
-
for i in range(len(wavs)):
|
54 |
-
finaloutput = os.path.join(outputdir, outputname)
|
55 |
-
fname = '-out_{}.wav'.format(i)
|
56 |
-
fname = finaloutput + fname
|
57 |
-
print('Writing ', fname)
|
58 |
-
soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
|
59 |
-
returnarray.append(fname)
|
60 |
-
# write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
|
61 |
-
return returnarray
|
62 |
-
|
63 |
-
|
64 |
-
def get_file_list(dirname):
|
65 |
-
outputfilelist = []
|
66 |
-
for subdir, dirs, files in os.walk(dirname):
|
67 |
-
for file in files:
|
68 |
-
outputfilelist.append(os.path.join(subdir, file))
|
69 |
-
|
70 |
-
return outputfilelist
|
71 |
-
|
72 |
-
|
73 |
-
def main():
|
74 |
-
parser = argparse.ArgumentParser(
|
75 |
-
description='torchspleeter allows you to separate instrumentals from audio (vocals, instruments, background noise, etc) in a simple, cross platform manner')
|
76 |
-
parser.add_argument('-i', '--inputfile', help='Input Audio File to split into instrumentals', required=True)
|
77 |
-
parser.add_argument('-o', '--output', help='Output directory to deposit split audio', required=True)
|
78 |
-
parser.add_argument('-n', '--number', help="Number of instruments in the model (default 2)", required=False,
|
79 |
-
default=2, type=int)
|
80 |
-
parser.add_argument('-m', '--modeldir',
|
81 |
-
help="directory containing number of pre-converted torch compatible model components",
|
82 |
-
required=False)
|
83 |
-
args = vars(parser.parse_args())
|
84 |
-
print(args)
|
85 |
-
if args['modeldir'] is not None:
|
86 |
-
modelfiles = get_file_list(args['modeldir'])
|
87 |
-
if len(modelfiles) != args['number']:
|
88 |
-
raise ValueError("You must have the same number of models as you do number of instruments!")
|
89 |
-
else:
|
90 |
-
args['modeldir'] = []
|
91 |
-
outputfiles = split_to_parts(args['inputfile'], args['output'], args['number'], args['modeldir'])
|
92 |
-
print("Your output files are:")
|
93 |
-
for item in outputfiles:
|
94 |
-
print(item)
|
95 |
-
|
96 |
-
|
97 |
-
if __name__ == "__main__":
|
98 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torchspleeter/data.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import musdb
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
|
10 |
+
class MusdbDataset(Dataset):
|
11 |
+
def __init__(
|
12 |
+
self, root="data/musdb18-wav", is_train: bool = True, targets: List[str] = None
|
13 |
+
) -> None:
|
14 |
+
super().__init__()
|
15 |
+
root = Path(root)
|
16 |
+
assert root.exists(), f"Path does not exist: {root}"
|
17 |
+
self.mus = musdb.DB(
|
18 |
+
root=root,
|
19 |
+
subsets=["train" if is_train else "test"],
|
20 |
+
is_wav=True,
|
21 |
+
)
|
22 |
+
self.targets = [s for s in targets] if targets else ["vocals", "accompaniment"]
|
23 |
+
|
24 |
+
def __len__(self) -> int:
|
25 |
+
return len(self.mus)
|
26 |
+
|
27 |
+
def __getitem__(self, index):
|
28 |
+
track = self.mus.tracks[index]
|
29 |
+
track.chunk_duration = 5.0
|
30 |
+
track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
|
31 |
+
x_wav = torch.torch.tensor(track.audio.T, dtype=torch.float32)
|
32 |
+
y_target_wavs = {
|
33 |
+
name: torch.tensor(track.targets[name].audio.T, dtype=torch.float32)
|
34 |
+
for name in self.targets
|
35 |
+
}
|
36 |
+
# original audio (x) and stems (y == targets)
|
37 |
+
return x_wav, y_target_wavs
|
torchspleeter/estimator.py
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from torch import nn
|
6 |
-
import tqdm
|
7 |
-
# from torchaudio.functional import istft
|
8 |
-
|
9 |
-
from torchspleeter.unet import UNet
|
10 |
-
# from .util import tf2pytorch
|
11 |
-
|
12 |
-
import os
|
13 |
-
|
14 |
-
dirname = os.path.dirname(__file__)
|
15 |
-
defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
|
16 |
-
defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
|
17 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
-
|
19 |
-
|
20 |
-
def load_ckpt(model, ckpt):
|
21 |
-
state_dict = model.state_dict()
|
22 |
-
for k, v in ckpt.items():
|
23 |
-
if k in state_dict:
|
24 |
-
target_shape = state_dict[k].shape
|
25 |
-
assert target_shape == v.shape
|
26 |
-
state_dict.update({k: torch.from_numpy(v)})
|
27 |
-
else:
|
28 |
-
print('Ignore ', k)
|
29 |
-
|
30 |
-
model.load_state_dict(state_dict)
|
31 |
-
return model
|
32 |
-
|
33 |
-
|
34 |
-
def pad_and_partition(tensor, T):
|
35 |
-
"""
|
36 |
-
pads zero and partition tensor into segments of length T
|
37 |
-
|
38 |
-
Args:
|
39 |
-
tensor(Tensor): BxCxFxL
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
tensor of size (B*[L/T] x C x F x T)
|
43 |
-
"""
|
44 |
-
old_size = tensor.size(3)
|
45 |
-
new_size = math.ceil(old_size / T) * T
|
46 |
-
tensor = F.pad(tensor, [0, new_size - old_size])
|
47 |
-
[b, c, t, f] = tensor.shape
|
48 |
-
split = new_size // T
|
49 |
-
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
|
50 |
-
|
51 |
-
|
52 |
-
class Estimator(nn.Module):
|
53 |
-
def __init__(self, num_instrumments=2, checkpoint_path=None):
|
54 |
-
super(Estimator, self).__init__()
|
55 |
-
if checkpoint_path is None:
|
56 |
-
checkpoint_path = [defaultmodel0, defaultmodel1]
|
57 |
-
else:
|
58 |
-
if len(checkpoint_path) < 1:
|
59 |
-
checkpoint_path = [defaultmodel0, defaultmodel1]
|
60 |
-
# stft config
|
61 |
-
self.F = 1024
|
62 |
-
self.T = 512
|
63 |
-
self.win_length = 4096
|
64 |
-
self.hop_length = 1024
|
65 |
-
self.win = nn.Parameter(
|
66 |
-
torch.hann_window(self.win_length, device=device),
|
67 |
-
requires_grad=False
|
68 |
-
)
|
69 |
-
|
70 |
-
ckpts = []
|
71 |
-
if len(checkpoint_path) != num_instrumments:
|
72 |
-
raise ValueError("You must submit as many models as there are instruments!")
|
73 |
-
for ckpt_path in checkpoint_path:
|
74 |
-
ckpts.append(torch.load(ckpt_path))
|
75 |
-
|
76 |
-
# self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
|
77 |
-
|
78 |
-
# ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
|
79 |
-
|
80 |
-
# filter
|
81 |
-
self.instruments = nn.ModuleList()
|
82 |
-
for i in range(num_instrumments):
|
83 |
-
print('Loading model for instrumment {}'.format(i))
|
84 |
-
net = UNet(2)
|
85 |
-
ckpt = ckpts[i]
|
86 |
-
net = load_ckpt(net, ckpt)
|
87 |
-
net.eval() # change mode to eval
|
88 |
-
self.instruments.append(net)
|
89 |
-
|
90 |
-
def compute_stft(self, wav):
|
91 |
-
"""
|
92 |
-
Computes stft feature from wav
|
93 |
-
|
94 |
-
Args:
|
95 |
-
wav (Tensor): B x L
|
96 |
-
"""
|
97 |
-
|
98 |
-
stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
|
99 |
-
center=True, return_complex=False, pad_mode='constant')
|
100 |
-
|
101 |
-
# only keep freqs smaller than self.F
|
102 |
-
stft = stft[:, :self.F, :, :]
|
103 |
-
real = stft[:, :, :, 0]
|
104 |
-
im = stft[:, :, :, 1]
|
105 |
-
mag = torch.sqrt(real ** 2 + im ** 2)
|
106 |
-
|
107 |
-
return stft, mag
|
108 |
-
|
109 |
-
def inverse_stft(self, stft):
|
110 |
-
"""Inverses stft to wave form"""
|
111 |
-
|
112 |
-
pad = self.win_length // 2 + 1 - stft.size(1)
|
113 |
-
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
114 |
-
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
|
115 |
-
window=self.win)
|
116 |
-
return wav.detach()
|
117 |
-
|
118 |
-
def separate(self, wav):
|
119 |
-
"""
|
120 |
-
Separates stereo wav into different tracks corresponding to different instruments
|
121 |
-
|
122 |
-
Args:
|
123 |
-
wav (tensor): 2 x L
|
124 |
-
"""
|
125 |
-
|
126 |
-
# stft - 2 X F x L x 2
|
127 |
-
# stft_mag - 2 X F x L
|
128 |
-
stft, stft_mag = self.compute_stft(wav)
|
129 |
-
|
130 |
-
L = stft.size(2)
|
131 |
-
|
132 |
-
# 1 x 2 x F x T
|
133 |
-
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
|
134 |
-
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
|
135 |
-
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
136 |
-
|
137 |
-
B = stft_mag.shape[0]
|
138 |
-
|
139 |
-
# compute instruments' mask
|
140 |
-
masks = []
|
141 |
-
for net in self.instruments:
|
142 |
-
mask = net(stft_mag)
|
143 |
-
masks.append(mask)
|
144 |
-
|
145 |
-
# compute denominator
|
146 |
-
mask_sum = sum([m ** 2 for m in masks])
|
147 |
-
mask_sum += 1e-10
|
148 |
-
|
149 |
-
wavs = []
|
150 |
-
for mask in tqdm.tqdm(masks):
|
151 |
-
mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
|
152 |
-
mask = mask.transpose(2, 3) # B x 2 X F x T
|
153 |
-
|
154 |
-
mask = torch.cat(
|
155 |
-
torch.split(mask, 1, dim=0), dim=3)
|
156 |
-
|
157 |
-
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
|
158 |
-
stft_masked = stft * mask
|
159 |
-
wav_masked = self.inverse_stft(stft_masked)
|
160 |
-
|
161 |
-
wavs.append(wav_masked)
|
162 |
-
|
163 |
-
return wavs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torchspleeter/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.14.0
|
2 |
+
librosa
|
3 |
+
musdb==0.4.0
|
4 |
+
soundfile
|
5 |
+
tensorflow==2.10.0
|
6 |
+
torch==1.13.0
|
7 |
+
torchaudio
|
8 |
+
typer
|
torchspleeter/run.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import typer
|
5 |
+
from utils import sound_split
|
6 |
+
|
7 |
+
app = typer.Typer()
|
8 |
+
|
9 |
+
|
10 |
+
@app.command()
|
11 |
+
def model_summary() -> None:
|
12 |
+
from unet import UNet
|
13 |
+
|
14 |
+
net = UNet()
|
15 |
+
print(net)
|
16 |
+
|
17 |
+
|
18 |
+
@app.command()
|
19 |
+
def test() -> None:
|
20 |
+
from unet import UNet
|
21 |
+
|
22 |
+
batch_size = 5
|
23 |
+
n_channels = 2
|
24 |
+
x = torch.randn(batch_size, n_channels, 512, 128)
|
25 |
+
print(x.shape)
|
26 |
+
net = UNet(in_channels=n_channels)
|
27 |
+
y = net.forward(x)
|
28 |
+
print(y.shape)
|
29 |
+
|
30 |
+
|
31 |
+
@app.command()
|
32 |
+
def split(
|
33 |
+
model_path: str = "models/2stems/model",
|
34 |
+
input: str = "data/audio_example.mp3",
|
35 |
+
output_dir: str = "output",
|
36 |
+
offset: float = 0,
|
37 |
+
duration: float = 30,
|
38 |
+
write_src: bool = False,
|
39 |
+
) -> None:
|
40 |
+
|
41 |
+
from splitter import Splitter
|
42 |
+
|
43 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
44 |
+
splitter_model = Splitter.from_pretrained(model_path).to(device).eval()
|
45 |
+
|
46 |
+
sound_split(splitter_model, input, output_dir, write_src)
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
app()
|
torchspleeter/splitter.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor, nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from .unet import UNet
|
9 |
+
|
10 |
+
|
11 |
+
def batchify(tensor: Tensor, T: int) -> Tensor:
|
12 |
+
"""
|
13 |
+
partition tensor into segments of length T, zero pad any ragged samples
|
14 |
+
Args:
|
15 |
+
tensor(Tensor): BxCxFxL
|
16 |
+
Returns:
|
17 |
+
tensor of size (B*[L/T] x C x F x T)
|
18 |
+
"""
|
19 |
+
# Zero pad the original tensor to an even multiple of T
|
20 |
+
orig_size = tensor.size(-1)
|
21 |
+
new_size = math.ceil(orig_size / T) * T
|
22 |
+
tensor = F.pad(tensor, [0, new_size - orig_size])
|
23 |
+
# Partition the tensor into multiple samples of length T and stack them into a batch
|
24 |
+
return torch.cat(torch.split(tensor, T, dim=-1), dim=0)
|
25 |
+
|
26 |
+
|
27 |
+
class Splitter(nn.Module):
|
28 |
+
def __init__(self, stem_names: List[str] = None):
|
29 |
+
super(Splitter, self).__init__()
|
30 |
+
|
31 |
+
# stft config
|
32 |
+
self.F = 1024
|
33 |
+
self.T = 512
|
34 |
+
self.win_length = 4096
|
35 |
+
self.hop_length = 1024
|
36 |
+
self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False)
|
37 |
+
|
38 |
+
self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2),
|
39 |
+
'accompaniment': UNet(in_channels=2)})
|
40 |
+
|
41 |
+
def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]:
|
42 |
+
"""
|
43 |
+
Computes stft feature from wav
|
44 |
+
Args:
|
45 |
+
wav (Tensor): B x L
|
46 |
+
"""
|
47 |
+
stft = torch.stft(
|
48 |
+
wav,
|
49 |
+
n_fft=self.win_length,
|
50 |
+
hop_length=self.hop_length,
|
51 |
+
window=self.win,
|
52 |
+
center=True,
|
53 |
+
return_complex=False,
|
54 |
+
pad_mode="constant",
|
55 |
+
)
|
56 |
+
|
57 |
+
# only keep freqs smaller than self.F
|
58 |
+
stft = stft[:, : self.F, :, :]
|
59 |
+
real = stft[:, :, :, 0]
|
60 |
+
im = stft[:, :, :, 1]
|
61 |
+
mag = torch.sqrt(real ** 2 + im ** 2)
|
62 |
+
|
63 |
+
return stft, mag
|
64 |
+
|
65 |
+
def inverse_stft(self, stft: Tensor) -> Tensor:
|
66 |
+
"""Inverses stft to wave form"""
|
67 |
+
|
68 |
+
pad = self.win_length // 2 + 1 - stft.size(1)
|
69 |
+
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
70 |
+
wav = torch.istft(
|
71 |
+
stft,
|
72 |
+
self.win_length,
|
73 |
+
hop_length=self.hop_length,
|
74 |
+
center=True,
|
75 |
+
window=self.win,
|
76 |
+
)
|
77 |
+
return wav.detach()
|
78 |
+
|
79 |
+
def forward(self, wav: Tensor) -> Dict[str, Tensor]:
|
80 |
+
"""
|
81 |
+
Separates stereo wav into different tracks (1 predicted track per stem)
|
82 |
+
Args:
|
83 |
+
wav (tensor): 2 x L
|
84 |
+
Returns:
|
85 |
+
masked stfts by track name
|
86 |
+
"""
|
87 |
+
# stft - 2 X F x L x 2
|
88 |
+
# stft_mag - 2 X F x L
|
89 |
+
stft, stft_mag = self.compute_stft(wav.squeeze())
|
90 |
+
|
91 |
+
L = stft.size(2)
|
92 |
+
|
93 |
+
# 1 x 2 x F x T
|
94 |
+
stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
|
95 |
+
stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T
|
96 |
+
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
97 |
+
|
98 |
+
# compute stems' mask
|
99 |
+
masks = {name: net(stft_mag) for name, net in self.stems.items()}
|
100 |
+
|
101 |
+
# compute denominator
|
102 |
+
mask_sum = sum([m ** 2 for m in masks.values()])
|
103 |
+
mask_sum += 1e-10
|
104 |
+
|
105 |
+
def apply_mask(mask):
|
106 |
+
mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
|
107 |
+
mask = mask.transpose(2, 3) # B x 2 X F x T
|
108 |
+
|
109 |
+
mask = torch.cat(torch.split(mask, 1, dim=0), dim=3)
|
110 |
+
|
111 |
+
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
|
112 |
+
stft_masked = stft * mask
|
113 |
+
return stft_masked
|
114 |
+
|
115 |
+
return {name: apply_mask(m) for name, m in masks.items()}
|
116 |
+
|
117 |
+
def separate(self, wav: Tensor) -> Dict[str, Tensor]:
|
118 |
+
"""
|
119 |
+
Separates stereo wav into different tracks (1 predicted track per stem)
|
120 |
+
Args:
|
121 |
+
wav (tensor): 2 x L
|
122 |
+
Returns:
|
123 |
+
wavs by track name
|
124 |
+
"""
|
125 |
+
|
126 |
+
stft_masks = self.forward(wav)
|
127 |
+
|
128 |
+
return {
|
129 |
+
name: self.inverse_stft(stft_masked)
|
130 |
+
for name, stft_masked in stft_masks.items()
|
131 |
+
}
|
132 |
+
|
133 |
+
@classmethod
|
134 |
+
def from_pretrained(cls, model_path: str):
|
135 |
+
checkpoint = torch.load(model_path)
|
136 |
+
model = cls()
|
137 |
+
model.load_state_dict(checkpoint)
|
138 |
+
return model
|
139 |
+
|
140 |
+
|
torchspleeter/test/test_estimator.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import librosa
|
3 |
-
import soundfile
|
4 |
-
import torch
|
5 |
-
import pydub
|
6 |
-
import os
|
7 |
-
from torchspleeter.estimator import Estimator
|
8 |
-
dirname = os.path.dirname(__file__)
|
9 |
-
testfilename = os.path.join(dirname, 'checkpoints/2stems/audio_example.mp3')
|
10 |
-
|
11 |
-
if __name__ == '__main__':
|
12 |
-
sr = 44100
|
13 |
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
-
#es = Estimator(2, './checkpoints/2stems/model').to(device)
|
15 |
-
#es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
|
16 |
-
es=Estimator()
|
17 |
-
es.eval()
|
18 |
-
|
19 |
-
# load wav audio
|
20 |
-
testaudiofile=testfilename
|
21 |
-
filestats=pydub.AudioSegment.from_file(testaudiofile)
|
22 |
-
channels=filestats.channels
|
23 |
-
mono_selection=False
|
24 |
-
if channels==1:
|
25 |
-
mono_selection=True
|
26 |
-
multichannel=pydub.AudioSegment.from_mono_audiosegments(filestats,filestats)
|
27 |
-
testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
|
28 |
-
multichannel.export(out_f=testaudiofile,format=testaudiofile.split('.')[-1])
|
29 |
-
print(mono_selection)
|
30 |
-
print(channels)
|
31 |
-
wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast',sr=sr)
|
32 |
-
wav = torch.Tensor(wav).to(device)
|
33 |
-
if mono_selection:
|
34 |
-
os.remove(testaudiofile)
|
35 |
-
|
36 |
-
|
37 |
-
# normalize audio
|
38 |
-
# wav_torch = wav / (wav.max() + 1e-8)
|
39 |
-
|
40 |
-
wavs = es.separate(wav)
|
41 |
-
for i in range(len(wavs)):
|
42 |
-
fname = 'output/out_{}.wav'.format(i)
|
43 |
-
print('Writing ',fname)
|
44 |
-
soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
|
45 |
-
# write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torchspleeter/tf2pytorch.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
7 |
+
import tensorflow as tf
|
8 |
+
|
9 |
+
|
10 |
+
def parse_int_or_default(s: str, default: int = 0) -> int:
|
11 |
+
try:
|
12 |
+
return int(s)
|
13 |
+
except:
|
14 |
+
return default
|
15 |
+
|
16 |
+
|
17 |
+
def tf2pytorch(checkpoint_path: str) -> Dict:
|
18 |
+
init_vars = tf.train.list_variables(checkpoint_path)
|
19 |
+
|
20 |
+
tf_vars = {}
|
21 |
+
for name, _ in init_vars:
|
22 |
+
try:
|
23 |
+
# print('Loading TF Weight {} with shape {}'.format(name, shape))
|
24 |
+
data = tf.train.load_variable(checkpoint_path, name)
|
25 |
+
tf_vars[name] = data
|
26 |
+
except Exception as e:
|
27 |
+
print(f"Load error: {name}")
|
28 |
+
raise
|
29 |
+
|
30 |
+
layer_idxs = set(
|
31 |
+
[
|
32 |
+
parse_int_or_default(name.split("/")[0].split("_")[-1], default=0)
|
33 |
+
for name in tf_vars.keys()
|
34 |
+
if "conv2d_transpose" in name
|
35 |
+
]
|
36 |
+
)
|
37 |
+
|
38 |
+
n_layers_per_unet = 6
|
39 |
+
n_layers_in_chkpt = max(layer_idxs) + 1
|
40 |
+
assert (
|
41 |
+
n_layers_in_chkpt % 6 == 0
|
42 |
+
), f"expected multiple of {n_layers_per_unet}... ie: {n_layers_per_unet} layers per unet & 1 unet per stem"
|
43 |
+
n_stems = n_layers_in_chkpt // n_layers_per_unet
|
44 |
+
|
45 |
+
stem_names = {
|
46 |
+
2: ["vocals", "accompaniment"],
|
47 |
+
4: ["vocals", "drums", "bass", "other"],
|
48 |
+
5: ["vocals", "piano", "drums", "bass", "other"],
|
49 |
+
}.get(n_stems, [])
|
50 |
+
|
51 |
+
assert stem_names, f"Unsupported stem count: {n_stems}"
|
52 |
+
|
53 |
+
state_dict = {}
|
54 |
+
tf_idx_conv = 0
|
55 |
+
tf_idx_tconv = 0
|
56 |
+
tf_idx_bn = 0
|
57 |
+
|
58 |
+
for stem_name in stem_names:
|
59 |
+
# Encoder Blocks (Down sampling)
|
60 |
+
for layer_idx in range(n_layers_per_unet):
|
61 |
+
prefix = f"stems.{stem_name}.encoder_layers.{layer_idx}"
|
62 |
+
conv_suffix = "" if tf_idx_conv == 0 else f"_{tf_idx_conv}"
|
63 |
+
bn_suffix = "" if tf_idx_bn == 0 else f"_{tf_idx_bn}"
|
64 |
+
|
65 |
+
state_dict[f"{prefix}.conv.weight"] = np.transpose(
|
66 |
+
tf_vars[f"conv2d{conv_suffix}/kernel"], (3, 2, 0, 1)
|
67 |
+
)
|
68 |
+
state_dict[f"{prefix}.conv.bias"] = tf_vars[f"conv2d{conv_suffix}/bias"]
|
69 |
+
tf_idx_conv += 1
|
70 |
+
|
71 |
+
state_dict[f"{prefix}.bn.weight"] = tf_vars[
|
72 |
+
f"batch_normalization{bn_suffix}/gamma"
|
73 |
+
]
|
74 |
+
state_dict[f"{prefix}.bn.bias"] = tf_vars[
|
75 |
+
f"batch_normalization{bn_suffix}/beta"
|
76 |
+
]
|
77 |
+
state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
|
78 |
+
f"batch_normalization{bn_suffix}/moving_mean"
|
79 |
+
]
|
80 |
+
state_dict[f"{prefix}.bn.running_var"] = tf_vars[
|
81 |
+
f"batch_normalization{bn_suffix}/moving_variance"
|
82 |
+
]
|
83 |
+
tf_idx_bn += 1
|
84 |
+
|
85 |
+
# Decoder Blocks (Up sampling)
|
86 |
+
for layer_idx in range(n_layers_per_unet):
|
87 |
+
prefix = f"stems.{stem_name}.decoder_layers.{layer_idx}"
|
88 |
+
tconv_suffix = "" if tf_idx_tconv == 0 else f"_{tf_idx_tconv}"
|
89 |
+
bn_suffix = f"_{tf_idx_bn}"
|
90 |
+
|
91 |
+
state_dict[f"{prefix}.tconv.weight"] = np.transpose(
|
92 |
+
tf_vars[f"conv2d_transpose{tconv_suffix}/kernel"], (3, 2, 0, 1)
|
93 |
+
)
|
94 |
+
state_dict[f"{prefix}.tconv.bias"] = tf_vars[
|
95 |
+
f"conv2d_transpose{tconv_suffix}/bias"
|
96 |
+
]
|
97 |
+
tf_idx_tconv += 1
|
98 |
+
|
99 |
+
state_dict[f"{prefix}.bn.weight"] = tf_vars[
|
100 |
+
f"batch_normalization{bn_suffix}/gamma"
|
101 |
+
]
|
102 |
+
state_dict[f"{prefix}.bn.bias"] = tf_vars[
|
103 |
+
f"batch_normalization{bn_suffix}/beta"
|
104 |
+
]
|
105 |
+
state_dict[f"{prefix}.bn.running_mean"] = tf_vars[
|
106 |
+
f"batch_normalization{bn_suffix}/moving_mean"
|
107 |
+
]
|
108 |
+
state_dict[f"{prefix}.bn.running_var"] = tf_vars[
|
109 |
+
f"batch_normalization{bn_suffix}/moving_variance"
|
110 |
+
]
|
111 |
+
tf_idx_bn += 1
|
112 |
+
|
113 |
+
# Final conv2d
|
114 |
+
state_dict[f"stems.{stem_name}.up_final.weight"] = np.transpose(
|
115 |
+
tf_vars[f"conv2d_{tf_idx_conv}/kernel"], (3, 2, 0, 1)
|
116 |
+
)
|
117 |
+
state_dict[f"stems.{stem_name}.up_final.bias"] = tf_vars[
|
118 |
+
f"conv2d_{tf_idx_conv}/bias"
|
119 |
+
]
|
120 |
+
tf_idx_conv += 1
|
121 |
+
|
122 |
+
return state_dict
|
torchspleeter/train.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import typer
|
7 |
+
from accelerate import Accelerator
|
8 |
+
from accelerate.utils import LoggerType
|
9 |
+
from torch import Tensor
|
10 |
+
from torch.optim import AdamW
|
11 |
+
|
12 |
+
# from torch.optim.lr_scheduler import ReduceLROnPlateau
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from data import MusdbDataset
|
17 |
+
from splitter import Splitter
|
18 |
+
|
19 |
+
DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False)
|
20 |
+
app = typer.Typer(pretty_exceptions_show_locals=False)
|
21 |
+
|
22 |
+
|
23 |
+
def spectrogram_loss(masked_target: Tensor, original: Tensor) -> Tensor:
|
24 |
+
"""
|
25 |
+
masked_target (Tensor): a masked STFT generated by applying a net's
|
26 |
+
estimated mask for source S to the ground truth STFT for source S
|
27 |
+
original (Tensor): an original input mixture
|
28 |
+
"""
|
29 |
+
square_difference = torch.square(masked_target - original)
|
30 |
+
loss_value = torch.mean(square_difference)
|
31 |
+
return loss_value
|
32 |
+
|
33 |
+
|
34 |
+
@app.command()
|
35 |
+
def train(
|
36 |
+
dataset: str = "data/musdb18-wav",
|
37 |
+
output_dir: str = None,
|
38 |
+
fp16: bool = False,
|
39 |
+
cpu: bool = True,
|
40 |
+
max_steps: int = 100,
|
41 |
+
num_train_epochs: int = 1,
|
42 |
+
per_device_train_batch_size: int = 1,
|
43 |
+
effective_batch_size: int = 4,
|
44 |
+
max_grad_norm: float = 0.0,
|
45 |
+
) -> None:
|
46 |
+
if not output_dir:
|
47 |
+
now_str = datetime.now().strftime("%Y%m%d-%H%M%S")
|
48 |
+
output_dir = f"experiments/{now_str}"
|
49 |
+
output_dir = Path(output_dir)
|
50 |
+
logging_dir = output_dir / "tracker_logs"
|
51 |
+
accelerator = Accelerator(
|
52 |
+
fp16=fp16,
|
53 |
+
cpu=cpu,
|
54 |
+
logging_dir=logging_dir,
|
55 |
+
log_with=[LoggerType.TENSORBOARD],
|
56 |
+
)
|
57 |
+
accelerator.init_trackers(logging_dir / "run")
|
58 |
+
|
59 |
+
train_dataset = MusdbDataset(root=dataset, is_train=True)
|
60 |
+
train_dataloader = DataLoader(
|
61 |
+
train_dataset,
|
62 |
+
shuffle=True,
|
63 |
+
batch_size=per_device_train_batch_size,
|
64 |
+
)
|
65 |
+
|
66 |
+
model = Splitter(stem_names=[s for s in train_dataset.targets])
|
67 |
+
optimizer = AdamW(
|
68 |
+
model.parameters(),
|
69 |
+
lr=1e-3,
|
70 |
+
eps=1e-8,
|
71 |
+
)
|
72 |
+
model, optimizer, train_dataloader = accelerator.prepare(
|
73 |
+
model, optimizer, train_dataloader
|
74 |
+
)
|
75 |
+
|
76 |
+
num_train_steps = (
|
77 |
+
max_steps if max_steps > 0 else len(train_dataloader) * num_train_epochs
|
78 |
+
)
|
79 |
+
accelerator.print(f"Num train steps: {num_train_steps}")
|
80 |
+
|
81 |
+
step_batch_size = per_device_train_batch_size * accelerator.num_processes
|
82 |
+
gradient_accumulation_steps = max(
|
83 |
+
1,
|
84 |
+
effective_batch_size // step_batch_size,
|
85 |
+
)
|
86 |
+
|
87 |
+
accelerator.print(
|
88 |
+
f"Gradient Accumulation Steps: {gradient_accumulation_steps}\nEffective Batch Size: {gradient_accumulation_steps * step_batch_size}"
|
89 |
+
)
|
90 |
+
global_step = 0
|
91 |
+
while global_step < num_train_steps:
|
92 |
+
accelerator.wait_for_everyone()
|
93 |
+
# accelerator.print(f"global step: {global_step}")
|
94 |
+
# accelerator.print("running train...")
|
95 |
+
model.train()
|
96 |
+
batch_iterator = tqdm(
|
97 |
+
train_dataloader,
|
98 |
+
desc="Batch",
|
99 |
+
disable=((not accelerator.is_local_main_process) or DISABLE_TQDM),
|
100 |
+
)
|
101 |
+
for batch_idx, batch in enumerate(batch_iterator):
|
102 |
+
assert per_device_train_batch_size == 1, "For now limit to 1."
|
103 |
+
x_wav, y_target_wavs = batch
|
104 |
+
predictions = model(x_wav)
|
105 |
+
stem_losses = []
|
106 |
+
for name, masked_stft in predictions.items():
|
107 |
+
target_stft, _ = model.compute_stft(y_target_wavs[name].squeeze())
|
108 |
+
loss = spectrogram_loss(
|
109 |
+
masked_target=masked_stft,
|
110 |
+
original=target_stft,
|
111 |
+
)
|
112 |
+
stem_losses.append(loss)
|
113 |
+
accelerator.log({f"train-loss-{name}": 1.0 * loss}, step=global_step)
|
114 |
+
|
115 |
+
total_loss = (
|
116 |
+
torch.sum(torch.stack(stem_losses)) / gradient_accumulation_steps
|
117 |
+
)
|
118 |
+
accelerator.print(f"global step: {global_step}\tloss: {total_loss:.4f}")
|
119 |
+
accelerator.log({f"train-loss": 1.0 * total_loss}, step=global_step)
|
120 |
+
accelerator.backward(total_loss)
|
121 |
+
if (batch_idx + 1) % gradient_accumulation_steps == 0:
|
122 |
+
if max_grad_norm > 0:
|
123 |
+
accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
|
124 |
+
optimizer.step()
|
125 |
+
optimizer.zero_grad()
|
126 |
+
global_step += 1
|
127 |
+
|
128 |
+
accelerator.wait_for_everyone()
|
129 |
+
accelerator.end_training()
|
130 |
+
accelerator.print(f"Saving model to {output_dir}...")
|
131 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
132 |
+
unwrapped_model.save_pretrained(
|
133 |
+
output_dir,
|
134 |
+
save_function=accelerator.save,
|
135 |
+
state_dict=accelerator.get_state_dict(model),
|
136 |
+
)
|
137 |
+
|
138 |
+
accelerator.wait_for_everyone()
|
139 |
+
accelerator.print("DONE!")
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
app()
|
torchspleeter/unet.py
CHANGED
@@ -1,97 +1,97 @@
|
|
|
|
1 |
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
self.padding_setting = padding_setting
|
18 |
-
|
19 |
-
def forward(self, x):
|
20 |
-
l,r,t,b = self.padding_setting
|
21 |
-
return x[:,:,l:-r,t:-b]
|
22 |
-
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
nn.Sequential(
|
28 |
-
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
|
29 |
-
nn.LeakyReLU(0.2))
|
30 |
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
class UNet(nn.Module):
|
46 |
-
def __init__(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
self.
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
self.up6 = up_block(32, 1)
|
61 |
-
self.up7 = nn.Sequential(
|
62 |
-
nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3),
|
63 |
-
nn.Sigmoid()
|
64 |
)
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
d6_conv = self.down6_conv(d5)
|
83 |
-
d6 = self.down6_act(d6_conv)
|
84 |
-
|
85 |
-
u1 = self.up1(d6_conv)
|
86 |
-
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
|
87 |
-
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
|
88 |
-
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
|
89 |
-
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
|
90 |
-
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
|
91 |
-
u7 = self.up7(u6)
|
92 |
-
return u7 * x
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
import torch
|
3 |
+
from torch import nn, Tensor
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class EncoderBlock(nn.Module):
|
8 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2))
|
11 |
+
self.bn = nn.BatchNorm2d(
|
12 |
+
num_features=out_channels,
|
13 |
+
track_running_stats=True,
|
14 |
+
eps=1e-3,
|
15 |
+
momentum=0.01,
|
16 |
+
)
|
17 |
+
self.relu = nn.LeakyReLU(negative_slope=0.2)
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
|
20 |
+
down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0))
|
21 |
+
return down, self.relu(self.bn(down))
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
+
class DecoderBlock(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self, in_channels: int, out_channels: int, dropout_prob: float = 0.0
|
27 |
+
) -> None:
|
28 |
+
super().__init__()
|
29 |
+
self.tconv = nn.ConvTranspose2d(
|
30 |
+
in_channels, out_channels, kernel_size=5, stride=2
|
31 |
+
)
|
32 |
+
self.relu = nn.ReLU()
|
33 |
+
self.bn = nn.BatchNorm2d(
|
34 |
+
out_channels, track_running_stats=True, eps=1e-3, momentum=0.01
|
35 |
+
)
|
36 |
+
self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()
|
37 |
|
38 |
+
def forward(self, input: Tensor) -> Tensor:
|
39 |
+
up = self.tconv(input)
|
40 |
+
# reverse padding
|
41 |
+
l, r, t, b = 1, 2, 1, 2
|
42 |
+
up = up[:, :, l:-r, t:-b]
|
43 |
+
return self.dropout(self.bn(self.relu(up)))
|
44 |
|
45 |
|
46 |
class UNet(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
n_layers: int = 6,
|
50 |
+
in_channels: int = 1,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
# DownSample layers
|
55 |
+
down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
|
56 |
+
self.encoder_layers = nn.ModuleList(
|
57 |
+
[
|
58 |
+
EncoderBlock(in_channels=in_ch, out_channels=out_ch)
|
59 |
+
for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
|
60 |
+
]
|
|
|
|
|
|
|
|
|
61 |
)
|
62 |
|
63 |
+
# UpSample layers
|
64 |
+
up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
|
65 |
+
up_set.reverse()
|
66 |
+
self.decoder_layers = nn.ModuleList(
|
67 |
+
[
|
68 |
+
DecoderBlock(
|
69 |
+
# doubled for concatenated inputs (skip connections)
|
70 |
+
in_channels=in_ch if i == 0 else in_ch * 2,
|
71 |
+
out_channels=out_ch,
|
72 |
+
# 50 % dropout... first 3 layers only
|
73 |
+
dropout_prob=0.5 if i < 3 else 0,
|
74 |
+
)
|
75 |
+
for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))
|
76 |
+
]
|
77 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
# reconstruct the final mask same as the original channels
|
80 |
+
self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
|
81 |
+
self.sigmoid = nn.Sigmoid()
|
82 |
+
|
83 |
+
def forward(self, input: Tensor) -> Tensor:
|
84 |
+
encoder_outputs_pre_act = []
|
85 |
+
x = input
|
86 |
+
for down in self.encoder_layers:
|
87 |
+
conv, x = down(x)
|
88 |
+
encoder_outputs_pre_act.append(conv)
|
89 |
+
|
90 |
+
for i, up in enumerate(self.decoder_layers):
|
91 |
+
if i == 0:
|
92 |
+
x = up(encoder_outputs_pre_act.pop())
|
93 |
+
else:
|
94 |
+
# merge skip connection
|
95 |
+
x = up(torch.concat([encoder_outputs_pre_act.pop(), x], axis=1))
|
96 |
+
mask = self.sigmoid(self.up_final(x))
|
97 |
+
return mask * input
|
torchspleeter/utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import soundfile
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .splitter import Splitter
|
8 |
+
|
9 |
+
|
10 |
+
def sound_split(
|
11 |
+
model: Splitter,
|
12 |
+
input: str = "data/audio_example.mp3",
|
13 |
+
output_dir: str = "output",
|
14 |
+
write_src: bool = False,
|
15 |
+
) -> None:
|
16 |
+
sr = 44100
|
17 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
18 |
+
# splitter = Splitter.from_pretrained(model_path).to(device).eval()
|
19 |
+
|
20 |
+
# load wav audio
|
21 |
+
fpath_src = Path(input)
|
22 |
+
wav, _ = librosa.load(
|
23 |
+
fpath_src,
|
24 |
+
mono=False,
|
25 |
+
res_type="kaiser_fast",
|
26 |
+
sr=sr,
|
27 |
+
)
|
28 |
+
wav = torch.Tensor(wav).to(device)
|
29 |
+
|
30 |
+
# normalize audio
|
31 |
+
# wav_torch = wav / (wav.max() + 1e-8)
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
stems = model.separate(wav)
|
35 |
+
|
36 |
+
if write_src:
|
37 |
+
stems["input"] = wav
|
38 |
+
for name, stem in stems.items():
|
39 |
+
fpath_dst = Path(output_dir) / f"{name}.wav"
|
40 |
+
print(f"Writing {fpath_dst}")
|
41 |
+
fpath_dst.parent.mkdir(exist_ok=True)
|
42 |
+
soundfile.write(fpath_dst, stem.cpu().detach().numpy().T, sr, "PCM_16")
|
warm_up.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from spleeter.separator import Separator
|
2 |
|
3 |
-
separator = Separator('spleeter:2stems')
|
4 |
# warm up
|
5 |
separator.separate_to_file('warm.wav', '/tmp/warm')
|
|
|
1 |
from spleeter.separator import Separator
|
2 |
|
3 |
+
separator = Separator('spleeter.pth:2stems')
|
4 |
# warm up
|
5 |
separator.separate_to_file('warm.wav', '/tmp/warm')
|