# Copyright (c) Meta Platforms, Inc. and affiliates. import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go import PIL.Image from utils.viz_2d import add_text from .parser import Groups class GeoPlotter: def __init__(self, zoom=12, **kwargs): self.fig = go.Figure() self.fig.update_layout( mapbox_style="open-street-map", autosize=True, mapbox_zoom=zoom, margin={"r": 0, "t": 0, "l": 0, "b": 0}, showlegend=True, **kwargs, ) def points(self, latlons, color, text=None, name=None, size=5, **kwargs): latlons = np.asarray(latlons) self.fig.add_trace( go.Scattermapbox( lat=latlons[..., 0], lon=latlons[..., 1], mode="markers", text=text, marker_color=color, marker_size=size, name=name, **kwargs, ) ) center = latlons.reshape(-1, 2).mean(0) self.fig.update_layout( mapbox_center=dict(zip(("lat", "lon"), center)), ) def bbox(self, bbox, color, name=None, **kwargs): corners = np.stack( [bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_] ) self.fig.add_trace( go.Scattermapbox( lat=corners[:, 0], lon=corners[:, 1], mode="lines", marker_color=color, name=name, **kwargs, ) ) self.fig.update_layout( mapbox_center=dict(zip(("lat", "lon"), bbox.center)), ) def raster(self, raster, bbox, below="traces", **kwargs): if not np.issubdtype(raster.dtype, np.integer): raster = (raster * 255).astype(np.uint8) raster = PIL.Image.fromarray(raster) corners = np.stack( [ bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, ] )[::-1, ::-1] layers = [*self.fig.layout.mapbox.layers] layers.append( dict( sourcetype="image", source=raster, coordinates=corners, below=below, **kwargs, ) ) self.fig.layout.mapbox.layers = layers map_colors = { "building": (84, 155, 255), "parking": (255, 229, 145), "playground": (150, 133, 125), "grass": (188, 255, 143), "park": (0, 158, 16), "forest": (0, 92, 9), "water": (184, 213, 255), "fence": (238, 0, 255), "wall": (0, 0, 0), "hedge": (107, 68, 48), "kerb": (255, 234, 0), "building_outline": (0, 0, 255), "cycleway": (0, 251, 255), "path": (8, 237, 0), "road": (255, 0, 0), "tree_row": (0, 92, 9), "busway": (255, 128, 0), "void": [int(255 * 0.9)] * 3, } class Colormap: colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas]) colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways]) @classmethod def apply(cls, rasters): return ( np.where( rasters[1, ..., None] > 0, cls.colors_ways[rasters[1]], cls.colors_areas[rasters[0]], ) / 255.0 ) @classmethod def add_colorbar(cls): ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8]) color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0 cmap = mpl.colors.ListedColormap(color_list[::-1]) ticks = np.linspace(0, 1, len(color_list), endpoint=False) ticks += 1 / len(color_list) / 2 cb = mpl.colorbar.ColorbarBase( ax2, cmap=cmap, orientation="vertical", ticks=ticks, ) cb.set_ticklabels((Groups.areas + Groups.ways)[::-1]) ax2.tick_params(labelsize=15) def plot_nodes(idx, raster, fontsize=8, size=15): ax = plt.gcf().axes[idx] ax.autoscale(enable=False) nodes_xy = np.stack(np.where(raster > 0)[::-1], -1) nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1 ax.scatter(*nodes_xy.T, c="k", s=size) for xy, val in zip(nodes_xy, nodes_val): group = Groups.nodes[val] add_text( idx, group, xy + 2, lcolor=None, fs=fontsize, color="k", normalized=False, ha="center", ) plt.show()