attila-balint-kul commited on
Commit
a6cfc29
·
verified ·
1 Parent(s): 9c958d6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +20 -21
  2. components.py +297 -111
  3. requirements.txt +2 -1
  4. utils.py +18 -6
app.py CHANGED
@@ -1,34 +1,39 @@
1
  import streamlit as st
2
 
 
3
  from components import (
4
  buildings_view,
5
- models_view,
6
- performance_view,
7
  computation_view,
 
8
  logos,
9
  model_selector,
10
- header,
11
  overview_view,
 
 
 
12
  )
13
- import utils
14
 
 
 
15
  PAGES = [
16
  "Overview",
17
  "Buildings",
18
  "Models",
19
- "Performance",
 
20
  "Computational Resources",
21
  ]
22
 
23
 
24
- st.set_page_config(page_title="Gas Demand Dashboard", layout="wide")
25
 
26
 
27
  @st.cache_data(ttl=86400)
28
  def fetch_data():
29
  return utils.get_wandb_data(
30
  entity=st.secrets["wandb_entity"],
31
- project="enfobench-gas-demand",
32
  api_key=st.secrets["wandb_api_key"],
33
  job_type="metrics",
34
  )
@@ -45,19 +50,11 @@ with st.sidebar:
45
  logos()
46
  view = st.selectbox("View", PAGES, index=0)
47
 
48
- if view == "Performance" or view == "Computational Resources":
49
- models_to_plot = model_selector(models)
50
 
51
  if view == "Overview":
52
- st.header("Sources")
53
- st.link_button("GitHub Repository", url="https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit", use_container_width=True)
54
- st.link_button("Documentation", url="https://attila-balint-kul.github.io/energy-forecast-benchmark-toolkit/", use_container_width=True)
55
- st.link_button("Electricity Demand Dataset", url="https://huggingface.co/datasets/EDS-lab/electricity-demand", use_container_width=True)
56
- st.link_button("HuggingFace Organization", url="https://huggingface.co/EDS-lab", use_container_width=True)
57
-
58
- st.header("Other Dashboards")
59
- st.link_button("Electricity Demand", url="https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand", use_container_width=True)
60
- st.link_button("PV Generation", url="https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration", use_container_width=True)
61
 
62
  st.header("Refresh data")
