Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import plotly.graph_objects as go | |
import PIL.Image | |
from utils.viz_2d import add_text | |
from .parser import Groups | |
class GeoPlotter: | |
def __init__(self, zoom=12, **kwargs): | |
self.fig = go.Figure() | |
self.fig.update_layout( | |
mapbox_style="open-street-map", | |
autosize=True, | |
mapbox_zoom=zoom, | |
margin={"r": 0, "t": 0, "l": 0, "b": 0}, | |
showlegend=True, | |
**kwargs, | |
) | |
def points(self, latlons, color, text=None, name=None, size=5, **kwargs): | |
latlons = np.asarray(latlons) | |
self.fig.add_trace( | |
go.Scattermapbox( | |
lat=latlons[..., 0], | |
lon=latlons[..., 1], | |
mode="markers", | |
text=text, | |
marker_color=color, | |
marker_size=size, | |
name=name, | |
**kwargs, | |
) | |
) | |
center = latlons.reshape(-1, 2).mean(0) | |
self.fig.update_layout( | |
mapbox_center=dict(zip(("lat", "lon"), center)), | |
) | |
def bbox(self, bbox, color, name=None, **kwargs): | |
corners = np.stack( | |
[bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_] | |
) | |
self.fig.add_trace( | |
go.Scattermapbox( | |
lat=corners[:, 0], | |
lon=corners[:, 1], | |
mode="lines", | |
marker_color=color, | |
name=name, | |
**kwargs, | |
) | |
) | |
self.fig.update_layout( | |
mapbox_center=dict(zip(("lat", "lon"), bbox.center)), | |
) | |
def raster(self, raster, bbox, below="traces", **kwargs): | |
if not np.issubdtype(raster.dtype, np.integer): | |
raster = (raster * 255).astype(np.uint8) | |
raster = PIL.Image.fromarray(raster) | |
corners = np.stack( | |
[ | |
bbox.min_, | |
bbox.left_top, | |
bbox.max_, | |
bbox.right_bottom, | |
] | |
)[::-1, ::-1] | |
layers = [*self.fig.layout.mapbox.layers] | |
layers.append( | |
dict( | |
sourcetype="image", | |
source=raster, | |
coordinates=corners, | |
below=below, | |
**kwargs, | |
) | |
) | |
self.fig.layout.mapbox.layers = layers | |
map_colors = { | |
"building": (84, 155, 255), | |
"parking": (255, 229, 145), | |
"playground": (150, 133, 125), | |
"grass": (188, 255, 143), | |
"park": (0, 158, 16), | |
"forest": (0, 92, 9), | |
"water": (184, 213, 255), | |
"fence": (238, 0, 255), | |
"wall": (0, 0, 0), | |
"hedge": (107, 68, 48), | |
"kerb": (255, 234, 0), | |
"building_outline": (0, 0, 255), | |
"cycleway": (0, 251, 255), | |
"path": (8, 237, 0), | |
"road": (255, 0, 0), | |
"tree_row": (0, 92, 9), | |
"busway": (255, 128, 0), | |
"void": [int(255 * 0.9)] * 3, | |
} | |
class Colormap: | |
colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas]) | |
colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways]) | |
def apply(cls, rasters): | |
return ( | |
np.where( | |
rasters[1, ..., None] > 0, | |
cls.colors_ways[rasters[1]], | |
cls.colors_areas[rasters[0]], | |
) | |
/ 255.0 | |
) | |
def add_colorbar(cls): | |
ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8]) | |
color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0 | |
cmap = mpl.colors.ListedColormap(color_list[::-1]) | |
ticks = np.linspace(0, 1, len(color_list), endpoint=False) | |
ticks += 1 / len(color_list) / 2 | |
cb = mpl.colorbar.ColorbarBase( | |
ax2, | |
cmap=cmap, | |
orientation="vertical", | |
ticks=ticks, | |
) | |
cb.set_ticklabels((Groups.areas + Groups.ways)[::-1]) | |
ax2.tick_params(labelsize=15) | |
def plot_nodes(idx, raster, fontsize=8, size=15): | |
ax = plt.gcf().axes[idx] | |
ax.autoscale(enable=False) | |
nodes_xy = np.stack(np.where(raster > 0)[::-1], -1) | |
nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1 | |
ax.scatter(*nodes_xy.T, c="k", s=size) | |
for xy, val in zip(nodes_xy, nodes_val): | |
group = Groups.nodes[val] | |
add_text( | |
idx, | |
group, | |
xy + 2, | |
lcolor=None, | |
fs=fontsize, | |
color="k", | |
normalized=False, | |
ha="center", | |
) | |
plt.show() | |