# imports from gradio_demo.py import gradio as gr import spaces import numpy as np from PIL import Image import torch from torchvision.transforms import ToTensor, ToPILImage import sys import os from midi_player import MIDIPlayer from midi_player.stylers import basic, cifka_advanced, dark import numpy as np from time import sleep from subprocess import call import pandas as pd # imports from sample.py import argparse from pathlib import Path import accelerate import safetensors.torch as safetorch #import torch from tqdm import trange, tqdm #from PIL import Image from torchvision import transforms import k_diffusion as K # test natten import: import natten import accelerate from sample import zero_wrapper from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square from pom.square_to_rect import square_to_rect CT_HOME = '.' def infer_mask_from_init_img(img, mask_with='grey'): "note, this works whether image is normalized on 0..1 or -1..1, but not 0..255" assert mask_with in ['blue','white','grey'] "given an image with mask areas marked, extract the mask itself" print("\n in infer_mask_from_init_img: ") if not torch.is_tensor(img): img = ToTensor()(img) print(" img.shape: ", img.shape) # shape of mask should be img shape without the channel dimension if len(img.shape) == 3: mask = torch.zeros(img.shape[-2:]) elif len(img.shape) == 2: mask = torch.zeros(img.shape) print(" mask.shape: ", mask.shape) if mask_with == 'white': mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1 elif mask_with == 'blue': mask[img[2,:,:]==1] = 1 # blue if mask_with == 'grey': mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1 return mask*1.0 def count_notes_in_mask(img, mask): "counts the number of new notes in the mask" img_t = ToTensor()(img) new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel return new_notes.item() def grab_dense_gen(init_img, PREFIX, num_to_gen=64, busyness=100, # after ranking images by how many notes were in mask, which one should we grab? ): df = None mask = infer_mask_from_init_img(init_img, mask_with='grey') for num in range(num_to_gen): filename = f'{PREFIX}_{num:05d}.png' gen_img = Image.open(filename) gen_img_rect = square_to_rect(gen_img) new_notes = count_notes_in_mask(gen_img, mask) if df is None: df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect']) else: df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True) # sort df by new_notes column, df = df.sort_values(by='new_notes', ascending=True) grab_index = (len(df)-1)*busyness//100 print("grab_index = ", grab_index) dense_filename = df.iloc[grab_index]['filename'] print("Grabbing filename = ", dense_filename) return dense_filename # dummy class to make an args-like object class Args: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self): return f'Args({", ".join(f"{key}={value}" for key, value in self.__dict__.items())})' def __str__(self): return f'Args with attributes: {", ".join(f"{key}={value}" for key, value in self.__dict__.items())}' @spaces.GPU def process_image(image, repaint, busyness): # get image ready and execute sampler #print("image = ",image) image = image['composite'] # if image is a numpy array convert to PIL if isinstance(image, np.ndarray): image = ToPILImage()(image) image = image.convert("RGB").crop((0, 0, 512, 128)) image = rect_to_square( image ) #mask = infer_mask_from_init_img( image ) masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale print("Saving masked image file to ", masked_img_file) image.save(masked_img_file) num = 64 # number of images to generate; we'll take the one with the most notes in the masked region bs = num repaint = repaint seed_scale = 1.0 CT_HOME = '.' CKPT = f'ckpt/256_chords_00130000.pth' PREFIX = 'gradiodemo' # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file} print("Reading init image from ", masked_img_file,", repaint = ",repaint) # HF ZeroGPU+Gradio doesn't seem to work with subprocesses. use_subprocess = True if use_subprocess: cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}' print("Will run command: ", cmd) args = cmd.split(' ') #call(cmd, shell=True) print("Calling subprocess with args = ", args,"\n") return_value = call(args) print("Subprocess finished. Return value = ", return_value) else: accelerator = accelerate.Accelerator() device = accelerator.device print("Accelerator device = ", device) args = Args(batch_size=bs, checkpoint=CKPT, config=f'{CT_HOME}/configs/config_pop909_256x256_chords.json', n=num, prefix=PREFIX, init_image=masked_img_file, steps=100, seed_scale=0.0, repaint=repaint) print(" Now calling zero_wrapper with args = ",args,"\n") zero_wrapper(args, accelerator, device) # find gen'd image and convert to midi piano roll #gen_file = f'{PREFIX}_00000.png' gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num) gen_image = square_to_rect(Image.open(gen_file)) midi_file = img_file_2_midi_file(gen_file) srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html srcdoc = srcdoc.replace("\"", "'") html = f'''''' # convert the midi to audio too audio_file = 'gradio_demo_out.mp3' cmd = f'timidity {midi_file} -Ow -o {audio_file}' print("Converting midi to audio with: ", cmd) return_value = call(cmd.split(' ')) print("Return value = ", return_value) return gen_image, html, audio_file make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]} demo = gr.Interface(fn=process_image, inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])), gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"), gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")], outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'), gr.HTML(label="MIDI Player"), gr.Audio(label="MIDI as Audio")], examples= [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+ [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+ [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] + [[make_dict('ismir_mask_2.png'),6,100]], ) demo.queue().launch()