import itertools import json import math import os import comfy.model_management as model_management import folder_paths import numpy as np import torch import torch.nn.functional as F from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo from skimage.filters import gaussian from skimage.util import compare_images from ..log import log from ..utils import np2tensor, pil2tensor, tensor2pil # try: # from cv2.ximgproc import guidedFilter # except ImportError: # log.warning("cv2.ximgproc.guidedFilter not found, use opencv-contrib-python") def gaussian_kernel( kernel_size: int, sigma_x: float, sigma_y: float, device=None ): x, y = torch.meshgrid( torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij", ) d_x = x * x / (2.0 * sigma_x * sigma_x) d_y = y * y / (2.0 * sigma_y * sigma_y) g = torch.exp(-(d_x + d_y)) return g / g.sum() class MTB_CoordinatesToString: RETURN_TYPES = ("STRING",) FUNCTION = "convert" CATEGORY = "mtb/coordinates" @classmethod def INPUT_TYPES(cls): return { "required": { "coordinates": ("BATCH_COORDINATES",), "frame": ("INT",), } } def convert( self, coordinates: list[list[tuple[int, int]]], frame: int ) -> tuple[str]: frame = max(frame, len(coordinates) - 1) coords = coordinates[frame] output: list[dict[str, int]] = [] for x, y in coords: output.append({"x": x, "y": y}) return (json.dumps(output),) class MTB_ExtractCoordinatesFromImage: """Extract 2D points from a batch of images based on a threshold.""" RETURN_TYPES = ("BATCH_COORDINATES", "IMAGE") FUNCTION = "extract" CATEGORY = "mtb/coordinates" @classmethod def INPUT_TYPES(cls): return { "required": { "threshold": ("FLOAT",), "max_points": ("INT", {"default": 50, "min": 0}), }, "optional": {"image": ("IMAGE",), "mask": ("MASK",)}, } def extract( self, threshold: float, max_points: int, image: torch.Tensor | None = None, mask: torch.Tensor | None = None, ) -> tuple[list[list[tuple[int, int]]], torch.Tensor]: if image is not None: batch_count, height, width, channel_count = image.shape imgs = image else: if mask is None: raise ValueError("Must provide either image or mask") batch_count, height, width = mask.shape channel_count = 1 imgs = mask if channel_count not in [1, 2, 3, 4]: raise ValueError(f"Incorrect channel count: {channel_count}") all_points: list[list[tuple[int, int]]] = [] debug_images = torch.zeros( (batch_count, height, width, 3), dtype=torch.uint8, device=imgs.device, ) for i, img in enumerate(imgs): if channel_count == 1: alpha_channel = img if len(img.shape) == 2 else img[:, :, 0] elif channel_count == 2: alpha_channel = img[:, :, 1] elif channel_count == 4: alpha_channel = img[:, :, 3] else: # get intensity alpha_channel = img[:, :, :3].max(dim=2)[0] points = (alpha_channel > threshold).nonzero(as_tuple=False) if len(points) > max_points: indices = torch.randperm(points.size(0), device=img.device)[ :max_points ] points = points[indices] points = [(int(y.item()), int(x.item())) for x, y in points] all_points.append(points) for x, y in points: self._draw_circle(debug_images[i], (x, y), 5) return (all_points, debug_images) @staticmethod def _draw_circle( image: torch.Tensor, center: tuple[int, int], radius: int ): """Draw a 5px circle on the image.""" x0, y0 = center for x in range(-radius, radius + 1): for y in range(-radius, radius + 1): in_radius = x**2 + y**2 <= radius**2 in_bounds = ( 0 <= x0 + x < image.shape[1] and 0 <= y0 + y < image.shape[0] ) if in_radius and in_bounds: image[y0 + y, x0 + x] = torch.tensor( [255, 255, 255], dtype=torch.uint8, device=image.device, ) class MTB_ColorCorrectGPU: """Various color correction methods using only Torch.""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "force_gpu": ("BOOLEAN", {"default": True}), "clamp": ([True, False], {"default": True}), "gamma": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), "contrast": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), "exposure": ( "FLOAT", {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, ), "offset": ( "FLOAT", {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, ), "hue": ( "FLOAT", {"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01}, ), "saturation": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), "value": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), }, "optional": {"mask": ("MASK",)}, } RETURN_TYPES = ("IMAGE",) FUNCTION = "correct" CATEGORY = "mtb/image processing" @staticmethod def get_device(tensor: torch.Tensor, force_gpu: bool): if force_gpu: if torch.cuda.is_available(): return torch.device("cuda") elif ( hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ): return torch.device("mps") elif hasattr(torch, "hip") and torch.hip.is_available(): return torch.device("hip") return ( tensor.device ) # model_management.get_torch_device() # torch.device("cpu") @staticmethod def rgb_to_hsv(image: torch.Tensor): r, g, b = image.unbind(-1) max_rgb, argmax_rgb = image.max(-1) min_rgb, _ = image.min(-1) diff = max_rgb - min_rgb h = torch.empty_like(max_rgb) s = diff / (max_rgb + 1e-7) v = max_rgb h[argmax_rgb == 0] = (g - b)[argmax_rgb == 0] / (diff + 1e-7)[ argmax_rgb == 0 ] h[argmax_rgb == 1] = ( 2.0 + (b - r)[argmax_rgb == 1] / (diff + 1e-7)[argmax_rgb == 1] ) h[argmax_rgb == 2] = ( 4.0 + (r - g)[argmax_rgb == 2] / (diff + 1e-7)[argmax_rgb == 2] ) h = (h / 6.0) % 1.0 h = h.unsqueeze(-1) s = s.unsqueeze(-1) v = v.unsqueeze(-1) return torch.cat((h, s, v), dim=-1) @staticmethod def hsv_to_rgb(hsv: torch.Tensor): h, s, v = hsv.unbind(-1) h = h * 6.0 i = torch.floor(h) f = h - i p = v * (1.0 - s) q = v * (1.0 - s * f) t = v * (1.0 - s * (1.0 - f)) i = i.long() % 6 mask = torch.stack( (i == 0, i == 1, i == 2, i == 3, i == 4, i == 5), -1 ) rgb = torch.stack( ( torch.where( mask[..., 0], v, torch.where( mask[..., 1], q, torch.where( mask[..., 2], p, torch.where( mask[..., 3], p, torch.where(mask[..., 4], t, v), ), ), ), ), torch.where( mask[..., 0], t, torch.where( mask[..., 1], v, torch.where( mask[..., 2], v, torch.where( mask[..., 3], q, torch.where(mask[..., 4], p, p), ), ), ), ), torch.where( mask[..., 0], p, torch.where( mask[..., 1], p, torch.where( mask[..., 2], t, torch.where( mask[..., 3], v, torch.where(mask[..., 4], v, q), ), ), ), ), ), dim=-1, ) return rgb def correct( self, image: torch.Tensor, force_gpu: bool, clamp: bool, gamma: float = 1.0, contrast: float = 1.0, exposure: float = 0.0, offset: float = 0.0, hue: float = 0.0, saturation: float = 1.0, value: float = 1.0, mask: torch.Tensor | None = None, ): device = self.get_device(image, force_gpu) image = image.to(device) if mask is not None: if mask.shape[0] != image.shape[0]: mask = mask.expand(image.shape[0], -1, -1) mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3) mask = mask.to(device) model_management.throw_exception_if_processing_interrupted() adjusted = image.pow(1 / gamma) * (2.0**exposure) * contrast + offset model_management.throw_exception_if_processing_interrupted() hsv = self.rgb_to_hsv(adjusted) hsv[..., 0] = (hsv[..., 0] + hue) % 1.0 # Hue hsv[..., 1] = hsv[..., 1] * saturation # Saturation hsv[..., 2] = hsv[..., 2] * value # Value adjusted = self.hsv_to_rgb(hsv) model_management.throw_exception_if_processing_interrupted() if clamp: adjusted = torch.clamp(adjusted, 0.0, 1.0) # apply mask result = ( adjusted if mask is None else torch.where(mask > 0, adjusted, image) ) if not force_gpu: result = result.cpu() return (result,) class MTB_ColorCorrect: """Various color correction methods""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "clamp": ([True, False], {"default": True}), "gamma": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), "contrast": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), "exposure": ( "FLOAT", {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, ), "offset": ( "FLOAT", {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01}, ), "hue": ( "FLOAT", {"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01}, ), "saturation": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), "value": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01}, ), }, "optional": {"mask": ("MASK",)}, } RETURN_TYPES = ("IMAGE",) FUNCTION = "correct" CATEGORY = "mtb/image processing" @staticmethod def gamma_correction_tensor(image, gamma): gamma_inv = 1.0 / gamma return image.pow(gamma_inv) @staticmethod def contrast_adjustment_tensor(image, contrast): r, g, b = image.unbind(-1) # Using Adobe RGB luminance weights. luminance_image = 0.33 * r + 0.71 * g + 0.06 * b luminance_mean = torch.mean(luminance_image.unsqueeze(-1)) # Blend original with mean luminance using contrast factor as blend ratio. contrasted = image * contrast + (1.0 - contrast) * luminance_mean return torch.clamp(contrasted, 0.0, 1.0) @staticmethod def exposure_adjustment_tensor(image, exposure): return image * (2.0**exposure) @staticmethod def offset_adjustment_tensor(image, offset): return image + offset @staticmethod def hsv_adjustment(image: torch.Tensor, hue, saturation, value): images = tensor2pil(image) out = [] for img in images: hsv_image = img.convert("HSV") h, s, v = hsv_image.split() h = h.point(lambda x: (x + hue * 255) % 256) s = s.point(lambda x: int(x * saturation)) v = v.point(lambda x: int(x * value)) hsv_image = Image.merge("HSV", (h, s, v)) rgb_image = hsv_image.convert("RGB") out.append(rgb_image) return pil2tensor(out) @staticmethod def hsv_adjustment_tensor_not_working( image: torch.Tensor, hue, saturation, value ): """Abandonning for now""" image = image.squeeze(0).permute(2, 0, 1) max_val, _ = image.max(dim=0, keepdim=True) min_val, _ = image.min(dim=0, keepdim=True) delta = max_val - min_val hue_image = torch.zeros_like(max_val) mask = delta != 0.0 r, g, b = image[0], image[1], image[2] hue_image[mask & (max_val == r)] = ((g - b) / delta)[ mask & (max_val == r) ] % 6.0 hue_image[mask & (max_val == g)] = ((b - r) / delta)[ mask & (max_val == g) ] + 2.0 hue_image[mask & (max_val == b)] = ((r - g) / delta)[ mask & (max_val == b) ] + 4.0 saturation_image = delta / (max_val + 1e-7) value_image = max_val hue_image = (hue_image + hue) % 1.0 saturation_image = torch.where( mask, saturation * saturation_image, saturation_image ) value_image = value * value_image c = value_image * saturation_image x = c * (1 - torch.abs((hue_image % 2) - 1)) m = value_image - c prime_image = torch.zeros_like(image) prime_image[0] = torch.where( max_val == r, c, torch.where(max_val == g, x, prime_image[0]) ) prime_image[1] = torch.where( max_val == r, x, torch.where(max_val == g, c, prime_image[1]) ) prime_image[2] = torch.where( max_val == g, x, torch.where(max_val == b, c, prime_image[2]) ) rgb_image = prime_image + m rgb_image = rgb_image.permute(1, 2, 0).unsqueeze(0) return rgb_image def correct( self, image: torch.Tensor, clamp: bool, gamma: float = 1.0, contrast: float = 1.0, exposure: float = 0.0, offset: float = 0.0, hue: float = 0.0, saturation: float = 1.0, value: float = 1.0, mask: torch.Tensor | None = None, ): if mask is not None: if mask.shape[0] != image.shape[0]: mask = mask.expand(image.shape[0], -1, -1) mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3) # Apply color correction operations adjusted = self.gamma_correction_tensor(image, gamma) adjusted = self.contrast_adjustment_tensor(adjusted, contrast) adjusted = self.exposure_adjustment_tensor(adjusted, exposure) adjusted = self.offset_adjustment_tensor(adjusted, offset) adjusted = self.hsv_adjustment(adjusted, hue, saturation, value) if clamp: adjusted = torch.clamp(image, 0.0, 1.0) result = ( adjusted if mask is None else torch.where(mask > 0, adjusted, image) ) return (result,) class MTB_ImageCompare: """Compare two images and return a difference image""" @classmethod def INPUT_TYPES(cls): return { "required": { "imageA": ("IMAGE",), "imageB": ("IMAGE",), "mode": ( ["checkerboard", "diff", "blend"], {"default": "checkerboard"}, ), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "compare" CATEGORY = "mtb/image" def compare(self, imageA: torch.Tensor, imageB: torch.Tensor, mode): if imageA.dim() == 4: batch_count = imageA.size(0) return ( torch.cat( tuple( self.compare(imageA[i], imageB[i], mode)[0] for i in range(batch_count) ), dim=0, ), ) num_channels_A = imageA.size(2) num_channels_B = imageB.size(2) # handle RGBA/RGB mismatch if num_channels_A == 3 and num_channels_B == 4: imageA = torch.cat( (imageA, torch.ones_like(imageA[:, :, 0:1])), dim=2 ) elif num_channels_B == 3 and num_channels_A == 4: imageB = torch.cat( (imageB, torch.ones_like(imageB[:, :, 0:1])), dim=2 ) match mode: case "diff": compare_image = torch.abs(imageA - imageB) case "blend": compare_image = 0.5 * (imageA + imageB) case "checkerboard": imageA = imageA.numpy() imageB = imageB.numpy() compared_channels = [ torch.from_numpy( compare_images( imageA[:, :, i], imageB[:, :, i], method=mode ) ) for i in range(imageA.shape[2]) ] compare_image = torch.stack(compared_channels, dim=2) case _: compare_image = None raise ValueError(f"Unknown mode {mode}") compare_image = compare_image.unsqueeze(0) return (compare_image,) import requests class MTB_LoadImageFromUrl: """Load an image from the given URL""" @classmethod def INPUT_TYPES(cls): return { "required": { "url": ( "STRING", { "default": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" }, ), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "load" CATEGORY = "mtb/IO" def load(self, url): # get the image from the url image = Image.open(requests.get(url, stream=True).raw) image = ImageOps.exif_transpose(image) return (pil2tensor(image),) class MTB_Blur: """Blur an image using a Gaussian filter.""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "sigmaX": ( "FLOAT", {"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01}, ), "sigmaY": ( "FLOAT", {"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01}, ), }, "optional": {"sigmasX": ("FLOATS",), "sigmasY": ("FLOATS",)}, } RETURN_TYPES = ("IMAGE",) FUNCTION = "blur" CATEGORY = "mtb/image processing" def blur( self, image: torch.Tensor, sigmaX, sigmaY, sigmasX=None, sigmasY=None ): image_np = image.numpy() * 255 blurred_images = [] if sigmasX is not None: if sigmasY is None: sigmasY = sigmasX if len(sigmasX) != image.size(0): raise ValueError( f"SigmasX must have same length as image, sigmasX is {len(sigmasX)} but the batch size is {image.size(0)}" ) for i in range(image.size(0)): blurred = gaussian( image_np[i], sigma=(sigmasX[i], sigmasY[i], 0), channel_axis=2, ) blurred_images.append(blurred) image_np = np.array(blurred_images) else: for i in range(image.size(0)): blurred = gaussian( image_np[i], sigma=(sigmaX, sigmaY, 0), channel_axis=2 ) blurred_images.append(blurred) image_np = np.array(blurred_images) return (np2tensor(image_np).squeeze(0),) class MTB_Sharpen: """Sharpens an image using a Gaussian kernel.""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "sharpen_radius": ( "INT", {"default": 1, "min": 1, "max": 31, "step": 1}, ), "sigma_x": ( "FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}, ), "sigma_y": ( "FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}, ), "alpha": ( "FLOAT", {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1}, ), }, } RETURN_TYPES = ("IMAGE",) FUNCTION = "do_sharp" CATEGORY = "mtb/image processing" def do_sharp( self, image: torch.Tensor, sharpen_radius: int, sigma_x: float, sigma_y: float, alpha: float, ): if sharpen_radius == 0: return (image,) channels = image.shape[3] kernel_size = 2 * sharpen_radius + 1 kernel = gaussian_kernel(kernel_size, sigma_x, sigma_y) * -(alpha * 10) # Modify center of kernel to make it a sharpening kernel center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) tensor_image = image.permute(0, 3, 1, 2) tensor_image = F.pad( tensor_image, (sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius), "reflect", ) sharpened = F.conv2d( tensor_image, kernel, padding=center, groups=channels ) # Remove padding sharpened = sharpened[ :, :, sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius, ] sharpened = sharpened.permute(0, 2, 3, 1) result = torch.clamp(sharpened, 0, 1) return (result,) # https://github.com/lllyasviel/AdverseCleaner/blob/main/clean.py # def deglaze_np_img(np_img): # y = np_img.copy() # for _ in range(64): # y = cv2.bilateralFilter(y, 5, 8, 8) # for _ in range(4): # y = guidedFilter(np_img, y, 4, 16) # return y # class DeglazeImage: # """Remove adversarial noise from images""" # @classmethod # def INPUT_TYPES(cls): # return {"required": {"image": ("IMAGE",)}} # CATEGORY = "mtb/image processing" # RETURN_TYPES = ("IMAGE",) # FUNCTION = "deglaze_image" # def deglaze_image(self, image): # return (np2tensor(deglaze_np_img(tensor2np(image))),) class MTB_MaskToImage: """Converts a mask (alpha) to an RGB image with a color and background""" @classmethod def INPUT_TYPES(cls): return { "required": { "mask": ("MASK",), "color": ("COLOR",), "background": ("COLOR", {"default": "#000000"}), }, "optional": { "invert": ("BOOLEAN", {"default": False}), }, } CATEGORY = "mtb/generate" RETURN_TYPES = ("IMAGE",) FUNCTION = "render_mask" def render_mask(self, mask, color, background, invert=False): masks = tensor2pil(1.0 - mask) if invert else tensor2pil(mask) images = [] for m in masks: _mask = m.convert("L") log.debug( f"Converted mask to PIL Image format, size: {_mask.size}" ) image = Image.new("RGBA", _mask.size, color=color) # apply the mask image = Image.composite( image, Image.new("RGBA", _mask.size, color=background), _mask ) # image = ImageChops.multiply(image, mask) # apply over background # image = Image.alpha_composite(Image.new("RGBA", image.size, color=background), image) images.append(image.convert("RGB")) return (pil2tensor(images),) class MTB_ColoredImage: """Constant color image of given size.""" def __init__(self) -> None: pass @classmethod def INPUT_TYPES(cls): return { "required": { "color": ("COLOR",), "width": ("INT", {"default": 512, "min": 16, "max": 8160}), "height": ("INT", {"default": 512, "min": 16, "max": 8160}), }, "optional": { "foreground_image": ("IMAGE",), "foreground_mask": ("MASK",), "invert": ("BOOLEAN", {"default": False}), "mask_opacity": ( "FLOAT", {"default": 1.0, "step": 0.1, "min": 0}, ), }, } CATEGORY = "mtb/generate" RETURN_TYPES = ("IMAGE",) FUNCTION = "render_img" def resize_and_crop(self, img: Image.Image, target_size: tuple[int, int]): scale = max(target_size[0] / img.width, target_size[1] / img.height) new_size = (int(img.width * scale), int(img.height * scale)) img = img.resize(new_size, Image.LANCZOS) left = (img.width - target_size[0]) // 2 top = (img.height - target_size[1]) // 2 return img.crop( (left, top, left + target_size[0], top + target_size[1]) ) def resize_and_crop_thumbnails( self, img: Image.Image, target_size: tuple[int, int] ): img.thumbnail(target_size, Image.LANCZOS) left = (img.width - target_size[0]) / 2 top = (img.height - target_size[1]) / 2 right = (img.width + target_size[0]) / 2 bottom = (img.height + target_size[1]) / 2 return img.crop((left, top, right, bottom)) @staticmethod def process_mask( mask: torch.Tensor | None, invert: bool, # opacity: float, batch_size: int, ) -> list[Image.Image] | None: if mask is None: return [None] * batch_size masks = tensor2pil(mask if not invert else 1.0 - mask) if len(masks) == 1 and batch_size > 1: masks = masks * batch_size if len(masks) != batch_size: raise ValueError( "Foreground image and mask must have the same batch size" ) return masks def render_img( self, color: str, width: int, height: int, foreground_image: torch.Tensor | None = None, foreground_mask: torch.Tensor | None = None, invert: bool = False, mask_opacity: float = 1.0, ) -> tuple[torch.Tensor]: background = Image.new("RGBA", (width, height), color=color) if foreground_image is None: return (pil2tensor([background.convert("RGB")]),) fg_images = tensor2pil(foreground_image) fg_masks = self.process_mask(foreground_mask, invert, len(fg_images)) output: list[Image.Image] = [] for fg_image, fg_mask in zip(fg_images, fg_masks, strict=False): fg_image = self.resize_and_crop(fg_image, background.size) if fg_mask: fg_mask = self.resize_and_crop(fg_mask, background.size) fg_mask_array = np.array(fg_mask) fg_mask_array = (fg_mask_array * mask_opacity).astype(np.uint8) fg_mask = Image.fromarray(fg_mask_array) output.append( Image.composite( fg_image.convert("RGBA"), background, fg_mask ).convert("RGB") ) else: if fg_image.mode != "RGBA": raise ValueError( f"Foreground image must be in 'RGBA' mode when no mask is provided, got {fg_image.mode}" ) output.append( Image.alpha_composite(background, fg_image).convert("RGB") ) return (pil2tensor(output),) class MTB_ImagePremultiply: """Premultiply image with mask""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "mask": ("MASK",), "invert": ("BOOLEAN", {"default": False}), } } CATEGORY = "mtb/image" RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("RGBA",) FUNCTION = "premultiply" def premultiply(self, image, mask, invert): images = tensor2pil(image) masks = tensor2pil(mask) if invert else tensor2pil(1.0 - mask) single = len(mask) == 1 masks = [x.convert("L") for x in masks] out = [] for i, img in enumerate(images): cur_mask = masks[0] if single else masks[i] img.putalpha(cur_mask) out.append(img) # if invert: # image = Image.composite(image,Image.new("RGBA", image.size, color=(0,0,0,0)), mask) # else: # image = Image.composite(Image.new("RGBA", image.size, color=(0,0,0,0)), image, mask) return (pil2tensor(out),) class MTB_ImageResizeFactor: """Extracted mostly from WAS Node Suite, with a few edits (most notably multiple image support) and less features.""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "factor": ( "FLOAT", {"default": 2, "min": 0.01, "max": 16.0, "step": 0.01}, ), "supersample": ("BOOLEAN", {"default": True}), "resampling": ( [ "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact", ], {"default": "nearest"}, ), }, "optional": { "mask": ("MASK",), }, } CATEGORY = "mtb/image" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "resize" def resize( self, image: torch.Tensor, factor: float, supersample: bool, resampling: str, mask=None, ): # Check if the tensor has the correct dimension if len(image.shape) not in [3, 4]: # HxWxC or BxHxWxC raise ValueError( "Expected image tensor of shape (H, W, C) or (B, H, W, C)" ) # Transpose to CxHxW or BxCxHxW for PyTorch if len(image.shape) == 3: image = image.permute(2, 0, 1).unsqueeze(0) # CxHxW else: image = image.permute(0, 3, 1, 2) # BxCxHxW # Compute new dimensions B, C, H, W = image.shape new_H, new_W = int(H * factor), int(W * factor) align_corner_filters = ("linear", "bilinear", "bicubic", "trilinear") # Resize the image resized_image = F.interpolate( image, size=(new_H, new_W), mode=resampling, align_corners=resampling in align_corner_filters, ) # Optionally supersample if supersample: resized_image = F.interpolate( resized_image, scale_factor=2, mode=resampling, align_corners=resampling in align_corner_filters, ) # Transpose back to the original format: BxHxWxC or HxWxC if len(image.shape) == 4: resized_image = resized_image.permute(0, 2, 3, 1) else: resized_image = resized_image.squeeze(0).permute(1, 2, 0) # Apply mask if provided if mask is not None: if len(mask.shape) != len(resized_image.shape): raise ValueError( "Mask tensor should have the same dimensions as the image tensor" ) resized_image = resized_image * mask return (resized_image,) class MTB_SaveImageGrid: """Save all the images in the input batch as a grid of images.""" def __init__(self): self.output_dir = folder_paths.get_output_directory() self.type = "output" @classmethod def INPUT_TYPES(cls): return { "required": { "images": ("IMAGE",), "filename_prefix": ("STRING", {"default": "ComfyUI"}), "save_intermediate": ("BOOLEAN", {"default": False}), }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () FUNCTION = "save_images" OUTPUT_NODE = True CATEGORY = "mtb/IO" def create_image_grid(self, image_list): total_images = len(image_list) # Calculate the grid size based on the square root of the total number of images grid_size = ( int(math.sqrt(total_images)), int(math.ceil(math.sqrt(total_images))), ) # Get the size of the first image to determine the grid size image_width, image_height = image_list[0].size # Create a new blank image to hold the grid grid_width = grid_size[0] * image_width grid_height = grid_size[1] * image_height grid_image = Image.new("RGB", (grid_width, grid_height)) # Iterate over the images and paste them onto the grid for i, image in enumerate(image_list): x = (i % grid_size[0]) * image_width y = (i // grid_size[0]) * image_height grid_image.paste(image, (x, y, x + image_width, y + image_height)) return grid_image def save_images( self, images, filename_prefix="Grid", save_intermediate=False, prompt=None, extra_pnginfo=None, ): ( full_output_folder, filename, counter, subfolder, filename_prefix, ) = folder_paths.get_save_image_path( filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0], ) image_list = [] batch_counter = counter metadata = PngInfo() if prompt is not None: metadata.add_text("prompt", json.dumps(prompt)) if extra_pnginfo is not None: for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) for idx, image in enumerate(images): i = 255.0 * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) image_list.append(img) if save_intermediate: file = f"{filename}_batch-{idx:03}_{batch_counter:05}_.png" img.save( os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4, ) batch_counter += 1 file = f"{filename}_{counter:05}_.png" grid = self.create_image_grid(image_list) grid.save( os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4, ) results = [ {"filename": file, "subfolder": subfolder, "type": self.type} ] return {"ui": {"images": results}} class MTB_ImageTileOffset: """Mimics an old photoshop technique to check for seamless textures""" @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), "tilesX": ("INT", {"default": 2, "min": 1}), "tilesY": ("INT", {"default": 2, "min": 1}), } } CATEGORY = "mtb/generate" RETURN_TYPES = ("IMAGE",) FUNCTION = "tile_image" def tile_image( self, image: torch.Tensor, tilesX: int = 2, tilesY: int = 2 ): if tilesX < 1 or tilesY < 1: raise ValueError("The number of tiles must be at least 1.") batch_size, height, width, channels = image.shape tile_height = height // tilesY tile_width = width // tilesX output_image = torch.zeros_like(image) for i, j in itertools.product(range(tilesY), range(tilesX)): start_h = i * tile_height end_h = start_h + tile_height start_w = j * tile_width end_w = start_w + tile_width tile = image[:, start_h:end_h, start_w:end_w, :] output_start_h = (i + 1) % tilesY * tile_height output_start_w = (j + 1) % tilesX * tile_width output_end_h = output_start_h + tile_height output_end_w = output_start_w + tile_width output_image[ :, output_start_h:output_end_h, output_start_w:output_end_w, : ] = tile return (output_image,) __nodes__ = [ MTB_ColorCorrect, MTB_ColorCorrectGPU, MTB_ImageCompare, MTB_ImageTileOffset, MTB_Blur, # DeglazeImage, MTB_MaskToImage, MTB_ColoredImage, MTB_ImagePremultiply, MTB_ImageResizeFactor, MTB_SaveImageGrid, MTB_LoadImageFromUrl, MTB_Sharpen, MTB_ExtractCoordinatesFromImage, MTB_CoordinatesToString, ]