Spaces:
Running
on
L40S
Running
on
L40S
import itertools | |
import json | |
import math | |
import os | |
import comfy.model_management as model_management | |
import folder_paths | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image, ImageOps | |
from PIL.PngImagePlugin import PngInfo | |
from skimage.filters import gaussian | |
from skimage.util import compare_images | |
from ..log import log | |
from ..utils import np2tensor, pil2tensor, tensor2pil | |
# try: | |
# from cv2.ximgproc import guidedFilter | |
# except ImportError: | |
# log.warning("cv2.ximgproc.guidedFilter not found, use opencv-contrib-python") | |
def gaussian_kernel( | |
kernel_size: int, sigma_x: float, sigma_y: float, device=None | |
): | |
x, y = torch.meshgrid( | |
torch.linspace(-1, 1, kernel_size, device=device), | |
torch.linspace(-1, 1, kernel_size, device=device), | |
indexing="ij", | |
) | |
d_x = x * x / (2.0 * sigma_x * sigma_x) | |
d_y = y * y / (2.0 * sigma_y * sigma_y) | |
g = torch.exp(-(d_x + d_y)) | |
return g / g.sum() | |
class MTB_CoordinatesToString: | |
RETURN_TYPES = ("STRING",) | |
FUNCTION = "convert" | |
CATEGORY = "mtb/coordinates" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"coordinates": ("BATCH_COORDINATES",), | |
"frame": ("INT",), | |
} | |
} | |
def convert( | |
self, coordinates: list[list[tuple[int, int]]], frame: int | |
) -> tuple[str]: | |
frame = max(frame, len(coordinates) - 1) | |
coords = coordinates[frame] | |
output: list[dict[str, int]] = [] | |
for x, y in coords: | |
output.append({"x": x, "y": y}) | |
return (json.dumps(output),) | |
class MTB_ExtractCoordinatesFromImage: | |
"""Extract 2D points from a batch of images based on a threshold.""" | |
RETURN_TYPES = ("BATCH_COORDINATES", "IMAGE") | |
FUNCTION = "extract" | |
CATEGORY = "mtb/coordinates" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"threshold": ("FLOAT",), | |
"max_points": ("INT", {"default": 50, "min": 0}), | |
}, | |
"optional": {"image": ("IMAGE",), "mask": ("MASK",)}, | |
} | |
def extract( | |
self, | |
threshold: float, | |
max_points: int, | |
image: torch.Tensor | None = None, | |
mask: torch.Tensor | None = None, | |
) -> tuple[list[list[tuple[int, int]]], torch.Tensor]: | |
if image is not None: | |
batch_count, height, width, channel_count = image.shape | |
imgs = image | |
else: | |
if mask is None: | |
raise ValueError("Must provide either image or mask") | |
batch_count, height, width = mask.shape | |
channel_count = 1 | |
imgs = mask | |
if channel_count not in [1, 2, 3, 4]: | |
raise ValueError(f"Incorrect channel count: {channel_count}") | |
all_points: list[list[tuple[int, int]]] = [] | |
debug_images = torch.zeros( | |
(batch_count, height, width, 3), | |
dtype=torch.uint8, | |
device=imgs.device, | |
) | |
for i, img in enumerate(imgs): | |
if channel_count == 1: | |
alpha_channel = img if len(img.shape) == 2 else img[:, :, 0] | |
elif channel_count == 2: | |
alpha_channel = img[:, :, 1] | |
elif channel_count == 4: | |
alpha_channel = img[:, :, 3] | |
else: | |
# get intensity | |
alpha_channel = img[:, :, :3].max(dim=2)[0] | |
points = (alpha_channel > threshold).nonzero(as_tuple=False) | |
if len(points) > max_points: | |
indices = torch.randperm(points.size(0), device=img.device)[ | |
:max_points | |
] | |
points = points[indices] | |
points = [(int(y.item()), int(x.item())) for x, y in points] | |
all_points.append(points) | |
for x, y in points: | |
self._draw_circle(debug_images[i], (x, y), 5) | |
return (all_points, debug_images) | |
def _draw_circle( | |
image: torch.Tensor, center: tuple[int, int], radius: int | |
): | |
"""Draw a 5px circle on the image.""" | |
x0, y0 = center | |
for x in range(-radius, radius + 1): | |
for y in range(-radius, radius + 1): | |
in_radius = x**2 + y**2 <= radius**2 | |
in_bounds = ( | |
0 <= x0 + x < image.shape[1] | |
and 0 <= y0 + y < image.shape[0] | |
) | |
if in_radius and in_bounds: | |
image[y0 + y, x0 + x] = torch.tensor( | |
[255, 255, 255], | |
dtype=torch.uint8, | |
device=image.device, | |
) | |
class MTB_ColorCorrectGPU: | |
"""Various color correction methods using only Torch.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"force_gpu": ("BOOLEAN", {"default": True}), | |
"clamp": ([True, False], {"default": True}), | |
"gamma": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
"contrast": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
"exposure": ( | |
"FLOAT", | |
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, | |
), | |
"offset": ( | |
"FLOAT", | |
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, | |
), | |
"hue": ( | |
"FLOAT", | |
{"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01}, | |
), | |
"saturation": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
"value": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
}, | |
"optional": {"mask": ("MASK",)}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "correct" | |
CATEGORY = "mtb/image processing" | |
def get_device(tensor: torch.Tensor, force_gpu: bool): | |
if force_gpu: | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
elif ( | |
hasattr(torch.backends, "mps") | |
and torch.backends.mps.is_available() | |
): | |
return torch.device("mps") | |
elif hasattr(torch, "hip") and torch.hip.is_available(): | |
return torch.device("hip") | |
return ( | |
tensor.device | |
) # model_management.get_torch_device() # torch.device("cpu") | |
def rgb_to_hsv(image: torch.Tensor): | |
r, g, b = image.unbind(-1) | |
max_rgb, argmax_rgb = image.max(-1) | |
min_rgb, _ = image.min(-1) | |
diff = max_rgb - min_rgb | |
h = torch.empty_like(max_rgb) | |
s = diff / (max_rgb + 1e-7) | |
v = max_rgb | |
h[argmax_rgb == 0] = (g - b)[argmax_rgb == 0] / (diff + 1e-7)[ | |
argmax_rgb == 0 | |
] | |
h[argmax_rgb == 1] = ( | |
2.0 + (b - r)[argmax_rgb == 1] / (diff + 1e-7)[argmax_rgb == 1] | |
) | |
h[argmax_rgb == 2] = ( | |
4.0 + (r - g)[argmax_rgb == 2] / (diff + 1e-7)[argmax_rgb == 2] | |
) | |
h = (h / 6.0) % 1.0 | |
h = h.unsqueeze(-1) | |
s = s.unsqueeze(-1) | |
v = v.unsqueeze(-1) | |
return torch.cat((h, s, v), dim=-1) | |
def hsv_to_rgb(hsv: torch.Tensor): | |
h, s, v = hsv.unbind(-1) | |
h = h * 6.0 | |
i = torch.floor(h) | |
f = h - i | |
p = v * (1.0 - s) | |
q = v * (1.0 - s * f) | |
t = v * (1.0 - s * (1.0 - f)) | |
i = i.long() % 6 | |
mask = torch.stack( | |
(i == 0, i == 1, i == 2, i == 3, i == 4, i == 5), -1 | |
) | |
rgb = torch.stack( | |
( | |
torch.where( | |
mask[..., 0], | |
v, | |
torch.where( | |
mask[..., 1], | |
q, | |
torch.where( | |
mask[..., 2], | |
p, | |
torch.where( | |
mask[..., 3], | |
p, | |
torch.where(mask[..., 4], t, v), | |
), | |
), | |
), | |
), | |
torch.where( | |
mask[..., 0], | |
t, | |
torch.where( | |
mask[..., 1], | |
v, | |
torch.where( | |
mask[..., 2], | |
v, | |
torch.where( | |
mask[..., 3], | |
q, | |
torch.where(mask[..., 4], p, p), | |
), | |
), | |
), | |
), | |
torch.where( | |
mask[..., 0], | |
p, | |
torch.where( | |
mask[..., 1], | |
p, | |
torch.where( | |
mask[..., 2], | |
t, | |
torch.where( | |
mask[..., 3], | |
v, | |
torch.where(mask[..., 4], v, q), | |
), | |
), | |
), | |
), | |
), | |
dim=-1, | |
) | |
return rgb | |
def correct( | |
self, | |
image: torch.Tensor, | |
force_gpu: bool, | |
clamp: bool, | |
gamma: float = 1.0, | |
contrast: float = 1.0, | |
exposure: float = 0.0, | |
offset: float = 0.0, | |
hue: float = 0.0, | |
saturation: float = 1.0, | |
value: float = 1.0, | |
mask: torch.Tensor | None = None, | |
): | |
device = self.get_device(image, force_gpu) | |
image = image.to(device) | |
if mask is not None: | |
if mask.shape[0] != image.shape[0]: | |
mask = mask.expand(image.shape[0], -1, -1) | |
mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3) | |
mask = mask.to(device) | |
model_management.throw_exception_if_processing_interrupted() | |
adjusted = image.pow(1 / gamma) * (2.0**exposure) * contrast + offset | |
model_management.throw_exception_if_processing_interrupted() | |
hsv = self.rgb_to_hsv(adjusted) | |
hsv[..., 0] = (hsv[..., 0] + hue) % 1.0 # Hue | |
hsv[..., 1] = hsv[..., 1] * saturation # Saturation | |
hsv[..., 2] = hsv[..., 2] * value # Value | |
adjusted = self.hsv_to_rgb(hsv) | |
model_management.throw_exception_if_processing_interrupted() | |
if clamp: | |
adjusted = torch.clamp(adjusted, 0.0, 1.0) | |
# apply mask | |
result = ( | |
adjusted | |
if mask is None | |
else torch.where(mask > 0, adjusted, image) | |
) | |
if not force_gpu: | |
result = result.cpu() | |
return (result,) | |
class MTB_ColorCorrect: | |
"""Various color correction methods""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"clamp": ([True, False], {"default": True}), | |
"gamma": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
"contrast": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
"exposure": ( | |
"FLOAT", | |
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, | |
), | |
"offset": ( | |
"FLOAT", | |
{"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, | |
), | |
"hue": ( | |
"FLOAT", | |
{"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01}, | |
), | |
"saturation": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
"value": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, | |
), | |
}, | |
"optional": {"mask": ("MASK",)}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "correct" | |
CATEGORY = "mtb/image processing" | |
def gamma_correction_tensor(image, gamma): | |
gamma_inv = 1.0 / gamma | |
return image.pow(gamma_inv) | |
def contrast_adjustment_tensor(image, contrast): | |
r, g, b = image.unbind(-1) | |
# Using Adobe RGB luminance weights. | |
luminance_image = 0.33 * r + 0.71 * g + 0.06 * b | |
luminance_mean = torch.mean(luminance_image.unsqueeze(-1)) | |
# Blend original with mean luminance using contrast factor as blend ratio. | |
contrasted = image * contrast + (1.0 - contrast) * luminance_mean | |
return torch.clamp(contrasted, 0.0, 1.0) | |
def exposure_adjustment_tensor(image, exposure): | |
return image * (2.0**exposure) | |
def offset_adjustment_tensor(image, offset): | |
return image + offset | |
def hsv_adjustment(image: torch.Tensor, hue, saturation, value): | |
images = tensor2pil(image) | |
out = [] | |
for img in images: | |
hsv_image = img.convert("HSV") | |
h, s, v = hsv_image.split() | |
h = h.point(lambda x: (x + hue * 255) % 256) | |
s = s.point(lambda x: int(x * saturation)) | |
v = v.point(lambda x: int(x * value)) | |
hsv_image = Image.merge("HSV", (h, s, v)) | |
rgb_image = hsv_image.convert("RGB") | |
out.append(rgb_image) | |
return pil2tensor(out) | |
def hsv_adjustment_tensor_not_working( | |
image: torch.Tensor, hue, saturation, value | |
): | |
"""Abandonning for now""" | |
image = image.squeeze(0).permute(2, 0, 1) | |
max_val, _ = image.max(dim=0, keepdim=True) | |
min_val, _ = image.min(dim=0, keepdim=True) | |
delta = max_val - min_val | |
hue_image = torch.zeros_like(max_val) | |
mask = delta != 0.0 | |
r, g, b = image[0], image[1], image[2] | |
hue_image[mask & (max_val == r)] = ((g - b) / delta)[ | |
mask & (max_val == r) | |
] % 6.0 | |
hue_image[mask & (max_val == g)] = ((b - r) / delta)[ | |
mask & (max_val == g) | |
] + 2.0 | |
hue_image[mask & (max_val == b)] = ((r - g) / delta)[ | |
mask & (max_val == b) | |
] + 4.0 | |
saturation_image = delta / (max_val + 1e-7) | |
value_image = max_val | |
hue_image = (hue_image + hue) % 1.0 | |
saturation_image = torch.where( | |
mask, saturation * saturation_image, saturation_image | |
) | |
value_image = value * value_image | |
c = value_image * saturation_image | |
x = c * (1 - torch.abs((hue_image % 2) - 1)) | |
m = value_image - c | |
prime_image = torch.zeros_like(image) | |
prime_image[0] = torch.where( | |
max_val == r, c, torch.where(max_val == g, x, prime_image[0]) | |
) | |
prime_image[1] = torch.where( | |
max_val == r, x, torch.where(max_val == g, c, prime_image[1]) | |
) | |
prime_image[2] = torch.where( | |
max_val == g, x, torch.where(max_val == b, c, prime_image[2]) | |
) | |
rgb_image = prime_image + m | |
rgb_image = rgb_image.permute(1, 2, 0).unsqueeze(0) | |
return rgb_image | |
def correct( | |
self, | |
image: torch.Tensor, | |
clamp: bool, | |
gamma: float = 1.0, | |
contrast: float = 1.0, | |
exposure: float = 0.0, | |
offset: float = 0.0, | |
hue: float = 0.0, | |
saturation: float = 1.0, | |
value: float = 1.0, | |
mask: torch.Tensor | None = None, | |
): | |
if mask is not None: | |
if mask.shape[0] != image.shape[0]: | |
mask = mask.expand(image.shape[0], -1, -1) | |
mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3) | |
# Apply color correction operations | |
adjusted = self.gamma_correction_tensor(image, gamma) | |
adjusted = self.contrast_adjustment_tensor(adjusted, contrast) | |
adjusted = self.exposure_adjustment_tensor(adjusted, exposure) | |
adjusted = self.offset_adjustment_tensor(adjusted, offset) | |
adjusted = self.hsv_adjustment(adjusted, hue, saturation, value) | |
if clamp: | |
adjusted = torch.clamp(image, 0.0, 1.0) | |
result = ( | |
adjusted | |
if mask is None | |
else torch.where(mask > 0, adjusted, image) | |
) | |
return (result,) | |
class MTB_ImageCompare: | |
"""Compare two images and return a difference image""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"imageA": ("IMAGE",), | |
"imageB": ("IMAGE",), | |
"mode": ( | |
["checkerboard", "diff", "blend"], | |
{"default": "checkerboard"}, | |
), | |
} | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "compare" | |
CATEGORY = "mtb/image" | |
def compare(self, imageA: torch.Tensor, imageB: torch.Tensor, mode): | |
if imageA.dim() == 4: | |
batch_count = imageA.size(0) | |
return ( | |
torch.cat( | |
tuple( | |
self.compare(imageA[i], imageB[i], mode)[0] | |
for i in range(batch_count) | |
), | |
dim=0, | |
), | |
) | |
num_channels_A = imageA.size(2) | |
num_channels_B = imageB.size(2) | |
# handle RGBA/RGB mismatch | |
if num_channels_A == 3 and num_channels_B == 4: | |
imageA = torch.cat( | |
(imageA, torch.ones_like(imageA[:, :, 0:1])), dim=2 | |
) | |
elif num_channels_B == 3 and num_channels_A == 4: | |
imageB = torch.cat( | |
(imageB, torch.ones_like(imageB[:, :, 0:1])), dim=2 | |
) | |
match mode: | |
case "diff": | |
compare_image = torch.abs(imageA - imageB) | |
case "blend": | |
compare_image = 0.5 * (imageA + imageB) | |
case "checkerboard": | |
imageA = imageA.numpy() | |
imageB = imageB.numpy() | |
compared_channels = [ | |
torch.from_numpy( | |
compare_images( | |
imageA[:, :, i], imageB[:, :, i], method=mode | |
) | |
) | |
for i in range(imageA.shape[2]) | |
] | |
compare_image = torch.stack(compared_channels, dim=2) | |
case _: | |
compare_image = None | |
raise ValueError(f"Unknown mode {mode}") | |
compare_image = compare_image.unsqueeze(0) | |
return (compare_image,) | |
import requests | |
class MTB_LoadImageFromUrl: | |
"""Load an image from the given URL""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"url": ( | |
"STRING", | |
{ | |
"default": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" | |
}, | |
), | |
} | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "load" | |
CATEGORY = "mtb/IO" | |
def load(self, url): | |
# get the image from the url | |
image = Image.open(requests.get(url, stream=True).raw) | |
image = ImageOps.exif_transpose(image) | |
return (pil2tensor(image),) | |
class MTB_Blur: | |
"""Blur an image using a Gaussian filter.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"sigmaX": ( | |
"FLOAT", | |
{"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01}, | |
), | |
"sigmaY": ( | |
"FLOAT", | |
{"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01}, | |
), | |
}, | |
"optional": {"sigmasX": ("FLOATS",), "sigmasY": ("FLOATS",)}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "blur" | |
CATEGORY = "mtb/image processing" | |
def blur( | |
self, image: torch.Tensor, sigmaX, sigmaY, sigmasX=None, sigmasY=None | |
): | |
image_np = image.numpy() * 255 | |
blurred_images = [] | |
if sigmasX is not None: | |
if sigmasY is None: | |
sigmasY = sigmasX | |
if len(sigmasX) != image.size(0): | |
raise ValueError( | |
f"SigmasX must have same length as image, sigmasX is {len(sigmasX)} but the batch size is {image.size(0)}" | |
) | |
for i in range(image.size(0)): | |
blurred = gaussian( | |
image_np[i], | |
sigma=(sigmasX[i], sigmasY[i], 0), | |
channel_axis=2, | |
) | |
blurred_images.append(blurred) | |
image_np = np.array(blurred_images) | |
else: | |
for i in range(image.size(0)): | |
blurred = gaussian( | |
image_np[i], sigma=(sigmaX, sigmaY, 0), channel_axis=2 | |
) | |
blurred_images.append(blurred) | |
image_np = np.array(blurred_images) | |
return (np2tensor(image_np).squeeze(0),) | |
class MTB_Sharpen: | |
"""Sharpens an image using a Gaussian kernel.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"sharpen_radius": ( | |
"INT", | |
{"default": 1, "min": 1, "max": 31, "step": 1}, | |
), | |
"sigma_x": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}, | |
), | |
"sigma_y": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}, | |
), | |
"alpha": ( | |
"FLOAT", | |
{"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1}, | |
), | |
}, | |
} | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "do_sharp" | |
CATEGORY = "mtb/image processing" | |
def do_sharp( | |
self, | |
image: torch.Tensor, | |
sharpen_radius: int, | |
sigma_x: float, | |
sigma_y: float, | |
alpha: float, | |
): | |
if sharpen_radius == 0: | |
return (image,) | |
channels = image.shape[3] | |
kernel_size = 2 * sharpen_radius + 1 | |
kernel = gaussian_kernel(kernel_size, sigma_x, sigma_y) * -(alpha * 10) | |
# Modify center of kernel to make it a sharpening kernel | |
center = kernel_size // 2 | |
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 | |
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) | |
tensor_image = image.permute(0, 3, 1, 2) | |
tensor_image = F.pad( | |
tensor_image, | |
(sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius), | |
"reflect", | |
) | |
sharpened = F.conv2d( | |
tensor_image, kernel, padding=center, groups=channels | |
) | |
# Remove padding | |
sharpened = sharpened[ | |
:, | |
:, | |
sharpen_radius:-sharpen_radius, | |
sharpen_radius:-sharpen_radius, | |
] | |
sharpened = sharpened.permute(0, 2, 3, 1) | |
result = torch.clamp(sharpened, 0, 1) | |
return (result,) | |
# https://github.com/lllyasviel/AdverseCleaner/blob/main/clean.py | |
# def deglaze_np_img(np_img): | |
# y = np_img.copy() | |
# for _ in range(64): | |
# y = cv2.bilateralFilter(y, 5, 8, 8) | |
# for _ in range(4): | |
# y = guidedFilter(np_img, y, 4, 16) | |
# return y | |
# class DeglazeImage: | |
# """Remove adversarial noise from images""" | |
# @classmethod | |
# def INPUT_TYPES(cls): | |
# return {"required": {"image": ("IMAGE",)}} | |
# CATEGORY = "mtb/image processing" | |
# RETURN_TYPES = ("IMAGE",) | |
# FUNCTION = "deglaze_image" | |
# def deglaze_image(self, image): | |
# return (np2tensor(deglaze_np_img(tensor2np(image))),) | |
class MTB_MaskToImage: | |
"""Converts a mask (alpha) to an RGB image with a color and background""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
"color": ("COLOR",), | |
"background": ("COLOR", {"default": "#000000"}), | |
}, | |
"optional": { | |
"invert": ("BOOLEAN", {"default": False}), | |
}, | |
} | |
CATEGORY = "mtb/generate" | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "render_mask" | |
def render_mask(self, mask, color, background, invert=False): | |
masks = tensor2pil(1.0 - mask) if invert else tensor2pil(mask) | |
images = [] | |
for m in masks: | |
_mask = m.convert("L") | |
log.debug( | |
f"Converted mask to PIL Image format, size: {_mask.size}" | |
) | |
image = Image.new("RGBA", _mask.size, color=color) | |
# apply the mask | |
image = Image.composite( | |
image, Image.new("RGBA", _mask.size, color=background), _mask | |
) | |
# image = ImageChops.multiply(image, mask) | |
# apply over background | |
# image = Image.alpha_composite(Image.new("RGBA", image.size, color=background), image) | |
images.append(image.convert("RGB")) | |
return (pil2tensor(images),) | |
class MTB_ColoredImage: | |
"""Constant color image of given size.""" | |
def __init__(self) -> None: | |
pass | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"color": ("COLOR",), | |
"width": ("INT", {"default": 512, "min": 16, "max": 8160}), | |
"height": ("INT", {"default": 512, "min": 16, "max": 8160}), | |
}, | |
"optional": { | |
"foreground_image": ("IMAGE",), | |
"foreground_mask": ("MASK",), | |
"invert": ("BOOLEAN", {"default": False}), | |
"mask_opacity": ( | |
"FLOAT", | |
{"default": 1.0, "step": 0.1, "min": 0}, | |
), | |
}, | |
} | |
CATEGORY = "mtb/generate" | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "render_img" | |
def resize_and_crop(self, img: Image.Image, target_size: tuple[int, int]): | |
scale = max(target_size[0] / img.width, target_size[1] / img.height) | |
new_size = (int(img.width * scale), int(img.height * scale)) | |
img = img.resize(new_size, Image.LANCZOS) | |
left = (img.width - target_size[0]) // 2 | |
top = (img.height - target_size[1]) // 2 | |
return img.crop( | |
(left, top, left + target_size[0], top + target_size[1]) | |
) | |
def resize_and_crop_thumbnails( | |
self, img: Image.Image, target_size: tuple[int, int] | |
): | |
img.thumbnail(target_size, Image.LANCZOS) | |
left = (img.width - target_size[0]) / 2 | |
top = (img.height - target_size[1]) / 2 | |
right = (img.width + target_size[0]) / 2 | |
bottom = (img.height + target_size[1]) / 2 | |
return img.crop((left, top, right, bottom)) | |
def process_mask( | |
mask: torch.Tensor | None, | |
invert: bool, | |
# opacity: float, | |
batch_size: int, | |
) -> list[Image.Image] | None: | |
if mask is None: | |
return [None] * batch_size | |
masks = tensor2pil(mask if not invert else 1.0 - mask) | |
if len(masks) == 1 and batch_size > 1: | |
masks = masks * batch_size | |
if len(masks) != batch_size: | |
raise ValueError( | |
"Foreground image and mask must have the same batch size" | |
) | |
return masks | |
def render_img( | |
self, | |
color: str, | |
width: int, | |
height: int, | |
foreground_image: torch.Tensor | None = None, | |
foreground_mask: torch.Tensor | None = None, | |
invert: bool = False, | |
mask_opacity: float = 1.0, | |
) -> tuple[torch.Tensor]: | |
background = Image.new("RGBA", (width, height), color=color) | |
if foreground_image is None: | |
return (pil2tensor([background.convert("RGB")]),) | |
fg_images = tensor2pil(foreground_image) | |
fg_masks = self.process_mask(foreground_mask, invert, len(fg_images)) | |
output: list[Image.Image] = [] | |
for fg_image, fg_mask in zip(fg_images, fg_masks, strict=False): | |
fg_image = self.resize_and_crop(fg_image, background.size) | |
if fg_mask: | |
fg_mask = self.resize_and_crop(fg_mask, background.size) | |
fg_mask_array = np.array(fg_mask) | |
fg_mask_array = (fg_mask_array * mask_opacity).astype(np.uint8) | |
fg_mask = Image.fromarray(fg_mask_array) | |
output.append( | |
Image.composite( | |
fg_image.convert("RGBA"), background, fg_mask | |
).convert("RGB") | |
) | |
else: | |
if fg_image.mode != "RGBA": | |
raise ValueError( | |
f"Foreground image must be in 'RGBA' mode when no mask is provided, got {fg_image.mode}" | |
) | |
output.append( | |
Image.alpha_composite(background, fg_image).convert("RGB") | |
) | |
return (pil2tensor(output),) | |
class MTB_ImagePremultiply: | |
"""Premultiply image with mask""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"mask": ("MASK",), | |
"invert": ("BOOLEAN", {"default": False}), | |
} | |
} | |
CATEGORY = "mtb/image" | |
RETURN_TYPES = ("IMAGE",) | |
RETURN_NAMES = ("RGBA",) | |
FUNCTION = "premultiply" | |
def premultiply(self, image, mask, invert): | |
images = tensor2pil(image) | |
masks = tensor2pil(mask) if invert else tensor2pil(1.0 - mask) | |
single = len(mask) == 1 | |
masks = [x.convert("L") for x in masks] | |
out = [] | |
for i, img in enumerate(images): | |
cur_mask = masks[0] if single else masks[i] | |
img.putalpha(cur_mask) | |
out.append(img) | |
# if invert: | |
# image = Image.composite(image,Image.new("RGBA", image.size, color=(0,0,0,0)), mask) | |
# else: | |
# image = Image.composite(Image.new("RGBA", image.size, color=(0,0,0,0)), image, mask) | |
return (pil2tensor(out),) | |
class MTB_ImageResizeFactor: | |
"""Extracted mostly from WAS Node Suite, with a few edits (most notably multiple image support) and less features.""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"factor": ( | |
"FLOAT", | |
{"default": 2, "min": 0.01, "max": 16.0, "step": 0.01}, | |
), | |
"supersample": ("BOOLEAN", {"default": True}), | |
"resampling": ( | |
[ | |
"nearest", | |
"linear", | |
"bilinear", | |
"bicubic", | |
"trilinear", | |
"area", | |
"nearest-exact", | |
], | |
{"default": "nearest"}, | |
), | |
}, | |
"optional": { | |
"mask": ("MASK",), | |
}, | |
} | |
CATEGORY = "mtb/image" | |
RETURN_TYPES = ("IMAGE", "MASK") | |
FUNCTION = "resize" | |
def resize( | |
self, | |
image: torch.Tensor, | |
factor: float, | |
supersample: bool, | |
resampling: str, | |
mask=None, | |
): | |
# Check if the tensor has the correct dimension | |
if len(image.shape) not in [3, 4]: # HxWxC or BxHxWxC | |
raise ValueError( | |
"Expected image tensor of shape (H, W, C) or (B, H, W, C)" | |
) | |
# Transpose to CxHxW or BxCxHxW for PyTorch | |
if len(image.shape) == 3: | |
image = image.permute(2, 0, 1).unsqueeze(0) # CxHxW | |
else: | |
image = image.permute(0, 3, 1, 2) # BxCxHxW | |
# Compute new dimensions | |
B, C, H, W = image.shape | |
new_H, new_W = int(H * factor), int(W * factor) | |
align_corner_filters = ("linear", "bilinear", "bicubic", "trilinear") | |
# Resize the image | |
resized_image = F.interpolate( | |
image, | |
size=(new_H, new_W), | |
mode=resampling, | |
align_corners=resampling in align_corner_filters, | |
) | |
# Optionally supersample | |
if supersample: | |
resized_image = F.interpolate( | |
resized_image, | |
scale_factor=2, | |
mode=resampling, | |
align_corners=resampling in align_corner_filters, | |
) | |
# Transpose back to the original format: BxHxWxC or HxWxC | |
if len(image.shape) == 4: | |
resized_image = resized_image.permute(0, 2, 3, 1) | |
else: | |
resized_image = resized_image.squeeze(0).permute(1, 2, 0) | |
# Apply mask if provided | |
if mask is not None: | |
if len(mask.shape) != len(resized_image.shape): | |
raise ValueError( | |
"Mask tensor should have the same dimensions as the image tensor" | |
) | |
resized_image = resized_image * mask | |
return (resized_image,) | |
class MTB_SaveImageGrid: | |
"""Save all the images in the input batch as a grid of images.""" | |
def __init__(self): | |
self.output_dir = folder_paths.get_output_directory() | |
self.type = "output" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"images": ("IMAGE",), | |
"filename_prefix": ("STRING", {"default": "ComfyUI"}), | |
"save_intermediate": ("BOOLEAN", {"default": False}), | |
}, | |
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, | |
} | |
RETURN_TYPES = () | |
FUNCTION = "save_images" | |
OUTPUT_NODE = True | |
CATEGORY = "mtb/IO" | |
def create_image_grid(self, image_list): | |
total_images = len(image_list) | |
# Calculate the grid size based on the square root of the total number of images | |
grid_size = ( | |
int(math.sqrt(total_images)), | |
int(math.ceil(math.sqrt(total_images))), | |
) | |
# Get the size of the first image to determine the grid size | |
image_width, image_height = image_list[0].size | |
# Create a new blank image to hold the grid | |
grid_width = grid_size[0] * image_width | |
grid_height = grid_size[1] * image_height | |
grid_image = Image.new("RGB", (grid_width, grid_height)) | |
# Iterate over the images and paste them onto the grid | |
for i, image in enumerate(image_list): | |
x = (i % grid_size[0]) * image_width | |
y = (i // grid_size[0]) * image_height | |
grid_image.paste(image, (x, y, x + image_width, y + image_height)) | |
return grid_image | |
def save_images( | |
self, | |
images, | |
filename_prefix="Grid", | |
save_intermediate=False, | |
prompt=None, | |
extra_pnginfo=None, | |
): | |
( | |
full_output_folder, | |
filename, | |
counter, | |
subfolder, | |
filename_prefix, | |
) = folder_paths.get_save_image_path( | |
filename_prefix, | |
self.output_dir, | |
images[0].shape[1], | |
images[0].shape[0], | |
) | |
image_list = [] | |
batch_counter = counter | |
metadata = PngInfo() | |
if prompt is not None: | |
metadata.add_text("prompt", json.dumps(prompt)) | |
if extra_pnginfo is not None: | |
for x in extra_pnginfo: | |
metadata.add_text(x, json.dumps(extra_pnginfo[x])) | |
for idx, image in enumerate(images): | |
i = 255.0 * image.cpu().numpy() | |
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) | |
image_list.append(img) | |
if save_intermediate: | |
file = f"{filename}_batch-{idx:03}_{batch_counter:05}_.png" | |
img.save( | |
os.path.join(full_output_folder, file), | |
pnginfo=metadata, | |
compress_level=4, | |
) | |
batch_counter += 1 | |
file = f"{filename}_{counter:05}_.png" | |
grid = self.create_image_grid(image_list) | |
grid.save( | |
os.path.join(full_output_folder, file), | |
pnginfo=metadata, | |
compress_level=4, | |
) | |
results = [ | |
{"filename": file, "subfolder": subfolder, "type": self.type} | |
] | |
return {"ui": {"images": results}} | |
class MTB_ImageTileOffset: | |
"""Mimics an old photoshop technique to check for seamless textures""" | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"image": ("IMAGE",), | |
"tilesX": ("INT", {"default": 2, "min": 1}), | |
"tilesY": ("INT", {"default": 2, "min": 1}), | |
} | |
} | |
CATEGORY = "mtb/generate" | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "tile_image" | |
def tile_image( | |
self, image: torch.Tensor, tilesX: int = 2, tilesY: int = 2 | |
): | |
if tilesX < 1 or tilesY < 1: | |
raise ValueError("The number of tiles must be at least 1.") | |
batch_size, height, width, channels = image.shape | |
tile_height = height // tilesY | |
tile_width = width // tilesX | |
output_image = torch.zeros_like(image) | |
for i, j in itertools.product(range(tilesY), range(tilesX)): | |
start_h = i * tile_height | |
end_h = start_h + tile_height | |
start_w = j * tile_width | |
end_w = start_w + tile_width | |
tile = image[:, start_h:end_h, start_w:end_w, :] | |
output_start_h = (i + 1) % tilesY * tile_height | |
output_start_w = (j + 1) % tilesX * tile_width | |
output_end_h = output_start_h + tile_height | |
output_end_w = output_start_w + tile_width | |
output_image[ | |
:, output_start_h:output_end_h, output_start_w:output_end_w, : | |
] = tile | |
return (output_image,) | |
__nodes__ = [ | |
MTB_ColorCorrect, | |
MTB_ColorCorrectGPU, | |
MTB_ImageCompare, | |
MTB_ImageTileOffset, | |
MTB_Blur, | |
# DeglazeImage, | |
MTB_MaskToImage, | |
MTB_ColoredImage, | |
MTB_ImagePremultiply, | |
MTB_ImageResizeFactor, | |
MTB_SaveImageGrid, | |
MTB_LoadImageFromUrl, | |
MTB_Sharpen, | |
MTB_ExtractCoordinatesFromImage, | |
MTB_CoordinatesToString, | |
] | |