Spaces:
Runtime error
Runtime error
File size: 6,573 Bytes
5f898a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import numpy as np
import gradio as gr
from PIL import Image
import matplotlib
from omegaconf import OmegaConf
from einops import repeat
import librosa
from ldm.models.diffusion.ddim import DDIMSampler
from vocoder.bigvgan.models import VocoderBigVGAN
from ldm.util import instantiate_from_config
from ldm.data.extract_mel_spectrogram import TRANSFORMS_16000
SAMPLE_RATE = 16000
cmap_transform = matplotlib.cm.viridis
torch.set_grad_enabled(False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def initialize_model(config, ckpt):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
model = model.to(device)
print(model.device,device,model.cond_stage_model.device)
sampler = DDIMSampler(model)
return sampler
def make_batch_sd(
mel,
mask,
device,
num_samples=1):
mel = torch.from_numpy(mel)[None,None,...].to(dtype=torch.float32)
mask = torch.from_numpy(mask)[None,None,...].to(dtype=torch.float32)
masked_mel = (1 - mask) * mel
mel = mel * 2 - 1
mask = mask * 2 - 1
masked_mel = masked_mel * 2 -1
batch = {
"mel": repeat(mel.to(device=device), "1 ... -> n ...", n=num_samples),
"mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
"masked_mel": repeat(masked_mel.to(device=device), "1 ... -> n ...", n=num_samples),
}
return batch
def gen_mel(input_audio):
sr,ori_wav = input_audio
print(sr,ori_wav.shape,ori_wav)
ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0 # order='C'是以C语言格式存储,不用管
if len(ori_wav.shape)==2:# stereo
ori_wav = librosa.to_mono(ori_wav.T)# gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
print(sr,ori_wav.shape,ori_wav)
ori_wav = librosa.resample(ori_wav,orig_sr = sr,target_sr = SAMPLE_RATE)
mel_len,hop_size = 848,256
input_len = mel_len * hop_size
if len(ori_wav) < input_len:
input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0)
else:
input_wav = ori_wav[:input_len]
mel = TRANSFORMS_16000(input_wav)
return mel
def show_mel_fn(input_audio):
crop_len = 500 # the full mel cannot be showed due to gradio's Image bug when using tool='sketch'
crop_mel = gen_mel(input_audio)[:,:crop_len]
color_mel = cmap_transform(crop_mel)
return Image.fromarray((color_mel*255).astype(np.uint8))
def inpaint(sampler, batch, seed, ddim_steps, num_samples=1, W=512, H=512):
model = sampler.model
prng = np.random.RandomState(seed)
start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8)
start_code = torch.from_numpy(start_code).to(device=device, dtype=torch.float32)
c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"]))
cc = torch.nn.functional.interpolate(batch["mask"],
size=c.shape[-2:])
c = torch.cat((c, cc), dim=1) # (b,c+1,h,w) 1 is mask
shape = (c.shape[1]-1,)+c.shape[2:]
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
mask = batch["mask"]# [-1,1]
mel = torch.clamp((batch["mel"]+1.0)/2.0,min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,min=0.0, max=1.0)
predicted_mel = torch.clamp((x_samples_ddim+1.0)/2.0,min=0.0, max=1.0)
inpainted = (1-mask)*mel+mask*predicted_mel
inpainted = inpainted.cpu().numpy().squeeze()
inapint_wav = vocoder.vocode(inpainted)
return inpainted,inapint_wav
def predict(input_audio,mel_and_mask,ddim_steps,seed):
show_mel = np.array(mel_and_mask['image'].convert("L"))/255 # 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
mask = np.array(mel_and_mask["mask"].convert("L"))/255
mel_bins,mel_len = 80,848
input_mel = gen_mel(input_audio)[:,:mel_len]# 由于展示的mel只展示了一部分,所以需要重新从音频生成mel
mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)# 将mask填充到原来的mel的大小
print(mask.shape,input_mel.shape)
with torch.no_grad():
batch = make_batch_sd(input_mel,mask,device,num_samples=1)
inpainted,gen_wav = inpaint(
sampler=sampler,
batch=batch,
seed=seed,
ddim_steps=ddim_steps,
num_samples=1,
H=mel_bins, W=mel_len
)
inpainted = inpainted[:,:show_mel.shape[1]]
color_mel = cmap_transform(inpainted)
input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0])
gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
return Image.fromarray((color_mel*255).astype(np.uint8)),(SAMPLE_RATE,gen_wav)
sampler = initialize_model('./configs/inpaint/txt2audio_args.yaml', './useful_ckpts/inpaint7_epoch00047.ckpt')
vocoder = VocoderBigVGAN('./vocoder/logs/bigv16k53w',device=device)
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## Make-An-Audio Inpainting")
with gr.Row():
with gr.Column():
input_audio = gr.inputs.Audio()
show_button = gr.Button("Show Mel")
run_button = gr.Button("Predict Masked Place")
with gr.Accordion("Advanced options", open=False):
ddim_steps = gr.Slider(label="Steps", minimum=1,
maximum=150, value=100, step=1)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
randomize=True,
)
with gr.Column():
show_inpainted = gr.Image(type="pil").style(width=848,height=80)
outaudio = gr.Audio()
show_mel = gr.Image(type="pil",tool='sketch')#.style(width=848,height=80) # 加上这个没办法展示完全图片
show_button.click(fn=show_mel_fn, inputs=[input_audio], outputs=show_mel)
run_button.click(fn=predict, inputs=[input_audio,show_mel,ddim_steps,seed], outputs=[show_inpainted,outaudio])
block.launch()
|