""" Organize some frequently used visualization functions. """
import cv2
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import copy
import seaborn as sns


# Plot junctions onto the image (return a separate copy)
def plot_junctions(input_image, junctions, junc_size=3, color=None):
    """
    input_image: can be 0~1 float or 0~255 uint8.
    junctions: Nx2 or 2xN np array.
    junc_size: the size of the plotted circles.
    """
    # Create image copy
    image = copy.copy(input_image)
    # Make sure the image is converted to 255 uint8
    if image.dtype == np.uint8:
        pass
    # A float type image ranging from 0~1
    elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
        image = (image * 255.0).astype(np.uint8)
    # A float type image ranging from 0.~255.
    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
        image = image.astype(np.uint8)
    else:
        raise ValueError(
            "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
        )

    # Check whether the image is single channel
    if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
        # Squeeze to H*W first
        image = image.squeeze()

        # Stack to channle 3
        image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)

    # Junction dimensions should be N*2
    if not len(junctions.shape) == 2:
        raise ValueError("[Error] junctions should be 2-dim array.")

    # Always convert to N*2
    if junctions.shape[-1] != 2:
        if junctions.shape[0] == 2:
            junctions = junctions.T
        else:
            raise ValueError("[Error] At least one of the two dims should be 2.")

    # Round and convert junctions to int (and check the boundary)
    H, W = image.shape[:2]
    junctions = (np.round(junctions)).astype(np.int)
    junctions[junctions < 0] = 0
    junctions[junctions[:, 0] >= H, 0] = H - 1  # (first dim) max bounded by H-1
    junctions[junctions[:, 1] >= W, 1] = W - 1  # (second dim) max bounded by W-1

    # Iterate through all the junctions
    num_junc = junctions.shape[0]
    if color is None:
        color = (0, 255.0, 0)
    for idx in range(num_junc):
        # Fetch one junction
        junc = junctions[idx, :]
        cv2.circle(
            image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
        )

    return image


