#!/usr/bin/env python3

# Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
# Modified by Scott H. Hawley for masking, ZeroGPU ets.

"""Samples from k-diffusion models."""


import argparse
from pathlib import Path

import accelerate
import safetensors.torch as safetorch
import torch
from tqdm import trange, tqdm
from PIL import Image
from torchvision import transforms

import k_diffusion as K

from pom.v_diffusion import DDPM, LogSchedule, CrashSchedule
#CHORD_BORDER = 8   # chord border size in pixels
from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder

# ---- my mangled sampler that includes repaint 
import torchsde 

class BatchedBrownianTree:
    """A wrapper around torchsde.BrownianTree that enables batches of entropy."""

    def __init__(self, x, t0, t1, seed=None, **kwargs):
        t0, t1, self.sign = self.sort(t0, t1)
        w0 = kwargs.get('w0', torch.zeros_like(x))
        if seed is None:
            seed = torch.randint(0, 2 ** 63 - 1, []).item()
        self.batched = True
        try:
            assert len(seed) == x.shape[0]
            w0 = w0[0]
        except TypeError:
            seed = [seed]
            self.batched = False
        self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]

    @staticmethod
    def sort(a, b):
        return (a, b, 1) if a < b else (b, a, -1)

    def __call__(self, t0, t1):
        t0, t1, sign = self.sort(t0, t1)
        w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
        return w if self.batched else w[0]


class BrownianTreeNoiseSampler:
    """A noise sampler backed by a torchsde.BrownianTree.

    Args:
        x (Tensor): The tensor whose shape, device and dtype to use to generate
            random samples.
        sigma_min (float): The low end of the valid interval.
        sigma_max (float): The high end of the valid interval.
        seed (int or List[int]): The random seed. If a list of seeds is
            supplied instead of a single integer, then the noise sampler will
            use one BrownianTree per batch item, each with its own seed.
        transform (callable): A function that maps sigma to the sampler's
            internal timestep.
    """

    def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
        self.transform = transform
        t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
        self.tree = BatchedBrownianTree(x, t0, t1, seed)

    def __call__(self, sigma, sigma_next):
        t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
        return self.tree(t0, t1) / (t1 - t0).abs().sqrt()

def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


