Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union | |
import cv2 | |
import numpy as np | |
import torch | |
if TYPE_CHECKING: | |
from matplotlib.backends.backend_agg import FigureCanvasAgg | |
def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: | |
"""If the type of value is torch.Tensor, convert the value to np.ndarray. | |
Args: | |
value (np.ndarray, torch.Tensor): value. | |
Returns: | |
Any: value. | |
""" | |
if isinstance(value, torch.Tensor): | |
value = value.detach().cpu().numpy() | |
return value | |
def value2list(value: Any, valid_type: Union[Type, Tuple[Type, ...]], | |
expand_dim: int) -> List[Any]: | |
"""If the type of ``value`` is ``valid_type``, convert the value to list | |
and expand to ``expand_dim``. | |
Args: | |
value (Any): value. | |
valid_type (Union[Type, Tuple[Type, ...]): valid type. | |
expand_dim (int): expand dim. | |
Returns: | |
List[Any]: value. | |
""" | |
if isinstance(value, valid_type): | |
value = [value] * expand_dim | |
return value | |
def check_type(name: str, value: Any, | |
valid_type: Union[Type, Tuple[Type, ...]]) -> None: | |
"""Check whether the type of value is in ``valid_type``. | |
Args: | |
name (str): value name. | |
value (Any): value. | |
valid_type (Type, Tuple[Type, ...]): expected type. | |
""" | |
if not isinstance(value, valid_type): | |
raise TypeError(f'`{name}` should be {valid_type} ' | |
f' but got {type(value)}') | |
def check_length(name: str, value: Any, valid_length: int) -> None: | |
"""If type of the ``value`` is list, check whether its length is equal with | |
or greater than ``valid_length``. | |
Args: | |
name (str): value name. | |
value (Any): value. | |
valid_length (int): expected length. | |
""" | |
if isinstance(value, list): | |
if len(value) < valid_length: | |
raise AssertionError( | |
f'The length of {name} must equal with or ' | |
f'greater than {valid_length}, but got {len(value)}') | |
def check_type_and_length(name: str, value: Any, | |
valid_type: Union[Type, Tuple[Type, ...]], | |
valid_length: int) -> None: | |
"""Check whether the type of value is in ``valid_type``. If type of the | |
``value`` is list, check whether its length is equal with or greater than | |
``valid_length``. | |
Args: | |
value (Any): value. | |
legal_type (Type, Tuple[Type, ...]): legal type. | |
valid_length (int): expected length. | |
Returns: | |
List[Any]: value. | |
""" | |
check_type(name, value, valid_type) | |
check_length(name, value, valid_length) | |
def color_val_matplotlib( | |
colors: Union[str, tuple, List[Union[str, tuple]]] | |
) -> Union[str, tuple, List[Union[str, tuple]]]: | |
"""Convert various input in RGB order to normalized RGB matplotlib color | |
tuples, | |
Args: | |
colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs | |
Returns: | |
Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized | |
floats indicating RGB channels. | |
""" | |
if isinstance(colors, str): | |
return colors | |
elif isinstance(colors, tuple): | |
assert len(colors) == 3 | |
for channel in colors: | |
assert 0 <= channel <= 255 | |
colors = [channel / 255 for channel in colors] | |
return tuple(colors) | |
elif isinstance(colors, list): | |
colors = [ | |
color_val_matplotlib(color) # type:ignore | |
for color in colors | |
] | |
return colors | |
else: | |
raise TypeError(f'Invalid type for color: {type(colors)}') | |
def color_str2rgb(color: str) -> tuple: | |
"""Convert Matplotlib str color to an RGB color which range is 0 to 255, | |
silently dropping the alpha channel. | |
Args: | |
color (str): Matplotlib color. | |
Returns: | |
tuple: RGB color. | |
""" | |
import matplotlib | |
rgb_color: tuple = matplotlib.colors.to_rgb(color) | |
rgb_color = tuple(int(c * 255) for c in rgb_color) | |
return rgb_color | |
def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], | |
img: Optional[np.ndarray] = None, | |
alpha: float = 0.5) -> np.ndarray: | |
"""Convert feat_map to heatmap and overlay on image, if image is not None. | |
Args: | |
feat_map (np.ndarray, torch.Tensor): The feat_map to convert | |
with of shape (H, W), where H is the image height and W is | |
the image width. | |
img (np.ndarray, optional): The origin image. The format | |
should be RGB. Defaults to None. | |
alpha (float): The transparency of featmap. Defaults to 0.5. | |
Returns: | |
np.ndarray: heatmap | |
""" | |
assert feat_map.ndim == 2 or (feat_map.ndim == 3 | |
and feat_map.shape[0] in [1, 3]) | |
if isinstance(feat_map, torch.Tensor): | |
feat_map = feat_map.detach().cpu().numpy() | |
if feat_map.ndim == 3: | |
feat_map = feat_map.transpose(1, 2, 0) | |
norm_img = np.zeros(feat_map.shape) | |
norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) | |
norm_img = np.asarray(norm_img, dtype=np.uint8) | |
heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) | |
heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) | |
if img is not None: | |
heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) | |
return heat_img | |
def wait_continue(figure, timeout: float = 0, continue_key: str = ' ') -> int: | |
"""Show the image and wait for the user's input. | |
This implementation refers to | |
https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py | |
Args: | |
timeout (float): If positive, continue after ``timeout`` seconds. | |
Defaults to 0. | |
continue_key (str): The key for users to continue. Defaults to | |
the space key. | |
Returns: | |
int: If zero, means time out or the user pressed ``continue_key``, | |
and if one, means the user closed the show figure. | |
""" # noqa: E501 | |
import matplotlib.pyplot as plt | |
from matplotlib.backend_bases import CloseEvent | |
is_inline = 'inline' in plt.get_backend() | |
if is_inline: | |
# If use inline backend, interactive input and timeout is no use. | |
return 0 | |
if figure.canvas.manager: # type: ignore | |
# Ensure that the figure is shown | |
figure.show() # type: ignore | |
while True: | |
# Connect the events to the handler function call. | |
event = None | |
def handler(ev): | |
# Set external event variable | |
nonlocal event | |
# Qt backend may fire two events at the same time, | |
# use a condition to avoid missing close event. | |
event = ev if not isinstance(event, CloseEvent) else event | |
figure.canvas.stop_event_loop() | |
cids = [ | |
figure.canvas.mpl_connect(name, handler) # type: ignore | |
for name in ('key_press_event', 'close_event') | |
] | |
try: | |
figure.canvas.start_event_loop(timeout) # type: ignore | |
finally: # Run even on exception like ctrl-c. | |
# Disconnect the callbacks. | |
for cid in cids: | |
figure.canvas.mpl_disconnect(cid) # type: ignore | |
if isinstance(event, CloseEvent): | |
return 1 # Quit for close. | |
elif event is None or event.key == continue_key: | |
return 0 # Quit for continue. | |
def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray: | |
"""Get RGB image from ``FigureCanvasAgg``. | |
Args: | |
canvas (FigureCanvasAgg): The canvas to get image. | |
Returns: | |
np.ndarray: the output of image in RGB. | |
""" # noqa: E501 | |
s, (width, height) = canvas.print_to_buffer() | |
buffer = np.frombuffer(s, dtype='uint8') | |
img_rgba = buffer.reshape(height, width, 4) | |
rgb, alpha = np.split(img_rgba, [3], axis=2) | |
return rgb.astype('uint8') | |