multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
4.57 kB
"""A set of constants and utilities for handling contexts.
Sets up the inputs and outputs for the Context going forward, with additional functions for
creating and exporting context objects.
"""
import comfy.samplers
import folder_paths
_all_context_input_output_data = {
"base_ctx": ("base_ctx", "RGTHREE_CONTEXT", "CONTEXT"),
"model": ("model", "MODEL", "MODEL"),
"clip": ("clip", "CLIP", "CLIP"),
"vae": ("vae", "VAE", "VAE"),
"positive": ("positive", "CONDITIONING", "POSITIVE"),
"negative": ("negative", "CONDITIONING", "NEGATIVE"),
"latent": ("latent", "LATENT", "LATENT"),
"images": ("images", "IMAGE", "IMAGE"),
"seed": ("seed", "INT", "SEED"),
"steps": ("steps", "INT", "STEPS"),
"step_refiner": ("step_refiner", "INT", "STEP_REFINER"),
"cfg": ("cfg", "FLOAT", "CFG"),
"ckpt_name": ("ckpt_name", folder_paths.get_filename_list("checkpoints"), "CKPT_NAME"),
"sampler": ("sampler", comfy.samplers.KSampler.SAMPLERS, "SAMPLER"),
"scheduler": ("scheduler", comfy.samplers.KSampler.SCHEDULERS, "SCHEDULER"),
"clip_width": ("clip_width", "INT", "CLIP_WIDTH"),
"clip_height": ("clip_height", "INT", "CLIP_HEIGHT"),
"text_pos_g": ("text_pos_g", "STRING", "TEXT_POS_G"),
"text_pos_l": ("text_pos_l", "STRING", "TEXT_POS_L"),
"text_neg_g": ("text_neg_g", "STRING", "TEXT_NEG_G"),
"text_neg_l": ("text_neg_l", "STRING", "TEXT_NEG_L"),
"mask": ("mask", "MASK", "MASK"),
"control_net": ("control_net", "CONTROL_NET", "CONTROL_NET"),
}
force_input_types = ["INT", "STRING", "FLOAT"]
force_input_names = ["sampler", "scheduler", "ckpt_name"]
def _create_context_data(input_list=None):
"""Returns a tuple of context inputs, return types, and return names to use in a node"s def"""
if input_list is None:
input_list = _all_context_input_output_data.keys()
list_ctx_return_types = []
list_ctx_return_names = []
ctx_optional_inputs = {}
for inp in input_list:
data = _all_context_input_output_data[inp]
list_ctx_return_types.append(data[1])
list_ctx_return_names.append(data[2])
ctx_optional_inputs[data[0]] = tuple([data[1]] + ([{
"forceInput": True
}] if data[1] in force_input_types or data[0] in force_input_names else []))
ctx_return_types = tuple(list_ctx_return_types)
ctx_return_names = tuple(list_ctx_return_names)
return (ctx_optional_inputs, ctx_return_types, ctx_return_names)
ALL_CTX_OPTIONAL_INPUTS, ALL_CTX_RETURN_TYPES, ALL_CTX_RETURN_NAMES = _create_context_data()
_original_ctx_inputs_list = [
"base_ctx", "model", "clip", "vae", "positive", "negative", "latent", "images", "seed"
]
ORIG_CTX_OPTIONAL_INPUTS, ORIG_CTX_RETURN_TYPES, ORIG_CTX_RETURN_NAMES = _create_context_data(
_original_ctx_inputs_list)
def new_context(base_ctx, **kwargs):
"""Creates a new context from the provided data, with an optional base ctx to start."""
context = base_ctx if base_ctx is not None else None
new_ctx = {}
for key in _all_context_input_output_data:
if key == "base_ctx":
continue
v = kwargs[key] if key in kwargs else None
new_ctx[key] = v if v is not None else context[
key] if context is not None and key in context else None
return new_ctx
def merge_new_context(*args):
"""Creates a new context by merging provided contexts with the latter overriding same fields."""
new_ctx = {}
for key in _all_context_input_output_data:
if key == "base_ctx":
continue
v = None
# Move backwards through the passed contexts until we find a value and use it.
for ctx in reversed(args):
v = ctx[key] if not is_context_empty(ctx) and key in ctx else None
if v is not None:
break
new_ctx[key] = v
return new_ctx
def get_context_return_tuple(ctx, inputs_list=None):
"""Returns a tuple for returning in the order of the inputs list."""
if inputs_list is None:
inputs_list = _all_context_input_output_data.keys()
tup_list = [
ctx,
]
for key in inputs_list:
if key == "base_ctx":
continue
tup_list.append(ctx[key] if ctx is not None and key in ctx else None)
return tuple(tup_list)
def get_orig_context_return_tuple(ctx):
"""Returns a tuple for returning from a node with only the original context keys."""
return get_context_return_tuple(ctx, _original_ctx_inputs_list)
def is_context_empty(ctx):
"""Checks if the provided ctx is None or contains just None values."""
return not ctx or all(v is None for v in ctx.values())