@torch.no_grad()
def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
    """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        for u in range(repaint):
            gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
            eps = torch.randn_like(x) * s_noise
            sigma_hat = sigmas[i] * (gamma + 1)
            if gamma > 0:
                x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
            denoised = model(x, sigma_hat * s_in, **extra_args)
            d = to_d(x, sigma_hat, denoised)
            if callback is not None:
                callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
            dt = sigmas[i + 1] - sigma_hat
            # Euler method
            x = x + d * dt
            if x.isnan().any():
                assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}"
            if u < repaint - 1:
                beta = (sigmas[i + 1] / sigmas[-1]) ** 2
                x = torch.sqrt(1 - beta) * x + torch.sqrt(beta) * torch.randn_like(x)

    return x

def get_scalings(sigma, sigma_data=0.5):
    c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
    c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2) ** 0.5
    c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
    return c_skip, c_out, c_in


@torch.no_grad()
def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, 
                    disable=None, eta=1., s_noise=1., noise_sampler=None, 
                    solver_type='midpoint',
                    repaint=4):
    """DPM-Solver++(2M) SDE.  but with repaint added"""

    if solver_type not in {'heun', 'midpoint'}:
        raise ValueError('solver_type must be \'heun\' or \'midpoint\'')

    sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
    noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])

    old_denoised = None
    h_last = None
    old_x = None

    for i in trange(len(sigmas) - 1, disable=disable):  # time loop

        for u in range(repaint):
            denoised = model(x, sigmas[i] * s_in, **extra_args)
            if callback is not None:
                callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
            #print("i, u, sigmas[i], sigmas[i + 1] = ", i, u, sigmas[i], sigmas[i + 1])
            if sigmas[i + 1] == 0:
                # Denoising step
                x = denoised
            else:
                # DPM-Solver++(2M) SDE
                t, s = -sigmas[i].log(), -sigmas[i + 1].log()
                h = s - t
                eta_h = eta * h

                x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised

                if old_denoised is not None:
                    r = h_last / h
                    if solver_type == 'heun':
                        x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
                    elif solver_type == 'midpoint':
                        x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)

                if eta:
                    x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
                
                
                if callback is not None:
                    callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})   

                if x.isnan().any():
                    assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}"
            
                if u < repaint - 1:
                    # RePaint: go "back" in integration via the "forward" process, by adding a little noise to x
                    #  ...but scaled properly!
                    # But how to convert from original RePaint to k-diffusion?  I'll try a few variants
                    repaint_choice = 'orig' # ['orig','var1','var2', etc...]

                    sigma_diff = (sigmas[i] - sigmas[i+1]).abs()
                    sigma_ratio = ( sigmas[i+1] / sigma_max ) # use i+1 or i?
                    if repaint_choice == 'orig': # attempt at original RePaint algorithm, which used betas
                        # if sigmas are the std devs, then betas are variances?  but beta_max = 1, so how to get that? ratio?
                        beta = sigma_ratio**2
                        x = torch.sqrt(1-beta)*x +  torch.sqrt(beta)*torch.randn_like(x) # this is from RePaint Paper
                    elif repaint_choice == 'var1': # or maybe this...?  # worse than orig
                        x = x + sigma_diff*torch.randn_like(x)
                    elif repaint_choice == 'var2':  # or this...?  # yields NaNs
                        x = (1-sigma_diff)*x + sigma_diff*torch.randn_like(x)
                    elif repaint_choice == 'var3':            # results similar to var1
                        x = (1.0-sigma_ratio)*x + sigmas[i+1]*torch.randn_like(x)
                    elif repaint_choice == 'var4':   # NaNs         # stealing code from elsewhere, no idea WTF I'm doing.
                        #Invert this: target = (input - c_skip * noised_input) / c_out, where target = model_output
                        x_tm1, x_t = x, old_x 
                        #              x_tm1 = ( x_0  - c_skip * noised_x0 ) / c_out
                        #       So     x_tm1*c_out = x_0 - c_skip * noised_x0
                        input, noise = x_tm1, torch.randn_like(x)
                        noised_input = input + noise * append_dims(sigma_diff, input.ndim)
                        c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigmas[i])]
                        model_output = x_tm1
                        renoised_x = c_out * model_output + c_skip * noised_input 
                        x = renoised_x
                    elif repaint_choice == 'var5':
                        x = torch.sqrt((1-(sigma_diff/sigma_max)**2))*x + sigma_diff*torch.randn_like(x)

                    # include this?  guessing no.
                    #old_denoised = denoised
                    #h_last = h

        old_denoised = denoised
        h_last = h
        old_x = x
    return x




# -----from stable-audio-tools

# Define the noise schedule and sampling loop
def get_alphas_sigmas(t):
    """Returns the scaling factors for the clean image (alpha) and for the
    noise (sigma), given a timestep."""
    return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)

def alpha_sigma_to_t(alpha, sigma):
    """Returns a timestep, given the scaling factors for the clean image and for
    the noise."""
    return torch.atan2(sigma, alpha) / math.pi * 2

def t_to_alpha_sigma(t):
    """Returns the scaling factors for the clean image and for the noise, given
    a timestep."""
    return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)

@torch.no_grad()
def sample(model, x, steps, eta, **extra_args):
    """Draws samples from a model given starting noise. v-diffusion"""
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    t = torch.linspace(1, 0, steps + 1)[:-1]

    alphas, sigmas = get_alphas_sigmas(t)

    # The sampling loop
    for i in trange(steps):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            v = model(x, ts * t[i], **extra_args).float()

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # If we are not on the last timestep, compute the noisy image for the
        # next timestep.
        if i < steps - 1:
            # If eta > 0, adjust the scaling factor for the predicted noise
            # downward according to the amount of additional noise to add
            ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
                (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
            adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()

            # Recombine the predicted noise and predicted denoised image in the
            # correct proportions for the next step
            x = pred * alphas[i + 1] + eps * adjusted_sigma

            # Add the correct amount of fresh noise
            if eta:
                x += torch.randn_like(x) * ddim_sigma

    # If we are on the last timestep, output the denoised image
    return pred

# Soft mask inpainting is just shrinking hard (binary) mask inpainting
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
def get_bmask(i, steps, mask):
    strength = (i+1)/(steps)
    # convert to binary mask
    bmask = torch.where(mask<=strength,1,0)
    return bmask

def make_cond_model_fn(model, cond_fn):
    def cond_model_fn(x, sigma, **kwargs):
        with torch.enable_grad():
            x = x.detach().requires_grad_()
            denoised = model(x, sigma, **kwargs)
            cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
            cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
        return cond_denoised
    return cond_model_fn

# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
# init_data is init_audio as latents (if this is latent diffusion)
# For sampling, set both init_data and mask to None
# For variations, set init_data 
# For inpainting, set both init_data & mask 
def sample_k(
        model_fn, 
        noise, 
        init_data=None,
        mask=None,
        steps=100, 
        sampler_type="dpmpp-2m-sde", 
        sigma_min=0.5, 
        sigma_max=50, 
        rho=1.0, device="cuda", 
        callback=None, 
        cond_fn=None,
        model_config=None,
        repaint=1,
        **extra_args
    ):

    #denoiser = K.external.VDenoiser(model_fn)
    denoiser = K.Denoiser(model_fn, sigma_data=model_config['sigma_data'])

    if cond_fn is not None:
        denoiser = make_cond_model_fn(denoiser, cond_fn)

    # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
    #sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
    sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma_max, rho=7., device=device)
    print("sigmas[0] = ", sigmas[0])
    # Scale the initial noise by sigma 
    noise = noise * sigmas[0]

    wrapped_callback = callback

    if mask is None and init_data is not None:
        # VARIATION (no inpainting)
        # set the initial latent to the init_data, and noise it with initial sigma
        x = init_data + noise 
    elif mask is not None and init_data is not None:
        # INPAINTING
        bmask = get_bmask(0, steps, mask)
        # initial noising
        input_noised = init_data + noise
        # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
        x = input_noised * bmask + noise * (1-bmask)
        # define the inpainting callback function (Note: side effects, it mutates x)
        # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
        # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
        def inpainting_callback(args):
            i = args["i"]
            x = args["x"]
            sigma = args["sigma"]
            #denoised = args["denoised"]
            # noise the init_data input with this step's appropriate amount of noise
            input_noised = init_data + torch.randn_like(init_data) * sigma
            # shrinking hard mask
            bmask = get_bmask(i, steps, mask)
            # mix input_noise with x, using binary mask
            new_x = input_noised * bmask + x * (1-bmask)
            # mutate x
            x[:,:,:] = new_x[:,:,:]
        # wrap together the inpainting callback and the user-submitted callback. 
        if callback is None: 
            wrapped_callback = inpainting_callback
        else:
            wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
    else:
        # SAMPLING
        # set the initial latent to noise
        x = noise


    print("sample_k: x.min, x.max = ", x.min(), x.max())
    print(f"sample_k: key, val.dtype = ",[ (key, val.dtype if val is not None else val) for key,val in extra_args.items()])
    with torch.cuda.amp.autocast():
        if sampler_type == "k-heun":
            return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "k-lms":
            return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "k-dpmpp-2s-ancestral":
            return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "k-dpm-2":
            return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "k-dpm-fast":
            return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "k-dpm-adaptive":
            return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "dpmpp-2m-sde":
            return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "my-dpmpp-2m-sde":
            return my_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args)
        elif sampler_type == "dpmpp-3m-sde":
            return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
        elif sampler_type == "my-sample-euler":
            return my_sample_euler(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args)


## ---- end stable-audio-tools
#@spaces.GPU
def infer_mask_from_init_img(img, mask_with='white'):
    """given an image with mask areas marked, extract the mask itself
       note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
    print("Inferring mask from init_img")
    assert mask_with in ['blue','white']
    if not torch.is_tensor(img):
        img = ToTensor()(img)
    mask = torch.zeros(img.shape[-2:])
    if mask_with == 'white':
        mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
    elif mask_with == 'blue':
        mask[img[2,:,:]==1] = 1  # blue
    return mask*1.0


