Spaces:
Running
Running
File size: 5,228 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import pprint
import numpy as np
from . import viz2d
from .tools import RadioHideTool, ToggleTool, __plot_dict__
class FormatPrinter(pprint.PrettyPrinter):
def __init__(self, formats):
super(FormatPrinter, self).__init__()
self.formats = formats
def format(self, obj, ctx, maxlvl, lvl):
if type(obj) in self.formats:
return self.formats[type(obj)] % obj, 1, 0
return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl)
class TwoViewFrame:
default_conf = {
"default": "matches",
"summary_visible": False,
}
plot_dict = __plot_dict__
childs = []
event_to_image = [None, "color", "depth", "color+depth"]
def __init__(self, conf, data, preds, title=None, event=1, summaries=None):
self.conf = conf
self.data = data
self.preds = preds
self.names = list(preds.keys())
self.plot = self.event_to_image[event]
self.summaries = summaries
self.fig, self.axes, self.summary_arts = self.init_frame()
if title is not None:
self.fig.canvas.manager.set_window_title(title)
keys = None
for _, pred in preds.items():
if keys is None:
keys = set(pred.keys())
else:
keys = keys.intersection(pred.keys())
keys = keys.union(data.keys())
self.options = [
k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)
]
self.handle = None
self.radios = self.fig.canvas.manager.toolmanager.add_tool(
"switch plot",
RadioHideTool,
options=self.options,
callback_fn=self.draw,
active=conf.default,
keymap="R",
)
self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool(
"toggle summary",
ToggleTool,
toggled=self.conf.summary_visible,
callback_fn=self.set_summary_visible,
keymap="t",
)
if self.fig.canvas.manager.toolbar is not None:
self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation")
self.draw(conf.default)
def init_frame(self):
"""initialize frame"""
view0, view1 = self.data["view0"], self.data["view1"]
if self.plot == "color" or self.plot == "color+depth":
imgs = [
view0["image"][0].permute(1, 2, 0),
view1["image"][0].permute(1, 2, 0),
]
elif self.plot == "depth":
imgs = [view0["depth"][0], view1["depth"][0]]
else:
raise ValueError(self.plot)
imgs = [imgs for _ in self.names] # repeat for each model
fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5)
[viz2d.add_text(0, n, axes=axes[i]) for i, n in enumerate(self.names)]
if (
self.plot == "color+depth"
and "depth" in view0.keys()
and view0["depth"] is not None
):
hmaps = [[view0["depth"][0], view1["depth"][0]] for _ in self.names]
[
viz2d.plot_heatmaps(hmaps[i], axes=axes[i], cmap="Spectral")
for i, _ in enumerate(hmaps)
]
fig.canvas.mpl_connect("pick_event", self.click_artist)
if self.summaries is not None:
formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"})
toggle_artists = [
viz2d.add_text(
0,
formatter.pformat(self.summaries[n]),
axes=axes[i],
pos=(0.01, 0.01),
va="bottom",
backgroundcolor=(0, 0, 0, 0.5),
visible=self.conf.summary_visible,
)
for i, n in enumerate(self.names)
]
else:
toggle_artists = []
return fig, axes, toggle_artists
def draw(self, value):
"""redraw content in frame"""
self.clear()
self.conf.default = value
self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds)
return self.handle
def clear(self):
if self.handle is not None:
try:
self.handle.clear()
except AttributeError:
pass
self.handle = None
for row in self.axes:
for ax in row:
[li.remove() for li in ax.lines]
[c.remove() for c in ax.collections]
self.fig.artists.clear()
self.fig.canvas.draw_idle()
self.handle = None
def click_artist(self, event):
art = event.artist
select = art.get_arrowstyle().arrow == "-"
art.set_arrowstyle("<|-|>" if select else "-")
if select:
art.set_zorder(1)
if hasattr(self.handle, "click_artist"):
self.handle.click_artist(event)
self.fig.canvas.draw_idle()
def set_summary_visible(self, visible):
self.conf.summary_visible = visible
[s.set_visible(visible) for s in self.summary_arts]
self.fig.canvas.draw_idle()
|