63
  refresh = st.button(
@@ -68,7 +65,7 @@ with st.sidebar:
68
  st.rerun()
69
 
70
 
71
- header()
72
 
73
  if view == "Overview":
74
  overview_view(data)
@@ -76,8 +73,10 @@ elif view == "Buildings":
76
  buildings_view(data)
77
  elif view == "Models":
78
  models_view(data)
79
- elif view == "Performance":
80
- performance_view(data, models_to_plot)
 
 
81
  elif view == "Computational Resources":
82
  computation_view(data, models_to_plot)
83
  else:
 
1
  import streamlit as st
2
 
3
+ import utils
4
  from components import (
5
  buildings_view,
 
 
6
  computation_view,
7
+ header,
8
  logos,
9
  model_selector,
10
+ models_view,
11
  overview_view,
12
+ accuracy_view,
13
+ relative_performance_view,
14
+ links,
15
  )
 
16
 
17
+
18
+ USE_CASE = st.secrets["enfobench_usecase"]
19
  PAGES = [
20
  "Overview",
21
  "Buildings",
22
  "Models",
23
+ "Accuracy",
24
+ "Relative Performance",
25
  "Computational Resources",
26
  ]
27
 
28
 
29
+ st.set_page_config(page_title=f"{USE_CASE} Dashboard", layout="wide")
30
 
31
 
32
  @st.cache_data(ttl=86400)
33
  def fetch_data():
34
  return utils.get_wandb_data(
35
  entity=st.secrets["wandb_entity"],
36
+ project=st.secrets["wandb_project"],
37
  api_key=st.secrets["wandb_api_key"],
38
  job_type="metrics",
39
  )
 
50
  logos()
51
  view = st.selectbox("View", PAGES, index=0)
52
 
53
+ if view in ["Accuracy", "Relative Performance", "Computational Resources"]:
54
+ models_to_plot = model_selector(models, data)
55
 
56
  if view == "Overview":
57
+ links(current=USE_CASE)
 
 
 
 
 
 
 
 
58
 
59
  st.header("Refresh data")
60
  refresh = st.button(
 
65
  st.rerun()
66
 
67
 
68
+ header(f"EnFoBench - {USE_CASE}")
69
 
70
  if view == "Overview":
71
  overview_view(data)
 
73
  buildings_view(data)
74
  elif view == "Models":
75
  models_view(data)
76
+ elif view == "Accuracy":
77
+ accuracy_view(data, models_to_plot)
78
+ elif view == "Relative Performance":
79
+ relative_performance_view(data, models_to_plot)
80
  elif view == "Computational Resources":
81
  computation_view(data, models_to_plot)
82
  else:
components.py CHANGED
@@ -1,12 +1,19 @@
1
  import pandas as pd
2
- import streamlit as st
3
  import plotly.express as px
 
 
4
 
5
- from utils import get_leaderboard
6
 
7
 
8
- def header() -> None:
9
- st.title("EnFoBench - Gas Demand")
 
 
 
 
 
 
10
  st.divider()
11
 
12
 
@@ -18,7 +25,51 @@ def logos() -> None:
18
  st.image("./images/energyville_logo.png")
19
 
20
 
21
- def model_selector(models: list[str]) -> set[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Group models by their prefix
23
  model_groups: dict[str, list[str]] = {}
24
  for model in models:
@@ -30,6 +81,35 @@ def model_selector(models: list[str]) -> set[str]:
30
  models_to_plot = set()
31
 
32
  st.header("Models to include")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  left, right = st.columns(2)
34
  with left:
35
  select_none = st.button("Select None", use_container_width=True)
@@ -53,18 +133,7 @@ def model_selector(models: list[str]) -> set[str]:
53
  return models_to_plot
54
 
55
 
56
- def overview_view(data):
57
- st.markdown(
58
- """
59
- [EnFoBench](https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit)
60
- is a community driven benchmarking framework for energy forecasting models.
61
-
62
- This dashboard presents the results of the gas demand forecasting usecase. All models were cross-validated
63
- on **365 days** of day ahead forecasting horizon *(10AM until midnight of the next day)*.
64
- """
65
- )
66
-
67
- st.divider()
68
  st.markdown("## Leaderboard")
69
 
70
  leaderboard = get_leaderboard(data, ["MAE.mean", "RMSE.mean", "rMAE.mean"])
@@ -78,7 +147,10 @@ def overview_view(data):
78
  )
79
  fig = px.bar(best_models_mae, x="MAE.mean", y=best_models_mae.index)
80
  fig.update_layout(
81
- title="Top 10 models by MAE", xaxis_title="", yaxis_title="Model"
 
 
 
82
  )
83
  st.plotly_chart(fig, use_container_width=True)
84
 
@@ -89,7 +161,9 @@ def overview_view(data):
89
  .sort_values("RMSE.mean")
90
  )
91
  fig = px.bar(best_models_mae, x="RMSE.mean", y=best_models_mae.index)
92
- fig.update_layout(title="Top 10 models by RMSE", xaxis_title="", yaxis_title="")
 
 
93
  st.plotly_chart(fig, use_container_width=True)
94
 
95
  with right:
@@ -99,13 +173,20 @@ def overview_view(data):
99
  .sort_values("rMAE.mean")
100
  )
101
  fig = px.bar(best_models_mae, x="rMAE.mean", y=best_models_mae.index)
102
- fig.update_layout(title="Top 10 models by rMAE", xaxis_title="", yaxis_title="")
 
 
103
  st.plotly_chart(fig, use_container_width=True)
104
 
105
  st.dataframe(leaderboard, use_container_width=True)
106
 
107
 
108
- def buildings_view(data):
 
 
 
 
 
109
  buildings = (
110
  data[
111
  [
@@ -115,6 +196,8 @@ def buildings_view(data):
115
  "metadata.location_id",
116
  "metadata.timezone",
117
  "dataset.available_history.days",
 
 
118
  ]
119
  ]
120
  .groupby("unique_id")
@@ -126,29 +209,32 @@ def buildings_view(data):
126
  "metadata.location_id": "Location ID",
127
  "metadata.timezone": "Timezone",
128
  "dataset.available_history.days": "Available history (days)",
 
 
129
  }
130
  )
131
  )
132
 
133
- st.metric("Number of buildings", len(buildings))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  st.divider()
135
 
