Spaces:
Running
Running
File size: 21,654 Bytes
028694a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 |
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
# basically, all the LLLite core code is from there, which I then combined with
# Advanced-ControlNet features and QoL
import math
from typing import Union
from torch import Tensor
import torch
import os
import comfy.utils
import comfy.ops
import comfy.model_management
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
from .logger import logger
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, ControlWeights, broadcast_image_to_extend, extend_to_batch_size,
deepcopy_with_sharing, prepare_mask_batch)
# based on set_model_patch code in comfy/model_patcher.py
def set_model_patch(model_options, patch, name):
to = model_options["transformer_options"]
# check if patch was already added
if "patches" in to:
current_patches = to["patches"].get(name, [])
if patch in current_patches:
return
if "patches" not in to:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_attn1_patch(model_options, patch):
set_model_patch(model_options, patch, "attn1_patch")
def set_model_attn2_patch(model_options, patch):
set_model_patch(model_options, patch, "attn2_patch")
def extra_options_to_module_prefix(extra_options):
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
# block_index is: 0-1 or 0-9, depends on the block
# input 7 and 8, middle has 10 blocks
# make module name from extra_options
block = extra_options["block"]
block_index = extra_options["block_index"]
if block[0] == "input":
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
elif block[0] == "middle":
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
elif block[0] == "output":
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
else:
raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.")
return module_pfx
class LLLitePatch:
ATTN1 = "attn1"
ATTN2 = "attn2"
def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None):
self.modules = modules
self.control = control
self.patch_type = patch_type
#logger.error(f"create LLLitePatch: {id(self)},{control}")
def __call__(self, q, k, v, extra_options):
#logger.error(f"in __call__: {id(self)}")
# determine if have anything to run
if self.control.timestep_range is not None:
# it turns out comparing single-value tensors to floats is extremely slow
# a: Tensor = extra_options["sigmas"][0]
if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]:
return q, k, v
module_pfx = extra_options_to_module_prefix(extra_options)
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
if is_attn1:
module_pfx = module_pfx + "_attn1"
else:
module_pfx = module_pfx + "_attn2"
module_pfx_to_q = module_pfx + "_to_q"
module_pfx_to_k = module_pfx + "_to_k"
module_pfx_to_v = module_pfx + "_to_v"
if module_pfx_to_q in self.modules:
q = q + self.modules[module_pfx_to_q](q, self.control)
if module_pfx_to_k in self.modules:
k = k + self.modules[module_pfx_to_k](k, self.control)
if module_pfx_to_v in self.modules:
v = v + self.modules[module_pfx_to_v](v, self.control)
return q, k, v
def to(self, device):
#logger.info(f"to... has control? {self.control}")
for d in self.modules.keys():
self.modules[d] = self.modules[d].to(device)
return self
def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch':
self.control = control
return self
#logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}")
def clone_with_control(self, control: AdvancedControlBase):
#logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}")
return LLLitePatch(self.modules, self.patch_type, control)
def cleanup(self):
#total_cleaned = 0
for module in self.modules.values():
module.cleanup()
# total_cleaned += 1
#logger.info(f"cleaned modules: {total_cleaned}, {id(self)}")
#logger.error(f"cleanup LLLitePatch: {id(self)}")
# make sure deepcopy does not copy control, and deepcopied LLLitePatch should be assigned to control
# def __deepcopy__(self, memo):
# self.cleanup()
# to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo)
# #logger.warn(f"patch {id(self)} turned into {id(to_return)}")
# try:
# if self.patch_type == self.ATTN1:
# to_return.control.patch_attn1 = to_return
# elif self.patch_type == self.ATTN2:
# to_return.control.patch_attn2 = to_return
# except Exception:
# pass
# return to_return
# TODO: use comfy.ops to support fp8 properly
class LLLiteModule(torch.nn.Module):
def __init__(
self,
name: str,
is_conv2d: bool,
in_dim: int,
depth: int,
cond_emb_dim: int,
mlp_dim: int,
):
super().__init__()
self.name = name
self.is_conv2d = is_conv2d
self.is_first = False
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
self.conditioning1 = torch.nn.Sequential(*modules)
if self.is_conv2d:
self.down = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
)
else:
self.down = torch.nn.Sequential(
torch.nn.Linear(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Linear(mlp_dim, in_dim),
)
self.depth = depth
self.cond_emb = None
self.cx_shape = None
self.prev_batch = 0
self.prev_sub_idxs = None
def cleanup(self):
del self.cond_emb
self.cond_emb = None
self.cx_shape = None
self.prev_batch = 0
self.prev_sub_idxs = None
def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]):
mask = None
mask_tk = None
#logger.info(x.shape)
if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch:
# print(f"cond_emb is None, {self.name}")
cond_hint = control.cond_hint.to(x.device, dtype=x.dtype)
if control.latent_dims_div2 is not None and x.shape[-1] != 1280:
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
elif control.latent_dims_div4 is not None and x.shape[-1] == 1280:
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
cx = self.conditioning1(cond_hint)
self.cx_shape = cx.shape
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
self.cond_emb = cx
# save prev values
self.prev_batch = x.shape[0]
self.prev_sub_idxs = control.sub_idxs
cx: torch.Tensor = self.cond_emb
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
# TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them)
# create masks
if not self.is_conv2d:
n, c, h, w = self.cx_shape
if control.mask_cond_hint is not None:
mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1)
if control.tk_mask_cond_hint is not None:
mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1)
# x in uncond/cond doubles batch size
if x.shape[0] != cx.shape[0]:
if self.is_conv2d:
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
else:
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
if mask is not None:
mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1)
if mask_tk is not None:
mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1)
if mask is None:
mask = 1.0
elif mask_tk is not None:
mask = mask * mask_tk
#logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}")
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
cx = self.up(cx)
if control.latent_keyframes is not None:
cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number)
if control.weights is not None and control.weights.has_uncond_multiplier:
cond_or_uncond = control.batched_number.cond_or_uncond
actual_length = cx.size(0) // control.batched_number
for idx, cond_type in enumerate(cond_or_uncond):
# if uncond, set to weight's uncond_multiplier
if cond_type == 1:
cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier
return cx * mask * control.strength * control._current_timestep_keyframe.strength
class ControlLLLiteModules(torch.nn.Module):
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch):
super().__init__()
self.patch_attn1_modules = torch.nn.Sequential(*list(patch_attn1.modules.values()))
self.patch_attn2_modules = torch.nn.Sequential(*list(patch_attn2.modules.values()))
class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
# This ControlNet is more of an attention patch than a traditional controlnet
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device, ops: comfy.ops.disable_weight_init):
super().__init__()
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
self.device = device
self.ops = ops
self.patch_attn1 = patch_attn1.clone_with_control(self)
self.patch_attn2 = patch_attn2.clone_with_control(self)
self.control_model = ControlLLLiteModules(self.patch_attn1, self.patch_attn2)
self.control_model_wrapped = ModelPatcher(self.control_model, load_device=device, offload_device=comfy.model_management.unet_offload_device())
self.latent_dims_div2 = None
self.latent_dims_div4 = None
def live_model_patches(self, model_options):
set_model_attn1_patch(model_options, self.patch_attn1.set_control(self))
set_model_attn2_patch(model_options, self.patch_attn2.set_control(self))
# def patch_model(self, model: ModelPatcher):
# model.set_model_attn1_patch(self.patch_attn1)
# model.set_model_attn2_patch(self.patch_attn2)
def set_cond_hint_inject(self, *args, **kwargs):
to_return = super().set_cond_hint_inject(*args, **kwargs)
# cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
return to_return
def pre_run_advanced(self, *args, **kwargs):
AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
#logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
self.patch_attn1.set_control(self)
self.patch_attn2.set_control(self)
#logger.warn(f"in pre_run_advanced: {id(self)}")
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
# normal ControlNet stuff
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
return control_prev
dtype = x_noisy.dtype
# prepare cond_hint
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
if self.sub_idxs is not None:
actual_cond_hint_orig = self.cond_hint_original
if self.cond_hint_original.size(0) < self.full_latent_length:
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
else:
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(x_noisy.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
# some special logic here compared to other controlnets:
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
divisible_by_2_h = x_noisy.shape[2]%2==0
divisible_by_2_w = x_noisy.shape[3]%2==0
if not (divisible_by_2_h and divisible_by_2_w):
#logger.warn(f"{x_noisy.shape} not divisible by 2!")
new_h = (x_noisy.shape[2]//2)*2
new_w = (x_noisy.shape[3]//2)*2
if not divisible_by_2_h:
new_h += 2
if not divisible_by_2_w:
new_w += 2
self.latent_dims_div2 = (new_h, new_w)
divisible_by_4_h = x_noisy.shape[2]%4==0
divisible_by_4_w = x_noisy.shape[3]%4==0
if not (divisible_by_4_h and divisible_by_4_w):
#logger.warn(f"{x_noisy.shape} not divisible by 4!")
new_h = (x_noisy.shape[2]//4)*4
new_w = (x_noisy.shape[3]//4)*4
if not divisible_by_4_h:
new_h += 4
if not divisible_by_4_w:
new_w += 4
self.latent_dims_div4 = (new_h, new_w)
# prepare mask
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
# done preparing; model patches will take care of everything now.
# return normal controlnet stuff
return control_prev
def get_models(self):
to_return: list = super().get_models()
to_return.append(self.control_model_wrapped)
return to_return
def cleanup_advanced(self):
super().cleanup_advanced()
self.patch_attn1.cleanup()
self.patch_attn2.cleanup()
self.latent_dims_div2 = None
self.latent_dims_div4 = None
def copy(self):
c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes, self.device, self.ops)
self.copy_to(c)
self.copy_to_advanced(c)
return c
# deepcopy needs to properly keep track of objects to work between model.clone calls!
# def __deepcopy__(self, *args, **kwargs):
# self.cleanup_advanced()
# return self
# def get_models(self):
# # get_models is called once at the start of every KSampler run - use to reset already_patched status
# out = super().get_models()
# logger.error(f"in get_models! {id(self)}")
# return out
def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
if controlnet_data is None:
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
# first, split weights for each module
module_weights = {}
for key, value in controlnet_data.items():
fragments = key.split(".")
module_name = fragments[0]
weight_name = ".".join(fragments[1:])
if module_name not in module_weights:
module_weights[module_name] = {}
module_weights[module_name][weight_name] = value
unet_dtype = comfy.model_management.unet_dtype()
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
ops = comfy.ops.disable_weight_init
if manual_cast_dtype is not None:
ops = comfy.ops.manual_cast
# next, load each module
modules = {}
for module_name, weights in module_weights.items():
# kohya planned to do something about how these should be chosen, so I'm not touching this
# since I am not familiar with the logic for this
if "conditioning1.4.weight" in weights:
depth = 3
elif weights["conditioning1.2.weight"].shape[-1] == 4:
depth = 2
else:
depth = 1
module = LLLiteModule(
name=module_name,
is_conv2d=weights["down.0.weight"].ndim == 4,
in_dim=weights["down.0.weight"].shape[1],
depth=depth,
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
mlp_dim=weights["down.0.weight"].shape[0],
)
# load weights into module
module.load_state_dict(weights)
modules[module_name] = module.to(dtype=unet_dtype)
if len(modules) == 1:
module.is_first = True
#logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe, device=load_device, ops=ops)
return control
|