File size: 12,372 Bytes
028694a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from typing import Callable, Union

import comfy.sample
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
from comfy.ldm.modules.attention import BasicTransformerBlock


from .control import convert_all_to_advanced, restore_all_controlnet_conns
from .control_reference import (ReferenceAdvanced, ReferenceInjections,
                                RefBasicTransformerBlock, RefTimestepEmbedSequential,
                                InjectionBasicTransformerBlockHolder, InjectionTimestepEmbedSequentialHolder,
                                _forward_inject_BasicTransformerBlock, factory_forward_inject_UNetModel,
                                handle_context_ref_setup,
                                REF_CONTROL_LIST_ALL, CONTEXTREF_CLEAN_FUNC)
from .control_lllite import (ControlLLLiteAdvanced)
from .utils import torch_dfs


def support_sliding_context_windows(model, positive, negative) -> tuple[bool, dict, dict]:
    # convert to advanced, with report if anything was actually modified
    modified, new_conds = convert_all_to_advanced([positive, negative])
    positive, negative = new_conds
    return modified, positive, negative


def has_sliding_context_windows(model):
    motion_injection_params = getattr(model, "motion_injection_params", None)
    if motion_injection_params is None:
        return False
    context_options = getattr(motion_injection_params, "context_options")
    return context_options.context_length is not None


def get_contextref_obj(model):
    motion_injection_params = getattr(model, "motion_injection_params", None)
    if motion_injection_params is None:
        return None
    context_options = getattr(motion_injection_params, "context_options")
    extras = getattr(context_options, "extras", None)
    if extras is None:
        return None
    return getattr(extras, "context_ref", None)