136
- st.markdown("### Buildings")
137
- st.dataframe(
138
- buildings,
139
- use_container_width=True,
140
- column_config={
141
- "Available history (days)": st.column_config.ProgressColumn(
142
- "Available history (days)",
143
- help="Available training data during the first prediction.",
144
- format="%f",
145
- min_value=0,
146
- max_value=float(buildings["Available history (days)"].max()),
147
- ),
148
- },
149
- )
150
-
151
- left, right = st.columns(2, gap="large")
152
  with left:
153
  st.markdown("#### Building classes")
154
  fig = px.pie(
@@ -156,19 +242,61 @@ def buildings_view(data):
156
  values=0,
157
  names="Building class",
158
  )
 
 
 
159
  st.plotly_chart(fig, use_container_width=True)
160
 
161
- with right:
162
  st.markdown("#### Timezones")
163
  fig = px.pie(
164
  buildings.groupby("Timezone").size().reset_index(),
165
  values=0,
166
  names="Timezone",
167
  )
 
 
 
168
  st.plotly_chart(fig, use_container_width=True)
169
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- def models_view(data):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  models = (
173
  data[
174
  [
@@ -197,12 +325,21 @@ def models_view(data):
197
  )
198
  )
199
 
200
- st.metric("Number of models", len(models))
 
 
 
 
 
 
 
 
 
 
 
 
201
  st.divider()
202
 
203
- st.markdown("### Models")
204
- st.dataframe(models, use_container_width=True)
205
-
206
  left, right = st.columns(2, gap="large")
207
  with left:
208
  st.markdown("#### Variate types")
@@ -224,8 +361,12 @@ def models_view(data):
224
  )
225
  st.plotly_chart(fig, use_container_width=True)
226
 
 
 
 
 
227
 
228
- def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
229
  data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
230
  by="model", ascending=True
231
  )
@@ -239,24 +380,20 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
239
  )
240
  st.markdown(f"#### {aggregation.capitalize()} {metric} per building")
241
 
242
- rank_df = (
243
- data_to_plot.groupby(["model"])
244
- .agg("median", numeric_only=True)
245
- .sort_values(by=f"{metric}.{aggregation}")
246
- .reset_index()
247
- .rename_axis("rank")
248
- .reset_index()[["rank", "model"]]
249
- )
250
-
251
- fig = px.box(
252
- data_to_plot.merge(rank_df, on="model").sort_values(by="rank"),
253
- x=f"{metric}.{aggregation}",
254
- y="model",
255
- color="model",
256
- points="all",
257
- )
258
- fig.update_layout(showlegend=False, height=40 * len(models_to_plot))
259
- st.plotly_chart(fig, use_container_width=True)
260
 
261
  st.divider()
262
 
@@ -285,14 +422,17 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
285
  st.markdown(
286
  f"#### {x_aggregation.capitalize()} {x_metric} vs {y_aggregation.capitalize()} {y_metric}"
287
  )
288
- fig = px.scatter(
289
- data_to_plot,
290
- x=f"{x_metric}.{x_aggregation}",
291
- y=f"{y_metric}.{y_aggregation}",
292
- color="model",
293
- )
294
- fig.update_layout(height=600)
295
- st.plotly_chart(fig, use_container_width=True)
 
 
 
296
 
297
  st.divider()
298
 
@@ -309,9 +449,7 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
309
  key="table_aggregation",
310
  )
311
 
312
- metrics_table = data_to_plot.groupby(["model"]).agg(
313
- aggregation, numeric_only=True
314
- )[
315
  [
316
  f"{metric}.min",
317
  f"{metric}.mean",
@@ -319,7 +457,7 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
319
  f"{metric}.max",
320
  f"{metric}.std",
321
  ]
322
- ]
323
 
324
  def custom_table(styler):
325
  styler.background_gradient(cmap="seismic", axis=0)
@@ -339,15 +477,12 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
339
  .reset_index()
340
  .pivot(index="model", columns="unique_id", values=f"{metric}.{aggregation}")
341
  )
342
- metrics_per_building_table.insert(
343
- 0, "median", metrics_per_building_table.median(axis=1)
344
- )
345
  metrics_per_building_table.insert(
346
  0, "mean", metrics_per_building_table.mean(axis=1)
347
  )
348
- metrics_per_building_table = metrics_per_building_table.sort_values(by="mean")
349
 
