File size: 23,747 Bytes
028694a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Code ported and modified from the diffusers ControlNetPlus repo by Qi Xin:
# https://github.com/xinsir6/ControlNetPlus/blob/main/models/controlnet_union.py
from typing import Union

import os
import torch
import torch as th
import torch.nn as nn
from torch import Tensor
from collections import OrderedDict


from comfy.ldm.modules.diffusionmodules.util import (zero_module, timestep_embedding)

from comfy.cldm.cldm import ControlNet as ControlNetCLDM
import comfy.cldm.cldm
from comfy.controlnet import ControlNet
#from comfy.t2i_adapter.adapter import ResidualAttentionBlock
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.model_management
import comfy.model_detection
import comfy.utils

from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper,
                    extend_to_batch_size, broadcast_image_to_extend)
from .logger import logger


class PlusPlusType:
    OPENPOSE = "openpose"
    DEPTH = "depth"
    THICKLINE = "hed/pidi/scribble/ted"
    THINLINE = "canny/lineart/mlsd"
    NORMAL = "normal"
    SEGMENT = "segment"
    TILE = "tile"
    REPAINT = "inpaint/outpaint"
    NONE = "none"
    _LIST_WITH_NONE = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT, NONE]
    _LIST = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT]
    _DICT = {OPENPOSE: 0, DEPTH: 1, THICKLINE: 2, THINLINE: 3, NORMAL: 4, SEGMENT: 5, TILE: 6, REPAINT: 7, NONE: -1}

    @classmethod
    def to_idx(cls, control_type: str):
        try:
            return cls._DICT[control_type]
        except KeyError:
            raise Exception(f"Unknown control type '{control_type}'.")


class PlusPlusInput:
    def __init__(self, image: Tensor, control_type: str, strength: float):
        self.image = image
        self.control_type = control_type
        self.strength = strength

    def clone(self):
        return PlusPlusInput(self.image, self.control_type, self.strength)


class PlusPlusInputGroup:
    def __init__(self):
        self.controls: dict[str, PlusPlusInput] = {}
    
    def add(self, pp_input: PlusPlusInput):
        if pp_input.control_type in self.controls:
            raise Exception(f"Control type '{pp_input.control_type}' is already present; ControlNet++ does not allow more than 1 of each type.")
        self.controls[pp_input.control_type] = pp_input
    
    def clone(self) -> 'PlusPlusInputGroup':
        cloned = PlusPlusInputGroup()
        for key, value in self.controls.items():
            cloned.controls[key] = value.clone()
        return cloned


class PlusPlusImageWrapper(AbstractPreprocWrapper):
    error_msg = error_msg = "Invalid use of ControlNet++ Image Wrapper. The output of ControlNet++ Image Wrapper is NOT a usual image, but an object holding the images and extra info - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
    def __init__(self, condhint: PlusPlusInputGroup):
        super().__init__(condhint)
        # just an IDE type hint
        self.condhint: PlusPlusInputGroup

    def movedim(self, source: int, destination: int):
        condhint = self.condhint.clone()
        for pp_input in condhint.controls.values():
            pp_input.image = pp_input.image.movedim(source, destination)
        return PlusPlusImageWrapper(condhint)

# parts taken from comfy/cldm/cldm.py
class OptimizedAttention(nn.Module):
    def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
        super().__init__()
        self.heads = nhead
        self.c = c

        self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
        self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)

    def forward(self, x):
        x = self.in_proj(x)
        q, k, v = x.split(self.c, dim=2)
        out = optimized_attention(q, k, v, self.heads)
        return self.out_proj(out)

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class ResBlockUnionControlnet(nn.Module):
    def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
        super().__init__()
        self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
        self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
        self.mlp = nn.Sequential(
            OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
                         ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
        self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)

    def attention(self, x: torch.Tensor):
        return self.attn(x)

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class ControlAddEmbeddingAdv(nn.Module):
    def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations: comfy.ops.disable_weight_init=None):
        super().__init__()
        self.num_control_type = num_control_type
        self.in_dim = in_dim
        self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
        self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)

    def forward(self, control_type, dtype, device):
        if control_type is None:
            control_type = torch.zeros((self.num_control_type,), device=device)
        c_type = timestep_embedding(control_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
        return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))


