fabiogra commited on
Commit
b0a9f8f
·
1 Parent(s): 2e3ca25

feat: add separate examples, logs and improvements

Browse files
app/helpers.py CHANGED
@@ -1,5 +1,4 @@
1
  import json
2
- import logging
3
  import os
4
  import random
5
  from base64 import b64encode
@@ -8,7 +7,6 @@ from pathlib import Path
8
 
9
  import matplotlib.pyplot as plt
10
  import numpy as np
11
- import requests
12
  import streamlit as st
13
  from PIL import Image
14
  from pydub import AudioSegment
@@ -20,7 +18,7 @@ extensions = ["mp3", "wav", "ogg", "flac"] # we will look for all those file ty
20
 
21
 
22
  def check_file_availability(url):
23
- exit_status = os.system(f"wget --spider {url}")
24
  return exit_status == 0
25
 
26
 
@@ -33,18 +31,6 @@ def url_is_valid(url):
33
  st.error("Extension not supported.")
34
  return False
35
  try:
36
- r = requests.get(url)
37
- r.raise_for_status()
38
- return True
39
- except requests.exceptions.HTTPError as err:
40
- msg = (
41
- "requests get failed with status code "
42
- + str(err.response.status_code)
43
- + " for url "
44
- + url
45
- + ". Try wget spider."
46
- )
47
- logging.error(msg)
48
  return check_file_availability(url)
49
  except Exception:
50
  st.error("URL is not valid.")
@@ -79,12 +65,19 @@ def plot_audio(_audio_segment: AudioSegment, *args, **kwargs) -> Image.Image:
79
 
80
 
81
  @st.cache_data(show_spinner=False)
82
- def load_list_of_songs():
83
- return json.load(open("sample_songs.json"))
 
 
 
 
 
84
 
85
 
86
  def get_random_song():
87
  sample_songs = load_list_of_songs()
 
 
88
  name, url = random.choice(list(sample_songs.items()))
89
  return name, url
90
 
 
1
  import json
 
2
  import os
3
  import random
4
  from base64 import b64encode
 
7
 
8
  import matplotlib.pyplot as plt
9
  import numpy as np
 
10
  import streamlit as st
11
  from PIL import Image
12
  from pydub import AudioSegment
 
18
 
19
 
20
  def check_file_availability(url):
21
+ exit_status = os.system(f"wget -o --spider {url}")
22
  return exit_status == 0
23
 
24
 
 
31
  st.error("Extension not supported.")
32
  return False
33
  try:
 
 
 
 
 
 
 
 
 
 
 
 
34
  return check_file_availability(url)
35
  except Exception:
36
  st.error("URL is not valid.")
 
65
 
66
 
67
  @st.cache_data(show_spinner=False)
68
+ def load_list_of_songs(path="sample_songs.json"):
69
+ if os.environ.get("PREPARE_SAMPLES"):
70
+ return json.load(open(path))
71
+ else:
72
+ st.error(
73
+ "No examples available. You need to set the environment variable `PREPARE_SAMPLES=true`"
74
+ )
75
 
76
 
77
  def get_random_song():
78
  sample_songs = load_list_of_songs()
79
+ if sample_songs is None:
80
+ return None, None
81
  name, url = random.choice(list(sample_songs.items()))
82
  return name, url
83
 
app/pages/Separate.py CHANGED
@@ -1,21 +1,22 @@
1
  import os
2
  from pathlib import Path
 
 
3
 
4
  import streamlit as st
5
- from streamlit_option_menu import option_menu
6
-
7
- from service.demucs_runner import separator
8
  from helpers import (
9
  load_audio_segment,
 
10
  plot_audio,
11
  st_local_audio,
12
  url_is_valid,
13
  )
 
 
 
14
 
15
- from service.vocal_remover.runner import separate, load_model
16
-
17
- from footer import footer
18
- from header import header
19
 
