Spaces:
Sleeping
Sleeping
attila-balint-kul
commited on
Upload 4 files
Browse files- app.py +20 -21
- components.py +297 -111
- requirements.txt +2 -1
- utils.py +18 -6
app.py
CHANGED
@@ -1,34 +1,39 @@
|
|
1 |
import streamlit as st
|
2 |
|
|
|
3 |
from components import (
|
4 |
buildings_view,
|
5 |
-
models_view,
|
6 |
-
performance_view,
|
7 |
computation_view,
|
|
|
8 |
logos,
|
9 |
model_selector,
|
10 |
-
|
11 |
overview_view,
|
|
|
|
|
|
|
12 |
)
|
13 |
-
import utils
|
14 |
|
|
|
|
|
15 |
PAGES = [
|
16 |
"Overview",
|
17 |
"Buildings",
|
18 |
"Models",
|
19 |
-
"
|
|
|
20 |
"Computational Resources",
|
21 |
]
|
22 |
|
23 |
|
24 |
-
st.set_page_config(page_title="
|
25 |
|
26 |
|
27 |
@st.cache_data(ttl=86400)
|
28 |
def fetch_data():
|
29 |
return utils.get_wandb_data(
|
30 |
entity=st.secrets["wandb_entity"],
|
31 |
-
project="
|
32 |
api_key=st.secrets["wandb_api_key"],
|
33 |
job_type="metrics",
|
34 |
)
|
@@ -45,19 +50,11 @@ with st.sidebar:
|
|
45 |
logos()
|
46 |
view = st.selectbox("View", PAGES, index=0)
|
47 |
|
48 |
-
if view
|
49 |
-
models_to_plot = model_selector(models)
|
50 |
|
51 |
if view == "Overview":
|
52 |
-
|
53 |
-
st.link_button("GitHub Repository", url="https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit", use_container_width=True)
|
54 |
-
st.link_button("Documentation", url="https://attila-balint-kul.github.io/energy-forecast-benchmark-toolkit/", use_container_width=True)
|
55 |
-
st.link_button("Electricity Demand Dataset", url="https://huggingface.co/datasets/EDS-lab/electricity-demand", use_container_width=True)
|
56 |
-
st.link_button("HuggingFace Organization", url="https://huggingface.co/EDS-lab", use_container_width=True)
|
57 |
-
|
58 |
-
st.header("Other Dashboards")
|
59 |
-
st.link_button("Electricity Demand", url="https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand", use_container_width=True)
|
60 |
-
st.link_button("PV Generation", url="https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration", use_container_width=True)
|
61 |
|
62 |
st.header("Refresh data")
|
63 |
refresh = st.button(
|
@@ -68,7 +65,7 @@ with st.sidebar:
|
|
68 |
st.rerun()
|
69 |
|
70 |
|
71 |
-
header()
|
72 |
|
73 |
if view == "Overview":
|
74 |
overview_view(data)
|
@@ -76,8 +73,10 @@ elif view == "Buildings":
|
|
76 |
buildings_view(data)
|
77 |
elif view == "Models":
|
78 |
models_view(data)
|
79 |
-
elif view == "
|
80 |
-
|
|
|
|
|
81 |
elif view == "Computational Resources":
|
82 |
computation_view(data, models_to_plot)
|
83 |
else:
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
import utils
|
4 |
from components import (
|
5 |
buildings_view,
|
|
|
|
|
6 |
computation_view,
|
7 |
+
header,
|
8 |
logos,
|
9 |
model_selector,
|
10 |
+
models_view,
|
11 |
overview_view,
|
12 |
+
accuracy_view,
|
13 |
+
relative_performance_view,
|
14 |
+
links,
|
15 |
)
|
|
|
16 |
|
17 |
+
|
18 |
+
USE_CASE = st.secrets["enfobench_usecase"]
|
19 |
PAGES = [
|
20 |
"Overview",
|
21 |
"Buildings",
|
22 |
"Models",
|
23 |
+
"Accuracy",
|
24 |
+
"Relative Performance",
|
25 |
"Computational Resources",
|
26 |
]
|
27 |
|
28 |
|
29 |
+
st.set_page_config(page_title=f"{USE_CASE} Dashboard", layout="wide")
|
30 |
|
31 |
|
32 |
@st.cache_data(ttl=86400)
|
33 |
def fetch_data():
|
34 |
return utils.get_wandb_data(
|
35 |
entity=st.secrets["wandb_entity"],
|
36 |
+
project=st.secrets["wandb_project"],
|
37 |
api_key=st.secrets["wandb_api_key"],
|
38 |
job_type="metrics",
|
39 |
)
|
|
|
50 |
logos()
|
51 |
view = st.selectbox("View", PAGES, index=0)
|
52 |
|
53 |
+
if view in ["Accuracy", "Relative Performance", "Computational Resources"]:
|
54 |
+
models_to_plot = model_selector(models, data)
|
55 |
|
56 |
if view == "Overview":
|
57 |
+
links(current=USE_CASE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
st.header("Refresh data")
|
60 |
refresh = st.button(
|
|
|
65 |
st.rerun()
|
66 |
|
67 |
|
68 |
+
header(f"EnFoBench - {USE_CASE}")
|
69 |
|
70 |
if view == "Overview":
|
71 |
overview_view(data)
|
|
|
73 |
buildings_view(data)
|
74 |
elif view == "Models":
|
75 |
models_view(data)
|
76 |
+
elif view == "Accuracy":
|
77 |
+
accuracy_view(data, models_to_plot)
|
78 |
+
elif view == "Relative Performance":
|
79 |
+
relative_performance_view(data, models_to_plot)
|
80 |
elif view == "Computational Resources":
|
81 |
computation_view(data, models_to_plot)
|
82 |
else:
|
components.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
import pandas as pd
|
2 |
-
import streamlit as st
|
3 |
import plotly.express as px
|
|
|
|
|
4 |
|
5 |
-
from utils import get_leaderboard
|
6 |
|
7 |
|
8 |
-
def header() -> None:
|
9 |
-
st.title(
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
st.divider()
|
11 |
|
12 |
|
@@ -18,7 +25,51 @@ def logos() -> None:
|
|
18 |
st.image("./images/energyville_logo.png")
|
19 |
|
20 |
|
21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# Group models by their prefix
|
23 |
model_groups: dict[str, list[str]] = {}
|
24 |
for model in models:
|
@@ -30,6 +81,35 @@ def model_selector(models: list[str]) -> set[str]:
|
|
30 |
models_to_plot = set()
|
31 |
|
32 |
st.header("Models to include")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
left, right = st.columns(2)
|
34 |
with left:
|
35 |
select_none = st.button("Select None", use_container_width=True)
|
@@ -53,18 +133,7 @@ def model_selector(models: list[str]) -> set[str]:
|
|
53 |
return models_to_plot
|
54 |
|
55 |
|
56 |
-
def overview_view(data):
|
57 |
-
st.markdown(
|
58 |
-
"""
|
59 |
-
[EnFoBench](https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit)
|
60 |
-
is a community driven benchmarking framework for energy forecasting models.
|
61 |
-
|
62 |
-
This dashboard presents the results of the gas demand forecasting usecase. All models were cross-validated
|
63 |
-
on **365 days** of day ahead forecasting horizon *(10AM until midnight of the next day)*.
|
64 |
-
"""
|
65 |
-
)
|
66 |
-
|
67 |
-
st.divider()
|
68 |
st.markdown("## Leaderboard")
|
69 |
|
70 |
leaderboard = get_leaderboard(data, ["MAE.mean", "RMSE.mean", "rMAE.mean"])
|
@@ -78,7 +147,10 @@ def overview_view(data):
|
|
78 |
)
|
79 |
fig = px.bar(best_models_mae, x="MAE.mean", y=best_models_mae.index)
|
80 |
fig.update_layout(
|
81 |
-
title="Top 10 models by MAE",
|
|
|
|
|
|
|
82 |
)
|
83 |
st.plotly_chart(fig, use_container_width=True)
|
84 |
|
@@ -89,7 +161,9 @@ def overview_view(data):
|
|
89 |
.sort_values("RMSE.mean")
|
90 |
)
|
91 |
fig = px.bar(best_models_mae, x="RMSE.mean", y=best_models_mae.index)
|
92 |
-
fig.update_layout(
|
|
|
|
|
93 |
st.plotly_chart(fig, use_container_width=True)
|
94 |
|
95 |
with right:
|
@@ -99,13 +173,20 @@ def overview_view(data):
|
|
99 |
.sort_values("rMAE.mean")
|
100 |
)
|
101 |
fig = px.bar(best_models_mae, x="rMAE.mean", y=best_models_mae.index)
|
102 |
-
fig.update_layout(
|
|
|
|
|
103 |
st.plotly_chart(fig, use_container_width=True)
|
104 |
|
105 |
st.dataframe(leaderboard, use_container_width=True)
|
106 |
|
107 |
|
108 |
-
def buildings_view(data):
|
|
|
|
|
|
|
|
|
|
|
109 |
buildings = (
|
110 |
data[
|
111 |
[
|
@@ -115,6 +196,8 @@ def buildings_view(data):
|
|
115 |
"metadata.location_id",
|
116 |
"metadata.timezone",
|
117 |
"dataset.available_history.days",
|
|
|
|
|
118 |
]
|
119 |
]
|
120 |
.groupby("unique_id")
|
@@ -126,29 +209,32 @@ def buildings_view(data):
|
|
126 |
"metadata.location_id": "Location ID",
|
127 |
"metadata.timezone": "Timezone",
|
128 |
"dataset.available_history.days": "Available history (days)",
|
|
|
|
|
129 |
}
|
130 |
)
|
131 |
)
|
132 |
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
st.divider()
|
135 |
|
136 |
-
st.
|
137 |
-
st.dataframe(
|
138 |
-
buildings,
|
139 |
-
use_container_width=True,
|
140 |
-
column_config={
|
141 |
-
"Available history (days)": st.column_config.ProgressColumn(
|
142 |
-
"Available history (days)",
|
143 |
-
help="Available training data during the first prediction.",
|
144 |
-
format="%f",
|
145 |
-
min_value=0,
|
146 |
-
max_value=float(buildings["Available history (days)"].max()),
|
147 |
-
),
|
148 |
-
},
|
149 |
-
)
|
150 |
-
|
151 |
-
left, right = st.columns(2, gap="large")
|
152 |
with left:
|
153 |
st.markdown("#### Building classes")
|
154 |
fig = px.pie(
|
@@ -156,19 +242,61 @@ def buildings_view(data):
|
|
156 |
values=0,
|
157 |
names="Building class",
|
158 |
)
|
|
|
|
|
|
|
159 |
st.plotly_chart(fig, use_container_width=True)
|
160 |
|
161 |
-
with
|
162 |
st.markdown("#### Timezones")
|
163 |
fig = px.pie(
|
164 |
buildings.groupby("Timezone").size().reset_index(),
|
165 |
values=0,
|
166 |
names="Timezone",
|
167 |
)
|
|
|
|
|
|
|
168 |
st.plotly_chart(fig, use_container_width=True)
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
models = (
|
173 |
data[
|
174 |
[
|
@@ -197,12 +325,21 @@ def models_view(data):
|
|
197 |
)
|
198 |
)
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
st.divider()
|
202 |
|
203 |
-
st.markdown("### Models")
|
204 |
-
st.dataframe(models, use_container_width=True)
|
205 |
-
|
206 |
left, right = st.columns(2, gap="large")
|
207 |
with left:
|
208 |
st.markdown("#### Variate types")
|
@@ -224,8 +361,12 @@ def models_view(data):
|
|
224 |
)
|
225 |
st.plotly_chart(fig, use_container_width=True)
|
226 |
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
def
|
229 |
data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
|
230 |
by="model", ascending=True
|
231 |
)
|
@@ -239,24 +380,20 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
|
239 |
)
|
240 |
st.markdown(f"#### {aggregation.capitalize()} {metric} per building")
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
.
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
points="all",
|
257 |
-
)
|
258 |
-
fig.update_layout(showlegend=False, height=40 * len(models_to_plot))
|
259 |
-
st.plotly_chart(fig, use_container_width=True)
|
260 |
|
261 |
st.divider()
|
262 |
|
@@ -285,14 +422,17 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
|
285 |
st.markdown(
|
286 |
f"#### {x_aggregation.capitalize()} {x_metric} vs {y_aggregation.capitalize()} {y_metric}"
|
287 |
)
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
|
|
|
|
|
|
296 |
|
297 |
st.divider()
|
298 |
|
@@ -309,9 +449,7 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
|
309 |
key="table_aggregation",
|
310 |
)
|
311 |
|
312 |
-
metrics_table = data_to_plot.groupby(["model"]).agg(
|
313 |
-
aggregation, numeric_only=True
|
314 |
-
)[
|
315 |
[
|
316 |
f"{metric}.min",
|
317 |
f"{metric}.mean",
|
@@ -319,7 +457,7 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
|
319 |
f"{metric}.max",
|
320 |
f"{metric}.std",
|
321 |
]
|
322 |
-
]
|
323 |
|
324 |
def custom_table(styler):
|
325 |
styler.background_gradient(cmap="seismic", axis=0)
|
@@ -339,15 +477,12 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
|
339 |
.reset_index()
|
340 |
.pivot(index="model", columns="unique_id", values=f"{metric}.{aggregation}")
|
341 |
)
|
342 |
-
metrics_per_building_table.insert(
|
343 |
-
0, "median", metrics_per_building_table.median(axis=1)
|
344 |
-
)
|
345 |
metrics_per_building_table.insert(
|
346 |
0, "mean", metrics_per_building_table.mean(axis=1)
|
347 |
)
|
348 |
-
metrics_per_building_table = metrics_per_building_table.sort_values(by="mean")
|
349 |
|
350 |
-
def custom_table(styler):
|
351 |
styler.background_gradient(cmap="seismic", axis=None)
|
352 |
styler.format(precision=2)
|
353 |
|
@@ -360,29 +495,54 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
|
360 |
st.dataframe(styled_table, use_container_width=True)
|
361 |
|
362 |
|
363 |
-
def
|
364 |
data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
|
365 |
by="model", ascending=True
|
366 |
)
|
367 |
|
368 |
-
st.markdown("####
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
"
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
)
|
383 |
-
|
384 |
|
385 |
-
st.
|
386 |
|
387 |
left, center, right = st.columns(3, gap="small")
|
388 |
with left:
|
@@ -399,17 +559,43 @@ def computation_view(data, models_to_plot: set[str]):
|
|
399 |
st.markdown(
|
400 |
f"#### {aggregation_per_model.capitalize()} {aggregation_per_building.capitalize()} {metric} vs CPU usage"
|
401 |
)
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
|
|
2 |
import plotly.express as px
|
3 |
+
import streamlit as st
|
4 |
+
from pandas.io.formats.style import Styler
|
5 |
|
6 |
+
from utils import get_leaderboard, get_model_ranks
|
7 |
|
8 |
|
9 |
+
def header(title: str) -> None:
|
10 |
+
st.title(title)
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
[EnFoBench](https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit)
|
14 |
+
is a community driven benchmarking framework for energy forecasting models.
|
15 |
+
"""
|
16 |
+
)
|
17 |
st.divider()
|
18 |
|
19 |
|
|
|
25 |
st.image("./images/energyville_logo.png")
|
26 |
|
27 |
|
28 |
+
def links(current: str) -> None:
|
29 |
+
st.header("Sources")
|
30 |
+
st.link_button(
|
31 |
+
"GitHub Repository",
|
32 |
+
url="https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit",
|
33 |
+
use_container_width=True,
|
34 |
+
)
|
35 |
+
st.link_button(
|
36 |
+
"Documentation",
|
37 |
+
url="https://attila-balint-kul.github.io/energy-forecast-benchmark-toolkit/",
|
38 |
+
use_container_width=True,
|
39 |
+
)
|
40 |
+
st.link_button(
|
41 |
+
"Electricity Demand Dataset",
|
42 |
+
url="https://huggingface.co/datasets/EDS-lab/electricity-demand",
|
43 |
+
use_container_width=True,
|
44 |
+
)
|
45 |
+
st.link_button(
|
46 |
+
"HuggingFace Organization",
|
47 |
+
url="https://huggingface.co/EDS-lab",
|
48 |
+
use_container_width=True,
|
49 |
+
)
|
50 |
+
|
51 |
+
st.header("Other Dashboards")
|
52 |
+
if current != "ElectricityDemand":
|
53 |
+
st.link_button(
|
54 |
+
"Electricity Demand",
|
55 |
+
url="https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand",
|
56 |
+
use_container_width=True,
|
57 |
+
)
|
58 |
+
if current != "GasDemand":
|
59 |
+
st.link_button(
|
60 |
+
"Gas Demand",
|
61 |
+
url="https://huggingface.co/spaces/EDS-lab/EnFoBench-GasDemand",
|
62 |
+
use_container_width=True,
|
63 |
+
)
|
64 |
+
if current != "PVGeneration":
|
65 |
+
st.link_button(
|
66 |
+
"PVGeneration",
|
67 |
+
url="https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration",
|
68 |
+
use_container_width=True,
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def model_selector(models: list[str], data: pd.DataFrame) -> set[str]:
|
73 |
# Group models by their prefix
|
74 |
model_groups: dict[str, list[str]] = {}
|
75 |
for model in models:
|
|
|
81 |
models_to_plot = set()
|
82 |
|
83 |
st.header("Models to include")
|
84 |
+
left, middle, right = st.columns(3)
|
85 |
+
with left:
|
86 |
+
best_by_mae = st.button("Best by MAE", use_container_width=True)
|
87 |
+
if best_by_mae:
|
88 |
+
best_models_by_mae = get_model_ranks(data, "MAE.mean").head(10).model.tolist()
|
89 |
+
for model in models:
|
90 |
+
if model in best_models_by_mae:
|
91 |
+
st.session_state[model] = True
|
92 |
+
else:
|
93 |
+
st.session_state[model] = False
|
94 |
+
with middle:
|
95 |
+
best_by_rmse = st.button("Best by RMSE", use_container_width=True)
|
96 |
+
if best_by_rmse:
|
97 |
+
best_models_by_rmse = get_model_ranks(data, "RMSE.mean").head(10).model.tolist()
|
98 |
+
for model in models:
|
99 |
+
if model in best_models_by_rmse:
|
100 |
+
st.session_state[model] = True
|
101 |
+
else:
|
102 |
+
st.session_state[model] = False
|
103 |
+
with right:
|
104 |
+
best_by_rmae = st.button("Best by rMAE", use_container_width=True)
|
105 |
+
if best_by_rmae:
|
106 |
+
best_models_by_rmae = get_model_ranks(data, "rMAE.mean").head(10).model.tolist()
|
107 |
+
for model in models:
|
108 |
+
if model in best_models_by_rmae:
|
109 |
+
st.session_state[model] = True
|
110 |
+
else:
|
111 |
+
st.session_state[model] = False
|
112 |
+
|
113 |
left, right = st.columns(2)
|
114 |
with left:
|
115 |
select_none = st.button("Select None", use_container_width=True)
|
|
|
133 |
return models_to_plot
|
134 |
|
135 |
|
136 |
+
def overview_view(data: pd.DataFrame):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
st.markdown("## Leaderboard")
|
138 |
|
139 |
leaderboard = get_leaderboard(data, ["MAE.mean", "RMSE.mean", "rMAE.mean"])
|
|
|
147 |
)
|
148 |
fig = px.bar(best_models_mae, x="MAE.mean", y=best_models_mae.index)
|
149 |
fig.update_layout(
|
150 |
+
title="Top 10 models by MAE",
|
151 |
+
xaxis_title="",
|
152 |
+
yaxis_title="Model",
|
153 |
+
height=600,
|
154 |
)
|
155 |
st.plotly_chart(fig, use_container_width=True)
|
156 |
|
|
|
161 |
.sort_values("RMSE.mean")
|
162 |
)
|
163 |
fig = px.bar(best_models_mae, x="RMSE.mean", y=best_models_mae.index)
|
164 |
+
fig.update_layout(
|
165 |
+
title="Top 10 models by RMSE", xaxis_title="", yaxis_title="", height=600
|
166 |
+
)
|
167 |
st.plotly_chart(fig, use_container_width=True)
|
168 |
|
169 |
with right:
|
|
|
173 |
.sort_values("rMAE.mean")
|
174 |
)
|
175 |
fig = px.bar(best_models_mae, x="rMAE.mean", y=best_models_mae.index)
|
176 |
+
fig.update_layout(
|
177 |
+
title="Top 10 models by rMAE", xaxis_title="", yaxis_title="", height=600
|
178 |
+
)
|
179 |
st.plotly_chart(fig, use_container_width=True)
|
180 |
|
181 |
st.dataframe(leaderboard, use_container_width=True)
|
182 |
|
183 |
|
184 |
+
def buildings_view(data: pd.DataFrame):
|
185 |
+
if 'metadata.cluster_size' not in data.columns:
|
186 |
+
data['metadata.cluster_size'] = 1
|
187 |
+
if 'metadata.building_class' not in data.columns:
|
188 |
+
data['metadata.building_class'] = "Unknown"
|
189 |
+
|
190 |
buildings = (
|
191 |
data[
|
192 |
[
|
|
|
196 |
"metadata.location_id",
|
197 |
"metadata.timezone",
|
198 |
"dataset.available_history.days",
|
199 |
+
"dataset.available_history.observations",
|
200 |
+
"metadata.freq",
|
201 |
]
|
202 |
]
|
203 |
.groupby("unique_id")
|
|
|
209 |
"metadata.location_id": "Location ID",
|
210 |
"metadata.timezone": "Timezone",
|
211 |
"dataset.available_history.days": "Available history (days)",
|
212 |
+
"dataset.available_history.observations": "Available history (#)",
|
213 |
+
"metadata.freq": "Frequency",
|
214 |
}
|
215 |
)
|
216 |
)
|
217 |
|
218 |
+
left, middle, right = st.columns(3)
|
219 |
+
with left:
|
220 |
+
st.metric("Number of buildings", data["unique_id"].nunique())
|
221 |
+
with middle:
|
222 |
+
st.metric(
|
223 |
+
"Residential",
|
224 |
+
data[data["metadata.building_class"] == "Residential"][
|
225 |
+
"unique_id"
|
226 |
+
].nunique(),
|
227 |
+
)
|
228 |
+
with right:
|
229 |
+
st.metric(
|
230 |
+
"Commercial",
|
231 |
+
data[data["metadata.building_class"] == "Commercial"][
|
232 |
+
"unique_id"
|
233 |
+
].nunique(),
|
234 |
+
)
|
235 |
st.divider()
|
236 |
|
237 |
+
left, middle, right = st.columns(3, gap="large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
with left:
|
239 |
st.markdown("#### Building classes")
|
240 |
fig = px.pie(
|
|
|
242 |
values=0,
|
243 |
names="Building class",
|
244 |
)
|
245 |
+
fig.update_layout(
|
246 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
247 |
+
)
|
248 |
st.plotly_chart(fig, use_container_width=True)
|
249 |
|
250 |
+
with middle:
|
251 |
st.markdown("#### Timezones")
|
252 |
fig = px.pie(
|
253 |
buildings.groupby("Timezone").size().reset_index(),
|
254 |
values=0,
|
255 |
names="Timezone",
|
256 |
)
|
257 |
+
fig.update_layout(
|
258 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
259 |
+
)
|
260 |
st.plotly_chart(fig, use_container_width=True)
|
261 |
|
262 |
+
with right:
|
263 |
+
st.markdown("#### Frequencies")
|
264 |
+
fig = px.pie(
|
265 |
+
buildings.groupby("Frequency").size().reset_index(),
|
266 |
+
values=0,
|
267 |
+
names="Frequency",
|
268 |
+
)
|
269 |
+
fig.update_layout(
|
270 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
271 |
+
)
|
272 |
+
st.plotly_chart(fig, use_container_width=True)
|
273 |
|
274 |
+
st.divider()
|
275 |
+
|
276 |
+
st.markdown("#### Buildings")
|
277 |
+
st.dataframe(
|
278 |
+
buildings.sort_values("Available history (days)"),
|
279 |
+
use_container_width=True,
|
280 |
+
column_config={
|
281 |
+
"Available history (days)": st.column_config.ProgressColumn(
|
282 |
+
"Available history (days)",
|
283 |
+
help="Available training data during the first prediction.",
|
284 |
+
format="%f",
|
285 |
+
min_value=0,
|
286 |
+
max_value=float(buildings["Available history (days)"].max()),
|
287 |
+
),
|
288 |
+
"Available history (#)": st.column_config.ProgressColumn(
|
289 |
+
"Available history (#)",
|
290 |
+
help="Available training data during the first prediction.",
|
291 |
+
format="%f",
|
292 |
+
min_value=0,
|
293 |
+
max_value=float(buildings["Available history (#)"].max()),
|
294 |
+
),
|
295 |
+
},
|
296 |
+
)
|
297 |
+
|
298 |
+
|
299 |
+
def models_view(data: pd.DataFrame):
|
300 |
models = (
|
301 |
data[
|
302 |
[
|
|
|
325 |
)
|
326 |
)
|
327 |
|
328 |
+
left, middle, right = st.columns(3)
|
329 |
+
with left:
|
330 |
+
st.metric("Models", len(models))
|
331 |
+
with middle:
|
332 |
+
st.metric(
|
333 |
+
"Univariate",
|
334 |
+
data[data["model_info.variate_type"] == "univariate"]["model"].nunique(),
|
335 |
+
)
|
336 |
+
with right:
|
337 |
+
st.metric(
|
338 |
+
"Univariate",
|
339 |
+
data[data["model_info.variate_type"] == "multivariate"]["model"].nunique(),
|
340 |
+
)
|
341 |
st.divider()
|
342 |
|
|
|
|
|
|
|
343 |
left, right = st.columns(2, gap="large")
|
344 |
with left:
|
345 |
st.markdown("#### Variate types")
|
|
|
361 |
)
|
362 |
st.plotly_chart(fig, use_container_width=True)
|
363 |
|
364 |
+
st.divider()
|
365 |
+
st.markdown("### Models")
|
366 |
+
st.dataframe(models, use_container_width=True)
|
367 |
+
|
368 |
|
369 |
+
def accuracy_view(data: pd.DataFrame, models_to_plot: set[str]):
|
370 |
data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
|
371 |
by="model", ascending=True
|
372 |
)
|
|
|
380 |
)
|
381 |
st.markdown(f"#### {aggregation.capitalize()} {metric} per building")
|
382 |
|
383 |
+
if data_to_plot.empty:
|
384 |
+
st.warning("No data to display.")
|
385 |
+
else:
|
386 |
+
model_ranks = get_model_ranks(data_to_plot, f"{metric}.{aggregation}")
|
387 |
+
|
388 |
+
fig = px.box(
|
389 |
+
data_to_plot.merge(model_ranks, on="model").sort_values(by="rank"),
|
390 |
+
x=f"{metric}.{aggregation}",
|
391 |
+
y="model",
|
392 |
+
color="model",
|
393 |
+
points="all",
|
394 |
+
)
|
395 |
+
fig.update_layout(showlegend=False, height=50 * len(models_to_plot))
|
396 |
+
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
|
|
|
397 |
|
398 |
st.divider()
|
399 |
|
|
|
422 |
st.markdown(
|
423 |
f"#### {x_aggregation.capitalize()} {x_metric} vs {y_aggregation.capitalize()} {y_metric}"
|
424 |
)
|
425 |
+
if data_to_plot.empty:
|
426 |
+
st.warning("No data to display.")
|
427 |
+
else:
|
428 |
+
fig = px.scatter(
|
429 |
+
data_to_plot,
|
430 |
+
x=f"{x_metric}.{x_aggregation}",
|
431 |
+
y=f"{y_metric}.{y_aggregation}",
|
432 |
+
color="model",
|
433 |
+
)
|
434 |
+
fig.update_layout(height=600)
|
435 |
+
st.plotly_chart(fig, use_container_width=True)
|
436 |
|
437 |
st.divider()
|
438 |
|
|
|
449 |
key="table_aggregation",
|
450 |
)
|
451 |
|
452 |
+
metrics_table = data_to_plot.groupby(["model"]).agg(aggregation, numeric_only=True)[
|
|
|
|
|
453 |
[
|
454 |
f"{metric}.min",
|
455 |
f"{metric}.mean",
|
|
|
457 |
f"{metric}.max",
|
458 |
f"{metric}.std",
|
459 |
]
|
460 |
+
].sort_values(by=f"{metric}.mean")
|
461 |
|
462 |
def custom_table(styler):
|
463 |
styler.background_gradient(cmap="seismic", axis=0)
|
|
|
477 |
.reset_index()
|
478 |
.pivot(index="model", columns="unique_id", values=f"{metric}.{aggregation}")
|
479 |
)
|
|
|
|
|
|
|
480 |
metrics_per_building_table.insert(
|
481 |
0, "mean", metrics_per_building_table.mean(axis=1)
|
482 |
)
|
483 |
+
metrics_per_building_table = metrics_per_building_table.sort_values(by="mean").drop(columns="mean")
|
484 |
|
485 |
+
def custom_table(styler: Styler):
|
486 |
styler.background_gradient(cmap="seismic", axis=None)
|
487 |
styler.format(precision=2)
|
488 |
|
|
|
495 |
st.dataframe(styled_table, use_container_width=True)
|
496 |
|
497 |
|
498 |
+
def relative_performance_view(data: pd.DataFrame, models_to_plot: set[str]):
|
499 |
data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
|
500 |
by="model", ascending=True
|
501 |
)
|
502 |
|
503 |
+
st.markdown("#### Relative performance")
|
504 |
+
if data_to_plot.empty:
|
505 |
+
st.warning("No data to display.")
|
506 |
+
else:
|
507 |
+
baseline_choices = sorted(
|
508 |
+
data.filter(like="better_than")
|
509 |
+
.columns.str.removeprefix("better_than.")
|
510 |
+
.tolist()
|
511 |
+
)
|
512 |
+
if len(baseline_choices) > 1:
|
513 |
+
better_than_baseline = st.selectbox("Baseline model", options=baseline_choices)
|
514 |
+
else:
|
515 |
+
better_than_baseline = baseline_choices[0]
|
516 |
+
data_to_plot.loc[:, f"better_than.{better_than_baseline}.percentage"] = (
|
517 |
+
pd.json_normalize(data_to_plot[f"better_than.{better_than_baseline}"])[
|
518 |
+
"percentage"
|
519 |
+
].values
|
520 |
+
* 100
|
521 |
+
)
|
522 |
+
model_rank = get_model_ranks(data_to_plot, f"better_than.{better_than_baseline}.percentage")
|
523 |
+
|
524 |
+
fig = px.box(
|
525 |
+
data_to_plot.merge(model_rank).sort_values(by="rank"),
|
526 |
+
x=f"better_than.{better_than_baseline}.percentage",
|
527 |
+
y="model",
|
528 |
+
points="all",
|
529 |
+
)
|
530 |
+
fig.update_xaxes(range=[0, 100], title_text="Better than baseline (%)")
|
531 |
+
fig.update_layout(
|
532 |
+
showlegend=False,
|
533 |
+
height=50 * len(models_to_plot),
|
534 |
+
title=f"Better than {better_than_baseline} on % of days per building",
|
535 |
+
)
|
536 |
+
st.plotly_chart(fig, use_container_width=True)
|
537 |
+
|
538 |
+
|
539 |
+
def computation_view(data: pd.DataFrame, models_to_plot: set[str]):
|
540 |
+
data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
|
541 |
+
by="model", ascending=True
|
542 |
)
|
543 |
+
data_to_plot["resource_usage.CPU"] /= 3600
|
544 |
|
545 |
+
st.markdown("#### Computational Resources")
|
546 |
|
547 |
left, center, right = st.columns(3, gap="small")
|
548 |
with left:
|
|
|
559 |
st.markdown(
|
560 |
f"#### {aggregation_per_model.capitalize()} {aggregation_per_building.capitalize()} {metric} vs CPU usage"
|
561 |
)
|
562 |
+
if data_to_plot.empty:
|
563 |
+
st.warning("No data to display.")
|
564 |
+
else:
|
565 |
+
aggregated_data = (
|
566 |
+
data_to_plot.groupby("model")
|
567 |
+
.agg(aggregation_per_building, numeric_only=True)
|
568 |
+
.reset_index()
|
569 |
+
)
|
570 |
+
fig = px.scatter(
|
571 |
+
aggregated_data,
|
572 |
+
x="resource_usage.CPU",
|
573 |
+
y=f"{metric}.{aggregation_per_model}",
|
574 |
+
color="model",
|
575 |
+
log_x=True,
|
576 |
+
)
|
577 |
+
fig.update_layout(height=600)
|
578 |
+
fig.update_xaxes(title_text="CPU usage (hours)")
|
579 |
+
fig.update_yaxes(
|
580 |
+
title_text=f"{metric} ({aggregation_per_building}, {aggregation_per_model})"
|
581 |
+
)
|
582 |
+
st.plotly_chart(fig, use_container_width=True)
|
583 |
+
|
584 |
+
st.divider()
|
585 |
+
|
586 |
+
st.markdown("#### Computational time vs historical data")
|
587 |
+
if data_to_plot.empty:
|
588 |
+
st.warning("No data to display.")
|
589 |
+
else:
|
590 |
+
fig = px.scatter(
|
591 |
+
data_to_plot,
|
592 |
+
x="dataset.available_history.observations",
|
593 |
+
y="resource_usage.CPU",
|
594 |
+
color="model",
|
595 |
+
trendline="ols",
|
596 |
+
hover_data=["model", "unique_id"],
|
597 |
+
)
|
598 |
+
fig.update_layout(height=600)
|
599 |
+
fig.update_xaxes(title_text="Available historical observations (#)")
|
600 |
+
fig.update_yaxes(title_text="CPU usage (hours)")
|
601 |
+
st.plotly_chart(fig, use_container_width=True)
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
wandb==0.17.0
|
2 |
-
plotly==5.20.0
|
|
|
|
1 |
wandb==0.17.0
|
2 |
+
plotly==5.20.0
|
3 |
+
statsmodels==0.14.2
|
utils.py
CHANGED
@@ -2,7 +2,9 @@ import pandas as pd
|
|
2 |
import wandb
|
3 |
|
4 |
|
5 |
-
def get_wandb_data(
|
|
|
|
|
6 |
api = wandb.Api(api_key=api_key)
|
7 |
|
8 |
# Project is specified by <entity/project-name>
|
@@ -17,7 +19,7 @@ def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd
|
|
17 |
|
18 |
# .config contains the hyperparameters.
|
19 |
# We remove special values that start with _.
|
20 |
-
config_list.append(
|
21 |
|
22 |
# .name is the human-readable name of the run.
|
23 |
name_list.append(run.name)
|
@@ -30,10 +32,9 @@ def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd
|
|
30 |
|
31 |
|
32 |
def get_leaderboard(runs_df: pd.DataFrame, metrics: list[str]) -> pd.DataFrame:
|
33 |
-
leaderboard = pd.DataFrame(
|
34 |
-
|
35 |
-
|
36 |
-
).fillna(0)
|
37 |
|
38 |
for _, building_df in runs_df.groupby("unique_id"):
|
39 |
for column in leaderboard.columns:
|
@@ -42,3 +43,14 @@ def get_leaderboard(runs_df: pd.DataFrame, metrics: list[str]) -> pd.DataFrame:
|
|
42 |
|
43 |
leaderboard = leaderboard.sort_values(by=list(leaderboard.columns), ascending=False)
|
44 |
return leaderboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import wandb
|
3 |
|
4 |
|
5 |
+
def get_wandb_data(
|
6 |
+
entity: str, project: str, api_key: str, job_type: str
|
7 |
+
) -> pd.DataFrame:
|
8 |
api = wandb.Api(api_key=api_key)
|
9 |
|
10 |
# Project is specified by <entity/project-name>
|
|
|
19 |
|
20 |
# .config contains the hyperparameters.
|
21 |
# We remove special values that start with _.
|
22 |
+
config_list.append(run.config)
|
23 |
|
24 |
# .name is the human-readable name of the run.
|
25 |
name_list.append(run.name)
|
|
|
32 |
|
33 |
|
34 |
def get_leaderboard(runs_df: pd.DataFrame, metrics: list[str]) -> pd.DataFrame:
|
35 |
+
leaderboard = pd.DataFrame(index=runs_df["model"].unique(), columns=metrics).fillna(
|
36 |
+
0
|
37 |
+
)
|
|
|
38 |
|
39 |
for _, building_df in runs_df.groupby("unique_id"):
|
40 |
for column in leaderboard.columns:
|
|
|
43 |
|
44 |
leaderboard = leaderboard.sort_values(by=list(leaderboard.columns), ascending=False)
|
45 |
return leaderboard
|
46 |
+
|
47 |
+
|
48 |
+
def get_model_ranks(runs_df: pd.DataFrame, metric: str) -> pd.DataFrame:
|
49 |
+
return (
|
50 |
+
runs_df.groupby(["model"])
|
51 |
+
.median(numeric_only=True)
|
52 |
+
.sort_values(by=metric)
|
53 |
+
.reset_index()
|
54 |
+
.rename_axis("rank")
|
55 |
+
.reset_index()[["rank", "model"]]
|
56 |
+
)
|