350
- def custom_table(styler):
351
  styler.background_gradient(cmap="seismic", axis=None)
352
  styler.format(precision=2)
353
 
@@ -360,29 +495,54 @@ def performance_view(data: pd.DataFrame, models_to_plot: set[str]):
360
  st.dataframe(styled_table, use_container_width=True)
361
 
362
 
363
- def computation_view(data, models_to_plot: set[str]):
364
  data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
365
  by="model", ascending=True
366
  )
367
 
368
- st.markdown("#### Computational Resources")
369
- fig = px.parallel_coordinates(
370
- data_to_plot.groupby("model").mean(numeric_only=True).reset_index(),
371
- dimensions=[
372
- "model",
373
- "resource_usage.CPU",
374
- "resource_usage.memory",
375
- "MAE.mean",
376
- "RMSE.mean",
377
- "MBE.mean",
378
- "rMAE.mean",
379
- ],
380
- color="rMAE.mean",
381
- color_continuous_scale=px.colors.diverging.Portland,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  )
383
- st.plotly_chart(fig, use_container_width=True)
384
 
385
- st.divider()
386
 
387
  left, center, right = st.columns(3, gap="small")
388
  with left:
@@ -399,17 +559,43 @@ def computation_view(data, models_to_plot: set[str]):
399
  st.markdown(
400
  f"#### {aggregation_per_model.capitalize()} {aggregation_per_building.capitalize()} {metric} vs CPU usage"
401
  )
402
- aggregated_data = (
403
- data_to_plot.groupby("model")
404
- .agg(aggregation_per_building, numeric_only=True)
405
- .reset_index()
406
- )
407
- fig = px.scatter(
408
- aggregated_data,
409
- x="resource_usage.CPU",
410
- y=f"{metric}.{aggregation_per_model}",
411
- color="model",
412
- log_x=True,
413
- )
414
- fig.update_layout(height=600)
415
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
 
2
  import plotly.express as px
3
+ import streamlit as st
4
+ from pandas.io.formats.style import Styler
5
 
6
+ from utils import get_leaderboard, get_model_ranks
7
 
8
 
9
+ def header(title: str) -> None:
10
+ st.title(title)
11
+ st.markdown(
12
+ """
13
+ [EnFoBench](https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit)
14
+ is a community driven benchmarking framework for energy forecasting models.
15
+ """
16
+ )
17
  st.divider()
18
 
19
 
 
25
  st.image("./images/energyville_logo.png")
26
 
27
 
28
+ def links(current: str) -> None:
29
+ st.header("Sources")
30
+ st.link_button(
31
+ "GitHub Repository",
32
+ url="https://github.com/attila-balint-kul/energy-forecast-benchmark-toolkit",
33
+ use_container_width=True,
34
+ )
35
+ st.link_button(
36
+ "Documentation",
37
+ url="https://attila-balint-kul.github.io/energy-forecast-benchmark-toolkit/",
38
+ use_container_width=True,
39
+ )
40
+ st.link_button(
41
+ "Electricity Demand Dataset",
42
+ url="https://huggingface.co/datasets/EDS-lab/electricity-demand",
43
+ use_container_width=True,
44
+ )
45
+ st.link_button(
46
+ "HuggingFace Organization",
47
+ url="https://huggingface.co/EDS-lab",
48
+ use_container_width=True,
49
+ )
50
+
51
+ st.header("Other Dashboards")
52
+ if current != "ElectricityDemand":
53
+ st.link_button(
54
+ "Electricity Demand",
55
+ url="https://huggingface.co/spaces/EDS-lab/EnFoBench-ElectricityDemand",
56
+ use_container_width=True,
57
+ )
58
+ if current != "GasDemand":
59
+ st.link_button(
60
+ "Gas Demand",
61
+ url="https://huggingface.co/spaces/EDS-lab/EnFoBench-GasDemand",
62
+ use_container_width=True,
63
+ )
64
+ if current != "PVGeneration":
65
+ st.link_button(
66
+ "PVGeneration",
67
+ url="https://huggingface.co/spaces/EDS-lab/EnFoBench-PVGeneration",
68
+ use_container_width=True,
69
+ )
70
+
71
+
72
+ def model_selector(models: list[str], data: pd.DataFrame) -> set[str]:
73
  # Group models by their prefix
