hilamanor commited on
Commit
a0738ba
·
1 Parent(s): 19c39fb

move zs wts to hdd instead of gpu memory, and auto delete after an hour

Browse files
Files changed (1) hide show
  1. app.py +61 -28
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import random
3
  import torch
4
- import torchaudio
5
  from torch import inference_mode
6
  from tempfile import NamedTemporaryFile
7
  import numpy as np
@@ -65,16 +65,16 @@ def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # ,
65
  with torch.no_grad():
66
  audio = ldm_stable.decode_to_mel(x0_dec)
67
 
68
- f = NamedTemporaryFile("wb", suffix=".wav", delete=False)
69
- torchaudio.save(f.name, audio, sample_rate=16000)
70
 
71
- return f.name
72
 
73
-
74
- def edit(input_audio,
75
  model_id: str,
76
  do_inversion: bool,
77
- wts: gr.State, zs: gr.State, saved_inv_model: str,
 
 
78
  source_prompt="",
79
  target_prompt="",
80
  steps=200,
@@ -95,24 +95,41 @@ def edit(input_audio,
95
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
96
  do_inversion = True
97
 
 
 
98
  x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
99
 
 
 
 
 
 
100
  if do_inversion or randomize_seed: # always re-run inversion
101
  zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
102
  num_diffusion_steps=steps,
103
  cfg_scale_src=cfg_scale_src)
104
- wts = gr.State(value=wts_tensor)
105
- zs = gr.State(value=zs_tensor)
 
 
 
 
 
106
  saved_inv_model = model_id
107
  do_inversion = False
 
 
 
 
 
108
 
109
  # make sure t_start is in the right limit
110
  # t_start = change_tstart_range(t_start, steps)
111
 
112
- output = sample(ldm_stable, zs.value, wts.value, steps, prompt_tar=target_prompt,
113
  tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
114
 
115
- return output, wts, zs, saved_inv_model, do_inversion
116
 
117
 
118
  def get_example():
@@ -170,27 +187,36 @@ For faster inference without waiting in queue, you may duplicate the space and u
170
  """
171
 
172
  help = """
 
173
  <b>Instructions:</b><br>
174
- Provide an input audio and a target prompt to edit the audio. <br>
175
- T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
176
- Lower value -> favor fidelity. Higher value -> apply a stronger edit.<br>
177
- Make sure that you use an AudioLDM2 version that is suitable for your input audio.
 
178
  For example, use the music version for music and the large version for general audio.
179
- </p>
180
- <p style="font-size:larger">
181
- You can additionally provide a source prompt to guide even further the editing process.
182
- </p>
183
- <p style="font-size:larger">Longer input will take more time.</p>
 
 
 
 
 
184
 
185
  """
186
 
187
- with gr.Blocks(css='style.css') as demo:
188
  def reset_do_inversion():
189
  do_inversion = gr.State(value=True)
190
  return do_inversion
191
  gr.HTML(intro)
192
- wts = gr.State()
193
- zs = gr.State()
 
 
194
  saved_inv_model = gr.State()
195
  # current_loaded_model = gr.State(value="cvssp/audioldm2-music")
196
  # ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
@@ -198,7 +224,7 @@ with gr.Blocks(css='style.css') as demo:
198
  do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
199
 
200
  with gr.Group():
201
- gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed (for unlimited input you may duplicate the space)")
202
  with gr.Row():
203
  input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input Audio",
204
  interactive=True, scale=1)
@@ -251,11 +277,14 @@ with gr.Blocks(css='style.css') as demo:
251
  inputs=[seed, randomize_seed],
252
  outputs=[seed], queue=False).then(
253
  fn=edit,
254
- inputs=[input_audio,
 
255
  model_id,
256
  do_inversion,
257
  # current_loaded_model, ldm_stable,
258
- wts, zs, saved_inv_model,
 
 
259
  src_prompt,
260
  tar_prompt,
261
  steps,
@@ -264,8 +293,12 @@ with gr.Blocks(css='style.css') as demo:
264
  t_start,
265
  randomize_seed
266
  ],
267
- outputs=[output_audio, wts, zs, saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
268
- )
 
 
 
 
269
 
270
  # If sources changed we have to rerun inversion
271
  input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])
 
1
  import gradio as gr
2
  import random
3
  import torch
4
+ import os
5
  from torch import inference_mode
6
  from tempfile import NamedTemporaryFile
7
  import numpy as np
 
65
  with torch.no_grad():
66
  audio = ldm_stable.decode_to_mel(x0_dec)
67
 
68
+ return (16000, audio.squeeze().cpu().numpy())
 
69
 
 
70
 
71
+ def edit(cache_dir,
72
+ input_audio,
73
  model_id: str,
74
  do_inversion: bool,
75
+ wtszs_file: str,
76
+ # wts: gr.State, zs: gr.State,
77
+ saved_inv_model: str,
78
  source_prompt="",
79
  target_prompt="",
80
  steps=200,
 
95
  if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
96
  do_inversion = True
97
 
98
+ if input_audio is None:
99
+ raise gr.Error('Input audio missing!')
100
  x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
101
 
102
+ if not (do_inversion or randomize_seed):
103
+ if not os.path.exists(wtszs_file):
104
+ do_inversion = True
105
+ # Too much time has passed
106
+
107
  if do_inversion or randomize_seed: # always re-run inversion
108
  zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
109
  num_diffusion_steps=steps,
110
  cfg_scale_src=cfg_scale_src)
111
+ f = NamedTemporaryFile("wb", dir=cache_dir, suffix=".pth", delete=False)
112
+ torch.save({'wts': wts_tensor, 'zs': zs_tensor}, f.name)
113
+ wtszs_file = f.name
114
+ # wtszs_file = gr.State(value=f.name)
115
+ # wts = gr.State(value=wts_tensor)
116
+ # zs = gr.State(value=zs_tensor)
117
+ # demo.move_resource_to_block_cache(f.name)
118
  saved_inv_model = model_id
119
  do_inversion = False
120
+ else:
121
+ wtszs = torch.load(wtszs_file, map_location=device)
122
+ # wtszs = torch.load(wtszs_file.f, map_location=device)
123
+ wts_tensor = wtszs['wts']
124
+ zs_tensor = wtszs['zs']
125
 
126
  # make sure t_start is in the right limit
127
  # t_start = change_tstart_range(t_start, steps)
128
 
129
+ output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
130
  tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
131
 
132
+ return output, wtszs_file, saved_inv_model, do_inversion
133
 
134
 
135
  def get_example():
 
187
  """
188
 
189
  help = """
190
+ <div style="font-size:medium">
191
  <b>Instructions:</b><br>
192
+ <ul style="line-height: normal">
193
+ <li>You must provide an input audio and a target prompt to edit the audio. </li>
194
+ <li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
195
+ Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
196
+ <li>Make sure that you use an AudioLDM2 version that is suitable for your input audio.
197
  For example, use the music version for music and the large version for general audio.
198
+ </li>
199
+ <li>You can additionally provide a source prompt to guide even further the editing process.</li>
200
+ <li>Longer input will take more time.</li>
201
+ <li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
202
+ For unlimited length, duplicated the space, and remove the trimming by changing the code.
203
+ Specifically, in the <code style="display:inline; background-color: lightgrey; ">load_audio</code> function in the <code style="display:inline; background-color: lightgrey; ">utils.py</code> file,
204
+ change <code style="display:inline; background-color: lightgrey; ">duration = min(audioldm.utils.get_duration(audio_path), 30)</code> to
205
+ <code style="display:inline; background-color: lightgrey; ">duration = audioldm.utils.get_duration(audio_path)</code>.
206
+ </ul>
207
+ </div>
208
 
209
  """
210
 
211
+ with gr.Blocks(css='style.css', delete_cache=(3600, 3600)) as demo:
212
  def reset_do_inversion():
213
  do_inversion = gr.State(value=True)
214
  return do_inversion
215
  gr.HTML(intro)
216
+ # wts = gr.State()
217
+ # zs = gr.State()
218
+ wtszs = gr.State()
219
+ cache_dir = gr.State(demo.GRADIO_CACHE)
220
  saved_inv_model = gr.State()
221
  # current_loaded_model = gr.State(value="cvssp/audioldm2-music")
222
  # ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
 
224
  do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
225
 
226
  with gr.Group():
227
+ gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed (for unlimited input, see the Help section below)")
228
  with gr.Row():
229
  input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input Audio",
230
  interactive=True, scale=1)
 
277
  inputs=[seed, randomize_seed],
278
  outputs=[seed], queue=False).then(
279
  fn=edit,
280
+ inputs=[cache_dir,
281
+ input_audio,
282
  model_id,
283
  do_inversion,
284
  # current_loaded_model, ldm_stable,
285
+ # wts, zs,
286
+ wtszs,
287
+ saved_inv_model,
288
  src_prompt,
289
  tar_prompt,
290
  steps,
 
293
  t_start,
294
  randomize_seed
295
  ],
296
+ outputs=[output_audio, wtszs,
297
+ saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
298
+ ).then(lambda x: demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None,
299
+ inputs=wtszs)
300
+
301
+ # demo.move_resource_to_block_cache(wtszs.value)
302
 
303
  # If sources changed we have to rerun inversion
304
  input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])