Spaces:
Running
on
L40S
Running
on
L40S
import csv | |
import shutil | |
from pathlib import Path | |
import folder_paths | |
import torch | |
from ..log import log | |
from ..utils import here | |
Conditioning = list[tuple[torch.Tensor, dict[str, torch.Tensor]]] | |
def check_condition(conditioning: Conditioning): | |
has_cn = False | |
if len(conditioning) > 1: | |
log.warn( | |
"More than one conditioning was provided. Only the first one will be used." | |
) | |
first = conditioning[0] | |
cond, kwargs = first | |
log.debug("Conditioning Shape") | |
log.debug(cond.shape) | |
log.debug("Conditioning keys") | |
log.debug([f"\t{k} - {type(kwargs[k])}" for k in kwargs]) | |
if "control" in kwargs: | |
log.debug("Conditioning contains a controlnet") | |
has_cn = True | |
if "pooled_output" not in kwargs: | |
raise ValueError( | |
"Conditioning is not valid. Missing 'pooled_output' key." | |
) | |
return has_cn | |
class MTB_InterpolateCondition: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"blend": ( | |
"FLOAT", | |
{"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
CATEGORY = "mtb/conditioning" | |
FUNCTION = "execute" | |
def execute( | |
self, blend: float, **kwargs: Conditioning | |
) -> tuple[Conditioning]: | |
blend = max(0.0, min(1.0, blend)) | |
conditions: list[Conditioning] = list(kwargs.values()) | |
num_conditions = len(conditions) | |
if num_conditions < 2: | |
raise ValueError("At least two conditioning inputs are required.") | |
segment_length = 1.0 / (num_conditions - 1) | |
segment_index = min(int(blend // segment_length), num_conditions - 2) | |
local_blend = ( | |
blend - (segment_index * segment_length) | |
) / segment_length | |
cond_from = conditions[segment_index] | |
cond_to = conditions[segment_index + 1] | |
from_cn = check_condition(cond_from) | |
to_cn = check_condition(cond_to) | |
if from_cn and to_cn: | |
raise ValueError( | |
"Interpolating conditions cannot both contain ControlNets" | |
) | |
try: | |
interpolated_condition = [ | |
(1.0 - local_blend) * c_from + local_blend * c_to | |
for c_from, c_to in zip( | |
cond_from[0][0], cond_to[0][0], strict=False | |
) | |
] | |
except Exception as e: | |
print(f"Error during interpolation: {e}") | |
raise | |
pooled_from = cond_from[0][1].get( | |
"pooled_output", | |
torch.zeros_like( | |
next(iter(cond_from[0][1].values()), torch.tensor([])) | |
), | |
) | |
pooled_to = cond_to[0][1].get( | |
"pooled_output", | |
torch.zeros_like( | |
next(iter(cond_from[0][1].values()), torch.tensor([])) | |
), | |
) | |
interpolated_pooled = ( | |
1.0 - local_blend | |
) * pooled_from + local_blend * pooled_to | |
res = {"pooled_output": interpolated_pooled} | |
if from_cn: | |
res["control"] = cond_from[0][1]["control"] | |
res["control_apply_to_uncond"] = cond_from[0][1][ | |
"control_apply_to_uncond" | |
] | |
if to_cn: | |
res["control"] = cond_to[0][1]["control"] | |
res["control_apply_to_uncond"] = cond_to[0][1][ | |
"control_apply_to_uncond" | |
] | |
return ([(torch.stack(interpolated_condition), res)],) | |
class MTB_InterpolateClipSequential: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"base_text": ("STRING", {"multiline": True}), | |
"text_to_replace": ("STRING", {"default": ""}), | |
"clip": ("CLIP",), | |
"interpolation_strength": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, | |
), | |
} | |
} | |
RETURN_TYPES = ("CONDITIONING",) | |
FUNCTION = "interpolate_encodings_sequential" | |
CATEGORY = "mtb/conditioning" | |
def interpolate_encodings_sequential( | |
self, | |
base_text, | |
text_to_replace, | |
clip, | |
interpolation_strength, | |
**replacements, | |
): | |
log.debug(f"Received interpolation_strength: {interpolation_strength}") | |
# - Ensure interpolation strength is within [0, 1] | |
interpolation_strength = max(0.0, min(1.0, interpolation_strength)) | |
# - Check if replacements were provided | |
if not replacements: | |
raise ValueError("At least one replacement should be provided.") | |
num_replacements = len(replacements) | |
log.debug(f"Number of replacements: {num_replacements}") | |
segment_length = 1.0 / num_replacements | |
log.debug(f"Calculated segment_length: {segment_length}") | |
# - Find the segment that the interpolation_strength falls into | |
segment_index = min( | |
int(interpolation_strength // segment_length), num_replacements - 1 | |
) | |
log.debug(f"Segment index: {segment_index}") | |
# - Calculate the local strength within the segment | |
local_strength = ( | |
interpolation_strength - (segment_index * segment_length) | |
) / segment_length | |
log.debug(f"Local strength: {local_strength}") | |
# - If it's the first segment, interpolate between base_text and the first replacement | |
if segment_index == 0: | |
replacement_text = list(replacements.values())[0] | |
log.debug("Using the base text a the base blend") | |
# - Start with the base_text condition | |
tokens = clip.tokenize(base_text) | |
cond_from, pooled_from = clip.encode_from_tokens( | |
tokens, return_pooled=True | |
) | |
else: | |
base_replace = list(replacements.values())[segment_index - 1] | |
log.debug(f"Using {base_replace} a the base blend") | |
# - Start with the base_text condition replaced by the closest replacement | |
tokens = clip.tokenize( | |
base_text.replace(text_to_replace, base_replace) | |
) | |
cond_from, pooled_from = clip.encode_from_tokens( | |
tokens, return_pooled=True | |
) | |
replacement_text = list(replacements.values())[segment_index] | |
interpolated_text = base_text.replace( | |
text_to_replace, replacement_text | |
) | |
tokens = clip.tokenize(interpolated_text) | |
cond_to, pooled_to = clip.encode_from_tokens( | |
tokens, return_pooled=True | |
) | |
# - Linearly interpolate between the two conditions | |
interpolated_condition = ( | |
1.0 - local_strength | |
) * cond_from + local_strength * cond_to | |
interpolated_pooled = ( | |
1.0 - local_strength | |
) * pooled_from + local_strength * pooled_to | |
return ( | |
[[interpolated_condition, {"pooled_output": interpolated_pooled}]], | |
) | |
class MTB_SmartStep: | |
"""Utils to control the steps start/stop of the KAdvancedSampler in percentage""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"step": ( | |
"INT", | |
{"default": 20, "min": 1, "max": 10000, "step": 1}, | |
), | |
"start_percent": ( | |
"INT", | |
{"default": 0, "min": 0, "max": 100, "step": 1}, | |
), | |
"end_percent": ( | |
"INT", | |
{"default": 0, "min": 0, "max": 100, "step": 1}, | |
), | |
} | |
} | |
RETURN_TYPES = ("INT", "INT", "INT") | |
RETURN_NAMES = ("step", "start", "end") | |
FUNCTION = "do_step" | |
CATEGORY = "mtb/conditioning" | |
def do_step(self, step, start_percent, end_percent): | |
start = int(step * start_percent / 100) | |
end = int(step * end_percent / 100) | |
return (step, start, end) | |
def install_default_styles(force=False): | |
styles_dir = Path(folder_paths.base_path) / "styles" | |
styles_dir.mkdir(parents=True, exist_ok=True) | |
default_style = here / "styles.csv" | |
dest_style = styles_dir / "default.csv" | |
if force or not dest_style.exists(): | |
log.debug(f"Copying default style to {dest_style}") | |
shutil.copy2(default_style.as_posix(), dest_style.as_posix()) | |
return dest_style | |
class MTB_StylesLoader: | |
"""Load csv files and populate a dropdown from the rows (à la A111)""" | |
options = {} | |
def INPUT_TYPES(cls): | |
if not cls.options: | |
input_dir = Path(folder_paths.base_path) / "styles" | |
if not input_dir.exists(): | |
install_default_styles() | |
if not ( | |
files := [f for f in input_dir.iterdir() if f.suffix == ".csv"] | |
): | |
log.warn( | |
"No styles found in the styles folder, place at least one csv file in the styles folder at the root of ComfyUI (for instance ComfyUI/styles/mystyle.csv)" | |
) | |
for file in files: | |
with open(file, encoding="utf8") as f: | |
parsed = csv.reader(f) | |
for i, row in enumerate(parsed): | |
# log.debug(f"Adding style {row[0]}") | |
try: | |
name, positive, negative = (row + [None] * 3)[:3] | |
positive = positive or "" | |
negative = negative or "" | |
if name is not None: | |
cls.options[name] = (positive, negative) | |
else: | |
# Handle the case where 'name' is None | |
log.warning(f"Missing 'name' in row {i}.") | |
except Exception as e: | |
log.warning( | |
f"There was an error while parsing {file}, make sure it respects A1111 format, i.e 3 columns name, positive, negative:\n{e}" | |
) | |
continue | |
else: | |
log.debug(f"Using cached styles (count: {len(cls.options)})") | |
return { | |
"required": { | |
"style_name": (list(cls.options.keys()),), | |
} | |
} | |
CATEGORY = "mtb/conditioning" | |
RETURN_TYPES = ("STRING", "STRING") | |
RETURN_NAMES = ("positive", "negative") | |
FUNCTION = "load_style" | |
def load_style(self, style_name): | |
return (self.options[style_name][0], self.options[style_name][1]) | |
__nodes__ = [ | |
MTB_SmartStep, | |
MTB_StylesLoader, | |
MTB_InterpolateClipSequential, | |
MTB_InterpolateCondition, | |
] | |