File size: 2,784 Bytes
a45988a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Author: Bingxin Ke
# Last modified: 2024-11-25

import concurrent.futures
from typing import Union

import matplotlib
import numpy as np
from tqdm import tqdm


def colorize_depth(
    depth: np.ndarray,
    min_depth: float,
    max_depth: float,
    cmap: str = "Spectral_r",
    valid_mask: Union[np.ndarray, None] = None,
) -> np.ndarray:
    assert len(depth.shape) >= 2, "Invalid dimension"

    if depth.ndim < 3:
        depth = depth[np.newaxis, :, :]

    # colorize
    cm = matplotlib.colormaps[cmap]
    depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
    img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3]  # value from 0 to 1

    if valid_mask is not None:
        valid_mask = valid_mask.squeeze()  # [H, W] or [B, H, W]
        if valid_mask.ndim < 3:
            valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
        else:
            valid_mask = valid_mask[:, np.newaxis, :, :]
        valid_mask = np.repeat(valid_mask, 3, axis=1)
        img_colored_np[~valid_mask] = 0

    return img_colored_np


def colorize_depth_multi_thread(
    depth: np.ndarray,
    valid_mask: Union[np.ndarray, None] = None,
    chunk_size: int = 4,
    num_threads: int = 4,
    color_map: str = "Spectral",
    verbose: bool = False,
) -> np.ndarray:
    depth = depth.squeeze(1)
    assert 3 == depth.ndim

    n_frame = depth.shape[0]

    if valid_mask is None:
        valid_depth = depth
    else:
        valid_depth = depth[valid_mask]
    min_depth = valid_depth.min()
    max_depth = valid_depth.max()

    def process_chunk(chunk):
        chunk = colorize_depth(
            chunk, min_depth=min_depth, max_depth=max_depth, cmap=color_map
        )
        chunk = (chunk * 255).astype(np.uint8)
        return chunk

    # Pre-allocate the full array
    colored = np.empty((*depth.shape[:3], 3), dtype=np.uint8)

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
        # Submit all tasks and store futures with their corresponding indices
        future_to_index = {
            executor.submit(process_chunk, depth[i : i + chunk_size]): i
            for i in range(0, n_frame, chunk_size)
        }

        # Process futures in the order they were submitted
        chunk_iterable = concurrent.futures.as_completed(future_to_index)
        if verbose:
            chunk_iterable = tqdm(
                chunk_iterable,
                desc=" colorizing",
                leave=False,
                total=len(future_to_index),
            )
        for future in chunk_iterable:
            index = future_to_index[future]
            start = index
            end = min(index + chunk_size, n_frame)
            result = future.result()
            colored[start:end] = result
    return colored