ZeyuXie commited on
Commit
a89c362
1 Parent(s): ae95272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -134
app.py CHANGED
@@ -1,135 +1,135 @@
1
-
2
- import os
3
- import json
4
- import numpy as np
5
- import torch
6
- import soundfile as sf
7
- import gradio as gr
8
- from diffusers import DDPMScheduler
9
- from pico_model import PicoDiffusion
10
- from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
11
- from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
12
- class dotdict(dict):
13
- """dot.notation access to dictionary attributes"""
14
- __getattr__ = dict.get
15
- __setattr__ = dict.__setitem__
16
- __delattr__ = dict.__delitem__
17
-
18
- class InferRunner:
19
- def __init__(self, device):
20
- vae_config = json.load(open("ckpts/ldm/vae_config.json"))
21
- self.vae = AutoencoderKL(**vae_config).to(device)
22
- vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
23
- self.vae.load_state_dict(vae_weights)
24
-
25
- train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
- self.pico_model = PicoDiffusion(
27
- scheduler_name=train_args.scheduler_name,
28
- unet_model_config_path=train_args.unet_model_config,
29
- snr_gamma=train_args.snr_gamma,
30
- freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
31
- diffusion_pt="ckpts/pico_model/diffusion.pt",
32
- ).eval().to(device)
33
- self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
34
-
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- runner = InferRunner(device)
37
- event_list = get_event()
38
- def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
39
- with torch.no_grad():
40
- latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
41
- mel = runner.vae.decode_first_stage(latents)
42
- wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
43
- outpath = f"output.wav"
44
- sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
45
- return outpath
46
-
47
- def preprocess(caption):
48
- output = preprocess_gemini(caption)
49
- return output, output
50
-
51
- with gr.Blocks() as demo:
52
- with gr.Row():
53
- gr.Markdown("## PicoAudio")
54
- with gr.Row():
55
- description_text = f"18 events: {', '.join(event_list)}"
56
- gr.Markdown(description_text)
57
-
58
- with gr.Row():
59
- gr.Markdown("## Step1")
60
- with gr.Row():
61
- preprocess_description_text = f"preprocess: free-text to timestamp caption via LLM"
62
- gr.Markdown(preprocess_description_text)
63
- with gr.Row():
64
- with gr.Column():
65
- freetext_prompt = gr.Textbox(label="Prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
66
- value="a dog barks three times.",)
67
- preprocess_run_button = gr.Button()
68
- prompt = None
69
- with gr.Column():
70
- freetext_prompt_out = gr.Textbox(label="Preprocess output")
71
- with gr.Row():
72
- with gr.Column():
73
- gr.Examples(
74
- examples = [["spraying two times then gunshot three times."],
75
- ["a dog barks three times."],
76
- ["cow mooing two times."],],
77
- inputs = [freetext_prompt],
78
- outputs = [prompt]
79
- )
80
- with gr.Column():
81
- pass
82
-
83
-
84
- with gr.Row():
85
- gr.Markdown("## Step2")
86
- with gr.Row():
87
- with gr.Column():
88
- prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
89
- value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
90
- generate_run_button = gr.Button()
91
- with gr.Accordion("Advanced options", open=False):
92
- num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
93
- guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
94
- with gr.Column():
95
- outaudio = gr.Audio()
96
- preprocess_run_button.click(fn=preprocess_gemini, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
97
- generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
98
-
99
- with gr.Row():
100
- with gr.Column():
101
- gr.Examples(
102
- examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
103
- ["dog_barking at 0.562-2.562_4.25-6.25."],
104
- ["cow_mooing at 0.958-3.582_5.272-7.896."],],
105
- inputs = [prompt, num_steps, guidance_scale],
106
- outputs = [outaudio]
107
- )
108
- with gr.Column():
109
- pass
110
-
111
-
112
- demo.launch()
113
-
114
-
115
- # description_text = f"18 events: {', '.join(event_list)}"
116
- # prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
117
- # value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
118
- # outaudio = gr.Audio()
119
- # num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
120
- # guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
121
- # gr_interface = gr.Interface(
122
- # fn=infer,
123
- # inputs=[prompt, num_steps, guidance_scale],
124
- # outputs=[outaudio],
125
- # title="PicoAudio",
126
- # description=description_text,
127
- # allow_flagging=False,
128
- # examples=[
129
- # ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
130
- # ["dog_barking at 0.562-2.562_4.25-6.25."],
131
- # ["cow_mooing at 0.958-3.582_5.272-7.896."],
132
- # ],
133
- # cache_examples="lazy", # Turn on to cache.
134
- # )
135
  # gr_interface.queue(10).launch()
 
