PicturesOfMIDI / app.py
drscotthawley's picture
typo fix
c020f94
# imports from gradio_demo.py
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np
from time import sleep
from subprocess import call
import pandas as pd
# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K
# test natten import:
import natten
import accelerate
from sample import zero_wrapper
from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect
CT_HOME = '.'
def infer_mask_from_init_img(img, mask_with='grey'):
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
assert mask_with in ['blue','white','grey']
"given an image with mask areas marked, extract the mask itself"
print("\n in infer_mask_from_init_img: ")
if not torch.is_tensor(img):
img = ToTensor()(img)
print(" img.shape: ", img.shape)
# shape of mask should be img shape without the channel dimension
if len(img.shape) == 3:
mask = torch.zeros(img.shape[-2:])
elif len(img.shape) == 2:
mask = torch.zeros(img.shape)
print(" mask.shape: ", mask.shape)
if mask_with == 'white':
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
elif mask_with == 'blue':
mask[img[2,:,:]==1] = 1 # blue
if mask_with == 'grey':
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1
return mask*1.0
def count_notes_in_mask(img, mask):
"counts the number of new notes in the mask"
img_t = ToTensor()(img)
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
return new_notes.item()
def grab_dense_gen(init_img,
PREFIX,
num_to_gen=64,
busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
):
df = None
mask = infer_mask_from_init_img(init_img, mask_with='grey')
for num in range(num_to_gen):
filename = f'{PREFIX}_{num:05d}.png'
gen_img = Image.open(filename)
gen_img_rect = square_to_rect(gen_img)
new_notes = count_notes_in_mask(gen_img, mask)
if df is None:
df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
else:
df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)
# sort df by new_notes column,
df = df.sort_values(by='new_notes', ascending=True)
grab_index = (len(df)-1)*busyness//100
print("grab_index = ", grab_index)
dense_filename = df.iloc[grab_index]['filename']
print("Grabbing filename = ", dense_filename)
return dense_filename
# dummy class to make an args-like object
class Args:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
return f'Args({", ".join(f"{key}={value}" for key, value in self.__dict__.items())})'
def __str__(self):
return f'Args with attributes: {", ".join(f"{key}={value}" for key, value in self.__dict__.items())}'
@spaces.GPU
def process_image(image, repaint, busyness):
# get image ready and execute sampler
#print("image = ",image)
image = image['composite']
# if image is a numpy array convert to PIL
if isinstance(image, np.ndarray):
image = ToPILImage()(image)
image = image.convert("RGB").crop((0, 0, 512, 128))
image = rect_to_square( image )
#mask = infer_mask_from_init_img( image )
masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
print("Saving masked image file to ", masked_img_file)
image.save(masked_img_file)
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
bs = num
repaint = repaint
seed_scale = 1.0
CT_HOME = '.'
CKPT = f'ckpt/256_chords_00130000.pth'
PREFIX = 'gradiodemo'
# !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
print("Reading init image from ", masked_img_file,", repaint = ",repaint)
# HF ZeroGPU+Gradio doesn't seem to work with subprocesses.
use_subprocess = True
if use_subprocess:
cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
print("Will run command: ", cmd)
args = cmd.split(' ')
#call(cmd, shell=True)
print("Calling subprocess with args = ", args,"\n")
return_value = call(args)
print("Subprocess finished. Return value = ", return_value)
else:
accelerator = accelerate.Accelerator()
device = accelerator.device
print("Accelerator device = ", device)
args = Args(batch_size=bs, checkpoint=CKPT, config=f'{CT_HOME}/configs/config_pop909_256x256_chords.json', n=num, prefix=PREFIX, init_image=masked_img_file, steps=100, seed_scale=0.0, repaint=repaint)
print(" Now calling zero_wrapper with args = ",args,"\n")
zero_wrapper(args, accelerator, device)
# find gen'd image and convert to midi piano roll
#gen_file = f'{PREFIX}_00000.png'
gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
gen_image = square_to_rect(Image.open(gen_file))
midi_file = img_file_2_midi_file(gen_file)
srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
srcdoc = srcdoc.replace("\"", "'")
html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''
# convert the midi to audio too
audio_file = 'gradio_demo_out.mp3'
cmd = f'timidity {midi_file} -Ow -o {audio_file}'
print("Converting midi to audio with: ", cmd)
return_value = call(cmd.split(' '))
print("Return value = ", return_value)
return gen_image, html, audio_file
# for making dictionaries for gradio
make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}
with gr.Blocks() as demo:
gr.Markdown(
"""
# Pictures of MIDI
Spaces demo of "Pictures Of MIDI: Controlled Music Generation via Graphical Prompts for Image-Based Diffusion Inpainting" by Scott H. Hawley
Paper: https://arxiv.org/abs/2407.01499
Website with examples & more discussion: https://picturesofmidi.github.io/PicturesOfMIDI/
## Instructions
Choose from the examples at the bottom, and/or select the 'draw' tool (pen with a squiggle on it) to draw shapes to inpaint with notes.
White denotes regions to inpaint. (In the paper I used blue.)
## Issues
If you get "`Error`", then try pressing Submit again. It seems that Spaces/Gradio *intermittently* kills this demo with a "`GPU task aborted`" error.
""")
with gr.Row():
with gr.Column():
in_img = gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"]))
repaint = gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)")
busyness = gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")
with gr.Column():
out_img = gr.Image(width=512, height=128, label='Generated Piano Roll Image')
out_midi = gr.HTML(label="MIDI Player")
out_audio = gr.Audio(label="MIDI as Audio")
inp = [in_img, repaint, busyness]
out = [out_img, out_midi, out_audio]
btn = gr.Button("Submit")
btn.click(fn=process_image, inputs=inp, outputs=out)
exam = gr.Examples(
examples=#[
[[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
[[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
[[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
[[make_dict('pom_mask_shrunk.png'),6,100]]
,#],
fn=process_image, inputs=inp, outputs=out, examples_per_page=50, run_on_click=False, cache_examples='lazy')
# old 'Interface' version
# demo = gr.Interface(fn=process_image,
# inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
# gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
# gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
# outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'),
# gr.HTML(label="MIDI Player"),
# gr.Audio(label="MIDI as Audio")],
# examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
# [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
# [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
# [[make_dict('ismir_mask_2.png'),6,100]],
# )
demo.queue().launch()