Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,571 Bytes
caabda6 0a08063 6dfde8b 3cf4680 6dfde8b 3cf4680 0a08063 3cf4680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# imports from gradio_demo.py
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np
from time import sleep
from subprocess import call
import pandas as pd
# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K
from .pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from .square_to_rect import square_to_rect
def infer_mask_from_init_img(img, mask_with='grey'):
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
assert mask_with in ['blue','white','grey']
"given an image with mask areas marked, extract the mask itself"
if not torch.is_tensor(img):
img = ToTensor()(img)
print("img.shape: ", img.shape)
# shape of mask should be img shape without the channel dimension
if len(img.shape) == 3:
mask = torch.zeros(img.shape[-2:])
elif len(img.shape) == 2:
mask = torch.zeros(img.shape)
print("mask.shape: ", mask.shape)
if mask_with == 'white':
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
elif mask_with == 'blue':
mask[img[2,:,:]==1] = 1 # blue
if mask_with == 'grey':
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1
return mask*1.0
def count_notes_in_mask(img, mask):
"counts the number of new notes in the mask"
img_t = ToTensor()(img)
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
return new_notes.item()
def grab_dense_gen(init_img,
PREFIX,
num_to_gen=64,
busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
):
df = None
mask = infer_mask_from_init_img(init_img, mask_with='grey')
for num in range(num_to_gen):
filename = f'{PREFIX}_{num:05d}.png'
gen_img = Image.open(filename)
gen_img_rect = square_to_rect(gen_img)
new_notes = count_notes_in_mask(gen_img, mask)
if df is None:
df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
else:
df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)
# sort df by new_notes column,
df = df.sort_values(by='new_notes', ascending=True)
grab_index = (len(df)-1)*busyness//100
print("grab_index = ", grab_index)
dense_filename = df.iloc[grab_index]['filename']
print("Grabbing filename = ", dense_filename)
return dense_filename
def process_image(image, repaint, busyness):
# get image ready and execute sampler
print("image = ",image)
image = image['composite']
# if image is a numpy array convert to PIL
if isinstance(image, np.ndarray):
image = ToPILImage()(image)
image = image.convert("RGB").crop((0, 0, 512, 128))
image = rect_to_square( image )
#mask = infer_mask_from_init_img( image )
masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
print("Saving masked image file to ", masked_img_file)
image.save(masked_img_file)
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
bs = num
repaint = repaint
seed_scale = 1.0
DEVICES = 'CUDA_VISIBLE_DEVICES=3'
USER = 'shawley'
RUN_HOME = f'/runs/{USER}/k-diffusion/pop909/full_chords'
CKPT = f'{RUN_HOME}/256_chords_00130000.pth'
PREFIX = 'gradiodemo'
# !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
print("Reading init image from ", masked_img_file,", repaint = ",repaint)
cmd = f'/home/shawley/envs/hs/bin/python {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
print("Will run command: ", cmd)
args = cmd.split(' ')
#call(cmd, shell=True)
print("Calling: ", args)
return_value = call(args)
print("Return value = ", return_value)
# find gen'd image and convert to midi piano roll
#gen_file = f'{PREFIX}_00000.png'
gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
gen_image = square_to_rect(Image.open(gen_file))
midi_file = img_file_2_midi_file(gen_file)
srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
srcdoc = srcdoc.replace("\"", "'")
html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''
# convert the midi to audio too
audio_file = 'gradio_demo_out.mp3'
cmd = f'timidity {midi_file} -Ow -o {audio_file}'
print("Converting midi to audio with: ", cmd)
return_value = call(cmd.split(' '))
print("Return value = ", return_value)
return gen_image, html, audio_file
# def greet(name):
# return "Hello " + name + "!!"
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()
demo = gr.Interface(fn=process_image,
inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'),
gr.HTML(label="MIDI Player"),
gr.Audio(label="MIDI as Audio")],
examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
[[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
[[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
[[make_dict('ismir_mask_2.png'),6,100]],
)
demo.queue().launch() |