1
+
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import torch
6
+ import soundfile as sf
7
+ import gradio as gr
8
+ from diffusers import DDPMScheduler
9
+ from pico_model import PicoDiffusion
10
+ from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
11
+ from llm_preprocess import get_event, preprocess_gemini, preprocess_gpt
12
+ class dotdict(dict):
13
+ """dot.notation access to dictionary attributes"""
14
+ __getattr__ = dict.get
15
+ __setattr__ = dict.__setitem__
16
+ __delattr__ = dict.__delitem__
17
+
18
+ class InferRunner:
19
+ def __init__(self, device):
20
+ vae_config = json.load(open("ckpts/ldm/vae_config.json"))
21
+ self.vae = AutoencoderKL(**vae_config).to(device)
22
+ vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
23
+ self.vae.load_state_dict(vae_weights)
24
+
25
+ train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
26
+ self.pico_model = PicoDiffusion(
27
+ scheduler_name=train_args.scheduler_name,
28
+ unet_model_config_path=train_args.unet_model_config,
29
+ snr_gamma=train_args.snr_gamma,
30
+ freeze_text_encoder_ckpt="ckpts/laion_clap/630k-audioset-best.pt",
31
+ diffusion_pt="ckpts/pico_model/diffusion.pt",
32
+ ).eval().to(device)
33
+ self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
34
+
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ runner = InferRunner(device)
37
+ event_list = get_event()
38
+ def infer(caption, num_steps=200, guidance_scale=3.0, audio_len=16000*10):
39
+ with torch.no_grad():
40
+ latents = runner.pico_model.demo_inference(caption, runner.scheduler, num_steps=num_steps, guidance_scale=guidance_scale, num_samples_per_prompt=1, disable_progress=True)
41
+ mel = runner.vae.decode_first_stage(latents)
42
+ wave = runner.vae.decode_to_waveform(mel)[0][:audio_len]
43
+ outpath = f"output.wav"
44
+ sf.write(outpath, wave, samplerate=16000, subtype='PCM_16')
45
+ return outpath
46
+
47
+ def preprocess(caption):
48
+ output = preprocess_gemini(caption)
49
+ return output, output
50
+
51
+ with gr.Blocks() as demo:
52
+ with gr.Row():
53
+ gr.Markdown("## PicoAudio")
54
+ with gr.Row():
55
+ description_text = f"18 events: {', '.join(event_list)}"
56
+ gr.Markdown(description_text)
57
+
58
+ with gr.Row():
59
+ gr.Markdown("## Step1")
60
+ with gr.Row():
61
+ preprocess_description_text = f"preprocess: free-text to timestamp caption via LLM"
62
+ gr.Markdown(preprocess_description_text)
63
+ with gr.Row():
64
+ with gr.Column():
65
+ freetext_prompt = gr.Textbox(label="Prompt: Input your free-text caption here. (e.g. a dog barks three times.)",
66
+ value="a dog barks three times.",)
67
+ preprocess_run_button = gr.Button()
68
+ prompt = None
69
+ with gr.Column():
70
+ freetext_prompt_out = gr.Textbox(label="Preprocess output")
71
+ with gr.Row():
72
+ with gr.Column():
73
+ gr.Examples(
74
+ examples = [["spraying two times then gunshot three times."],
75
+ ["a dog barks three times."],
76
+ ["cow mooing two times."],],
77
+ inputs = [freetext_prompt],
78
+ outputs = [prompt]
79
+ )
80
+ with gr.Column():
81
+ pass
82
+
83
+
84
+ with gr.Row():
85
+ gr.Markdown("## Step2")
86
+ with gr.Row():
87
+ with gr.Column():
88
+ prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
89
+ value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
90
+ generate_run_button = gr.Button()
91
+ with gr.Accordion("Advanced options", open=False):
92
+ num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
93
+ guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
94
+ with gr.Column():
95
+ outaudio = gr.Audio()
96
+ preprocess_run_button.click(fn=preprocess_gemini, inputs=[freetext_prompt], outputs=[prompt, freetext_prompt_out])
97
+ generate_run_button.click(fn=infer, inputs=[prompt, num_steps, guidance_scale], outputs=[outaudio])
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ gr.Examples(
102
+ examples = [["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
103
+ ["dog_barking at 0.562-2.562_4.25-6.25."],
104
+ ["cow_mooing at 0.958-3.582_5.272-7.896."],],
105
+ inputs = [prompt, num_steps, guidance_scale],
106
+ outputs = [outaudio]
107
+ )
108
+ with gr.Column():
109
+ pass
110
+
111
+
112
+ demo.launch()
113
+
114
+
115
+ # description_text = f"18 events: {', '.join(event_list)}"
116
+ # prompt = gr.Textbox(label="Prompt: Input your caption formatted as 'event1 at onset1-offset1_onset2-offset2 and event2 at onset1-offset1'.",
117
+ # value="spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031.",)
118
+ # outaudio = gr.Audio()
119
+ # num_steps = gr.Slider(label="num_steps", minimum=1, maximum=300, value=200, step=1)
120
+ # guidance_scale = gr.Slider(label="guidance_scale", minimum=0.1, maximum=8.0, value=3.0, step=0.1)
121
+ # gr_interface = gr.Interface(
122
+ # fn=infer,
123
+ # inputs=[prompt, num_steps, guidance_scale],
124
+ # outputs=[outaudio],
125
+ # title="PicoAudio",
126
+ # description=description_text,
127
+ # allow_flagging=False,
128
+ # examples=[
129
+ # ["spraying at 0.38-1.176_3.06-3.856 and gunshot at 1.729-3.729_4.367-6.367_7.031-9.031."],
130
+ # ["dog_barking at 0.562-2.562_4.25-6.25."],
131
+ # ["cow_mooing at 0.958-3.582_5.272-7.896."],
132
+ # ],
133
+ # cache_examples="lazy", # Turn on to cache.
134
+ # )
135
  # gr_interface.queue(10).launch()