File size: 7,650 Bytes
caabda6
 
59a2caa
caabda6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500319a
 
5e340e8
500319a
f084bcc
 
0a08063
b1e308f
 
6dfde8b
8b092f7
d19a04f
6dfde8b
 
 
 
 
3f93e88
6dfde8b
 
3f93e88
6dfde8b
 
 
 
 
3f93e88
6dfde8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfde8b
3cf4680
 
 
 
 
 
 
0a08063
f084bcc
 
 
 
 
28f8000
 
 
 
 
 
f084bcc
0a08063
59a2caa
3cf4680
5e340e8
3cf4680
fcc20d8
3cf4680
 
 
 
 
 
 
 
 
 
 
8a80eb5
3cf4680
 
b46aa4b
d19a04f
3cf4680
 
 
1b4c6b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288c3b4
 
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# imports from gradio_demo.py
import gradio as gr 
import spaces
import numpy as np
from PIL import Image
import torch 
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np 
from time import sleep
from subprocess import call
import pandas as pd

# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K

# test natten import:
import natten
import accelerate

from sample import zero_wrapper


from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect


CT_HOME = '.'

def infer_mask_from_init_img(img, mask_with='grey'):
    "note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
    assert mask_with in ['blue','white','grey']
    "given an image with mask areas marked, extract the mask itself"
    print("\n in infer_mask_from_init_img: ")
    if not torch.is_tensor(img):
        img = ToTensor()(img)
    print("    img.shape: ", img.shape)
    # shape of mask should be img shape without the channel dimension
    if len(img.shape) == 3:
        mask = torch.zeros(img.shape[-2:])
    elif len(img.shape) == 2:
        mask = torch.zeros(img.shape)
    print("    mask.shape: ", mask.shape)
    if mask_with == 'white':
        mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
    elif mask_with == 'blue':
        mask[img[2,:,:]==1] = 1  # blue
    if mask_with == 'grey':
        mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:])  & (img[2,:,:]==img[1,:,:])] = 1
    return mask*1.0


def count_notes_in_mask(img, mask):
    "counts the number of new notes in the mask"
    img_t = ToTensor()(img)
    new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
    return new_notes.item()


def grab_dense_gen(init_img, 
                PREFIX, 
                num_to_gen=64, 
                busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
                ):
    df = None
    mask = infer_mask_from_init_img(init_img, mask_with='grey')
    for num in range(num_to_gen):
        filename = f'{PREFIX}_{num:05d}.png'
        gen_img = Image.open(filename)
        gen_img_rect = square_to_rect(gen_img)
        new_notes = count_notes_in_mask(gen_img, mask)
        if df is None:
            df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
        else:
            df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)

    # sort df by new_notes column, 
    df = df.sort_values(by='new_notes', ascending=True)
    grab_index = (len(df)-1)*busyness//100
    print("grab_index = ", grab_index)
    dense_filename = df.iloc[grab_index]['filename']
    print("Grabbing filename = ", dense_filename)
    return dense_filename

# dummy class to make an args-like object
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
    def __repr__(self):
        return f'Args({", ".join(f"{key}={value}" for key, value in self.__dict__.items())})'
    
    def __str__(self):
        return f'Args with attributes: {", ".join(f"{key}={value}" for key, value in self.__dict__.items())}'



@spaces.GPU
def process_image(image, repaint, busyness):

    # get image ready and execute sampler
    #print("image = ",image)
    image = image['composite']
    # if image is a numpy array convert to PIL 
    if isinstance(image, np.ndarray):
        image = ToPILImage()(image)
    image = image.convert("RGB").crop((0, 0, 512, 128))
    image = rect_to_square( image )
    #mask = infer_mask_from_init_img( image )
    masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
    print("Saving masked image file to ", masked_img_file)
    image.save(masked_img_file)
    num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
    bs = num
    repaint = repaint
    seed_scale = 1.0
    CT_HOME = '.'
    CKPT = f'ckpt/256_chords_00130000.pth'
    PREFIX = 'gradiodemo'
    # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
    print("Reading init image from ", masked_img_file,", repaint = ",repaint) 

    #  HF ZeroGPU+Gradio doesn't seem to work with subprocesses.
    use_subprocess = True 
    if use_subprocess:
        cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'    
        print("Will run command: ", cmd)
        args = cmd.split(' ')
        #call(cmd, shell=True)
        print("Calling subprocess with args = ", args,"\n")
        return_value = call(args)
        print("Subprocess finished.  Return value = ", return_value)
    else:
        accelerator = accelerate.Accelerator()
        device = accelerator.device
        print("Accelerator device = ", device)
        args = Args(batch_size=bs, checkpoint=CKPT, config=f'{CT_HOME}/configs/config_pop909_256x256_chords.json', n=num, prefix=PREFIX, init_image=masked_img_file, steps=100, seed_scale=0.0, repaint=repaint)
        print(" Now calling zero_wrapper with args = ",args,"\n")
        zero_wrapper(args, accelerator, device)

    # find gen'd image and convert to midi piano roll 
    #gen_file = f'{PREFIX}_00000.png'
    gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
    gen_image = square_to_rect(Image.open(gen_file))
    midi_file = img_file_2_midi_file(gen_file)
    srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
    srcdoc = srcdoc.replace("\"", "'")
    html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''


    # convert the midi to audio too 
    audio_file = 'gradio_demo_out.mp3'
    cmd = f'timidity {midi_file} -Ow -o {audio_file}'
    print("Converting midi to audio with: ", cmd)
    return_value = call(cmd.split(' '))
    print("Return value = ", return_value)

    return gen_image, html, audio_file



make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}



demo = gr.Interface(fn=process_image,
    inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
        gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
        gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
    outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'), 
             gr.HTML(label="MIDI Player"),
             gr.Audio(label="MIDI as Audio")],
    examples=   [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
                [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
                [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] + 
                [[make_dict('ismir_mask_2.png'),6,100]],
    )
demo.queue().launch()