# Plot line segements given junctions and line adjecent map
def plot_line_segments(
    input_image,
    junctions,
    line_map,
    junc_size=3,
    color=(0, 255.0, 0),
    line_width=1,
    plot_survived_junc=True,
):
    """
    input_image: can be 0~1 float or 0~255 uint8.
    junctions: Nx2 or 2xN np array.
    line_map: NxN np array
    junc_size: the size of the plotted circles.
    color: color of the line segments (can be string "random")
    line_width: width of the drawn segments.
    plot_survived_junc: whether we only plot the survived junctions.
    """
    # Create image copy
    image = copy.copy(input_image)
    # Make sure the image is converted to 255 uint8
    if image.dtype == np.uint8:
        pass
    # A float type image ranging from 0~1
    elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
        image = (image * 255.0).astype(np.uint8)
    # A float type image ranging from 0.~255.
    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
        image = image.astype(np.uint8)
    else:
        raise ValueError(
            "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
        )

    # Check whether the image is single channel
    if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
        # Squeeze to H*W first
        image = image.squeeze()

        # Stack to channle 3
        image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)

    # Junction dimensions should be 2
    if not len(junctions.shape) == 2:
        raise ValueError("[Error] junctions should be 2-dim array.")

    # Always convert to N*2
    if junctions.shape[-1] != 2:
        if junctions.shape[0] == 2:
            junctions = junctions.T
        else:
            raise ValueError("[Error] At least one of the two dims should be 2.")

    # line_map dimension should be 2
    if not len(line_map.shape) == 2:
        raise ValueError("[Error] line_map should be 2-dim array.")

    # Color should be "random" or a list or tuple with length 3
    if color != "random":
        if not (isinstance(color, tuple) or isinstance(color, list)):
            raise ValueError("[Error] color should have type list or tuple.")
        else:
            if len(color) != 3:
                raise ValueError(
                    "[Error] color should be a list or tuple with length 3."
                )

    # Make a copy of the line_map
    line_map_tmp = copy.copy(line_map)

    # Parse line_map back to segment pairs
    segments = np.zeros([0, 4])
    for idx in range(junctions.shape[0]):
        # if no connectivity, just skip it
        if line_map_tmp[idx, :].sum() == 0:
            continue
        # record the line segment
        else:
            for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
                p1 = np.flip(junctions[idx, :])  # Convert to xy format
                p2 = np.flip(junctions[idx2, :])  # Convert to xy format
                segments = np.concatenate(
                    (segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
                    axis=0,
                )

                # Update line_map
                line_map_tmp[idx, idx2] = 0
                line_map_tmp[idx2, idx] = 0

    # Draw segment pairs
    for idx in range(segments.shape[0]):
        seg = np.round(segments[idx, :]).astype(np.int)
        # Decide the color
        if color != "random":
            color = tuple(color)
        else:
            color = tuple(
                np.random.rand(
                    3,
                )
            )
        cv2.line(
            image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
        )

    # Also draw the junctions
    if not plot_survived_junc:
        num_junc = junctions.shape[0]
        for idx in range(num_junc):
            # Fetch one junction
            junc = junctions[idx, :]
            cv2.circle(
                image,
                tuple(np.flip(junc)),
                radius=junc_size,
                color=(0, 255.0, 0),
                thickness=3,
            )
    # Only plot the junctions which are part of a line segment
    else:
        for idx in range(segments.shape[0]):
            seg = np.round(segments[idx, :]).astype(np.int)  # Already in HW format.
            cv2.circle(
                image,
                tuple(seg[:2]),
                radius=junc_size,
                color=(0, 255.0, 0),
                thickness=3,
            )
            cv2.circle(
                image,
                tuple(seg[2:]),
                radius=junc_size,
                color=(0, 255.0, 0),
                thickness=3,
            )

    return image


# Plot line segments given Nx4 or Nx2x2 line segments
def plot_line_segments_from_segments(
    input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
):
    # Create image copy
    image = copy.copy(input_image)
    # Make sure the image is converted to 255 uint8
    if image.dtype == np.uint8:
        pass
    # A float type image ranging from 0~1
    elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
        image = (image * 255.0).astype(np.uint8)
    # A float type image ranging from 0.~255.
    elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
        image = image.astype(np.uint8)
    else:
        raise ValueError(
            "[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
        )

    # Check whether the image is single channel
    if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
        # Squeeze to H*W first
        image = image.squeeze()

        # Stack to channle 3
        image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)

    # Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
    H, W, _ = image.shape
    # (1) Nx4 format
    if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
        # Round to int32
        line_segments = line_segments.astype(np.int32)

        # Clip H dimension
        line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
        line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)

        # Clip W dimension
        line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
        line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)

        # Convert to Nx2x2 format
        line_segments = np.concatenate(
            [
                np.expand_dims(line_segments[:, :2], axis=1),
                np.expand_dims(line_segments[:, 2:], axis=1),
            ],
            axis=1,
        )

    # (2) Nx2x2 format
    elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
        # Round to int32
        line_segments = line_segments.astype(np.int32)

        # Clip H dimension
        line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
        line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)

    else:
        raise ValueError(
            "[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
        )

    # Draw segment pairs (all segments should be in HW format)
    image = image.copy()
    for idx in range(line_segments.shape[0]):
        seg = np.round(line_segments[idx, :, :]).astype(np.int32)
        # Decide the color
        if color != "random":
            color = tuple(color)
        else:
            color = tuple(
                np.random.rand(
                    3,
                )
            )
        cv2.line(
            image,
            tuple(np.flip(seg[0, :])),
            tuple(np.flip(seg[1, :])),
            color=color,
            thickness=line_width,
        )

        # Also draw the junctions
        cv2.circle(
            image,
            tuple(np.flip(seg[0, :])),
            radius=junc_size,
            color=(0, 255.0, 0),
            thickness=3,
        )
        cv2.circle(
            image,
            tuple(np.flip(seg[1, :])),
            radius=junc_size,
            color=(0, 255.0, 0),
            thickness=3,
        )

    return image


