Spaces:
Running
on
L4
Running
on
L4
import torch | |
import comfy | |
# Check and add 'model_patch' to model.model_options['transformer_options'] | |
def add_model_patch_option(model): | |
if 'transformer_options' not in model.model_options: | |
model.model_options['transformer_options'] = {} | |
to = model.model_options['transformer_options'] | |
if "model_patch" not in to: | |
to["model_patch"] = {} | |
return to | |
# Patch model with model_function_wrapper | |
def patch_model_function_wrapper(model, forward_patch, remove=False): | |
def brushnet_model_function_wrapper(apply_model_method, options_dict): | |
to = options_dict['c']['transformer_options'] | |
control = None | |
if 'control' in options_dict['c']: | |
control = options_dict['c']['control'] | |
x = options_dict['input'] | |
timestep = options_dict['timestep'] | |
# check if there are patches to execute | |
if 'model_patch' not in to or 'forward' not in to['model_patch']: | |
return apply_model_method(x, timestep, **options_dict['c']) | |
mp = to['model_patch'] | |
unet = mp['unet'] | |
all_sigmas = mp['all_sigmas'] | |
sigma = to['sigmas'][0].item() | |
total_steps = all_sigmas.shape[0] - 1 | |
step = torch.argmin((all_sigmas - sigma).abs()).item() | |
mp['step'] = step | |
mp['total_steps'] = total_steps | |
# comfy.model_base.apply_model | |
xc = model.model.model_sampling.calculate_input(timestep, x) | |
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None: | |
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1) | |
t = model.model.model_sampling.timestep(timestep).float() | |
# execute all patches | |
for method in mp['forward']: | |
method(unet, xc, t, to, control) | |
return apply_model_method(x, timestep, **options_dict['c']) | |
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]: | |
print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"]) | |
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper) | |
to = add_model_patch_option(model) | |
mp = to['model_patch'] | |
if isinstance(model.model.model_config, comfy.supported_models.SD15): | |
mp['SDXL'] = False | |
elif isinstance(model.model.model_config, comfy.supported_models.SDXL): | |
mp['SDXL'] = True | |
else: | |
print('Base model type: ', type(model.model.model_config)) | |
raise Exception("Unsupported model type: ", type(model.model.model_config)) | |
if 'forward' not in mp: | |
mp['forward'] = [] | |
if remove: | |
if forward_patch in mp['forward']: | |
mp['forward'].remove(forward_patch) | |
else: | |
mp['forward'].append(forward_patch) | |
mp['unet'] = model.model.diffusion_model | |
mp['step'] = 0 | |
mp['total_steps'] = 1 | |
# apply patches to code | |
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__: | |
comfy.samplers.original_sample = comfy.samplers.sample | |
comfy.samplers.sample = modified_sample | |
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \ | |
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__: | |
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control | |
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control | |
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one. | |
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes | |
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, | |
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): | |
''' | |
Modified by BrushNet nodes | |
''' | |
cfg_guider = comfy.samplers.CFGGuider(model) | |
cfg_guider.set_conds(positive, negative) | |
cfg_guider.set_cfg(cfg) | |
### Modified part ###################################################################### | |
# | |
to = add_model_patch_option(model) | |
to['model_patch']['all_sigmas'] = sigmas | |
# | |
#sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at) | |
#sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at) | |
# | |
# | |
#if math.isclose(cfg, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: | |
# to['model_patch']['free_guidance'] = False | |
#else: | |
# to['model_patch']['free_guidance'] = True | |
# | |
####################################################################################### | |
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) | |
# To use Controlnet with RAUNet it is much easier to modify apply_control a little | |
def modified_apply_control(h, control, name): | |
''' | |
Modified by BrushNet nodes | |
''' | |
if control is not None and name in control and len(control[name]) > 0: | |
ctrl = control[name].pop() | |
if ctrl is not None: | |
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]: | |
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device) | |
try: | |
h += ctrl | |
except: | |
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape)) | |
return h | |