wangerniu
maplocnet
629144d
# Copyright (c) Meta Platforms, Inc. and affiliates.
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import PIL.Image
from utils.viz_2d import add_text
from .parser import Groups
class GeoPlotter:
def __init__(self, zoom=12, **kwargs):
self.fig = go.Figure()
self.fig.update_layout(
mapbox_style="open-street-map",
autosize=True,
mapbox_zoom=zoom,
margin={"r": 0, "t": 0, "l": 0, "b": 0},
showlegend=True,
**kwargs,
)
def points(self, latlons, color, text=None, name=None, size=5, **kwargs):
latlons = np.asarray(latlons)
self.fig.add_trace(
go.Scattermapbox(
lat=latlons[..., 0],
lon=latlons[..., 1],
mode="markers",
text=text,
marker_color=color,
marker_size=size,
name=name,
**kwargs,
)
)
center = latlons.reshape(-1, 2).mean(0)
self.fig.update_layout(
mapbox_center=dict(zip(("lat", "lon"), center)),
)
def bbox(self, bbox, color, name=None, **kwargs):
corners = np.stack(
[bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_]
)
self.fig.add_trace(
go.Scattermapbox(
lat=corners[:, 0],
lon=corners[:, 1],
mode="lines",
marker_color=color,
name=name,
**kwargs,
)
)
self.fig.update_layout(
mapbox_center=dict(zip(("lat", "lon"), bbox.center)),
)
def raster(self, raster, bbox, below="traces", **kwargs):
if not np.issubdtype(raster.dtype, np.integer):
raster = (raster * 255).astype(np.uint8)
raster = PIL.Image.fromarray(raster)
corners = np.stack(
[
bbox.min_,
bbox.left_top,
bbox.max_,
bbox.right_bottom,
]
)[::-1, ::-1]
layers = [*self.fig.layout.mapbox.layers]
layers.append(
dict(
sourcetype="image",
source=raster,
coordinates=corners,
below=below,
**kwargs,
)
)
self.fig.layout.mapbox.layers = layers
map_colors = {
"building": (84, 155, 255),
"parking": (255, 229, 145),
"playground": (150, 133, 125),
"grass": (188, 255, 143),
"park": (0, 158, 16),
"forest": (0, 92, 9),
"water": (184, 213, 255),
"fence": (238, 0, 255),
"wall": (0, 0, 0),
"hedge": (107, 68, 48),
"kerb": (255, 234, 0),
"building_outline": (0, 0, 255),
"cycleway": (0, 251, 255),
"path": (8, 237, 0),
"road": (255, 0, 0),
"tree_row": (0, 92, 9),
"busway": (255, 128, 0),
"void": [int(255 * 0.9)] * 3,
}
class Colormap:
colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas])
colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways])
@classmethod
def apply(cls, rasters):
return (
np.where(
rasters[1, ..., None] > 0,
cls.colors_ways[rasters[1]],
cls.colors_areas[rasters[0]],
)
/ 255.0
)
@classmethod
def add_colorbar(cls):
ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8])
color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0
cmap = mpl.colors.ListedColormap(color_list[::-1])
ticks = np.linspace(0, 1, len(color_list), endpoint=False)
ticks += 1 / len(color_list) / 2
cb = mpl.colorbar.ColorbarBase(
ax2,
cmap=cmap,
orientation="vertical",
ticks=ticks,
)
cb.set_ticklabels((Groups.areas + Groups.ways)[::-1])
ax2.tick_params(labelsize=15)
def plot_nodes(idx, raster, fontsize=8, size=15):
ax = plt.gcf().axes[idx]
ax.autoscale(enable=False)
nodes_xy = np.stack(np.where(raster > 0)[::-1], -1)
nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1
ax.scatter(*nodes_xy.T, c="k", s=size)
for xy, val in zip(nodes_xy, nodes_val):
group = Groups.nodes[val]
add_text(
idx,
group,
xy + 2,
lcolor=None,
fs=fontsize,
color="k",
normalized=False,
ha="center",
)
plt.show()