# Additional functions to visualize multiple images at the same time,
# e.g. for line matching
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5):
    """Plot a set of images horizontally.
    Args:
        imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
        titles: a list of strings, as titles for each image.
        cmaps: colormaps for monochrome images.
    """
    n = len(imgs)
    if not isinstance(cmaps, (list, tuple)):
        cmaps = [cmaps] * n
    figsize = (size * n, size * 3 / 4) if size is not None else None
    fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
    if n == 1:
        ax = [ax]
    for i in range(n):
        ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
        ax[i].get_yaxis().set_ticks([])
        ax[i].get_xaxis().set_ticks([])
        ax[i].set_axis_off()
        for spine in ax[i].spines.values():  # remove frame
            spine.set_visible(False)
        if titles:
            ax[i].set_title(titles[i])
    fig.tight_layout(pad=pad)


def plot_keypoints(kpts, colors="lime", ps=4):
    """Plot keypoints for existing images.
    Args:
        kpts: list of ndarrays of size (N, 2).
        colors: string, or list of list of tuples (one for each keypoints).
        ps: size of the keypoints as float.
    """
    if not isinstance(colors, list):
        colors = [colors] * len(kpts)
    axes = plt.gcf().axes
    for a, k, c in zip(axes, kpts, colors):
        a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)


def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
    """Plot matches for a pair of existing images.
    Args:
        kpts0, kpts1: corresponding keypoints of size (N, 2).
        color: color of each match, string or RGB tuple. Random if not given.
        lw: width of the lines.
        ps: size of the end points (no endpoint if ps=0)
        indices: indices of the images to draw the matches on.
        a: alpha opacity of the match lines.
    """
    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    ax0, ax1 = ax[indices[0]], ax[indices[1]]
    fig.canvas.draw()

    assert len(kpts0) == len(kpts1)
    if color is None:
        color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
    elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
        color = [color] * len(kpts0)

    if lw > 0:
        # transform the points into the figure coordinate system
        transFigure = fig.transFigure.inverted()
        fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
        fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
        fig.lines += [
            matplotlib.lines.Line2D(
                (fkpts0[i, 0], fkpts1[i, 0]),
                (fkpts0[i, 1], fkpts1[i, 1]),
                zorder=1,
                transform=fig.transFigure,
                c=color[i],
                linewidth=lw,
                alpha=a,
            )
            for i in range(len(kpts0))
        ]

    # freeze the axes to prevent the transform to change
    ax0.autoscale(enable=False)
    ax1.autoscale(enable=False)

    if ps > 0:
        ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
        ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)


def plot_lines(
    lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
):
    """Plot lines and endpoints for existing images.
    Args:
        lines: list of ndarrays of size (N, 2, 2).
        colors: string, or list of list of tuples (one for each keypoints).
        ps: size of the keypoints as float pixels.
        lw: line width as float pixels.
        indices: indices of the images to draw the matches on.
    """
    if not isinstance(line_colors, list):
        line_colors = [line_colors] * len(lines)
    if not isinstance(point_colors, list):
        point_colors = [point_colors] * len(lines)

    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    axes = [ax[i] for i in indices]
    fig.canvas.draw()

    # Plot the lines and junctions
    for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
        for i in range(len(l)):
            line = matplotlib.lines.Line2D(
                (l[i, 0, 0], l[i, 1, 0]),
                (l[i, 0, 1], l[i, 1, 1]),
                zorder=1,
                c=lc,
                linewidth=lw,
            )
            a.add_line(line)
        pts = l.reshape(-1, 2)
        a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)


def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
    """Plot matches for a pair of existing images, parametrized by their middle point.
    Args:
        kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
        color: color of each match, string or RGB tuple. Random if not given.
        lw: width of the lines.
        indices: indices of the images to draw the matches on.
        a: alpha opacity of the match lines.
    """
    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    ax0, ax1 = ax[indices[0]], ax[indices[1]]
    fig.canvas.draw()

    assert len(kpts0) == len(kpts1)
    if color is None:
        color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
    elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
        color = [color] * len(kpts0)

    if lw > 0:
        # transform the points into the figure coordinate system
        transFigure = fig.transFigure.inverted()
        fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
        fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
        fig.lines += [
            matplotlib.lines.Line2D(
                (fkpts0[i, 0], fkpts1[i, 0]),
                (fkpts0[i, 1], fkpts1[i, 1]),
                zorder=1,
                transform=fig.transFigure,
                c=color[i],
                linewidth=lw,
                alpha=a,
            )
            for i in range(len(kpts0))
        ]

    # freeze the axes to prevent the transform to change
    ax0.autoscale(enable=False)
    ax1.autoscale(enable=False)


