Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,081 Bytes
caabda6 59a2caa caabda6 500319a 5e340e8 500319a f084bcc 0a08063 b1e308f 6dfde8b 8b092f7 d19a04f 6dfde8b 3f93e88 6dfde8b 3f93e88 6dfde8b 3f93e88 6dfde8b 3cf4680 6dfde8b 3cf4680 0a08063 f084bcc 28f8000 f084bcc 0a08063 59a2caa 3cf4680 5e340e8 3cf4680 fcc20d8 3cf4680 8a80eb5 3cf4680 b46aa4b d19a04f 3cf4680 1b4c6b5 3cf4680 40a67c8 288c3b4 867f599 40a67c8 867f599 3cea661 224d483 867f599 3cf4680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
# imports from gradio_demo.py
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np
from time import sleep
from subprocess import call
import pandas as pd
# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K
# test natten import:
import natten
import accelerate
from sample import zero_wrapper
from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect
CT_HOME = '.'
def infer_mask_from_init_img(img, mask_with='grey'):
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
assert mask_with in ['blue','white','grey']
"given an image with mask areas marked, extract the mask itself"
print("\n in infer_mask_from_init_img: ")
if not torch.is_tensor(img):
img = ToTensor()(img)
print(" img.shape: ", img.shape)
# shape of mask should be img shape without the channel dimension
if len(img.shape) == 3:
mask = torch.zeros(img.shape[-2:])
elif len(img.shape) == 2:
mask = torch.zeros(img.shape)
print(" mask.shape: ", mask.shape)
if mask_with == 'white':
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
elif mask_with == 'blue':
mask[img[2,:,:]==1] = 1 # blue
if mask_with == 'grey':
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1
return mask*1.0
def count_notes_in_mask(img, mask):
"counts the number of new notes in the mask"
img_t = ToTensor()(img)
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
return new_notes.item()
def grab_dense_gen(init_img,
PREFIX,
num_to_gen=64,
busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
):
df = None
mask = infer_mask_from_init_img(init_img, mask_with='grey')
for num in range(num_to_gen):
filename = f'{PREFIX}_{num:05d}.png'
gen_img = Image.open(filename)
gen_img_rect = square_to_rect(gen_img)
new_notes = count_notes_in_mask(gen_img, mask)
if df is None:
df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
else:
df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)
# sort df by new_notes column,
df = df.sort_values(by='new_notes', ascending=True)
grab_index = (len(df)-1)*busyness//100
print("grab_index = ", grab_index)
dense_filename = df.iloc[grab_index]['filename']
print("Grabbing filename = ", dense_filename)
return dense_filename
# dummy class to make an args-like object
class Args:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
return f'Args({", ".join(f"{key}={value}" for key, value in self.__dict__.items())})'
def __str__(self):
return f'Args with attributes: {", ".join(f"{key}={value}" for key, value in self.__dict__.items())}'
@spaces.GPU
def process_image(image, repaint, busyness):
# get image ready and execute sampler
#print("image = ",image)
image = image['composite']
# if image is a numpy array convert to PIL
if isinstance(image, np.ndarray):
image = ToPILImage()(image)
image = image.convert("RGB").crop((0, 0, 512, 128))
image = rect_to_square( image )
#mask = infer_mask_from_init_img( image )
masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
print("Saving masked image file to ", masked_img_file)
image.save(masked_img_file)
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
bs = num
repaint = repaint
seed_scale = 1.0
CT_HOME = '.'
CKPT = f'ckpt/256_chords_00130000.pth'
PREFIX = 'gradiodemo'
# !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
print("Reading init image from ", masked_img_file,", repaint = ",repaint)
# HF ZeroGPU+Gradio doesn't seem to work with subprocesses.
use_subprocess = True
if use_subprocess:
cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
print("Will run command: ", cmd)
args = cmd.split(' ')
#call(cmd, shell=True)
print("Calling subprocess with args = ", args,"\n")
return_value = call(args)
print("Subprocess finished. Return value = ", return_value)
else:
accelerator = accelerate.Accelerator()
device = accelerator.device
print("Accelerator device = ", device)
args = Args(batch_size=bs, checkpoint=CKPT, config=f'{CT_HOME}/configs/config_pop909_256x256_chords.json', n=num, prefix=PREFIX, init_image=masked_img_file, steps=100, seed_scale=0.0, repaint=repaint)
print(" Now calling zero_wrapper with args = ",args,"\n")
zero_wrapper(args, accelerator, device)
# find gen'd image and convert to midi piano roll
#gen_file = f'{PREFIX}_00000.png'
gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
gen_image = square_to_rect(Image.open(gen_file))
midi_file = img_file_2_midi_file(gen_file)
srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
srcdoc = srcdoc.replace("\"", "'")
html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''
# convert the midi to audio too
audio_file = 'gradio_demo_out.mp3'
cmd = f'timidity {midi_file} -Ow -o {audio_file}'
print("Converting midi to audio with: ", cmd)
return_value = call(cmd.split(' '))
print("Return value = ", return_value)
return gen_image, html, audio_file
# for making dictionaries for gradio
make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}
with gr.Blocks() as demo:
gr.Markdown(
"""
# Pictures of MIDI
Spaces demo of "Pictures Of MIDI: Controlled Music Generation via Graphical Prompts for Image-Based Diffusion Inpainting" by Scott H. Hawley
Paper: https://arxiv.org/abs/2407.01499
Website with examples & more discussion: https://picturesofmidi.github.io/PicturesOfMIDI/
## Instructions
Choose from the examples at the bottom, and/or select the 'draw' tool (pen with a squiggle on it) to draw shapes to inpaint with notes.
White denotes regions to inpaint. (In the paper we used blue.)
## Issues
If you get "`Error`", then try pressing Submit again. It seems that Spaces/Gradio *intermittently* kills this demo with a "`GPU task aborted`" error.
""")
with gr.Row():
with gr.Column():
in_img = gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"]))
repaint = gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)")
busyness = gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")
with gr.Column():
out_img = gr.Image(width=512, height=128, label='Generated Piano Roll Image')
out_midi = gr.HTML(label="MIDI Player")
out_audio = gr.Audio(label="MIDI as Audio")
inp = [in_img, repaint, busyness]
out = [out_img, out_midi, out_audio]
btn = gr.Button("Submit")
btn.click(fn=process_image, inputs=inp, outputs=out)
exam = gr.Examples(
examples=#[
[[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
[[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
[[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
[[make_dict('pom_mask_shrunk.png'),6,100]]
,#],
fn=process_image, inputs=inp, outputs=out, examples_per_page=50, run_on_click=False, cache_examples='lazy')
# old 'Interface' version
# demo = gr.Interface(fn=process_image,
# inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
# gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
# gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
# outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'),
# gr.HTML(label="MIDI Player"),
# gr.Audio(label="MIDI as Audio")],
# examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
# [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
# [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
# [[make_dict('ismir_mask_2.png'),6,100]],
# )
demo.queue().launch() |