# Copyright (c) Facebook, Inc. and its affiliates.

"""
An awesome colormap for really neat visualizations.
Copied from Detectron, and removed gray colors.
"""

import numpy as np
import random
random.seed(0)

__all__ = ["colormap", "random_color", "random_colors"]

# fmt: off
# RGB:
# _COLORS = np.array(
#     [
#         0.000, 0.447, 0.741,
#         0.850, 0.325, 0.098,
#         0.929, 0.694, 0.125,
#         0.494, 0.184, 0.556,
#         0.466, 0.674, 0.188,
#         0.301, 0.745, 0.933,
#         0.635, 0.078, 0.184,
#         0.300, 0.300, 0.300,
#         0.600, 0.600, 0.600,
#         1.000, 0.000, 0.000,
#         1.000, 0.500, 0.000,
#         0.749, 0.749, 0.000,
#         0.000, 1.000, 0.000,
#         0.000, 0.000, 1.000,
#         0.667, 0.000, 1.000,
#         0.333, 0.333, 0.000,
#         0.333, 0.667, 0.000,
#         0.333, 1.000, 0.000,
#         0.667, 0.333, 0.000,
#         0.667, 0.667, 0.000,
#         0.667, 1.000, 0.000,
#         1.000, 0.333, 0.000,
#         1.000, 0.667, 0.000,
#         1.000, 1.000, 0.000,
#         0.000, 0.333, 0.500,
#         0.000, 0.667, 0.500,
#         0.000, 1.000, 0.500,
#         0.333, 0.000, 0.500,
#         0.333, 0.333, 0.500,
#         0.333, 0.667, 0.500,
#         0.333, 1.000, 0.500,
#         0.667, 0.000, 0.500,
#         0.667, 0.333, 0.500,
#         0.667, 0.667, 0.500,
#         0.667, 1.000, 0.500,
#         1.000, 0.000, 0.500,
#         1.000, 0.333, 0.500,
#         1.000, 0.667, 0.500,
#         1.000, 1.000, 0.500,
#         0.000, 0.333, 1.000,
#         0.000, 0.667, 1.000,
#         0.000, 1.000, 1.000,
#         0.333, 0.000, 1.000,
#         0.333, 0.333, 1.000,
#         0.333, 0.667, 1.000,
#         0.333, 1.000, 1.000,
#         0.667, 0.000, 1.000,
#         0.667, 0.333, 1.000,
#         0.667, 0.667, 1.000,
#         0.667, 1.000, 1.000,
#         1.000, 0.000, 1.000,
#         1.000, 0.333, 1.000,
#         1.000, 0.667, 1.000,
#         0.333, 0.000, 0.000,
#         0.500, 0.000, 0.000,
#         0.667, 0.000, 0.000,
#         0.833, 0.000, 0.000,
#         1.000, 0.000, 0.000,
#         0.000, 0.167, 0.000,
#         0.000, 0.333, 0.000,
#         0.000, 0.500, 0.000,
#         0.000, 0.667, 0.000,
#         0.000, 0.833, 0.000,
#         0.000, 1.000, 0.000,
#         0.000, 0.000, 0.167,
#         0.000, 0.000, 0.333,
#         0.000, 0.000, 0.500,
#         0.000, 0.000, 0.667,
#         0.000, 0.000, 0.833,
#         0.000, 0.000, 1.000,
#         0.000, 0.000, 0.000,
#         0.143, 0.143, 0.143,
#         0.857, 0.857, 0.857,
#         1.000, 1.000, 1.000
#     ]
# ).astype(np.float32).reshape(-1, 3)
# fmt: on

_COLORS = []


def gen_color():
    color = tuple(np.round(np.random.choice(range(256), size=3)/255, 3))
    if color not in _COLORS and np.mean(color) != 0.0:
        _COLORS.append(color)
    else:
        gen_color()


for _ in range(300):
    gen_color()


def colormap(rgb=False, maximum=255):
    """
    Args:
        rgb (bool): whether to return RGB colors or BGR colors.
        maximum (int): either 255 or 1
    Returns:
        ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
    """
    assert maximum in [255, 1], maximum
    c = _COLORS * maximum
    if not rgb:
        c = c[:, ::-1]
    return c


def random_color(rgb=False, maximum=255):
    """
    Args:
        rgb (bool): whether to return RGB colors or BGR colors.
        maximum (int): either 255 or 1
    Returns:
        ndarray: a vector of 3 numbers
    """
    idx = np.random.randint(0, len(_COLORS))
    ret = _COLORS[idx] * maximum
    if not rgb:
        ret = ret[::-1]
    return ret


def random_colors(N, rgb=False, maximum=255):
    """
    Args:
        N (int): number of unique colors needed
        rgb (bool): whether to return RGB colors or BGR colors.
        maximum (int): either 255 or 1
    Returns:
        ndarray: a list of random_color
    """
    indices = random.sample(range(len(_COLORS)), N)
    ret = [_COLORS[i] * maximum for i in indices]
    if not rgb:
        ret = [x[::-1] for x in ret]
    return ret


if __name__ == "__main__":
    import cv2

    size = 100
    H, W = 10, 10
    canvas = np.random.rand(H * size, W * size, 3).astype("float32")
    for h in range(H):
        for w in range(W):
            idx = h * W + w
            if idx >= len(_COLORS):
                break
            canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
    cv2.imshow("a", canvas)
    cv2.waitKey(0)