""" 2D visualization primitives based on Matplotlib. 1) Plot images with `plot_images`. 2) Call `plot_keypoints` or `plot_matches` any number of times. 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. """ import matplotlib import matplotlib.pyplot as plt import matplotlib.patheffects as path_effects import numpy as np import torch def cm_RdGn(x): """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" x = np.clip(x, 0, 1)[..., None]*2 c = x*np.array([[0, 1., 0]]) + (2-x)*np.array([[1., 0, 0]]) return np.clip(c, 0, 1) def cm_BlRdGn(x_): """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" x = np.clip(x_, 0, 1)[..., None]*2 c = x*np.array([[0, 1., 0, 1.]]) + (2-x)*np.array([[1., 0, 0, 1.]]) xn = -np.clip(x_, -1, 0)[..., None]*2 cn = xn*np.array([[0, 0.1, 1, 1.]]) + (2-xn)*np.array([[1., 0, 0, 1.]]) out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) return out def cm_prune(x_): """ Custom colormap to visualize pruning """ if isinstance(x_, torch.Tensor): x_ = x_.cpu().numpy() max_i = max(x_) norm_x = np.where(x_ == max_i, -1, (x_-1) / 9) return cm_BlRdGn(norm_x) def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, adaptive=True): """Plot a set of images horizontally. Args: imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W). titles: a list of strings, as titles for each image. cmaps: colormaps for monochrome images. adaptive: whether the figure size should fit the image aspect ratios. """ # conversion to (H, W, 3) for torch.Tensor imgs = [img.permute(1, 2, 0).cpu().numpy() if (isinstance(img, torch.Tensor) and img.dim() == 3) else img for img in imgs] n = len(imgs) if not isinstance(cmaps, (list, tuple)): cmaps = [cmaps] * n if adaptive: ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H else: ratios = [4/3] * n figsize = [sum(ratios)*4.5, 4.5] fig, ax = plt.subplots( 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) if n == 1: ax = [ax] for i in range(n): ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) ax[i].get_yaxis().set_ticks([]) ax[i].get_xaxis().set_ticks([]) ax[i].set_axis_off() for spine in ax[i].spines.values(): # remove frame spine.set_visible(False) if titles: ax[i].set_title(titles[i]) fig.tight_layout(pad=pad) def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0): """Plot keypoints for existing images. Args: kpts: list of ndarrays of size (N, 2). colors: string, or list of list of tuples (one for each keypoints). ps: size of the keypoints as float. """ if not isinstance(colors, list): colors = [colors] * len(kpts) if not isinstance(a, list): a = [a] * len(kpts) if axes is None: axes = plt.gcf().axes for ax, k, c, alpha in zip(axes, kpts, colors, a): if isinstance(k, torch.Tensor): k = k.cpu().numpy() ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha) def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1., labels=None, axes=None): """Plot matches for a pair of existing images. Args: kpts0, kpts1: corresponding keypoints of size (N, 2). color: color of each match, string or RGB tuple. Random if not given. lw: width of the lines. ps: size of the end points (no endpoint if ps=0) indices: indices of the images to draw the matches on. a: alpha opacity of the match lines. """ fig = plt.gcf() if axes is None: ax = fig.axes ax0, ax1 = ax[0], ax[1] else: ax0, ax1 = axes if isinstance(kpts0, torch.Tensor): kpts0 = kpts0.cpu().numpy() if isinstance(kpts1, torch.Tensor): kpts1 = kpts1.cpu().numpy() assert len(kpts0) == len(kpts1) if color is None: color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() elif len(color) > 0 and not isinstance(color[0], (tuple, list)): color = [color] * len(kpts0) if lw > 0: for i in range(len(kpts0)): line = matplotlib.patches.ConnectionPatch( xyA=(kpts0[i, 0], kpts0[i, 1]), xyB=(kpts1[i, 0], kpts1[i, 1]), coordsA=ax0.transData, coordsB=ax1.transData, axesA=ax0, axesB=ax1, zorder=1, color=color[i], linewidth=lw, clip_on=True, alpha=a, label=None if labels is None else labels[i], picker=5.0) line.set_annotation_clip(True) fig.add_artist(line) # freeze the axes to prevent the transform to change ax0.autoscale(enable=False) ax1.autoscale(enable=False) if ps > 0: ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) def add_text(idx, text, pos=(0.01, 0.99), fs=15, color='w', lcolor='k', lwidth=2, ha='left', va='top'): ax = plt.gcf().axes[idx] t = ax.text(*pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes) if lcolor is not None: t.set_path_effects([ path_effects.Stroke(linewidth=lwidth, foreground=lcolor), path_effects.Normal()]) def save_plot(path, **kw): """Save the current figure without any white margin.""" plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw)