William Arnold commited on
Commit
d056c4a
·
1 Parent(s): 34898cb

Add AUC, ROC, etc

Browse files
README.md CHANGED
@@ -17,14 +17,14 @@ This dashboard is best viewed at [the huggingface space](https://huggingface.co/
17
 
18
  LLM MCQA (multiple choice question-answering) benchmarks are measured in the following way:
19
  1. Some number of few shot examples are pulled from the validation set of the MCQA benchmark and formatted as
20
- > **Quesiton**: What is the capital of France? \
21
  > (A) Paris \
22
  > (B) London \
23
  > (C) Berlin \
24
  > (D) Madrid \
25
  > **Answer**: A
26
  2. The target question is then appended, without the answer, and fed into the model as
27
- > **Quesiton**: What is the capital of France? \
28
  > (A) Paris \
29
  > (B) London \
30
  > (C) Berlin \
@@ -62,6 +62,7 @@ Here, $\Delta$ is a measure of how much more confident the model is in the corre
62
 
63
  An ideal model would have $\Phi = 1$ (and therefore $\Delta=1$) always, while a model that performs random guessing would have $p_i = \Phi = 0.25$ (and therefore $\Delta=0$) always.
64
 
 
65
  ### Reading $\Phi$ plots
66
  Let's look at an example: MMLU on Llama-7b and Guanaco-7b, an early example of instruction tuning, in the 5-shot setting.
67
 
@@ -81,7 +82,7 @@ Again, the <span style="color:lightblue">**blue line is Llama-7b**</span> and th
81
  * The 'accuracy' as we defined earlier is the percentage of samples with $\Delta > 0$. We can see this as the intersection of the curves with the vertical line at $\Delta = 0$. We can see that while instruction tuning doesn't seem to have changed the accuracy significantly, it has *vastly* altered the distribution of $\Delta$ values.
82
  * Guanaco-7b has a higher percentage of samples with large $\Delta$ values than Llama-7b. For example, in ~12-13% of the samples, Guanaco-7b predicts the correct answer with a probability at least 0.2 greater than the most confident incorrect answer.
83
  * Guanaco-7b also has a higher percentage of samples with very low $\Delta$ values. For example, we can read that ~75% of the samples have $\Delta > -0.2$, meaning that ~25% have $\Delta \leq -0.2$. This means that Guanaco-7b predicts the wrong answer with a probability at least 0.2 greater than the correct answer in 25% of the samples, when compared to Llama-7b which only does that in ~6-7% of the samples.
84
-
85
 
86
  ## How to use this notebook
87
 
 
17
 
18
  LLM MCQA (multiple choice question-answering) benchmarks are measured in the following way:
19
  1. Some number of few shot examples are pulled from the validation set of the MCQA benchmark and formatted as
20
+ > **Question**: What is the capital of France? \
21
  > (A) Paris \
22
  > (B) London \
23
  > (C) Berlin \
24
  > (D) Madrid \
25
  > **Answer**: A
26
  2. The target question is then appended, without the answer, and fed into the model as
27
+ > **Question**: What is the capital of France? \
28
  > (A) Paris \
29
  > (B) London \
30
  > (C) Berlin \
 
62
 
63
  An ideal model would have $\Phi = 1$ (and therefore $\Delta=1$) always, while a model that performs random guessing would have $p_i = \Phi = 0.25$ (and therefore $\Delta=0$) always.
64
 
65
+ <!---
66
  ### Reading $\Phi$ plots
67
  Let's look at an example: MMLU on Llama-7b and Guanaco-7b, an early example of instruction tuning, in the 5-shot setting.
68
 
 
82
  * The 'accuracy' as we defined earlier is the percentage of samples with $\Delta > 0$. We can see this as the intersection of the curves with the vertical line at $\Delta = 0$. We can see that while instruction tuning doesn't seem to have changed the accuracy significantly, it has *vastly* altered the distribution of $\Delta$ values.
83
  * Guanaco-7b has a higher percentage of samples with large $\Delta$ values than Llama-7b. For example, in ~12-13% of the samples, Guanaco-7b predicts the correct answer with a probability at least 0.2 greater than the most confident incorrect answer.
84
  * Guanaco-7b also has a higher percentage of samples with very low $\Delta$ values. For example, we can read that ~75% of the samples have $\Delta > -0.2$, meaning that ~25% have $\Delta \leq -0.2$. This means that Guanaco-7b predicts the wrong answer with a probability at least 0.2 greater than the correct answer in 25% of the samples, when compared to Llama-7b which only does that in ~6-7% of the samples.
85
+ -->
86
 
87
  ## How to use this notebook
88
 
requirements.txt CHANGED
@@ -5,3 +5,5 @@ tqdm>=4.66.4
5
  numpy>=1.26.4
6
  dacite>=1.8.1
7
  seaborn>=0.13.1
 
 
 
5
  numpy>=1.26.4
6
  dacite>=1.8.1
7
  seaborn>=0.13.1
8
+ polars>=1.5.0
9
+ scikit-learn>=1.5.1
src/rbeval/dash.py CHANGED
@@ -1,6 +1,7 @@
1
  from dataclasses import asdict
2
  from pathlib import Path
3
  from typing import List, Optional
 
4
  import streamlit as st
5
  import argparse
6
  from dacite import from_dict
@@ -9,7 +10,6 @@ from rbeval.plot.dash_utils import markdown_insert_images
9
  from rbeval.plot.data import EvalGroup, get_samples
10
  from rbeval.plot.score_cdf import (
11
  CdfPlotConfig,
12
- PlotData,
13
  plot_with_data,
14
  get_plot_data,
15
  plot_cfgs,
@@ -29,7 +29,7 @@ def cached_samples(dir: Path, name_filter: Optional[str]) -> List[EvalGroup]:
29
  @st.cache_data
30
  def cached_score_cdf(
31
  dir: Path, name_filter: Optional[str]
32
- ) -> tuple[List[PlotData], List[CdfPlotConfig]]:
33
  samples = cached_samples(dir, name_filter)
34
  cfgs = plot_cfgs()
35
  data = [get_plot_data(cfg, samples) for cfg in cfgs]
@@ -48,20 +48,6 @@ def cache_compare(
48
  return grouped_dict, base_name, comp_name
49
 
50
 
51
- def filter_for_group(data: List[PlotData], group: str) -> List[PlotData]:
52
- return [
53
- PlotData(
54
- renorm=[df for df in d.renorm if df["group"].iloc[0] == group],
55
- norenorm=[df for df in d.norenorm if df["group"].iloc[0] == group],
56
- )
57
- for d in data
58
- ]
59
-
60
-
61
- def get_group_names(data: List[PlotData]) -> List[str]:
62
- return sorted(set([df["group"].iloc[0] for d in data for df in d.renorm]))
63
-
64
-
65
  def main():
66
  parser = argparse.ArgumentParser(description="rbeval dashboard")
67
  parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
@@ -77,26 +63,32 @@ def main():
77
  st.markdown(markdown_insert_images(markdown), unsafe_allow_html=True)
78
 
79
  score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
80
- group_names = sorted([g.name for g in cached_samples(eval_dir, None)])
 
 
 
81
 
82
  st.markdown("""
83
- Below is a toggle which renormalizes multiple choice answer probabilities to sum to 1.
84
  For more performant models (anything after Llama 1) or in higher fewshot scenarios, this doesn't impact the results very much.
85
  """)
86
 
87
  renormed = st.toggle("Renormalize Probabilities", True)
 
 
 
88
 
89
  st.subheader("Model Performance Curves")
90
  for group in group_names:
91
- group_data = filter_for_group(score_cdf_data, group)
92
  with st.expander(group):
93
- figs = [
94
- fig
95
- for data, cdf in zip(group_data, cfgs)
96
- for fig in plot_with_data(cdf, data, renormed)
97
- ]
98
- for fig in figs:
99
- st.altair_chart(fig.chart, use_container_width=True) # type: ignore
 
100
 
101
  model_names = set(
102
  [
 
1
  from dataclasses import asdict
2
  from pathlib import Path
3
  from typing import List, Optional
4
+ import pandas as pd
5
  import streamlit as st
6
  import argparse
7
  from dacite import from_dict
 
10
  from rbeval.plot.data import EvalGroup, get_samples
11
  from rbeval.plot.score_cdf import (
12
  CdfPlotConfig,
 
13
  plot_with_data,
14
  get_plot_data,
15
  plot_cfgs,
 
29
  @st.cache_data
30
  def cached_score_cdf(
31
  dir: Path, name_filter: Optional[str]
32
+ ) -> tuple[List[pd.DataFrame], List[CdfPlotConfig]]:
33
  samples = cached_samples(dir, name_filter)
34
  cfgs = plot_cfgs()
35
  data = [get_plot_data(cfg, samples) for cfg in cfgs]
 
48
  return grouped_dict, base_name, comp_name
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def main():
52
  parser = argparse.ArgumentParser(description="rbeval dashboard")
53
  parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
 
63
  st.markdown(markdown_insert_images(markdown), unsafe_allow_html=True)
64
 
65
  score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
66
+ assert len(score_cdf_data) > 0, "No score cdfs found"
67
+ group_names: List[str] = sorted(
68
+ score_cdf_data[0]["group"].unique().tolist(), reverse=True
69
+ )
70
 
71
  st.markdown("""
72
+ Below is a toggle which renormalizes the multiple choice answer probabilities to sum to 1.
73
  For more performant models (anything after Llama 1) or in higher fewshot scenarios, this doesn't impact the results very much.
74
  """)
75
 
76
  renormed = st.toggle("Renormalize Probabilities", True)
77
+ fs_names = [str(i) + "-shot" for i in range(0, 5 + 1)]
78
+ fs_filt_sel = st.multiselect("Fewshot Filter", fs_names, default=fs_names)
79
+ fs_filt = [int(i.split("-")[0]) for i in fs_filt_sel]
80
 
81
  st.subheader("Model Performance Curves")
82
  for group in group_names:
 
83
  with st.expander(group):
84
+ for cfg, df in zip(cfgs, score_cdf_data):
85
+ group_data = df[
86
+ (df["group"] == group)
87
+ & (df["renorm"] == renormed)
88
+ & (df["fewshot"].isin(fs_filt))
89
+ ]
90
+ for fig in plot_with_data(cfg, group_data):
91
+ st.altair_chart(fig.chart, use_container_width=True) # type: ignore
92
 
93
  model_names = set(
94
  [
src/rbeval/plot/data.py CHANGED
@@ -59,8 +59,8 @@ def get_samples(inp: Path, name_filter: Optional[str]) -> List["EvalGroup"]:
59
  inc_logprobs.append(probs)
60
  eval = Eval(
61
  name=samples_file.stem,
62
- cor_logprobs=np.array(cor_logprobs),
63
- inc_logprobs=np.array(inc_logprobs),
64
  )
65
  model_eval.evals.append(eval)
66
  np.save(str(model_eval_cache_file), asdict(model_eval)) # type: ignore
 
59
  inc_logprobs.append(probs)
60
  eval = Eval(
61
  name=samples_file.stem,
62
+ cor_logprobs=np.array(cor_logprobs, dtype=np.float64),
63
+ inc_logprobs=np.array(inc_logprobs, dtype=np.float64),
64
  )
65
  model_eval.evals.append(eval)
66
  np.save(str(model_eval_cache_file), asdict(model_eval)) # type: ignore
src/rbeval/plot/model_comp.py CHANGED
@@ -11,7 +11,7 @@ import numpy as np
11
 
12
  from rbeval.eval_spec import EvalSpec
13
  from rbeval.plot.data import EvalGroup, Figure, ModelEval
14
- from rbeval.plot.utils import CdfData, renormed
15
  from typing import Any
16
 
17
 
@@ -135,23 +135,23 @@ def plot_diff_cdf(grouped: Dict[str, List[Scores]]) -> alt.HConcatChart:
135
  diff_cdf_data: List[pd.DataFrame] = []
136
  corr_cdf_data: List[pd.DataFrame] = []
137
  for score in score_list:
138
- diff_cdf = CdfData.from_samples(score.cor_minus_inc_samples)
139
  diff_cdf_data.append(
140
  pd.DataFrame(
141
  {
142
- "p": diff_cdf.scores,
143
- "1-CDF(p)": diff_cdf.cdf_p,
144
  "fewshot": score.spec.fewshot,
145
  "model": score.spec.model_name,
146
  }
147
  )
148
  )
149
- corr_cdf = CdfData.from_samples(score.cor_samples)
150
  corr_cdf_data.append(
151
  pd.DataFrame(
152
  {
153
- "p": corr_cdf.scores,
154
- "1-CDF(p)": corr_cdf.cdf_p,
155
  "fewshot": score.spec.fewshot,
156
  "model": score.spec.model_name,
157
  }
 
11
 
12
  from rbeval.eval_spec import EvalSpec
13
  from rbeval.plot.data import EvalGroup, Figure, ModelEval
14
+ from rbeval.plot.utils import PlotData, renormed
15
  from typing import Any
16
 
17
 
 
135
  diff_cdf_data: List[pd.DataFrame] = []
136
  corr_cdf_data: List[pd.DataFrame] = []
137
  for score in score_list:
138
+ diff_cdf = PlotData.perf_curve_from_samples(score.cor_minus_inc_samples)
139
  diff_cdf_data.append(
140
  pd.DataFrame(
141
  {
142
+ "p": diff_cdf.x,
143
+ "1-CDF(p)": diff_cdf.y,
144
  "fewshot": score.spec.fewshot,
145
  "model": score.spec.model_name,
146
  }
147
  )
148
  )
149
+ corr_cdf = PlotData.perf_curve_from_samples(score.cor_samples)
150
  corr_cdf_data.append(
151
  pd.DataFrame(
152
  {
153
+ "p": corr_cdf.x,
154
+ "1-CDF(p)": corr_cdf.y,
155
  "fewshot": score.spec.fewshot,
156
  "model": score.spec.model_name,
157
  }
src/rbeval/plot/score_cdf.py CHANGED
@@ -1,5 +1,4 @@
1
- from dataclasses import dataclass, field
2
- from typing import List, Optional
3
 
4
  from numpy._typing import NDArray
5
  from rbeval.plot.data import Eval, EvalGroup, Figure
@@ -7,83 +6,81 @@ from abc import ABC, abstractmethod
7
  import numpy as np
8
  import altair as alt
9
  import pandas as pd
 
10
 
11
- from rbeval.plot.utils import CdfData, renormed
12
-
13
-
14
- @dataclass
15
- class PlotData:
16
- renorm: List[pd.DataFrame] = field(default_factory=list)
17
- norenorm: List[pd.DataFrame] = field(default_factory=list)
18
 
19
 
20
  def plot_cfgs():
21
- return [CorrectProbCdfPlot(), CorrIncorrDiffConfig()]
 
 
 
 
 
 
 
22
 
23
 
24
  def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
25
  return [
26
  a
27
  for cfg in plot_cfgs()
28
- for renorm in [True, False]
29
- for a in plot_with_data(cfg, get_plot_data(cfg, samples), renorm)
30
  ]
31
 
32
 
33
  def get_plot_data(
34
  cfg: "CdfPlotConfig",
35
  samples: List[EvalGroup],
36
- ) -> PlotData:
37
- data = PlotData()
38
  for renorm in [True, False]:
39
- gfs = data.renorm if renorm else data.norenorm
40
  for group in samples:
41
- dfs: List[pd.DataFrame] = []
42
  for m in group.model_evals:
43
  spec = m.eval_spec
44
  cdf = cfg.get_cdf(m.evals, renorm)
45
- df = pd.DataFrame(
46
  {
47
- "x": cdf.scores,
48
- "y": cdf.cdf_p,
49
  "label": m.model_name,
50
  "group": group.name,
51
  "renorm": renorm,
52
  "fewshot": spec.fewshot,
53
  }
54
  )
55
- dfs.append(df)
56
- gfs.append(pd.concat(dfs))
57
  return data
58
 
59
 
60
  def plot_with_data(
61
  cfg: "CdfPlotConfig",
62
- data: PlotData,
63
- renorm: bool = True,
64
  ) -> List[Figure]:
65
  figures: List[Figure] = []
66
- group_dfs = data.renorm if renorm else data.norenorm
67
- for df in group_dfs:
68
- group_name: str = str(df["group"].iloc[0]) # type: ignore
69
  label_selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
70
  fs_selection = alt.selection_point(fields=["fewshot"], bind="legend") # type: ignore
 
 
71
  chart = (
72
- alt.Chart(df) # type: ignore
73
- .mark_line()
74
- .encode(
75
- x=alt.X("x:Q", title=cfg.xlabel),
76
- y=alt.Y("y:Q", title=cfg.ylabel),
77
  color=alt.Color(
78
  "label:N", legend=alt.Legend(symbolOpacity=1.0, labelLimit=1000)
79
- ).scale(scheme="set1"),
 
80
  opacity=alt.condition( # type: ignore
81
  label_selection & fs_selection,
82
  alt.Opacity("fewshot:O"),
83
  alt.value(0.0), # type: ignore
84
  ),
85
  )
86
- .properties(title=cfg.title(group_name, renorm))
87
  .add_params(fs_selection, label_selection)
88
  .interactive()
89
  )
@@ -102,9 +99,10 @@ class CdfPlotConfig(ABC):
102
  ylabel: str
103
  name: str = ""
104
  xline: Optional[float] = None
 
105
 
106
  @abstractmethod
107
- def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
108
  pass
109
 
110
  def title(self, group_name: str, prob_renorm: bool) -> str:
@@ -119,25 +117,74 @@ class CdfPlotConfig(ABC):
119
 
120
 
121
  class CorrectProbCdfPlot(CdfPlotConfig):
122
- name = "𝚽 Performance Curve"
123
  xlabel = "𝚽"
124
- ylabel = "% of correct answers with 𝚽 > x"
125
  xline = 0.25
126
 
127
- def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
128
  samples = [np.exp(e.cor_logprobs) for e in evals]
129
  if prob_renorm:
130
  samples = [renormed(e)[0] for e in evals]
131
- return CdfData.from_samples(samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  class CorrIncorrDiffConfig(CdfPlotConfig):
135
- name = "𝚫 Performance Curve"
136
  xline = 0.0
137
  xlabel = "𝚫"
138
- ylabel = "% of samples with 𝚫 > x"
139
 
140
- def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
141
  score_arrs: List[NDArray[np.float64]] = []
142
  for e in evals:
143
  if prob_renorm:
@@ -148,4 +195,47 @@ class CorrIncorrDiffConfig(CdfPlotConfig):
148
 
149
  score_arrs.append(cor_probs - inc_probs.max(axis=1))
150
 
151
- return CdfData.from_samples(score_arrs, per_sample_weighting=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Optional
 
2
 
3
  from numpy._typing import NDArray
4
  from rbeval.plot.data import Eval, EvalGroup, Figure
 
6
  import numpy as np
7
  import altair as alt
8
  import pandas as pd
9
+ from sklearn.metrics import roc_curve, roc_auc_score # type: ignore
10
 
11
+ from rbeval.plot.utils import PlotData, renormed
 
 
 
 
 
 
12
 
13
 
14
  def plot_cfgs():
15
+ return [
16
+ CorrectProbCdfPlot(),
17
+ CorrIncorrDiffConfig(),
18
+ ROCCurve(),
19
+ MaxIncorProbCdfPlot(),
20
+ AccVsLoss(),
21
+ AccVsAUC(),
22
+ ]
23
 
24
 
25
  def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
26
  return [
27
  a
28
  for cfg in plot_cfgs()
29
+ for a in plot_with_data(cfg, get_plot_data(cfg, samples))
 
30
  ]
31
 
32
 
33
  def get_plot_data(
34
  cfg: "CdfPlotConfig",
35
  samples: List[EvalGroup],
36
+ ) -> pd.DataFrame:
37
+ records = []
38
  for renorm in [True, False]:
 
39
  for group in samples:
 
40
  for m in group.model_evals:
41
  spec = m.eval_spec
42
  cdf = cfg.get_cdf(m.evals, renorm)
43
+ records.append(
44
  {
45
+ "x": cdf.x,
46
+ "y": cdf.y,
47
  "label": m.model_name,
48
  "group": group.name,
49
  "renorm": renorm,
50
  "fewshot": spec.fewshot,
51
  }
52
  )
53
+ data = pd.DataFrame.from_records(records)
 
54
  return data
55
 
56
 
57
  def plot_with_data(
58
  cfg: "CdfPlotConfig",
59
+ data: pd.DataFrame,
 
60
  ) -> List[Figure]:
61
  figures: List[Figure] = []
62
+ for (group_name, renorm), df in data.groupby(["group", "renorm"]):
63
+ assert isinstance(group_name, str)
64
+ assert isinstance(renorm, (bool, np.bool_))
65
  label_selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
66
  fs_selection = alt.selection_point(fields=["fewshot"], bind="legend") # type: ignore
67
+ chart = alt.Chart(df.explode(["x", "y"])) # type: ignore
68
+ chart = chart.mark_line() if cfg.type == "line" else chart.mark_point()
69
  chart = (
70
+ chart.encode(
71
+ x=alt.X("x:Q", title=cfg.xlabel, scale=alt.Scale(zero=False)),
72
+ y=alt.Y("y:Q", title=cfg.ylabel, scale=alt.Scale(zero=False)),
 
 
73
  color=alt.Color(
74
  "label:N", legend=alt.Legend(symbolOpacity=1.0, labelLimit=1000)
75
+ ).scale(scheme="dark2"),
76
+ shape="label:N" if cfg.type == "scatter" else alt.Undefined,
77
  opacity=alt.condition( # type: ignore
78
  label_selection & fs_selection,
79
  alt.Opacity("fewshot:O"),
80
  alt.value(0.0), # type: ignore
81
  ),
82
  )
83
+ .properties(title=cfg.title(group_name, renorm)) # type: ignore
84
  .add_params(fs_selection, label_selection)
85
  .interactive()
86
  )
 
99
  ylabel: str
100
  name: str = ""
101
  xline: Optional[float] = None
102
+ type: Literal["line", "scatter"] = "line"
103
 
104
  @abstractmethod
105
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
106
  pass
107
 
108
  def title(self, group_name: str, prob_renorm: bool) -> str:
 
117
 
118
 
119
  class CorrectProbCdfPlot(CdfPlotConfig):
120
+ name = "CDF(𝚽)"
121
  xlabel = "𝚽"
122
+ ylabel = "% of correct answers with 𝚽 < x"
123
  xline = 0.25
124
 
125
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
126
  samples = [np.exp(e.cor_logprobs) for e in evals]
127
  if prob_renorm:
128
  samples = [renormed(e)[0] for e in evals]
129
+ return PlotData.perf_curve_from_samples(samples)
130
+
131
+
132
+ class MaxIncorProbCdfPlot(CdfPlotConfig):
133
+ name = "CDF(Max(Incorrect))"
134
+ xlabel = "max(incorrect)"
135
+ ylabel = "% of correct answers with max(incorrect) < x"
136
+ xline = 0.25
137
+
138
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
139
+ if prob_renorm:
140
+ samples = [renormed(e)[1].max(axis=1) for e in evals]
141
+ else:
142
+ samples = [np.exp(np.max(e.inc_logprobs, axis=1)) for e in evals]
143
+ return PlotData.perf_curve_from_samples(samples)
144
+
145
+
146
+ class AccVsLoss(CdfPlotConfig):
147
+ name = "Cross Entropy Loss vs Accuracy"
148
+ xlabel = "Accuracy"
149
+ ylabel = "CE Loss"
150
+ xline = None
151
+ type = "scatter"
152
+
153
+ def get_cdf(self, evals: List[Eval], _prob_renorm: bool) -> "PlotData":
154
+ cor, incor = zip(*[renormed(e) for e in evals])
155
+ cor = np.concatenate(cor)
156
+ incor = np.concatenate(incor).max(axis=1)
157
+ pct_corr = np.mean(cor > incor)
158
+
159
+ celoss = np.mean(-np.log(cor))
160
+ return PlotData(np.array([celoss]), np.array([pct_corr]))
161
+
162
+
163
+ class AccVsAUC(CdfPlotConfig):
164
+ name = "Simulated AUROC vs Accuracy"
165
+ xlabel = "Accuracy"
166
+ ylabel = "Simulated AUROC"
167
+ xline = None
168
+ type = "scatter"
169
+
170
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
171
+ cor, incor = zip(*[renormed(e) for e in evals])
172
+ cor = np.concatenate(cor)
173
+ incor = np.concatenate(incor).max(axis=1)
174
+ pct_corr = np.mean(cor > incor)
175
+
176
+ scores, labels, weights = roc_data(evals, prob_renorm)
177
+ auc = roc_auc_score(labels, scores, sample_weight=weights)
178
+ return PlotData(np.array([auc]), np.array([pct_corr]))
179
 
180
 
181
  class CorrIncorrDiffConfig(CdfPlotConfig):
182
+ name = "CDF(𝚫)"
183
  xline = 0.0
184
  xlabel = "𝚫"
185
+ ylabel = "% of samples with 𝚫 < x"
186
 
187
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
188
  score_arrs: List[NDArray[np.float64]] = []
189
  for e in evals:
190
  if prob_renorm:
 
195
 
196
  score_arrs.append(cor_probs - inc_probs.max(axis=1))
197
 
198
+ return PlotData.perf_curve_from_samples(score_arrs, per_sample_weighting=True)
199
+
200
+
201
+ class ROCCurve(CdfPlotConfig):
202
+ name = "Simulated ROC Curve"
203
+ xline = None
204
+ xlabel = "FPR"
205
+ ylabel = "TPR"
206
+
207
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
208
+ scores, labels, weights = roc_data(evals, prob_renorm)
209
+ assert len(scores) == len(labels) == len(weights)
210
+ tpr, fpr, _ = roc_curve(labels, scores, sample_weight=weights)
211
+
212
+ x_interp = np.linspace(0, 1, 600)
213
+ y_interp = np.interp(x_interp, fpr, tpr)
214
+
215
+ return PlotData(x_interp, y_interp)
216
+
217
+
218
+ def roc_data(evals: List[Eval], prob_renorm):
219
+ weight_arrs = []
220
+ total = sum(len(e.cor_logprobs) for e in evals)
221
+ for samples in evals:
222
+ this = np.ones(2 * len(samples.cor_logprobs)) / (2 * total)
223
+ weight_arrs.append(this)
224
+
225
+ score_arrs = []
226
+ label_arrs = []
227
+ for e in evals:
228
+ if prob_renorm:
229
+ cor_probs, inc_probs = renormed(e)
230
+ else:
231
+ cor_probs = np.exp(e.cor_logprobs)
232
+ inc_probs = np.exp(e.inc_logprobs)
233
+ score_arrs.append(cor_probs)
234
+ label_arrs.append(np.ones(len(cor_probs)))
235
+ score_arrs.append(inc_probs.max(axis=1))
236
+ label_arrs.append(np.zeros(inc_probs.shape[0]))
237
+
238
+ scores = np.concatenate(score_arrs)
239
+ labels = np.concatenate(label_arrs)
240
+ weights = np.concatenate(weight_arrs)
241
+ return scores, labels, weights
src/rbeval/plot/utils.py CHANGED
@@ -17,14 +17,17 @@ def renormed(eval: Eval) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
17
 
18
 
19
  @dataclass
20
- class CdfData:
21
- cdf_p: np.ndarray
22
- scores: np.ndarray
23
 
24
  @classmethod
25
- def from_samples(
26
- cls, samples: List[NDArray[np.float64]], per_sample_weighting: bool = True
27
- ) -> "CdfData":
 
 
 
28
  num_cats = len(samples)
29
  scores = np.concatenate(samples)
30
  if per_sample_weighting:
@@ -37,26 +40,29 @@ class CdfData:
37
  weights = np.concatenate(weight_arrs)
38
  else:
39
  weights = np.ones_like(scores) / len(scores)
40
- return cls.from_weights(weights, scores)
41
 
42
  @classmethod
43
- def from_weights(
44
  cls,
45
  weights: NDArray[np.float64],
46
  base_scores: NDArray[np.float64],
47
  max_p: int = 600,
48
- ) -> "CdfData":
 
49
  sort_perm = base_scores.argsort()
50
  base_weights = weights[sort_perm]
51
  base_scores = base_scores[sort_perm]
52
- base_cdf_p = 1 - np.cumsum(base_weights)
 
 
53
  minscore, maxscore = base_scores[0], base_scores[-1]
54
  if len(base_scores) > max_p:
55
  scores = np.linspace(minscore, maxscore, max_p) # type: ignore
56
  cdf_p = np.interp(scores, base_scores, base_cdf_p) # type: ignore
57
  else:
58
  scores, cdf_p = base_scores, base_cdf_p
59
- return CdfData(
60
- cdf_p=cdf_p,
61
- scores=scores, # type: ignore
62
  )
 
17
 
18
 
19
  @dataclass
20
+ class PlotData:
21
+ y: np.ndarray
22
+ x: np.ndarray
23
 
24
  @classmethod
25
+ def perf_curve_from_samples(
26
+ cls,
27
+ samples: List[NDArray[np.float64]],
28
+ per_sample_weighting: bool = True,
29
+ one_minus: bool = False,
30
+ ) -> "PlotData":
31
  num_cats = len(samples)
32
  scores = np.concatenate(samples)
33
  if per_sample_weighting:
 
40
  weights = np.concatenate(weight_arrs)
41
  else:
42
  weights = np.ones_like(scores) / len(scores)
43
+ return cls.perf_curve_from_weights(weights, scores, one_minus=one_minus)
44
 
45
  @classmethod
46
+ def perf_curve_from_weights(
47
  cls,
48
  weights: NDArray[np.float64],
49
  base_scores: NDArray[np.float64],
50
  max_p: int = 600,
51
+ one_minus: bool = True,
52
+ ) -> "PlotData":
53
  sort_perm = base_scores.argsort()
54
  base_weights = weights[sort_perm]
55
  base_scores = base_scores[sort_perm]
56
+ base_cdf_p = np.cumsum(base_weights)
57
+ if one_minus:
58
+ base_cdf_p = 1 - base_cdf_p
59
  minscore, maxscore = base_scores[0], base_scores[-1]
60
  if len(base_scores) > max_p:
61
  scores = np.linspace(minscore, maxscore, max_p) # type: ignore
62
  cdf_p = np.interp(scores, base_scores, base_cdf_p) # type: ignore
63
  else:
64
  scores, cdf_p = base_scores, base_cdf_p
65
+ return PlotData(
66
+ y=cdf_p,
67
+ x=scores, # type: ignore
68
  )