Spaces:
Sleeping
Sleeping
import spaces | |
import argparse, os, sys, glob | |
import pathlib | |
directory = pathlib.Path(os.getcwd()) | |
print(directory) | |
sys.path.append(str(directory)) | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from tqdm import tqdm, trange | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.scheduling_lcm import LCMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
import pandas as pd | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from icecream import ic | |
from pathlib import Path | |
import soundfile as sf | |
import yaml | |
import datetime | |
from vocoder.bigvgan.models import VocoderBigVGAN | |
import soundfile | |
# from pytorch_memlab import LineProfiler,profile | |
import gradio | |
import gradio as gr | |
def load_model_from_config(config, ckpt = None, verbose=True): | |
model = instantiate_from_config(config.model) | |
if ckpt: | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
sd = pl_sd["state_dict"] | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
else: | |
print(f"Note chat no ckpt is loaded !!!") | |
model.cuda() | |
model.eval() | |
return model | |
class GenSamples: | |
def __init__(self,sampler,model,outpath,vocoder = None,save_mel = True,save_wav = True, original_inference_steps=None, ddim_steps=2, scale=5, num_samples=1) -> None: | |
self.sampler = sampler | |
self.model = model | |
self.outpath = outpath | |
if save_wav: | |
assert vocoder is not None | |
self.vocoder = vocoder | |
self.save_mel = save_mel | |
self.save_wav = save_wav | |
self.channel_dim = self.model.channels | |
self.original_inference_steps = original_inference_steps | |
self.ddim_steps = ddim_steps | |
self.scale = scale | |
self.num_samples = num_samples | |
def gen_test_sample(self,prompt,mel_name = None,wav_name = None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'} | |
uc = None | |
record_dicts = [] | |
# if os.path.exists(os.path.join(self.outpath,mel_name+f'_0.npy')): | |
# return record_dicts | |
if self.scale != 1.0: | |
emptycap = {'ori_caption':self.num_samples*[""],'struct_caption':self.num_samples*[""]} | |
uc = self.model.get_learned_conditioning(emptycap) | |
for n in range(1):# trange(self.opt.n_iter, desc="Sampling"): | |
for k,v in prompt.items(): | |
prompt[k] = self.num_samples * [v] | |
c = self.model.get_learned_conditioning(prompt)# shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding | |
if self.channel_dim>0: | |
shape = [self.channel_dim, 20, 312] # (z_dim, 80//2^x, 848//2^x) | |
else: | |
shape = [20, 312] | |
samples_ddim, _ = self.sampler.sample(S=self.ddim_steps, | |
conditioning=c, | |
batch_size=self.num_samples, | |
shape=shape, | |
verbose=False, | |
guidance_scale=self.scale, | |
original_inference_steps=self.original_inference_steps | |
) | |
x_samples_ddim = self.model.decode_first_stage(samples_ddim) | |
for idx,spec in enumerate(x_samples_ddim): | |
spec = spec.squeeze(0).cpu().numpy() | |
record_dict = {'caption':prompt['ori_caption'][0]} | |
if self.save_mel: | |
mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy') | |
np.save(mel_path,spec) | |
record_dict['mel_path'] = mel_path | |
if self.save_wav: | |
wav = self.vocoder.vocode(spec) | |
wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav') | |
soundfile.write(wav_path, wav, 16000) | |
record_dict['audio_path'] = wav_path | |
record_dicts.append(record_dict) | |
return record_dicts | |
def infer(ori_prompt, ddim_steps, num_samples, scale, seed): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>') | |
config = OmegaConf.load("configs/audiolcm.yaml") | |
# print("-------quick debug no load ckpt---------") | |
# model = instantiate_from_config(config['model'])# for quick debug | |
model = load_model_from_config(config, "./model/000184.ckpt") | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
sampler = LCMSampler(model) | |
os.makedirs("results/test", exist_ok=True) | |
vocoder = VocoderBigVGAN("./model/vocoder",device) | |
generator = GenSamples(sampler,model,"results/test",vocoder,save_mel = False,save_wav = True, original_inference_steps=config.model.params.num_ddim_timesteps, ddim_steps=ddim_steps, scale=scale, num_samples=num_samples) | |
csv_dicts = [] | |
with torch.no_grad(): | |
with model.ema_scope(): | |
wav_name = f'{prompt["ori_caption"].strip().replace(" ", "-")}' | |
generator.gen_test_sample(prompt,wav_name=wav_name) | |
print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.") | |
return "results/test/"+wav_name+"_0.wav" | |
def my_inference_function(text_prompt, ddim_steps, num_samples, scale, seed): | |
file_path = infer(text_prompt, ddim_steps, num_samples, scale, seed) | |
return file_path | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown("## AudioLCM:Text-to-Audio Generation with Latent Consistency Models") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt: Input your text here. ") | |
run_button = gr.Button() | |
with gr.Accordion("Advanced options", open=False): | |
num_samples = gr.Slider( | |
label="Select from audios num.This number control the number of candidates \ | |
(e.g., generate three audios and choose the best to show you). A Larger value usually lead to \ | |
better quality with heavier computation", minimum=1, maximum=10, value=1, step=1) | |
ddim_steps = gr.Slider(label="ddim_steps", minimum=1, | |
maximum=50, value=2, step=1) | |
scale = gr.Slider( | |
label="Guidance Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=8.0, value=5.0, step=0.1 | |
) | |
seed = gr.Slider( | |
label="Seed:Change this value (any integer number) will lead to a different generation result.", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
value=44, | |
) | |
with gr.Column(): | |
outaudio = gr.Audio() | |
run_button.click(fn=my_inference_function, inputs=[ | |
prompt,ddim_steps, num_samples, scale, seed], outputs=[outaudio]) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Examples( | |
examples = [['An engine revving and then tires squealing',2,1,5,55],['A group of people laughing followed by farting',2,1,5,55], | |
['Duck quacking repeatedly',2,1,5,88],['A man speaks as birds chirp and dogs bark',2,1,5,55],['Continuous snoring of a person',2,1,5,55]], | |
inputs = [prompt,ddim_steps, num_samples, scale, seed], | |
outputs = [outaudio] | |
) | |
with gr.Column(): | |
pass | |
demo.launch(show_error=True) | |
# gradio_interface = gradio.Interface( | |
# fn = my_inference_function, | |
# inputs = "text", | |
# outputs = "audio" | |
# ) | |
# gradio_interface.launch() |