def acn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
    def get_refcn(control: ControlBase, order: int=-1):
        ref_set: set[ReferenceAdvanced] = set()
        if control is None:
            return ref_set
        if type(control) == ReferenceAdvanced and not control.is_context_ref:
            control.order = order
            order -= 1
            ref_set.add(control)
        ref_set.update(get_refcn(control.previous_controlnet, order=order))
        return ref_set

    def get_lllitecn(control: ControlBase):
        cn_dict: dict[ControlLLLiteAdvanced,None] = {}
        if control is None:
            return cn_dict
        if type(control) == ControlLLLiteAdvanced:
            cn_dict[control] = None
        cn_dict.update(get_lllitecn(control.previous_controlnet))
        return cn_dict

    def acn_sample(model: ModelPatcher, *args, **kwargs):
        controlnets_modified = False
        orig_positive = args[-3]
        orig_negative = args[-2]
        try:
            orig_model_options = model.model_options
            # check if positive or negative conds contain ref cn
            positive = args[-3]
            negative = args[-2]
            # if context options present, perform some special actions that may be required
            context_refs = []
            if has_sliding_context_windows(model):
                model.model_options = model.model_options.copy()
                model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
                # convert all CNs to Advanced if needed
                controlnets_modified, positive, negative = support_sliding_context_windows(model, positive, negative)
                if controlnets_modified:
                    args = list(args)
                    args[-3] = positive
                    args[-2] = negative
                    args = tuple(args)
                # enable ContextRef, if requested
                existing_contextref_obj = get_contextref_obj(model)
                if existing_contextref_obj is not None:
                    context_refs = handle_context_ref_setup(existing_contextref_obj, model.model_options["transformer_options"], positive, negative)
                    controlnets_modified = True
            # look for Advanced ControlNets that will require intervention to work
            ref_set = set()
            lllite_dict: dict[ControlLLLiteAdvanced, None] = {} # dicts preserve insertion order since py3.7
            if positive is not None:
                for cond in positive:
                    if "control" in cond[1]:
                        ref_set.update(get_refcn(cond[1]["control"]))
                        lllite_dict.update(get_lllitecn(cond[1]["control"]))
            if negative is not None:
                for cond in negative:
                    if "control" in cond[1]:
                        ref_set.update(get_refcn(cond[1]["control"]))
                        lllite_dict.update(get_lllitecn(cond[1]["control"]))
            # if lllite found, apply patches to a cloned model_options, and continue
            if len(lllite_dict) > 0:
                lllite_list = list(lllite_dict.keys())
                model.model_options = model.model_options.copy()
                model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
                lllite_list.reverse() # reverse so that patches will be applied in expected order
                for lll in lllite_list:
                    lll.live_model_patches(model.model_options)
            # if no ref cn found, do original function immediately
            if len(ref_set) == 0 and len(context_refs) == 0:
                return orig_comfy_sample(model, *args, **kwargs)
            # otherwise, injection time
            try:
                # inject
                # storage for all Reference-related injections
                reference_injections = ReferenceInjections()

                # first, handle attn module injection
                all_modules = torch_dfs(model.model)
                attn_modules: list[RefBasicTransformerBlock] = []
                for module in all_modules:
                    if isinstance(module, BasicTransformerBlock):
                        attn_modules.append(module)
                attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
                attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
                for i, module in enumerate(attn_modules):
                    injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i)
                    injection_holder.attn_weight = float(i) / float(len(attn_modules))
                    if hasattr(module, "_forward"): # backward compatibility
                        module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
                    else:
                        module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
                    module.injection_holder = injection_holder
                    reference_injections.attn_modules.append(module)
                # figure out which module is middle block
                if hasattr(model.model.diffusion_model, "middle_block"):
                    mid_modules = torch_dfs(model.model.diffusion_model.middle_block)
                    mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)]
                    for module in mid_attn_modules:
                        module.injection_holder.is_middle = True

                # next, handle gn module injection (TimestepEmbedSequential)
                # TODO: figure out the logic behind these hardcoded indexes
                if type(model.model).__name__ == "SDXL":
                    input_block_indices = [4, 5, 7, 8]
                    output_block_indices = [0, 1, 2, 3, 4, 5]
                else:
                    input_block_indices = [4, 5, 7, 8, 10, 11]
                    output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
                if hasattr(model.model.diffusion_model, "middle_block"):
                    module = model.model.diffusion_model.middle_block
                    injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True)
                    injection_holder.gn_weight = 0.0
                    module.injection_holder = injection_holder
                    reference_injections.gn_modules.append(module)
                for w, i in enumerate(input_block_indices):
                    module = model.model.diffusion_model.input_blocks[i]
                    injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True)
                    injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
                    module.injection_holder = injection_holder
                    reference_injections.gn_modules.append(module)
                for w, i in enumerate(output_block_indices):
                    module = model.model.diffusion_model.output_blocks[i]
                    injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True)
                    injection_holder.gn_weight = float(w) / float(len(output_block_indices))
                    module.injection_holder = injection_holder
                    reference_injections.gn_modules.append(module)
                # hack gn_module forwards and update weights
                for i, module in enumerate(reference_injections.gn_modules):
                    module.injection_holder.gn_weight *= 2

                # handle diffusion_model forward injection
                reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward
                model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model))
                # store ordered ref cns in model's transformer options
                new_model_options = model.model_options.copy()
                new_model_options["transformer_options"] = model.model_options["transformer_options"].copy()
                ref_list: list[ReferenceAdvanced] = list(ref_set)
                new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
                new_model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC] = reference_injections.clean_contextref_module_mem
                model.model_options = new_model_options
                # continue with original function
                return orig_comfy_sample(model, *args, **kwargs)
            finally:
                # cleanup injections
                # restore attn modules
                attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
                for module in attn_modules:
                    module.injection_holder.restore(module)
                    module.injection_holder.clean_all()
                    del module.injection_holder
                del attn_modules
                # restore gn modules
                gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
                for module in gn_modules:
                    module.injection_holder.restore(module)
                    module.injection_holder.clean_all()
                    del module.injection_holder
                del gn_modules
                # restore diffusion_model forward function
                model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model))
                # cleanup
                reference_injections.cleanup()
        finally:
            # restore model_options
            model.model_options = orig_model_options
            # restore controlnets in conds, if needed
            if controlnets_modified:
                restore_all_controlnet_conns([orig_positive, orig_negative])

    return acn_sample