def grow_mask(init_mask, grow_by=2):
    "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
    new_mask = init_mask.clone()
    for c in range(grow_by):
        # wherever mask is bordered by a 1, set it to 1
        new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0 
    return new_mask


def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
    "adds extra noise inside mask"
    init_mask = grow_mask(init_mask, grow_by=grow_by)  # make the mask bigger
    if not torch.is_tensor(init_image):
        init_image = ToTensor()(init_image)
    init_image = init_image.clone()
    # wherever mask is 1, set first set init_image to min value 
    init_image[:,init_mask == 1] = init_image.min()   
    init_image = init_image + seed_scale*torch.randn_like(init_image) * (init_mask) # add noise where mask is 1
    # wherever the mask is 1, set the blue channel to -1.0, otherwise leave it alone
    init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
    return init_image


def get_init_image_and_mask(args, device):
    convert_tensor = transforms.ToTensor()
    init_image = Image.open(args.init_image).convert('RGB')
    init_image = convert_tensor(init_image)
    #normalize image from 0..1 to -1..1
    init_image = (2.0 * init_image) - 1.0
    init_mask = torch.ones(init_image.shape[-2:])  # ones are where stuff will change, zeros will stay the same

    inpaint_task = 'infer'  # infer mask from init_image
    assert inpaint_task in ['accomp','chords','melody','nucleation','notes','continue','infer']

    if inpaint_task in ['melody','accomp']:
        init_mask[0:70,:] = 0 # zero out a melody strip of image near top
        init_mask[128+0:128+70,:] = 0 # zero out a melody strip of image along bottom row
        if inpaint_task == 'melody':
            init_mask = 1 - init_mask 
    elif inpaint_task in ['notes','chords']:
        # keep chords only
        #init_mask = torch.ones_like(x) 
        init_mask[0:CHORD_BORDER,:] = 0  # top row of 256x256
        init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0  # middle rows of 256x256
        init_mask[-CHORD_BORDER:,:] = 0  # bottom row of 256x256
        if inpaint_task == 'chords':
            init_mask = 1 - init_mask # inverse: genereate chords given notes
    elif inpaint_task == 'continue': 
        init_mask[0:128,:] = 0     # remember it's a square, so just mask out the bottom half
    elif inpaint_task == 'nucleation':
        # set mask to wherever the blue channel is >= 0.9
        init_mask = (init_image[2,:,:] > 0.0)*1.0
        # zero out init mask in top and bottom borders
        init_mask[0:CHORD_BORDER,:] = 0
        init_mask[-CHORD_BORDER:,:] = 0
        init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0

        # remove all blue in init_image between the borders
        init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0
        init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0

        # grow the sides of the mask by one pixel:
        # wherever mask is zero but is bordered by a 1, set it to 1
        init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0 
        #init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0 
    elif inpaint_task == 'infer':
        init_mask = infer_mask_from_init_img(init_image, mask_with='white')

    # Also black out init_image wherever init mask is 1 
    init_image[:,init_mask == 1] = init_image.min()

    if args.seed_scale > 0: # driving nucleation
        print("Seeding nucleation, seed_scale = ", args.seed_scale)
        init_image = add_seeding(init_image, init_mask, grow_by=0, seed_scale=args.seed_scale)

    # remove any blue in middle of init image
    print("init_image.shape = ", init_image.shape)
    init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0
    init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0

    # Debugging: output some images so we can see what's going on
    init_mask_t = init_mask.float()*255 # convert mask to 0..255 for writing as image
    # Convert to NumPy array and rearrange dimensions
    init_mask_img_numpy = init_mask_t.byte().cpu().numpy()#.transpose(1, 2, 0)
    init_mask_debug_img = Image.fromarray(init_mask_img_numpy)
    init_mask_debug_img.save("init_mask_debug.png")
    init_image_debug_img = Image.fromarray((init_image*127.5+127.5).byte().cpu().numpy().transpose(1,2,0))
    init_image_debug_img.save("init_image_debug.png")

    # reshape image and mask to be 4D tensors
    init_image = init_image.unsqueeze(0).repeat(args.batch_size, 1, 1, 1)
    init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
    return init_image.to(device), init_mask.to(device)




