rollingdepth / colorize.py
toshas's picture
initial commit
a45988a
# Author: Bingxin Ke
# Last modified: 2024-11-25
import concurrent.futures
from typing import Union
import matplotlib
import numpy as np
from tqdm import tqdm
def colorize_depth(
depth: np.ndarray,
min_depth: float,
max_depth: float,
cmap: str = "Spectral_r",
valid_mask: Union[np.ndarray, None] = None,
) -> np.ndarray:
assert len(depth.shape) >= 2, "Invalid dimension"
if depth.ndim < 3:
depth = depth[np.newaxis, :, :]
# colorize
cm = matplotlib.colormaps[cmap]
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
if valid_mask is not None:
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
if valid_mask.ndim < 3:
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
else:
valid_mask = valid_mask[:, np.newaxis, :, :]
valid_mask = np.repeat(valid_mask, 3, axis=1)
img_colored_np[~valid_mask] = 0
return img_colored_np
def colorize_depth_multi_thread(
depth: np.ndarray,
valid_mask: Union[np.ndarray, None] = None,
chunk_size: int = 4,
num_threads: int = 4,
color_map: str = "Spectral",
verbose: bool = False,
) -> np.ndarray:
depth = depth.squeeze(1)
assert 3 == depth.ndim
n_frame = depth.shape[0]
if valid_mask is None:
valid_depth = depth
else:
valid_depth = depth[valid_mask]
min_depth = valid_depth.min()
max_depth = valid_depth.max()
def process_chunk(chunk):
chunk = colorize_depth(
chunk, min_depth=min_depth, max_depth=max_depth, cmap=color_map
)
chunk = (chunk * 255).astype(np.uint8)
return chunk
# Pre-allocate the full array
colored = np.empty((*depth.shape[:3], 3), dtype=np.uint8)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit all tasks and store futures with their corresponding indices
future_to_index = {
executor.submit(process_chunk, depth[i : i + chunk_size]): i
for i in range(0, n_frame, chunk_size)
}
# Process futures in the order they were submitted
chunk_iterable = concurrent.futures.as_completed(future_to_index)
if verbose:
chunk_iterable = tqdm(
chunk_iterable,
desc=" colorizing",
leave=False,
total=len(future_to_index),
)
for future in chunk_iterable:
index = future_to_index[future]
start = index
end = min(index + chunk_size, n_frame)
result = future.result()
colored[start:end] = result
return colored