class ControlNetPlusPlus(ControlNetCLDM):
    def __init__(self, *args,**kwargs):
        super().__init__(*args, **kwargs)

        operations: comfy.ops.disable_weight_init = kwargs.get("operations", comfy.ops.disable_weight_init)
        device = kwargs.get("device", None)

        time_embed_dim = self.model_channels * 4
        control_add_embed_dim = 256

        self.control_add_embedding = ControlAddEmbeddingAdv(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)

    def union_controlnet_merge(self, hint: list[Tensor], control_type, emb, context):
        # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
        indexes = torch.nonzero(control_type[0])
        inputs = []
        condition_list = []

        for idx in range(indexes.shape[0]):
            controlnet_cond = self.input_hint_block(hint[indexes[idx][0]], emb, context)
            feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
            if idx < indexes.shape[0]:
                feat_seq += self.task_embedding[indexes[idx][0]].to(dtype=feat_seq.dtype, device=feat_seq.device)

            inputs.append(feat_seq.unsqueeze(1))
            condition_list.append(controlnet_cond)

        x = torch.cat(inputs, dim=1)
        x = self.transformer_layes(x)

        controlnet_cond_fuser = None
        for idx in range(indexes.shape[0]):
            alpha = self.spatial_ch_projs(x[:, idx])
            alpha = alpha.unsqueeze(-1).unsqueeze(-1)
            o = condition_list[idx] + alpha
            if controlnet_cond_fuser is None:
                controlnet_cond_fuser = o
            else:
                controlnet_cond_fuser += o
        return controlnet_cond_fuser

    def forward(self, x: Tensor, hint: list[Tensor], timesteps, context, y: Tensor=None, **kwargs):
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
        emb = self.time_embed(t_emb)

        guided_hint = None
        if self.control_add_embedding is not None:
            control_type = kwargs.get("control_type", None)

            emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
            if control_type is not None:
                guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)

        if guided_hint is None:
            guided_hint = self.input_hint_block(hint[0], emb, context)

        out_output = []
        out_middle = []

        hs = []
        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        h = x
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)
                h += guided_hint
                guided_hint = None
            else:
                h = module(h, emb, context)
            out_output.append(zero_conv(h, emb, context))

        h = self.middle_block(h, emb, context)
        out_middle.append(self.middle_block_out(h, emb, context))

        return {"middle": out_middle, "output": out_output}