74
  model_groups: dict[str, list[str]] = {}
75
  for model in models:
 
81
  models_to_plot = set()
82
 
83
  st.header("Models to include")
84
+ left, middle, right = st.columns(3)
85
+ with left:
86
+ best_by_mae = st.button("Best by MAE", use_container_width=True)
87
+ if best_by_mae:
88
+ best_models_by_mae = get_model_ranks(data, "MAE.mean").head(10).model.tolist()
89
+ for model in models:
90
+ if model in best_models_by_mae:
91
+ st.session_state[model] = True
92
+ else:
93
+ st.session_state[model] = False
94
+ with middle:
95
+ best_by_rmse = st.button("Best by RMSE", use_container_width=True)
96
+ if best_by_rmse:
97
+ best_models_by_rmse = get_model_ranks(data, "RMSE.mean").head(10).model.tolist()
98
+ for model in models:
99
+ if model in best_models_by_rmse:
100
+ st.session_state[model] = True
101
+ else:
102
+ st.session_state[model] = False
103
+ with right:
104
+ best_by_rmae = st.button("Best by rMAE", use_container_width=True)
105
+ if best_by_rmae:
106
+ best_models_by_rmae = get_model_ranks(data, "rMAE.mean").head(10).model.tolist()
107
+ for model in models:
108
+ if model in best_models_by_rmae:
109
+ st.session_state[model] = True
110
+ else:
111
+ st.session_state[model] = False
112
+
113
  left, right = st.columns(2)
114
  with left:
115
  select_none = st.button("Select None", use_container_width=True)
 
133
  return models_to_plot
134
 
135
 
136
+ def overview_view(data: pd.DataFrame):
 
 
 
 
 
 
 
 
 
 
 
137
  st.markdown("## Leaderboard")
138
 
139
  leaderboard = get_leaderboard(data, ["MAE.mean", "RMSE.mean", "rMAE.mean"])
 
147
  )
148
  fig = px.bar(best_models_mae, x="MAE.mean", y=best_models_mae.index)
149
  fig.update_layout(
150
+ title="Top 10 models by MAE",
151
+ xaxis_title="",
152
+ yaxis_title="Model",
153
+ height=600,
154
  )
155
  st.plotly_chart(fig, use_container_width=True)
156
 
 
161
  .sort_values("RMSE.mean")
162
  )
163
  fig = px.bar(best_models_mae, x="RMSE.mean", y=best_models_mae.index)
164
+ fig.update_layout(
165
+ title="Top 10 models by RMSE", xaxis_title="", yaxis_title="", height=600
166
+ )
167
  st.plotly_chart(fig, use_container_width=True)
168
 
169
  with right:
 
173
  .sort_values("rMAE.mean")
174
  )
175
  fig = px.bar(best_models_mae, x="rMAE.mean", y=best_models_mae.index)
176
+ fig.update_layout(
177
+ title="Top 10 models by rMAE", xaxis_title="", yaxis_title="", height=600
178
+ )
179
  st.plotly_chart(fig, use_container_width=True)
180
 
181
  st.dataframe(leaderboard, use_container_width=True)
182
 
183
 
184
+ def buildings_view(data: pd.DataFrame):
185
+ if 'metadata.cluster_size' not in data.columns:
186
+ data['metadata.cluster_size'] = 1
187
+ if 'metadata.building_class' not in data.columns:
188
+ data['metadata.building_class'] = "Unknown"
189
+
190
  buildings = (
191
  data[
192
  [
 
196
  "metadata.location_id",
197
  "metadata.timezone",
198
  "dataset.available_history.days",
199
+ "dataset.available_history.observations",
200
+ "metadata.freq",
201
  ]
202
  ]
203
  .groupby("unique_id")
 
209
  "metadata.location_id": "Location ID",
210
  "metadata.timezone": "Timezone",
211
  "dataset.available_history.days": "Available history (days)",
212
+ "dataset.available_history.observations": "Available history (#)",
213
+ "metadata.freq": "Frequency",
214
  }
215
  )
216
  )
217
 
