import os import types from typing import Tuple import torch import torchvision.transforms as T import torch.nn.functional as F from accelerate import init_empty_weights, load_checkpoint_and_dispatch import sys import comfy.sd import comfy.utils import comfy.model_management import comfy.sd1_clip import comfy.ldm.models.autoencoder import comfy.supported_models import folder_paths from .model_patch import add_model_patch_option, patch_model_function_wrapper from .brushnet.brushnet import BrushNetModel from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens current_directory = os.path.dirname(os.path.abspath(__file__)) brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json') brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json') powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json') sd15_scaling_factor = 0.18215 sdxl_scaling_factor = 0.13025 print(sys.path) ModelsToUnload = [comfy.sd1_clip.SD1ClipModel, comfy.ldm.models.autoencoder.AutoencoderKL ] class BrushNetLoader: @classmethod def INPUT_TYPES(self): self.inpaint_files = get_files_with_extension('inpaint') return {"required": { "brushnet": ([file for file in self.inpaint_files], ), "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ), }, } CATEGORY = "inpaint" RETURN_TYPES = ("BRMODEL",) RETURN_NAMES = ("brushnet",) FUNCTION = "brushnet_loading" def brushnet_loading(self, brushnet, dtype): brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet) print('BrushNet model file:', brushnet_file) is_SDXL = False is_PP = False sd = comfy.utils.load_torch_file(brushnet_file) brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd) del sd if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30: is_SDXL = False if keys == 322: is_PP = False print('BrushNet model type: SD1.5') else: is_PP = True print('PowerPaint model type: SD1.5') elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22: print('BrushNet model type: Loading SDXL') is_SDXL = True is_PP = False else: raise Exception("Unknown BrushNet model") with init_empty_weights(): if is_SDXL: brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file) brushnet_model = BrushNetModel.from_config(brushnet_config) elif is_PP: brushnet_config = PowerPaintModel.load_config(powerpaint_config_file) brushnet_model = PowerPaintModel.from_config(brushnet_config) else: brushnet_config = BrushNetModel.load_config(brushnet_config_file) brushnet_model = BrushNetModel.from_config(brushnet_config) if is_PP: print("PowerPaint model file:", brushnet_file) else: print("BrushNet model file:", brushnet_file) if dtype == 'float16': torch_dtype = torch.float16 elif dtype == 'bfloat16': torch_dtype = torch.bfloat16 elif dtype == 'float32': torch_dtype = torch.float32 else: torch_dtype = torch.float64 brushnet_model = load_checkpoint_and_dispatch( brushnet_model, brushnet_file, device_map="sequential", max_memory=None, offload_folder=None, offload_state_dict=False, dtype=torch_dtype, force_hooks=False, ) if is_PP: print("PowerPaint model is loaded") elif is_SDXL: print("BrushNet SDXL model is loaded") else: print("BrushNet SD1.5 model is loaded") return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, ) class PowerPaintCLIPLoader: @classmethod def INPUT_TYPES(self): self.inpaint_files = get_files_with_extension('inpaint', ['.bin']) self.clip_files = get_files_with_extension('clip') return {"required": { "base": ([file for file in self.clip_files], ), "powerpaint": ([file for file in self.inpaint_files], ), }, } CATEGORY = "inpaint" RETURN_TYPES = ("CLIP",) RETURN_NAMES = ("clip",) FUNCTION = "ppclip_loading" def ppclip_loading(self, base, powerpaint): base_CLIP_file = os.path.join(self.clip_files[base], base) pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint) pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file]) print('PowerPaint base CLIP file: ', base_CLIP_file) pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer) pp_text_encoder = pp_clip.patcher.model.clip_l.transformer add_tokens( tokenizer = pp_tokenizer, text_encoder = pp_text_encoder, placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"], initialize_tokens = ["a", "a", "a"], num_vectors_per_token = 10, ) pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False) print('PowerPaint CLIP file: ', pp_CLIP_file) pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer pp_clip.patcher.model.clip_l.transformer = pp_text_encoder return (pp_clip,) class PowerPaint: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "vae": ("VAE", ), "image": ("IMAGE",), "mask": ("MASK",), "powerpaint": ("BRMODEL", ), "clip": ("CLIP", ), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}), "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ), "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}), "save_memory": (['none', 'auto', 'max'], ), }, } CATEGORY = "inpaint" RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",) RETURN_NAMES = ("model","positive","negative","latent",) FUNCTION = "model_update" def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory): is_SDXL, is_PP = check_compatibilty(model, powerpaint) if not is_PP: raise Exception("BrushNet model was loaded, please use BrushNet node") # Make a copy of the model so that we're not patching it everywhere in the workflow. model = model.clone() # prepare image and mask # no batches for original image and mask masked_image, mask = prepare_image(image, mask) batch = masked_image.shape[0] #width = masked_image.shape[2] #height = masked_image.shape[1] if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'): scaling_factor = model.model.model_config.latent_format.scale_factor else: scaling_factor = sd15_scaling_factor torch_dtype = powerpaint['dtype'] # prepare conditioning latents conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor) conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device) conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device) # prepare embeddings if function == "object removal": promptA = "P_ctxt" promptB = "P_ctxt" negative_promptA = "P_obj" negative_promptB = "P_obj" print('You should add to positive prompt: "empty scene blur"') #positive = positive + " empty scene blur" elif function == "context aware": promptA = "P_ctxt" promptB = "P_ctxt" negative_promptA = "" negative_promptB = "" #positive = positive + " empty scene" print('You should add to positive prompt: "empty scene"') elif function == "shape guided": promptA = "P_shape" promptB = "P_ctxt" negative_promptA = "P_shape" negative_promptB = "P_ctxt" elif function == "image outpainting": promptA = "P_ctxt" promptB = "P_ctxt" negative_promptA = "P_obj" negative_promptB = "P_obj" #positive = positive + " empty scene" print('You should add to positive prompt: "empty scene"') else: promptA = "P_obj" promptB = "P_obj" negative_promptA = "P_obj" negative_promptB = "P_obj" tokens = clip.tokenize(promptA) prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False) tokens = clip.tokenize(negative_promptA) negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False) tokens = clip.tokenize(promptB) prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False) tokens = clip.tokenize(negative_promptB) negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False) prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device) negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device) # unload vae and CLIPs del vae del clip for loaded_model in comfy.model_management.current_loaded_models: if type(loaded_model.model.model) in ModelsToUnload: comfy.model_management.current_loaded_models.remove(loaded_model) loaded_model.model_unload() del loaded_model # apply patch to model brushnet_conditioning_scale = scale control_guidance_start = start_at control_guidance_end = end_at if save_memory != 'none': powerpaint['brushnet'].set_attention_slice(save_memory) add_brushnet_patch(model, powerpaint['brushnet'], torch_dtype, conditioning_latents, (brushnet_conditioning_scale, control_guidance_start, control_guidance_end), negative_prompt_embeds_pp, prompt_embeds_pp, None, None, None, False) latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device) return (model, positive, negative, {"samples":latent},) class BrushNet: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "vae": ("VAE", ), "image": ("IMAGE",), "mask": ("MASK",), "brushnet": ("BRMODEL", ), "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}), }, } CATEGORY = "inpaint" RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",) RETURN_NAMES = ("model","positive","negative","latent",) FUNCTION = "model_update" def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at): is_SDXL, is_PP = check_compatibilty(model, brushnet) if is_PP: raise Exception("PowerPaint model was loaded, please use PowerPaint node") # Make a copy of the model so that we're not patching it everywhere in the workflow. model = model.clone() # prepare image and mask # no batches for original image and mask masked_image, mask = prepare_image(image, mask) batch = masked_image.shape[0] width = masked_image.shape[2] height = masked_image.shape[1] if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'): scaling_factor = model.model.model_config.latent_format.scale_factor elif is_SDXL: scaling_factor = sdxl_scaling_factor else: scaling_factor = sd15_scaling_factor torch_dtype = brushnet['dtype'] # prepare conditioning latents conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor) conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device) conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device) # unload vae del vae for loaded_model in comfy.model_management.current_loaded_models: if type(loaded_model.model.model) in ModelsToUnload: comfy.model_management.current_loaded_models.remove(loaded_model) loaded_model.model_unload() del loaded_model # prepare embeddings prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device) negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device) max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) if prompt_embeds.shape[1] < max_tokens: multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77 prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1) print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds') if negative_prompt_embeds.shape[1] < max_tokens: multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77 negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1) print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds') if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None: pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device) else: print('BrushNet: positive conditioning has not pooled_output') if is_SDXL: print('BrushNet will not produce correct results') pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype) if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None: negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device) else: print('BrushNet: negative conditioning has not pooled_output') if is_SDXL: print('BrushNet will not produce correct results') negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype) time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device) if not is_SDXL: pooled_prompt_embeds = None negative_pooled_prompt_embeds = None time_ids = None # apply patch to model brushnet_conditioning_scale = scale control_guidance_start = start_at control_guidance_end = end_at add_brushnet_patch(model, brushnet['brushnet'], torch_dtype, conditioning_latents, (brushnet_conditioning_scale, control_guidance_start, control_guidance_end), prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids, False) latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device) return (model, positive, negative, {"samples":latent},) class BlendInpaint: @classmethod def INPUT_TYPES(s): return {"required": { "inpaint": ("IMAGE",), "original": ("IMAGE",), "mask": ("MASK",), "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}), "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}), }, "optional": { "origin": ("VECTOR",), }, } CATEGORY = "inpaint" RETURN_TYPES = ("IMAGE","MASK",) RETURN_NAMES = ("image","MASK",) FUNCTION = "blend_inpaint" def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]: original, mask = check_image_mask(original, mask, 'Blend Inpaint') if len(inpaint.shape) < 4: # image tensor shape should be [B, H, W, C], but batch somehow is missing inpaint = inpaint[None,:,:,:] if inpaint.shape[0] < original.shape[0]: print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0])) original= original[:inpaint.shape[0],:,:] mask = mask[:inpaint.shape[0],:,:] if inpaint.shape[0] > original.shape[0]: # batch over inpaint count = 0 original_list = [] mask_list = [] origin_list = [] while (count < inpaint.shape[0]): for i in range(original.shape[0]): original_list.append(original[i][None,:,:,:]) mask_list.append(mask[i][None,:,:]) if origin is not None: origin_list.append(origin[i][None,:]) count += 1 if count >= inpaint.shape[0]: break original = torch.concat(original_list, dim=0) mask = torch.concat(mask_list, dim=0) if origin is not None: origin = torch.concat(origin_list, dim=0) if kernel % 2 == 0: kernel += 1 transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma)) ret = [] blurred = [] for i in range(inpaint.shape[0]): if origin is None: blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype) blurred.append(blurred_mask[0]) result = torch.nn.functional.interpolate( inpaint[i][None,:,:,:].permute(0, 3, 1, 2), size=( original[i].shape[0], original[i].shape[1], ) ).permute(0, 2, 3, 1).to(original.device).to(original.dtype) else: # got mask from CutForInpaint height, width, _ = original[i].shape x0 = origin[i][0].item() y0 = origin[i][1].item() if mask[i].shape[0] < height or mask[i].shape[1] < width: padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1], y0, height-y0-mask[i].shape[0]), mode='constant', value=0) else: padded_mask = mask[i] blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype) blurred.append(blurred_mask[0][0]) result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1], y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0) result = result[None,:,:,:].to(original.device).to(original.dtype) ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None]) return (torch.stack(ret), torch.stack(blurred), ) class CutForInpaint: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), "mask": ("MASK",), "width": ("INT", {"default": 512, "min": 64, "max": 2048}), "height": ("INT", {"default": 512, "min": 64, "max": 2048}), }, } CATEGORY = "inpaint" RETURN_TYPES = ("IMAGE","MASK","VECTOR",) RETURN_NAMES = ("image","mask","origin",) FUNCTION = "cut_for_inpaint" def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int): image, mask = check_image_mask(image, mask, 'BrushNet') ret = [] msk = [] org = [] for i in range(image.shape[0]): x0, y0, w, h = cut_with_mask(mask[i], width, height) ret.append((image[i][y0:y0+h,x0:x0+w,:])) msk.append((mask[i][y0:y0+h,x0:x0+w])) org.append(torch.IntTensor([x0,y0])) return (torch.stack(ret), torch.stack(msk), torch.stack(org), ) #### Utility function def get_files_with_extension(folder_name, extension=['.safetensors']): try: folders = folder_paths.get_folder_paths(folder_name) except: folders = [] if not folders: folders = [os.path.join(folder_paths.models_dir, folder_name)] if not os.path.isdir(folders[0]): folders = [os.path.join(folder_paths.base_path, folder_name)] if not os.path.isdir(folders[0]): return {} filtered_folders = [] for x in folders: if not os.path.isdir(x): continue the_same = False for y in filtered_folders: if os.path.samefile(x, y): the_same = True break if not the_same: filtered_folders.append(x) if not filtered_folders: return {} output = {} for x in filtered_folders: files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"]) filtered_files = folder_paths.filter_files_extensions(files, extension) for f in filtered_files: output[f] = x return output # get blocks from state_dict so we could know which model it is def brushnet_blocks(sd): brushnet_down_block = 0 brushnet_mid_block = 0 brushnet_up_block = 0 for key in sd: if 'brushnet_down_block' in key: brushnet_down_block += 1 if 'brushnet_mid_block' in key: brushnet_mid_block += 1 if 'brushnet_up_block' in key: brushnet_up_block += 1 return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd)) # Check models compatibility def check_compatibilty(model, brushnet): is_SDXL = False is_PP = False if isinstance(model.model.model_config, comfy.supported_models.SD15): print('Base model type: SD1.5') is_SDXL = False if brushnet["SDXL"]: raise Exception("Base model is SD15, but BrushNet is SDXL type") if brushnet["PP"]: is_PP = True elif isinstance(model.model.model_config, comfy.supported_models.SDXL): print('Base model type: SDXL') is_SDXL = True if not brushnet["SDXL"]: raise Exception("Base model is SDXL, but BrushNet is SD15 type") else: print('Base model type: ', type(model.model.model_config)) raise Exception("Unsupported model type: " + str(type(model.model.model_config))) return (is_SDXL, is_PP) def check_image_mask(image, mask, name): if len(image.shape) < 4: # image tensor shape should be [B, H, W, C], but batch somehow is missing image = image[None,:,:,:] if len(mask.shape) > 3: # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be? # take first mask, red channel mask = (mask[:,:,:,0])[:,:,:] elif len(mask.shape) < 3: # mask tensor shape should be [B, H, W] but batch somehow is missing mask = mask[None,:,:] if image.shape[0] > mask.shape[0]: print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0])) if mask.shape[0] == 1: print(name, "will copy the mask to fill batch") mask = torch.cat([mask] * image.shape[0], dim=0) else: print(name, "will add empty masks to fill batch") empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]]) mask = torch.cat([mask, empty_mask], dim=0) elif image.shape[0] < mask.shape[0]: print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0])) mask = mask[:image.shape[0],:,:] return (image, mask) # Prepare image and mask def prepare_image(image, mask): image, mask = check_image_mask(image, mask, 'BrushNet') print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape) if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]: raise Exception("Image and mask should be the same size") # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64) mask = mask.round() masked_image = image * (1.0 - mask[:,:,:,None]) return (masked_image, mask) # Get origin of the mask def cut_with_mask(mask, width, height): iy, ix = (mask == 1).nonzero(as_tuple=True) h0, w0 = mask.shape if iy.numel() == 0: x_c = w0 / 2.0 y_c = h0 / 2.0 else: x_min = ix.min().item() x_max = ix.max().item() y_min = iy.min().item() y_max = iy.max().item() if x_max - x_min > width or y_max - y_min > height: raise Exception("Masked area is bigger than provided dimensions") x_c = (x_min + x_max) / 2.0 y_c = (y_min + y_max) / 2.0 width2 = width / 2.0 height2 = height / 2.0 if w0 <= width: x0 = 0 w = w0 else: x0 = max(0, x_c - width2) w = width if x0 + width > w0: x0 = w0 - width if h0 <= height: y0 = 0 h = h0 else: y0 = max(0, y_c - height2) h = height if y0 + height > h0: y0 = h0 - height return (int(x0), int(y0), int(w), int(h)) # Prepare conditioning_latents @torch.inference_mode() def get_image_latents(masked_image, mask, vae, scaling_factor): processed_image = masked_image.to(vae.device) image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor processed_mask = 1. - mask[:,None,:,:] interpolated_mask = torch.nn.functional.interpolate( processed_mask, size=( image_latents.shape[-2], image_latents.shape[-1] ) ) interpolated_mask = interpolated_mask.to(image_latents.device) conditioning_latents = [image_latents, interpolated_mask] print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape) return conditioning_latents # Main function where magic happens @torch.inference_mode() def brushnet_inference(x, timesteps, transformer_options, debug): if 'model_patch' not in transformer_options: print('BrushNet inference: there is no model_patch key in transformer_options') return ([], 0, []) mp = transformer_options['model_patch'] if 'brushnet' not in mp: print('BrushNet inference: there is no brushnet key in mdel_patch') return ([], 0, []) bo = mp['brushnet'] if 'model' not in bo: print('BrushNet inference: there is no model key in brushnet') return ([], 0, []) brushnet = bo['model'] if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)): print('BrushNet model is not a BrushNetModel class') return ([], 0, []) torch_dtype = bo['dtype'] cl_list = bo['latents'] brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls'] pe = bo['prompt_embeds'] npe = bo['negative_prompt_embeds'] ppe, nppe, time_ids = bo['add_embeds'] #do_classifier_free_guidance = mp['free_guidance'] do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1 x = x.detach().clone() x = x.to(torch_dtype).to(brushnet.device) timesteps = timesteps.detach().clone() timesteps = timesteps.to(torch_dtype).to(brushnet.device) total_steps = mp['total_steps'] step = mp['step'] added_cond_kwargs = {} if do_classifier_free_guidance and step == 0: print('BrushNet inference: do_classifier_free_guidance is True') sub_idx = None if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']: sub_idx = transformer_options['ad_params']['sub_idxs'] # we have batch input images batch = cl_list[0].shape[0] # we have incoming latents latents_incoming = x.shape[0] # and we already got some latents_got = bo['latent_id'] if step == 0 or batch > 1: print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \ % (step, batch, latents_incoming, latents_got)) image_latents = [] masks = [] prompt_embeds = [] negative_prompt_embeds = [] pooled_prompt_embeds = [] negative_pooled_prompt_embeds = [] if sub_idx: # AnimateDiff indexes detected if step == 0: print('BrushNet inference: AnimateDiff indexes detected and applied') batch = len(sub_idx) if do_classifier_free_guidance: for i in sub_idx: image_latents.append(cl_list[0][i][None,:,:,:]) masks.append(cl_list[1][i][None,:,:,:]) prompt_embeds.append(pe) negative_prompt_embeds.append(npe) pooled_prompt_embeds.append(ppe) negative_pooled_prompt_embeds.append(nppe) for i in sub_idx: image_latents.append(cl_list[0][i][None,:,:,:]) masks.append(cl_list[1][i][None,:,:,:]) else: for i in sub_idx: image_latents.append(cl_list[0][i][None,:,:,:]) masks.append(cl_list[1][i][None,:,:,:]) prompt_embeds.append(pe) pooled_prompt_embeds.append(ppe) else: # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond continue_batch = True for i in range(latents_incoming): number = latents_got + i if number < batch: # 1st pass, cond image_latents.append(cl_list[0][number][None,:,:,:]) masks.append(cl_list[1][number][None,:,:,:]) prompt_embeds.append(pe) pooled_prompt_embeds.append(ppe) elif do_classifier_free_guidance and number < batch * 2: # 2nd pass, uncond image_latents.append(cl_list[0][number-batch][None,:,:,:]) masks.append(cl_list[1][number-batch][None,:,:,:]) negative_prompt_embeds.append(npe) negative_pooled_prompt_embeds.append(nppe) else: # latent batch image_latents.append(cl_list[0][0][None,:,:,:]) masks.append(cl_list[1][0][None,:,:,:]) prompt_embeds.append(pe) pooled_prompt_embeds.append(ppe) latents_got = -i continue_batch = False if continue_batch: # we don't have full batch yet if do_classifier_free_guidance: if number < batch * 2 - 1: bo['latent_id'] = number + 1 else: bo['latent_id'] = 0 else: if number < batch - 1: bo['latent_id'] = number + 1 else: bo['latent_id'] = 0 else: bo['latent_id'] = 0 cl = [] for il, m in zip(image_latents, masks): cl.append(torch.concat([il, m], dim=1)) cl2apply = torch.concat(cl, dim=0) conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device) # print("BrushNet CL: conditioning_latents shape =", conditioning_latents.shape) # print("BrushNet CL: x shape =", x.shape) prompt_embeds.extend(negative_prompt_embeds) prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device) if ppe is not None: added_cond_kwargs = {} added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device) pooled_prompt_embeds.extend(negative_pooled_prompt_embeds) pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device) added_cond_kwargs['text_embeds'] = pooled_prompt_embeds else: added_cond_kwargs = None if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]: if step == 0: print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image') conditioning_latents = torch.nn.functional.interpolate( conditioning_latents, size=( x.shape[2], x.shape[3], ), mode='bicubic', ).to(torch_dtype).to(brushnet.device) if step == 0: print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype) if debug: print('BrushNet: step =', step) if step < control_guidance_start or step > control_guidance_end: cond_scale = 0.0 else: cond_scale = brushnet_conditioning_scale return brushnet(x, encoder_hidden_states=prompt_embeds, brushnet_cond=conditioning_latents, timestep = timesteps, conditioning_scale=cond_scale, guess_mode=False, added_cond_kwargs=added_cond_kwargs, return_dict=False, debug=debug, ) # This is main patch function def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents, controls, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids, debug): is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL) if is_SDXL: input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d], [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample], [4, comfy.ldm.modules.attention.SpatialTransformer], [5, comfy.ldm.modules.attention.SpatialTransformer], [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample], [7, comfy.ldm.modules.attention.SpatialTransformer], [8, comfy.ldm.modules.attention.SpatialTransformer]] middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock] output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer], [1, comfy.ldm.modules.attention.SpatialTransformer], [2, comfy.ldm.modules.attention.SpatialTransformer], [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample], [3, comfy.ldm.modules.attention.SpatialTransformer], [4, comfy.ldm.modules.attention.SpatialTransformer], [5, comfy.ldm.modules.attention.SpatialTransformer], [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample], [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]] else: input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d], [1, comfy.ldm.modules.attention.SpatialTransformer], [2, comfy.ldm.modules.attention.SpatialTransformer], [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample], [4, comfy.ldm.modules.attention.SpatialTransformer], [5, comfy.ldm.modules.attention.SpatialTransformer], [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample], [7, comfy.ldm.modules.attention.SpatialTransformer], [8, comfy.ldm.modules.attention.SpatialTransformer], [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample], [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]] middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock] output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock], [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample], [3, comfy.ldm.modules.attention.SpatialTransformer], [4, comfy.ldm.modules.attention.SpatialTransformer], [5, comfy.ldm.modules.attention.SpatialTransformer], [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample], [6, comfy.ldm.modules.attention.SpatialTransformer], [7, comfy.ldm.modules.attention.SpatialTransformer], [8, comfy.ldm.modules.attention.SpatialTransformer], [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample], [9, comfy.ldm.modules.attention.SpatialTransformer], [10, comfy.ldm.modules.attention.SpatialTransformer], [11, comfy.ldm.modules.attention.SpatialTransformer]] def last_layer_index(block, tp): layer_list = [] for layer in block: layer_list.append(type(layer)) layer_list.reverse() if tp not in layer_list: return -1, layer_list.reverse() return len(layer_list) - 1 - layer_list.index(tp), layer_list def brushnet_forward(model, x, timesteps, transformer_options, control): if 'brushnet' not in transformer_options['model_patch']: input_samples = [] mid_sample = 0 output_samples = [] else: # brushnet inference input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug) # give additional samples to blocks for i, tp in input_blocks: idx, layer_list = last_layer_index(model.input_blocks[i], tp) if idx < 0: print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list) continue model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0 idx, layer_list = last_layer_index(model.middle_block, middle_block[1]) if idx < 0: print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list) model.middle_block[idx].add_sample_after = mid_sample for i, tp in output_blocks: idx, layer_list = last_layer_index(model.output_blocks[i], tp) if idx < 0: print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list) continue model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0 patch_model_function_wrapper(model, brushnet_forward) to = add_model_patch_option(model) mp = to['model_patch'] if 'brushnet' not in mp: mp['brushnet'] = {} bo = mp['brushnet'] bo['model'] = brushnet bo['dtype'] = torch_dtype bo['latents'] = conditioning_latents bo['controls'] = controls bo['prompt_embeds'] = prompt_embeds bo['negative_prompt_embeds'] = negative_prompt_embeds bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids) bo['latent_id'] = 0 # patch layers `forward` so we can apply brushnet def forward_patched_by_brushnet(self, x, *args, **kwargs): h = self.original_forward(x, *args, **kwargs) if hasattr(self, 'add_sample_after') and type(self): to_add = self.add_sample_after if torch.is_tensor(to_add): # interpolate due to RAUNet if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]: to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic') h += to_add.to(h.dtype).to(h.device) else: h += self.add_sample_after self.add_sample_after = 0 return h for i, block in enumerate(model.model.diffusion_model.input_blocks): for j, layer in enumerate(block): if not hasattr(layer, 'original_forward'): layer.original_forward = layer.forward layer.forward = types.MethodType(forward_patched_by_brushnet, layer) layer.add_sample_after = 0 for j, layer in enumerate(model.model.diffusion_model.middle_block): if not hasattr(layer, 'original_forward'): layer.original_forward = layer.forward layer.forward = types.MethodType(forward_patched_by_brushnet, layer) layer.add_sample_after = 0 for i, block in enumerate(model.model.diffusion_model.output_blocks): for j, layer in enumerate(block): if not hasattr(layer, 'original_forward'): layer.original_forward = layer.forward layer.forward = types.MethodType(forward_patched_by_brushnet, layer) layer.add_sample_after = 0