class ControlNetPlusPlusAdvanced(ControlNet, AdvancedControlBase):
    def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, load_device=None, manual_cast_dtype=None):
        super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
        AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
        self.add_compatible_weight(ControlWeightType.CONTROLNETPLUSPLUS)
        # for IDE type hint purposes
        self.control_model: ControlNetPlusPlus
        self.cond_hint_original: Union[PlusPlusImageWrapper, PlusPlusInputGroup]
        self.cond_hint: list[Union[Tensor, None]]
        self.cond_hint_shape: Tensor = None
        self.cond_hint_types: Tensor = None
        # in case it is using the single loader
        self.single_control_type: str = None

    def get_universal_weights(self) -> ControlWeights:
        def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
            if key == "middle":
                return 1.0
            c_len = len(control[key])
            raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
            raw_weights = raw_weights[:-1]
            if key == "input":
                raw_weights.reverse()
            return raw_weights[idx]
        return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)

    def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None):
        if pp_group is not None:
            for pp_input in pp_group.controls.values():
                if PlusPlusType.to_idx(pp_input.control_type) >= self.control_model.num_control_type:
                    raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{pp_input.control_type}'.")
        if self.single_control_type is not None:
            if PlusPlusType.to_idx(self.single_control_type) >= self.control_model.num_control_type:
                raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{self.single_control_type}'.")

    def set_cond_hint_inject(self, *args, **kwargs):
        to_return = super().set_cond_hint_inject(*args, **kwargs)
        # if not single_control_type, expect PlusPlusImageWrapper
        if self.single_control_type is None:
            # check that cond_hint is wrapped, and unwrap it
            if type(self.cond_hint_original) != PlusPlusImageWrapper:
                raise Exception("ControlNet++ (Multi) expects image input from the Load ControlNet++ Model node, NOT from anything else. Images are provided to that node via ControlNet++ Input nodes.")
            self.cond_hint_original = self.cond_hint_original.condhint.clone()
        # otherwise, expect single image input (AKA, usual controlnet input)
        else:
            # check that cond_hint is not a PlusPlusImageWrapper
            if type(self.cond_hint_original) == PlusPlusImageWrapper:
                raise Exception("ControlNet++ (Single) expects usual image input, NOT the image input from a Load ControlNet++ Model (Multi) node.")
            pp_group = PlusPlusInputGroup()
            pp_input = PlusPlusInput(self.cond_hint_original, self.single_control_type, 1.0)
            pp_group.add(pp_input)
            self.cond_hint_original = pp_group
        return to_return

    def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number):
        control_prev = None
        if self.previous_controlnet is not None:
            control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

        if self.timestep_range is not None:
            if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
                if control_prev is not None:
                    return control_prev
                else:
                    return None

        dtype = self.control_model.dtype
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype

        output_dtype = x_noisy.dtype

        # make all cond_hints appropriate dimensions
        # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs is present
        if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint_shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint_shape[3]:
            if self.cond_hint is not None:
                del self.cond_hint
            self.cond_hint = [None] * self.control_model.num_control_type
            self.cond_hint_types = torch.tensor([0.0] * self.control_model.num_control_type)
            self.cond_hint_shape = None
            compression_ratio = self.compression_ratio
            # unlike normal controlnet, need to handle each input image tensor (for each type)
            for pp_type, pp_input in self.cond_hint_original.controls.items():
                pp_idx = PlusPlusType.to_idx(pp_type)
                # if negative, means no type should be selected (single only)
                if pp_idx < 0:
                    pp_idx = 0
                else:
                    self.cond_hint_types[pp_idx] = pp_input.strength
                # if self.cond_hint_original lengths greater or equal to latent count, subdivide
                if self.sub_idxs is not None:
                    actual_cond_hint_orig = pp_input.image
                    if pp_input.image.size(0) < self.full_latent_length:
                        actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
                    self.cond_hint[pp_idx] = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
                else:
                    self.cond_hint[pp_idx] = comfy.utils.common_upscale(pp_input.image, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
                self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=x_noisy.device, dtype=dtype)
                self.cond_hint_shape = self.cond_hint[pp_idx].shape
            # prepare cond_hint_controls to match batchsize
            if self.cond_hint_types.count_nonzero() == 0:
                self.cond_hint_types = None
            else:
                self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=x_noisy.device, dtype=dtype).repeat(x_noisy.shape[0], 1)
        for i in range(len(self.cond_hint)):
            if self.cond_hint[i] is not None:
                if x_noisy.shape[0] != self.cond_hint[i].shape[0]:
                    self.cond_hint[i] = broadcast_image_to_extend(self.cond_hint[i], x_noisy.shape[0], batched_number)
        if self.cond_hint_types is not None and x_noisy.shape[0] != self.cond_hint_types.shape[0]:
            self.cond_hint_types = broadcast_image_to_extend(self.cond_hint_types, x_noisy.shape[0], batched_number, False)
        
        # prepare mask_cond_hint
        self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)

        context = cond.get('crossattn_controlnet', cond['c_crossattn'])
        y = cond.get('y', None)
        if y is not None:
            y = y.to(dtype)
        timestep = self.model_sampling_current.timestep(t)
        x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

        control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, control_type=self.cond_hint_types)
        return self.control_merge(control, control_prev, output_dtype)

    def copy(self):
        c = ControlNetPlusPlusAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
        self.copy_to(c)
        self.copy_to_advanced(c)
        c.single_control_type = self.single_control_type
        return c