218
+ left, middle, right = st.columns(3)
219
+ with left:
220
+ st.metric("Number of buildings", data["unique_id"].nunique())
221
+ with middle:
222
+ st.metric(
223
+ "Residential",
224
+ data[data["metadata.building_class"] == "Residential"][
225
+ "unique_id"
226
+ ].nunique(),
227
+ )
228
+ with right:
229
+ st.metric(
230
+ "Commercial",
231
+ data[data["metadata.building_class"] == "Commercial"][
232
+ "unique_id"
233
+ ].nunique(),
234
+ )
235
  st.divider()
236
 
237
+ left, middle, right = st.columns(3, gap="large")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  with left:
239
  st.markdown("#### Building classes")
240
  fig = px.pie(
 
242
  values=0,
243
  names="Building class",
244
  )
245
+ fig.update_layout(
246
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
247
+ )
248
  st.plotly_chart(fig, use_container_width=True)
249
 
250
+ with middle:
251
  st.markdown("#### Timezones")
252
  fig = px.pie(
253
  buildings.groupby("Timezone").size().reset_index(),
254
  values=0,
255
  names="Timezone",
256
  )
257
+ fig.update_layout(
258
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
259
+ )
260
  st.plotly_chart(fig, use_container_width=True)
261
 
262
+ with right:
263
+ st.markdown("#### Frequencies")
264
+ fig = px.pie(
265
+ buildings.groupby("Frequency").size().reset_index(),
266
+ values=0,
267
+ names="Frequency",
268
+ )
269
+ fig.update_layout(
270
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
271
+ )
272
+ st.plotly_chart(fig, use_container_width=True)
273
 
274
+ st.divider()
275
+
276
+ st.markdown("#### Buildings")
277
+ st.dataframe(
278
+ buildings.sort_values("Available history (days)"),
279
+ use_container_width=True,
280
+ column_config={
281
+ "Available history (days)": st.column_config.ProgressColumn(
282
+ "Available history (days)",
283
+ help="Available training data during the first prediction.",
284
+ format="%f",
285
+ min_value=0,
286
+ max_value=float(buildings["Available history (days)"].max()),
287
+ ),
288
+ "Available history (#)": st.column_config.ProgressColumn(
289
+ "Available history (#)",
290
+ help="Available training data during the first prediction.",
291
+ format="%f",
292
+ min_value=0,
293
+ max_value=float(buildings["Available history (#)"].max()),
294
+ ),
295
+ },
296
+ )
297
+
298
+
299
+ def models_view(data: pd.DataFrame):
300
  models = (
301
  data[
302
  [
 
325
  )
326
  )
327
 
328
+ left, middle, right = st.columns(3)
329
+ with left:
330
+ st.metric("Models", len(models))
331
+ with middle:
332
+ st.metric(
333
+ "Univariate",
334
+ data[data["model_info.variate_type"] == "univariate"]["model"].nunique(),
335
+ )
336
+ with right:
337
+ st.metric(
338
+ "Univariate",
339
+ data[data["model_info.variate_type"] == "multivariate"]["model"].nunique(),
340
+ )
341
  st.divider()
342
 
 
 
 
343
  left, right = st.columns(2, gap="large")
344
  with left:
345
  st.markdown("#### Variate types")
 
361
  )
362
  st.plotly_chart(fig, use_container_width=True)
363
 
364
+ st.divider()
365
+ st.markdown("### Models")
366
+ st.dataframe(models, use_container_width=True)
367
+
368
 
369
+ def accuracy_view(data: pd.DataFrame, models_to_plot: set[str]):
370
  data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
371
  by="model", ascending=True
372
  )
 
380
  )
381
  st.markdown(f"#### {aggregation.capitalize()} {metric} per building")
382
 
383
+ if data_to_plot.empty:
384
+ st.warning("No data to display.")
385
+ else:
386
+ model_ranks = get_model_ranks(data_to_plot, f"{metric}.{aggregation}")
387
+
388
+ fig = px.box(
389
+ data_to_plot.merge(model_ranks, on="model").sort_values(by="rank"),
390
+ x=f"{metric}.{aggregation}",
391
+ y="model",
392
+ color="model",
393
+ points="all",
394
+ )
395
+ fig.update_layout(showlegend=False, height=50 * len(models_to_plot))
396
+ st.plotly_chart(fig, use_container_width=True)
 
 
 
 
397
 
398
  st.divider()
399
 
 
422
  st.markdown(
423
  f"#### {x_aggregation.capitalize()} {x_metric} vs {y_aggregation.capitalize()} {y_metric}"
424
  )
