File size: 6,722 Bytes
caabda6
 
59a2caa
caabda6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500319a
 
 
59a2caa
 
 
0a08063
b1e308f
 
6dfde8b
d19a04f
6dfde8b
59a2caa
6dfde8b
 
 
 
3f93e88
6dfde8b
 
3f93e88
6dfde8b
 
 
 
 
3f93e88
6dfde8b
 
 
 
 
 
 
 
 
59a2caa
6dfde8b
 
 
 
 
 
 
59a2caa
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfde8b
3cf4680
 
 
 
 
 
 
0a08063
 
59a2caa
3cf4680
59a2caa
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46aa4b
d19a04f
3cf4680
 
 
c6d1aad
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288c3b4
 
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# imports from gradio_demo.py
import gradio as gr 
import spaces
import numpy as np
from PIL import Image
import torch 
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np 
from time import sleep
from subprocess import call
import pandas as pd

# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K

# test natten import:
import natten

zero = torch.Tensor([0]).cuda()
print("Zero Device = ",zero.device," <-- this probably says cpu") # <-- 'cpu' 🤔


from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect

CT_HOME = '.'

@spaces.GPU
def infer_mask_from_init_img(img, mask_with='grey'):
    "note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
    assert mask_with in ['blue','white','grey']
    "given an image with mask areas marked, extract the mask itself"
    print("\n in infer_mask_from_init_img: ")
    if not torch.is_tensor(img):
        img = ToTensor()(img)
    print("    img.shape: ", img.shape)
    # shape of mask should be img shape without the channel dimension
    if len(img.shape) == 3:
        mask = torch.zeros(img.shape[-2:])
    elif len(img.shape) == 2:
        mask = torch.zeros(img.shape)
    print("    mask.shape: ", mask.shape)
    if mask_with == 'white':
        mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
    elif mask_with == 'blue':
        mask[img[2,:,:]==1] = 1  # blue
    if mask_with == 'grey':
        mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:])  & (img[2,:,:]==img[1,:,:])] = 1
    return mask*1.0


@spaces.GPU
def count_notes_in_mask(img, mask):
    "counts the number of new notes in the mask"
    img_t = ToTensor()(img)
    new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
    return new_notes.item()


@spaces.GPU
def grab_dense_gen(init_img, 
                PREFIX, 
                num_to_gen=64, 
                busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
                ):
    df = None
    mask = infer_mask_from_init_img(init_img, mask_with='grey')
    for num in range(num_to_gen):
        filename = f'{PREFIX}_{num:05d}.png'
        gen_img = Image.open(filename)
        gen_img_rect = square_to_rect(gen_img)
        new_notes = count_notes_in_mask(gen_img, mask)
        if df is None:
            df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
        else:
            df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)

    # sort df by new_notes column, 
    df = df.sort_values(by='new_notes', ascending=True)
    grab_index = (len(df)-1)*busyness//100
    print("grab_index = ", grab_index)
    dense_filename = df.iloc[grab_index]['filename']
    print("Grabbing filename = ", dense_filename)
    return dense_filename


@spaces.GPU
def process_image(image, repaint, busyness):
    print("Process Image: Zero Device = ",zero.device)
    # get image ready and execute sampler
    print("image = ",image)
    image = image['composite']
    # if image is a numpy array convert to PIL 
    if isinstance(image, np.ndarray):
        image = ToPILImage()(image)
    image = image.convert("RGB").crop((0, 0, 512, 128))
    image = rect_to_square( image )
    #mask = infer_mask_from_init_img( image )
    masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
    print("Saving masked image file to ", masked_img_file)
    image.save(masked_img_file)
    num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
    bs = num
    repaint = repaint
    seed_scale = 1.0
    CT_HOME = '.'
    CKPT = f'ckpt/256_chords_00130000.pth'
    PREFIX = 'gradiodemo'
    # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
    print("Reading init image from ", masked_img_file,", repaint = ",repaint) 
    cmd = f'{sys.executable} {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'    
    print("Will run command: ", cmd)
    args = cmd.split(' ')
    #call(cmd, shell=True)
    print("Calling: ", args)
    return_value = call(args)
    print("Return value = ", return_value)


    # find gen'd image and convert to midi piano roll 
    #gen_file = f'{PREFIX}_00000.png'
    gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
    gen_image = square_to_rect(Image.open(gen_file))
    midi_file = img_file_2_midi_file(gen_file)
    srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
    srcdoc = srcdoc.replace("\"", "'")
    html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''


    # convert the midi to audio too 
    audio_file = 'gradio_demo_out.mp3'
    cmd = f'timidity {midi_file} -Ow -o {audio_file}'
    print("Converting midi to audio with: ", cmd)
    return_value = call(cmd.split(' '))
    print("Return value = ", return_value)

    return gen_image, html, audio_file



make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}



demo = gr.Interface(fn=process_image,
    inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
        gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
        gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
    outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'), 
             gr.HTML(label="MIDI Player"),
             gr.Audio(label="MIDI as Audio")],
    examples=   [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
                [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
                [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] + 
                [[make_dict('ismir_mask_2.png'),6,100]],
    )
demo.queue().launch()