def load_controlnetplusplus(ckpt_path: str, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
    controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
    # check that actually is ControlNet++ model
    if "task_embedding" not in controlnet_data:
        raise Exception(f"'{ckpt_path}' is not a valid ControlNet++ model.")

    controlnet_config = None
    supported_inference_dtypes = None

    if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
        controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
        diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
        diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
        diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                k_in = "controlnet_down_blocks.{}{}".format(count, s)
                k_out = "zero_convs.{}.0{}".format(count, s)
                if k_in not in controlnet_data:
                    loop = False
                    break
                diffusers_keys[k_in] = k_out
            count += 1

        count = 0
        loop = True
        while loop:
            suffix = [".weight", ".bias"]
            for s in suffix:
                if count == 0:
                    k_in = "controlnet_cond_embedding.conv_in{}".format(s)
                else:
                    k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
                k_out = "input_hint_block.{}{}".format(count * 2, s)
                if k_in not in controlnet_data:
                    k_in = "controlnet_cond_embedding.conv_out{}".format(s)
                    loop = False
                diffusers_keys[k_in] = k_out
            count += 1

        new_sd = {}
        for k in diffusers_keys:
            if k in controlnet_data:
                new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
        
        if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
            controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
            for k in list(controlnet_data.keys()):
                new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
                new_sd[new_k] = controlnet_data.pop(k)

        leftover_keys = controlnet_data.keys()
        if len(leftover_keys) > 0:
            logger.warning("leftover ControlNet++ keys: {}".format(leftover_keys))
        controlnet_data = new_sd
    elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
        raise Exception("Unexpected SD3 diffusers format for ControlNet++ model. Something is very wrong.")

    pth_key = 'control_model.zero_convs.0.0.weight'
    pth = False
    key = 'zero_convs.0.0.weight'
    if pth_key in controlnet_data:
        pth = True
        key = pth_key
        prefix = "control_model."
    elif key in controlnet_data:
        prefix = ""
    else:
        raise Exception("Unexpected T2IAdapter format for ControlNet++ model. Something is very wrong.")

    if controlnet_config is None:
        model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
        supported_inference_dtypes = model_config.supported_inference_dtypes
        controlnet_config = model_config.unet_config

    load_device = comfy.model_management.get_torch_device()
    if supported_inference_dtypes is None:
        unet_dtype = comfy.model_management.unet_dtype()
    else:
        unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)

    manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
    if manual_cast_dtype is not None:
        controlnet_config["operations"] = comfy.ops.manual_cast
    controlnet_config["dtype"] = unet_dtype
    controlnet_config.pop("out_channels")
    controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
    control_model = ControlNetPlusPlus(**controlnet_config)

    if pth:
        if 'difference' in controlnet_data:
            if model is not None:
                comfy.model_management.load_models_gpu([model])
                model_sd = model.model_state_dict()
                for x in controlnet_data:
                    c_m = "control_model."
                    if x.startswith(c_m):
                        sd_key = "diffusion_model.{}".format(x[len(c_m):])
                        if sd_key in model_sd:
                            cd = controlnet_data[x]
                            cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
            else:
                logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")

        class WeightsLoader(torch.nn.Module):
            pass
        w = WeightsLoader()
        w.control_model = control_model
        missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
    else:
        missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)

    if len(missing) > 0:
        logger.warning("missing ControlNet++ keys: {}".format(missing))

    if len(unexpected) > 0:
        logger.debug("unexpected ControlNet++ keys: {}".format(unexpected))

    global_average_pooling = False
    filename = os.path.splitext(ckpt_path)[0]
    if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
        global_average_pooling = True

    control = ControlNetPlusPlusAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
    return control