from typing import Callable, Union from torch import Tensor import torch import os import comfy.ops import comfy.utils import comfy.model_management import comfy.model_detection import comfy.controlnet as comfy_cn from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, StrengthType from comfy.model_patcher import ModelPatcher from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst from .control_lllite import LLLiteModule, LLLitePatch, load_controllllite from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException, manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory, broadcast_image_to_extend, extend_to_batch_size, ORIG_PREVIOUS_CONTROLNET, CONTROL_INIT_BY_ACN) from .logger import logger class ControlNetAdvanced(ControlNet, AdvancedControlBase): def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT): super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet()) self.is_flux = False self.x_noisy_shape = None def get_universal_weights(self) -> ControlWeights: def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str): if key == "middle": return 1.0 c_len = len(control[key]) raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)] raw_weights = raw_weights[:-1] if key == "input": raw_weights.reverse() return raw_weights[idx] return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func) def get_control_advanced(self, x_noisy, t, cond, batched_number): # perform special version of get_control that supports sliding context and masks return self.sliding_get_control(x_noisy, t, cond, batched_number) def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if control_prev is not None: return control_prev else: return None dtype = self.control_model.dtype if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype # make cond_hint appropriate dimensions # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None compression_ratio = self.compression_ratio if self.vae is not None: compression_ratio *= self.vae.downscale_ratio # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling if self.sub_idxs is not None: actual_cond_hint_orig = self.cond_hint_original if self.cond_hint_original.size(0) < self.full_latent_length: actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") else: self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) comfy.model_management.load_models_gpu(loaded_models) if self.latent_format is not None: self.cond_hint = self.latent_format.process_in(self.cond_hint) self.cond_hint = self.cond_hint.to(device=x_noisy.device, dtype=dtype) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number) # prepare mask_cond_hint self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype) context = cond.get('crossattn_controlnet', cond['c_crossattn']) extra = self.extra_args.copy() for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: extra[c] = temp.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) self.x_noisy_shape = x_noisy.shape control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) return self.control_merge(control, control_prev, output_dtype=None) def pre_run_advanced(self, *args, **kwargs): self.is_flux = "Flux" in str(type(self.control_model).__name__) return super().pre_run_advanced(*args, **kwargs) def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape=None): if self.is_flux: flux_shape = self.x_noisy_shape return super().apply_advanced_strengths_and_masks(x, batched_number, flux_shape) def copy(self): c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c.control_model = self.control_model c.control_model_wrapped = self.control_model_wrapped self.copy_to(c) self.copy_to_advanced(c) return c def cleanup_advanced(self): self.x_noisy_shape = None return super().cleanup_advanced() @staticmethod def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced': to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype) v.copy_to(to_return) return to_return class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase): def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channels_in, compression_ratio=8, upscale_algorithm="nearest_exact", device=None): super().__init__(t2i_model=t2i_model, channels_in=channels_in, compression_ratio=compression_ratio, upscale_algorithm=upscale_algorithm, device=device) AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.t2iadapter()) def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, output_dtype): # match batch_size # TODO: make this more efficient by modifying the cached self.control_input val instead of doing this every step for key in control: control_current = control[key] for i in range(len(control_current)): x = control_current[i] if x is not None and x.size(0) == 1 and x.size(0) != self.batch_size: control_current[i] = x.repeat(self.batch_size, 1, 1, 1)[:self.batch_size] return AdvancedControlBase.control_merge_inject(self, control, control_prev, output_dtype) def get_universal_weights(self) -> ControlWeights: def t2i_weights_func(idx: int, control: dict[str, list[Tensor]], key: str): if key == "middle": return 1.0 c_len = 8 #len(control[key]) raw_weights = [(self.weights.base_multiplier ** float((c_len-1) - i)) for i in range(c_len)] raw_weights = [raw_weights[-c_len], raw_weights[-3], raw_weights[-2], raw_weights[-1]] raw_weights = get_properly_arranged_t2i_weights(raw_weights) if key == "input": raw_weights.reverse() return raw_weights[idx] return self.weights.copy_with_new_weights(new_weight_func=t2i_weights_func) def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int: if key == "middle": return 0 # match how T2IAdapterAdvanced deals with universal weights c_len = 8 #len(control[key]) indeces = [(c_len-1) - i for i in range(c_len)] indeces = [indeces[-c_len], indeces[-3], indeces[-2], indeces[-1]] indeces = get_properly_arranged_t2i_weights(indeces) if key == "input": indeces.reverse() # need to reverse to match recent ComfyUI changes return indeces[idx] def get_control_advanced(self, x_noisy, t, cond, batched_number): try: # if sub indexes present, replace original hint with subsection if self.sub_idxs is not None: # cond hints full_cond_hint_original = self.cond_hint_original actual_cond_hint_orig = full_cond_hint_original del self.cond_hint self.cond_hint = None if full_cond_hint_original.size(0) < self.full_latent_length: actual_cond_hint_orig = extend_to_batch_size(tensor=full_cond_hint_original, batch_size=full_cond_hint_original.size(0)) self.cond_hint_original = actual_cond_hint_orig[self.sub_idxs] # mask hints self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number) return super().get_control(x_noisy, t, cond, batched_number) finally: if self.sub_idxs is not None: # replace original cond hint self.cond_hint_original = full_cond_hint_original del full_cond_hint_original def copy(self): c = T2IAdapterAdvanced(self.t2i_model, self.timestep_keyframes, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) self.copy_to_advanced(c) return c def cleanup(self): super().cleanup() self.cleanup_advanced() @staticmethod def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced': to_return = T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in, compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device) v.copy_to(to_return) return to_return class ControlLoraAdvanced(ControlLora, AdvancedControlBase): def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False): super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling) AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora()) # use some functions from ControlNetAdvanced self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self)) self.sliding_get_control = ControlNetAdvanced.sliding_get_control.__get__(self, type(self)) def get_universal_weights(self) -> ControlWeights: raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)] return self.weights.copy_with_new_weights(raw_weights) def copy(self): c = ControlLoraAdvanced(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling) self.copy_to(c) self.copy_to_advanced(c) return c def cleanup(self): super().cleanup() self.cleanup_advanced() @staticmethod def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced': to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe, global_average_pooling=v.global_average_pooling) v.copy_to(to_return) return to_return class SVDControlNetAdvanced(ControlNetAdvanced): def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None): super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) def set_cond_hint_inject(self, *args, **kwargs): to_return = super().set_cond_hint_inject(*args, **kwargs) # cond hint for SVD-ControlNet needs to be scaled between (-1, 1) instead of (0, 1) self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0 return to_return def get_control_advanced(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if control_prev is not None: return control_prev else: return None dtype = self.control_model.dtype if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype # make cond_hint appropriate dimensions # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling if self.sub_idxs is not None: actual_cond_hint_orig = self.cond_hint_original if self.cond_hint_original.size(0) < self.full_latent_length: actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device) else: self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number) # prepare mask_cond_hint self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype) context = cond.get('crossattn_controlnet', cond['c_crossattn']) # uses 'y' in new ComfyUI update y = cond.get('y', None) if y is not None: y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) # concat c_concat if exists (should exist for SVD), doubling channels to 8 if cond.get('c_concat', None) is not None: x_noisy = torch.cat([x_noisy] + [cond['c_concat']], dim=1) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, cond=cond) return self.control_merge(control, control_prev, output_dtype) def copy(self): c = SVDControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) self.copy_to(c) self.copy_to_advanced(c) return c class SparseCtrlAdvanced(ControlNetAdvanced): def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, load_device=None, manual_cast_dtype=None): super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) self.control_model_wrapped = SparseModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.add_compatible_weight(ControlWeightType.SPARSECTRL) self.control_model: SparseControlNet = self.control_model # does nothing except help with IDE hints if self.control_model.use_simplified_conditioning_embedding: # TODO: allow vae_optional to be used instead of preprocessor #self.require_vae = True self.allow_condhint_latents = True self.sparse_settings = sparse_settings if sparse_settings is not None else SparseSettings.default() self.model_latent_format = None # latent format for active SD model, NOT controlnet self.preprocessed = False def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int): # normal ControlNet stuff control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: if control_prev is not None: return control_prev else: return None dtype = self.control_model.dtype if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype # set actual input length on motion model actual_length = x_noisy.size(0)//batched_number full_length = actual_length if self.sub_idxs is None else self.full_latent_length self.control_model.set_actual_length(actual_length=actual_length, full_length=full_length) # prepare cond_hint, if needed dim_mult = 1 if self.control_model.use_simplified_conditioning_embedding else 8 if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2]*dim_mult != self.cond_hint.shape[2] or x_noisy.shape[3]*dim_mult != self.cond_hint.shape[3]: # clear out cond_hint and conditioning_mask if self.cond_hint is not None: del self.cond_hint self.cond_hint = None # first, figure out which cond idxs are relevant, and where they fit in cond_idxs, hint_order = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length, sub_idxs=self.sub_idxs if self.sparse_settings.is_context_aware() else None) range_idxs = list(range(full_length)) if self.sub_idxs is None else self.sub_idxs hint_idxs = [] # idxs in cond_idxs local_idxs = [] # idx to put in final cond_hint for i,cond_idx in enumerate(cond_idxs): if cond_idx in range_idxs: hint_idxs.append(i) local_idxs.append(range_idxs.index(cond_idx)) # log_string = f"cond_idxs: {cond_idxs}, local_idxs: {local_idxs}, hint_idxs: {hint_idxs}, hint_order: {hint_order}" # if self.sub_idxs is not None: # log_string += f" sub_idxs: {self.sub_idxs[0]}-{self.sub_idxs[-1]}" # logger.warn(log_string) # determine cond/uncond indexes that will get masked self.local_sparse_idxs = [] self.local_sparse_idxs_inverse = list(range(x_noisy.size(0))) for batch_idx in range(batched_number): for i in local_idxs: actual_i = i+(batch_idx*actual_length) self.local_sparse_idxs.append(actual_i) if actual_i in self.local_sparse_idxs_inverse: self.local_sparse_idxs_inverse.remove(actual_i) # sub_cond_hint now contains the hints relevant to current x_noisy if hint_order is None: sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(x_noisy.device) else: sub_cond_hint = self.cond_hint_original[hint_order][hint_idxs].to(dtype).to(x_noisy.device) # scale cond_hints to match noisy input if self.control_model.use_simplified_conditioning_embedding: # RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts sub_cond_hint = self.model_latent_format.process_in(sub_cond_hint) # multiplies by model scale factor sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(x_noisy.device) else: # other SparseCtrl; inputs are typical images sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device) # prepare cond_hint (b, c, h ,w) cond_shape = list(sub_cond_hint.shape) cond_shape[0] = len(range_idxs) self.cond_hint = torch.zeros(cond_shape).to(dtype).to(x_noisy.device) self.cond_hint[local_idxs] = sub_cond_hint[:] # prepare cond_mask (b, 1, h, w) cond_shape[1] = 1 cond_mask = torch.zeros(cond_shape).to(dtype).to(x_noisy.device) cond_mask[local_idxs] = self.sparse_settings.sparse_mask_mult * self.weights.extras.get(SparseConst.MASK_MULT, 1.0) # combine cond_hint and cond_mask into (b, c+1, h, w) if not self.sparse_settings.merged: self.cond_hint = torch.cat([self.cond_hint, cond_mask], dim=1) del sub_cond_hint del cond_mask # make cond_hint match x_noisy batch if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number) # prepare mask_cond_hint self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype) context = cond['c_crossattn'] y = cond.get('y', None) if y is not None: y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(control, control_prev, output_dtype) def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, *args, **kwargs): # apply mults to indexes with and without a direct condhint x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0) x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0) return super().apply_advanced_strengths_and_masks(x, batched_number, *args, **kwargs) def pre_run_advanced(self, model, percent_to_timestep_function): super().pre_run_advanced(model, percent_to_timestep_function) if isinstance(self.cond_hint_original, AbstractPreprocWrapper): if not self.control_model.use_simplified_conditioning_embedding: raise ValueError("Any model besides RGB SparseCtrl should NOT have its images go through the RGB SparseCtrl preprocessor.") self.cond_hint_original = self.cond_hint_original.condhint self.model_latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint if self.control_model.motion_wrapper is not None: self.control_model.motion_wrapper.reset() self.control_model.motion_wrapper.set_strength(self.sparse_settings.motion_strength) self.control_model.motion_wrapper.set_scale_multiplier(self.sparse_settings.motion_scale) def cleanup_advanced(self): super().cleanup_advanced() if self.model_latent_format is not None: del self.model_latent_format self.model_latent_format = None self.local_sparse_idxs = None self.local_sparse_idxs_inverse = None def copy(self): c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.load_device, self.manual_cast_dtype) self.copy_to(c) self.copy_to_advanced(c) return c def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) # from pathlib import Path # log_name = ckpt_path.split('\\')[-1] # with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile: # for key, value in controlnet_data.items(): # afile.write(f"{key}:\t{value.shape}\n") control = None # check if a non-vanilla ControlNet controlnet_type = ControlWeightType.DEFAULT has_controlnet_key = False has_motion_modules_key = False has_temporal_res_block_key = False for key in controlnet_data: # LLLite check if "lllite" in key: controlnet_type = ControlWeightType.CONTROLLLLITE break # SparseCtrl check elif "motion_modules" in key: has_motion_modules_key = True elif "controlnet" in key: has_controlnet_key = True # SVD-ControlNet check elif "temporal_res_block" in key: has_temporal_res_block_key = True # ControlNet++ check elif "task_embedding" in key: pass if has_controlnet_key and has_motion_modules_key: controlnet_type = ControlWeightType.SPARSECTRL elif has_controlnet_key and has_temporal_res_block_key: controlnet_type = ControlWeightType.SVD_CONTROLNET if controlnet_type != ControlWeightType.DEFAULT: if controlnet_type == ControlWeightType.CONTROLLLLITE: control = load_controllllite(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe) elif controlnet_type == ControlWeightType.SPARSECTRL: control = load_sparsectrl(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe, model=model) elif controlnet_type == ControlWeightType.SVD_CONTROLNET: control = load_svdcontrolnet(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe) # otherwise, load vanilla ControlNet else: try: # hacky way of getting load_torch_file in load_controlnet to use already-present controlnet_data and not redo loading orig_load_torch_file = comfy.utils.load_torch_file comfy.utils.load_torch_file = load_torch_file_with_dict_factory(controlnet_data, orig_load_torch_file) control = comfy_cn.load_controlnet(ckpt_path, model=model) finally: comfy.utils.load_torch_file = orig_load_torch_file return convert_to_advanced(control, timestep_keyframe=timestep_keyframe) def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None): # if already advanced, leave it be if is_advanced_controlnet(control): return control # if exactly ControlNet returned, transform it into ControlNetAdvanced if type(control) == ControlNet: control = ControlNetAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) if is_sd3_advanced_controlnet(control): control.require_vae = True return control # if exactly ControlLora returned, transform it into ControlLoraAdvanced elif type(control) == ControlLora: return ControlLoraAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) # if T2IAdapter returned, transform it into T2IAdapterAdvanced elif isinstance(control, T2IAdapter): return T2IAdapterAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe) # otherwise, leave it be - might be something I am not supporting yet return control def convert_all_to_advanced(conds: list[list[dict[str]]]) -> tuple[bool, list]: cache = {} modified = False new_conds = [] for cond in conds: converted_cond = None if cond is not None: need_to_convert = False # first, check if there is even a need to convert for sub_cond in cond: actual_cond = sub_cond[1] if "control" in actual_cond: if not are_all_advanced_controlnet(actual_cond["control"]): need_to_convert = True break if not need_to_convert: converted_cond = cond else: converted_cond = [] for sub_cond in cond: new_sub_cond: list = [] for actual_cond in sub_cond: if not type(actual_cond) == dict: new_sub_cond.append(actual_cond) continue if "control" not in actual_cond: new_sub_cond.append(actual_cond) elif are_all_advanced_controlnet(actual_cond["control"]): new_sub_cond.append(actual_cond) else: actual_cond = actual_cond.copy() actual_cond["control"] = _convert_all_control_to_advanced(actual_cond["control"], cache) new_sub_cond.append(actual_cond) modified = True converted_cond.append(new_sub_cond) new_conds.append(converted_cond) return modified, new_conds def _convert_all_control_to_advanced(input_object: ControlBase, cache: dict): output_object = input_object # iteratively convert to advanced, if needed next_cn = None curr_cn = input_object iter = 0 while curr_cn is not None: if not is_advanced_controlnet(curr_cn): # if already in cache, then conversion was done before, so just link it and exit if curr_cn in cache: new_cn = cache[curr_cn] if next_cn is not None: setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet) next_cn.previous_controlnet = new_cn if iter == 0: # if was top-level controlnet, that's the new output output_object = new_cn break try: # convert to advanced, and assign previous_controlnet (convert doesn't transfer it) new_cn = convert_to_advanced(curr_cn) except Exception as e: raise Exception("Failed to automatically convert a ControlNet to Advanced to support sliding window context.", e) new_cn.previous_controlnet = curr_cn.previous_controlnet if iter == 0: # if was top-level controlnet, that's the new output output_object = new_cn # if next_cn is present, then it needs to be pointed to new_cn if next_cn is not None: setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet) next_cn.previous_controlnet = new_cn # add to cache cache[curr_cn] = new_cn curr_cn = new_cn next_cn = curr_cn curr_cn = curr_cn.previous_controlnet iter += 1 return output_object def restore_all_controlnet_conns(conds: list[list[dict[str]]]): # if a cn has an _orig_previous_controlnet property, restore it and delete for main_cond in conds: if main_cond is not None: for cond in main_cond: if "control" in cond[1]: # if ACN is the one to have initialized it, delete it # TODO: maybe check if someone else did a similar hack, and carefully pluck out our stuff? if CONTROL_INIT_BY_ACN in cond[1]: cond[1].pop("control") cond[1].pop(CONTROL_INIT_BY_ACN) else: _restore_all_controlnet_conns(cond[1]["control"]) def _restore_all_controlnet_conns(input_object: ControlBase): # restore original previous_controlnet if needed curr_cn = input_object while curr_cn is not None: if hasattr(curr_cn, ORIG_PREVIOUS_CONTROLNET): curr_cn.previous_controlnet = getattr(curr_cn, ORIG_PREVIOUS_CONTROLNET) delattr(curr_cn, ORIG_PREVIOUS_CONTROLNET) curr_cn = curr_cn.previous_controlnet def are_all_advanced_controlnet(input_object: ControlBase): # iteratively check if linked controlnets objects are all advanced curr_cn = input_object while curr_cn is not None: if not is_advanced_controlnet(curr_cn): return False curr_cn = curr_cn.previous_controlnet return True def is_advanced_controlnet(input_object): return hasattr(input_object, "sub_idxs") def is_sd3_advanced_controlnet(input_object: ControlNetAdvanced): return type(input_object) == ControlNetAdvanced and input_object.latent_format is not None def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced: if controlnet_data is None: controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) # first, separate out motion part from normal controlnet part and attempt to load that portion motion_data = {} for key in list(controlnet_data.keys()): if "temporal" in key: motion_data[key] = controlnet_data.pop(key) if len(motion_data) == 0: raise ValueError(f"No motion-related keys in '{ckpt_path}'; not a valid SparseCtrl model!") # now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function controlnet_config: dict[str] = None is_diffusers = False use_simplified_conditioning_embedding = False if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: is_diffusers = True if "controlnet_cond_embedding.weight" in controlnet_data: is_diffusers = True use_simplified_conditioning_embedding = True if is_diffusers: #diffusers format unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" count = 0 loop = True while loop: suffix = [".weight", ".bias"] for s in suffix: k_in = "controlnet_down_blocks.{}{}".format(count, s) k_out = "zero_convs.{}.0{}".format(count, s) if k_in not in controlnet_data: loop = False break diffusers_keys[k_in] = k_out count += 1 # normal conditioning embedding if not use_simplified_conditioning_embedding: count = 0 loop = True while loop: suffix = [".weight", ".bias"] for s in suffix: if count == 0: k_in = "controlnet_cond_embedding.conv_in{}".format(s) else: k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) k_out = "input_hint_block.{}{}".format(count * 2, s) if k_in not in controlnet_data: k_in = "controlnet_cond_embedding.conv_out{}".format(s) loop = False diffusers_keys[k_in] = k_out count += 1 # simplified conditioning embedding else: count = 0 suffix = [".weight", ".bias"] for s in suffix: k_in = "controlnet_cond_embedding{}".format(s) k_out = "input_hint_block.{}{}".format(count, s) diffusers_keys[k_in] = k_out new_sd = {} for k in diffusers_keys: if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: logger.info("leftover keys:", leftover_keys) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' pth = False key = 'zero_convs.0.0.weight' if pth_key in controlnet_data: pth = True key = pth_key prefix = "control_model." elif key in controlnet_data: prefix = "" else: raise ValueError("The provided model is not a valid SparseCtrl model! [ErrorCode: HORSERADISH]") if controlnet_config is None: unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = manual_cast_clean_groupnorm else: controlnet_config["operations"] = disable_weight_init_clean_groupnorm controlnet_config.pop("out_channels") # get proper hint channels if use_simplified_conditioning_embedding: controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding else: controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding control_model = SparseControlNet(**controlnet_config) if pth: if 'difference' in controlnet_data: if model is not None: comfy.model_management.load_models_gpu([model]) model_sd = model.model_state_dict() for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): sd_key = "diffusion_model.{}".format(x[len(c_m):]) if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) else: logger.warning("WARNING: Loaded a diff SparseCtrl without a model. It will very likely not work.") class WeightsLoader(torch.nn.Module): pass w = WeightsLoader() w.control_model = control_model missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) if len(missing) > 0 or len(unexpected) > 0: logger.info(f"SparseCtrl ControlNet: {missing}, {unexpected}") global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True # actually load motion portion of model now motion_wrapper: SparseCtrlMotionWrapper = SparseCtrlMotionWrapper(motion_data, ops=controlnet_config.get("operations", None)).to(comfy.model_management.unet_dtype()) missing, unexpected = motion_wrapper.load_state_dict(motion_data) if len(missing) > 0 or len(unexpected) > 0: logger.info(f"SparseCtrlMotionWrapper: {missing}, {unexpected}") # both motion portion and controlnet portions are loaded; bring them together if using motion model if sparse_settings.use_motion: motion_wrapper.inject(control_model) control = SparseCtrlAdvanced(control_model, timestep_keyframes=timestep_keyframe, sparse_settings=sparse_settings, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None): if controlnet_data is None: controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) controlnet_config = None if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format unet_dtype = comfy.model_management.unet_dtype() controlnet_config = svd_unet_config_from_diffusers_unet(controlnet_data, unet_dtype) diffusers_keys = svd_unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" count = 0 loop = True while loop: suffix = [".weight", ".bias"] for s in suffix: k_in = "controlnet_down_blocks.{}{}".format(count, s) k_out = "zero_convs.{}.0{}".format(count, s) if k_in not in controlnet_data: loop = False break diffusers_keys[k_in] = k_out count += 1 count = 0 loop = True while loop: suffix = [".weight", ".bias"] for s in suffix: if count == 0: k_in = "controlnet_cond_embedding.conv_in{}".format(s) else: k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) k_out = "input_hint_block.{}{}".format(count * 2, s) if k_in not in controlnet_data: k_in = "controlnet_cond_embedding.conv_out{}".format(s) loop = False diffusers_keys[k_in] = k_out count += 1 new_sd = {} for k in diffusers_keys: if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: spatial_leftover_keys = [] temporal_leftover_keys = [] other_leftover_keys = [] for key in leftover_keys: if "spatial" in key: spatial_leftover_keys.append(key) elif "temporal" in key: temporal_leftover_keys.append(key) else: other_leftover_keys.append(key) logger.warn(f"spatial_leftover_keys ({len(spatial_leftover_keys)}): {spatial_leftover_keys}") logger.warn(f"temporal_leftover_keys ({len(temporal_leftover_keys)}): {temporal_leftover_keys}") logger.warn(f"other_leftover_keys ({len(other_leftover_keys)}): {other_leftover_keys}") #print("leftover keys:", leftover_keys) controlnet_data = new_sd pth_key = 'control_model.zero_convs.0.0.weight' pth = False key = 'zero_convs.0.0.weight' if pth_key in controlnet_data: pth = True key = pth_key prefix = "control_model." elif key in controlnet_data: prefix = "" else: raise ValueError("The provided model is not a valid SVD-ControlNet model! [ErrorCode: MUSTARD]") if controlnet_config is None: unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config load_device = comfy.model_management.get_torch_device() manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = comfy.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = SVDControlNet(**controlnet_config) if pth: if 'difference' in controlnet_data: if model is not None: comfy.model_management.load_models_gpu([model]) model_sd = model.model_state_dict() for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): sd_key = "diffusion_model.{}".format(x[len(c_m):]) if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) else: print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") class WeightsLoader(torch.nn.Module): pass w = WeightsLoader() w.control_model = control_model missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) if len(missing) > 0 or len(unexpected) > 0: logger.info(f"SVD-ControlNet: {missing}, {unexpected}") global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True control = SVDControlNetAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control