multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
4.18 kB
import torch
from ..log import log
class MTB_StackImages:
"""Stack the input images horizontally or vertically."""
@classmethod
def INPUT_TYPES(cls):
return {"required": {"vertical": ("BOOLEAN", {"default": False})}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stack"
CATEGORY = "mtb/image utils"
def stack(self, vertical, **kwargs):
if not kwargs:
raise ValueError("At least one tensor must be provided.")
tensors = list(kwargs.values())
log.debug(
f"Stacking {len(tensors)} tensors "
f"{'vertically' if vertical else 'horizontally'}"
)
normalized_tensors = [
self.normalize_to_rgba(tensor) for tensor in tensors
]
max_batch_size = max(tensor.shape[0] for tensor in normalized_tensors)
normalized_tensors = [
self.duplicate_frames(tensor, max_batch_size)
for tensor in normalized_tensors
]
if vertical:
width = normalized_tensors[0].shape[2]
if any(tensor.shape[2] != width for tensor in normalized_tensors):
raise ValueError(
"All tensors must have the same width "
"for vertical stacking."
)
dim = 1
else:
height = normalized_tensors[0].shape[1]
if any(tensor.shape[1] != height for tensor in normalized_tensors):
raise ValueError(
"All tensors must have the same height "
"for horizontal stacking."
)
dim = 2
stacked_tensor = torch.cat(normalized_tensors, dim=dim)
return (stacked_tensor,)
def normalize_to_rgba(self, tensor):
"""Normalize tensor to have 4 channels (RGBA)."""
_, _, _, channels = tensor.shape
# already RGBA
if channels == 4:
return tensor
# RGB to RGBA
elif channels == 3:
alpha_channel = torch.ones(
tensor.shape[:-1] + (1,), device=tensor.device
) # Add an alpha channel
return torch.cat((tensor, alpha_channel), dim=-1)
else:
raise ValueError(
"Tensor has an unsupported number of channels: "
"expected 3 (RGB) or 4 (RGBA)."
)
def duplicate_frames(self, tensor, target_batch_size):
"""Duplicate frames in tensor to match the target batch size."""
current_batch_size = tensor.shape[0]
if current_batch_size < target_batch_size:
duplication_factors: int = target_batch_size // current_batch_size
duplicated_tensor = tensor.repeat(duplication_factors, 1, 1, 1)
remaining_frames = target_batch_size % current_batch_size
if remaining_frames > 0:
duplicated_tensor = torch.cat(
(duplicated_tensor, tensor[:remaining_frames]), dim=0
)
return duplicated_tensor
else:
return tensor
class MTB_PickFromBatch:
"""Pick a specific number of images from a batch.
either from the start or end.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"from_direction": (["end", "start"], {"default": "start"}),
"count": ("INT", {"default": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "pick_from_batch"
CATEGORY = "mtb/image utils"
def pick_from_batch(self, image, from_direction, count):
batch_size = image.size(0)
# Limit count to the available number of images in the batch
count = min(count, batch_size)
if count < batch_size:
log.warning(
f"Requested {count} images, "
f"but only {batch_size} are available."
)
if from_direction == "end":
selected_tensors = image[-count:]
else:
selected_tensors = image[:count]
return (selected_tensors,)
__nodes__ = [MTB_StackImages, MTB_PickFromBatch]