gartajackhats1985's picture
Upload 45 files
028694a verified
raw
history blame
12.4 kB
from typing import Callable, Union
import comfy.sample
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
from comfy.ldm.modules.attention import BasicTransformerBlock
from .control import convert_all_to_advanced, restore_all_controlnet_conns
from .control_reference import (ReferenceAdvanced, ReferenceInjections,
RefBasicTransformerBlock, RefTimestepEmbedSequential,
InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder,
_forward_inject_BasicTransformerBlock, factory_forward_inject_UNetModel,
handle_context_ref_setup,
REF_CONTROL_LIST_ALL, CONTEXTREF_CLEAN_FUNC)
from .control_lllite import (ControlLLLiteAdvanced)
from .utils import torch_dfs
def support_sliding_context_windows(model, positive, negative) -> tuple[bool, dict, dict]:
# convert to advanced, with report if anything was actually modified
modified, new_conds = convert_all_to_advanced([positive, negative])
positive, negative = new_conds
return modified, positive, negative
def has_sliding_context_windows(model):
motion_injection_params = getattr(model, "motion_injection_params", None)
if motion_injection_params is None:
return False
context_options = getattr(motion_injection_params, "context_options")
return context_options.context_length is not None
def get_contextref_obj(model):
motion_injection_params = getattr(model, "motion_injection_params", None)
if motion_injection_params is None:
return None
context_options = getattr(motion_injection_params, "context_options")
extras = getattr(context_options, "extras", None)
if extras is None:
return None
return getattr(extras, "context_ref", None)
def acn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
def get_refcn(control: ControlBase, order: int=-1):
ref_set: set[ReferenceAdvanced] = set()
if control is None:
return ref_set
if type(control) == ReferenceAdvanced and not control.is_context_ref:
control.order = order
order -= 1
ref_set.add(control)
ref_set.update(get_refcn(control.previous_controlnet, order=order))
return ref_set
def get_lllitecn(control: ControlBase):
cn_dict: dict[ControlLLLiteAdvanced,None] = {}
if control is None:
return cn_dict
if type(control) == ControlLLLiteAdvanced:
cn_dict[control] = None
cn_dict.update(get_lllitecn(control.previous_controlnet))
return cn_dict
def acn_sample(model: ModelPatcher, *args, **kwargs):
controlnets_modified = False
orig_positive = args[-3]
orig_negative = args[-2]
try:
orig_model_options = model.model_options
# check if positive or negative conds contain ref cn
positive = args[-3]
negative = args[-2]
# if context options present, perform some special actions that may be required
context_refs = []
if has_sliding_context_windows(model):
model.model_options = model.model_options.copy()
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
# convert all CNs to Advanced if needed
controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative)
if controlnets_modified:
args = list(args)
args[-3] = positive
args[-2] = negative
args = tuple(args)
# enable ContextRef, if requested
existing_contextref_obj = get_contextref_obj(model)
if existing_contextref_obj is not None:
context_refs = handle_context_ref_setup(existing_contextref_obj, model.model_options["transformer_options"], positive, negative)
controlnets_modified = True
# look for Advanced ControlNets that will require intervention to work
ref_set = set()
lllite_dict: dict[ControlLLLiteAdvanced, None] = {} # dicts preserve insertion order since py3.7
if positive is not None:
for cond in positive:
if "control" in cond[1]:
ref_set.update(get_refcn(cond[1]["control"]))
lllite_dict.update(get_lllitecn(cond[1]["control"]))
if negative is not None:
for cond in negative:
if "control" in cond[1]:
ref_set.update(get_refcn(cond[1]["control"]))
lllite_dict.update(get_lllitecn(cond[1]["control"]))
# if lllite found, apply patches to a cloned model_options, and continue
if len(lllite_dict) > 0:
lllite_list = list(lllite_dict.keys())
model.model_options = model.model_options.copy()
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
lllite_list.reverse() # reverse so that patches will be applied in expected order
for lll in lllite_list:
lll.live_model_patches(model.model_options)
# if no ref cn found, do original function immediately
if len(ref_set) == 0 and len(context_refs) == 0:
return orig_comfy_sample(model, *args, **kwargs)
# otherwise, injection time
try:
# inject
# storage for all Reference-related injections
reference_injections = ReferenceInjections()
# first, handle attn module injection
all_modules = torch_dfs(model.model)
attn_modules: list[RefBasicTransformerBlock] = []
for module in all_modules:
if isinstance(module, BasicTransformerBlock):
attn_modules.append(module)
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
for i, module in enumerate(attn_modules):
injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i)
injection_holder.attn_weight = float(i) / float(len(attn_modules))
if hasattr(module, "_forward"): # backward compatibility
module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
else:
module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
module.injection_holder = injection_holder
reference_injections.attn_modules.append(module)
# figure out which module is middle block
if hasattr(model.model.diffusion_model, "middle_block"):
mid_modules = torch_dfs(model.model.diffusion_model.middle_block)
mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)]
for module in mid_attn_modules:
module.injection_holder.is_middle = True
# next, handle gn module injection (TimestepEmbedSequential)
# TODO: figure out the logic behind these hardcoded indexes
if type(model.model).__name__ == "SDXL":
input_block_indices = [4, 5, 7, 8]
output_block_indices = [0, 1, 2, 3, 4, 5]
else:
input_block_indices = [4, 5, 7, 8, 10, 11]
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
if hasattr(model.model.diffusion_model, "middle_block"):
module = model.model.diffusion_model.middle_block
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True)
injection_holder.gn_weight = 0.0
module.injection_holder = injection_holder
reference_injections.gn_modules.append(module)
for w, i in enumerate(input_block_indices):
module = model.model.diffusion_model.input_blocks[i]
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True)
injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
module.injection_holder = injection_holder
reference_injections.gn_modules.append(module)
for w, i in enumerate(output_block_indices):
module = model.model.diffusion_model.output_blocks[i]
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True)
injection_holder.gn_weight = float(w) / float(len(output_block_indices))
module.injection_holder = injection_holder
reference_injections.gn_modules.append(module)
# hack gn_module forwards and update weights
for i, module in enumerate(reference_injections.gn_modules):
module.injection_holder.gn_weight *= 2
# handle diffusion_model forward injection
reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward
model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model))
# store ordered ref cns in model's transformer options
new_model_options = model.model_options.copy()
new_model_options["transformer_options"] = model.model_options["transformer_options"].copy()
ref_list: list[ReferenceAdvanced] = list(ref_set)
new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
new_model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC] = reference_injections.clean_contextref_module_mem
model.model_options = new_model_options
# continue with original function
return orig_comfy_sample(model, *args, **kwargs)
finally:
# cleanup injections
# restore attn modules
attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
for module in attn_modules:
module.injection_holder.restore(module)
module.injection_holder.clean_all()
del module.injection_holder
del attn_modules
# restore gn modules
gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
for module in gn_modules:
module.injection_holder.restore(module)
module.injection_holder.clean_all()
del module.injection_holder
del gn_modules
# restore diffusion_model forward function
model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model))
# cleanup
reference_injections.cleanup()
finally:
# restore model_options
model.model_options = orig_model_options
# restore controlnets in conds, if needed
if controlnets_modified:
restore_all_controlnet_conns([orig_positive, orig_negative])
return acn_sample