Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH) | |
# Modified by Scott H. Hawley for masking, ZeroGPU ets. | |
"""Samples from k-diffusion models.""" | |
import argparse | |
from pathlib import Path | |
import accelerate | |
import safetensors.torch as safetorch | |
import torch | |
from tqdm import trange, tqdm | |
from PIL import Image | |
from torchvision import transforms | |
import k_diffusion as K | |
from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule | |
#CHORD_BORDER = 8 # chord border size in pixels | |
from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder | |
# ---- my mangled sampler that includes repaint | |
import torchsde | |
class BatchedBrownianTree: | |
"""A wrapper around torchsde.BrownianTree that enables batches of entropy.""" | |
def __init__(self, x, t0, t1, seed=None, **kwargs): | |
t0, t1, self.sign = self.sort(t0, t1) | |
w0 = kwargs.get('w0', torch.zeros_like(x)) | |
if seed is None: | |
seed = torch.randint(0, 2 ** 63 - 1, []).item() | |
self.batched = True | |
try: | |
assert len(seed) == x.shape[0] | |
w0 = w0[0] | |
except TypeError: | |
seed = [seed] | |
self.batched = False | |
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] | |
def sort(a, b): | |
return (a, b, 1) if a < b else (b, a, -1) | |
def __call__(self, t0, t1): | |
t0, t1, sign = self.sort(t0, t1) | |
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) | |
return w if self.batched else w[0] | |
class BrownianTreeNoiseSampler: | |
"""A noise sampler backed by a torchsde.BrownianTree. | |
Args: | |
x (Tensor): The tensor whose shape, device and dtype to use to generate | |
random samples. | |
sigma_min (float): The low end of the valid interval. | |
sigma_max (float): The high end of the valid interval. | |
seed (int or List[int]): The random seed. If a list of seeds is | |
supplied instead of a single integer, then the noise sampler will | |
use one BrownianTree per batch item, each with its own seed. | |
transform (callable): A function that maps sigma to the sampler's | |
internal timestep. | |
""" | |
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): | |
self.transform = transform | |
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) | |
self.tree = BatchedBrownianTree(x, t0, t1, seed) | |
def __call__(self, sigma, sigma_next): | |
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) | |
return self.tree(t0, t1) / (t1 - t0).abs().sqrt() | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | |
return x[(...,) + (None,) * dims_to_append] | |
def to_d(x, sigma, denoised): | |
"""Converts a denoiser output to a Karras ODE derivative.""" | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1): | |
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" | |
extra_args = {} if extra_args is None else extra_args | |
s_in = x.new_ones([x.shape[0]]) | |
for i in trange(len(sigmas) - 1, disable=disable): | |
for u in range(repaint): | |
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. | |
eps = torch.randn_like(x) * s_noise | |
sigma_hat = sigmas[i] * (gamma + 1) | |
if gamma > 0: | |
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 | |
denoised = model(x, sigma_hat * s_in, **extra_args) | |
d = to_d(x, sigma_hat, denoised) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) | |
dt = sigmas[i + 1] - sigma_hat | |
# Euler method | |
x = x + d * dt | |
if x.isnan().any(): | |
assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}" | |
if u < repaint - 1: | |
beta = (sigmas[i + 1] / sigmas[-1]) ** 2 | |
x = torch.sqrt(1 - beta) * x + torch.sqrt(beta) * torch.randn_like(x) | |
return x | |
def get_scalings(sigma, sigma_data=0.5): | |
c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2) | |
c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2) ** 0.5 | |
c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | |
return c_skip, c_out, c_in | |
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, | |
disable=None, eta=1., s_noise=1., noise_sampler=None, | |
solver_type='midpoint', | |
repaint=4): | |
"""DPM-Solver++(2M) SDE. but with repaint added""" | |
if solver_type not in {'heun', 'midpoint'}: | |
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') | |
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() | |
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler | |
extra_args = {} if extra_args is None else extra_args | |
s_in = x.new_ones([x.shape[0]]) | |
old_denoised = None | |
h_last = None | |
old_x = None | |
for i in trange(len(sigmas) - 1, disable=disable): # time loop | |
for u in range(repaint): | |
denoised = model(x, sigmas[i] * s_in, **extra_args) | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
#print("i, u, sigmas[i], sigmas[i + 1] = ", i, u, sigmas[i], sigmas[i + 1]) | |
if sigmas[i + 1] == 0: | |
# Denoising step | |
x = denoised | |
else: | |
# DPM-Solver++(2M) SDE | |
t, s = -sigmas[i].log(), -sigmas[i + 1].log() | |
h = s - t | |
eta_h = eta * h | |
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised | |
if old_denoised is not None: | |
r = h_last / h | |
if solver_type == 'heun': | |
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) | |
elif solver_type == 'midpoint': | |
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) | |
if eta: | |
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise | |
if callback is not None: | |
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
if x.isnan().any(): | |
assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}" | |
if u < repaint - 1: | |
# RePaint: go "back" in integration via the "forward" process, by adding a little noise to x | |
# ...but scaled properly! | |
# But how to convert from original RePaint to k-diffusion? I'll try a few variants | |
repaint_choice = 'orig' # ['orig','var1','var2', etc...] | |
sigma_diff = (sigmas[i] - sigmas[i+1]).abs() | |
sigma_ratio = ( sigmas[i+1] / sigma_max ) # use i+1 or i? | |
if repaint_choice == 'orig': # attempt at original RePaint algorithm, which used betas | |
# if sigmas are the std devs, then betas are variances? but beta_max = 1, so how to get that? ratio? | |
beta = sigma_ratio**2 | |
x = torch.sqrt(1-beta)*x + torch.sqrt(beta)*torch.randn_like(x) # this is from RePaint Paper | |
elif repaint_choice == 'var1': # or maybe this...? # worse than orig | |
x = x + sigma_diff*torch.randn_like(x) | |
elif repaint_choice == 'var2': # or this...? # yields NaNs | |
x = (1-sigma_diff)*x + sigma_diff*torch.randn_like(x) | |
elif repaint_choice == 'var3': # results similar to var1 | |
x = (1.0-sigma_ratio)*x + sigmas[i+1]*torch.randn_like(x) | |
elif repaint_choice == 'var4': # NaNs # stealing code from elsewhere, no idea WTF I'm doing. | |
#Invert this: target = (input - c_skip * noised_input) / c_out, where target = model_output | |
x_tm1, x_t = x, old_x | |
# x_tm1 = ( x_0 - c_skip * noised_x0 ) / c_out | |
# So x_tm1*c_out = x_0 - c_skip * noised_x0 | |
input, noise = x_tm1, torch.randn_like(x) | |
noised_input = input + noise * append_dims(sigma_diff, input.ndim) | |
c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigmas[i])] | |
model_output = x_tm1 | |
renoised_x = c_out * model_output + c_skip * noised_input | |
x = renoised_x | |
elif repaint_choice == 'var5': | |
x = torch.sqrt((1-(sigma_diff/sigma_max)**2))*x + sigma_diff*torch.randn_like(x) | |
# include this? guessing no. | |
#old_denoised = denoised | |
#h_last = h | |
old_denoised = denoised | |
h_last = h | |
old_x = x | |
return x | |
# -----from stable-audio-tools | |
# Define the noise schedule and sampling loop | |
def get_alphas_sigmas(t): | |
"""Returns the scaling factors for the clean image (alpha) and for the | |
noise (sigma), given a timestep.""" | |
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) | |
def alpha_sigma_to_t(alpha, sigma): | |
"""Returns a timestep, given the scaling factors for the clean image and for | |
the noise.""" | |
return torch.atan2(sigma, alpha) / math.pi * 2 | |
def t_to_alpha_sigma(t): | |
"""Returns the scaling factors for the clean image and for the noise, given | |
a timestep.""" | |
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) | |
def sample(model, x, steps, eta, **extra_args): | |
"""Draws samples from a model given starting noise. v-diffusion""" | |
ts = x.new_ones([x.shape[0]]) | |
# Create the noise schedule | |
t = torch.linspace(1, 0, steps + 1)[:-1] | |
alphas, sigmas = get_alphas_sigmas(t) | |
# The sampling loop | |
for i in trange(steps): | |
# Get the model output (v, the predicted velocity) | |
with torch.cuda.amp.autocast(): | |
v = model(x, ts * t[i], **extra_args).float() | |
# Predict the noise and the denoised image | |
pred = x * alphas[i] - v * sigmas[i] | |
eps = x * sigmas[i] + v * alphas[i] | |
# If we are not on the last timestep, compute the noisy image for the | |
# next timestep. | |
if i < steps - 1: | |
# If eta > 0, adjust the scaling factor for the predicted noise | |
# downward according to the amount of additional noise to add | |
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ | |
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() | |
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() | |
# Recombine the predicted noise and predicted denoised image in the | |
# correct proportions for the next step | |
x = pred * alphas[i + 1] + eps * adjusted_sigma | |
# Add the correct amount of fresh noise | |
if eta: | |
x += torch.randn_like(x) * ddim_sigma | |
# If we are on the last timestep, output the denoised image | |
return pred | |
# Soft mask inpainting is just shrinking hard (binary) mask inpainting | |
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step | |
def get_bmask(i, steps, mask): | |
strength = (i+1)/(steps) | |
# convert to binary mask | |
bmask = torch.where(mask<=strength,1,0) | |
return bmask | |
def make_cond_model_fn(model, cond_fn): | |
def cond_model_fn(x, sigma, **kwargs): | |
with torch.enable_grad(): | |
x = x.detach().requires_grad_() | |
denoised = model(x, sigma, **kwargs) | |
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() | |
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) | |
return cond_denoised | |
return cond_model_fn | |
# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion | |
# init_data is init_audio as latents (if this is latent diffusion) | |
# For sampling, set both init_data and mask to None | |
# For variations, set init_data | |
# For inpainting, set both init_data & mask | |
def sample_k( | |
model_fn, | |
noise, | |
init_data=None, | |
mask=None, | |
steps=100, | |
sampler_type="dpmpp-2m-sde", | |
sigma_min=0.5, | |
sigma_max=50, | |
rho=1.0, device="cuda", | |
callback=None, | |
cond_fn=None, | |
model_config=None, | |
repaint=1, | |
**extra_args | |
): | |
#denoiser = K.external.VDenoiser(model_fn) | |
denoiser = K.Denoiser(model_fn, sigma_data=model_config['sigma_data']) | |
if cond_fn is not None: | |
denoiser = make_cond_model_fn(denoiser, cond_fn) | |
# Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has | |
#sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) | |
sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma_max, rho=7., device=device) | |
print("sigmas[0] = ", sigmas[0]) | |
# Scale the initial noise by sigma | |
noise = noise * sigmas[0] | |
wrapped_callback = callback | |
if mask is None and init_data is not None: | |
# VARIATION (no inpainting) | |
# set the initial latent to the init_data, and noise it with initial sigma | |
x = init_data + noise | |
elif mask is not None and init_data is not None: | |
# INPAINTING | |
bmask = get_bmask(0, steps, mask) | |
# initial noising | |
input_noised = init_data + noise | |
# set the initial latent to a mix of init_data and noise, based on step 0's binary mask | |
x = input_noised * bmask + noise * (1-bmask) | |
# define the inpainting callback function (Note: side effects, it mutates x) | |
# See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 | |
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
# This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` | |
def inpainting_callback(args): | |
i = args["i"] | |
x = args["x"] | |
sigma = args["sigma"] | |
#denoised = args["denoised"] | |
# noise the init_data input with this step's appropriate amount of noise | |
input_noised = init_data + torch.randn_like(init_data) * sigma | |
# shrinking hard mask | |
bmask = get_bmask(i, steps, mask) | |
# mix input_noise with x, using binary mask | |
new_x = input_noised * bmask + x * (1-bmask) | |
# mutate x | |
x[:,:,:] = new_x[:,:,:] | |
# wrap together the inpainting callback and the user-submitted callback. | |
if callback is None: | |
wrapped_callback = inpainting_callback | |
else: | |
wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) | |
else: | |
# SAMPLING | |
# set the initial latent to noise | |
x = noise | |
print("sample_k: x.min, x.max = ", x.min(), x.max()) | |
print(f"sample_k: key, val.dtype = ",[ (key, val.dtype if val is not None else val) for key,val in extra_args.items()]) | |
with torch.cuda.amp.autocast(): | |
if sampler_type == "k-heun": | |
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "k-lms": | |
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "k-dpmpp-2s-ancestral": | |
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "k-dpm-2": | |
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "k-dpm-fast": | |
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "k-dpm-adaptive": | |
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "dpmpp-2m-sde": | |
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "my-dpmpp-2m-sde": | |
return my_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args) | |
elif sampler_type == "dpmpp-3m-sde": | |
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) | |
elif sampler_type == "my-sample-euler": | |
return my_sample_euler(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args) | |
## ---- end stable-audio-tools | |
#@spaces.GPU | |
def infer_mask_from_init_img(img, mask_with='white'): | |
"""given an image with mask areas marked, extract the mask itself | |
note, this works whether image is normalized on 0..1 or -1..1, but not 0..255""" | |
print("Inferring mask from init_img") | |
assert mask_with in ['blue','white'] | |
if not torch.is_tensor(img): | |
img = ToTensor()(img) | |
mask = torch.zeros(img.shape[-2:]) | |
if mask_with == 'white': | |
mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1 | |
elif mask_with == 'blue': | |
mask[img[2,:,:]==1] = 1 # blue | |
return mask*1.0 | |
def grow_mask(init_mask, grow_by=2): | |
"adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing" | |
new_mask = init_mask.clone() | |
for c in range(grow_by): | |
# wherever mask is bordered by a 1, set it to 1 | |
new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0 | |
return new_mask | |
def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0): | |
"adds extra noise inside mask" | |
init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger | |
if not torch.is_tensor(init_image): | |
init_image = ToTensor()(init_image) | |
init_image = init_image.clone() | |
# wherever mask is 1, set first set init_image to min value | |
init_image[:,init_mask == 1] = init_image.min() | |
init_image = init_image + seed_scale*torch.randn_like(init_image) * (init_mask) # add noise where mask is 1 | |
# wherever the mask is 1, set the blue channel to -1.0, otherwise leave it alone | |
init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask | |
return init_image | |
def get_init_image_and_mask(args, device): | |
convert_tensor = transforms.ToTensor() | |
init_image = Image.open(args.init_image).convert('RGB') | |
init_image = convert_tensor(init_image) | |
#normalize image from 0..1 to -1..1 | |
init_image = (2.0 * init_image) - 1.0 | |
init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same | |
inpaint_task = 'infer' # infer mask from init_image | |
assert inpaint_task in ['accomp','chords','melody','nucleation','notes','continue','infer'] | |
if inpaint_task in ['melody','accomp']: | |
init_mask[0:70,:] = 0 # zero out a melody strip of image near top | |
init_mask[128+0:128+70,:] = 0 # zero out a melody strip of image along bottom row | |
if inpaint_task == 'melody': | |
init_mask = 1 - init_mask | |
elif inpaint_task in ['notes','chords']: | |
# keep chords only | |
#init_mask = torch.ones_like(x) | |
init_mask[0:CHORD_BORDER,:] = 0 # top row of 256x256 | |
init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0 # middle rows of 256x256 | |
init_mask[-CHORD_BORDER:,:] = 0 # bottom row of 256x256 | |
if inpaint_task == 'chords': | |
init_mask = 1 - init_mask # inverse: genereate chords given notes | |
elif inpaint_task == 'continue': | |
init_mask[0:128,:] = 0 # remember it's a square, so just mask out the bottom half | |
elif inpaint_task == 'nucleation': | |
# set mask to wherever the blue channel is >= 0.9 | |
init_mask = (init_image[2,:,:] > 0.0)*1.0 | |
# zero out init mask in top and bottom borders | |
init_mask[0:CHORD_BORDER,:] = 0 | |
init_mask[-CHORD_BORDER:,:] = 0 | |
init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0 | |
# remove all blue in init_image between the borders | |
init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0 | |
init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0 | |
# grow the sides of the mask by one pixel: | |
# wherever mask is zero but is bordered by a 1, set it to 1 | |
init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0 | |
#init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0 | |
elif inpaint_task == 'infer': | |
init_mask = infer_mask_from_init_img(init_image, mask_with='white') | |
# Also black out init_image wherever init mask is 1 | |
init_image[:,init_mask == 1] = init_image.min() | |
if args.seed_scale > 0: # driving nucleation | |
print("Seeding nucleation, seed_scale = ", args.seed_scale) | |
init_image = add_seeding(init_image, init_mask, grow_by=0, seed_scale=args.seed_scale) | |
# remove any blue in middle of init image | |
print("init_image.shape = ", init_image.shape) | |
init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0 | |
init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0 | |
# Debugging: output some images so we can see what's going on | |
init_mask_t = init_mask.float()*255 # convert mask to 0..255 for writing as image | |
# Convert to NumPy array and rearrange dimensions | |
init_mask_img_numpy = init_mask_t.byte().cpu().numpy()#.transpose(1, 2, 0) | |
init_mask_debug_img = Image.fromarray(init_mask_img_numpy) | |
init_mask_debug_img.save("init_mask_debug.png") | |
init_image_debug_img = Image.fromarray((init_image*127.5+127.5).byte().cpu().numpy().transpose(1,2,0)) | |
init_image_debug_img.save("init_image_debug.png") | |
# reshape image and mask to be 4D tensors | |
init_image = init_image.unsqueeze(0).repeat(args.batch_size, 1, 1, 1) | |
init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float() | |
return init_image.to(device), init_mask.to(device) | |
# wrapper compatible with ZeroGPU, callable from outside | |
#@spaces.GPU | |
def zero_wrapper(args, accelerator, device): | |
global init_image, init_mask | |
config = K.config.load_config(args.config if args.config else args.checkpoint) | |
model_config = config['model'] | |
# TODO: allow non-square input sizes | |
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] | |
size = model_config['input_size'] | |
print('zero_wrapper: Using device:', device, flush=True) | |
inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) | |
cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device) # add chord embedding-maker to main model | |
if cse is not None: | |
inner_model.cse = cse | |
try: | |
inner_model.load_state_dict(safetorch.load_file(args.checkpoint)) | |
except: | |
#ckpt = torch.load(args.checkpoint).to(device) | |
ckpt = torch.load(args.checkpoint, map_location='cpu') | |
inner_model.load_state_dict(ckpt['model']) | |
print('Parameters:', K.utils.n_params(inner_model)) | |
model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) | |
sigma_min = model_config['sigma_min'] | |
sigma_max = model_config['sigma_max'] | |
torch.set_float32_matmul_precision('high') | |
extra_args = {} | |
init_image, init_mask = None, None | |
if args.init_image is not None: | |
init_image, init_mask = get_init_image_and_mask(args, device) | |
init_image = init_image.to(device) | |
init_mask = init_mask.to(device) | |
def run(): | |
global init_image, init_mask | |
if accelerator.is_local_main_process: | |
tqdm.write('Sampling...') | |
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) | |
def sample_fn(n, debug=True): | |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max | |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max()) | |
if args.init_image is not None: | |
init_data, mask = get_init_image_and_mask(args, device) | |
init_data = args.seed_scale*x*mask + (1-mask)*init_data # extra nucleation? | |
if cse is not None: | |
chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device) | |
else: | |
chord_cond = None | |
#print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max()) | |
else: | |
init_data, mask, chord_cond = None, None, None | |
# chord_cond doesn't work anyway so f it: | |
chord_cond = None | |
print("chord_cond = ", chord_cond) | |
if chord_cond is not None: | |
extra_args['chord_cond'] = chord_cond | |
# these two work: | |
#x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args) | |
#x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args) | |
noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) | |
sampler_type="my-dpmpp-2m-sde" # "k-lms" | |
#sampler_type="my-sample-euler" | |
#sampler_type="dpmpp-2m-sde" | |
#sampler_type = "dpmpp-3m-sde" | |
#sampler_type = "k-dpmpp-2s-ancestral" | |
print("dtypes:", [x.dtype if x is not None else None for x in [noise, init_data, mask, chord_cond]]) | |
x_0 = sample_k(inner_model, noise, sampler_type=sampler_type, | |
init_data=init_data, mask=mask, steps=args.steps, | |
sigma_min=sigma_min, sigma_max=sigma_max, rho=7., | |
device=device, model_config=model_config, repaint=args.repaint, | |
**extra_args) | |
#x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100, sigma_min=0.5, sigma_max=50, rho=1., device=device, model_config=model_config, **extra_args) | |
print("x_0.min, x_0.max = ", x_0.min(), x_0.max()) | |
if x_0.isnan().any(): | |
assert False, "x_0 has NaNs" | |
# do gpu garbage collection before proceeding | |
torch.cuda.empty_cache() | |
return x_0 | |
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) | |
if accelerator.is_main_process: | |
for i, out in enumerate(x_0): | |
filename = f'{args.prefix}_{i:05}.png' | |
K.utils.to_pil_image(out).save(filename) | |
try: | |
run() | |
except KeyboardInterrupt: | |
pass | |
def main(): | |
global init_image, init_mask | |
p = argparse.ArgumentParser(description=__doc__, | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
p.add_argument('--batch-size', type=int, default=64, | |
help='the batch size') | |
p.add_argument('--checkpoint', type=Path, required=True, | |
help='the checkpoint to use') | |
p.add_argument('--config', type=Path, | |
help='the model config') | |
p.add_argument('-n', type=int, default=64, | |
help='the number of images to sample') | |
p.add_argument('--prefix', type=str, default='out', | |
help='the output prefix') | |
p.add_argument('--repaint', type=int, default=1, | |
help='number of (re)paint steps') | |
p.add_argument('--steps', type=int, default=50, | |
help='the number of denoising steps') | |
p.add_argument('--seed-scale', type=float, default=0.0, help='strength of nucleation seeding') | |
p.add_argument('--init-image', type=Path, default=None, help='the initial image') | |
p.add_argument('--init-strength', type=float, default=1., help='strength of init image') | |
args = p.parse_args() | |
print("args =", args, flush=True) | |
config = K.config.load_config(args.config if args.config else args.checkpoint) | |
model_config = config['model'] | |
# TODO: allow non-square input sizes | |
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] | |
size = model_config['input_size'] | |
accelerator = accelerate.Accelerator() | |
device = accelerator.device | |
print('Using device:', device, flush=True) | |
inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device) | |
cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device) # add chord embedding-maker to main model | |
if cse is not None: | |
inner_model.cse = cse | |
try: | |
inner_model.load_state_dict(safetorch.load_file(args.checkpoint)) | |
except: | |
#ckpt = torch.load(args.checkpoint).to(device) | |
ckpt = torch.load(args.checkpoint, map_location='cpu') | |
inner_model.load_state_dict(ckpt['model']) | |
accelerator.print('Parameters:', K.utils.n_params(inner_model)) | |
model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data']) | |
sigma_min = model_config['sigma_min'] | |
sigma_max = model_config['sigma_max'] | |
torch.set_float32_matmul_precision('high') | |
extra_args = {} | |
init_image, init_mask = None, None | |
if args.init_image is not None: | |
init_image, init_mask = get_init_image_and_mask(args, device) | |
init_image = init_image.to(device) | |
init_mask = init_mask.to(device) | |
def run(): | |
global init_image, init_mask | |
if accelerator.is_local_main_process: | |
tqdm.write('Sampling...') | |
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device) | |
def sample_fn(n, debug=True): | |
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max | |
print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max()) | |
if args.init_image is not None: | |
init_data, mask = get_init_image_and_mask(args, device) | |
init_data = args.seed_scale*x*mask + (1-mask)*init_data # extra nucleation? | |
if cse is not None: | |
chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device) | |
else: | |
chord_cond = None | |
#print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max()) | |
else: | |
init_data, mask, chord_cond = None, None, None | |
# chord_cond doesn't work anyway so f it: | |
chord_cond = None | |
print("chord_cond = ", chord_cond) | |
if chord_cond is not None: | |
extra_args['chord_cond'] = chord_cond | |
# these two work: | |
#x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args) | |
#x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args) | |
noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) | |
sampler_type="my-dpmpp-2m-sde" # "k-lms" | |
#sampler_type="my-sample-euler" | |
#sampler_type="dpmpp-2m-sde" | |
#sampler_type = "dpmpp-3m-sde" | |
#sampler_type = "k-dpmpp-2s-ancestral" | |
print("dtypes:", [x.dtype if x is not None else None for x in [noise, init_data, mask, chord_cond]]) | |
x_0 = sample_k(inner_model, noise, sampler_type=sampler_type, | |
init_data=init_data, mask=mask, steps=args.steps, | |
sigma_min=sigma_min, sigma_max=sigma_max, rho=7., | |
device=device, model_config=model_config, repaint=args.repaint, | |
**extra_args) | |
#x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100, sigma_min=0.5, sigma_max=50, rho=1., device=device, model_config=model_config, **extra_args) | |
print("x_0.min, x_0.max = ", x_0.min(), x_0.max()) | |
if x_0.isnan().any(): | |
assert False, "x_0 has NaNs" | |
# do gpu garbage collection before proceeding | |
torch.cuda.empty_cache() | |
return x_0 | |
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size) | |
if accelerator.is_main_process: | |
for i, out in enumerate(x_0): | |
filename = f'{args.prefix}_{i:05}.png' | |
K.utils.to_pil_image(out).save(filename) | |
try: | |
run() | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() | |