multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
10.9 kB
import csv
import shutil
from pathlib import Path
import folder_paths
import torch
from ..log import log
from ..utils import here
Conditioning = list[tuple[torch.Tensor, dict[str, torch.Tensor]]]
def check_condition(conditioning: Conditioning):
has_cn = False
if len(conditioning) > 1:
log.warn(
"More than one conditioning was provided. Only the first one will be used."
)
first = conditioning[0]
cond, kwargs = first
log.debug("Conditioning Shape")
log.debug(cond.shape)
log.debug("Conditioning keys")
log.debug([f"\t{k} - {type(kwargs[k])}" for k in kwargs])
if "control" in kwargs:
log.debug("Conditioning contains a controlnet")
has_cn = True
if "pooled_output" not in kwargs:
raise ValueError(
"Conditioning is not valid. Missing 'pooled_output' key."
)
return has_cn
class MTB_InterpolateCondition:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"blend": (
"FLOAT",
{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01},
),
},
}
RETURN_TYPES = ("CONDITIONING",)
CATEGORY = "mtb/conditioning"
FUNCTION = "execute"
def execute(
self, blend: float, **kwargs: Conditioning
) -> tuple[Conditioning]:
blend = max(0.0, min(1.0, blend))
conditions: list[Conditioning] = list(kwargs.values())
num_conditions = len(conditions)
if num_conditions < 2:
raise ValueError("At least two conditioning inputs are required.")
segment_length = 1.0 / (num_conditions - 1)
segment_index = min(int(blend // segment_length), num_conditions - 2)
local_blend = (
blend - (segment_index * segment_length)
) / segment_length
cond_from = conditions[segment_index]
cond_to = conditions[segment_index + 1]
from_cn = check_condition(cond_from)
to_cn = check_condition(cond_to)
if from_cn and to_cn:
raise ValueError(
"Interpolating conditions cannot both contain ControlNets"
)
try:
interpolated_condition = [
(1.0 - local_blend) * c_from + local_blend * c_to
for c_from, c_to in zip(
cond_from[0][0], cond_to[0][0], strict=False
)
]
except Exception as e:
print(f"Error during interpolation: {e}")
raise
pooled_from = cond_from[0][1].get(
"pooled_output",
torch.zeros_like(
next(iter(cond_from[0][1].values()), torch.tensor([]))
),
)
pooled_to = cond_to[0][1].get(
"pooled_output",
torch.zeros_like(
next(iter(cond_from[0][1].values()), torch.tensor([]))
),
)
interpolated_pooled = (
1.0 - local_blend
) * pooled_from + local_blend * pooled_to
res = {"pooled_output": interpolated_pooled}
if from_cn:
res["control"] = cond_from[0][1]["control"]
res["control_apply_to_uncond"] = cond_from[0][1][
"control_apply_to_uncond"
]
if to_cn:
res["control"] = cond_to[0][1]["control"]
res["control_apply_to_uncond"] = cond_to[0][1][
"control_apply_to_uncond"
]
return ([(torch.stack(interpolated_condition), res)],)
class MTB_InterpolateClipSequential:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"base_text": ("STRING", {"multiline": True}),
"text_to_replace": ("STRING", {"default": ""}),
"clip": ("CLIP",),
"interpolation_strength": (
"FLOAT",
{"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01},
),
}
}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "interpolate_encodings_sequential"
CATEGORY = "mtb/conditioning"
def interpolate_encodings_sequential(
self,
base_text,
text_to_replace,
clip,
interpolation_strength,
**replacements,
):
log.debug(f"Received interpolation_strength: {interpolation_strength}")
# - Ensure interpolation strength is within [0, 1]
interpolation_strength = max(0.0, min(1.0, interpolation_strength))
# - Check if replacements were provided
if not replacements:
raise ValueError("At least one replacement should be provided.")
num_replacements = len(replacements)
log.debug(f"Number of replacements: {num_replacements}")
segment_length = 1.0 / num_replacements
log.debug(f"Calculated segment_length: {segment_length}")
# - Find the segment that the interpolation_strength falls into
segment_index = min(
int(interpolation_strength // segment_length), num_replacements - 1
)
log.debug(f"Segment index: {segment_index}")
# - Calculate the local strength within the segment
local_strength = (
interpolation_strength - (segment_index * segment_length)
) / segment_length
log.debug(f"Local strength: {local_strength}")
# - If it's the first segment, interpolate between base_text and the first replacement
if segment_index == 0:
replacement_text = list(replacements.values())[0]
log.debug("Using the base text a the base blend")
# - Start with the base_text condition
tokens = clip.tokenize(base_text)
cond_from, pooled_from = clip.encode_from_tokens(
tokens, return_pooled=True
)
else:
base_replace = list(replacements.values())[segment_index - 1]
log.debug(f"Using {base_replace} a the base blend")
# - Start with the base_text condition replaced by the closest replacement
tokens = clip.tokenize(
base_text.replace(text_to_replace, base_replace)
)
cond_from, pooled_from = clip.encode_from_tokens(
tokens, return_pooled=True
)
replacement_text = list(replacements.values())[segment_index]
interpolated_text = base_text.replace(
text_to_replace, replacement_text
)
tokens = clip.tokenize(interpolated_text)
cond_to, pooled_to = clip.encode_from_tokens(
tokens, return_pooled=True
)
# - Linearly interpolate between the two conditions
interpolated_condition = (
1.0 - local_strength
) * cond_from + local_strength * cond_to
interpolated_pooled = (
1.0 - local_strength
) * pooled_from + local_strength * pooled_to
return (
[[interpolated_condition, {"pooled_output": interpolated_pooled}]],
)
class MTB_SmartStep:
"""Utils to control the steps start/stop of the KAdvancedSampler in percentage"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"step": (
"INT",
{"default": 20, "min": 1, "max": 10000, "step": 1},
),
"start_percent": (
"INT",
{"default": 0, "min": 0, "max": 100, "step": 1},
),
"end_percent": (
"INT",
{"default": 0, "min": 0, "max": 100, "step": 1},
),
}
}
RETURN_TYPES = ("INT", "INT", "INT")
RETURN_NAMES = ("step", "start", "end")
FUNCTION = "do_step"
CATEGORY = "mtb/conditioning"
def do_step(self, step, start_percent, end_percent):
start = int(step * start_percent / 100)
end = int(step * end_percent / 100)
return (step, start, end)
def install_default_styles(force=False):
styles_dir = Path(folder_paths.base_path) / "styles"
styles_dir.mkdir(parents=True, exist_ok=True)
default_style = here / "styles.csv"
dest_style = styles_dir / "default.csv"
if force or not dest_style.exists():
log.debug(f"Copying default style to {dest_style}")
shutil.copy2(default_style.as_posix(), dest_style.as_posix())
return dest_style
class MTB_StylesLoader:
"""Load csv files and populate a dropdown from the rows (à la A111)"""
options = {}
@classmethod
def INPUT_TYPES(cls):
if not cls.options:
input_dir = Path(folder_paths.base_path) / "styles"
if not input_dir.exists():
install_default_styles()
if not (
files := [f for f in input_dir.iterdir() if f.suffix == ".csv"]
):
log.warn(
"No styles found in the styles folder, place at least one csv file in the styles folder at the root of ComfyUI (for instance ComfyUI/styles/mystyle.csv)"
)
for file in files:
with open(file, encoding="utf8") as f:
parsed = csv.reader(f)
for i, row in enumerate(parsed):
# log.debug(f"Adding style {row[0]}")
try:
name, positive, negative = (row + [None] * 3)[:3]
positive = positive or ""
negative = negative or ""
if name is not None:
cls.options[name] = (positive, negative)
else:
# Handle the case where 'name' is None
log.warning(f"Missing 'name' in row {i}.")
except Exception as e:
log.warning(
f"There was an error while parsing {file}, make sure it respects A1111 format, i.e 3 columns name, positive, negative:\n{e}"
)
continue
else:
log.debug(f"Using cached styles (count: {len(cls.options)})")
return {
"required": {
"style_name": (list(cls.options.keys()),),
}
}
CATEGORY = "mtb/conditioning"
RETURN_TYPES = ("STRING", "STRING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "load_style"
def load_style(self, style_name):
return (self.options[style_name][0], self.options[style_name][1])
__nodes__ = [
MTB_SmartStep,
MTB_StylesLoader,
MTB_InterpolateClipSequential,
MTB_InterpolateCondition,
]