Spaces:
Running
Running
move zs wts to hdd instead of gpu memory, and auto delete after an hour
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import random
|
3 |
import torch
|
4 |
-
import
|
5 |
from torch import inference_mode
|
6 |
from tempfile import NamedTemporaryFile
|
7 |
import numpy as np
|
@@ -65,16 +65,16 @@ def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # ,
|
|
65 |
with torch.no_grad():
|
66 |
audio = ldm_stable.decode_to_mel(x0_dec)
|
67 |
|
68 |
-
|
69 |
-
torchaudio.save(f.name, audio, sample_rate=16000)
|
70 |
|
71 |
-
return f.name
|
72 |
|
73 |
-
|
74 |
-
|
75 |
model_id: str,
|
76 |
do_inversion: bool,
|
77 |
-
|
|
|
|
|
78 |
source_prompt="",
|
79 |
target_prompt="",
|
80 |
steps=200,
|
@@ -95,24 +95,41 @@ def edit(input_audio,
|
|
95 |
if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
|
96 |
do_inversion = True
|
97 |
|
|
|
|
|
98 |
x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
100 |
if do_inversion or randomize_seed: # always re-run inversion
|
101 |
zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
|
102 |
num_diffusion_steps=steps,
|
103 |
cfg_scale_src=cfg_scale_src)
|
104 |
-
|
105 |
-
zs
|
|
|
|
|
|
|
|
|
|
|
106 |
saved_inv_model = model_id
|
107 |
do_inversion = False
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
# make sure t_start is in the right limit
|
110 |
# t_start = change_tstart_range(t_start, steps)
|
111 |
|
112 |
-
output = sample(ldm_stable,
|
113 |
tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
|
114 |
|
115 |
-
return output,
|
116 |
|
117 |
|
118 |
def get_example():
|
@@ -170,27 +187,36 @@ For faster inference without waiting in queue, you may duplicate the space and u
|
|
170 |
"""
|
171 |
|
172 |
help = """
|
|
|
173 |
<b>Instructions:</b><br>
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
178 |
For example, use the music version for music and the large version for general audio.
|
179 |
-
</
|
180 |
-
<
|
181 |
-
|
182 |
-
</
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
"""
|
186 |
|
187 |
-
with gr.Blocks(css='style.css') as demo:
|
188 |
def reset_do_inversion():
|
189 |
do_inversion = gr.State(value=True)
|
190 |
return do_inversion
|
191 |
gr.HTML(intro)
|
192 |
-
wts = gr.State()
|
193 |
-
zs = gr.State()
|
|
|
|
|
194 |
saved_inv_model = gr.State()
|
195 |
# current_loaded_model = gr.State(value="cvssp/audioldm2-music")
|
196 |
# ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
|
@@ -198,7 +224,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
198 |
do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
|
199 |
|
200 |
with gr.Group():
|
201 |
-
gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed (for unlimited input
|
202 |
with gr.Row():
|
203 |
input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input Audio",
|
204 |
interactive=True, scale=1)
|
@@ -251,11 +277,14 @@ with gr.Blocks(css='style.css') as demo:
|
|
251 |
inputs=[seed, randomize_seed],
|
252 |
outputs=[seed], queue=False).then(
|
253 |
fn=edit,
|
254 |
-
inputs=[
|
|
|
255 |
model_id,
|
256 |
do_inversion,
|
257 |
# current_loaded_model, ldm_stable,
|
258 |
-
wts, zs,
|
|
|
|
|
259 |
src_prompt,
|
260 |
tar_prompt,
|
261 |
steps,
|
@@ -264,8 +293,12 @@ with gr.Blocks(css='style.css') as demo:
|
|
264 |
t_start,
|
265 |
randomize_seed
|
266 |
],
|
267 |
-
outputs=[output_audio,
|
268 |
-
|
|
|
|
|
|
|
|
|
269 |
|
270 |
# If sources changed we have to rerun inversion
|
271 |
input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])
|
|
|
1 |
import gradio as gr
|
2 |
import random
|
3 |
import torch
|
4 |
+
import os
|
5 |
from torch import inference_mode
|
6 |
from tempfile import NamedTemporaryFile
|
7 |
import numpy as np
|
|
|
65 |
with torch.no_grad():
|
66 |
audio = ldm_stable.decode_to_mel(x0_dec)
|
67 |
|
68 |
+
return (16000, audio.squeeze().cpu().numpy())
|
|
|
69 |
|
|
|
70 |
|
71 |
+
def edit(cache_dir,
|
72 |
+
input_audio,
|
73 |
model_id: str,
|
74 |
do_inversion: bool,
|
75 |
+
wtszs_file: str,
|
76 |
+
# wts: gr.State, zs: gr.State,
|
77 |
+
saved_inv_model: str,
|
78 |
source_prompt="",
|
79 |
target_prompt="",
|
80 |
steps=200,
|
|
|
95 |
if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
|
96 |
do_inversion = True
|
97 |
|
98 |
+
if input_audio is None:
|
99 |
+
raise gr.Error('Input audio missing!')
|
100 |
x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
|
101 |
|
102 |
+
if not (do_inversion or randomize_seed):
|
103 |
+
if not os.path.exists(wtszs_file):
|
104 |
+
do_inversion = True
|
105 |
+
# Too much time has passed
|
106 |
+
|
107 |
if do_inversion or randomize_seed: # always re-run inversion
|
108 |
zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
|
109 |
num_diffusion_steps=steps,
|
110 |
cfg_scale_src=cfg_scale_src)
|
111 |
+
f = NamedTemporaryFile("wb", dir=cache_dir, suffix=".pth", delete=False)
|
112 |
+
torch.save({'wts': wts_tensor, 'zs': zs_tensor}, f.name)
|
113 |
+
wtszs_file = f.name
|
114 |
+
# wtszs_file = gr.State(value=f.name)
|
115 |
+
# wts = gr.State(value=wts_tensor)
|
116 |
+
# zs = gr.State(value=zs_tensor)
|
117 |
+
# demo.move_resource_to_block_cache(f.name)
|
118 |
saved_inv_model = model_id
|
119 |
do_inversion = False
|
120 |
+
else:
|
121 |
+
wtszs = torch.load(wtszs_file, map_location=device)
|
122 |
+
# wtszs = torch.load(wtszs_file.f, map_location=device)
|
123 |
+
wts_tensor = wtszs['wts']
|
124 |
+
zs_tensor = wtszs['zs']
|
125 |
|
126 |
# make sure t_start is in the right limit
|
127 |
# t_start = change_tstart_range(t_start, steps)
|
128 |
|
129 |
+
output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
|
130 |
tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
|
131 |
|
132 |
+
return output, wtszs_file, saved_inv_model, do_inversion
|
133 |
|
134 |
|
135 |
def get_example():
|
|
|
187 |
"""
|
188 |
|
189 |
help = """
|
190 |
+
<div style="font-size:medium">
|
191 |
<b>Instructions:</b><br>
|
192 |
+
<ul style="line-height: normal">
|
193 |
+
<li>You must provide an input audio and a target prompt to edit the audio. </li>
|
194 |
+
<li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
|
195 |
+
Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
|
196 |
+
<li>Make sure that you use an AudioLDM2 version that is suitable for your input audio.
|
197 |
For example, use the music version for music and the large version for general audio.
|
198 |
+
</li>
|
199 |
+
<li>You can additionally provide a source prompt to guide even further the editing process.</li>
|
200 |
+
<li>Longer input will take more time.</li>
|
201 |
+
<li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
|
202 |
+
For unlimited length, duplicated the space, and remove the trimming by changing the code.
|
203 |
+
Specifically, in the <code style="display:inline; background-color: lightgrey; ">load_audio</code> function in the <code style="display:inline; background-color: lightgrey; ">utils.py</code> file,
|
204 |
+
change <code style="display:inline; background-color: lightgrey; ">duration = min(audioldm.utils.get_duration(audio_path), 30)</code> to
|
205 |
+
<code style="display:inline; background-color: lightgrey; ">duration = audioldm.utils.get_duration(audio_path)</code>.
|
206 |
+
</ul>
|
207 |
+
</div>
|
208 |
|
209 |
"""
|
210 |
|
211 |
+
with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
|
212 |
def reset_do_inversion():
|
213 |
do_inversion = gr.State(value=True)
|
214 |
return do_inversion
|
215 |
gr.HTML(intro)
|
216 |
+
# wts = gr.State()
|
217 |
+
# zs = gr.State()
|
218 |
+
wtszs = gr.State()
|
219 |
+
cache_dir = gr.State(demo.GRADIO_CACHE)
|
220 |
saved_inv_model = gr.State()
|
221 |
# current_loaded_model = gr.State(value="cvssp/audioldm2-music")
|
222 |
# ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
|
|
|
224 |
do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
|
225 |
|
226 |
with gr.Group():
|
227 |
+
gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed (for unlimited input, see the Help section below)")
|
228 |
with gr.Row():
|
229 |
input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input Audio",
|
230 |
interactive=True, scale=1)
|
|
|
277 |
inputs=[seed, randomize_seed],
|
278 |
outputs=[seed], queue=False).then(
|
279 |
fn=edit,
|
280 |
+
inputs=[cache_dir,
|
281 |
+
input_audio,
|
282 |
model_id,
|
283 |
do_inversion,
|
284 |
# current_loaded_model, ldm_stable,
|
285 |
+
# wts, zs,
|
286 |
+
wtszs,
|
287 |
+
saved_inv_model,
|
288 |
src_prompt,
|
289 |
tar_prompt,
|
290 |
steps,
|
|
|
293 |
t_start,
|
294 |
randomize_seed
|
295 |
],
|
296 |
+
outputs=[output_audio, wtszs,
|
297 |
+
saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
|
298 |
+
).then(lambda x: demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None,
|
299 |
+
inputs=wtszs)
|
300 |
+
|
301 |
+
# demo.move_resource_to_block_cache(wtszs.value)
|
302 |
|
303 |
# If sources changed we have to rerun inversion
|
304 |
input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])
|