hilamanor commited on
Commit
7c56def
1 Parent(s): 6daa825

Stable Audio Open + progbars + mp3 + batched forward + cleanup

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.wav filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.wav filter=lfs diff=lfs merge=lfs -text
37
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
Examples/{Beethoven.wav → Beethoven.mp3} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:30a6a087a9e0eb87422aa3b48ad966eabb1dfe105d73a25d356b71d3aee31493
3
- size 4828972
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dcc79fe071d118df3caaeeb85d7944f93a5df40bbdb72a26b67bd57da2af7c5
3
+ size 1097142
Examples/{Cat_dog.wav → Beethoven_arcade.mp3} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:90a97dc229eeccef307dd40db97fb09cc439ce0b45a320fd84b2ea6b03d0deb2
3
- size 327822
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:542bd61d9cc1723ccfd9bfc06b0818e77fc763013827ff1f9289e2ac6a912904
3
+ size 563040
Examples/{Beethoven_arcade.wav → Beethoven_piano.mp3} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ccd929b93c15706f2102a27973d490a84ce0eb97faba6a92ece0c6d81ed2c26e
3
- size 1794746
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:000d82c39d8c41b10188d328e29cb1baa948232bacd693f22e297cc54f4bb707
3
+ size 563040
Examples/{Beethoven_piano.wav → Beethoven_rock.mp3} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5787c31b0b3c78dec33d651d437364785713042e7cfce2290cf4baf01f65ac6f
3
- size 1794746
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c51d75c9094a50c7892449a013b32ffde266a5abd6dad9f00bf3aeec0ee935ee
3
+ size 1097142
Examples/{Cat.wav → Cat.mp3} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:27b43763a8d9ac90dc78285ed9817b16524f24b4f4d1aa399616f1a04d4a9fd9
3
- size 1920508
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff7010e5fb12a57508c7a0941663f1a12bfc8b3b3d01d0973359cd42ae5eb1e
3
+ size 402542
Examples/Cat_dog.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:72ff727243606215c934552e946f7d97b5e2e39c4d6263f7f36659e3f39f3008
3
+ size 207403
Examples/ModalJazz.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34cf145b84b6b4669050ca42932fb74ac0f28aabbe6c665f12a877c9809fa9c6
3
+ size 4153468
Examples/ModalJazz.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:846a77046d21ebc3996841404eede9d56797c82b3414025e1ccafe586eaf2959
3
- size 9153322
 
 
 
 
Examples/ModalJazz_banjo.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11680068427556981aa6304e6c11bd05debc820ca581c248954c1ffe3cd94569
3
+ size 2128320
Examples/ModalJazz_banjo.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:122e0078c0bf2fc96425071706fe0e8674c93cc1d2787fd02c0e2c0f12de5cc5
3
- size 6802106
 
 
 
 
Examples/Shadows.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e0cab2ebda4507641d6a1b5d9b2d888a7526581b7de48540ebf86ce00579908
3
+ size 1342693
Examples/Shadows_arcade.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68c84805ea17d0697cd79bc85394754d70fb02f740db4bee4c6ccbb5269a5d84
3
+ size 1342693
README.md CHANGED
@@ -9,7 +9,10 @@ app_file: app.py
9
  pinned: false
10
  license: cc-by-sa-4.0
11
  short_description: Edit audios with text prompts
 
 
 
12
  ---
13
 
14
  The 30-second limit was introduced to ensure that queue wait times remain reasonable, especially when there are a lot of users.
15
- For that reason pull-requests that change this limit will not be merged. Please clone or duplicate the space to work locally without limits.
 
9
  pinned: false
10
  license: cc-by-sa-4.0
11
  short_description: Edit audios with text prompts
12
+ hf_oauth: true
13
+ hf_oauth_scopes:
14
+ - read-repos
15
  ---
16
 
17
  The 30-second limit was introduced to ensure that queue wait times remain reasonable, especially when there are a lot of users.
18
+ For that reason pull-requests that change this limit will not be merged. Please clone or duplicate the space to work locally without limits.
app.py CHANGED
@@ -6,27 +6,26 @@ if os.getenv('SPACES_ZERO_GPU') == "true":
6
  import gradio as gr
7
  import random
8
  import torch
 
9
  from torch import inference_mode
10
- # from tempfile import NamedTemporaryFile
11
- from typing import Optional
12
  import numpy as np
13
  from models import load_model
14
  import utils
15
  import spaces
 
16
  from inversion_utils import inversion_forward_process, inversion_reverse_process
17
 
18
 
19
- # current_loaded_model = "cvssp/audioldm2-music"
20
- # # current_loaded_model = "cvssp/audioldm2-music"
21
-
22
- # ldm_stable = load_model(current_loaded_model, device, 200) # deafult model
23
  LDM2 = "cvssp/audioldm2"
24
  MUSIC = "cvssp/audioldm2-music"
25
  LDM2_LARGE = "cvssp/audioldm2-large"
 
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  ldm2 = load_model(model_id=LDM2, device=device)
28
  ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
29
  ldm2_music = load_model(model_id=MUSIC, device=device)
 
30
 
31
 
32
  def randomize_seed_fn(seed, randomize_seed):
@@ -36,89 +35,136 @@ def randomize_seed_fn(seed, randomize_seed):
36
  return seed
37
 
38
 
39
- def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable):
40
  # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
41
 
42
  with inference_mode():
43
  w0 = ldm_stable.vae_encode(x0)
44
 
45
  # find Zs and wts - forward process
46
- _, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1,
47
- prompts=[prompt_src],
48
- cfg_scales=[cfg_scale_src],
49
- prog_bar=True,
50
- num_inference_steps=num_diffusion_steps,
51
- numerical_fix=True)
52
- return zs, wts
 
53
 
54
 
55
- def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable):
56
  # reverse process (via Zs and wT)
57
  tstart = torch.tensor(tstart, dtype=torch.int)
58
- skip = steps - tstart
59
- w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip,
60
  etas=1., prompts=[prompt_tar],
61
  neg_prompts=[""], cfg_scales=[cfg_scale_tar],
62
- prog_bar=True,
63
- zs=zs[:int(steps - skip)])
 
 
64
 
65
  # vae decode image
66
  with inference_mode():
67
  x0_dec = ldm_stable.vae_decode(w0)
68
- if x0_dec.dim() < 4:
69
- x0_dec = x0_dec[None, :, :, :]
70
 
71
- with torch.no_grad():
72
- audio = ldm_stable.decode_to_mel(x0_dec)
 
73
 