# wrapper compatible with ZeroGPU+Gradio, callable from outside
def zero_wrapper(args, accelerator, device): 
    global init_image, init_mask
    print("zero_wrapper: Using device:", device, flush=True)

    config = K.config.load_config(args.config if args.config else args.checkpoint)
    model_config = config['model']
    # TODO: allow non-square input sizes
    assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
    size = model_config['input_size']

    print('zero_wrapper: Using device:', device, flush=True)

    inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
    cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device)  # add chord embedding-maker to main model
    if cse is not None:
        inner_model.cse = cse
    try:
        inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
    except:
        #ckpt = torch.load(args.checkpoint).to(device)
        ckpt = torch.load(args.checkpoint, map_location='cpu')
        inner_model.load_state_dict(ckpt['model'])

    print('Parameters:', K.utils.n_params(inner_model))
    model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])

    sigma_min = model_config['sigma_min']
    sigma_max = model_config['sigma_max']
    torch.set_float32_matmul_precision('high')
    extra_args = {}
    init_image, init_mask = None, None
    if args.init_image is not None:
        init_image, init_mask = get_init_image_and_mask(args, device)
        init_image = init_image.to(device)
        init_mask = init_mask.to(device)
    @torch.no_grad()
    @K.utils.eval_mode(model)
    def run():
        global init_image, init_mask
        if accelerator.is_local_main_process:
            tqdm.write('Sampling...')
        sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)

        def sample_fn(n, debug=True):
            x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
            print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())

            if args.init_image is not None:
                init_data, mask = get_init_image_and_mask(args, device)
                init_data = args.seed_scale*x*mask + (1-mask)*init_data  # extra nucleation?
                if cse is not None: 
                    chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
                else: 
                    chord_cond = None
                #print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
            else:
                init_data, mask, chord_cond = None, None, None
            # chord_cond doesn't work anyway so f it: 
            chord_cond = None

            print("chord_cond = ", chord_cond)
            if chord_cond is not None: 
                extra_args['chord_cond'] = chord_cond
            # these two work:
            #x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
            #x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)

            noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) 

            sampler_type="my-dpmpp-2m-sde"  # "k-lms"
            #sampler_type="my-sample-euler"
            #sampler_type="dpmpp-2m-sde"  
            #sampler_type = "dpmpp-3m-sde"
            #sampler_type = "k-dpmpp-2s-ancestral"
            print("dtypes:", [x.dtype if x is not None else None  for x in [noise, init_data, mask, chord_cond]])
            x_0 = sample_k(inner_model, noise, sampler_type=sampler_type, 
                           init_data=init_data, mask=mask, steps=args.steps, 
                           sigma_min=sigma_min, sigma_max=sigma_max, rho=7., 
                           device=device, model_config=model_config, repaint=args.repaint, 
                           **extra_args)
            #x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100,  sigma_min=0.5, sigma_max=50, rho=1., device=device,  model_config=model_config, **extra_args)
            print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
            if x_0.isnan().any():
                assert False, "x_0 has NaNs"
            
            # do gpu garbage collection before proceeding
            torch.cuda.empty_cache()
            return x_0
        
        x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
        if accelerator.is_main_process:
            for i, out in enumerate(x_0):
                filename = f'{args.prefix}_{i:05}.png'
                K.utils.to_pil_image(out).save(filename)

    try:
        run()
    except KeyboardInterrupt:
        pass






