Spaces:
Running
Running
from copy import deepcopy | |
from typing import Callable, Union | |
import torch | |
from torch import Tensor | |
import torch.nn.functional | |
from einops import rearrange | |
import numpy as np | |
import math | |
import comfy.ops | |
import comfy.utils | |
import comfy.sample | |
import comfy.samplers | |
import comfy.model_base | |
from comfy.controlnet import ControlBase | |
from comfy.model_patcher import ModelPatcher | |
from comfy.sd import VAE | |
from .logger import logger | |
BIGMIN = -(2**53-1) | |
BIGMAX = (2**53-1) | |
ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet" | |
CONTROL_INIT_BY_ACN = "_control_init_by_ACN" | |
def load_torch_file_with_dict_factory(controlnet_data: dict[str, Tensor], orig_load_torch_file: Callable): | |
def load_torch_file_with_dict(*args, **kwargs): | |
# immediately restore load_torch_file to original version | |
comfy.utils.load_torch_file = orig_load_torch_file | |
return controlnet_data | |
return load_torch_file_with_dict | |
# wrapping len function so that it will save the thing len is trying to get the length of; | |
# this will be assumed to be the cond_or_uncond variable; | |
# automatically restores len to original function after running | |
def wrapper_len_factory(orig_len: Callable) -> Callable: | |
def wrapper_len(*args, **kwargs): | |
cond_or_uncond = args[0] | |
real_length = orig_len(*args, **kwargs) | |
if real_length > 0 and type(cond_or_uncond) == list and isinstance(cond_or_uncond[0], int) and (cond_or_uncond[0] in [0, 1]): | |
try: | |
to_return = IntWithCondOrUncond(real_length) | |
setattr(to_return, "cond_or_uncond", cond_or_uncond) | |
return to_return | |
finally: | |
__builtins__["len"] = orig_len | |
else: | |
return real_length | |
return wrapper_len | |
# wrapping cond_cat function so that it will wrap around len function to get cond_or_uncond variable value | |
# from comfy.samplers.calc_conds_batch | |
def wrapper_cond_cat_factory(orig_cond_cat: Callable): | |
def wrapper_cond_cat(*args, **kwargs): | |
__builtins__["len"] = wrapper_len_factory(__builtins__["len"]) | |
return orig_cond_cat(*args, **kwargs) | |
return wrapper_cond_cat | |
orig_cond_cat = comfy.samplers.cond_cat | |
comfy.samplers.cond_cat = wrapper_cond_cat_factory(orig_cond_cat) | |
# wrapping apply_model so that len function will be cleaned up fairly soon after being injected | |
def apply_model_uncond_cleanup_factory(orig_apply_model, orig_len): | |
def apply_model_uncond_cleanup_wrapper(self, *args, **kwargs): | |
__builtins__["len"] = orig_len | |
return orig_apply_model(self, *args, **kwargs) | |
return apply_model_uncond_cleanup_wrapper | |
global_orig_len = __builtins__["len"] | |
orig_apply_model = comfy.model_base.BaseModel.apply_model | |
comfy.model_base.BaseModel.apply_model = apply_model_uncond_cleanup_factory(orig_apply_model, global_orig_len) | |
def uncond_multiplier_check_cn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable: | |
def contains_uncond_multiplier(control: Union[ControlBase, 'AdvancedControlBase']): | |
if control is None: | |
return False | |
if not isinstance(control, AdvancedControlBase): | |
return contains_uncond_multiplier(control.previous_controlnet) | |
# check if weights_override has an uncond_multiplier | |
if control.weights_override is not None and control.weights_override.has_uncond_multiplier: | |
return True | |
# check if any timestep_keyframes have an uncond_multiplier on their weights | |
if control.timestep_keyframes is not None: | |
for tk in control.timestep_keyframes.keyframes: | |
if tk.has_control_weights() and tk.control_weights.has_uncond_multiplier: | |
return True | |
return contains_uncond_multiplier(control.previous_controlnet) | |
# check if positive or negative conds contain Adv. Cns that use multiply_negative on weights | |
def uncond_multiplier_check_cn_sample(model: ModelPatcher, *args, **kwargs): | |
positive = args[-3] | |
negative = args[-2] | |
has_uncond_multiplier = False | |
if positive is not None: | |
for cond in positive: | |
if "control" in cond[1]: | |
has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"]) | |
if has_uncond_multiplier: | |
break | |
if negative is not None and not has_uncond_multiplier: | |
for cond in negative: | |
if "control" in cond[1]: | |
has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"]) | |
if has_uncond_multiplier: | |
break | |
try: | |
# if uncond_multiplier found, continue to use wrapped version of function | |
if has_uncond_multiplier: | |
return orig_comfy_sample(model, *args, **kwargs) | |
# otherwise, use original version of function to prevent even the smallest of slowdowns (0.XX%) | |
try: | |
wrapped_cond_cat = comfy.samplers.cond_cat | |
comfy.samplers.cond_cat = orig_cond_cat | |
return orig_comfy_sample(model, *args, **kwargs) | |
finally: | |
comfy.samplers.cond_cat = wrapped_cond_cat | |
finally: | |
# make sure len function is unwrapped by the time sampling is done, just in case | |
__builtins__["len"] = global_orig_len | |
return uncond_multiplier_check_cn_sample | |
# inject sample functions | |
comfy.sample.sample = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample) | |
comfy.sample.sample_custom = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample_custom, is_custom=True) | |
class IntWithCondOrUncond(int): | |
def __new__(cls, *args, **kwargs): | |
return super(IntWithCondOrUncond, cls).__new__(cls, *args, **kwargs) | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.cond_or_uncond = None | |
def get_properly_arranged_t2i_weights(initial_weights: list[float]): | |
new_weights = [] | |
new_weights.extend([initial_weights[0]]*3) | |
new_weights.extend([initial_weights[1]]*3) | |
new_weights.extend([initial_weights[2]]*3) | |
new_weights.extend([initial_weights[3]]*3) | |
return new_weights | |
class ControlWeightType: | |
DEFAULT = "default" | |
UNIVERSAL = "universal" | |
T2IADAPTER = "t2iadapter" | |
CONTROLNET = "controlnet" | |
CONTROLNETPLUSPLUS = "controlnet++" | |
CONTROLLORA = "controllora" | |
CONTROLLLLITE = "controllllite" | |
SVD_CONTROLNET = "svd_controlnet" | |
SPARSECTRL = "sparsectrl" | |
class ControlWeights: | |
def __init__(self, weight_type: str, base_multiplier: float=1.0, | |
weights_input: list[float]=None, weights_middle: list[float]=None, weights_output: list[float]=None, | |
weight_func: Callable=None, weight_mask: Tensor=None, | |
uncond_multiplier=1.0, uncond_mask: Tensor=None, extras: dict[str]={},): | |
self.weight_type = weight_type | |
self.base_multiplier = base_multiplier | |
self.weights_input = weights_input | |
self.weights_middle = weights_middle | |
self.weights_output = weights_output | |
self.weight_func = weight_func | |
self.weight_mask = weight_mask | |
self.uncond_multiplier = float(uncond_multiplier) | |
self.has_uncond_multiplier = not math.isclose(self.uncond_multiplier, 1.0) | |
self.uncond_mask = uncond_mask if uncond_mask is not None else 1.0 | |
self.has_uncond_mask = uncond_mask is not None | |
self.extras = extras | |
def get(self, idx: int, control: dict[str, list[Tensor]], key: str, default=1.0) -> Union[float, Tensor]: | |
# if weight_func present, use it | |
if self.weight_func is not None: | |
return self.weight_func(idx=idx, control=control, key=key) | |
# if weights is not none, return index | |
relevant_weights = None | |
if key == "middle": | |
relevant_weights = self.weights_middle | |
elif key == "input": | |
relevant_weights = self.weights_input | |
if relevant_weights is not None: | |
relevant_weights = list(reversed(relevant_weights)) | |
else: | |
relevant_weights = self.weights_output | |
if relevant_weights is None: | |
return default | |
elif idx >= len(relevant_weights): | |
return default | |
return relevant_weights[idx] | |
def copy_with_new_weights(self, new_weights_input: list[float]=None, new_weights_middle: list[float]=None, new_weights_output: list[float]=None, | |
new_weight_func: Callable=None): | |
return ControlWeights(weight_type=self.weight_type, base_multiplier=self.base_multiplier, | |
weights_input=new_weights_input, weights_middle=new_weights_middle, weights_output=new_weights_output, | |
weight_func=new_weight_func, weight_mask=self.weight_mask, | |
uncond_multiplier=self.uncond_multiplier, extras=self.extras) | |
def default(cls, extras: dict[str]={}): | |
return cls(ControlWeightType.DEFAULT, extras=extras) | |
def universal(cls, base_multiplier: float, uncond_multiplier: float=1.0, extras: dict[str]={}): | |
return cls(ControlWeightType.UNIVERSAL, base_multiplier=base_multiplier, uncond_multiplier=uncond_multiplier, extras=extras) | |
def universal_mask(cls, weight_mask: Tensor, uncond_multiplier: float=1.0, extras: dict[str]={}): | |
return cls(ControlWeightType.UNIVERSAL, weight_mask=weight_mask, uncond_multiplier=uncond_multiplier, extras=extras) | |
def t2iadapter(cls, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}): | |
return cls(ControlWeightType.T2IADAPTER, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras) | |
def controlnet(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}): | |
return cls(ControlWeightType.CONTROLNET, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras) | |
def controllora(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}): | |
return cls(ControlWeightType.CONTROLLORA, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras) | |
def controllllite(cls, weights_output: list[float]=None, weights_middle: list[float]=None, weights_input: list[float]=None, uncond_multiplier: float=1.0, extras: dict[str]={}): | |
return cls(ControlWeightType.CONTROLLLLITE, weights_output=weights_output, weights_middle=weights_middle, weights_input=weights_input, uncond_multiplier=uncond_multiplier, extras=extras) | |
class StrengthInterpolation: | |
LINEAR = "linear" | |
EASE_IN = "ease-in" | |
EASE_OUT = "ease-out" | |
EASE_IN_OUT = "ease-in-out" | |
NONE = "none" | |
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT] | |
_LIST_WITH_NONE = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT, NONE] | |
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False): | |
diff = num_to - num_from | |
if method == cls.LINEAR: | |
weights = torch.linspace(num_from, num_to, length) | |
elif method == cls.EASE_IN: | |
index = torch.linspace(0, 1, length) | |
weights = diff * np.power(index, 2) + num_from | |
elif method == cls.EASE_OUT: | |
index = torch.linspace(0, 1, length) | |
weights = diff * (1 - np.power(1 - index, 2)) + num_from | |
elif method == cls.EASE_IN_OUT: | |
index = torch.linspace(0, 1, length) | |
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from | |
else: | |
raise ValueError(f"Unrecognized interpolation method '{method}'.") | |
if reverse: | |
weights = weights.flip(dims=(0,)) | |
return weights | |
class LatentKeyframe: | |
def __init__(self, batch_index: int, strength: float) -> None: | |
self.batch_index = batch_index | |
self.strength = strength | |
# always maintain sorted state (by batch_index of LatentKeyframe) | |
class LatentKeyframeGroup: | |
def __init__(self) -> None: | |
self.keyframes: list[LatentKeyframe] = [] | |
def add(self, keyframe: LatentKeyframe) -> None: | |
added = False | |
# replace existing keyframe if same batch_index | |
for i in range(len(self.keyframes)): | |
if self.keyframes[i].batch_index == keyframe.batch_index: | |
self.keyframes[i] = keyframe | |
added = True | |
break | |
if not added: | |
self.keyframes.append(keyframe) | |
self.keyframes.sort(key=lambda k: k.batch_index) | |
def get_index(self, index: int) -> Union[LatentKeyframe, None]: | |
try: | |
return self.keyframes[index] | |
except IndexError: | |
return None | |
def __getitem__(self, index) -> LatentKeyframe: | |
return self.keyframes[index] | |
def is_empty(self) -> bool: | |
return len(self.keyframes) == 0 | |
def clone(self) -> 'LatentKeyframeGroup': | |
cloned = LatentKeyframeGroup() | |
for tk in self.keyframes: | |
cloned.add(tk) | |
return cloned | |
class TimestepKeyframe: | |
def __init__(self, | |
start_percent: float = 0.0, | |
strength: float = 1.0, | |
control_weights: ControlWeights = None, | |
latent_keyframes: LatentKeyframeGroup = None, | |
null_latent_kf_strength: float = 0.0, | |
inherit_missing: bool = True, | |
guarantee_steps: int = 1, | |
mask_hint_orig: Tensor = None) -> None: | |
self.start_percent = float(start_percent) | |
self.start_t = 999999999.9 | |
self.strength = strength | |
self.control_weights = control_weights | |
self.latent_keyframes = latent_keyframes | |
self.null_latent_kf_strength = null_latent_kf_strength | |
self.inherit_missing = inherit_missing | |
self.guarantee_steps = guarantee_steps | |
self.mask_hint_orig = mask_hint_orig | |
def has_control_weights(self): | |
return self.control_weights is not None | |
def has_latent_keyframes(self): | |
return self.latent_keyframes is not None | |
def has_mask_hint(self): | |
return self.mask_hint_orig is not None | |
def default() -> 'TimestepKeyframe': | |
return TimestepKeyframe(start_percent=0.0, guarantee_steps=0) | |
# always maintain sorted state (by start_percent of TimestepKeyFrame) | |
class TimestepKeyframeGroup: | |
def __init__(self) -> None: | |
self.keyframes: list[TimestepKeyframe] = [] | |
self.keyframes.append(TimestepKeyframe.default()) | |
def add(self, keyframe: TimestepKeyframe) -> None: | |
# add to end of list, then sort | |
self.keyframes.append(keyframe) | |
self.keyframes = get_sorted_list_via_attr(self.keyframes, attr="start_percent") | |
def get_index(self, index: int) -> Union[TimestepKeyframe, None]: | |
try: | |
return self.keyframes[index] | |
except IndexError: | |
return None | |
def has_index(self, index: int) -> int: | |
return index >=0 and index < len(self.keyframes) | |
def __getitem__(self, index) -> TimestepKeyframe: | |
return self.keyframes[index] | |
def __len__(self) -> int: | |
return len(self.keyframes) | |
def is_empty(self) -> bool: | |
return len(self.keyframes) == 0 | |
def clone(self) -> 'TimestepKeyframeGroup': | |
cloned = TimestepKeyframeGroup() | |
# already sorted, so don't use add function to make cloning quicker | |
for tk in self.keyframes: | |
cloned.keyframes.append(tk) | |
return cloned | |
def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup': | |
group = cls() | |
group.keyframes[0] = keyframe | |
return group | |
class AbstractPreprocWrapper: | |
error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input." | |
def __init__(self, condhint): | |
self.condhint = condhint | |
def movedim(self, *args, **kwargs): | |
return self | |
def __getattr__(self, *args, **kwargs): | |
raise AttributeError(self.error_msg) | |
def __setattr__(self, name, value): | |
if name != "condhint": | |
raise AttributeError(self.error_msg) | |
super().__setattr__(name, value) | |
def __iter__(self, *args, **kwargs): | |
raise AttributeError(self.error_msg) | |
def __next__(self, *args, **kwargs): | |
raise AttributeError(self.error_msg) | |
def __len__(self, *args, **kwargs): | |
raise AttributeError(self.error_msg) | |
def __getitem__(self, *args, **kwargs): | |
raise AttributeError(self.error_msg) | |
def __setitem__(self, *args, **kwargs): | |
raise AttributeError(self.error_msg) | |
# depending on model, AnimateDiff may inject into GroupNorm, so make sure GroupNorm will be clean | |
class disable_weight_init_clean_groupnorm(comfy.ops.disable_weight_init): | |
class GroupNorm(comfy.ops.disable_weight_init.GroupNorm): | |
def forward_comfy_cast_weights(self, input): | |
weight, bias = comfy.ops.cast_bias_weight(self, input) | |
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) | |
def forward(self, input): | |
if self.comfy_cast_weights: | |
return self.forward_comfy_cast_weights(input) | |
else: | |
return torch.nn.functional.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) | |
class manual_cast_clean_groupnorm(comfy.ops.manual_cast): | |
class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm): | |
comfy_cast_weights = True | |
# adapted from comfy/sample.py | |
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False, match_shape=False, flux_shape=None): | |
mask = mask.clone() | |
if flux_shape is not None: | |
multiplier = multiplier * 0.5 | |
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(round(flux_shape[-2]*multiplier), round(flux_shape[-1]*multiplier)), mode="bilinear") | |
mask = rearrange(mask, "b c h w -> b (h w) c") | |
else: | |
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(round(shape[-2]*multiplier), round(shape[-1]*multiplier)), mode="bilinear") | |
if match_dim1: | |
if match_shape and len(shape) < 4: | |
raise Exception(f"match_dim1 cannot be True if shape is under 4 dims; was {len(shape)}.") | |
mask = torch.cat([mask] * shape[1], dim=1) | |
if match_shape and len(shape) == 3 and len(mask.shape) != 3: | |
mask = mask.squeeze(1) | |
return mask | |
# applies min-max normalization, from: | |
# https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch | |
def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0): | |
x_min, x_max = x.min(), x.max() | |
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min | |
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0): | |
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min | |
def extend_to_batch_size(tensor: Tensor, batch_size: int): | |
if tensor.shape[0] > batch_size: | |
return tensor[:batch_size] | |
elif tensor.shape[0] < batch_size: | |
remainder = batch_size-tensor.shape[0] | |
return torch.cat([tensor] + [tensor[-1:]]*remainder, dim=0) | |
return tensor | |
def broadcast_image_to_extend(tensor, target_batch_size, batched_number, except_one=True): | |
current_batch_size = tensor.shape[0] | |
#print(current_batch_size, target_batch_size) | |
if except_one and current_batch_size == 1: | |
return tensor | |
per_batch = target_batch_size // batched_number | |
tensor = tensor[:per_batch] | |
if per_batch > tensor.shape[0]: | |
tensor = extend_to_batch_size(tensor=tensor, batch_size=per_batch) | |
current_batch_size = tensor.shape[0] | |
if current_batch_size == target_batch_size: | |
return tensor | |
else: | |
return torch.cat([tensor] * batched_number, dim=0) | |
# from https://stackoverflow.com/a/24621200 | |
def deepcopy_with_sharing(obj, shared_attribute_names, memo=None): | |
''' | |
Deepcopy an object, except for a given list of attributes, which should | |
be shared between the original object and its copy. | |
obj is some object | |
shared_attribute_names: A list of strings identifying the attributes that | |
should be shared between the original and its copy. | |
memo is the dictionary passed into __deepcopy__. Ignore this argument if | |
not calling from within __deepcopy__. | |
''' | |
assert isinstance(shared_attribute_names, (list, tuple)) | |
shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names} | |
if hasattr(obj, '__deepcopy__'): | |
# Do hack to prevent infinite recursion in call to deepcopy | |
deepcopy_method = obj.__deepcopy__ | |
obj.__deepcopy__ = None | |
for attr in shared_attribute_names: | |
del obj.__dict__[attr] | |
clone = deepcopy(obj) | |
for attr, val in shared_attributes.items(): | |
setattr(obj, attr, val) | |
setattr(clone, attr, val) | |
if hasattr(obj, '__deepcopy__'): | |
# Undo hack | |
obj.__deepcopy__ = deepcopy_method | |
del clone.__deepcopy__ | |
return clone | |
def get_sorted_list_via_attr(objects: list, attr: str) -> list: | |
if not objects: | |
return objects | |
elif len(objects) <= 1: | |
return [x for x in objects] | |
# now that we know we have to sort, do it following these rules: | |
# a) if objects have same value of attribute, maintain their relative order | |
# b) perform sorting of the groups of objects with same attributes | |
unique_attrs = {} | |
for o in objects: | |
val_attr = getattr(o, attr) | |
attr_list: list = unique_attrs.get(val_attr, list()) | |
attr_list.append(o) | |
if val_attr not in unique_attrs: | |
unique_attrs[val_attr] = attr_list | |
# now that we have the unique attr values grouped together in relative order, sort them by key | |
sorted_attrs = dict(sorted(unique_attrs.items())) | |
# now flatten out the dict into a list to return | |
sorted_list = [] | |
for object_list in sorted_attrs.values(): | |
sorted_list.extend(object_list) | |
return sorted_list | |
# DFS Search for Torch.nn.Module, Written by Lvmin | |
def torch_dfs(model: torch.nn.Module): | |
result = [model] | |
for child in model.children(): | |
result += torch_dfs(child) | |
return result | |
class WeightTypeException(TypeError): | |
"Raised when weight not compatible with AdvancedControlBase object" | |
pass | |
class AdvancedControlBase: | |
def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup, weights_default: ControlWeights, require_model=False, require_vae=False, allow_condhint_latents=False): | |
self.base = base | |
self.compatible_weights = [ControlWeightType.UNIVERSAL, ControlWeightType.DEFAULT] | |
self.add_compatible_weight(weights_default.weight_type) | |
# mask for which parts of controlnet output to keep | |
self.mask_cond_hint_original = None | |
self.mask_cond_hint = None | |
self.tk_mask_cond_hint_original = None | |
self.tk_mask_cond_hint = None | |
self.weight_mask_cond_hint = None | |
# actual index values | |
self.sub_idxs = None | |
self.full_latent_length = 0 | |
self.context_length = 0 | |
# timesteps | |
self.t: float = None | |
self.prev_t: float = None | |
self.batched_number: Union[int, IntWithCondOrUncond] = None | |
self.batch_size: int = 0 | |
# weights + override | |
self.weights: ControlWeights = None | |
self.weights_default: ControlWeights = weights_default | |
self.weights_override: ControlWeights = None | |
# latent keyframe + override | |
self.latent_keyframes: LatentKeyframeGroup = None | |
self.latent_keyframe_override: LatentKeyframeGroup = None | |
# initialize timestep_keyframes | |
self.set_timestep_keyframes(timestep_keyframes) | |
# override some functions | |
self.get_control = self.get_control_inject | |
self.control_merge = self.control_merge_inject | |
self.pre_run = self.pre_run_inject | |
self.cleanup = self.cleanup_inject | |
self.set_previous_controlnet = self.set_previous_controlnet_inject | |
self.set_cond_hint = self.set_cond_hint_inject | |
# vae to store | |
self.adv_vae = None | |
# require model/vae to be passed into Apply Advanced ControlNet 🛂🅐🅒🅝 node | |
self.require_model = require_model | |
self.require_vae = require_vae | |
self.allow_condhint_latents = allow_condhint_latents | |
# disarm - when set to False, used to force usage of Apply Advanced ControlNet 🛂🅐🅒🅝 node (which will set it to True) | |
self.disarmed = not require_model | |
def patch_model(self, model: ModelPatcher): | |
pass | |
def add_compatible_weight(self, control_weight_type: str): | |
self.compatible_weights.append(control_weight_type) | |
def verify_all_weights(self, throw_error=True): | |
# first, check if override exists - if so, only need to check the override | |
if self.weights_override is not None: | |
if self.weights_override.weight_type not in self.compatible_weights: | |
msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \ | |
f"only supports {self.compatible_weights} weights." | |
raise WeightTypeException(msg) | |
# otherwise, check all timestep keyframe weights | |
else: | |
for tk in self.timestep_keyframes.keyframes: | |
if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights: | |
msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type " + \ | |
f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights." | |
raise WeightTypeException(msg) | |
def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup): | |
self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroup() | |
# prepare first timestep_keyframe related stuff | |
self._current_timestep_keyframe = None | |
self._current_timestep_index = -1 | |
self._current_used_steps = 0 | |
self.weights = None | |
self.latent_keyframes = None | |
def prepare_current_timestep(self, t: Tensor, batched_number: int=1): | |
self.t = float(t[0]) | |
# check if t has changed (otherwise do nothing, as step already accounted for) | |
if self.t == self.prev_t: | |
return | |
# get current step percent | |
curr_t: float = self.t | |
prev_index = self._current_timestep_index | |
# if met guaranteed steps (or no current keyframe), look for next keyframe in case need to switch | |
if self._current_timestep_keyframe is None or self._current_used_steps >= self._current_timestep_keyframe.guarantee_steps: | |
# if has next index, loop through and see if need to switch | |
if self.timestep_keyframes.has_index(self._current_timestep_index+1): | |
for i in range(self._current_timestep_index+1, len(self.timestep_keyframes)): | |
eval_tk = self.timestep_keyframes[i] | |
# check if start percent is less or equal to curr_t | |
if eval_tk.start_t >= curr_t: | |
self._current_timestep_index = i | |
self._current_timestep_keyframe = eval_tk | |
self._current_used_steps = 0 | |
# keep track of control weights, latent keyframes, and masks, | |
# accounting for inherit_missing | |
if self._current_timestep_keyframe.has_control_weights(): | |
self.weights = self._current_timestep_keyframe.control_weights | |
elif not self._current_timestep_keyframe.inherit_missing: | |
self.weights = self.weights_default | |
if self._current_timestep_keyframe.has_latent_keyframes(): | |
self.latent_keyframes = self._current_timestep_keyframe.latent_keyframes | |
elif not self._current_timestep_keyframe.inherit_missing: | |
self.latent_keyframes = None | |
if self._current_timestep_keyframe.has_mask_hint(): | |
self.tk_mask_cond_hint_original = self._current_timestep_keyframe.mask_hint_orig | |
elif not self._current_timestep_keyframe.inherit_missing: | |
del self.tk_mask_cond_hint_original | |
self.tk_mask_cond_hint_original = None | |
# if guarantee_steps greater than zero, stop searching for other keyframes | |
if self._current_timestep_keyframe.guarantee_steps > 0: | |
break | |
# if eval_tk is outside of percent range, stop looking further | |
else: | |
break | |
# update prev_t | |
self.prev_t = self.t | |
# update steps current keyframe is used | |
self._current_used_steps += 1 | |
# if index changed, apply overrides | |
if prev_index != self._current_timestep_index: | |
if self.weights_override is not None: | |
self.weights = self.weights_override | |
if self.latent_keyframe_override is not None: | |
self.latent_keyframes = self.latent_keyframe_override | |
# make sure weights and latent_keyframes are in a workable state | |
# Note: each AdvancedControlBase should create their own get_universal_weights class | |
self.prepare_weights() | |
def prepare_weights(self): | |
if self.weights is None: | |
self.weights = self.weights_default | |
elif self.weights.weight_type == ControlWeightType.UNIVERSAL: | |
# if universal and weight_mask present, no need to convert | |
if self.weights.weight_mask is not None: | |
return | |
self.weights = self.get_universal_weights() | |
def get_universal_weights(self) -> ControlWeights: | |
return self.weights | |
def set_cond_hint_mask(self, mask_hint): | |
self.mask_cond_hint_original = mask_hint | |
return self | |
def set_cond_hint_inject(self, *args, **kwargs): | |
to_return = self.base.set_cond_hint(*args, **kwargs) | |
# if vae required, look in args and kwargs for it | |
if self.require_vae: | |
# check args first, as that's the default way vae param is used in ComfyUI | |
for arg in args: | |
if isinstance(arg, VAE): | |
self.adv_vae = arg | |
break | |
# if not in args, check kwargs now | |
if self.adv_vae is None: | |
if 'vae' in kwargs: | |
self.adv_vae = kwargs['vae'] | |
return to_return | |
def pre_run_inject(self, model, percent_to_timestep_function): | |
self.base.pre_run(model, percent_to_timestep_function) | |
self.pre_run_advanced(model, percent_to_timestep_function) | |
def pre_run_advanced(self, model, percent_to_timestep_function): | |
# for each timestep keyframe, calculate the start_t | |
for tk in self.timestep_keyframes.keyframes: | |
tk.start_t = percent_to_timestep_function(tk.start_percent) | |
# clear variables | |
self.cleanup_advanced() | |
def set_previous_controlnet_inject(self, *args, **kwargs): | |
to_return = self.base.set_previous_controlnet(*args, **kwargs) | |
if not self.disarmed: | |
raise Exception(f"Type '{type(self).__name__}' must be used with Apply Advanced ControlNet 🛂🅐🅒🅝 node (with model_optional passed in); otherwise, it will not work.") | |
return to_return | |
def disarm(self): | |
self.disarmed = True | |
def should_run(self): | |
if math.isclose(self.strength, 0.0) or math.isclose(self._current_timestep_keyframe.strength, 0.0): | |
return False | |
if self.timestep_range is not None: | |
if self.t > self.timestep_range[0] or self.t < self.timestep_range[1]: | |
return False | |
return True | |
def get_control_inject(self, x_noisy, t, cond, batched_number): | |
self.batched_number = batched_number | |
self.batch_size = len(t) | |
# prepare timestep and everything related | |
self.prepare_current_timestep(t=t, batched_number=batched_number) | |
# if should not perform any actions for the controlnet, exit without doing any work | |
if self.strength == 0.0 or self._current_timestep_keyframe.strength == 0.0: | |
return self.default_control_actions(x_noisy, t, cond, batched_number) | |
# otherwise, perform normal function | |
return self.get_control_advanced(x_noisy, t, cond, batched_number) | |
def get_control_advanced(self, x_noisy, t, cond, batched_number): | |
return self.default_control_actions(x_noisy, t, cond, batched_number) | |
def default_control_actions(self, x_noisy, t, cond, batched_number): | |
control_prev = None | |
if self.previous_controlnet is not None: | |
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) | |
return control_prev | |
def calc_weight(self, idx: int, x: Tensor, control: dict[str, list[Tensor]], key: str) -> Union[float, Tensor]: | |
if self.weights.weight_mask is not None: | |
# prepare weight mask | |
self.prepare_weight_mask_cond_hint(x, self.batched_number) | |
# adjust mask for current layer and return | |
return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, control=control, key=key)) | |
return self.weights.get(idx=idx, control=control, key=key) | |
def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int: | |
if key == "middle": | |
return 0 | |
else: | |
c_len = len(control[key]) | |
real_idx = c_len-idx | |
if key == "input": | |
real_idx = c_len - real_idx + 1 | |
return real_idx | |
def calc_latent_keyframe_mults(self, x: Tensor, batched_number: int) -> Tensor: | |
# apply strengths, and get batch indeces to null out | |
# AKA latents that should not be influenced by ControlNet | |
final_mults = [1.0] * x.shape[0] | |
if self.latent_keyframes: | |
latent_count = x.shape[0] // batched_number | |
indeces_to_null = set(range(latent_count)) | |
mapped_indeces = None | |
# if expecting subdivision, will need to translate between subset and actual idx values | |
if self.sub_idxs: | |
mapped_indeces = {} | |
for i, actual in enumerate(self.sub_idxs): | |
mapped_indeces[actual] = i | |
for keyframe in self.latent_keyframes: | |
real_index = keyframe.batch_index | |
# if negative, count from end | |
if real_index < 0: | |
real_index += latent_count if self.sub_idxs is None else self.full_latent_length | |
# if not mapping indeces, what you see is what you get | |
if mapped_indeces is None: | |
if real_index in indeces_to_null: | |
indeces_to_null.remove(real_index) | |
# otherwise, see if batch_index is even included in this set of latents | |
else: | |
real_index = mapped_indeces.get(real_index, None) | |
if real_index is None: | |
continue | |
indeces_to_null.remove(real_index) | |
# if real_index is outside the bounds of latents, don't apply | |
if real_index >= latent_count or real_index < 0: | |
continue | |
# apply strength for each batched cond/uncond | |
for b in range(batched_number): | |
final_mults[(latent_count*b)+real_index] = keyframe.strength | |
# null them out by multiplying by null_latent_kf_strength | |
for batch_index in indeces_to_null: | |
# apply null for each batched cond/uncond | |
for b in range(batched_number): | |
final_mults[(latent_count*b)+batch_index] = self._current_timestep_keyframe.null_latent_kf_strength | |
# convert final_mults into tensor and match expected dimension count | |
final_tensor = torch.tensor(final_mults, dtype=x.dtype, device=x.device) | |
while len(final_tensor.shape) < len(x.shape): | |
final_tensor = final_tensor.unsqueeze(-1) | |
return final_tensor | |
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape: tuple=None): | |
# handle weight's uncond_multiplier, if applicable | |
if self.weights.has_uncond_multiplier: | |
cond_or_uncond = self.batched_number.cond_or_uncond | |
actual_length = x.size(0) // batched_number | |
for idx, cond_type in enumerate(cond_or_uncond): | |
# if uncond, set to weight's uncond_multiplier | |
if cond_type == 1: | |
x[actual_length*idx:actual_length*(idx+1)] *= self.weights.uncond_multiplier | |
if self.weights.has_uncond_mask: | |
pass | |
if self.latent_keyframes is not None: | |
x[:] = x[:] * self.calc_latent_keyframe_mults(x=x, batched_number=batched_number) | |
# apply masks, resizing mask to required dims | |
if self.mask_cond_hint is not None: | |
masks = prepare_mask_batch(self.mask_cond_hint, x.shape, match_shape=True, flux_shape=flux_shape) | |
x[:] = x[:] * masks | |
if self.tk_mask_cond_hint is not None: | |
masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape, match_shape=True, flux_shape=flux_shape) | |
x[:] = x[:] * masks | |
# apply timestep keyframe strengths | |
if self._current_timestep_keyframe.strength != 1.0: | |
x[:] *= self._current_timestep_keyframe.strength | |
def control_merge_inject(self: 'AdvancedControlBase', control: dict[str, list[Tensor]], control_prev: dict, output_dtype): | |
out = {'input':[], 'middle':[], 'output': []} | |
for key in control: | |
control_output = control[key] | |
applied_to = set() | |
for i in range(len(control_output)): | |
x = control_output[i] | |
if x is not None: | |
if self.global_average_pooling: | |
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) | |
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once | |
applied_to.add(x) | |
self.apply_advanced_strengths_and_masks(x, self.batched_number) | |
x *= self.strength * self.calc_weight(i, x, control, key) | |
if output_dtype is not None and x.dtype != output_dtype: | |
x = x.to(output_dtype) | |
out[key].append(x) | |
if control_prev is not None: | |
for x in ['input', 'middle', 'output']: | |
o = out[x] | |
for i in range(len(control_prev[x])): | |
prev_val = control_prev[x][i] | |
if i >= len(o): | |
o.append(prev_val) | |
elif prev_val is not None: | |
if o[i] is None: | |
o[i] = prev_val | |
else: | |
if o[i].shape[0] < prev_val.shape[0]: | |
o[i] = prev_val + o[i] | |
else: | |
o[i] = prev_val + o[i] # TODO from base ComfyUI: change back to inplace add if shared tensors stop being an issue | |
return out | |
def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): | |
self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn) | |
self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn) | |
def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): | |
return self._prepare_mask("tk_mask_cond_hint", self._current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn) | |
def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None): | |
return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True) | |
def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False): | |
# make mask appropriate dimensions, if present | |
if orig_mask is not None: | |
out_mask = getattr(self, attr_name) | |
multiplier = 1 if direct_attn else 8 | |
if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * multiplier != out_mask.shape[1] or x_noisy.shape[3] * multiplier != out_mask.shape[2]: | |
self._reset_attr(attr_name) | |
del out_mask | |
# TODO: perform upscale on only the sub_idxs masks at a time instead of all to conserve RAM | |
# resize mask and match batch count | |
out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier, match_shape=True) | |
actual_latent_length = x_noisy.shape[0] // batched_number | |
out_mask = extend_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length) | |
if self.sub_idxs is not None: | |
out_mask = out_mask[self.sub_idxs] | |
# make cond_hint_mask length match x_noise | |
if x_noisy.shape[0] != out_mask.shape[0]: | |
out_mask = broadcast_image_to_extend(out_mask, x_noisy.shape[0], batched_number) | |
# default dtype to be same as x_noisy | |
if dtype is None: | |
dtype = x_noisy.dtype | |
setattr(self, attr_name, out_mask.to(dtype=dtype).to(x_noisy.device)) | |
del out_mask | |
def _reset_attr(self, attr_name, new_value=None): | |
if hasattr(self, attr_name): | |
delattr(self, attr_name) | |
setattr(self, attr_name, new_value) | |
def cleanup_inject(self): | |
self.base.cleanup() | |
self.cleanup_advanced() | |
def cleanup_advanced(self): | |
self.sub_idxs = None | |
self.full_latent_length = 0 | |
self.context_length = 0 | |
self.t = None | |
self.prev_t = None | |
self.batched_number = None | |
self.batch_size = 0 | |
self.weights = None | |
self.latent_keyframes = None | |
# timestep stuff | |
self._current_timestep_keyframe = None | |
self._current_timestep_index = -1 | |
self._current_used_steps = 0 | |
# clear mask hints | |
if self.mask_cond_hint is not None: | |
del self.mask_cond_hint | |
self.mask_cond_hint = None | |
if self.tk_mask_cond_hint_original is not None: | |
del self.tk_mask_cond_hint_original | |
self.tk_mask_cond_hint_original = None | |
if self.tk_mask_cond_hint is not None: | |
del self.tk_mask_cond_hint | |
self.tk_mask_cond_hint = None | |
if self.weight_mask_cond_hint is not None: | |
del self.weight_mask_cond_hint | |
self.weight_mask_cond_hint = None | |
def copy_to_advanced(self, copied: 'AdvancedControlBase'): | |
copied.mask_cond_hint_original = self.mask_cond_hint_original | |
copied.weights_override = self.weights_override | |
copied.latent_keyframe_override = self.latent_keyframe_override | |
copied.adv_vae = self.adv_vae | |
copied.require_vae = self.require_vae | |
copied.allow_condhint_latents = self.allow_condhint_latents | |
copied.disarmed = self.disarmed | |