Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
from .tools import VariantSupport | |
from comfy_execution.graph_utils import GraphBuilder | |
class TestLazyMixImages: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image1": ("IMAGE",{"lazy": True}), | |
"image2": ("IMAGE",{"lazy": True}), | |
"mask": ("MASK",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "mix" | |
CATEGORY = "Testing/Nodes" | |
def check_lazy_status(self, mask, image1, image2): | |
mask_min = mask.min() | |
mask_max = mask.max() | |
needed = [] | |
if image1 is None and (mask_min != 1.0 or mask_max != 1.0): | |
needed.append("image1") | |
if image2 is None and (mask_min != 0.0 or mask_max != 0.0): | |
needed.append("image2") | |
return needed | |
# Not trying to handle different batch sizes here just to keep the demo simple | |
def mix(self, mask, image1, image2): | |
mask_min = mask.min() | |
mask_max = mask.max() | |
if mask_min == 0.0 and mask_max == 0.0: | |
return (image1,) | |
elif mask_min == 1.0 and mask_max == 1.0: | |
return (image2,) | |
if len(mask.shape) == 2: | |
mask = mask.unsqueeze(0) | |
if len(mask.shape) == 3: | |
mask = mask.unsqueeze(3) | |
if mask.shape[3] < image1.shape[3]: | |
mask = mask.repeat(1, 1, 1, image1.shape[3]) | |
result = image1 * (1. - mask) + image2 * mask, | |
return (result[0],) | |
class TestVariadicAverage: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("IMAGE",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "variadic_average" | |
CATEGORY = "Testing/Nodes" | |
def variadic_average(self, input1, **kwargs): | |
inputs = [input1] | |
while 'input' + str(len(inputs) + 1) in kwargs: | |
inputs.append(kwargs['input' + str(len(inputs) + 1)]) | |
return (torch.stack(inputs).mean(dim=0),) | |
class TestCustomIsChanged: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
}, | |
"optional": { | |
"should_change": ("BOOL", {"default": False}), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_is_changed" | |
CATEGORY = "Testing/Nodes" | |
def custom_is_changed(self, image, should_change=False): | |
return (image,) | |
def IS_CHANGED(cls, should_change=False, *args, **kwargs): | |
if should_change: | |
return float("NaN") | |
else: | |
return False | |
class TestIsChangedWithConstants: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_is_changed" | |
CATEGORY = "Testing/Nodes" | |
def custom_is_changed(self, image, value): | |
return (image * value,) | |
def IS_CHANGED(cls, image, value): | |
if image is None: | |
return value | |
else: | |
return image.mean().item() * value | |
class TestCustomValidation1: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("IMAGE,FLOAT",), | |
"input2": ("IMAGE,FLOAT",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_validation1" | |
CATEGORY = "Testing/Nodes" | |
def custom_validation1(self, input1, input2): | |
if isinstance(input1, float) and isinstance(input2, float): | |
result = torch.ones([1, 512, 512, 3]) * input1 * input2 | |
else: | |
result = input1 * input2 | |
return (result,) | |
def VALIDATE_INPUTS(cls, input1=None, input2=None): | |
if input1 is not None: | |
if not isinstance(input1, (torch.Tensor, float)): | |
return f"Invalid type of input1: {type(input1)}" | |
if input2 is not None: | |
if not isinstance(input2, (torch.Tensor, float)): | |
return f"Invalid type of input2: {type(input2)}" | |
return True | |
class TestCustomValidation2: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("IMAGE,FLOAT",), | |
"input2": ("IMAGE,FLOAT",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_validation2" | |
CATEGORY = "Testing/Nodes" | |
def custom_validation2(self, input1, input2): | |
if isinstance(input1, float) and isinstance(input2, float): | |
result = torch.ones([1, 512, 512, 3]) * input1 * input2 | |
else: | |
result = input1 * input2 | |
return (result,) | |
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): | |
if input1 is not None: | |
if not isinstance(input1, (torch.Tensor, float)): | |
return f"Invalid type of input1: {type(input1)}" | |
if input2 is not None: | |
if not isinstance(input2, (torch.Tensor, float)): | |
return f"Invalid type of input2: {type(input2)}" | |
if 'input1' in input_types: | |
if input_types['input1'] not in ["IMAGE", "FLOAT"]: | |
return f"Invalid type of input1: {input_types['input1']}" | |
if 'input2' in input_types: | |
if input_types['input2'] not in ["IMAGE", "FLOAT"]: | |
return f"Invalid type of input2: {input_types['input2']}" | |
return True | |
class TestCustomValidation3: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("IMAGE,FLOAT",), | |
"input2": ("IMAGE,FLOAT",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_validation3" | |
CATEGORY = "Testing/Nodes" | |
def custom_validation3(self, input1, input2): | |
if isinstance(input1, float) and isinstance(input2, float): | |
result = torch.ones([1, 512, 512, 3]) * input1 * input2 | |
else: | |
result = input1 * input2 | |
return (result,) | |
class TestCustomValidation4: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("FLOAT",), | |
"input2": ("FLOAT",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_validation4" | |
CATEGORY = "Testing/Nodes" | |
def custom_validation4(self, input1, input2): | |
result = torch.ones([1, 512, 512, 3]) * input1 * input2 | |
return (result,) | |
def VALIDATE_INPUTS(cls, input1, input2): | |
if input1 is not None: | |
if not isinstance(input1, float): | |
return f"Invalid type of input1: {type(input1)}" | |
if input2 is not None: | |
if not isinstance(input2, float): | |
return f"Invalid type of input2: {type(input2)}" | |
return True | |
class TestCustomValidation5: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("FLOAT", {"min": 0.0, "max": 1.0}), | |
"input2": ("FLOAT", {"min": 0.0, "max": 1.0}), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "custom_validation5" | |
CATEGORY = "Testing/Nodes" | |
def custom_validation5(self, input1, input2): | |
value = input1 * input2 | |
return (torch.ones([1, 512, 512, 3]) * value,) | |
def VALIDATE_INPUTS(cls, **kwargs): | |
if kwargs['input2'] == 7.0: | |
return "7s are not allowed. I've never liked 7s." | |
return True | |
class TestDynamicDependencyCycle: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("IMAGE",), | |
"input2": ("IMAGE",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "dynamic_dependency_cycle" | |
CATEGORY = "Testing/Nodes" | |
def dynamic_dependency_cycle(self, input1, input2): | |
g = GraphBuilder() | |
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) | |
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0)) | |
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0)) | |
# Create the cyle | |
mix1.set_input("image2", mix2.out(0)) | |
return { | |
"result": (mix2.out(0),), | |
"expand": g.finalize(), | |
} | |
class TestMixedExpansionReturns: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"input1": ("FLOAT",), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE","IMAGE") | |
FUNCTION = "mixed_expansion_returns" | |
CATEGORY = "Testing/Nodes" | |
def mixed_expansion_returns(self, input1): | |
white_image = torch.ones([1, 512, 512, 3]) | |
if input1 <= 0.1: | |
return (torch.ones([1, 512, 512, 3]) * 0.1, white_image) | |
elif input1 <= 0.2: | |
return { | |
"result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image), | |
} | |
else: | |
g = GraphBuilder() | |
mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1) | |
black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0)) | |
return { | |
"result": (mix.out(0), white_image), | |
"expand": g.finalize(), | |
} | |
TEST_NODE_CLASS_MAPPINGS = { | |
"TestLazyMixImages": TestLazyMixImages, | |
"TestVariadicAverage": TestVariadicAverage, | |
"TestCustomIsChanged": TestCustomIsChanged, | |
"TestIsChangedWithConstants": TestIsChangedWithConstants, | |
"TestCustomValidation1": TestCustomValidation1, | |
"TestCustomValidation2": TestCustomValidation2, | |
"TestCustomValidation3": TestCustomValidation3, | |
"TestCustomValidation4": TestCustomValidation4, | |
"TestCustomValidation5": TestCustomValidation5, | |
"TestDynamicDependencyCycle": TestDynamicDependencyCycle, | |
"TestMixedExpansionReturns": TestMixedExpansionReturns, | |
} | |
TEST_NODE_DISPLAY_NAME_MAPPINGS = { | |
"TestLazyMixImages": "Lazy Mix Images", | |
"TestVariadicAverage": "Variadic Average", | |
"TestCustomIsChanged": "Custom IsChanged", | |
"TestIsChangedWithConstants": "IsChanged With Constants", | |
"TestCustomValidation1": "Custom Validation 1", | |
"TestCustomValidation2": "Custom Validation 2", | |
"TestCustomValidation3": "Custom Validation 3", | |
"TestCustomValidation4": "Custom Validation 4", | |
"TestCustomValidation5": "Custom Validation 5", | |
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle", | |
"TestMixedExpansionReturns": "Mixed Expansion Returns", | |
} | |