425
+ if data_to_plot.empty:
426
+ st.warning("No data to display.")
427
+ else:
428
+ fig = px.scatter(
429
+ data_to_plot,
430
+ x=f"{x_metric}.{x_aggregation}",
431
+ y=f"{y_metric}.{y_aggregation}",
432
+ color="model",
433
+ )
434
+ fig.update_layout(height=600)
435
+ st.plotly_chart(fig, use_container_width=True)
436
 
437
  st.divider()
438
 
 
449
  key="table_aggregation",
450
  )
451
 
452
+ metrics_table = data_to_plot.groupby(["model"]).agg(aggregation, numeric_only=True)[
 
 
453
  [
454
  f"{metric}.min",
455
  f"{metric}.mean",
 
457
  f"{metric}.max",
458
  f"{metric}.std",
459
  ]
460
+ ].sort_values(by=f"{metric}.mean")
461
 
462
  def custom_table(styler):
463
  styler.background_gradient(cmap="seismic", axis=0)
 
477
  .reset_index()
478
  .pivot(index="model", columns="unique_id", values=f"{metric}.{aggregation}")
479
  )
 
 
 
480
  metrics_per_building_table.insert(
481
  0, "mean", metrics_per_building_table.mean(axis=1)
482
  )
483
+ metrics_per_building_table = metrics_per_building_table.sort_values(by="mean").drop(columns="mean")
484
 
485
+ def custom_table(styler: Styler):
486
  styler.background_gradient(cmap="seismic", axis=None)
487
  styler.format(precision=2)
488
 
 
495
  st.dataframe(styled_table, use_container_width=True)
496
 
497
 
498
+ def relative_performance_view(data: pd.DataFrame, models_to_plot: set[str]):
499
  data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
500
  by="model", ascending=True
501
  )
502
 
503
+ st.markdown("#### Relative performance")
504
+ if data_to_plot.empty:
505
+ st.warning("No data to display.")
506
+ else:
507
+ baseline_choices = sorted(
508
+ data.filter(like="better_than")
509
+ .columns.str.removeprefix("better_than.")
510
+ .tolist()
511
+ )
512
+ if len(baseline_choices) > 1:
513
+ better_than_baseline = st.selectbox("Baseline model", options=baseline_choices)
514
+ else:
515
+ better_than_baseline = baseline_choices[0]
516
+ data_to_plot.loc[:, f"better_than.{better_than_baseline}.percentage"] = (
517
+ pd.json_normalize(data_to_plot[f"better_than.{better_than_baseline}"])[
518
+ "percentage"
519
+ ].values
520
+ * 100
521
+ )
522
+ model_rank = get_model_ranks(data_to_plot, f"better_than.{better_than_baseline}.percentage")
523
+
524
+ fig = px.box(
525
+ data_to_plot.merge(model_rank).sort_values(by="rank"),
526
+ x=f"better_than.{better_than_baseline}.percentage",
527
+ y="model",
528
+ points="all",
529
+ )
530
+ fig.update_xaxes(range=[0, 100], title_text="Better than baseline (%)")
531
+ fig.update_layout(
532
+ showlegend=False,
533
+ height=50 * len(models_to_plot),
534
+ title=f"Better than {better_than_baseline} on % of days per building",
535
+ )
536
+ st.plotly_chart(fig, use_container_width=True)
537
+
538
+
539
+ def computation_view(data: pd.DataFrame, models_to_plot: set[str]):
540
+ data_to_plot = data[data["model"].isin(models_to_plot)].sort_values(
541
+ by="model", ascending=True
542
  )
543
+ data_to_plot["resource_usage.CPU"] /= 3600
544
 
545
+ st.markdown("#### Computational Resources")
546
 
547
  left, center, right = st.columns(3, gap="small")
548
  with left:
 
559
  st.markdown(
560
  f"#### {aggregation_per_model.capitalize()} {aggregation_per_building.capitalize()} {metric} vs CPU usage"
561
  )