def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
    """Plot line matches for existing images with multiple colors.
    Args:
        lines: list of ndarrays of size (N, 2, 2).
        correct_matches: bool array of size (N,) indicating correct matches.
        lw: line width as float pixels.
        indices: indices of the images to draw the matches on.
    """
    n_lines = len(lines[0])
    colors = sns.color_palette("husl", n_colors=n_lines)
    np.random.shuffle(colors)
    alphas = np.ones(n_lines)
    # If correct_matches is not None, display wrong matches with a low alpha
    if correct_matches is not None:
        alphas[~np.array(correct_matches)] = 0.2

    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    axes = [ax[i] for i in indices]
    fig.canvas.draw()

    # Plot the lines
    for a, l in zip(axes, lines):
        # Transform the points into the figure coordinate system
        transFigure = fig.transFigure.inverted()
        endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
        endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
        fig.lines += [
            matplotlib.lines.Line2D(
                (endpoint0[i, 0], endpoint1[i, 0]),
                (endpoint0[i, 1], endpoint1[i, 1]),
                zorder=1,
                transform=fig.transFigure,
                c=colors[i],
                alpha=alphas[i],
                linewidth=lw,
            )
            for i in range(n_lines)
        ]


def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
    """Plot line matches for existing images with multiple colors:
    green for correct matches, red for wrong ones, and blue for the rest.
    Args:
        lines: list of ndarrays of size (N, 2, 2).
        correct_matches: list of bool arrays of size N with correct matches.
        wrong_matches: list of bool arrays of size (N,) with correct matches.
        lw: line width as float pixels.
        indices: indices of the images to draw the matches on.
    """
    # palette = sns.color_palette()
    palette = sns.color_palette("hls", 8)
    blue = palette[5]  # palette[0]
    red = palette[0]  # palette[3]
    green = palette[2]  # palette[2]
    colors = [np.array([blue] * len(l)) for l in lines]
    for i, c in enumerate(colors):
        c[np.array(correct_matches[i])] = green
        c[np.array(wrong_matches[i])] = red

    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    axes = [ax[i] for i in indices]
    fig.canvas.draw()

    # Plot the lines
    for a, l, c in zip(axes, lines, colors):
        # Transform the points into the figure coordinate system
        transFigure = fig.transFigure.inverted()
        endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
        endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
        fig.lines += [
            matplotlib.lines.Line2D(
                (endpoint0[i, 0], endpoint1[i, 0]),
                (endpoint0[i, 1], endpoint1[i, 1]),
                zorder=1,
                transform=fig.transFigure,
                c=c[i],
                linewidth=lw,
            )
            for i in range(len(l))
        ]


def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
    """Plot line matches for existing images with multiple colors and
        highlight the actually matched subsegments.
    Args:
        lines: list of ndarrays of size (N, 2, 2).
        subsegments: list of ndarrays of size (N, 2, 2).
        lw: line width as float pixels.
        indices: indices of the images to draw the matches on.
    """
    n_lines = len(lines[0])
    colors = sns.cubehelix_palette(
        start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
    )

    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    axes = [ax[i] for i in indices]
    fig.canvas.draw()

    # Plot the lines
    for a, l, ss in zip(axes, lines, subsegments):
        # Transform the points into the figure coordinate system
        transFigure = fig.transFigure.inverted()

        # Draw full line
        endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
        endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
        fig.lines += [
            matplotlib.lines.Line2D(
                (endpoint0[i, 0], endpoint1[i, 0]),
                (endpoint0[i, 1], endpoint1[i, 1]),
                zorder=1,
                transform=fig.transFigure,
                c="red",
                alpha=0.7,
                linewidth=lw,
            )
            for i in range(n_lines)
        ]

        # Draw matched subsegment
        endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
        endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
        fig.lines += [
            matplotlib.lines.Line2D(
                (endpoint0[i, 0], endpoint1[i, 0]),
                (endpoint0[i, 1], endpoint1[i, 1]),
                zorder=1,
                transform=fig.transFigure,
                c=colors[i],
                alpha=1,
                linewidth=lw,
            )
            for i in range(n_lines)
        ]