File size: 6,524 Bytes
caabda6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a08063
b1e308f
 
6dfde8b
d19a04f
6dfde8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfde8b
3cf4680
 
 
 
 
 
 
0a08063
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b46aa4b
d19a04f
3cf4680
 
 
d19a04f
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288c3b4
 
 
3cf4680
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# imports from gradio_demo.py
import gradio as gr 
import numpy as np
from PIL import Image
import torch 
from torchvision.transforms import ToTensor, ToPILImage
import sys
import os
from midi_player import MIDIPlayer
from midi_player.stylers import basic, cifka_advanced, dark
import numpy as np 
from time import sleep
from subprocess import call
import pandas as pd

# imports from sample.py
import argparse
from pathlib import Path
import accelerate
import safetensors.torch as safetorch
#import torch
from tqdm import trange, tqdm
#from PIL import Image
from torchvision import transforms
import k_diffusion as K


from pom.pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square
from pom.square_to_rect import square_to_rect

CT_HOME = '.'

def infer_mask_from_init_img(img, mask_with='grey'):
    "note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"
    assert mask_with in ['blue','white','grey']
    "given an image with mask areas marked, extract the mask itself"
    if not torch.is_tensor(img):
        img = ToTensor()(img)
    print("img.shape: ", img.shape)
    # shape of mask should be img shape without the channel dimension
    if len(img.shape) == 3:
        mask = torch.zeros(img.shape[-2:])
    elif len(img.shape) == 2:
        mask = torch.zeros(img.shape)
    print("mask.shape: ", mask.shape)
    if mask_with == 'white':
        mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
    elif mask_with == 'blue':
        mask[img[2,:,:]==1] = 1  # blue
    if mask_with == 'grey':
        mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:])  & (img[2,:,:]==img[1,:,:])] = 1
    return mask*1.0


def count_notes_in_mask(img, mask):
    "counts the number of new notes in the mask"
    img_t = ToTensor()(img)
    new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel
    return new_notes.item()


def grab_dense_gen(init_img, 
                PREFIX, 
                num_to_gen=64, 
                busyness=100, # after ranking images by how many notes were in mask, which one should we grab?
                ):
    df = None
    mask = infer_mask_from_init_img(init_img, mask_with='grey')
    for num in range(num_to_gen):
        filename = f'{PREFIX}_{num:05d}.png'
        gen_img = Image.open(filename)
        gen_img_rect = square_to_rect(gen_img)
        new_notes = count_notes_in_mask(gen_img, mask)
        if df is None:
            df = pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])
        else:
            df = pd.concat([df, pd.DataFrame([[filename, new_notes, gen_img_rect]], columns=['filename', 'new_notes', 'img_rect'])], ignore_index=True)

    # sort df by new_notes column, 
    df = df.sort_values(by='new_notes', ascending=True)
    grab_index = (len(df)-1)*busyness//100
    print("grab_index = ", grab_index)
    dense_filename = df.iloc[grab_index]['filename']
    print("Grabbing filename = ", dense_filename)
    return dense_filename



def process_image(image, repaint, busyness):
    # get image ready and execute sampler
    print("image = ",image)
    image = image['composite']
    # if image is a numpy array convert to PIL 
    if isinstance(image, np.ndarray):
        image = ToPILImage()(image)
    image = image.convert("RGB").crop((0, 0, 512, 128))
    image = rect_to_square( image )
    #mask = infer_mask_from_init_img( image )
    masked_img_file = 'gradio_masked_image.png' # TODO: could allow for clobber at scale
    print("Saving masked image file to ", masked_img_file)
    image.save(masked_img_file)
    num = 64 # number of images to generate; we'll take the one with the most notes in the masked region
    bs = num
    repaint = repaint
    seed_scale = 1.0
    CT_HOME = '.'
    CKPT = f'ckpt/256_chords_00130000.pth'
    PREFIX = 'gradiodemo'
    # !echo {DEVICES} {CT_HOME} {CKPT} {PREFIX} {masked_img_file}
    print("Reading init image from ", masked_img_file,", repaint = ",repaint) 
    cmd = f'python {CT_HOME}/sample.py --batch-size {bs} --checkpoint {CKPT} --config {CT_HOME}/configs/config_pop909_256x256_chords.json -n {num} --prefix {PREFIX} --init-image {masked_img_file} --steps=100 --repaint={repaint}'    
    print("Will run command: ", cmd)
    args = cmd.split(' ')
    #call(cmd, shell=True)
    print("Calling: ", args)
    return_value = call(args)
    print("Return value = ", return_value)


    # find gen'd image and convert to midi piano roll 
    #gen_file = f'{PREFIX}_00000.png'
    gen_file = grab_dense_gen(image, PREFIX, num_to_gen=num)
    gen_image = square_to_rect(Image.open(gen_file))
    midi_file = img_file_2_midi_file(gen_file)
    srcdoc = MIDIPlayer(midi_file, 300, styler=dark).html
    srcdoc = srcdoc.replace("\"", "'")
    html = f'''<iframe srcdoc="{srcdoc}" height="500" width="100%" title="Iframe Example"></iframe>'''


    # convert the midi to audio too 
    audio_file = 'gradio_demo_out.mp3'
    cmd = f'timidity {midi_file} -Ow -o {audio_file}'
    print("Converting midi to audio with: ", cmd)
    return_value = call(cmd.split(' '))
    print("Return value = ", return_value)

    return gen_image, html, audio_file

# def greet(name):
#     return "Hello " + name + "!!"

# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()


make_dict = lambda x: {'background':x, 'composite':x, 'layers':[x]}



demo = gr.Interface(fn=process_image,
    inputs=[gr.ImageEditor(sources=["upload",'clipboard'], label="Input Piano Roll Image (White = Gen Notes Here)", value=make_dict('all_black.png'), brush=gr.Brush(colors=["#FFFFFF","#000000"])),
        gr.Slider(minimum=1, maximum=10, step=1, value=2, label="RePaint (Larger = More Notes, But Crazier. Also Slower.)"),
        gr.Slider(minimum=1, maximum=100, step=1, value=100, label="Busy-ness Percentile (Based on Notes Generated)")],
    outputs=[gr.Image(width=512, height=128, label='Generated Piano Roll Image'), 
             gr.HTML(label="MIDI Player"),
             gr.Audio(label="MIDI as Audio")],
    examples=   [[make_dict(y),1,100] for y in ['all_white.png','all_black.png','init_img_melody.png','init_img_accomp.png','init_img_cont.png',]]+
                [[make_dict(x),2,100] for x in ['584_TOTAL_crop.png', '780_TOTAL_crop_bg.png', '780_TOTAL_crop_draw.png','loop_middle_2.png']]+
                [[make_dict(z),3,100] for z in ['584_TOTAL_crop_draw.png','loop_middle.png']] + 
                [[make_dict('ismir_mask_2.png'),6,100]],
    )
demo.queue().launch()