562
+ if data_to_plot.empty:
563
+ st.warning("No data to display.")
564
+ else:
565
+ aggregated_data = (
566
+ data_to_plot.groupby("model")
567
+ .agg(aggregation_per_building, numeric_only=True)
568
+ .reset_index()
569
+ )
570
+ fig = px.scatter(
571
+ aggregated_data,
572
+ x="resource_usage.CPU",
573
+ y=f"{metric}.{aggregation_per_model}",
574
+ color="model",
575
+ log_x=True,
576
+ )
577
+ fig.update_layout(height=600)
578
+ fig.update_xaxes(title_text="CPU usage (hours)")
579
+ fig.update_yaxes(
580
+ title_text=f"{metric} ({aggregation_per_building}, {aggregation_per_model})"
581
+ )
582
+ st.plotly_chart(fig, use_container_width=True)
583
+
584
+ st.divider()
585
+
586
+ st.markdown("#### Computational time vs historical data")
587
+ if data_to_plot.empty:
588
+ st.warning("No data to display.")
589
+ else:
590
+ fig = px.scatter(
591
+ data_to_plot,
592
+ x="dataset.available_history.observations",
593
+ y="resource_usage.CPU",
594
+ color="model",
595
+ trendline="ols",
596
+ hover_data=["model", "unique_id"],
597
+ )
598
+ fig.update_layout(height=600)
599
+ fig.update_xaxes(title_text="Available historical observations (#)")
600
+ fig.update_yaxes(title_text="CPU usage (hours)")
601
+ st.plotly_chart(fig, use_container_width=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  wandb==0.17.0
2
- plotly==5.20.0
 
 
1
  wandb==0.17.0
2
+ plotly==5.20.0
3
+ statsmodels==0.14.2
utils.py CHANGED
@@ -2,7 +2,9 @@ import pandas as pd
2
  import wandb
3
 
4
 
5
- def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd.DataFrame:
 
 
6
  api = wandb.Api(api_key=api_key)
7
 
8
  # Project is specified by <entity/project-name>
@@ -17,7 +19,7 @@ def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd
17
 
18
  # .config contains the hyperparameters.
19
  # We remove special values that start with _.
20
- config_list.append({k: v for k, v in run.config.items()})
21
 
22
  # .name is the human-readable name of the run.
23
  name_list.append(run.name)
@@ -30,10 +32,9 @@ def get_wandb_data(entity: str, project: str, api_key: str, job_type: str) -> pd
30
 
31
 
32
  def get_leaderboard(runs_df: pd.DataFrame, metrics: list[str]) -> pd.DataFrame:
33
- leaderboard = pd.DataFrame(
34
- index=runs_df['model'].unique(),
35
- columns=metrics
36
- ).fillna(0)
37
 
38
  for _, building_df in runs_df.groupby("unique_id"):
39
  for column in leaderboard.columns:
@@ -42,3 +43,14 @@ def get_leaderboard(runs_df: pd.DataFrame, metrics: list[str]) -> pd.DataFrame:
42
 
43
  leaderboard = leaderboard.sort_values(by=list(leaderboard.columns), ascending=False)
44
  return leaderboard
 
 
 
 
 
 
 
 
 
 
 
 
2
  import wandb
3
 
4
 
5
+ def get_wandb_data(
6
+ entity: str, project: str, api_key: str, job_type: str
7
+ ) -> pd.DataFrame:
8
  api = wandb.Api(api_key=api_key)
9
 
10
  # Project is specified by <entity/project-name>
 
19
 
20
  # .config contains the hyperparameters.
21
  # We remove special values that start with _.
22
+ config_list.append(run.config)
23
 
24
  # .name is the human-readable name of the run.
25
  name_list.append(run.name)
 
32
 
33
 
34
  def get_leaderboard(runs_df: pd.DataFrame, metrics: list[str]) -> pd.DataFrame:
35
+ leaderboard = pd.DataFrame(index=runs_df["model"].unique(), columns=metrics).fillna(
36
+ 0
37
+ )
 
38
 
39
  for _, building_df in runs_df.groupby("unique_id"):
40
  for column in leaderboard.columns:
 
43
 
44
  leaderboard = leaderboard.sort_values(by=list(leaderboard.columns), ascending=False)
45
  return leaderboard
46
+
47
+
48
+ def get_model_ranks(runs_df: pd.DataFrame, metric: str) -> pd.DataFrame:
49
+ return (
50
+ runs_df.groupby(["model"])
51
+ .median(numeric_only=True)
52
+ .sort_values(by=metric)
53
+ .reset_index()
54
+ .rename_axis("rank")
55
+ .reset_index()[["rank", "model"]]
56
+ )