def main():
    global init_image, init_mask
    p = argparse.ArgumentParser(description=__doc__,
                                formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    p.add_argument('--batch-size', type=int, default=64,
                   help='the batch size')
    p.add_argument('--checkpoint', type=Path, required=True,
                   help='the checkpoint to use')
    p.add_argument('--config', type=Path,
                   help='the model config')
    p.add_argument('-n', type=int, default=64,
                   help='the number of images to sample')
    p.add_argument('--prefix', type=str, default='out',
                   help='the output prefix')
    p.add_argument('--repaint', type=int, default=1,
                   help='number of (re)paint steps')
    p.add_argument('--steps', type=int, default=50,
                   help='the number of denoising steps')
    p.add_argument('--seed-scale', type=float, default=0.0, help='strength of nucleation seeding')
    p.add_argument('--init-image', type=Path, default=None, help='the initial image')
    p.add_argument('--init-strength', type=float, default=1., help='strength of init image')
    args = p.parse_args()
    print("args =", args, flush=True)

    config = K.config.load_config(args.config if args.config else args.checkpoint)
    model_config = config['model']
    # TODO: allow non-square input sizes
    assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
    size = model_config['input_size']

    accelerator = accelerate.Accelerator()
    device = accelerator.device
    print('Using device:', device, flush=True)

    inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
    cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device)  # add chord embedding-maker to main model
    if cse is not None:
        inner_model.cse = cse
    try:
        inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
    except:
        #ckpt = torch.load(args.checkpoint).to(device)
        ckpt = torch.load(args.checkpoint, map_location='cpu')
        inner_model.load_state_dict(ckpt['model'])

    accelerator.print('Parameters:', K.utils.n_params(inner_model))
    model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])

    sigma_min = model_config['sigma_min']
    sigma_max = model_config['sigma_max']

    torch.set_float32_matmul_precision('high')
    extra_args = {}
    init_image, init_mask = None, None
    if args.init_image is not None:
        init_image, init_mask = get_init_image_and_mask(args, device)
        init_image = init_image.to(device)
        init_mask = init_mask.to(device)

    @torch.no_grad()
    @K.utils.eval_mode(model)
    def run():
        global init_image, init_mask
        if accelerator.is_local_main_process:
            tqdm.write('Sampling...')
        sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)

        def sample_fn(n, debug=True):
            x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
            print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())

            if args.init_image is not None:
                init_data, mask = get_init_image_and_mask(args, device)
                init_data = args.seed_scale*x*mask + (1-mask)*init_data  # extra nucleation?
                if cse is not None: 
                    chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
                else: 
                    chord_cond = None
                #print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
            else:
                init_data, mask, chord_cond = None, None, None
            # chord_cond doesn't work anyway so f it: 
            chord_cond = None

            print("chord_cond = ", chord_cond)
            if chord_cond is not None: 
                extra_args['chord_cond'] = chord_cond
            # these two work:
            #x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
            #x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)

            noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) 

            sampler_type="my-dpmpp-2m-sde"  # "k-lms"
            #sampler_type="my-sample-euler"
            #sampler_type="dpmpp-2m-sde"  
            #sampler_type = "dpmpp-3m-sde"
            #sampler_type = "k-dpmpp-2s-ancestral"
            print("dtypes:", [x.dtype if x is not None else None  for x in [noise, init_data, mask, chord_cond]])
            x_0 = sample_k(inner_model, noise, sampler_type=sampler_type, 
                           init_data=init_data, mask=mask, steps=args.steps, 
                           sigma_min=sigma_min, sigma_max=sigma_max, rho=7., 
                           device=device, model_config=model_config, repaint=args.repaint, 
                           **extra_args)
            #x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100,  sigma_min=0.5, sigma_max=50, rho=1., device=device,  model_config=model_config, **extra_args)
            print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
            if x_0.isnan().any():
                assert False, "x_0 has NaNs"
            
            # do gpu garbage collection before proceeding
            torch.cuda.empty_cache()
            return x_0
        
        x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
        if accelerator.is_main_process:
            for i, out in enumerate(x_0):
                filename = f'{args.prefix}_{i:05}.png'
                K.utils.to_pil_image(out).save(filename)

    try:
        run()
    except KeyboardInterrupt:
        pass


if __name__ == '__main__':
    main()