Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import matplotlib | |
from omegaconf import OmegaConf | |
from einops import repeat | |
import librosa | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from vocoder.bigvgan.models import VocoderBigVGAN | |
from ldm.util import instantiate_from_config | |
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000 | |
SAMPLE_RATE = 16000 | |
cmap_transform = matplotlib.cm.viridis | |
torch.set_grad_enabled(False) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
def initialize_model(config, ckpt): | |
config = OmegaConf.load(config) | |
model = instantiate_from_config(config.model) | |
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False) | |
model = model.to(device) | |
print(model.device,device,model.cond_stage_model.device) | |
sampler = DDIMSampler(model) | |
return sampler | |
def make_batch_sd( | |
mel, | |
mask, | |
device, | |
num_samples=1): | |
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32) | |
mask = torch.from_numpy(mask)[None,None,...].to(dtype=torch.float32) | |
masked_mel = (1 - mask) * mel | |
mel = mel * 2 - 1 | |
mask = mask * 2 - 1 | |
masked_mel = masked_mel * 2 -1 | |
batch = { | |
"mel": repeat(mel.to(device=device), "1 ... -> n ...", n=num_samples), | |
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), | |
"masked_mel": repeat(masked_mel.to(device=device), "1 ... -> n ...", n=num_samples), | |
} | |
return batch | |
def gen_mel(input_audio): | |
sr,ori_wav = input_audio | |
print(sr,ori_wav.shape,ori_wav) | |
ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0 # order='C'是以C语言格式存储,不用管 | |
if len(ori_wav.shape)==2:# stereo | |
ori_wav = librosa.to_mono(ori_wav.T)# gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len) | |
print(sr,ori_wav.shape,ori_wav) | |
ori_wav = librosa.resample(ori_wav,orig_sr = sr,target_sr = SAMPLE_RATE) | |
mel_len,hop_size = 848,256 | |
input_len = mel_len * hop_size | |
if len(ori_wav) < input_len: | |
input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0) | |
else: | |
input_wav = ori_wav[:input_len] | |
mel = TRANSFORMS_16000(input_wav) | |
return mel | |
def show_mel_fn(input_audio): | |
crop_len = 500 # the full mel cannot be showed due to gradio's Image bug when using tool='sketch' | |
crop_mel = gen_mel(input_audio)[:,:crop_len] | |
color_mel = cmap_transform(crop_mel) | |
return Image.fromarray((color_mel*255).astype(np.uint8)) | |
def inpaint(sampler, batch, seed, ddim_steps, num_samples=1, W=512, H=512): | |
model = sampler.model | |
prng = np.random.RandomState(seed) | |
start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8) | |
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32) | |
c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"])) | |
cc = torch.nn.functional.interpolate(batch["mask"], | |
size=c.shape[-2:]) | |
c = torch.cat((c, cc), dim=1) # (b,c+1,h,w) 1 is mask | |
shape = (c.shape[1]-1,)+c.shape[2:] | |
samples_ddim, _ = sampler.sample(S=ddim_steps, | |
conditioning=c, | |
batch_size=c.shape[0], | |
shape=shape, | |
verbose=False) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
mask = batch["mask"]# [-1,1] | |
mel = torch.clamp((batch["mel"]+1.0)/2.0,min=0.0, max=1.0) | |
mask = torch.clamp((batch["mask"]+1.0)/2.0,min=0.0, max=1.0) | |
predicted_mel = torch.clamp((x_samples_ddim+1.0)/2.0,min=0.0, max=1.0) | |
inpainted = (1-mask)*mel+mask*predicted_mel | |
inpainted = inpainted.cpu().numpy().squeeze() | |
inapint_wav = vocoder.vocode(inpainted) | |
return inpainted,inapint_wav | |
def predict(input_audio,mel_and_mask,ddim_steps,seed): | |
show_mel = np.array(mel_and_mask['image'].convert("L"))/255 # 由于展示的mel只展示了一部分,所以需要重新从音频生成mel | |
mask = np.array(mel_and_mask["mask"].convert("L"))/255 | |
mel_bins,mel_len = 80,848 | |
input_mel = gen_mel(input_audio)[:,:mel_len]# 由于展示的mel只展示了一部分,所以需要重新从音频生成mel | |
mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)# 将mask填充到原来的mel的大小 | |
print(mask.shape,input_mel.shape) | |
with torch.no_grad(): | |
batch = make_batch_sd(input_mel,mask,device,num_samples=1) | |
inpainted,gen_wav = inpaint( | |
sampler=sampler, | |
batch=batch, | |
seed=seed, | |
ddim_steps=ddim_steps, | |
num_samples=1, | |
H=mel_bins, W=mel_len | |
) | |
inpainted = inpainted[:,:show_mel.shape[1]] | |
color_mel = cmap_transform(inpainted) | |
input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0]) | |
gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len] | |
return Image.fromarray((color_mel*255).astype(np.uint8)),(SAMPLE_RATE,gen_wav) | |
sampler = initialize_model('./configs/inpaint/txt2audio_args.yaml', './useful_ckpts/inpaint7_epoch00047.ckpt') | |
vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device) | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## Make-An-Audio Inpainting") | |
with gr.Row(): | |
with gr.Column(): | |
input_audio = gr.inputs.Audio() | |
show_button = gr.Button("Show Mel") | |
run_button = gr.Button("Predict Masked Place") | |
with gr.Accordion("Advanced options", open=False): | |
ddim_steps = gr.Slider(label="Steps", minimum=1, | |
maximum=150, value=100, step=1) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
randomize=True, | |
) | |
with gr.Column(): | |
show_inpainted = gr.Image(type="pil").style(width=848,height=80) | |
outaudio = gr.Audio() | |
show_mel = gr.Image(type="pil",tool='sketch')#.style(width=848,height=80) # 加上这个没办法展示完全图片 | |
show_button.click(fn=show_mel_fn, inputs=[input_audio], outputs=show_mel) | |
run_button.click(fn=predict, inputs=[input_audio,show_mel,ddim_steps,seed], outputs=[show_inpainted,outaudio]) | |
block.launch() | |