|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import PIL |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
|
|
from ... import ConfigMixin |
|
from ...configuration_utils import register_to_config |
|
from ...image_processor import PipelineImageInput |
|
from ...utils import CONFIG_NAME, logging |
|
from ...utils.import_utils import is_matplotlib_available |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class MarigoldImageProcessor(ConfigMixin): |
|
config_name = CONFIG_NAME |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
vae_scale_factor: int = 8, |
|
do_normalize: bool = True, |
|
do_range_check: bool = True, |
|
): |
|
super().__init__() |
|
|
|
@staticmethod |
|
def expand_tensor_or_array(images: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: |
|
""" |
|
Expand a tensor or array to a specified number of images. |
|
""" |
|
if isinstance(images, np.ndarray): |
|
if images.ndim == 2: |
|
images = images[None, ..., None] |
|
if images.ndim == 3: |
|
images = images[None] |
|
elif isinstance(images, torch.Tensor): |
|
if images.ndim == 2: |
|
images = images[None, None] |
|
elif images.ndim == 3: |
|
images = images[None] |
|
else: |
|
raise ValueError(f"Unexpected input type: {type(images)}") |
|
return images |
|
|
|
@staticmethod |
|
def pt_to_numpy(images: torch.Tensor) -> np.ndarray: |
|
""" |
|
Convert a PyTorch tensor to a NumPy image. |
|
""" |
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
return images |
|
|
|
@staticmethod |
|
def numpy_to_pt(images: np.ndarray) -> torch.Tensor: |
|
""" |
|
Convert a NumPy image to a PyTorch tensor. |
|
""" |
|
if np.issubdtype(images.dtype, np.integer) and not np.issubdtype(images.dtype, np.unsignedinteger): |
|
raise ValueError(f"Input image dtype={images.dtype} cannot be a signed integer.") |
|
if np.issubdtype(images.dtype, np.complexfloating): |
|
raise ValueError(f"Input image dtype={images.dtype} cannot be complex.") |
|
if np.issubdtype(images.dtype, bool): |
|
raise ValueError(f"Input image dtype={images.dtype} cannot be boolean.") |
|
|
|
images = torch.from_numpy(images.transpose(0, 3, 1, 2)) |
|
return images |
|
|
|
@staticmethod |
|
def resize_antialias( |
|
image: torch.Tensor, size: Tuple[int, int], mode: str, is_aa: Optional[bool] = None |
|
) -> torch.Tensor: |
|
if not torch.is_tensor(image): |
|
raise ValueError(f"Invalid input type={type(image)}.") |
|
if not torch.is_floating_point(image): |
|
raise ValueError(f"Invalid input dtype={image.dtype}.") |
|
if image.dim() != 4: |
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") |
|
|
|
antialias = is_aa and mode in ("bilinear", "bicubic") |
|
image = F.interpolate(image, size, mode=mode, antialias=antialias) |
|
|
|
return image |
|
|
|
@staticmethod |
|
def resize_to_max_edge(image: torch.Tensor, max_edge_sz: int, mode: str) -> torch.Tensor: |
|
if not torch.is_tensor(image): |
|
raise ValueError(f"Invalid input type={type(image)}.") |
|
if not torch.is_floating_point(image): |
|
raise ValueError(f"Invalid input dtype={image.dtype}.") |
|
if image.dim() != 4: |
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") |
|
|
|
h, w = image.shape[-2:] |
|
max_orig = max(h, w) |
|
new_h = h * max_edge_sz // max_orig |
|
new_w = w * max_edge_sz // max_orig |
|
|
|
if new_h == 0 or new_w == 0: |
|
raise ValueError(f"Extreme aspect ratio of the input image: [{w} x {h}]") |
|
|
|
image = MarigoldImageProcessor.resize_antialias(image, (new_h, new_w), mode, is_aa=True) |
|
|
|
return image |
|
|
|
@staticmethod |
|
def pad_image(image: torch.Tensor, align: int) -> Tuple[torch.Tensor, Tuple[int, int]]: |
|
if not torch.is_tensor(image): |
|
raise ValueError(f"Invalid input type={type(image)}.") |
|
if not torch.is_floating_point(image): |
|
raise ValueError(f"Invalid input dtype={image.dtype}.") |
|
if image.dim() != 4: |
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") |
|
|
|
h, w = image.shape[-2:] |
|
ph, pw = -h % align, -w % align |
|
|
|
image = F.pad(image, (0, pw, 0, ph), mode="replicate") |
|
|
|
return image, (ph, pw) |
|
|
|
@staticmethod |
|
def unpad_image(image: torch.Tensor, padding: Tuple[int, int]) -> torch.Tensor: |
|
if not torch.is_tensor(image): |
|
raise ValueError(f"Invalid input type={type(image)}.") |
|
if not torch.is_floating_point(image): |
|
raise ValueError(f"Invalid input dtype={image.dtype}.") |
|
if image.dim() != 4: |
|
raise ValueError(f"Invalid input dimensions; shape={image.shape}.") |
|
|
|
ph, pw = padding |
|
uh = None if ph == 0 else -ph |
|
uw = None if pw == 0 else -pw |
|
|
|
image = image[:, :, :uh, :uw] |
|
|
|
return image |
|
|
|
@staticmethod |
|
def load_image_canonical( |
|
image: Union[torch.Tensor, np.ndarray, Image.Image], |
|
device: torch.device = torch.device("cpu"), |
|
dtype: torch.dtype = torch.float32, |
|
) -> Tuple[torch.Tensor, int]: |
|
if isinstance(image, Image.Image): |
|
image = np.array(image) |
|
|
|
image_dtype_max = None |
|
if isinstance(image, (np.ndarray, torch.Tensor)): |
|
image = MarigoldImageProcessor.expand_tensor_or_array(image) |
|
if image.ndim != 4: |
|
raise ValueError("Input image is not 2-, 3-, or 4-dimensional.") |
|
if isinstance(image, np.ndarray): |
|
if np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.unsignedinteger): |
|
raise ValueError(f"Input image dtype={image.dtype} cannot be a signed integer.") |
|
if np.issubdtype(image.dtype, np.complexfloating): |
|
raise ValueError(f"Input image dtype={image.dtype} cannot be complex.") |
|
if np.issubdtype(image.dtype, bool): |
|
raise ValueError(f"Input image dtype={image.dtype} cannot be boolean.") |
|
if np.issubdtype(image.dtype, np.unsignedinteger): |
|
image_dtype_max = np.iinfo(image.dtype).max |
|
image = image.astype(np.float32) |
|
image = MarigoldImageProcessor.numpy_to_pt(image) |
|
|
|
if torch.is_tensor(image) and not torch.is_floating_point(image) and image_dtype_max is None: |
|
if image.dtype != torch.uint8: |
|
raise ValueError(f"Image dtype={image.dtype} is not supported.") |
|
image_dtype_max = 255 |
|
|
|
if not torch.is_tensor(image): |
|
raise ValueError(f"Input type unsupported: {type(image)}.") |
|
|
|
if image.shape[1] == 1: |
|
image = image.repeat(1, 3, 1, 1) |
|
if image.shape[1] != 3: |
|
raise ValueError(f"Input image is not 1- or 3-channel: {image.shape}.") |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
|
|
if image_dtype_max is not None: |
|
image = image / image_dtype_max |
|
|
|
return image |
|
|
|
@staticmethod |
|
def check_image_values_range(image: torch.Tensor) -> None: |
|
if not torch.is_tensor(image): |
|
raise ValueError(f"Invalid input type={type(image)}.") |
|
if not torch.is_floating_point(image): |
|
raise ValueError(f"Invalid input dtype={image.dtype}.") |
|
if image.min().item() < 0.0 or image.max().item() > 1.0: |
|
raise ValueError("Input image data is partially outside of the [0,1] range.") |
|
|
|
def preprocess( |
|
self, |
|
image: PipelineImageInput, |
|
processing_resolution: Optional[int] = None, |
|
resample_method_input: str = "bilinear", |
|
device: torch.device = torch.device("cpu"), |
|
dtype: torch.dtype = torch.float32, |
|
): |
|
if isinstance(image, list): |
|
images = None |
|
for i, img in enumerate(image): |
|
img = self.load_image_canonical(img, device, dtype) |
|
if images is None: |
|
images = img |
|
else: |
|
if images.shape[2:] != img.shape[2:]: |
|
raise ValueError( |
|
f"Input image[{i}] has incompatible dimensions {img.shape[2:]} with the previous images " |
|
f"{images.shape[2:]}" |
|
) |
|
images = torch.cat((images, img), dim=0) |
|
image = images |
|
del images |
|
else: |
|
image = self.load_image_canonical(image, device, dtype) |
|
|
|
original_resolution = image.shape[2:] |
|
|
|
if self.config.do_range_check: |
|
self.check_image_values_range(image) |
|
|
|
if self.config.do_normalize: |
|
image = image * 2.0 - 1.0 |
|
|
|
if processing_resolution is not None and processing_resolution > 0: |
|
image = self.resize_to_max_edge(image, processing_resolution, resample_method_input) |
|
|
|
image, padding = self.pad_image(image, self.config.vae_scale_factor) |
|
|
|
return image, padding, original_resolution |
|
|
|
@staticmethod |
|
def colormap( |
|
image: Union[np.ndarray, torch.Tensor], |
|
cmap: str = "Spectral", |
|
bytes: bool = False, |
|
_force_method: Optional[str] = None, |
|
) -> Union[np.ndarray, torch.Tensor]: |
|
""" |
|
Converts a monochrome image into an RGB image by applying the specified colormap. This function mimics the |
|
behavior of matplotlib.colormaps, but allows the user to use the most discriminative color maps ("Spectral", |
|
"binary") without having to install or import matplotlib. For all other cases, the function will attempt to use |
|
the native implementation. |
|
|
|
Args: |
|
image: 2D tensor of values between 0 and 1, either as np.ndarray or torch.Tensor. |
|
cmap: Colormap name. |
|
bytes: Whether to return the output as uint8 or floating point image. |
|
_force_method: |
|
Can be used to specify whether to use the native implementation (`"matplotlib"`), the efficient custom |
|
implementation of the select color maps (`"custom"`), or rely on autodetection (`None`, default). |
|
|
|
Returns: |
|
An RGB-colorized tensor corresponding to the input image. |
|
""" |
|
if not (torch.is_tensor(image) or isinstance(image, np.ndarray)): |
|
raise ValueError("Argument must be a numpy array or torch tensor.") |
|
if _force_method not in (None, "matplotlib", "custom"): |
|
raise ValueError("_force_method must be either `None`, `'matplotlib'` or `'custom'`.") |
|
|
|
supported_cmaps = { |
|
"binary": [ |
|
(1.0, 1.0, 1.0), |
|
(0.0, 0.0, 0.0), |
|
], |
|
"Spectral": [ |
|
(0.61960784313725492, 0.003921568627450980, 0.25882352941176473), |
|
(0.83529411764705885, 0.24313725490196078, 0.30980392156862746), |
|
(0.95686274509803926, 0.42745098039215684, 0.2627450980392157), |
|
(0.99215686274509807, 0.68235294117647061, 0.38039215686274508), |
|
(0.99607843137254903, 0.8784313725490196, 0.54509803921568623), |
|
(1.0, 1.0, 0.74901960784313726), |
|
(0.90196078431372551, 0.96078431372549022, 0.59607843137254901), |
|
(0.6705882352941176, 0.8666666666666667, 0.64313725490196083), |
|
(0.4, 0.76078431372549016, 0.6470588235294118), |
|
(0.19607843137254902, 0.53333333333333333, 0.74117647058823533), |
|
(0.36862745098039218, 0.30980392156862746, 0.63529411764705879), |
|
], |
|
} |
|
|
|
def method_matplotlib(image, cmap, bytes=False): |
|
if is_matplotlib_available(): |
|
import matplotlib |
|
else: |
|
return None |
|
|
|
arg_is_pt, device = torch.is_tensor(image), None |
|
if arg_is_pt: |
|
image, device = image.cpu().numpy(), image.device |
|
|
|
if cmap not in matplotlib.colormaps: |
|
raise ValueError( |
|
f"Unexpected color map {cmap}; available options are: {', '.join(list(matplotlib.colormaps.keys()))}" |
|
) |
|
|
|
cmap = matplotlib.colormaps[cmap] |
|
out = cmap(image, bytes=bytes) |
|
out = out[..., :3] |
|
|
|
if arg_is_pt: |
|
out = torch.tensor(out, device=device) |
|
|
|
return out |
|
|
|
def method_custom(image, cmap, bytes=False): |
|
arg_is_np = isinstance(image, np.ndarray) |
|
if arg_is_np: |
|
image = torch.tensor(image) |
|
if image.dtype == torch.uint8: |
|
image = image.float() / 255 |
|
else: |
|
image = image.float() |
|
|
|
is_cmap_reversed = cmap.endswith("_r") |
|
if is_cmap_reversed: |
|
cmap = cmap[:-2] |
|
|
|
if cmap not in supported_cmaps: |
|
raise ValueError( |
|
f"Only {list(supported_cmaps.keys())} color maps are available without installing matplotlib." |
|
) |
|
|
|
cmap = supported_cmaps[cmap] |
|
if is_cmap_reversed: |
|
cmap = cmap[::-1] |
|
cmap = torch.tensor(cmap, dtype=torch.float, device=image.device) |
|
K = cmap.shape[0] |
|
|
|
pos = image.clamp(min=0, max=1) * (K - 1) |
|
left = pos.long() |
|
right = (left + 1).clamp(max=K - 1) |
|
|
|
d = (pos - left.float()).unsqueeze(-1) |
|
left_colors = cmap[left] |
|
right_colors = cmap[right] |
|
|
|
out = (1 - d) * left_colors + d * right_colors |
|
|
|
if bytes: |
|
out = (out * 255).to(torch.uint8) |
|
|
|
if arg_is_np: |
|
out = out.numpy() |
|
|
|
return out |
|
|
|
if _force_method is None and torch.is_tensor(image) and cmap == "Spectral": |
|
return method_custom(image, cmap, bytes) |
|
|
|
out = None |
|
if _force_method != "custom": |
|
out = method_matplotlib(image, cmap, bytes) |
|
|
|
if _force_method == "matplotlib" and out is None: |
|
raise ImportError("Make sure to install matplotlib if you want to use a color map other than 'Spectral'.") |
|
|
|
if out is None: |
|
out = method_custom(image, cmap, bytes) |
|
|
|
return out |
|
|
|
@staticmethod |
|
def visualize_depth( |
|
depth: Union[ |
|
PIL.Image.Image, |
|
np.ndarray, |
|
torch.Tensor, |
|
List[PIL.Image.Image], |
|
List[np.ndarray], |
|
List[torch.Tensor], |
|
], |
|
val_min: float = 0.0, |
|
val_max: float = 1.0, |
|
color_map: str = "Spectral", |
|
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: |
|
""" |
|
Visualizes depth maps, such as predictions of the `MarigoldDepthPipeline`. |
|
|
|
Args: |
|
depth (`Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], |
|
List[torch.Tensor]]`): Depth maps. |
|
val_min (`float`, *optional*, defaults to `0.0`): Minimum value of the visualized depth range. |
|
val_max (`float`, *optional*, defaults to `1.0`): Maximum value of the visualized depth range. |
|
color_map (`str`, *optional*, defaults to `"Spectral"`): Color map used to convert a single-channel |
|
depth prediction into colored representation. |
|
|
|
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with depth maps visualization. |
|
""" |
|
if val_max <= val_min: |
|
raise ValueError(f"Invalid values range: [{val_min}, {val_max}].") |
|
|
|
def visualize_depth_one(img, idx=None): |
|
prefix = "Depth" + (f"[{idx}]" if idx else "") |
|
if isinstance(img, PIL.Image.Image): |
|
if img.mode != "I;16": |
|
raise ValueError(f"{prefix}: invalid PIL mode={img.mode}.") |
|
img = np.array(img).astype(np.float32) / (2**16 - 1) |
|
if isinstance(img, np.ndarray) or torch.is_tensor(img): |
|
if img.ndim != 2: |
|
raise ValueError(f"{prefix}: unexpected shape={img.shape}.") |
|
if isinstance(img, np.ndarray): |
|
img = torch.from_numpy(img) |
|
if not torch.is_floating_point(img): |
|
raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") |
|
else: |
|
raise ValueError(f"{prefix}: unexpected type={type(img)}.") |
|
if val_min != 0.0 or val_max != 1.0: |
|
img = (img - val_min) / (val_max - val_min) |
|
img = MarigoldImageProcessor.colormap(img, cmap=color_map, bytes=True) |
|
img = PIL.Image.fromarray(img.cpu().numpy()) |
|
return img |
|
|
|
if depth is None or isinstance(depth, list) and any(o is None for o in depth): |
|
raise ValueError("Input depth is `None`") |
|
if isinstance(depth, (np.ndarray, torch.Tensor)): |
|
depth = MarigoldImageProcessor.expand_tensor_or_array(depth) |
|
if isinstance(depth, np.ndarray): |
|
depth = MarigoldImageProcessor.numpy_to_pt(depth) |
|
if not (depth.ndim == 4 and depth.shape[1] == 1): |
|
raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") |
|
return [visualize_depth_one(img[0], idx) for idx, img in enumerate(depth)] |
|
elif isinstance(depth, list): |
|
return [visualize_depth_one(img, idx) for idx, img in enumerate(depth)] |
|
else: |
|
raise ValueError(f"Unexpected input type: {type(depth)}") |
|
|
|
@staticmethod |
|
def export_depth_to_16bit_png( |
|
depth: Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]], |
|
val_min: float = 0.0, |
|
val_max: float = 1.0, |
|
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: |
|
def export_depth_to_16bit_png_one(img, idx=None): |
|
prefix = "Depth" + (f"[{idx}]" if idx else "") |
|
if not isinstance(img, np.ndarray) and not torch.is_tensor(img): |
|
raise ValueError(f"{prefix}: unexpected type={type(img)}.") |
|
if img.ndim != 2: |
|
raise ValueError(f"{prefix}: unexpected shape={img.shape}.") |
|
if torch.is_tensor(img): |
|
img = img.cpu().numpy() |
|
if not np.issubdtype(img.dtype, np.floating): |
|
raise ValueError(f"{prefix}: unexected dtype={img.dtype}.") |
|
if val_min != 0.0 or val_max != 1.0: |
|
img = (img - val_min) / (val_max - val_min) |
|
img = (img * (2**16 - 1)).astype(np.uint16) |
|
img = PIL.Image.fromarray(img, mode="I;16") |
|
return img |
|
|
|
if depth is None or isinstance(depth, list) and any(o is None for o in depth): |
|
raise ValueError("Input depth is `None`") |
|
if isinstance(depth, (np.ndarray, torch.Tensor)): |
|
depth = MarigoldImageProcessor.expand_tensor_or_array(depth) |
|
if isinstance(depth, np.ndarray): |
|
depth = MarigoldImageProcessor.numpy_to_pt(depth) |
|
if not (depth.ndim == 4 and depth.shape[1] == 1): |
|
raise ValueError(f"Unexpected input shape={depth.shape}, expecting [N,1,H,W].") |
|
return [export_depth_to_16bit_png_one(img[0], idx) for idx, img in enumerate(depth)] |
|
elif isinstance(depth, list): |
|
return [export_depth_to_16bit_png_one(img, idx) for idx, img in enumerate(depth)] |
|
else: |
|
raise ValueError(f"Unexpected input type: {type(depth)}") |
|
|
|
@staticmethod |
|
def visualize_normals( |
|
normals: Union[ |
|
np.ndarray, |
|
torch.Tensor, |
|
List[np.ndarray], |
|
List[torch.Tensor], |
|
], |
|
flip_x: bool = False, |
|
flip_y: bool = False, |
|
flip_z: bool = False, |
|
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: |
|
""" |
|
Visualizes surface normals, such as predictions of the `MarigoldNormalsPipeline`. |
|
|
|
Args: |
|
normals (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): |
|
Surface normals. |
|
flip_x (`bool`, *optional*, defaults to `False`): Flips the X axis of the normals frame of reference. |
|
Default direction is right. |
|
flip_y (`bool`, *optional*, defaults to `False`): Flips the Y axis of the normals frame of reference. |
|
Default direction is top. |
|
flip_z (`bool`, *optional*, defaults to `False`): Flips the Z axis of the normals frame of reference. |
|
Default direction is facing the observer. |
|
|
|
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with surface normals visualization. |
|
""" |
|
flip_vec = None |
|
if any((flip_x, flip_y, flip_z)): |
|
flip_vec = torch.tensor( |
|
[ |
|
(-1) ** flip_x, |
|
(-1) ** flip_y, |
|
(-1) ** flip_z, |
|
], |
|
dtype=torch.float32, |
|
) |
|
|
|
def visualize_normals_one(img, idx=None): |
|
img = img.permute(1, 2, 0) |
|
if flip_vec is not None: |
|
img *= flip_vec.to(img.device) |
|
img = (img + 1.0) * 0.5 |
|
img = (img * 255).to(dtype=torch.uint8, device="cpu").numpy() |
|
img = PIL.Image.fromarray(img) |
|
return img |
|
|
|
if normals is None or isinstance(normals, list) and any(o is None for o in normals): |
|
raise ValueError("Input normals is `None`") |
|
if isinstance(normals, (np.ndarray, torch.Tensor)): |
|
normals = MarigoldImageProcessor.expand_tensor_or_array(normals) |
|
if isinstance(normals, np.ndarray): |
|
normals = MarigoldImageProcessor.numpy_to_pt(normals) |
|
if not (normals.ndim == 4 and normals.shape[1] == 3): |
|
raise ValueError(f"Unexpected input shape={normals.shape}, expecting [N,3,H,W].") |
|
return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] |
|
elif isinstance(normals, list): |
|
return [visualize_normals_one(img, idx) for idx, img in enumerate(normals)] |
|
else: |
|
raise ValueError(f"Unexpected input type: {type(normals)}") |
|
|
|
@staticmethod |
|
def visualize_uncertainty( |
|
uncertainty: Union[ |
|
np.ndarray, |
|
torch.Tensor, |
|
List[np.ndarray], |
|
List[torch.Tensor], |
|
], |
|
saturation_percentile=95, |
|
) -> Union[PIL.Image.Image, List[PIL.Image.Image]]: |
|
""" |
|
Visualizes dense uncertainties, such as produced by `MarigoldDepthPipeline` or `MarigoldNormalsPipeline`. |
|
|
|
Args: |
|
uncertainty (`Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]`): |
|
Uncertainty maps. |
|
saturation_percentile (`int`, *optional*, defaults to `95`): |
|
Specifies the percentile uncertainty value visualized with maximum intensity. |
|
|
|
Returns: `PIL.Image.Image` or `List[PIL.Image.Image]` with uncertainty visualization. |
|
""" |
|
|
|
def visualize_uncertainty_one(img, idx=None): |
|
prefix = "Uncertainty" + (f"[{idx}]" if idx else "") |
|
if img.min() < 0: |
|
raise ValueError(f"{prefix}: unexected data range, min={img.min()}.") |
|
img = img.squeeze(0).cpu().numpy() |
|
saturation_value = np.percentile(img, saturation_percentile) |
|
img = np.clip(img * 255 / saturation_value, 0, 255) |
|
img = img.astype(np.uint8) |
|
img = PIL.Image.fromarray(img) |
|
return img |
|
|
|
if uncertainty is None or isinstance(uncertainty, list) and any(o is None for o in uncertainty): |
|
raise ValueError("Input uncertainty is `None`") |
|
if isinstance(uncertainty, (np.ndarray, torch.Tensor)): |
|
uncertainty = MarigoldImageProcessor.expand_tensor_or_array(uncertainty) |
|
if isinstance(uncertainty, np.ndarray): |
|
uncertainty = MarigoldImageProcessor.numpy_to_pt(uncertainty) |
|
if not (uncertainty.ndim == 4 and uncertainty.shape[1] == 1): |
|
raise ValueError(f"Unexpected input shape={uncertainty.shape}, expecting [N,1,H,W].") |
|
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] |
|
elif isinstance(uncertainty, list): |
|
return [visualize_uncertainty_one(img, idx) for idx, img in enumerate(uncertainty)] |
|
else: |
|
raise ValueError(f"Unexpected input type: {type(uncertainty)}") |
|
|