|
""" Organize some frequently used visualization functions. """ |
|
import cv2 |
|
import numpy as np |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
import copy |
|
import seaborn as sns |
|
|
|
|
|
|
|
def plot_junctions(input_image, junctions, junc_size=3, color=None): |
|
""" |
|
input_image: can be 0~1 float or 0~255 uint8. |
|
junctions: Nx2 or 2xN np array. |
|
junc_size: the size of the plotted circles. |
|
""" |
|
|
|
image = copy.copy(input_image) |
|
|
|
if image.dtype == np.uint8: |
|
pass |
|
|
|
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0: |
|
image = (image * 255.0).astype(np.uint8) |
|
|
|
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0: |
|
image = image.astype(np.uint8) |
|
else: |
|
raise ValueError( |
|
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8." |
|
) |
|
|
|
|
|
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): |
|
|
|
image = image.squeeze() |
|
|
|
|
|
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) |
|
|
|
|
|
if not len(junctions.shape) == 2: |
|
raise ValueError("[Error] junctions should be 2-dim array.") |
|
|
|
|
|
if junctions.shape[-1] != 2: |
|
if junctions.shape[0] == 2: |
|
junctions = junctions.T |
|
else: |
|
raise ValueError("[Error] At least one of the two dims should be 2.") |
|
|
|
|
|
H, W = image.shape[:2] |
|
junctions = (np.round(junctions)).astype(np.int) |
|
junctions[junctions < 0] = 0 |
|
junctions[junctions[:, 0] >= H, 0] = H - 1 |
|
junctions[junctions[:, 1] >= W, 1] = W - 1 |
|
|
|
|
|
num_junc = junctions.shape[0] |
|
if color is None: |
|
color = (0, 255.0, 0) |
|
for idx in range(num_junc): |
|
|
|
junc = junctions[idx, :] |
|
cv2.circle( |
|
image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3 |
|
) |
|
|
|
return image |
|
|
|
|
|
|
|
def plot_line_segments( |
|
input_image, |
|
junctions, |
|
line_map, |
|
junc_size=3, |
|
color=(0, 255.0, 0), |
|
line_width=1, |
|
plot_survived_junc=True, |
|
): |
|
""" |
|
input_image: can be 0~1 float or 0~255 uint8. |
|
junctions: Nx2 or 2xN np array. |
|
line_map: NxN np array |
|
junc_size: the size of the plotted circles. |
|
color: color of the line segments (can be string "random") |
|
line_width: width of the drawn segments. |
|
plot_survived_junc: whether we only plot the survived junctions. |
|
""" |
|
|
|
image = copy.copy(input_image) |
|
|
|
if image.dtype == np.uint8: |
|
pass |
|
|
|
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0: |
|
image = (image * 255.0).astype(np.uint8) |
|
|
|
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0: |
|
image = image.astype(np.uint8) |
|
else: |
|
raise ValueError( |
|
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8." |
|
) |
|
|
|
|
|
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): |
|
|
|
image = image.squeeze() |
|
|
|
|
|
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) |
|
|
|
|
|
if not len(junctions.shape) == 2: |
|
raise ValueError("[Error] junctions should be 2-dim array.") |
|
|
|
|
|
if junctions.shape[-1] != 2: |
|
if junctions.shape[0] == 2: |
|
junctions = junctions.T |
|
else: |
|
raise ValueError("[Error] At least one of the two dims should be 2.") |
|
|
|
|
|
if not len(line_map.shape) == 2: |
|
raise ValueError("[Error] line_map should be 2-dim array.") |
|
|
|
|
|
if color != "random": |
|
if not (isinstance(color, tuple) or isinstance(color, list)): |
|
raise ValueError("[Error] color should have type list or tuple.") |
|
else: |
|
if len(color) != 3: |
|
raise ValueError( |
|
"[Error] color should be a list or tuple with length 3." |
|
) |
|
|
|
|
|
line_map_tmp = copy.copy(line_map) |
|
|
|
|
|
segments = np.zeros([0, 4]) |
|
for idx in range(junctions.shape[0]): |
|
|
|
if line_map_tmp[idx, :].sum() == 0: |
|
continue |
|
|
|
else: |
|
for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]: |
|
p1 = np.flip(junctions[idx, :]) |
|
p2 = np.flip(junctions[idx2, :]) |
|
segments = np.concatenate( |
|
(segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]), |
|
axis=0, |
|
) |
|
|
|
|
|
line_map_tmp[idx, idx2] = 0 |
|
line_map_tmp[idx2, idx] = 0 |
|
|
|
|
|
for idx in range(segments.shape[0]): |
|
seg = np.round(segments[idx, :]).astype(np.int) |
|
|
|
if color != "random": |
|
color = tuple(color) |
|
else: |
|
color = tuple( |
|
np.random.rand( |
|
3, |
|
) |
|
) |
|
cv2.line( |
|
image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width |
|
) |
|
|
|
|
|
if not plot_survived_junc: |
|
num_junc = junctions.shape[0] |
|
for idx in range(num_junc): |
|
|
|
junc = junctions[idx, :] |
|
cv2.circle( |
|
image, |
|
tuple(np.flip(junc)), |
|
radius=junc_size, |
|
color=(0, 255.0, 0), |
|
thickness=3, |
|
) |
|
|
|
else: |
|
for idx in range(segments.shape[0]): |
|
seg = np.round(segments[idx, :]).astype(np.int) |
|
cv2.circle( |
|
image, |
|
tuple(seg[:2]), |
|
radius=junc_size, |
|
color=(0, 255.0, 0), |
|
thickness=3, |
|
) |
|
cv2.circle( |
|
image, |
|
tuple(seg[2:]), |
|
radius=junc_size, |
|
color=(0, 255.0, 0), |
|
thickness=3, |
|
) |
|
|
|
return image |
|
|
|
|
|
|
|
def plot_line_segments_from_segments( |
|
input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1 |
|
): |
|
|
|
image = copy.copy(input_image) |
|
|
|
if image.dtype == np.uint8: |
|
pass |
|
|
|
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0: |
|
image = (image * 255.0).astype(np.uint8) |
|
|
|
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0: |
|
image = image.astype(np.uint8) |
|
else: |
|
raise ValueError( |
|
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8." |
|
) |
|
|
|
|
|
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)): |
|
|
|
image = image.squeeze() |
|
|
|
|
|
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1) |
|
|
|
|
|
H, W, _ = image.shape |
|
|
|
if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4: |
|
|
|
line_segments = line_segments.astype(np.int32) |
|
|
|
|
|
line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1) |
|
line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1) |
|
|
|
|
|
line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1) |
|
line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1) |
|
|
|
|
|
line_segments = np.concatenate( |
|
[ |
|
np.expand_dims(line_segments[:, :2], axis=1), |
|
np.expand_dims(line_segments[:, 2:], axis=1), |
|
], |
|
axis=1, |
|
) |
|
|
|
|
|
elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2: |
|
|
|
line_segments = line_segments.astype(np.int32) |
|
|
|
|
|
line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1) |
|
line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1) |
|
|
|
else: |
|
raise ValueError( |
|
"[Error] line_segments should be either Nx4 or Nx2x2 in HW format." |
|
) |
|
|
|
|
|
image = image.copy() |
|
for idx in range(line_segments.shape[0]): |
|
seg = np.round(line_segments[idx, :, :]).astype(np.int32) |
|
|
|
if color != "random": |
|
color = tuple(color) |
|
else: |
|
color = tuple( |
|
np.random.rand( |
|
3, |
|
) |
|
) |
|
cv2.line( |
|
image, |
|
tuple(np.flip(seg[0, :])), |
|
tuple(np.flip(seg[1, :])), |
|
color=color, |
|
thickness=line_width, |
|
) |
|
|
|
|
|
cv2.circle( |
|
image, |
|
tuple(np.flip(seg[0, :])), |
|
radius=junc_size, |
|
color=(0, 255.0, 0), |
|
thickness=3, |
|
) |
|
cv2.circle( |
|
image, |
|
tuple(np.flip(seg[1, :])), |
|
radius=junc_size, |
|
color=(0, 255.0, 0), |
|
thickness=3, |
|
) |
|
|
|
return image |
|
|
|
|
|
|
|
|
|
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5): |
|
"""Plot a set of images horizontally. |
|
Args: |
|
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). |
|
titles: a list of strings, as titles for each image. |
|
cmaps: colormaps for monochrome images. |
|
""" |
|
n = len(imgs) |
|
if not isinstance(cmaps, (list, tuple)): |
|
cmaps = [cmaps] * n |
|
figsize = (size * n, size * 3 / 4) if size is not None else None |
|
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi) |
|
if n == 1: |
|
ax = [ax] |
|
for i in range(n): |
|
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) |
|
ax[i].get_yaxis().set_ticks([]) |
|
ax[i].get_xaxis().set_ticks([]) |
|
ax[i].set_axis_off() |
|
for spine in ax[i].spines.values(): |
|
spine.set_visible(False) |
|
if titles: |
|
ax[i].set_title(titles[i]) |
|
fig.tight_layout(pad=pad) |
|
|
|
|
|
def plot_keypoints(kpts, colors="lime", ps=4): |
|
"""Plot keypoints for existing images. |
|
Args: |
|
kpts: list of ndarrays of size (N, 2). |
|
colors: string, or list of list of tuples (one for each keypoints). |
|
ps: size of the keypoints as float. |
|
""" |
|
if not isinstance(colors, list): |
|
colors = [colors] * len(kpts) |
|
axes = plt.gcf().axes |
|
for a, k, c in zip(axes, kpts, colors): |
|
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) |
|
|
|
|
|
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): |
|
"""Plot matches for a pair of existing images. |
|
Args: |
|
kpts0, kpts1: corresponding keypoints of size (N, 2). |
|
color: color of each match, string or RGB tuple. Random if not given. |
|
lw: width of the lines. |
|
ps: size of the end points (no endpoint if ps=0) |
|
indices: indices of the images to draw the matches on. |
|
a: alpha opacity of the match lines. |
|
""" |
|
fig = plt.gcf() |
|
ax = fig.axes |
|
assert len(ax) > max(indices) |
|
ax0, ax1 = ax[indices[0]], ax[indices[1]] |
|
fig.canvas.draw() |
|
|
|
assert len(kpts0) == len(kpts1) |
|
if color is None: |
|
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() |
|
elif len(color) > 0 and not isinstance(color[0], (tuple, list)): |
|
color = [color] * len(kpts0) |
|
|
|
if lw > 0: |
|
|
|
transFigure = fig.transFigure.inverted() |
|
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) |
|
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) |
|
fig.lines += [ |
|
matplotlib.lines.Line2D( |
|
(fkpts0[i, 0], fkpts1[i, 0]), |
|
(fkpts0[i, 1], fkpts1[i, 1]), |
|
zorder=1, |
|
transform=fig.transFigure, |
|
c=color[i], |
|
linewidth=lw, |
|
alpha=a, |
|
) |
|
for i in range(len(kpts0)) |
|
] |
|
|
|
|
|
ax0.autoscale(enable=False) |
|
ax1.autoscale(enable=False) |
|
|
|
if ps > 0: |
|
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2) |
|
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2) |
|
|
|
|
|
def plot_lines( |
|
lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1) |
|
): |
|
"""Plot lines and endpoints for existing images. |
|
Args: |
|
lines: list of ndarrays of size (N, 2, 2). |
|
colors: string, or list of list of tuples (one for each keypoints). |
|
ps: size of the keypoints as float pixels. |
|
lw: line width as float pixels. |
|
indices: indices of the images to draw the matches on. |
|
""" |
|
if not isinstance(line_colors, list): |
|
line_colors = [line_colors] * len(lines) |
|
if not isinstance(point_colors, list): |
|
point_colors = [point_colors] * len(lines) |
|
|
|
fig = plt.gcf() |
|
ax = fig.axes |
|
assert len(ax) > max(indices) |
|
axes = [ax[i] for i in indices] |
|
fig.canvas.draw() |
|
|
|
|
|
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors): |
|
for i in range(len(l)): |
|
line = matplotlib.lines.Line2D( |
|
(l[i, 0, 0], l[i, 1, 0]), |
|
(l[i, 0, 1], l[i, 1, 1]), |
|
zorder=1, |
|
c=lc, |
|
linewidth=lw, |
|
) |
|
a.add_line(line) |
|
pts = l.reshape(-1, 2) |
|
a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2) |
|
|
|
|
|
def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0): |
|
"""Plot matches for a pair of existing images, parametrized by their middle point. |
|
Args: |
|
kpts0, kpts1: corresponding middle points of the lines of size (N, 2). |
|
color: color of each match, string or RGB tuple. Random if not given. |
|
lw: width of the lines. |
|
indices: indices of the images to draw the matches on. |
|
a: alpha opacity of the match lines. |
|
""" |
|
fig = plt.gcf() |
|
ax = fig.axes |
|
assert len(ax) > max(indices) |
|
ax0, ax1 = ax[indices[0]], ax[indices[1]] |
|
fig.canvas.draw() |
|
|
|
assert len(kpts0) == len(kpts1) |
|
if color is None: |
|
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() |
|
elif len(color) > 0 and not isinstance(color[0], (tuple, list)): |
|
color = [color] * len(kpts0) |
|
|
|
if lw > 0: |
|
|
|
transFigure = fig.transFigure.inverted() |
|
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) |
|
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) |
|
fig.lines += [ |
|
matplotlib.lines.Line2D( |
|
(fkpts0[i, 0], fkpts1[i, 0]), |
|
(fkpts0[i, 1], fkpts1[i, 1]), |
|
zorder=1, |
|
transform=fig.transFigure, |
|
c=color[i], |
|
linewidth=lw, |
|
alpha=a, |
|
) |
|
for i in range(len(kpts0)) |
|
] |
|
|
|
|
|
ax0.autoscale(enable=False) |
|
ax1.autoscale(enable=False) |
|
|
|
|
|
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)): |
|
"""Plot line matches for existing images with multiple colors. |
|
Args: |
|
lines: list of ndarrays of size (N, 2, 2). |
|
correct_matches: bool array of size (N,) indicating correct matches. |
|
lw: line width as float pixels. |
|
indices: indices of the images to draw the matches on. |
|
""" |
|
n_lines = len(lines[0]) |
|
colors = sns.color_palette("husl", n_colors=n_lines) |
|
np.random.shuffle(colors) |
|
alphas = np.ones(n_lines) |
|
|
|
if correct_matches is not None: |
|
alphas[~np.array(correct_matches)] = 0.2 |
|
|
|
fig = plt.gcf() |
|
ax = fig.axes |
|
assert len(ax) > max(indices) |
|
axes = [ax[i] for i in indices] |
|
fig.canvas.draw() |
|
|
|
|
|
for a, l in zip(axes, lines): |
|
|
|
transFigure = fig.transFigure.inverted() |
|
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) |
|
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) |
|
fig.lines += [ |
|
matplotlib.lines.Line2D( |
|
(endpoint0[i, 0], endpoint1[i, 0]), |
|
(endpoint0[i, 1], endpoint1[i, 1]), |
|
zorder=1, |
|
transform=fig.transFigure, |
|
c=colors[i], |
|
alpha=alphas[i], |
|
linewidth=lw, |
|
) |
|
for i in range(n_lines) |
|
] |
|
|
|
|
|
def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)): |
|
"""Plot line matches for existing images with multiple colors: |
|
green for correct matches, red for wrong ones, and blue for the rest. |
|
Args: |
|
lines: list of ndarrays of size (N, 2, 2). |
|
correct_matches: list of bool arrays of size N with correct matches. |
|
wrong_matches: list of bool arrays of size (N,) with correct matches. |
|
lw: line width as float pixels. |
|
indices: indices of the images to draw the matches on. |
|
""" |
|
|
|
palette = sns.color_palette("hls", 8) |
|
blue = palette[5] |
|
red = palette[0] |
|
green = palette[2] |
|
colors = [np.array([blue] * len(l)) for l in lines] |
|
for i, c in enumerate(colors): |
|
c[np.array(correct_matches[i])] = green |
|
c[np.array(wrong_matches[i])] = red |
|
|
|
fig = plt.gcf() |
|
ax = fig.axes |
|
assert len(ax) > max(indices) |
|
axes = [ax[i] for i in indices] |
|
fig.canvas.draw() |
|
|
|
|
|
for a, l, c in zip(axes, lines, colors): |
|
|
|
transFigure = fig.transFigure.inverted() |
|
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) |
|
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) |
|
fig.lines += [ |
|
matplotlib.lines.Line2D( |
|
(endpoint0[i, 0], endpoint1[i, 0]), |
|
(endpoint0[i, 1], endpoint1[i, 1]), |
|
zorder=1, |
|
transform=fig.transFigure, |
|
c=c[i], |
|
linewidth=lw, |
|
) |
|
for i in range(len(l)) |
|
] |
|
|
|
|
|
def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)): |
|
"""Plot line matches for existing images with multiple colors and |
|
highlight the actually matched subsegments. |
|
Args: |
|
lines: list of ndarrays of size (N, 2, 2). |
|
subsegments: list of ndarrays of size (N, 2, 2). |
|
lw: line width as float pixels. |
|
indices: indices of the images to draw the matches on. |
|
""" |
|
n_lines = len(lines[0]) |
|
colors = sns.cubehelix_palette( |
|
start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines |
|
) |
|
|
|
fig = plt.gcf() |
|
ax = fig.axes |
|
assert len(ax) > max(indices) |
|
axes = [ax[i] for i in indices] |
|
fig.canvas.draw() |
|
|
|
|
|
for a, l, ss in zip(axes, lines, subsegments): |
|
|
|
transFigure = fig.transFigure.inverted() |
|
|
|
|
|
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0])) |
|
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1])) |
|
fig.lines += [ |
|
matplotlib.lines.Line2D( |
|
(endpoint0[i, 0], endpoint1[i, 0]), |
|
(endpoint0[i, 1], endpoint1[i, 1]), |
|
zorder=1, |
|
transform=fig.transFigure, |
|
c="red", |
|
alpha=0.7, |
|
linewidth=lw, |
|
) |
|
for i in range(n_lines) |
|
] |
|
|
|
|
|
endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0])) |
|
endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1])) |
|
fig.lines += [ |
|
matplotlib.lines.Line2D( |
|
(endpoint0[i, 0], endpoint1[i, 0]), |
|
(endpoint0[i, 1], endpoint1[i, 1]), |
|
zorder=1, |
|
transform=fig.transFigure, |
|
c=colors[i], |
|
alpha=1, |
|
linewidth=lw, |
|
) |
|
for i in range(n_lines) |
|
] |
|
|