import gradio as gr import random import torch import torchaudio from torch import inference_mode from tempfile import NamedTemporaryFile import numpy as np from models import load_model import utils from inversion_utils import inversion_forward_process, inversion_reverse_process # current_loaded_model = "cvssp/audioldm2-music" # # current_loaded_model = "cvssp/audioldm2-music" # ldm_stable = load_model(current_loaded_model, device, 200) # deafult model LDM2 = "cvssp/audioldm2" MUSIC = "cvssp/audioldm2-music" LDM2_LARGE = "cvssp/audioldm2-large" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ldm2 = load_model(model_id=LDM2, device=device) ldm2_large = load_model(model_id=LDM2_LARGE, device=device) ldm2_music = load_model(model_id= MUSIC, device=device) def randomize_seed_fn(seed, randomize_seed): if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) return seed def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable): ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device) with inference_mode(): w0 = ldm_stable.vae_encode(x0) # find Zs and wts - forward process _, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], prog_bar=True, num_inference_steps=num_diffusion_steps, numerical_fix=True) return zs, wts def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable): # reverse process (via Zs and wT) tstart = torch.tensor(tstart, dtype=torch.int) skip = steps - tstart w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:int(steps - skip)]) # vae decode image with inference_mode(): x0_dec = ldm_stable.vae_decode(w0) if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :] with torch.no_grad(): audio = ldm_stable.decode_to_mel(x0_dec) f = NamedTemporaryFile("wb", suffix=".wav", delete=False) torchaudio.save(f.name, audio, sample_rate=16000) return f.name def change_tstart_range(t_start, steps): maximum = int(0.8 * steps) minimum = int(0.15 * steps) if t_start > maximum: t_start = maximum elif t_start < minimum: t_start = minimum return t_start def edit(input_audio, model_id: str, do_inversion: bool, wts: gr.State, zs: gr.State, saved_inv_model: str, source_prompt="", target_prompt="", steps=200, cfg_scale_src=3.5, cfg_scale_tar=12, t_start=90, randomize_seed=True): # global ldm_stable, current_loaded_model # print(f'current loaded model: {ldm_stable.model_id}') # if model_id != current_loaded_model: # print(f'Changing model to {model_id}...') # current_loaded_model = model_id # ldm_stable = None # ldm_stable = load_model(model_id, device) print(model_id) if model_id == LDM2: ldm_stable = ldm2 elif model_id == LDM2_LARGE: ldm_stable = ldm2_large else: # MUSIC ldm_stable = ldm2_music # If the inversion was done for a different model, we need to re-run the inversion if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id): do_inversion = True x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device) if do_inversion or randomize_seed: # always re-run inversion zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src) wts = gr.State(value=wts_tensor) zs = gr.State(value=zs_tensor) saved_inv_model = model_id do_inversion = False # make sure t_start is in the right limit t_start = change_tstart_range(t_start, steps) output = sample(ldm_stable, zs.value, wts.value, steps, prompt_tar=target_prompt, tstart=t_start, cfg_scale_tar=cfg_scale_tar) return output, wts, zs, saved_inv_model, do_inversion def get_example(): case = [ ['Examples/Beethoven.wav', '', 'A recording of an arcade game soundtrack.', 90, 'cvssp/audioldm2-music', '27s', 'Examples/Beethoven_arcade.wav', ], ['Examples/Beethoven.wav', 'A high quality recording of wind instruments and strings playing.', 'A high quality recording of a piano playing.', 90, 'cvssp/audioldm2-music', '27s', 'Examples/Beethoven_piano.wav', ], ['Examples/ModalJazz.wav', 'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.', 'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.', 90, 'cvssp/audioldm2-music', '106s', 'Examples/ModalJazz_banjo.wav',], ['Examples/Cat.wav', '', 'A dog barking.', 150, 'cvssp/audioldm2-large', '10s', 'Examples/Cat_dog.wav',] ] return case intro = """
Demo for the method introduced in: Zero-Shot Unsupervised and Text-Based Audio Editing Using DDPM Inversion
"""
help = """
Instructions:
Provide an input audio and a target prompt to edit the audio.
Tstart is used to control the tradeoff between fidelity to the original signal and text-adhearance.
Lower value -> favor fidelity. Higher value -> apply a stronger edit.
Make sure that you use an AudioLDM2 version that is suitable for your input audio.
For example, use the music version for music and the large version for general audio.
You can additionally provide a source prompt to guide even further the editing process.
Longer input will take more time.
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
""" with gr.Blocks(css='style.css') as demo: def reset_do_inversion(): do_inversion = gr.State(value=True) return do_inversion gr.HTML(intro) wts = gr.State() zs = gr.State() saved_inv_model = gr.State() # current_loaded_model = gr.State(value="cvssp/audioldm2-music") # ldm_stable = load_model("cvssp/audioldm2-music", device, 200) # ldm_stable = gr.State(value=ldm_stable) do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over with gr.Row(): input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input Audio", interactive=True, scale=1) output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1) with gr.Row(): tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output", placeholder="a recording of a happy upbeat arcade game soundtrack", lines=2, interactive=True) with gr.Row(): with gr.Column(): submit = gr.Button("Edit") with gr.Row(): t_start = gr.Slider(minimum=10, maximum=240, value=30, step=1, label="T-start", interactive=True, scale=3, info="Higher T-start -> stronger edit. Lower T-start -> closer to original audio") model_id = gr.Dropdown(label="AudioLDM2 Version", choices=["cvssp/audioldm2", "cvssp/audioldm2-large", "cvssp/audioldm2-music"], info="Choose a checkpoint suitable for your intended audio and edit", value="cvssp/audioldm2-music", interactive=True, type="value", scale=2) with gr.Accordion("More Options", open=False): with gr.Row(): src_prompt = gr.Textbox(label="Source Prompt", lines=2, interactive=True, info= "Optional: Describe the original audio input", placeholder="A recording of a happy upbeat classical music piece",) with gr.Row(): cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None, label="Source Guidance Scale", interactive=True, scale=1) cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None, label="Target Guidance Scale", interactive=True, scale=1) steps = gr.Number(value=50, step=1, minimum=20, maximum=300, label="Num Diffusion Steps", interactive=True, scale=1) with gr.Row(): seed = gr.Number(value=0, precision=0, label="Seed", interactive=True) randomize_seed = gr.Checkbox(label='Randomize seed', value=False) length = gr.Number(label="Length", interactive=False, visible=False) submit.click( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=[seed], queue=False).then( fn=edit, inputs=[input_audio, model_id, do_inversion, # current_loaded_model, ldm_stable, wts, zs, saved_inv_model, src_prompt, tar_prompt, steps, cfg_scale_src, cfg_scale_tar, t_start, randomize_seed ], outputs=[output_audio, wts, zs, saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable], ) # If sources changed we have to rerun inversion input_audio.change(fn=reset_do_inversion, outputs=[do_inversion]) src_prompt.change(fn=reset_do_inversion, outputs=[do_inversion]) model_id.change(fn=reset_do_inversion, outputs=[do_inversion]) steps.change(fn=reset_do_inversion, outputs=[do_inversion]) gr.Examples( label="Examples", examples=get_example(), inputs=[input_audio, src_prompt, tar_prompt, t_start, model_id, length, output_audio], outputs=[output_audio] ) demo.queue() demo.launch()