Spaces:
Running
Running
from typing import Callable, Union | |
import comfy.sample | |
from comfy.model_patcher import ModelPatcher | |
from comfy.controlnet import ControlBase | |
from comfy.ldm.modules.attention import BasicTransformerBlock | |
from .control import convert_all_to_advanced, restore_all_controlnet_conns | |
from .control_reference import (ReferenceAdvanced, ReferenceInjections, | |
RefBasicTransformerBlock, RefTimestepEmbedSequential, | |
InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder, | |
_forward_inject_BasicTransformerBlock, factory_forward_inject_UNetModel, | |
handle_context_ref_setup, | |
REF_CONTROL_LIST_ALL, CONTEXTREF_CLEAN_FUNC) | |
from .control_lllite import (ControlLLLiteAdvanced) | |
from .utils import torch_dfs | |
def support_sliding_context_windows(model, positive, negative) -> tuple[bool, dict, dict]: | |
# convert to advanced, with report if anything was actually modified | |
modified, new_conds = convert_all_to_advanced([positive, negative]) | |
positive, negative = new_conds | |
return modified, positive, negative | |
def has_sliding_context_windows(model): | |
motion_injection_params = getattr(model, "motion_injection_params", None) | |
if motion_injection_params is None: | |
return False | |
context_options = getattr(motion_injection_params, "context_options") | |
return context_options.context_length is not None | |
def get_contextref_obj(model): | |
motion_injection_params = getattr(model, "motion_injection_params", None) | |
if motion_injection_params is None: | |
return None | |
context_options = getattr(motion_injection_params, "context_options") | |
extras = getattr(context_options, "extras", None) | |
if extras is None: | |
return None | |
return getattr(extras, "context_ref", None) | |
def acn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable: | |
def get_refcn(control: ControlBase, order: int=-1): | |
ref_set: set[ReferenceAdvanced] = set() | |
if control is None: | |
return ref_set | |
if type(control) == ReferenceAdvanced and not control.is_context_ref: | |
control.order = order | |
order -= 1 | |
ref_set.add(control) | |
ref_set.update(get_refcn(control.previous_controlnet, order=order)) | |
return ref_set | |
def get_lllitecn(control: ControlBase): | |
cn_dict: dict[ControlLLLiteAdvanced,None] = {} | |
if control is None: | |
return cn_dict | |
if type(control) == ControlLLLiteAdvanced: | |
cn_dict[control] = None | |
cn_dict.update(get_lllitecn(control.previous_controlnet)) | |
return cn_dict | |
def acn_sample(model: ModelPatcher, *args, **kwargs): | |
controlnets_modified = False | |
orig_positive = args[-3] | |
orig_negative = args[-2] | |
try: | |
orig_model_options = model.model_options | |
# check if positive or negative conds contain ref cn | |
positive = args[-3] | |
negative = args[-2] | |
# if context options present, perform some special actions that may be required | |
context_refs = [] | |
if has_sliding_context_windows(model): | |
model.model_options = model.model_options.copy() | |
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() | |
# convert all CNs to Advanced if needed | |
controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative) | |
if controlnets_modified: | |
args = list(args) | |
args[-3] = positive | |
args[-2] = negative | |
args = tuple(args) | |
# enable ContextRef, if requested | |
existing_contextref_obj = get_contextref_obj(model) | |
if existing_contextref_obj is not None: | |
context_refs = handle_context_ref_setup(existing_contextref_obj, model.model_options["transformer_options"], positive, negative) | |
controlnets_modified = True | |
# look for Advanced ControlNets that will require intervention to work | |
ref_set = set() | |
lllite_dict: dict[ControlLLLiteAdvanced, None] = {} # dicts preserve insertion order since py3.7 | |
if positive is not None: | |
for cond in positive: | |
if "control" in cond[1]: | |
ref_set.update(get_refcn(cond[1]["control"])) | |
lllite_dict.update(get_lllitecn(cond[1]["control"])) | |
if negative is not None: | |
for cond in negative: | |
if "control" in cond[1]: | |
ref_set.update(get_refcn(cond[1]["control"])) | |
lllite_dict.update(get_lllitecn(cond[1]["control"])) | |
# if lllite found, apply patches to a cloned model_options, and continue | |
if len(lllite_dict) > 0: | |
lllite_list = list(lllite_dict.keys()) | |
model.model_options = model.model_options.copy() | |
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() | |
lllite_list.reverse() # reverse so that patches will be applied in expected order | |
for lll in lllite_list: | |
lll.live_model_patches(model.model_options) | |
# if no ref cn found, do original function immediately | |
if len(ref_set) == 0 and len(context_refs) == 0: | |
return orig_comfy_sample(model, *args, **kwargs) | |
# otherwise, injection time | |
try: | |
# inject | |
# storage for all Reference-related injections | |
reference_injections = ReferenceInjections() | |
# first, handle attn module injection | |
all_modules = torch_dfs(model.model) | |
attn_modules: list[RefBasicTransformerBlock] = [] | |
for module in all_modules: | |
if isinstance(module, BasicTransformerBlock): | |
attn_modules.append(module) | |
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)] | |
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) | |
for i, module in enumerate(attn_modules): | |
injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i) | |
injection_holder.attn_weight = float(i) / float(len(attn_modules)) | |
if hasattr(module, "_forward"): # backward compatibility | |
module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) | |
else: | |
module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module)) | |
module.injection_holder = injection_holder | |
reference_injections.attn_modules.append(module) | |
# figure out which module is middle block | |
if hasattr(model.model.diffusion_model, "middle_block"): | |
mid_modules = torch_dfs(model.model.diffusion_model.middle_block) | |
mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)] | |
for module in mid_attn_modules: | |
module.injection_holder.is_middle = True | |
# next, handle gn module injection (TimestepEmbedSequential) | |
# TODO: figure out the logic behind these hardcoded indexes | |
if type(model.model).__name__ == "SDXL": | |
input_block_indices = [4, 5, 7, 8] | |
output_block_indices = [0, 1, 2, 3, 4, 5] | |
else: | |
input_block_indices = [4, 5, 7, 8, 10, 11] | |
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7] | |
if hasattr(model.model.diffusion_model, "middle_block"): | |
module = model.model.diffusion_model.middle_block | |
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True) | |
injection_holder.gn_weight = 0.0 | |
module.injection_holder = injection_holder | |
reference_injections.gn_modules.append(module) | |
for w, i in enumerate(input_block_indices): | |
module = model.model.diffusion_model.input_blocks[i] | |
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True) | |
injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices)) | |
module.injection_holder = injection_holder | |
reference_injections.gn_modules.append(module) | |
for w, i in enumerate(output_block_indices): | |
module = model.model.diffusion_model.output_blocks[i] | |
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True) | |
injection_holder.gn_weight = float(w) / float(len(output_block_indices)) | |
module.injection_holder = injection_holder | |
reference_injections.gn_modules.append(module) | |
# hack gn_module forwards and update weights | |
for i, module in enumerate(reference_injections.gn_modules): | |
module.injection_holder.gn_weight *= 2 | |
# handle diffusion_model forward injection | |
reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward | |
model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model)) | |
# store ordered ref cns in model's transformer options | |
new_model_options = model.model_options.copy() | |
new_model_options["transformer_options"] = model.model_options["transformer_options"].copy() | |
ref_list: list[ReferenceAdvanced] = list(ref_set) | |
new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order) | |
new_model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC] = reference_injections.clean_contextref_module_mem | |
model.model_options = new_model_options | |
# continue with original function | |
return orig_comfy_sample(model, *args, **kwargs) | |
finally: | |
# cleanup injections | |
# restore attn modules | |
attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules | |
for module in attn_modules: | |
module.injection_holder.restore(module) | |
module.injection_holder.clean_all() | |
del module.injection_holder | |
del attn_modules | |
# restore gn modules | |
gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules | |
for module in gn_modules: | |
module.injection_holder.restore(module) | |
module.injection_holder.clean_all() | |
del module.injection_holder | |
del gn_modules | |
# restore diffusion_model forward function | |
model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model)) | |
# cleanup | |
reference_injections.cleanup() | |
finally: | |
# restore model_options | |
model.model_options = orig_model_options | |
# restore controlnets in conds, if needed | |
if controlnets_modified: | |
restore_all_controlnet_conns([orig_positive, orig_negative]) | |
return acn_sample | |