20
  label_sources = {
21
  "no_vocals.mp3": "🎶 Instrumental",
@@ -27,28 +28,104 @@ label_sources = {
27
  "other.mp3": "🎶 Other",
28
  }
29
 
30
- extensions = ["mp3", "wav", "ogg", "flac"]
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
 
33
  out_path = Path("/tmp")
34
  in_path = Path("/tmp")
35
 
36
 
 
 
 
 
 
 
 
 
 
 
37
  def reset_execution():
38
  st.session_state.executed = False
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def body():
42
  filename = None
 
 
 
 
 
 
 
 
 
 
 
 
43
  cols = st.columns([1, 3, 2, 1])
44
  with cols[1]:
45
- with st.columns([1, 5, 1])[1]:
46
  option = option_menu(
47
  menu_title=None,
48
- options=["Upload File", "From URL"],
49
- icons=["cloud-upload-fill", "link-45deg"],
50
  orientation="horizontal",
51
- styles={"container": {"width": "100%", "margin": "0px", "padding": "0px"}},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  key="option_separate",
53
  )
54
  if option == "Upload File":
@@ -64,18 +141,32 @@ def body():
64
  filename = uploaded_file.name
65
  st_local_audio(in_path / filename, key="input_upload_file")
66
 
67
- elif option == "From URL": # TODO: show examples
68
  url = st.text_input(
69
  "Paste the URL of the audio file",
70
  key="url_input",
71
  help="Supported formats: mp3, wav, ogg, flac.",
72
  )
73
- if url != "":
74
- if url_is_valid(url):
75
- with st.spinner("Downloading audio..."):
76
- filename = url.split("/")[-1]
77
- os.system(f"wget -O {in_path / filename} {url}")
78
  st_local_audio(in_path / filename, key="input_from_url")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  with cols[2]:
80
  separation_mode = st.selectbox(
81
  "Choose the separation mode",
@@ -92,6 +183,7 @@ def body():
92
  max_duration = 30
93
  else:
94
  max_duration = 15
 
95
 
96
  if filename is not None:
97
  song = load_audio_segment(in_path / filename, filename.split(".")[-1])
@@ -124,10 +216,10 @@ def body():
124
  st.session_state.executed = False
125
 
126
  if not st.session_state.executed:
 
127
  song.export(in_path / filename, format=filename.split(".")[-1])
128
  with st.spinner("Separating source audio, it will take a while..."):
129
- if separation_mode == "Vocals & Instrumental (Faster)":
130
- model_name = "vocal_remover"
131
  model, device = load_model(pretrained_model="baseline.pth")
132
  separate(
133
  input=in_path / filename,
@@ -137,13 +229,7 @@ def body():
137
  )
138
  else:
139
  stem = None
140
- model_name = "htdemucs"
141
- if (
142
- separation_mode
143
- == "Vocal, Drums, Bass, Guitar, Piano & Other (Slowest)"
144
- ):
145
- model_name = "htdemucs_6s"
146
- elif separation_mode == "Vocals & Instrumental (High Quality, Slower)":
147
  stem = "vocals"
148
 
149
  separator(
@@ -162,39 +248,12 @@ def body():
162
  start_time=start_time,
163
  end_time=end_time,
164
  )
165
- last_dir = ".".join(filename.split(".")[:-1])
166
  filename = None
167
  st.session_state.executed = True
168
-
169
- def get_sources(path):
170
- sources = {}
171
- for file in [
172
- "no_vocals.mp3",
173
- "vocals.mp3",
174
- "drums.mp3",
175
- "bass.mp3",
176
- "guitar.mp3",
177
- "piano.mp3",
178
- "other.mp3",
179
- ]:
180
- fullpath = path / file
181
- if fullpath.exists():
182
- sources[file] = fullpath
183
- return sources
184
-
185
- sources = get_sources(out_path / Path(model_name) / last_dir)
186
- tab_sources = st.tabs([f"**{label_sources.get(k)}**" for k in sources.keys()])
187
- for i, (file, pathname) in enumerate(sources.items()):
188
- with tab_sources[i]:
189
- cols = st.columns(2)
190
- with cols[0]:
191
- auseg = load_audio_segment(pathname, "mp3")
192
- st.image(
193
- plot_audio(auseg, title="", file=file),
194
- use_column_width="always",
195
- )
196
- with cols[1]:
197
- st_local_audio(pathname, key=f"output_{file}")
198
 
199
 
200
  if __name__ == "__main__":
 
1
  import os
2
  from pathlib import Path
3
+ from typing import List
4
+ from loguru import logger as log
5
 
6
  import streamlit as st
7
+ from footer import footer
8
+ from header import header
 
9
  from helpers import (
10
  load_audio_segment,
11
+ load_list_of_songs,
12
  plot_audio,
13
  st_local_audio,
14
  url_is_valid,
15
  )
16
+ from service.demucs_runner import separator
17
+ from service.vocal_remover.runner import load_model, separate
18
+ from streamlit_option_menu import option_menu
19
 
 
 
 
 
20
 
21
  label_sources = {
22
  "no_vocals.mp3": "🎶 Instrumental",
 
28
  "other.mp3": "🎶 Other",
29
  }
30
 
31
+ separation_mode_to_model = {
32
+ "Vocals & Instrumental (Faster)": ("vocal_remover", ["vocals.mp3", "no_vocals.mp3"]),
33
+ "Vocals & Instrumental (High Quality, Slower)": ("htdemucs", ["vocals.mp3", "no_vocals.mp3"]),
34
+ "Vocals, Drums, Bass & Other (Slower)": (
35
+ "htdemucs",
36
+ ["vocals.mp3", "drums.mp3", "bass.mp3", "other.mp3"],
37
+ ),
38
+ "Vocal, Drums, Bass, Guitar, Piano & Other (Slowest)": (
39
+ "htdemucs_6s",
40
+ ["vocals.mp3", "drums.mp3", "bass.mp3", "guitar.mp3", "piano.mp3", "other.mp3"],
41
+ ),
42
+ }
43
 
44
+ extensions = ["mp3", "wav", "ogg", "flac"]
45
 
46
  out_path = Path("/tmp")
47
  in_path = Path("/tmp")
48
 
49
 
50
+ @st.cache_data(show_spinner=False)
51
+ def get_sources(path, file_sources):
52
+ sources = {}
53
+ for file in file_sources:
54
+ fullpath = path / file
55
+ if fullpath.exists():
56
+ sources[file] = fullpath
57
+ return sources
58
+
59
+
60
  def reset_execution():
61
  st.session_state.executed = False
62
 
63
 
64
+ def show_results(model_name: str, dir_name_output: str, file_sources: List):
65
+ sources = get_sources(out_path / Path(model_name) / dir_name_output, file_sources)
66
+ tab_sources = st.tabs([f"**{label_sources.get(k)}**" for k in sources.keys()])
67
+ for i, (file, pathname) in enumerate(sources.items()):
68
+ with tab_sources[i]:
69
+ cols = st.columns(2)
70
+ with cols[0]:
71
+ auseg = load_audio_segment(pathname, "mp3")
72
+ st.image(
73
+ plot_audio(
74
+ auseg,
75
+ title="",
76
+ file=file,
77
+ model_name=model_name,
78
+ dir_name_output=dir_name_output,
79
+ ),
80
+ use_column_width="always",
81
+ )
82
+ with cols[1]:
83
+ st_local_audio(pathname, key=f"output_{file}_{dir_name_output}")
84
+ log.info(f"Displaying results for {dir_name_output}")
85
+
86
+
87
  def body():
88
  filename = None
89
+ name_song = None
90
+ st.markdown(
91
+ """
92
+ <style>
93
+ div[data-baseweb="tab-list"] {
94
+ align-items: center !important;
95
+ justify-content: center !important;
96
+ }
97
+ </style>""",
98
+ unsafe_allow_html=True,
99
+ )
100
+
101
  cols = st.columns([1, 3, 2, 1])
102
  with cols[1]:
103
+ with st.columns([1, 8, 1])[1]:
104
  option = option_menu(
105
  menu_title=None,
106
+ options=["Upload File", "From URL", "Examples"],
107
+ icons=["cloud-upload-fill", "link-45deg", "music-note-list"],
108
  orientation="horizontal",
109
+ styles={
110
+ "container": {
111
+ "width": "100%",
112
+ "height": "3.5rem",
113
+ "margin": "0px",
114
+ "padding": "0px",
115
+ },
116
+ "icon": {"font-size": "1rem"},
117
+ "nav-link": {
118
+ "display": "flex",
119
+ "height": "3rem",
120
+ "justify-content": "center",
121
+ "align-items": "center",
122
+ "text-align": "center",
123
+ "flex-direction": "column",
124
+ "font-size": "1rem",
125
+ "padding-left": "0px",
126
+ "padding-right": "0px",
127
+ },
128
+ },
129
  key="option_separate",
130
  )
131
  if option == "Upload File":
 
141
  filename = uploaded_file.name
142
  st_local_audio(in_path / filename, key="input_upload_file")
143
 
144
+ elif option == "From URL":
145
  url = st.text_input(
146
  "Paste the URL of the audio file",
147
  key="url_input",
148
  help="Supported formats: mp3, wav, ogg, flac.",
149
  )
150
+ if url != "" and url_is_valid(url):
151
+ with st.spinner("Downloading audio..."):
152
+ filename = url.split("/")[-1]
153
+ os.system(f"wget -q -O {in_path / filename} {url}")
 
154
  st_local_audio(in_path / filename, key="input_from_url")
155
+ elif option == "Examples":
156
+ samples_song = load_list_of_songs(path="separate_songs.json")
157
+ if samples_song is not None:
158
+ name_song = st.selectbox(
159
+ label="Select a song",
160
+ options=list(samples_song.keys()),
161
+ format_func=lambda x: x.replace("_", " "),
162
+ index=1,
163
+ key="select_example",
164
+ )
165
+ if (Path("/tmp") / name_song).exists():
166
+ st_local_audio(Path("/tmp") / name_song, key=f"input_from_sample_{name_song}")
167
+ else:
168
+ name_song = None
169
+
170
  with cols[2]:
171
  separation_mode = st.selectbox(
172
  "Choose the separation mode",
 
183
  max_duration = 30
184
  else:
185
  max_duration = 15
186
+ model_name, file_sources = separation_mode_to_model[separation_mode]
187
 
188
  if filename is not None:
189
  song = load_audio_segment(in_path / filename, filename.split(".")[-1])
 
216
  st.session_state.executed = False
217
 
218
  if not st.session_state.executed:
219
+ log.info(f"{option} - Separating {filename} with {separation_mode}...")
220
  song.export(in_path / filename, format=filename.split(".")[-1])
221
  with st.spinner("Separating source audio, it will take a while..."):
222
+ if model_name == "vocal_remover":
 
223
  model, device = load_model(pretrained_model="baseline.pth")
224
  separate(
225
  input=in_path / filename,
 
229
  )
230
  else:
231
  stem = None
232
+ if separation_mode == "Vocals & Instrumental (High Quality, Slower)":
 
 
 
 
 
 
233
  stem = "vocals"
234
 
235
  separator(
 
248
  start_time=start_time,
249
  end_time=end_time,
250
  )
251
+ dir_name_output = ".".join(filename.split(".")[:-1])
252
  filename = None
253
  st.session_state.executed = True
254
+ show_results(model_name, dir_name_output, file_sources)
255
+ elif name_song is not None and option == "Examples":
256
+ show_results(model_name, name_song, file_sources)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
259
  if __name__ == "__main__":
app/style.py CHANGED
@@ -124,6 +124,12 @@ CSS = (
124
  gap: 0rem;
125
  }
126
 
 
 
 
 
 
 
127
 
128
  </style>
129
 
 
124
  gap: 0rem;
125
  }
126
 
127
+ /* center the audio player in Separate page */
128
+ .css-keje6w.e1tzin5v1 {
129
+ display: flex;
130
+ justify-content: center;
131
+ align-items: center;
132
+ }
133
 
134
  </style>
135
 
requirements.in CHANGED
@@ -14,3 +14,4 @@ resampy==0.4.2
14
  stqdm==0.0.5
15
  streamlit_option_menu==0.3.6
16
  htbuilder==0.6.1
 
 
14
  stqdm==0.0.5
15
  streamlit_option_menu==0.3.6
16
  htbuilder==0.6.1
17
+ loguru==0.7.0
requirements.txt CHANGED
@@ -38,7 +38,7 @@ contourpy==1.1.0
38
  # via matplotlib
39
  cycler==0.11.0
40
  # via matplotlib
41
- cython==0.29.35
42
  # via diffq
43
  decorator==5.1.1
44
  # via
@@ -91,14 +91,16 @@ kaleido==0.2.1
91
  # via -r requirements.in
92
  kiwisolver==1.4.4
93
  # via matplotlib
94
- lameenc==1.5.0
95
  # via demucs
96
- lazy-loader==0.2
97
  # via librosa
98
  librosa==0.10.0.post2
99
  # via -r requirements.in
100
  llvmlite==0.40.1
101
  # via numba
 
 
102
  markdown-it-py==3.0.0
103
  # via rich
104
  markupsafe==2.1.3
@@ -152,7 +154,7 @@ pandas==1.5.3
152
  # -r requirements.in
153
  # altair
154
  # streamlit
155
- pillow==9.5.0
156
  # via
157
  # matplotlib
158
  # streamlit
@@ -271,7 +273,7 @@ tqdm==4.65.0
271
  # stqdm
272
  treetable==0.2.5
273
  # via dora-search
274
- typing-extensions==4.7.0
275
  # via
276
  # librosa
277
  # rich
 
38
  # via matplotlib
39
  cycler==0.11.0
40
  # via matplotlib
41
+ cython==0.29.36
42
  # via diffq
43
  decorator==5.1.1
44
  # via
 
91
  # via -r requirements.in
92
  kiwisolver==1.4.4
93
  # via matplotlib
94
+ lameenc==1.5.1
95
  # via demucs
96
+ lazy-loader==0.3
97
  # via librosa
98
  librosa==0.10.0.post2
99
  # via -r requirements.in
100
  llvmlite==0.40.1
101
  # via numba
102
+ loguru==0.7.0
103
+ # via -r requirements.in
104
  markdown-it-py==3.0.0
105
  # via rich
106
  markupsafe==2.1.3
 
154
  # -r requirements.in
155
  # altair
156
  # streamlit
157
+ pillow==10.0.0
158
  # via
159
  # matplotlib
160
  # streamlit
 
273
  # stqdm
274
  treetable==0.2.5
275
  # via dora-search
276
+ typing-extensions==4.7.1
277
  # via
278
  # librosa
279
  # rich
scripts/inference.py CHANGED
@@ -1,7 +1,9 @@
1
  import argparse
 
2
 
3
  import warnings
4
  from app.service.vocal_remover.runner import load_model, separate
 
5
 
6
  warnings.simplefilter("ignore", UserWarning)
7
  warnings.simplefilter("ignore", FutureWarning)
@@ -14,16 +16,35 @@ def main():
14
  p.add_argument("--pretrained_model", "-P", type=str, default="baseline.pth")
15
  p.add_argument("--input", "-i", required=True)
16
  p.add_argument("--output_dir", "-o", type=str, default="")
 
17
  args = p.parse_args()
18
 
 
 
19
  model, device = load_model(pretrained_model=args.pretrained_model)
20
  separate(
21
- input=args.input,
22
  model=model,
23
  device=device,
24
  output_dir=args.output_dir,
25
- only_no_vocals=True,
26
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  if __name__ == "__main__":
 
1
  import argparse
2
+ from pathlib import Path
3
 
4
  import warnings
5
  from app.service.vocal_remover.runner import load_model, separate
6
+ from app.service.demucs_runner import separator
7
 
8
  warnings.simplefilter("ignore", UserWarning)
9
  warnings.simplefilter("ignore", FutureWarning)
 
16
  p.add_argument("--pretrained_model", "-P", type=str, default="baseline.pth")
17
  p.add_argument("--input", "-i", required=True)
18
  p.add_argument("--output_dir", "-o", type=str, default="")
19
+ p.add_argument("--only_no_vocals", "-n", action="store_true")
20
  args = p.parse_args()
21
 
22
+ input_file = args.input
23
+
24
  model, device = load_model(pretrained_model=args.pretrained_model)
25
  separate(
26
+ input=input_file,
27
  model=model,
28
  device=device,
29
  output_dir=args.output_dir,
30
+ only_no_vocals=args.only_no_vocals,
31
  )
32
+ if not args.only_no_vocals:
33
+ for stem, model_name in [("vocals", "htdemucs"), (None, "htdemucs"), (None, "htdemucs_6s")]:
34
+ separator(
35
+ tracks=[Path(input_file)],
36
+ out=Path(args.output_dir),
37
+ model=model_name,
38
+ shifts=1,
39
+ overlap=0.5,
40
+ stem=stem,
41
+ int24=False,
42
+ float32=False,
43
+ clip_mode="rescale",
44
+ mp3=True,
45
+ mp3_bitrate=320,
46
+ verbose=False,
47
+ )
48
 
49
 
50
  if __name__ == "__main__":
scripts/prepare_samples.sh CHANGED
@@ -22,3 +22,21 @@ for name in $(echo "${json}" | jq -r 'keys[]'); do
22
  python inference.py --input /tmp/${name} --output /tmp
23
  echo "Done separating ${name}"
24
  done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  python inference.py --input /tmp/${name} --output /tmp
23
  echo "Done separating ${name}"
24
  done
25
+
26
+
27
+ # Read JSON file into a variable
28
+ json_separate=$(cat separate_songs.json)
29
+
30
+ # Iterate through keys and values
31
+ for name in $(echo "${json_separate}" | jq -r 'keys[]'); do
32
+ url=$(echo "${json_separate}" | jq -r --arg name "${name}" '.[$name]')
33
+ echo "Separating ${name} from ${url}"
34
+
35
+ # Download with pytube
36
+ yt-dlp ${url} -o "/tmp/${name}" --format "bestaudio/best" --download-sections "*45-110"
37
+ mkdir -p "/tmp/vocal_remover"
38
+
39
+ # Run inference
40
+ python inference.py --input /tmp/${name} --output /tmp --only_no_vocals false
41
+ echo "Done separating ${name}"
42
+ done
scripts/separate_songs.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ABBA_-_Dancing_Queen": "https://www.youtube.com/watch?v=3qiMJt-JBb4",
3
+ "Queen_–_Bohemian_Rhapsody": "https://www.youtube.com/watch?v=yk3prd8GER4",
4
+ "Backstreet_Boys_-_I_Want_It_That_Way": "https://www.youtube.com/watch?v=qjlVAsvQLM8",
5
+ "The_Beatles_-_Let_It_Be": "https://www.youtube.com/watch?v=FIV73iG_e5I",
6
+ "Coldplay_-_Viva_La_Vida": "https://www.youtube.com/watch?v=a1EYnngNHIA",
7
+ "The_Cranberries_-_Zombie": "https://www.youtube.com/watch?v=8sM-rm4lFZg"
8
+ }