Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,722 Bytes
caabda6 59a2caa caabda6 500319a 59a2caa 0a08063 b1e308f 6dfde8b d19a04f 6dfde8b 59a2caa 6dfde8b 3f93e88 6dfde8b 3f93e88 6dfde8b 3f93e88 6dfde8b 59a2caa 6dfde8b 59a2caa 3cf4680 6dfde8b 3cf4680 0a08063 59a2caa 3cf4680 59a2caa 3cf4680 b46aa4b d19a04f 3cf4680 c6d1aad 3cf4680 288c3b4 3cf4680 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
# imports from gradio_demo.py
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np
from time import sleep
from subprocess import call
import pandas as pd
# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K
# test natten import:
import natten
zero = torch.Tensor([0]).cuda()
print("Zero Device = ",zero.device," <-- this probably says cpu") # <-- 'cpu' 🤔
from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect
CT_HOME = '.'
@spaces.GPU
def infer_mask_from_init_img(img, mask_with='grey'):
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
assert mask_with in ['blue','white','grey']
"given an image with mask areas marked, extract the mask itself"
print("\n in infer_mask_from_init_img: ")
if not torch.is_tensor(img):
img = ToTensor()(img)
print(" img.shape: ", img.shape)
# shape of mask should be img shape without the channel dimension
if len(img.shape) == 3:
mask = torch.zeros(img.shape[-2:])
elif len(img.shape) == 2:
mask = torch.zeros(img.shape)
print(" mask.shape: ", mask.shape)
if mask_with == 'white':
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
elif mask_with == 'blue':
mask[img[2,:,:]==1] = 1 # blue
if mask_with == 'grey':
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1
return mask*1.0
@spaces.GPU
def count_notes_in_mask(img, mask):
"counts the number of new notes in the mask"
img_t = ToTensor()(img)
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
return new_notes.item()
@spaces.GPU
def grab_dense_gen(init_img,
PREFIX,
num_to_gen=64,
busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
):
df = None
mask = infer_mask_from_init_img(init_img, mask_with='grey')
for num in range(num_to_gen):
filename = f'{PREFIX}_{num:05d}.png'
gen_img = Image.open(filename)
gen_img_rect = square_to_rect(gen_img)
new_notes = count_notes_in_mask(gen_img, mask)
if df is None:
df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
else:
df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)
# sort df by new_notes column,
df = df.sort_values(by='new_notes', ascending=True)
grab_index = (len(df)-1)*busyness//100
print("grab_index = ", grab_index)
dense_filename = df.iloc[grab_index]['filename']
print("Grabbing filename = ", dense_filename)
return dense_filename
@spaces.GPU
def process_image(image, repaint, busyness):
print("Process Image: Zero Device = ",zero.device)
# get image ready and execute sampler
print("image = ",image)
image = image['composite']
# if image is a numpy array convert to PIL
if isinstance(image, np.ndarray):
image = ToPILImage()(image)
image = image.convert("RGB").crop((0, 0, 512, 128))
image = rect_to_square( image )
#mask = infer_mask_from_init_img( image )
masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
print("Saving masked image file to ", masked_img_file)
image.save(masked_img_file)
num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
bs = num
repaint = repaint
seed_scale = 1.0
CT_HOME = '.'
CKPT = f'ckpt/256_chords_00130000.pth'
PREFIX = 'gradiodemo'
# !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
print("Reading init image from ", masked_img_file,", repaint = ",repaint)
cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'
print("Will run command: ", cmd)
args = cmd.split(' ')
#call(cmd, shell=True)
print("Calling: ", args)
return_value = call(args)
print("Return value = ", return_value)
# find gen'd image and convert to midi piano roll
#gen_file = f'{PREFIX}_00000.png'
gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
gen_image = square_to_rect(Image.open(gen_file))
midi_file = img_file_2_midi_file(gen_file)
srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
srcdoc = srcdoc.replace("\"", "'")
html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''
# convert the midi to audio too
audio_file = 'gradio_demo_out.mp3'
cmd = f'timidity {midi_file} -Ow -o {audio_file}'
print("Converting midi to audio with: ", cmd)
return_value = call(cmd.split(' '))
print("Return value = ", return_value)
return gen_image, html, audio_file
make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}
demo = gr.Interface(fn=process_image,
inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'),
gr.HTML(label="MIDI Player"),
gr.Audio(label="MIDI as Audio")],
examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
[[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
[[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] +
[[make_dict('ismir_mask_2.png'),6,100]],
)
demo.queue().launch() |