Spaces:
Running
on
Zero
Running
on
Zero
# imports from gradio_demo.py | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision.transforms import ToTensor, ToPILImage | |
import sys | |
import os | |
from midi_player import MIDIPlayer | |
from midi_player.stylers import basic, cifka_advanced, dark | |
import numpy as np | |
from time import sleep | |
from subprocess import call | |
import pandas as pd | |
# imports from sample.py | |
import argparse | |
from pathlib import Path | |
import accelerate | |
import safetensors.torch as safetorch | |
#import torch | |
from tqdm import trange, tqdm | |
#from PIL import Image | |
from torchvision import transforms | |
import k_diffusion as K | |
from .pianoroll import regroup_lines, img_file_2_midi_file, square_to_rect, rect_to_square | |
from .square_to_rect import square_to_rect | |
def infer_mask_from_init_img(img, mask_with='grey'): | |
"note, this works whether image is normalized on 0..1 or -1..1, but not 0..255" | |
assert mask_with in ['blue','white','grey'] | |
"given an image with mask areas marked, extract the mask itself" | |
if not torch.is_tensor(img): | |
img = ToTensor()(img) | |
print("img.shape: ", img.shape) | |
# shape of mask should be img shape without the channel dimension | |
if len(img.shape) == 3: | |
mask = torch.zeros(img.shape[-2:]) | |
elif len(img.shape) == 2: | |
mask = torch.zeros(img.shape) | |
print("mask.shape: ", mask.shape) | |
if mask_with == 'white': | |
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1 | |
elif mask_with == 'blue': | |
mask[img[2,:,:]==1] = 1 # blue | |
if mask_with == 'grey': | |
mask[ (img[0,:,:] != 0) & (img[0,:,:]==img[1,:,:]) & (img[2,:,:]==img[1,:,:])] = 1 | |
return mask*1.0 | |
def count_notes_in_mask(img, mask): | |
"counts the number of new notes in the mask" | |
img_t = ToTensor()(img) | |
new_notes = (mask * (img_t[1,:,:] > 0)).sum() # green channel | |
return new_notes.item() | |
def greet(name): | |
return "Hello " + name + "!!" | |
demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
demo.launch() | |