# Copyright (c) Meta Platforms, Inc. and affiliates. from collections import Counter, defaultdict from typing import Dict import matplotlib.pyplot as plt import numpy as np import plotly.graph_objects as go from .parser import ( filter_area, filter_node, filter_way, match_to_group, parse_area, parse_node, parse_way, Patterns, ) from .reader import OSMData def recover_hierarchy(counter: Counter) -> Dict: """Recover a two-level hierarchy from the flat group labels.""" groups = defaultdict(dict) for k, v in sorted(counter.items(), key=lambda x: -x[1]): if ":" in k: prefix, group = k.split(":") if prefix in groups and isinstance(groups[prefix], int): groups[prefix] = {} groups[prefix][prefix] = groups[prefix] groups[prefix] = {} groups[prefix][group] = v else: groups[k] = v return dict(groups) def bar_autolabel(rects, fontsize): """Attach a text label above each bar in *rects*, displaying its height.""" for rect in rects: width = rect.get_width() plt.gca().annotate( f"{width}", xy=(width, rect.get_y() + rect.get_height() / 2), xytext=(3, 0), # 3 points vertical offset textcoords="offset points", ha="left", va="center", fontsize=fontsize, ) def plot_histogram(counts, fontsize, dpi): fig, ax = plt.subplots(dpi=dpi, figsize=(8, 20)) labels = [] for k, v in counts.items(): if isinstance(v, dict): labels += list(v.keys()) v = list(v.values()) else: labels.append(k) v = [v] bars = plt.barh( len(labels) + -len(v) + np.arange(len(v)), v, height=0.9, label=k ) bar_autolabel(bars, fontsize) ax.set_yticklabels(labels, fontsize=fontsize) ax.axes.xaxis.set_ticklabels([]) ax.xaxis.tick_top() ax.invert_yaxis() plt.yticks(np.arange(len(labels))) plt.xscale("log") plt.legend(ncol=len(counts), loc="upper center") def count_elements(elems: Dict[int, str], filter_fn, parse_fn) -> Dict: """Count the number of elements in each group.""" counts = Counter() for elem in filter(filter_fn, elems.values()): group = parse_fn(elem.tags) if group is None: continue counts[group] += 1 counts = recover_hierarchy(counts) return counts def plot_osm_histograms(osm: OSMData, fontsize=8, dpi=150): counts = count_elements(osm.nodes, filter_node, parse_node) plot_histogram(counts, fontsize, dpi) plt.title("nodes") counts = count_elements(osm.ways, filter_way, parse_way) plot_histogram(counts, fontsize, dpi) plt.title("ways") counts = count_elements(osm.ways, filter_area, parse_area) plot_histogram(counts, fontsize, dpi) plt.title("areas") def plot_sankey_hierarchy(osm: OSMData): triplets = [] for node in filter(filter_node, osm.nodes.values()): label = parse_node(node.tags) if label is None: continue group = match_to_group(label, Patterns.nodes) if group is None: group = match_to_group(label, Patterns.ways) if group is None: group = "null" if ":" in label: key, tag = label.split(":") if tag == "yes": tag = key else: key = tag = label triplets.append((key, tag, group)) keys, tags, groups = list(zip(*triplets)) counts_key_tag = Counter(zip(keys, tags)) counts_key_tag_group = Counter(triplets) key2tags = defaultdict(set) for k, t in zip(keys, tags): key2tags[k].add(t) key2tags = {k: sorted(t) for k, t in key2tags.items()} keytag2group = dict(zip(zip(keys, tags), groups)) key_names = sorted(set(keys)) tag_names = [(k, t) for k in key_names for t in key2tags[k]] group_names = [] for k in key_names: for t in key2tags[k]: g = keytag2group[k, t] if g not in group_names and g != "null": group_names.append(g) group_names += ["null"] key2idx = dict(zip(key_names, range(len(key_names)))) tag2idx = {kt: i + len(key2idx) for i, kt in enumerate(tag_names)} group2idx = {n: i + len(key2idx) + len(tag2idx) for i, n in enumerate(group_names)} key_counts = Counter(keys) key_text = [f"{k} {key_counts[k]}" for k in key_names] tag_counts = Counter(list(zip(keys, tags))) tag_text = [f"{t} {tag_counts[k, t]}" for k, t in tag_names] group_counts = Counter(groups) group_text = [f"{k} {group_counts[k]}" for k in group_names] fig = go.Figure( data=[ go.Sankey( orientation="h", node=dict( pad=15, thickness=20, line=dict(color="black", width=0.5), label=key_text + tag_text + group_text, x=[0] * len(key_names) + [1] * len(tag_names) + [2] * len(group_names), color="blue", ), arrangement="fixed", link=dict( source=[key2idx[k] for k, _ in counts_key_tag] + [tag2idx[k, t] for k, t, _ in counts_key_tag_group], target=[tag2idx[k, t] for k, t in counts_key_tag] + [group2idx[g] for _, _, g in counts_key_tag_group], value=list(counts_key_tag.values()) + list(counts_key_tag_group.values()), ), ) ] ) fig.update_layout(autosize=False, width=800, height=2000, font_size=10) fig.show() return fig