74
- return (16000, audio.squeeze().cpu().numpy())
75
-
76
- def get_duration(input_audio, model_id: str, do_inversion: bool,
77
- wts: Optional[torch.Tensor], zs: Optional[torch.Tensor],
78
- saved_inv_model: str, source_prompt="", target_prompt="",
79
- steps=200, cfg_scale_src=3.5, cfg_scale_tar=12, t_start=45, randomize_seed=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if model_id == LDM2:
81
- factor = 0.8
82
  elif model_id == LDM2_LARGE:
83
- factor = 1.5
 
 
84
  else: # MUSIC
85
  factor = 1
86
 
87
- mult = 0
88
  if do_inversion or randomize_seed:
89
- mult = steps
 
 
 
 
 
 
 
 
 
90
 
 
 
91
  if input_audio is None:
92
  raise gr.Error('Input audio missing!')
93
- duration = min(utils.get_duration(input_audio), 30)
94
 
95
- time_per_iter_of_full = factor * ((t_start /100 * steps)*2 + mult) * 0.25
96
- print('expected time:', time_per_iter_of_full / 30 * duration)
97
- return max(15, time_per_iter_of_full / 30 * duration)
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- @spaces.GPU(duration=get_duration)
101
- def edit(
102
- # cache_dir,
103
- input_audio,
104
- model_id: str,
105
- do_inversion: bool,
106
- # wtszs_file: str,
107
- wts: Optional[torch.Tensor], zs: Optional[torch.Tensor],
108
- saved_inv_model: str,
109
- source_prompt="",
110
- target_prompt="",
111
- steps=200,
112
- cfg_scale_src=3.5,
113
- cfg_scale_tar=12,
114
- t_start=45,
115
- randomize_seed=True):
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  print(model_id)
118
  if model_id == LDM2:
119
  ldm_stable = ldm2
120
  elif model_id == LDM2_LARGE:
121
  ldm_stable = ldm2_large
 
 
122
  else: # MUSIC
123
  ldm_stable = ldm2_music
124
 
@@ -130,102 +176,126 @@ def edit(
130
 
131
  if input_audio is None:
132
  raise gr.Error('Input audio missing!')
133
- x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
134
-
135
- # if not (do_inversion or randomize_seed):
136
- # if not os.path.exists(wtszs_file):
137
- # do_inversion = True
138
- # Too much time has passed
139
  if wts is None or zs is None:
140
  do_inversion = True
141
 
142
  if do_inversion or randomize_seed: # always re-run inversion
143
- zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
144
- num_diffusion_steps=steps,
145
- cfg_scale_src=cfg_scale_src)
146
- # f = NamedTemporaryFile("wb", dir=cache_dir, suffix=".pth", delete=False)
147
- # torch.save({'wts': wts_tensor, 'zs': zs_tensor}, f.name)
148
- # wtszs_file = f.name
149
- # wtszs_file = gr.State(value=f.name)
150
- # wts = gr.State(value=wts_tensor)
151
  wts = wts_tensor
152
  zs = zs_tensor
153
- # zs = gr.State(value=zs_tensor)
154
- # demo.move_resource_to_block_cache(f.name)
155
  saved_inv_model = model_id
156
  do_inversion = False
157
  else:
158
- # wtszs = torch.load(wtszs_file, map_location=device)
159
- # # wtszs = torch.load(wtszs_file.f, map_location=device)
160
- # wts_tensor = wtszs['wts']
161
- # zs_tensor = wtszs['zs']
162
  wts_tensor = wts.to(device)
163
  zs_tensor = zs.to(device)
 
164
 
165
- # make sure t_start is in the right limit
166
- # t_start = change_tstart_range(t_start, steps)
167
-
168
- output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
169
- tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
170
 
171
- return output, wts.cpu(), zs.cpu(), saved_inv_model, do_inversion
172
  # return output, wtszs_file, saved_inv_model, do_inversion
173
 
174
 
175
  def get_example():
176
  case = [
177
- ['Examples/Beethoven.wav',
178
  '',
179
  'A recording of an arcade game soundtrack.',
180
  45,
181
  'cvssp/audioldm2-music',
182
  '27s',
183
- 'Examples/Beethoven_arcade.wav',
184
  ],
185
- ['Examples/Beethoven.wav',
186
  'A high quality recording of wind instruments and strings playing.',
187
  'A high quality recording of a piano playing.',
188
  45,
189
  'cvssp/audioldm2-music',
190
  '27s',
191
- 'Examples/Beethoven_piano.wav',
 
 
 
 
 
 
 
 
192
  ],
193
- ['Examples/ModalJazz.wav',
194
  'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
195
  'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
196
  45,
197
  'cvssp/audioldm2-music',
198
  '106s',
199
- 'Examples/ModalJazz_banjo.wav',],
200
- ['Examples/Cat.wav',
 
 
 
 
 
 
 
201
  '',
202
  'A dog barking.',
203
  75,
204
  'cvssp/audioldm2-large',
205
  '10s',
206
- 'Examples/Cat_dog.wav',]
207
  ]
208
  return case
209
 
210
 
211
  intro = """
212
- <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> ZETA Editing 🎧 </h1>
213
- <h2 style="font-weight: 1400; text-align: center; margin-bottom: 7px;"> Zero-Shot Text-Based Audio Editing Using DDPM Inversion 🎛️ </h2>
214
- <h3 style="margin-bottom: 10px; text-align: center;">
 
215
  <a href="https://arxiv.org/abs/2402.10009">[Paper]</a>&nbsp;|&nbsp;
216
  <a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a>&nbsp;|&nbsp;
217
  <a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
218
  </h3>
219
 
220
-
221
- <p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
222
  For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
223
  <a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
224
- <img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" ></a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  </p>
226
-
227
  """
228
 
 
229
  help = """
230
  <div style="font-size:medium">
231
  <b>Instructions:</b><br>
@@ -233,22 +303,27 @@ help = """
233
  <li>You must provide an input audio and a target prompt to edit the audio. </li>
234
  <li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
235
  Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
236
- <li>Make sure that you use an AudioLDM2 version that is suitable for your input audio.
237
- For example, use the music version for music and the large version for general audio.
238
  </li>
239
  <li>You can additionally provide a source prompt to guide even further the editing process.</li>
240
  <li>Longer input will take more time.</li>
241
  <li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
242
- For unlimited length, duplicated the space, and remove the trimming by changing the code.
243
- Specifically, in the <code style="display:inline; background-color: lightgrey; ">load_audio</code> function in the <code style="display:inline; background-color: lightgrey; ">utils.py</code> file,
244
- change <code style="display:inline; background-color: lightgrey; ">duration = min(audioldm.utils.get_duration(audio_path), 30)</code> to
245
- <code style="display:inline; background-color: lightgrey; ">duration = audioldm.utils.get_duration(audio_path)</code>.
 
246
  </ul>
247
  </div>
248
 
249
  """
250
 
251
- with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
 
 
 
 
252
  def reset_do_inversion(do_inversion_user, do_inversion):
253
  # do_inversion = gr.State(value=True)
254
  do_inversion = True
@@ -267,23 +342,22 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
267
  return do_inversion_user, do_inversion
268
 
269
  gr.HTML(intro)
 
270
  wts = gr.State()
271
  zs = gr.State()
272
- wtszs = gr.State()
273
- # cache_dir = gr.State(demo.GRADIO_CACHE)
274
  saved_inv_model = gr.State()
275
- # current_loaded_model = gr.State(value="cvssp/audioldm2-music")
276
- # ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
277
- # ldm_stable = gr.State(value=ldm_stable)
278
  do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
279
  do_inversion_user = gr.State(value=False)
280
 
281
  with gr.Group():
282
- gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed (for unlimited input, see the Help section below)")
283
- with gr.Row():
284
- input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", editable=True, label="Input Audio",
285
- interactive=True, scale=1)
286
- output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1)
 
 
287
 
288
  with gr.Row():
289
  tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
@@ -293,17 +367,16 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
293
  with gr.Row():
294
  t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
295
  info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
296
- # model_id = gr.Radio(label="AudioLDM2 Version",
297
- model_id = gr.Dropdown(label="AudioLDM2 Version",
298
- choices=["cvssp/audioldm2",
299
- "cvssp/audioldm2-large",
300
- "cvssp/audioldm2-music"],
301
- info="Choose a checkpoint suitable for your intended audio and edit",
302
  value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
303
-
304
  with gr.Row():
305
- with gr.Column():
306
- submit = gr.Button("Edit")
307
 
308
  with gr.Accordion("More Options", open=False):
309
  with gr.Row():
@@ -311,58 +384,62 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
311
  info="Optional: Describe the original audio input",
312
  placeholder="A recording of a happy upbeat classical music piece",)
313
 
314
- with gr.Row():
315
  cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
316
  label="Source Guidance Scale", interactive=True, scale=1)
317
  cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
318
  label="Target Guidance Scale", interactive=True, scale=1)
319
- steps = gr.Number(value=50, step=1, minimum=20, maximum=300,
320
  info="Higher values (e.g. 200) yield higher-quality generation.",
321
- label="Num Diffusion Steps", interactive=True, scale=1)
322
- with gr.Row():
323
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
324
  randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
 
325
  length = gr.Number(label="Length", interactive=False, visible=False)
326
 
327
  with gr.Accordion("Help💡", open=False):
328
  gr.HTML(help)
329
 
330
  submit.click(
331
- fn=randomize_seed_fn,
332
- inputs=[seed, randomize_seed],
333
- outputs=[seed], queue=False).then(
334
- fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]).then(
335
- fn=edit,
336
- inputs=[#cache_dir,
337
- input_audio,
338
- model_id,
339
- do_inversion,
340
- # current_loaded_model, ldm_stable,
341
- wts, zs,
342
- # wtszs,
343
- saved_inv_model,
344
- src_prompt,
345
- tar_prompt,
346
- steps,
347
- cfg_scale_src,
348
- cfg_scale_tar,
349
- t_start,
350
- randomize_seed
351
- ],
352
- outputs=[output_audio, wts, zs, # wtszs,
353
- saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
354
- ).then(post_match_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion]
355
- ).then(lambda x: (demo.temp_file_sets.append(set([str(gr.utils.abspath(x))])) if type(x) is str else None),
356
- inputs=wtszs)
357
-
358
- # demo.move_resource_to_block_cache(wtszs.value)
 
359
 
360
  # If sources changed we have to rerun inversion
361
- input_audio.change(fn=reset_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion])
362
- src_prompt.change(fn=reset_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion])
363
- model_id.change(fn=reset_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion])
364
- cfg_scale_src.change(fn=reset_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion])
365
- steps.change(fn=reset_do_inversion, inputs=[do_inversion_user, do_inversion], outputs=[do_inversion_user, do_inversion])
 
 
366
 
367
  gr.Examples(
368
  label="Examples",
 
6
  import gradio as gr
7
  import random
8
  import torch
9
+ import os
10
  from torch import inference_mode
11
+ from typing import Optional, List
 
12
  import numpy as np
13
  from models import load_model
14
  import utils
15
  import spaces
16
+ import huggingface_hub
17
  from inversion_utils import inversion_forward_process, inversion_reverse_process
18
 
19
 
 
 
 
 
20
  LDM2 = "cvssp/audioldm2"
21
  MUSIC = "cvssp/audioldm2-music"
22
  LDM2_LARGE = "cvssp/audioldm2-large"
23
+ STABLEAUD = "stabilityai/stable-audio-open-1.0"
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  ldm2 = load_model(model_id=LDM2, device=device)
26
  ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
27
  ldm2_music = load_model(model_id=MUSIC, device=device)
28
+ ldm_stableaud = load_model(model_id=STABLEAUD, device=device, token=os.getenv('PRIV_TOKEN'))
29
 
30
 
31
  def randomize_seed_fn(seed, randomize_seed):
 
35
  return seed
36
 
37
 
38
+ def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
39
  # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
40
 
41
  with inference_mode():
42
  w0 = ldm_stable.vae_encode(x0)
43
 
44
  # find Zs and wts - forward process
45
+ _, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1,
46
+ prompts=[prompt_src],
47
+ cfg_scales=[cfg_scale_src],
48
+ num_inference_steps=num_diffusion_steps,
49
+ numerical_fix=True,
50
+ duration=duration,
51
+ save_compute=save_compute)
52
+ return zs, wts, extra_info
53
 
54
 
55
+ def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute):
56
  # reverse process (via Zs and wT)
57
  tstart = torch.tensor(tstart, dtype=torch.int)
58
+ w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart,
 
59
  etas=1., prompts=[prompt_tar],
60
  neg_prompts=[""], cfg_scales=[cfg_scale_tar],
61
+ zs=zs[:int(tstart)],
62
+ duration=duration,
63
+ extra_info=extra_info,
64
+ save_compute=save_compute)
65
 
66
  # vae decode image
67
  with inference_mode():
68
  x0_dec = ldm_stable.vae_decode(w0)
 
 
69
 
70
+ if 'stable-audio' not in ldm_stable.model_id:
71
+ if x0_dec.dim() < 4:
72
+ x0_dec = x0_dec[None, :, :, :]
73
 
74
+ with torch.no_grad():
75
+ audio = ldm_stable.decode_to_mel(x0_dec)
76
+ else:
77
+ audio = x0_dec.squeeze(0).T
78
+
79
+ return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy())
80
+
81
+
82
+ def get_duration(input_audio,
83
+ model_id: str,
84
+ do_inversion: bool,
85
+ wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
86
+ saved_inv_model: str,
87
+ source_prompt: str = "",
88
+ target_prompt: str = "",
89
+ steps: int = 200,
90
+ cfg_scale_src: float = 3.5,
91
+ cfg_scale_tar: float = 12,
92
+ t_start: int = 45,
93
+ randomize_seed: bool = True,
94
+ save_compute: bool = True,
95
+ oauth_token: Optional[gr.OAuthToken] = None):
96
  if model_id == LDM2:
97
+ factor = 1
98
  elif model_id == LDM2_LARGE:
99
+ factor = 2.5
100
+ elif model_id == STABLEAUD:
101
+ factor = 3.2
102
  else: # MUSIC
103
  factor = 1
104
 
105
+ forwards = 0
106
  if do_inversion or randomize_seed:
107
+ forwards = steps if source_prompt == "" else steps * 2 # x2 when there is a prompt text
108
+ forwards += int(t_start / 100 * steps) * 2
109
+
110
+ duration = min(utils.get_duration(input_audio), utils.MAX_DURATION)
111
+ time_for_maxlength = factor * forwards * 0.15 # 0.25 is the time per forward pass
112
+ print('expected time:', time_for_maxlength / utils.MAX_DURATION * duration)
113
+
114
+ spare_time = 5
115
+ return max(10, time_for_maxlength / utils.MAX_DURATION * duration + spare_time)
116
+
117
 
118
+ def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float,
119
+ oauth_token: gr.OAuthToken | None):
120
  if input_audio is None:
121
  raise gr.Error('Input audio missing!')
 
122
 
123
+ if tar_prompt == "":
124
+ raise gr.Error("Please provide a target prompt to edit the audio.")
125
+
126
+ if src_prompt != "":
127
+ if model_id == STABLEAUD and cfg_scale_src != 1:
128
+ gr.Info("Consider using Source Guidance Scale=1 for Stable Audio Open 1.0.")
129
+ elif model_id != STABLEAUD and cfg_scale_src != 3:
130
+ gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.")
131
+
132
+ if model_id == STABLEAUD:
133
+ if oauth_token is None:
134
+ raise gr.Error("You must be logged in to use Stable Audio Open 1.0. Please log in and try again.")
135
+ try:
136
+ huggingface_hub.get_hf_file_metadata(huggingface_hub.hf_hub_url(STABLEAUD, 'transformer/config.json'),
137
+ token=oauth_token.token)
138
+ print('Has Access')
139
+ # except huggingface_hub.utils._errors.GatedRepoError:
140
+ except huggingface_hub.errors.GatedRepoError:
141
+ raise gr.Error("You need to accept the license agreement to use Stable Audio Open 1.0. "
142
+ "Visit the <a href='https://huggingface.co/stabilityai/stable-audio-open-1.0'>"
143
+ "model page</a> to get access.")
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ @spaces.GPU(duration=get_duration)
147
+ def edit(input_audio,
148
+ model_id: str,
149
+ do_inversion: bool,
150
+ wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
151
+ saved_inv_model: str,
152
+ source_prompt: str = "",
153
+ target_prompt: str = "",
154
+ steps: int = 200,
155
+ cfg_scale_src: float = 3.5,
156
+ cfg_scale_tar: float = 12,
157
+ t_start: int = 45,
158
+ randomize_seed: bool = True,
159
+ save_compute: bool = True,
160
+ oauth_token: Optional[gr.OAuthToken] = None):
161
  print(model_id)
162
  if model_id == LDM2:
163
  ldm_stable = ldm2
164
  elif model_id == LDM2_LARGE:
165
  ldm_stable = ldm2_large
166
+ elif model_id == STABLEAUD:
167
+ ldm_stable = ldm_stableaud
168
  else: # MUSIC
169
  ldm_stable = ldm2_music
170
 
 
176
 
177
  if input_audio is None:
178
  raise gr.Error('Input audio missing!')
179
+ x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device,
180
+ stft=('stable-audio' not in ldm_stable.model_id), model_sr=ldm_stable.get_sr())
 
 
 
 
181
  if wts is None or zs is None:
182
  do_inversion = True
183
 
184
  if do_inversion or randomize_seed: # always re-run inversion
185
+ zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
186
+ num_diffusion_steps=steps,
187
+ cfg_scale_src=cfg_scale_src,
188
+ duration=duration,
189
+ save_compute=save_compute)
 
 
 
190
  wts = wts_tensor
191
  zs = zs_tensor
192
+ extra_info = extra_info_list
 
193
  saved_inv_model = model_id
194
  do_inversion = False
195
  else:
 
 
 
 
196
  wts_tensor = wts.to(device)
197
  zs_tensor = zs.to(device)
198
+ extra_info_list = [e.to(device) for e in extra_info if e is not None]
199
 
200
+ output = sample(ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt,
201
+ tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration,
202
+ save_compute=save_compute)
 
 
203
 
204
+ return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion
205
  # return output, wtszs_file, saved_inv_model, do_inversion
206
 
207
 
208
  def get_example():
209
  case = [
210
+ ['Examples/Beethoven.mp3',
211
  '',
212
  'A recording of an arcade game soundtrack.',
213
  45,
214
  'cvssp/audioldm2-music',
215
  '27s',
216
+ 'Examples/Beethoven_arcade.mp3',
217
  ],
218
+ ['Examples/Beethoven.mp3',
219
  'A high quality recording of wind instruments and strings playing.',
220
  'A high quality recording of a piano playing.',
221
  45,
222
  'cvssp/audioldm2-music',
223
  '27s',
224
+ 'Examples/Beethoven_piano.mp3',
225
+ ],
226
+ ['Examples/Beethoven.mp3',
227
+ '',
228
+ 'Heavy Rock.',
229
+ 40,
230
+ 'stabilityai/stable-audio-open-1.0',
231
+ '27s',
232
+ 'Examples/Beethoven_rock.mp3',
233
  ],
234
+ ['Examples/ModalJazz.mp3',
235
  'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
236
  'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
237
  45,
238
  'cvssp/audioldm2-music',
239
  '106s',
240
+ 'Examples/ModalJazz_banjo.mp3',],
241
+ ['Examples/Shadows.mp3',
242
+ '',
243
+ '8-bit arcade game soundtrack.',
244
+ 40,
245
+ 'stabilityai/stable-audio-open-1.0',
246
+ '34s',
247
+ 'Examples/Shadows_arcade.mp3',],
248
+ ['Examples/Cat.mp3',
249
  '',
250
  'A dog barking.',
251
  75,
252
  'cvssp/audioldm2-large',
253
  '10s',
254
+ 'Examples/Cat_dog.mp3',]
255
  ]
256
  return case
257
 
258
 
259
  intro = """
260
+ <h1 style="font-weight: 1000; text-align: center; margin: 0px;"> ZETA Editing 🎧 </h1>
261
+ <h2 style="font-weight: 1000; text-align: center; margin: 0px;">
262
+ Zero-Shot Text-Based Audio Editing Using DDPM Inversion 🎛️ </h2>
263
+ <h3 style="margin-top: 0px; margin-bottom: 10px; text-align: center;">
264
  <a href="https://arxiv.org/abs/2402.10009">[Paper]</a>&nbsp;|&nbsp;
265
  <a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a>&nbsp;|&nbsp;
266
  <a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
267
  </h3>
268
 
269
+ <p style="font-size: 1rem; line-height: 1.2em;">
 
270
  For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
271
  <a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
272
+ <img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" >
273
+ </a>
274
+ </p>
275
+ <p style="margin: 0px;">
276
+ <b>NEW - 15.10.24:</b> You can now edit using <b>Stable Audio Open 1.0</b>.
277
+ You must be <b>logged in</b> after accepting the
278
+ <b><a href="https://huggingface.co/stabilityai/stable-audio-open-1.0">license agreement</a></b> to use it.</br>
279
+ </p>
280
+ <ul style="padding-left:40px; line-height:normal;">
281
+ <li style="margin: 0px;">Prompts behave differently - e.g.,
282
+ try "8-bit arcade" directly instead of "a recording of...". Check out the new examples below!</li>
283
+ <li style="margin: 0px;">Try to play around <code>T-start=40%</code>.</li>
284
+ <li style="margin: 0px;">Under "More Options": Use <code>Source Guidance Scale=1</code>,
285
+ and you can try fewer timesteps (even 20!).</li>
286
+ <li style="margin: 0px;">Stable Audio Open is a general-audio model.
287
+ For better music editing, duplicate the space and change to a
288
+ <a href="https://huggingface.co/models?other=base_model:finetune:stabilityai/stable-audio-open-1.0">
289
+ fine-tuned model for music</a>.</li>
290
+ </ul>
291
+ <p>
292
+ <b>NEW - 15.10.24:</b> Parallel editing is enabled by default.
293
+ To disable, uncheck <code>Efficient editing</code> under "More Options".
294
+ Saves a bit of time.
295
  </p>
 
296
  """
297
 
298
+
299
  help = """
300
  <div style="font-size:medium">
301
  <b>Instructions:</b><br>
 
303
  <li>You must provide an input audio and a target prompt to edit the audio. </li>
304
  <li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
305
  Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
306
+ <li>Make sure that you use a model version that is suitable for your input audio.
307
+ For example, use AudioLDM2-music for music while AudioLDM2-large for general audio.
308
  </li>
309
  <li>You can additionally provide a source prompt to guide even further the editing process.</li>
310
  <li>Longer input will take more time.</li>
311
  <li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
312
+ For unlimited length, duplicated the space, and change the
313
+ <code style="display:inline; background-color: lightgrey;">MAX_DURATION</code> parameter
314
+ inside <code style="display:inline; background-color: lightgrey;">utils.py</code>
315
+ to <code style="display:inline; background-color: lightgrey;">None</code>.
316
+ </li>
317
  </ul>
318
  </div>
319
 
320
  """
321
 
322
+ css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \
323
+ '.audio-upload .wrap {min-height: 0px;}'
324
+
325
+ # with gr.Blocks(css='style.css') as demo:
326
+ with gr.Blocks(css=css) as demo:
327
  def reset_do_inversion(do_inversion_user, do_inversion):
328
  # do_inversion = gr.State(value=True)
329
  do_inversion = True
 
342
  return do_inversion_user, do_inversion
343
 
344
  gr.HTML(intro)
345
+
346
  wts = gr.State()
347
  zs = gr.State()
348
+ extra_info = gr.State()
 
349
  saved_inv_model = gr.State()
 
 
 
350
  do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
351
  do_inversion_user = gr.State(value=False)
352
 
353
  with gr.Group():
354
+ gr.Markdown("💡 **note**: input longer than **30 sec** is automatically trimmed "
355
+ "(for unlimited input, see the Help section below)")
356
+ with gr.Row(equal_height=True):
357
+ input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath",
358
+ editable=True, label="Input Audio", interactive=True, scale=1, format='wav',
359
+ elem_classes=['audio-upload'])
360
+ output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1, format='wav')
361
 
362
  with gr.Row():
363
  tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
 
367
  with gr.Row():
368
  t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
369
  info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
370
+ model_id = gr.Dropdown(label="Model Version",
371
+ choices=[LDM2,
372
+ LDM2_LARGE,
373
+ MUSIC,
374
+ STABLEAUD],
375
+ info="Choose a checkpoint suitable for your audio and edit",
376
  value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
 
377
  with gr.Row():
378
+ submit = gr.Button("Edit", variant="primary", scale=3)
379
+ gr.LoginButton(value="Login to HF (For Stable Audio)", scale=1)
380
 
381
  with gr.Accordion("More Options", open=False):
382
  with gr.Row():
 
384
  info="Optional: Describe the original audio input",
385
  placeholder="A recording of a happy upbeat classical music piece",)
386
 
387
+ with gr.Row(equal_height=True):
388
  cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
389
  label="Source Guidance Scale", interactive=True, scale=1)
390
  cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
391
  label="Target Guidance Scale", interactive=True, scale=1)
392
+ steps = gr.Number(value=50, step=1, minimum=10, maximum=300,
393
  info="Higher values (e.g. 200) yield higher-quality generation.",
394
+ label="Num Diffusion Steps", interactive=True, scale=2)
395
+ with gr.Row(equal_height=True):
396
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
397
  randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
398
+ save_compute = gr.Checkbox(label='Efficient editing', value=True)
399
  length = gr.Number(label="Length", interactive=False, visible=False)
400
 
401
  with gr.Accordion("Help💡", open=False):
402
  gr.HTML(help)
403
 
404
  submit.click(
405
+ fn=verify_model_params,
406
+ inputs=[model_id, input_audio, src_prompt, tar_prompt, cfg_scale_src],
407
+ outputs=[]
408
+ ).success(
409
+ fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=[seed], queue=False
410
+ ).then(
411
+ fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]
412
+ ).then(
413
+ fn=edit,
414
+ inputs=[input_audio,
415
+ model_id,
416
+ do_inversion,
417
+ wts, zs, extra_info,
418
+ saved_inv_model,
419
+ src_prompt,
420
+ tar_prompt,
421
+ steps,
422
+ cfg_scale_src,
423
+ cfg_scale_tar,
424
+ t_start,
425
+ randomize_seed,
426
+ save_compute,
427
+ ],
428
+ outputs=[output_audio, wts, zs, extra_info, saved_inv_model, do_inversion]
429
+ ).success(
430
+ fn=post_match_do_inversion,
431
+ inputs=[do_inversion_user, do_inversion],
432
+ outputs=[do_inversion_user, do_inversion]
433
+ )
434
 
435
  # If sources changed we have to rerun inversion
436
+ gr.on(
437
+ triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change,
438
+ steps.change, save_compute.change],
439
+ fn=reset_do_inversion,
440
+ inputs=[do_inversion_user, do_inversion],
441
+ outputs=[do_inversion_user, do_inversion]
442
+ )
443
 
444
  gr.Examples(
445
  label="Examples",
inversion_utils.py CHANGED
@@ -1,341 +1,135 @@
1
  import torch
2
  from tqdm import tqdm
3
- # from torchvision import transforms as T
4
- from typing import List, Optional, Dict, Union
5
  from models import PipelineWrapper
6
-
7
-
8
- def mu_tilde(model, xt, x0, timestep):
9
- "mu_tilde(x_t, x_0) DDPM paper eq. 7"
10
- prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
11
- alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
12
- else model.scheduler.final_alpha_cumprod
13
- alpha_t = model.scheduler.alphas[timestep]
14
- beta_t = 1 - alpha_t
15
- alpha_bar = model.scheduler.alphas_cumprod[timestep]
16
- return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + \
17
- ((alpha_t**0.5 * (1-alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
18
-
19
-
20
- def sample_xts_from_x0(model, x0, num_inference_steps=50, x_prev_mode=False):
21
- """
22
- Samples from P(x_1:T|x_0)
23
- """
24
- # torch.manual_seed(43256465436)
25
- alpha_bar = model.model.scheduler.alphas_cumprod
26
- sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
27
- alphas = model.model.scheduler.alphas
28
- # betas = 1 - alphas
29
- variance_noise_shape = (
30
- num_inference_steps + 1,
31
- model.model.unet.config.in_channels,
32
- # model.unet.sample_size,
33
- # model.unet.sample_size)
34
- x0.shape[-2],
35
- x0.shape[-1])
36
-
37
- timesteps = model.model.scheduler.timesteps.to(model.device)
38
- t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
39
- xts = torch.zeros(variance_noise_shape).to(x0.device)
40
- xts[0] = x0
41
- x_prev = x0
42
- for t in reversed(timesteps):
43
- # idx = t_to_idx[int(t)]
44
- idx = num_inference_steps-t_to_idx[int(t)]
45
- if x_prev_mode:
46
- xts[idx] = x_prev * (alphas[t] ** 0.5) + torch.randn_like(x0) * ((1-alphas[t]) ** 0.5)
47
- x_prev = xts[idx].clone()
48
- else:
49
- xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
50
- # xts = torch.cat([xts, x0 ],dim = 0)
51
-
52
- return xts
53
-
54
-
55
- def forward_step(model, model_output, timestep, sample):
56
- next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
57
- timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
58
-
59
- # 2. compute alphas, betas
60
- alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
61
- # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 \
62
- # else self.scheduler.final_alpha_cumprod
63
-
64
- beta_prod_t = 1 - alpha_prod_t
65
-
66
- # 3. compute predicted original sample from predicted noise also called
67
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
68
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
69
-
70
- # 5. TODO: simple noising implementatiom
71
- next_sample = model.scheduler.add_noise(pred_original_sample, model_output, torch.LongTensor([next_timestep]))
72
- return next_sample
73
 
74
 
75
  def inversion_forward_process(model: PipelineWrapper,
76
  x0: torch.Tensor,
77
  etas: Optional[float] = None,
78
- prog_bar: bool = False,
79
  prompts: List[str] = [""],
80
  cfg_scales: List[float] = [3.5],
81
  num_inference_steps: int = 50,
82
- eps: Optional[float] = None,
83
- cutoff_points: Optional[List[float]] = None,
84
  numerical_fix: bool = False,
85
- extract_h_space: bool = False,
86
- extract_skipconns: bool = False,
87
- x_prev_mode: bool = False):
88
- if len(prompts) > 1 and extract_h_space:
89
- raise NotImplementedError("How do you split cfg_scales for hspace? TODO")
90
-
91
  if len(prompts) > 1 or prompts[0] != "":
92
  text_embeddings_hidden_states, text_embeddings_class_labels, \
93
  text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
94
- # text_embeddings = encode_text(model, prompt)
95
-
96
- # # classifier free guidance
97
- batch_size = len(prompts)
98
- cfg_scales_tensor = torch.ones((batch_size, *x0.shape[1:]), device=model.device, dtype=x0.dtype)
99
-
100
- # if len(prompts) > 1:
101
- # if cutoff_points is None:
102
- # cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
103
- # if len(cfg_scales) == 1:
104
- # cfg_scales *= batch_size
105
- # elif len(cfg_scales) < batch_size:
106
- # raise ValueError("Not enough target CFG scales")
107
-
108
- # cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
109
- # cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
110
 
111
- # for i, (start, end) in enumerate(zip(cutoff_points[:-1], cutoff_points[1:])):
112
- # cfg_scales_tensor[i, :, end:] = 0
113
- # cfg_scales_tensor[i, :, :start] = 0
114
- # cfg_scales_tensor[i] *= cfg_scales[i]
115
- # if prompts[i] == "":
116
- # cfg_scales_tensor[i] = 0
117
- # cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1)
118
- # else:
119
- cfg_scales_tensor *= cfg_scales[0]
120
 
121
- uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = model.encode_text([""])
122
- # uncond_embedding = encode_text(model, "")
123
  timesteps = model.model.scheduler.timesteps.to(model.device)
124
- variance_noise_shape = (
125
- num_inference_steps,
126
- model.model.unet.config.in_channels,
127
- # model.unet.sample_size,
128
- # model.unet.sample_size)
129
- x0.shape[-2],
130
- x0.shape[-1])
131
 
132
- if etas is None or (type(etas) in [int, float] and etas == 0):
133
- eta_is_zero = True
134
- zs = None
135
- else:
136
- eta_is_zero = False
137
- if type(etas) in [int, float]:
138
- etas = [etas]*model.model.scheduler.num_inference_steps
139
- xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps, x_prev_mode=x_prev_mode)
140
- alpha_bar = model.model.scheduler.alphas_cumprod
141
- zs = torch.zeros(size=variance_noise_shape, device=model.device)
142
- hspaces = []
143
- skipconns = []
144
- t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
145
  xt = x0
146
- # op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
147
- op = tqdm(timesteps) if prog_bar else timesteps
 
 
 
 
148
 
149
- for t in op:
150
- # idx = t_to_idx[int(t)]
151
- idx = num_inference_steps - t_to_idx[int(t)] - 1
152
  # 1. predict noise residual
153
- if not eta_is_zero:
154
- xt = xts[idx+1][None]
155
 
156
  with torch.no_grad():
157
- out, out_hspace, out_skipconns = model.unet_forward(xt, timestep=t,
158
- encoder_hidden_states=uncond_embedding_hidden_states,
159
- class_labels=uncond_embedding_class_lables,
160
- encoder_attention_mask=uncond_boolean_prompt_mask)
161
- # out = model.unet.forward(xt, timestep= t, encoder_hidden_states=uncond_embedding)
162
- if len(prompts) > 1 or prompts[0] != "":
163
- cond_out, cond_out_hspace, cond_out_skipconns = model.unet_forward(
164
- xt.expand(len(prompts), -1, -1, -1), timestep=t,
165
- encoder_hidden_states=text_embeddings_hidden_states,
166
- class_labels=text_embeddings_class_labels,
167
- encoder_attention_mask=text_embeddings_boolean_prompt_mask)
168
- # cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  if len(prompts) > 1 or prompts[0] != "":
171
  # # classifier free guidance
172
- noise_pred = out.sample + \
173
- (cfg_scales_tensor * (cond_out.sample - out.sample.expand(batch_size, -1, -1, -1))
174
- ).sum(axis=0).unsqueeze(0)
175
- if extract_h_space or extract_skipconns:
176
- noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
177
- if extract_skipconns:
178
- noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
179
- (cond_out_skipconns[k][j] - out_skipconns[k][j])
180
- for j in range(len(out_skipconns[k]))]
181
- for k in out_skipconns}
182
- else:
183
- noise_pred = out.sample
184
- if extract_h_space or extract_skipconns:
185
- noise_h_space = out_hspace
186
- if extract_skipconns:
187
- noise_skipconns = out_skipconns
188
- if extract_h_space or extract_skipconns:
189
- hspaces.append(noise_h_space)
190
- if extract_skipconns:
191
- skipconns.append(noise_skipconns)
192
-
193
- if eta_is_zero:
194
- # 2. compute more noisy image and set x_t -> x_t+1
195
- xt = forward_step(model.model, noise_pred, t, xt)
196
  else:
197
- # xtm1 = xts[idx+1][None]
198
- xtm1 = xts[idx][None]
199
- # pred of x0
200
- if model.model.scheduler.config.prediction_type == 'epsilon':
201
- pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
202
- elif model.model.scheduler.config.prediction_type == 'v_prediction':
203
- pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
204
-
205
- # direction to xt
206
- prev_timestep = t - model.model.scheduler.config.num_train_timesteps // \
207
- model.model.scheduler.num_inference_steps
208
-
209
- alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
210
- variance = model.get_variance(t, prev_timestep)
211
-
212
- if model.model.scheduler.config.prediction_type == 'epsilon':
213
- radom_noise_pred = noise_pred
214
- elif model.model.scheduler.config.prediction_type == 'v_prediction':
215
- radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
216
-
217
- pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * radom_noise_pred
218
-
219
- mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
220
-
221
- z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
222
-
223
- zs[idx] = z
224
-
225
- # correction to avoid error accumulation
226
- if numerical_fix:
227
- xtm1 = mu_xt + (etas[idx] * variance ** 0.5)*z
228
- xts[idx] = xtm1
229
 
230
  if zs is not None:
231
  # zs[-1] = torch.zeros_like(zs[-1])
232
  zs[0] = torch.zeros_like(zs[0])
233
  # zs_cycle[0] = torch.zeros_like(zs[0])
234
 
235
- if extract_h_space:
236
- hspaces = torch.concat(hspaces, axis=0)
237
- return xt, zs, xts, hspaces
238
-
239
- if extract_skipconns:
240
- hspaces = torch.concat(hspaces, axis=0)
241
- return xt, zs, xts, hspaces, skipconns
242
-
243
- return xt, zs, xts
244
-
245
-
246
- def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
247
- # 1. get previous step value (=t-1)
248
- prev_timestep = timestep - model.model.scheduler.config.num_train_timesteps // \
249
- model.model.scheduler.num_inference_steps
250
- # 2. compute alphas, betas
251
- alpha_prod_t = model.model.scheduler.alphas_cumprod[timestep]
252
- alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
253
- beta_prod_t = 1 - alpha_prod_t
254
- # 3. compute predicted original sample from predicted noise also called
255
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
256
- if model.model.scheduler.config.prediction_type == 'epsilon':
257
- pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
258
- elif model.model.scheduler.config.prediction_type == 'v_prediction':
259
- pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
260
-
261
- # 5. compute variance: "sigma_t(η)" -> see formula (16)
262
- # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
263
- # variance = self.scheduler._get_variance(timestep, prev_timestep)
264
- variance = model.get_variance(timestep, prev_timestep)
265
- # std_dev_t = eta * variance ** (0.5)
266
- # Take care of asymetric reverse process (asyrp)
267
- if model.model.scheduler.config.prediction_type == 'epsilon':
268
- model_output_direction = model_output
269
- elif model.model.scheduler.config.prediction_type == 'v_prediction':
270
- model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
271
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
272
- # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
273
- pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
274
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
275
- prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
276
- # 8. Add noice if eta > 0
277
- if eta > 0:
278
- if variance_noise is None:
279
- variance_noise = torch.randn(model_output.shape, device=model.device)
280
- sigma_z = eta * variance ** (0.5) * variance_noise
281
- prev_sample = prev_sample + sigma_z
282
-
283
- return prev_sample
284
 
285
 
286
  def inversion_reverse_process(model: PipelineWrapper,
287
  xT: torch.Tensor,
288
- skips: torch.Tensor,
289
- fix_alpha: float = 0.1,
290
  etas: float = 0,
291
  prompts: List[str] = [""],
292
  neg_prompts: List[str] = [""],
293
  cfg_scales: Optional[List[float]] = None,
294
- prog_bar: bool = False,
295
  zs: Optional[List[torch.Tensor]] = None,
296
- # controller=None,
297
- cutoff_points: Optional[List[float]] = None,
298
- hspace_add: Optional[torch.Tensor] = None,
299
- hspace_replace: Optional[torch.Tensor] = None,
300
- skipconns_replace: Optional[Dict[int, torch.Tensor]] = None,
301
- zero_out_resconns: Optional[Union[int, List]] = None,
302
- asyrp: bool = False,
303
- extract_h_space: bool = False,
304
- extract_skipconns: bool = False):
305
-
306
- batch_size = len(prompts)
307
 
308
  text_embeddings_hidden_states, text_embeddings_class_labels, \
309
  text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
310
- uncond_embedding_hidden_states, uncond_embedding_class_lables, \
311
- uncond_boolean_prompt_mask = model.encode_text(neg_prompts)
312
- # text_embeddings = encode_text(model, prompts)
313
- # uncond_embedding = encode_text(model, [""] * batch_size)
314
-
315
- masks = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype)
316
- cfg_scales_tensor = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype)
317
-
318
- # if batch_size > 1:
319
- # if cutoff_points is None:
320
- # cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
321
- # if len(cfg_scales) == 1:
322
- # cfg_scales *= batch_size
323
- # elif len(cfg_scales) < batch_size:
324
- # raise ValueError("Not enough target CFG scales")
325
-
326
- # cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
327
- # cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
328
 
329
- # for i, (start, end) in enumerate(zip(cutoff_points[:-1], cutoff_points[1:])):
330
- # cfg_scales_tensor[i, :, end:] = 0
331
- # cfg_scales_tensor[i, :, :start] = 0
332
- # masks[i, :, end:] = 0
333
- # masks[i, :, :start] = 0
334
- # cfg_scales_tensor[i] *= cfg_scales[i]
335
- # cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1)
336
- # masks = T.functional.gaussian_blur(masks, kernel_size=15, sigma=1)
337
- # else:
338
- cfg_scales_tensor *= cfg_scales[0]
339
 
340
  if etas is None:
341
  etas = 0
@@ -344,107 +138,71 @@ def inversion_reverse_process(model: PipelineWrapper,
344
  assert len(etas) == model.model.scheduler.num_inference_steps
345
  timesteps = model.model.scheduler.timesteps.to(model.device)
346
 
347
- # xt = xT.expand(1, -1, -1, -1)
348
- xt = xT[skips.max()].unsqueeze(0)
349
- op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
 
 
 
 
 
 
 
 
 
 
 
350
 
351
- t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
352
- hspaces = []
353
- skipconns = []
354
-
355
- for it, t in enumerate(op):
356
- # idx = t_to_idx[int(t)]
357
- idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t)] - \
358
- (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
359
  # # Unconditional embedding
360
  with torch.no_grad():
361
- uncond_out, out_hspace, out_skipconns = model.unet_forward(
362
- xt, timestep=t,
363
- encoder_hidden_states=uncond_embedding_hidden_states,
364
- class_labels=uncond_embedding_class_lables,
365
- encoder_attention_mask=uncond_boolean_prompt_mask,
366
- mid_block_additional_residual=(None if hspace_add is None else
367
- (1 / (cfg_scales[0] + 1)) *
368
- (hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1
369
- else hspace_add)),
370
- replace_h_space=(None if hspace_replace is None else
371
- (hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1
372
- else hspace_replace)),
373
- zero_out_resconns=zero_out_resconns,
374
- replace_skip_conns=(None if skipconns_replace is None else
375
- (skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1
376
- else skipconns_replace))
377
- ) # encoder_hidden_states = uncond_embedding)
378
-
379
- # # Conditional embedding
380
- if prompts:
381
- with torch.no_grad():
382
- cond_out, cond_out_hspace, cond_out_skipconns = model.unet_forward(
383
- xt.expand(batch_size, -1, -1, -1),
 
 
 
 
 
 
 
 
 
 
 
384
  timestep=t,
385
  encoder_hidden_states=text_embeddings_hidden_states,
386
  class_labels=text_embeddings_class_labels,
387
  encoder_attention_mask=text_embeddings_boolean_prompt_mask,
388
- mid_block_additional_residual=(None if hspace_add is None else
389
- (cfg_scales[0] / (cfg_scales[0] + 1)) *
390
- (hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1
391
- else hspace_add)),
392
- replace_h_space=(None if hspace_replace is None else
393
- (hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1
394
- else hspace_replace)),
395
- zero_out_resconns=zero_out_resconns,
396
- replace_skip_conns=(None if skipconns_replace is None else
397
- (skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1
398
- else skipconns_replace))
399
- ) # encoder_hidden_states = text_embeddings)
400
 
401
  z = zs[idx] if zs is not None else None
402
- # print(f'idx: {idx}')
403
- # print(f't: {t}')
404
  z = z.unsqueeze(0)
405
- # z = z.expand(batch_size, -1, -1, -1)
406
- if prompts:
407
- # # classifier free guidance
408
- # noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
409
- noise_pred = uncond_out.sample + \
410
- (cfg_scales_tensor * (cond_out.sample - uncond_out.sample.expand(batch_size, -1, -1, -1))
411
- ).sum(axis=0).unsqueeze(0)
412
- if extract_h_space or extract_skipconns:
413
- noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
414
- if extract_skipconns:
415
- noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
416
- (cond_out_skipconns[k][j] - out_skipconns[k][j])
417
- for j in range(len(out_skipconns[k]))]
418
- for k in out_skipconns}
419
- else:
420
- noise_pred = uncond_out.sample
421
- if extract_h_space or extract_skipconns:
422
- noise_h_space = out_hspace
423
- if extract_skipconns:
424
- noise_skipconns = out_skipconns
425
-
426
- if extract_h_space or extract_skipconns:
427
- hspaces.append(noise_h_space)
428
- if extract_skipconns:
429
- skipconns.append(noise_skipconns)
430
 
431
  # 2. compute less noisy image and set x_t -> x_t-1
432
- xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
433
- # if controller is not None:
434
- # xt = controller.step_callback(xt)
435
-
436
- # "fix" xt
437
- apply_fix = ((skips.max() - skips) > it)
438
- if apply_fix.any():
439
- apply_fix = (apply_fix * fix_alpha).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(xT.device)
440
- xt = (masks * (xt.expand(batch_size, -1, -1, -1) * (1 - apply_fix) +
441
- apply_fix * xT[skips.max() - it - 1].expand(batch_size, -1, -1, -1))
442
- ).sum(axis=0).unsqueeze(0)
443
-
444
- if extract_h_space:
445
- return xt, zs, torch.concat(hspaces, axis=0)
446
-
447
- if extract_skipconns:
448
- return xt, zs, torch.concat(hspaces, axis=0), skipconns
449
 
 
450
  return xt, zs
 
1
  import torch
2
  from tqdm import tqdm
3
+ from typing import List, Optional, Tuple
 
4
  from models import PipelineWrapper
5
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def inversion_forward_process(model: PipelineWrapper,
9
  x0: torch.Tensor,
10
  etas: Optional[float] = None,
 
11
  prompts: List[str] = [""],
12
  cfg_scales: List[float] = [3.5],
13
  num_inference_steps: int = 50,
 
 
14
  numerical_fix: bool = False,
15
+ duration: Optional[float] = None,
16
+ first_order: bool = False,
17
+ save_compute: bool = True,
18
+ progress=gr.Progress()) -> Tuple:
 
 
19
  if len(prompts) > 1 or prompts[0] != "":
20
  text_embeddings_hidden_states, text_embeddings_class_labels, \
21
  text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # In the forward negative prompts are not supported currently (TODO)
24
+ uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
25
+ [""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1]
26
+ if text_embeddings_class_labels is not None else None)
27
+ else:
28
+ uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
29
+ [""], negative=True, save_compute=False)
 
 
30
 
 
 
31
  timesteps = model.model.scheduler.timesteps.to(model.device)
32
+ variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
 
 
 
 
 
 
33
 
34
+ if type(etas) in [int, float]:
35
+ etas = [etas]*model.model.scheduler.num_inference_steps
36
+ xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
37
+ zs = torch.zeros(size=variance_noise_shape, device=model.device)
38
+ extra_info = [None] * len(zs)
39
+
40
+ if timesteps[0].dtype == torch.int64:
41
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
42
+ elif timesteps[0].dtype == torch.float32:
43
+ t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
 
 
 
44
  xt = x0
45
+ op = tqdm(timesteps, desc="Inverting")
46
+ model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration,
47
+ save_compute=save_compute and prompts[0] != "")
48
+ app_op = progress.tqdm(timesteps, desc="Inverting")
49
+ for t, _ in zip(op, app_op):
50
+ idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
51
 
 
 
 
52
  # 1. predict noise residual
53
+ xt = xts[idx+1][None]
54
+ xt_inp = model.model.scheduler.scale_model_input(xt, t)
55
 
56
  with torch.no_grad():
57
+ if save_compute and prompts[0] != "":
58
+ comb_out, _, _ = model.unet_forward(
59
+ xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
60
+ timestep=t,
61
+ encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
62
+ ], dim=0)
63
+ if uncond_embeddings_hidden_states is not None else None,
64
+ class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
65
+ if uncond_embeddings_class_lables is not None else None,
66
+ encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
67
+ ], dim=0)
68
+ if uncond_boolean_prompt_mask is not None else None,
69
+ )
70
+ out, cond_out = comb_out.sample.chunk(2, dim=0)
71
+ else:
72
+ out = model.unet_forward(xt_inp, timestep=t,
73
+ encoder_hidden_states=uncond_embeddings_hidden_states,
74
+ class_labels=uncond_embeddings_class_lables,
75
+ encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
76
+ if len(prompts) > 1 or prompts[0] != "":
77
+ cond_out = model.unet_forward(
78
+ xt_inp,
79
+ timestep=t,
80
+ encoder_hidden_states=text_embeddings_hidden_states,
81
+ class_labels=text_embeddings_class_labels,
82
+ encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
83
 
84
  if len(prompts) > 1 or prompts[0] != "":
85
  # # classifier free guidance
86
+ noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  else:
88
+ noise_pred = out
89
+
90
+ # xtm1 = xts[idx+1][None]
91
+ xtm1 = xts[idx][None]
92
+ z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t,
93
+ eta=etas[idx], numerical_fix=numerical_fix,
94
+ first_order=first_order)
95
+ zs[idx] = z
96
+ # print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}")
97
+ xts[idx] = xtm1
98
+ extra_info[idx] = extra
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if zs is not None:
101
  # zs[-1] = torch.zeros_like(zs[-1])
102
  zs[0] = torch.zeros_like(zs[0])
103
  # zs_cycle[0] = torch.zeros_like(zs[0])
104
 
105
+ del app_op.iterables[0]
106
+ return xt, zs, xts, extra_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  def inversion_reverse_process(model: PipelineWrapper,
110
  xT: torch.Tensor,
111
+ tstart: torch.Tensor,
 
112
  etas: float = 0,
113
  prompts: List[str] = [""],
114
  neg_prompts: List[str] = [""],
115
  cfg_scales: Optional[List[float]] = None,
 
116
  zs: Optional[List[torch.Tensor]] = None,
117
+ duration: Optional[float] = None,
118
+ first_order: bool = False,
119
+ extra_info: Optional[List] = None,
120
+ save_compute: bool = True,
121
+ progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
122
 
123
  text_embeddings_hidden_states, text_embeddings_class_labels, \
124
  text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
125
+ uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \
126
+ uncond_boolean_prompt_mask = model.encode_text(neg_prompts,
127
+ negative=True,
128
+ save_compute=save_compute,
129
+ cond_length=text_embeddings_class_labels.shape[1]
130
+ if text_embeddings_class_labels is not None else None)
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ xt = xT[tstart.max()].unsqueeze(0)
 
 
 
 
 
 
 
 
 
133
 
134
  if etas is None:
135
  etas = 0
 
138
  assert len(etas) == model.model.scheduler.num_inference_steps
139
  timesteps = model.model.scheduler.timesteps.to(model.device)
140
 
141
+ op = tqdm(timesteps[-zs.shape[0]:], desc="Editing")
142
+ if timesteps[0].dtype == torch.int64:
143
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
144
+ elif timesteps[0].dtype == torch.float32:
145
+ t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
146
+ model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]],
147
+ audio_end_in_s=duration, save_compute=save_compute)
148
+ app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing")
149
+ for it, (t, _) in enumerate(zip(op, app_op)):
150
+ idx = model.model.scheduler.num_inference_steps - t_to_idx[
151
+ int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \
152
+ (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
153
+
154
+ xt_inp = model.model.scheduler.scale_model_input(xt, t)
155
 
 
 
 
 
 
 
 
 
156
  # # Unconditional embedding
157
  with torch.no_grad():
158
+ # print(f'xt_inp.shape: {xt_inp.shape}')
159
+ # print(f't.shape: {t.shape}')
160
+ # print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}')
161
+ # print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}')
162
+ # print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}')
163
+ # print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}')
164
+ # print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}')
165
+ # print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}')
166
+
167
+ if save_compute:
168
+ comb_out, _, _ = model.unet_forward(
169
+ xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
170
+ timestep=t,
171
+ encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
172
+ ], dim=0)
173
+ if uncond_embeddings_hidden_states is not None else None,
174
+ class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
175
+ if uncond_embeddings_class_lables is not None else None,
176
+ encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
177
+ ], dim=0)
178
+ if uncond_boolean_prompt_mask is not None else None,
179
+ )
180
+ uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
181
+ else:
182
+ uncond_out = model.unet_forward(
183
+ xt_inp, timestep=t,
184
+ encoder_hidden_states=uncond_embeddings_hidden_states,
185
+ class_labels=uncond_embeddings_class_lables,
186
+ encoder_attention_mask=uncond_boolean_prompt_mask,
187
+ )[0].sample
188
+
189
+ # Conditional embedding
190
+ cond_out = model.unet_forward(
191
+ xt_inp,
192
  timestep=t,
193
  encoder_hidden_states=text_embeddings_hidden_states,
194
  class_labels=text_embeddings_class_labels,
195
  encoder_attention_mask=text_embeddings_boolean_prompt_mask,
196
+ )[0].sample
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  z = zs[idx] if zs is not None else None
 
 
199
  z = z.unsqueeze(0)
200
+ # classifier free guidance
201
+ noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  # 2. compute less noisy image and set x_t -> x_t-1
204
+ xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z,
205
+ eta=etas[idx], first_order=first_order)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ del app_op.iterables[0]
208
  return xt, zs
models.py CHANGED
@@ -1,46 +1,160 @@
1
  import torch
2
- from diffusers import DDIMScheduler
3
- from diffusers import AudioLDM2Pipeline
4
- from transformers import RobertaTokenizer, RobertaTokenizerFast
 
5
  from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
 
6
  from typing import Any, Dict, List, Optional, Tuple, Union
 
7
 
8
 
9
  class PipelineWrapper(torch.nn.Module):
10
- def __init__(self, model_id, device, double_precision=False, *args, **kwargs) -> None:
 
 
 
11
  super().__init__(*args, **kwargs)
12
  self.model_id = model_id
13
  self.device = device
14
  self.double_precision = double_precision
 
15
 
16
- def get_sigma(self, timestep) -> float:
17
  sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
18
  return sqrt_recipm1_alphas_cumprod[timestep]
19
 
20
- def load_scheduler(self):
21
  pass
22
 
23
- def get_fn_STFT(self):
24
  pass
25
 
26
- def vae_encode(self, x: torch.Tensor):
 
 
 
 
 
 
27
  pass
28
 
29
- def vae_decode(self, x: torch.Tensor):
30
  pass
31
 
32
- def decode_to_mel(self, x: torch.Tensor):
33
  pass
34
 
35
- def encode_text(self, prompts: List[str]) -> Tuple:
 
36
  pass
37
 
38
- def get_variance(self, timestep, prev_timestep):
39
  pass
40
 
41
- def get_alpha_prod_t_prev(self, prev_timestep):
42
  pass
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def unet_forward(self,
45
  sample: torch.FloatTensor,
46
  timestep: Union[torch.Tensor, float, int],
@@ -57,244 +171,27 @@ class PipelineWrapper(torch.nn.Module):
57
  replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
58
  return_dict: bool = True,
59
  zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
60
-
61
- # By default samples have to be AT least a multiple of the overall upsampling factor.
62
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
63
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
64
- # on the fly if necessary.
65
- default_overall_up_factor = 2**self.model.unet.num_upsamplers
66
-
67
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
68
- forward_upsample_size = False
69
- upsample_size = None
70
-
71
- if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
72
- # logger.info("Forward upsample size to force interpolation output size.")
73
- forward_upsample_size = True
74
-
75
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
76
- # expects mask of shape:
77
- # [batch, key_tokens]
78
- # adds singleton query_tokens dimension:
79
- # [batch, 1, key_tokens]
80
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
81
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
82
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
83
- if attention_mask is not None:
84
- # assume that mask is expressed as:
85
- # (1 = keep, 0 = discard)
86
- # convert mask into a bias that can be added to attention scores:
87
- # (keep = +0, discard = -10000.0)
88
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
89
- attention_mask = attention_mask.unsqueeze(1)
90
-
91
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
92
- if encoder_attention_mask is not None:
93
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
94
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
95
-
96
- # 0. center input if necessary
97
- if self.model.unet.config.center_input_sample:
98
- sample = 2 * sample - 1.0
99
-
100
- # 1. time
101
- timesteps = timestep
102
- if not torch.is_tensor(timesteps):
103
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
104
- # This would be a good case for the `match` statement (Python 3.10+)
105
- is_mps = sample.device.type == "mps"
106
- if isinstance(timestep, float):
107
- dtype = torch.float32 if is_mps else torch.float64
108
- else:
109
- dtype = torch.int32 if is_mps else torch.int64
110
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
111
- elif len(timesteps.shape) == 0:
112
- timesteps = timesteps[None].to(sample.device)
113
-
114
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
115
- timesteps = timesteps.expand(sample.shape[0])
116
-
117
- t_emb = self.model.unet.time_proj(timesteps)
118
-
119
- # `Timesteps` does not contain any weights and will always return f32 tensors
120
- # but time_embedding might actually be running in fp16. so we need to cast here.
121
- # there might be better ways to encapsulate this.
122
- t_emb = t_emb.to(dtype=sample.dtype)
123
-
124
- emb = self.model.unet.time_embedding(t_emb, timestep_cond)
125
-
126
- if self.model.unet.class_embedding is not None:
127
- if class_labels is None:
128
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
129
-
130
- if self.model.unet.config.class_embed_type == "timestep":
131
- class_labels = self.model.unet.time_proj(class_labels)
132
-
133
- # `Timesteps` does not contain any weights and will always return f32 tensors
134
- # there might be better ways to encapsulate this.
135
- class_labels = class_labels.to(dtype=sample.dtype)
136
-
137
- class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
138
-
139
- if self.model.unet.config.class_embeddings_concat:
140
- emb = torch.cat([emb, class_emb], dim=-1)
141
- else:
142
- emb = emb + class_emb
143
-
144
- if self.model.unet.config.addition_embed_type == "text":
145
- aug_emb = self.model.unet.add_embedding(encoder_hidden_states)
146
- emb = emb + aug_emb
147
- elif self.model.unet.config.addition_embed_type == "text_image":
148
- # Kadinsky 2.1 - style
149
- if "image_embeds" not in added_cond_kwargs:
150
- raise ValueError(
151
- f"{self.model.unet.__class__} has the config param `addition_embed_type` set to 'text_image' "
152
- f"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
153
- )
154
-
155
- image_embs = added_cond_kwargs.get("image_embeds")
156
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
157
-
158
- aug_emb = self.model.unet.add_embedding(text_embs, image_embs)
159
- emb = emb + aug_emb
160
-
161
- if self.model.unet.time_embed_act is not None:
162
- emb = self.model.unet.time_embed_act(emb)
163
-
164
- if self.model.unet.encoder_hid_proj is not None and self.model.unet.config.encoder_hid_dim_type == "text_proj":
165
- encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states)
166
- elif self.model.unet.encoder_hid_proj is not None and \
167
- self.model.unet.config.encoder_hid_dim_type == "text_image_proj":
168
- # Kadinsky 2.1 - style
169
- if "image_embeds" not in added_cond_kwargs:
170
- raise ValueError(
171
- f"{self.model.unet.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
172
- f"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
173
- )
174
-
175
- image_embeds = added_cond_kwargs.get("image_embeds")
176
- encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states, image_embeds)
177
-
178
- # 2. pre-process
179
- sample = self.model.unet.conv_in(sample)
180
-
181
- # 3. down
182
- down_block_res_samples = (sample,)
183
- for downsample_block in self.model.unet.down_blocks:
184
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
185
- sample, res_samples = downsample_block(
186
- hidden_states=sample,
187
- temb=emb,
188
- encoder_hidden_states=encoder_hidden_states,
189
- attention_mask=attention_mask,
190
- cross_attention_kwargs=cross_attention_kwargs,
191
- encoder_attention_mask=encoder_attention_mask,
192
- )
193
- else:
194
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
195
-
196
- down_block_res_samples += res_samples
197
-
198
- if down_block_additional_residuals is not None:
199
- new_down_block_res_samples = ()
200
-
201
- for down_block_res_sample, down_block_additional_residual in zip(
202
- down_block_res_samples, down_block_additional_residuals
203
- ):
204
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
205
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
206
-
207
- down_block_res_samples = new_down_block_res_samples
208
-
209
- # 4. mid
210
- if self.model.unet.mid_block is not None:
211
- sample = self.model.unet.mid_block(
212
- sample,
213
- emb,
214
- encoder_hidden_states=encoder_hidden_states,
215
- attention_mask=attention_mask,
216
- cross_attention_kwargs=cross_attention_kwargs,
217
- encoder_attention_mask=encoder_attention_mask,
218
- )
219
-
220
- # print(sample.shape)
221
-
222
- if replace_h_space is None:
223
- h_space = sample.clone()
224
- else:
225
- h_space = replace_h_space
226
- sample = replace_h_space.clone()
227
-
228
- if mid_block_additional_residual is not None:
229
- sample = sample + mid_block_additional_residual
230
-
231
- extracted_res_conns = {}
232
- # 5. up
233
- for i, upsample_block in enumerate(self.model.unet.up_blocks):
234
- is_final_block = i == len(self.model.unet.up_blocks) - 1
235
-
236
- res_samples = down_block_res_samples[-len(upsample_block.resnets):]
237
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
238
- if replace_skip_conns is not None and replace_skip_conns.get(i):
239
- res_samples = replace_skip_conns.get(i)
240
-
241
- if zero_out_resconns is not None:
242
- if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \
243
- type(zero_out_resconns) is list and i in zero_out_resconns:
244
- res_samples = [torch.zeros_like(x) for x in res_samples]
245
- # down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples]
246
-
247
- extracted_res_conns[i] = res_samples
248
-
249
- # if we have not reached the final block and need to forward the
250
- # upsample size, we do it here
251
- if not is_final_block and forward_upsample_size:
252
- upsample_size = down_block_res_samples[-1].shape[2:]
253
-
254
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
255
- sample = upsample_block(
256
- hidden_states=sample,
257
- temb=emb,
258
- res_hidden_states_tuple=res_samples,
259
- encoder_hidden_states=encoder_hidden_states,
260
- cross_attention_kwargs=cross_attention_kwargs,
261
- upsample_size=upsample_size,
262
- attention_mask=attention_mask,
263
- encoder_attention_mask=encoder_attention_mask,
264
- )
265
- else:
266
- sample = upsample_block(
267
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
268
- )
269
-
270
- # 6. post-process
271
- if self.model.unet.conv_norm_out:
272
- sample = self.model.unet.conv_norm_out(sample)
273
- sample = self.model.unet.conv_act(sample)
274
- sample = self.model.unet.conv_out(sample)
275
-
276
- if not return_dict:
277
- return (sample,)
278
-
279
- return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
280
 
281
 
282
  class AudioLDM2Wrapper(PipelineWrapper):
283
  def __init__(self, *args, **kwargs) -> None:
284
  super().__init__(*args, **kwargs)
285
  if self.double_precision:
286
- self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64).to(self.device)
 
287
  else:
288
  try:
289
- self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True).to(self.device)
 
290
  except FileNotFoundError:
291
- self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False).to(self.device)
 
292
 
293
- def load_scheduler(self):
294
- # self.model.scheduler = DDIMScheduler.from_config(self.model_id, subfolder="scheduler")
295
  self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
296
 
297
- def get_fn_STFT(self):
298
  from audioldm.audio import TacotronSTFT
299
  return TacotronSTFT(
300
  filter_length=1024,
@@ -306,17 +203,17 @@ class AudioLDM2Wrapper(PipelineWrapper):
306
  mel_fmax=8000,
307
  )
308
 
309
- def vae_encode(self, x):
310
  # self.model.vae.disable_tiling()
311
  if x.shape[2] % 4:
312
  x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
313
  return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
314
  # return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
315
 
316
- def vae_decode(self, x):
317
  return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
318
 
319
- def decode_to_mel(self, x):
320
  if self.double_precision:
321
  tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
322
  tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
@@ -324,7 +221,9 @@ class AudioLDM2Wrapper(PipelineWrapper):
324
  tmp = tmp.unsqueeze(0)
325
  return tmp
326
 
327
- def encode_text(self, prompts: List[str]):
 
 
328
  tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
329
  text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
330
  prompt_embeds_list = []
@@ -333,8 +232,11 @@ class AudioLDM2Wrapper(PipelineWrapper):
333
  for tokenizer, text_encoder in zip(tokenizers, text_encoders):
334
  text_inputs = tokenizer(
335
  prompts,
336
- padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
337
- max_length=tokenizer.model_max_length,
 
 
 
338
  truncation=True,
339
  return_tensors="pt",
340
  )
@@ -404,7 +306,7 @@ class AudioLDM2Wrapper(PipelineWrapper):
404
 
405
  return generated_prompt_embeds, prompt_embeds, attention_mask
406
 
407
- def get_variance(self, timestep, prev_timestep):
408
  alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
409
  alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
410
  beta_prod_t = 1 - alpha_prod_t
@@ -412,7 +314,7 @@ class AudioLDM2Wrapper(PipelineWrapper):
412
  variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
413
  return variance
414
 
415
- def get_alpha_prod_t_prev(self, prev_timestep):
416
  return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
417
  else self.model.scheduler.final_alpha_cumprod
418
 
@@ -485,8 +387,6 @@ class AudioLDM2Wrapper(PipelineWrapper):
485
  # 1. time
486
  timesteps = timestep
487
  if not torch.is_tensor(timesteps):
488
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
489
- # This would be a good case for the `match` statement (Python 3.10+)
490
  is_mps = sample.device.type == "mps"
491
  if isinstance(timestep, float):
492
  dtype = torch.float32 if is_mps else torch.float64
@@ -628,12 +528,328 @@ class AudioLDM2Wrapper(PipelineWrapper):
628
 
629
  return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
630
 
631
- def forward(self, *args, **kwargs):
632
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
 
635
- def load_model(model_id, device, double_precision=False):
636
- ldm_stable = AudioLDM2Wrapper(model_id=model_id, device=device, double_precision=double_precision)
 
 
 
 
637
  ldm_stable.load_scheduler()
638
  torch.cuda.empty_cache()
639
  return ldm_stable
 
1
  import torch
2
+ from diffusers import DDIMScheduler, CosineDPMSolverMultistepScheduler
3
+ from diffusers.schedulers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
4
+ from diffusers import AudioLDM2Pipeline, StableAudioPipeline
5
+ from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
6
  from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
7
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
8
  from typing import Any, Dict, List, Optional, Tuple, Union
9
+ import gradio as gr
10
 
11
 
12
  class PipelineWrapper(torch.nn.Module):
13
+ def __init__(self, model_id: str,
14
+ device: torch.device,
15
+ double_precision: bool = False,
16
+ token: Optional[str] = None, *args, **kwargs) -> None:
17
  super().__init__(*args, **kwargs)
18
  self.model_id = model_id
19
  self.device = device
20
  self.double_precision = double_precision
21
+ self.token = token
22
 
23
+ def get_sigma(self, timestep: int) -> float:
24
  sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
25
  return sqrt_recipm1_alphas_cumprod[timestep]
26
 
27
+ def load_scheduler(self) -> None:
28
  pass
29
 
30
+ def get_fn_STFT(self) -> torch.nn.Module:
31
  pass
32
 
33
+ def get_sr(self) -> int:
34
+ return 16000
35
+
36
+ def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
37
+ pass
38
+
39
+ def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
40
  pass
41
 
42
+ def decode_to_mel(self, x: torch.Tensor) -> torch.Tensor:
43
  pass
44
 
45
+ def setup_extra_inputs(self, *args, **kwargs) -> None:
46
  pass
47
 
48
+ def encode_text(self, prompts: List[str], **kwargs
49
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
50
  pass
51
 
52
+ def get_variance(self, timestep: torch.Tensor, prev_timestep: torch.Tensor) -> torch.Tensor:
53
  pass
54
 
55
+ def get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
56
  pass
57
 
58
+ def get_noise_shape(self, x0: torch.Tensor, num_steps: int) -> Tuple[int, ...]:
59
+ variance_noise_shape = (num_steps,
60
+ self.model.unet.config.in_channels,
61
+ x0.shape[-2],
62
+ x0.shape[-1])
63
+ return variance_noise_shape
64
+
65
+ def sample_xts_from_x0(self, x0: torch.Tensor, num_inference_steps: int = 50) -> torch.Tensor:
66
+ """
67
+ Samples from P(x_1:T|x_0)
68
+ """
69
+ alpha_bar = self.model.scheduler.alphas_cumprod
70
+ sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
71
+
72
+ variance_noise_shape = self.get_noise_shape(x0, num_inference_steps + 1)
73
+ timesteps = self.model.scheduler.timesteps.to(self.device)
74
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
75
+ xts = torch.zeros(variance_noise_shape).to(x0.device)
76
+ xts[0] = x0
77
+ for t in reversed(timesteps):
78
+ idx = num_inference_steps - t_to_idx[int(t)]
79
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
80
+
81
+ return xts
82
+
83
+ def get_zs_from_xts(self, xt: torch.Tensor, xtm1: torch.Tensor, noise_pred: torch.Tensor,
84
+ t: torch.Tensor, eta: float = 0, numerical_fix: bool = True, **kwargs
85
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
86
+ # pred of x0
87
+ alpha_bar = self.model.scheduler.alphas_cumprod
88
+ if self.model.scheduler.config.prediction_type == 'epsilon':
89
+ pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
90
+ elif self.model.scheduler.config.prediction_type == 'v_prediction':
91
+ pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
92
+
93
+ # direction to xt
94
+ prev_timestep = t - self.model.scheduler.config.num_train_timesteps // \
95
+ self.model.scheduler.num_inference_steps
96
+
97
+ alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
98
+ variance = self.get_variance(t, prev_timestep)
99
+
100
+ if self.model.scheduler.config.prediction_type == 'epsilon':
101
+ radom_noise_pred = noise_pred
102
+ elif self.model.scheduler.config.prediction_type == 'v_prediction':
103
+ radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
104
+
105
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred
106
+
107
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
108
+
109
+ z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
110
+
111
+ # correction to avoid error accumulation
112
+ if numerical_fix:
113
+ xtm1 = mu_xt + (eta * variance ** 0.5)*z
114
+
115
+ return z, xtm1, None
116
+
117
+ def reverse_step_with_custom_noise(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor,
118
+ variance_noise: Optional[torch.Tensor] = None, eta: float = 0, **kwargs
119
+ ) -> torch.Tensor:
120
+ # 1. get previous step value (=t-1)
121
+ prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // \
122
+ self.model.scheduler.num_inference_steps
123
+ # 2. compute alphas, betas
124
+ alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
125
+ alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
126
+ beta_prod_t = 1 - alpha_prod_t
127
+ # 3. compute predicted original sample from predicted noise also called
128
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
129
+ if self.model.scheduler.config.prediction_type == 'epsilon':
130
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
131
+ elif self.model.scheduler.config.prediction_type == 'v_prediction':
132
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
133
+
134
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
135
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
136
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
137
+ variance = self.get_variance(timestep, prev_timestep)
138
+ # std_dev_t = eta * variance ** (0.5)
139
+ # Take care of asymetric reverse process (asyrp)
140
+ if self.model.scheduler.config.prediction_type == 'epsilon':
141
+ model_output_direction = model_output
142
+ elif self.model.scheduler.config.prediction_type == 'v_prediction':
143
+ model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
144
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
145
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
146
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
147
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
148
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
149
+ # 8. Add noice if eta > 0
150
+ if eta > 0:
151
+ if variance_noise is None:
152
+ variance_noise = torch.randn(model_output.shape, device=self.device)
153
+ sigma_z = eta * variance ** (0.5) * variance_noise
154
+ prev_sample = prev_sample + sigma_z
155
+
156
+ return prev_sample
157
+
158
  def unet_forward(self,
159
  sample: torch.FloatTensor,
160
  timestep: Union[torch.Tensor, float, int],
 
171
  replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
172
  return_dict: bool = True,
173
  zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
174
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
  class AudioLDM2Wrapper(PipelineWrapper):
178
  def __init__(self, *args, **kwargs) -> None:
179
  super().__init__(*args, **kwargs)
180
  if self.double_precision:
181
+ self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64, token=self.token
182
+ ).to(self.device)
183
  else:
184
  try:
185
+ self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, token=self.token
186
+ ).to(self.device)
187
  except FileNotFoundError:
188
+ self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False, token=self.token
189
+ ).to(self.device)
190
 
191
+ def load_scheduler(self) -> None:
 
192
  self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
193
 
194
+ def get_fn_STFT(self) -> torch.nn.Module:
195
  from audioldm.audio import TacotronSTFT
196
  return TacotronSTFT(
197
  filter_length=1024,
 
203
  mel_fmax=8000,
204
  )
205
 
206
+ def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
207
  # self.model.vae.disable_tiling()
208
  if x.shape[2] % 4:
209
  x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
210
  return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
211
  # return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
212
 
213
+ def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
214
  return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
215
 
216
+ def decode_to_mel(self, x: torch.Tensor) -> torch.Tensor:
217
  if self.double_precision:
218
  tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
219
  tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
 
221
  tmp = tmp.unsqueeze(0)
222
  return tmp
223
 
224
+ def encode_text(self, prompts: List[str], negative: bool = False,
225
+ save_compute: bool = False, cond_length: int = 0, **kwargs
226
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
227
  tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
228
  text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
229
  prompt_embeds_list = []
 
232
  for tokenizer, text_encoder in zip(tokenizers, text_encoders):
233
  text_inputs = tokenizer(
234
  prompts,
235
+ padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
236
+ else True,
237
+ max_length=tokenizer.model_max_length
238
+ if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)))
239
+ else cond_length,
240
  truncation=True,
241
  return_tensors="pt",
242
  )
 
306
 
307
  return generated_prompt_embeds, prompt_embeds, attention_mask
308
 
309
+ def get_variance(self, timestep: torch.Tensor, prev_timestep: torch.Tensor) -> torch.Tensor:
310
  alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
311
  alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
312
  beta_prod_t = 1 - alpha_prod_t
 
314
  variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
315
  return variance
316
 
317
+ def get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
318
  return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
319
  else self.model.scheduler.final_alpha_cumprod
320
 
 
387
  # 1. time
388
  timesteps = timestep
389
  if not torch.is_tensor(timesteps):
 
 
390
  is_mps = sample.device.type == "mps"
391
  if isinstance(timestep, float):
392
  dtype = torch.float32 if is_mps else torch.float64
 
528
 
529
  return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
530
 
531
+
532
+ class StableAudWrapper(PipelineWrapper):
533
+ def __init__(self, *args, **kwargs) -> None:
534
+ super().__init__(*args, **kwargs)
535
+ try:
536
+ self.model = StableAudioPipeline.from_pretrained(self.model_id, token=self.token, local_files_only=True
537
+ ).to(self.device)
538
+ except FileNotFoundError:
539
+ self.model = StableAudioPipeline.from_pretrained(self.model_id, token=self.token, local_files_only=False
540
+ ).to(self.device)
541
+ self.model.transformer.eval()
542
+ self.model.vae.eval()
543
+
544
+ if self.double_precision:
545
+ self.model = self.model.to(torch.float64)
546
+
547
+ def load_scheduler(self) -> None:
548
+ self.model.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(
549
+ self.model_id, subfolder="scheduler", token=self.token)
550
+
551
+ def encode_text(self, prompts: List[str], negative: bool = False, **kwargs) -> Tuple[torch.Tensor, None, torch.Tensor]:
552
+ text_inputs = self.model.tokenizer(
553
+ prompts,
554
+ padding="max_length",
555
+ max_length=self.model.tokenizer.model_max_length,
556
+ truncation=True,
557
+ return_tensors="pt",
558
+ )
559
+
560
+ text_input_ids = text_inputs.input_ids.to(self.device)
561
+ attention_mask = text_inputs.attention_mask.to(self.device)
562
+
563
+ self.model.text_encoder.eval()
564
+ with torch.no_grad():
565
+ prompt_embeds = self.model.text_encoder(text_input_ids, attention_mask=attention_mask)[0]
566
+
567
+ if negative and attention_mask is not None: # set the masked tokens to the null embed
568
+ prompt_embeds = torch.where(attention_mask.to(torch.bool).unsqueeze(2), prompt_embeds, 0.0)
569
+
570
+ prompt_embeds = self.model.projection_model(text_hidden_states=prompt_embeds).text_hidden_states
571
+
572
+ if attention_mask is None:
573
+ raise gr.Error("Shouldn't reach here. Please raise an issue if you do.")
574
+ """prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
575
+ if attention_mask is not None and negative_attention_mask is None:
576
+ negative_attention_mask = torch.ones_like(attention_mask)
577
+ elif attention_mask is None and negative_attention_mask is not None:
578
+ attention_mask = torch.ones_like(negative_attention_mask)"""
579
+
580
+ if prompts == [""]: # empty
581
+ return torch.zeros_like(prompt_embeds, device=prompt_embeds.device), None, None
582
+
583
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
584
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
585
+ return prompt_embeds, None, attention_mask
586
+
587
+ def get_fn_STFT(self) -> torch.nn.Module:
588
+ from audioldm.audio import TacotronSTFT
589
+ return TacotronSTFT(
590
+ filter_length=1024,
591
+ hop_length=160,
592
+ win_length=1024,
593
+ n_mel_channels=64,
594
+ sampling_rate=44100,
595
+ mel_fmin=0,
596
+ mel_fmax=22050,
597
+ )
598
+
599
+ def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
600
+ x = x.unsqueeze(0)
601
+
602
+ audio_vae_length = int(self.model.transformer.config.sample_size * self.model.vae.hop_length)
603
+ audio_shape = (1, self.model.vae.config.audio_channels, audio_vae_length)
604
+
605
+ # check num_channels
606
+ if x.shape[1] == 1 and self.model.vae.config.audio_channels == 2:
607
+ x = x.repeat(1, 2, 1)
608
+
609
+ audio_length = x.shape[-1]
610
+ audio = x.new_zeros(audio_shape)
611
+ audio[:, :, : min(audio_length, audio_vae_length)] = x[:, :, :audio_vae_length]
612
+
613
+ encoded_audio = self.model.vae.encode(audio.to(self.device)).latent_dist
614
+ encoded_audio = encoded_audio.sample()
615
+ return encoded_audio
616
+
617
+ def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
618
+ torch.cuda.empty_cache()
619
+ # return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
620
+ aud = self.model.vae.decode(x).sample
621
+ return aud[:, :, self.waveform_start:self.waveform_end]
622
+
623
+ def setup_extra_inputs(self, x: torch.Tensor, init_timestep: torch.Tensor,
624
+ extra_info: Optional[Any] = None,
625
+ audio_start_in_s: float = 0, audio_end_in_s: Optional[float] = None,
626
+ save_compute: bool = False) -> None:
627
+ max_audio_length_in_s = self.model.transformer.config.sample_size * self.model.vae.hop_length / \
628
+ self.model.vae.config.sampling_rate
629
+ if audio_end_in_s is None:
630
+ audio_end_in_s = max_audio_length_in_s
631
+
632
+ if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
633
+ raise ValueError(
634
+ f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer "
635
+ f"than the model maximum possible length ({max_audio_length_in_s}). "
636
+ f"Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
637
+ )
638
+
639
+ self.waveform_start = int(audio_start_in_s * self.model.vae.config.sampling_rate)
640
+ self.waveform_end = int(audio_end_in_s * self.model.vae.config.sampling_rate)
641
+
642
+ self.seconds_start_hidden_states, self.seconds_end_hidden_states = self.model.encode_duration(
643
+ audio_start_in_s, audio_end_in_s, self.device, False, 1)
644
+
645
+ if save_compute:
646
+ self.seconds_start_hidden_states = torch.cat([self.seconds_start_hidden_states, self.seconds_start_hidden_states], dim=0)
647
+ self.seconds_end_hidden_states = torch.cat([self.seconds_end_hidden_states, self.seconds_end_hidden_states], dim=0)
648
+
649
+ self.audio_duration_embeds = torch.cat([self.seconds_start_hidden_states,
650
+ self.seconds_end_hidden_states], dim=2)
651
+
652
+ # 7. Prepare rotary positional embedding
653
+ self.rotary_embedding = get_1d_rotary_pos_embed(
654
+ self.model.rotary_embed_dim,
655
+ x.shape[2] + self.audio_duration_embeds.shape[1],
656
+ use_real=True,
657
+ repeat_interleave_real=False,
658
+ )
659
+
660
+ self.model.scheduler._init_step_index(init_timestep)
661
+
662
+ # fix lower_order_nums for the reverse step - Option 1: only start from first order
663
+ # self.model.scheduler.lower_order_nums = 0
664
+ # self.model.scheduler.model_outputs = [None] * self.model.scheduler.config.solver_order
665
+ # fix lower_order_nums for the reverse step - Option 2: start from the correct order with history
666
+ t_to_idx = {float(v): k for k, v in enumerate(self.model.scheduler.timesteps)}
667
+ idx = len(self.model.scheduler.timesteps) - t_to_idx[float(init_timestep)] - 1
668
+ self.model.scheduler.model_outputs = [None, extra_info[idx] if extra_info is not None else None]
669
+ self.model.scheduler.lower_order_nums = min(self.model.scheduler.step_index,
670
+ self.model.scheduler.config.solver_order)
671
+
672
+ # if rand check:
673
+ # x *= self.model.scheduler.init_noise_sigma
674
+ # return x
675
+
676
+ def sample_xts_from_x0(self, x0: torch.Tensor, num_inference_steps: int = 50) -> torch.Tensor:
677
+ """
678
+ Samples from P(x_1:T|x_0)
679
+ """
680
+
681
+ sigmas = self.model.scheduler.sigmas
682
+ shapes = self.get_noise_shape(x0, num_inference_steps + 1)
683
+ xts = torch.zeros(shapes).to(x0.device)
684
+ xts[0] = x0
685
+
686
+ timesteps = self.model.scheduler.timesteps.to(self.device)
687
+ t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
688
+ for t in reversed(timesteps):
689
+ # idx = t_to_idx[int(t)]
690
+ idx = num_inference_steps - t_to_idx[float(t)]
691
+ n = torch.randn_like(x0)
692
+ xts[idx] = x0 + n * sigmas[t_to_idx[float(t)]]
693
+ return xts
694
+
695
+ def get_zs_from_xts(self, xt: torch.Tensor, xtm1: torch.Tensor, data_pred: torch.Tensor,
696
+ t: torch.Tensor, numerical_fix: bool = True, first_order: bool = False, **kwargs
697
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
698
+ # pred of x0
699
+ sigmas = self.model.scheduler.sigmas
700
+ timesteps = self.model.scheduler.timesteps
701
+ solver_order = self.model.scheduler.config.solver_order
702
+
703
+ if self.model.scheduler.step_index is None:
704
+ self.model.scheduler._init_step_index(t)
705
+ curr_step_index = self.model.scheduler.step_index
706
+
707
+ # Improve numerical stability for small number of steps
708
+ lower_order_final = (curr_step_index == len(timesteps) - 1) and (
709
+ self.model.scheduler.config.euler_at_final
710
+ or (self.model.scheduler.config.lower_order_final and len(timesteps) < 15)
711
+ or self.model.scheduler.config.final_sigmas_type == "zero")
712
+ lower_order_second = ((curr_step_index == len(timesteps) - 2) and
713
+ self.model.scheduler.config.lower_order_final and len(timesteps) < 15)
714
+
715
+ data_pred = self.model.scheduler.convert_model_output(data_pred, sample=xt)
716
+ for i in range(solver_order - 1):
717
+ self.model.scheduler.model_outputs[i] = self.model.scheduler.model_outputs[i + 1]
718
+ self.model.scheduler.model_outputs[-1] = data_pred
719
+
720
+ # instead of brownian noise, here we calculate the noise ourselves
721
+ if (curr_step_index == len(timesteps) - 1) and self.model.scheduler.config.final_sigmas_type == "zero":
722
+ z = torch.zeros_like(xt)
723
+ elif first_order or solver_order == 1 or self.model.scheduler.lower_order_nums < 1 or lower_order_final:
724
+ sigma_t, sigma_s = sigmas[curr_step_index + 1], sigmas[curr_step_index]
725
+ h = torch.log(sigma_s) - torch.log(sigma_t)
726
+ z = (xtm1 - (sigma_t / sigma_s * torch.exp(-h)) * xt - (1 - torch.exp(-2.0 * h)) * data_pred) \
727
+ / (sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)))
728
+ elif solver_order == 2 or self.model.scheduler.lower_order_nums < 2 or lower_order_second:
729
+ sigma_t = sigmas[curr_step_index + 1]
730
+ sigma_s0 = sigmas[curr_step_index]
731
+ sigma_s1 = sigmas[curr_step_index - 1]
732
+ m0, m1 = self.model.scheduler.model_outputs[-1], self.model.scheduler.model_outputs[-2]
733
+ h, h_0 = torch.log(sigma_s0) - torch.log(sigma_t), torch.log(sigma_s1) - torch.log(sigma_s0)
734
+ r0 = h_0 / h
735
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
736
+
737
+ # sde-dpmsolver++
738
+ z = (xtm1 - (sigma_t / sigma_s0 * torch.exp(-h)) * xt
739
+ - (1 - torch.exp(-2.0 * h)) * D0
740
+ - 0.5 * (1 - torch.exp(-2.0 * h)) * D1) \
741
+ / (sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)))
742
+
743
+ # correction to avoid error accumulation
744
+ if numerical_fix:
745
+ if first_order or solver_order == 1 or self.model.scheduler.lower_order_nums < 1 or lower_order_final:
746
+ xtm1 = self.model.scheduler.dpm_solver_first_order_update(data_pred, sample=xt, noise=z)
747
+ elif solver_order == 2 or self.model.scheduler.lower_order_nums < 2 or lower_order_second:
748
+ xtm1 = self.model.scheduler.multistep_dpm_solver_second_order_update(
749
+ self.model.scheduler.model_outputs, sample=xt, noise=z)
750
+ # If not perfect recon - maybe TODO fix self.model.scheduler.model_outputs as well?
751
+
752
+ if self.model.scheduler.lower_order_nums < solver_order:
753
+ self.model.scheduler.lower_order_nums += 1
754
+ # upon completion increase step index by one
755
+ self.model.scheduler._step_index += 1
756
+
757
+ return z, xtm1, self.model.scheduler.model_outputs[-2]
758
+
759
+ def get_sr(self) -> int:
760
+ return self.model.vae.config.sampling_rate
761
+
762
+ def get_noise_shape(self, x0: torch.Tensor, num_steps: int) -> Tuple[int, int, int]:
763
+ variance_noise_shape = (num_steps,
764
+ self.model.transformer.config.in_channels,
765
+ int(self.model.transformer.config.sample_size))
766
+ return variance_noise_shape
767
+
768
+ def reverse_step_with_custom_noise(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor,
769
+ variance_noise: Optional[torch.Tensor] = None,
770
+ first_order: bool = False, **kwargs
771
+ ) -> torch.Tensor:
772
+ if self.model.scheduler.step_index is None:
773
+ self.model.scheduler._init_step_index(timestep)
774
+
775
+ # Improve numerical stability for small number of steps
776
+ lower_order_final = (self.model.scheduler.step_index == len(self.model.scheduler.timesteps) - 1) and (
777
+ self.model.scheduler.config.euler_at_final
778
+ or (self.model.scheduler.config.lower_order_final and len(self.model.scheduler.timesteps) < 15)
779
+ or self.model.scheduler.config.final_sigmas_type == "zero"
780
+ )
781
+ lower_order_second = (
782
+ (self.model.scheduler.step_index == len(self.model.scheduler.timesteps) - 2) and
783
+ self.model.scheduler.config.lower_order_final and len(self.model.scheduler.timesteps) < 15
784
+ )
785
+
786
+ model_output = self.model.scheduler.convert_model_output(model_output, sample=sample)
787
+ for i in range(self.model.scheduler.config.solver_order - 1):
788
+ self.model.scheduler.model_outputs[i] = self.model.scheduler.model_outputs[i + 1]
789
+ self.model.scheduler.model_outputs[-1] = model_output
790
+
791
+ if variance_noise is None:
792
+ if self.model.scheduler.noise_sampler is None:
793
+ self.model.scheduler.noise_sampler = BrownianTreeNoiseSampler(
794
+ model_output, sigma_min=self.model.scheduler.config.sigma_min,
795
+ sigma_max=self.model.scheduler.config.sigma_max, seed=None)
796
+ variance_noise = self.model.scheduler.noise_sampler(
797
+ self.model.scheduler.sigmas[self.model.scheduler.step_index],
798
+ self.model.scheduler.sigmas[self.model.scheduler.step_index + 1]).to(model_output.device)
799
+
800
+ if first_order or self.model.scheduler.config.solver_order == 1 or \
801
+ self.model.scheduler.lower_order_nums < 1 or lower_order_final:
802
+ prev_sample = self.model.scheduler.dpm_solver_first_order_update(
803
+ model_output, sample=sample, noise=variance_noise)
804
+ elif self.model.scheduler.config.solver_order == 2 or \
805
+ self.model.scheduler.lower_order_nums < 2 or lower_order_second:
806
+ prev_sample = self.model.scheduler.multistep_dpm_solver_second_order_update(
807
+ self.model.scheduler.model_outputs, sample=sample, noise=variance_noise)
808
+
809
+ if self.model.scheduler.lower_order_nums < self.model.scheduler.config.solver_order:
810
+ self.model.scheduler.lower_order_nums += 1
811
+
812
+ # upon completion increase step index by one
813
+ self.model.scheduler._step_index += 1
814
+
815
+ return prev_sample
816
+
817
+ def unet_forward(self,
818
+ sample: torch.FloatTensor,
819
+ timestep: Union[torch.Tensor, float, int],
820
+ encoder_hidden_states: torch.Tensor,
821
+ encoder_attention_mask: Optional[torch.Tensor] = None,
822
+ return_dict: bool = True,
823
+ **kwargs) -> Tuple:
824
+
825
+ # Create text_audio_duration_embeds and audio_duration_embeds
826
+ embeds = torch.cat([encoder_hidden_states, self.seconds_start_hidden_states, self.seconds_end_hidden_states],
827
+ dim=1)
828
+ if encoder_attention_mask is None:
829
+ # handle the batched case
830
+ if embeds.shape[0] > 1:
831
+ embeds[0] = torch.zeros_like(embeds[0], device=embeds.device)
832
+ else:
833
+ embeds = torch.zeros_like(embeds, device=embeds.device)
834
+
835
+ noise_pred = self.model.transformer(sample,
836
+ timestep.unsqueeze(0),
837
+ encoder_hidden_states=embeds,
838
+ global_hidden_states=self.audio_duration_embeds,
839
+ rotary_embedding=self.rotary_embedding)
840
+
841
+ if not return_dict:
842
+ return (noise_pred.sample,)
843
+
844
+ return noise_pred, None, None
845
 
846
 
847
+ def load_model(model_id: str, device: torch.device,
848
+ double_precision: bool = False, token: Optional[str] = None) -> PipelineWrapper:
849
+ if 'audioldm2' in model_id:
850
+ ldm_stable = AudioLDM2Wrapper(model_id=model_id, device=device, double_precision=double_precision, token=token)
851
+ elif 'stable-audio' in model_id:
852
+ ldm_stable = StableAudWrapper(model_id=model_id, device=device, double_precision=double_precision, token=token)
853
  ldm_stable.load_scheduler()
854
  torch.cuda.empty_cache()
855
  return ldm_stable
requirements.txt CHANGED
@@ -1,8 +1,9 @@
1
- torch
2
- numpy<2
3
  torchaudio
4
  diffusers
5
  accelerate
 
6
  transformers
7
  tqdm
8
  soundfile
 
1
+ torch>2.2.0
2
+ numpy<2.0.0
3
  torchaudio
4
  diffusers
5
  accelerate
6
+ torchsde
7
  transformers
8
  tqdm
9
  soundfile
utils.py CHANGED
@@ -2,8 +2,11 @@ import numpy as np
2
  import torch
3
  from typing import Optional, List, Tuple, NamedTuple, Union
4
  from models import PipelineWrapper
 
5
  from audioldm.utils import get_duration
6
 
 
 
7
 
8
  class PromptEmbeddings(NamedTuple):
9
  embedding_hidden_states: torch.Tensor
@@ -11,26 +14,57 @@ class PromptEmbeddings(NamedTuple):
11
  boolean_prompt_mask: torch.Tensor
12
 
13
 
14
- def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0, device: Optional[torch.device] = None
15
- ) -> torch.tensor:
16
- if type(audio_path) is str:
17
- import audioldm
18
- import audioldm.audio
 
 
19
 
20
- duration = min(get_duration(audio_path), 30)
 
 
21
 
22
- mel, _, _ = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
23
- mel = mel.unsqueeze(0)
24
- else:
25
- mel = audio_path
26
 
27
- c, h, w = mel.shape
28
- left = min(left, w-1)
29
- right = min(right, w - left - 1)
30
- mel = mel[:, :, left:w-right]
31
- mel = mel.unsqueeze(0).to(device)
32
 
33
- return mel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
 
2
  import torch
3
  from typing import Optional, List, Tuple, NamedTuple, Union
4
  from models import PipelineWrapper
5
+ import torchaudio
6
  from audioldm.utils import get_duration
7
 
8
+ MAX_DURATION = 30
9
+
10
 
11
  class PromptEmbeddings(NamedTuple):
12
  embedding_hidden_states: torch.Tensor
 
14
  boolean_prompt_mask: torch.Tensor
15
 
16
 
17
+ def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
18
+ device: Optional[torch.device] = None,
19
+ return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor:
20
+ if stft: # AudioLDM/tango loading to spectrogram
21
+ if type(audio_path) is str:
22
+ import audioldm
23
+ import audioldm.audio
24
 
25
+ duration = get_duration(audio_path)
26
+ if MAX_DURATION is not None:
27
+ duration = min(duration, MAX_DURATION)
28
 
29
+ mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
30
+ mel = mel.unsqueeze(0)
31
+ else:
32
+ mel = audio_path
33
 
34
+ c, h, w = mel.shape
35
+ left = min(left, w-1)
36
+ right = min(right, w - left - 1)
37
+ mel = mel[:, :, left:w-right]
38
+ mel = mel.unsqueeze(0).to(device)
39
 
40
+ if return_wav:
41
+ return mel, 16000, duration, wav
42
+
43
+ return mel, model_sr, duration
44
+ else:
45
+ waveform, sr = torchaudio.load(audio_path)
46
+ if sr != model_sr:
47
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr)
48
+ # waveform = waveform.numpy()[0, ...]
49
+
50
+ def normalize_wav(waveform):
51
+ waveform = waveform - torch.mean(waveform)
52
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
53
+ return waveform * 0.5
54
+
55
+ waveform = normalize_wav(waveform)
56
+ # waveform = waveform[None, ...]
57
+ # waveform = pad_wav(waveform, segment_length)
58
+
59
+ # waveform = waveform[0, ...]
60
+ waveform = torch.FloatTensor(waveform)
61
+ if MAX_DURATION is not None:
62
+ duration = min(waveform.shape[-1] / model_sr, MAX_DURATION)
63
+ waveform = waveform[:, :int(duration * model_sr)]
64
+
65
+ # cut waveform
66
+ duration = waveform.shape[-1] / model_sr
67
+ return waveform, model_sr, duration
68
 
69
 
70
  def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int: