William Arnold
commited on
Commit
·
d056c4a
1
Parent(s):
34898cb
Add AUC, ROC, etc
Browse files- README.md +4 -3
- requirements.txt +2 -0
- src/rbeval/dash.py +18 -26
- src/rbeval/plot/data.py +2 -2
- src/rbeval/plot/model_comp.py +7 -7
- src/rbeval/plot/score_cdf.py +132 -42
- src/rbeval/plot/utils.py +19 -13
README.md
CHANGED
@@ -17,14 +17,14 @@ This dashboard is best viewed at [the huggingface space](https://huggingface.co/
|
|
17 |
|
18 |
LLM MCQA (multiple choice question-answering) benchmarks are measured in the following way:
|
19 |
1. Some number of few shot examples are pulled from the validation set of the MCQA benchmark and formatted as
|
20 |
-
> **
|
21 |
> (A) Paris \
|
22 |
> (B) London \
|
23 |
> (C) Berlin \
|
24 |
> (D) Madrid \
|
25 |
> **Answer**: A
|
26 |
2. The target question is then appended, without the answer, and fed into the model as
|
27 |
-
> **
|
28 |
> (A) Paris \
|
29 |
> (B) London \
|
30 |
> (C) Berlin \
|
@@ -62,6 +62,7 @@ Here, $\Delta$ is a measure of how much more confident the model is in the corre
|
|
62 |
|
63 |
An ideal model would have $\Phi = 1$ (and therefore $\Delta=1$) always, while a model that performs random guessing would have $p_i = \Phi = 0.25$ (and therefore $\Delta=0$) always.
|
64 |
|
|
|
65 |
### Reading $\Phi$ plots
|
66 |
Let's look at an example: MMLU on Llama-7b and Guanaco-7b, an early example of instruction tuning, in the 5-shot setting.
|
67 |
|
@@ -81,7 +82,7 @@ Again, the <span style="color:lightblue">**blue line is Llama-7b**</span> and th
|
|
81 |
* The 'accuracy' as we defined earlier is the percentage of samples with $\Delta > 0$. We can see this as the intersection of the curves with the vertical line at $\Delta = 0$. We can see that while instruction tuning doesn't seem to have changed the accuracy significantly, it has *vastly* altered the distribution of $\Delta$ values.
|
82 |
* Guanaco-7b has a higher percentage of samples with large $\Delta$ values than Llama-7b. For example, in ~12-13% of the samples, Guanaco-7b predicts the correct answer with a probability at least 0.2 greater than the most confident incorrect answer.
|
83 |
* Guanaco-7b also has a higher percentage of samples with very low $\Delta$ values. For example, we can read that ~75% of the samples have $\Delta > -0.2$, meaning that ~25% have $\Delta \leq -0.2$. This means that Guanaco-7b predicts the wrong answer with a probability at least 0.2 greater than the correct answer in 25% of the samples, when compared to Llama-7b which only does that in ~6-7% of the samples.
|
84 |
-
|
85 |
|
86 |
## How to use this notebook
|
87 |
|
|
|
17 |
|
18 |
LLM MCQA (multiple choice question-answering) benchmarks are measured in the following way:
|
19 |
1. Some number of few shot examples are pulled from the validation set of the MCQA benchmark and formatted as
|
20 |
+
> **Question**: What is the capital of France? \
|
21 |
> (A) Paris \
|
22 |
> (B) London \
|
23 |
> (C) Berlin \
|
24 |
> (D) Madrid \
|
25 |
> **Answer**: A
|
26 |
2. The target question is then appended, without the answer, and fed into the model as
|
27 |
+
> **Question**: What is the capital of France? \
|
28 |
> (A) Paris \
|
29 |
> (B) London \
|
30 |
> (C) Berlin \
|
|
|
62 |
|
63 |
An ideal model would have $\Phi = 1$ (and therefore $\Delta=1$) always, while a model that performs random guessing would have $p_i = \Phi = 0.25$ (and therefore $\Delta=0$) always.
|
64 |
|
65 |
+
<!---
|
66 |
### Reading $\Phi$ plots
|
67 |
Let's look at an example: MMLU on Llama-7b and Guanaco-7b, an early example of instruction tuning, in the 5-shot setting.
|
68 |
|
|
|
82 |
* The 'accuracy' as we defined earlier is the percentage of samples with $\Delta > 0$. We can see this as the intersection of the curves with the vertical line at $\Delta = 0$. We can see that while instruction tuning doesn't seem to have changed the accuracy significantly, it has *vastly* altered the distribution of $\Delta$ values.
|
83 |
* Guanaco-7b has a higher percentage of samples with large $\Delta$ values than Llama-7b. For example, in ~12-13% of the samples, Guanaco-7b predicts the correct answer with a probability at least 0.2 greater than the most confident incorrect answer.
|
84 |
* Guanaco-7b also has a higher percentage of samples with very low $\Delta$ values. For example, we can read that ~75% of the samples have $\Delta > -0.2$, meaning that ~25% have $\Delta \leq -0.2$. This means that Guanaco-7b predicts the wrong answer with a probability at least 0.2 greater than the correct answer in 25% of the samples, when compared to Llama-7b which only does that in ~6-7% of the samples.
|
85 |
+
-->
|
86 |
|
87 |
## How to use this notebook
|
88 |
|
requirements.txt
CHANGED
@@ -5,3 +5,5 @@ tqdm>=4.66.4
|
|
5 |
numpy>=1.26.4
|
6 |
dacite>=1.8.1
|
7 |
seaborn>=0.13.1
|
|
|
|
|
|
5 |
numpy>=1.26.4
|
6 |
dacite>=1.8.1
|
7 |
seaborn>=0.13.1
|
8 |
+
polars>=1.5.0
|
9 |
+
scikit-learn>=1.5.1
|
src/rbeval/dash.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from dataclasses import asdict
|
2 |
from pathlib import Path
|
3 |
from typing import List, Optional
|
|
|
4 |
import streamlit as st
|
5 |
import argparse
|
6 |
from dacite import from_dict
|
@@ -9,7 +10,6 @@ from rbeval.plot.dash_utils import markdown_insert_images
|
|
9 |
from rbeval.plot.data import EvalGroup, get_samples
|
10 |
from rbeval.plot.score_cdf import (
|
11 |
CdfPlotConfig,
|
12 |
-
PlotData,
|
13 |
plot_with_data,
|
14 |
get_plot_data,
|
15 |
plot_cfgs,
|
@@ -29,7 +29,7 @@ def cached_samples(dir: Path, name_filter: Optional[str]) -> List[EvalGroup]:
|
|
29 |
@st.cache_data
|
30 |
def cached_score_cdf(
|
31 |
dir: Path, name_filter: Optional[str]
|
32 |
-
) -> tuple[List[
|
33 |
samples = cached_samples(dir, name_filter)
|
34 |
cfgs = plot_cfgs()
|
35 |
data = [get_plot_data(cfg, samples) for cfg in cfgs]
|
@@ -48,20 +48,6 @@ def cache_compare(
|
|
48 |
return grouped_dict, base_name, comp_name
|
49 |
|
50 |
|
51 |
-
def filter_for_group(data: List[PlotData], group: str) -> List[PlotData]:
|
52 |
-
return [
|
53 |
-
PlotData(
|
54 |
-
renorm=[df for df in d.renorm if df["group"].iloc[0] == group],
|
55 |
-
norenorm=[df for df in d.norenorm if df["group"].iloc[0] == group],
|
56 |
-
)
|
57 |
-
for d in data
|
58 |
-
]
|
59 |
-
|
60 |
-
|
61 |
-
def get_group_names(data: List[PlotData]) -> List[str]:
|
62 |
-
return sorted(set([df["group"].iloc[0] for d in data for df in d.renorm]))
|
63 |
-
|
64 |
-
|
65 |
def main():
|
66 |
parser = argparse.ArgumentParser(description="rbeval dashboard")
|
67 |
parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
|
@@ -77,26 +63,32 @@ def main():
|
|
77 |
st.markdown(markdown_insert_images(markdown), unsafe_allow_html=True)
|
78 |
|
79 |
score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
st.markdown("""
|
83 |
-
Below is a toggle which renormalizes multiple choice answer probabilities to sum to 1.
|
84 |
For more performant models (anything after Llama 1) or in higher fewshot scenarios, this doesn't impact the results very much.
|
85 |
""")
|
86 |
|
87 |
renormed = st.toggle("Renormalize Probabilities", True)
|
|
|
|
|
|
|
88 |
|
89 |
st.subheader("Model Performance Curves")
|
90 |
for group in group_names:
|
91 |
-
group_data = filter_for_group(score_cdf_data, group)
|
92 |
with st.expander(group):
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
100 |
|
101 |
model_names = set(
|
102 |
[
|
|
|
1 |
from dataclasses import asdict
|
2 |
from pathlib import Path
|
3 |
from typing import List, Optional
|
4 |
+
import pandas as pd
|
5 |
import streamlit as st
|
6 |
import argparse
|
7 |
from dacite import from_dict
|
|
|
10 |
from rbeval.plot.data import EvalGroup, get_samples
|
11 |
from rbeval.plot.score_cdf import (
|
12 |
CdfPlotConfig,
|
|
|
13 |
plot_with_data,
|
14 |
get_plot_data,
|
15 |
plot_cfgs,
|
|
|
29 |
@st.cache_data
|
30 |
def cached_score_cdf(
|
31 |
dir: Path, name_filter: Optional[str]
|
32 |
+
) -> tuple[List[pd.DataFrame], List[CdfPlotConfig]]:
|
33 |
samples = cached_samples(dir, name_filter)
|
34 |
cfgs = plot_cfgs()
|
35 |
data = [get_plot_data(cfg, samples) for cfg in cfgs]
|
|
|
48 |
return grouped_dict, base_name, comp_name
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def main():
|
52 |
parser = argparse.ArgumentParser(description="rbeval dashboard")
|
53 |
parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
|
|
|
63 |
st.markdown(markdown_insert_images(markdown), unsafe_allow_html=True)
|
64 |
|
65 |
score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
|
66 |
+
assert len(score_cdf_data) > 0, "No score cdfs found"
|
67 |
+
group_names: List[str] = sorted(
|
68 |
+
score_cdf_data[0]["group"].unique().tolist(), reverse=True
|
69 |
+
)
|
70 |
|
71 |
st.markdown("""
|
72 |
+
Below is a toggle which renormalizes the multiple choice answer probabilities to sum to 1.
|
73 |
For more performant models (anything after Llama 1) or in higher fewshot scenarios, this doesn't impact the results very much.
|
74 |
""")
|
75 |
|
76 |
renormed = st.toggle("Renormalize Probabilities", True)
|
77 |
+
fs_names = [str(i) + "-shot" for i in range(0, 5 + 1)]
|
78 |
+
fs_filt_sel = st.multiselect("Fewshot Filter", fs_names, default=fs_names)
|
79 |
+
fs_filt = [int(i.split("-")[0]) for i in fs_filt_sel]
|
80 |
|
81 |
st.subheader("Model Performance Curves")
|
82 |
for group in group_names:
|
|
|
83 |
with st.expander(group):
|
84 |
+
for cfg, df in zip(cfgs, score_cdf_data):
|
85 |
+
group_data = df[
|
86 |
+
(df["group"] == group)
|
87 |
+
& (df["renorm"] == renormed)
|
88 |
+
& (df["fewshot"].isin(fs_filt))
|
89 |
+
]
|
90 |
+
for fig in plot_with_data(cfg, group_data):
|
91 |
+
st.altair_chart(fig.chart, use_container_width=True) # type: ignore
|
92 |
|
93 |
model_names = set(
|
94 |
[
|
src/rbeval/plot/data.py
CHANGED
@@ -59,8 +59,8 @@ def get_samples(inp: Path, name_filter: Optional[str]) -> List["EvalGroup"]:
|
|
59 |
inc_logprobs.append(probs)
|
60 |
eval = Eval(
|
61 |
name=samples_file.stem,
|
62 |
-
cor_logprobs=np.array(cor_logprobs),
|
63 |
-
inc_logprobs=np.array(inc_logprobs),
|
64 |
)
|
65 |
model_eval.evals.append(eval)
|
66 |
np.save(str(model_eval_cache_file), asdict(model_eval)) # type: ignore
|
|
|
59 |
inc_logprobs.append(probs)
|
60 |
eval = Eval(
|
61 |
name=samples_file.stem,
|
62 |
+
cor_logprobs=np.array(cor_logprobs, dtype=np.float64),
|
63 |
+
inc_logprobs=np.array(inc_logprobs, dtype=np.float64),
|
64 |
)
|
65 |
model_eval.evals.append(eval)
|
66 |
np.save(str(model_eval_cache_file), asdict(model_eval)) # type: ignore
|
src/rbeval/plot/model_comp.py
CHANGED
@@ -11,7 +11,7 @@ import numpy as np
|
|
11 |
|
12 |
from rbeval.eval_spec import EvalSpec
|
13 |
from rbeval.plot.data import EvalGroup, Figure, ModelEval
|
14 |
-
from rbeval.plot.utils import
|
15 |
from typing import Any
|
16 |
|
17 |
|
@@ -135,23 +135,23 @@ def plot_diff_cdf(grouped: Dict[str, List[Scores]]) -> alt.HConcatChart:
|
|
135 |
diff_cdf_data: List[pd.DataFrame] = []
|
136 |
corr_cdf_data: List[pd.DataFrame] = []
|
137 |
for score in score_list:
|
138 |
-
diff_cdf =
|
139 |
diff_cdf_data.append(
|
140 |
pd.DataFrame(
|
141 |
{
|
142 |
-
"p": diff_cdf.
|
143 |
-
"1-CDF(p)": diff_cdf.
|
144 |
"fewshot": score.spec.fewshot,
|
145 |
"model": score.spec.model_name,
|
146 |
}
|
147 |
)
|
148 |
)
|
149 |
-
corr_cdf =
|
150 |
corr_cdf_data.append(
|
151 |
pd.DataFrame(
|
152 |
{
|
153 |
-
"p": corr_cdf.
|
154 |
-
"1-CDF(p)": corr_cdf.
|
155 |
"fewshot": score.spec.fewshot,
|
156 |
"model": score.spec.model_name,
|
157 |
}
|
|
|
11 |
|
12 |
from rbeval.eval_spec import EvalSpec
|
13 |
from rbeval.plot.data import EvalGroup, Figure, ModelEval
|
14 |
+
from rbeval.plot.utils import PlotData, renormed
|
15 |
from typing import Any
|
16 |
|
17 |
|
|
|
135 |
diff_cdf_data: List[pd.DataFrame] = []
|
136 |
corr_cdf_data: List[pd.DataFrame] = []
|
137 |
for score in score_list:
|
138 |
+
diff_cdf = PlotData.perf_curve_from_samples(score.cor_minus_inc_samples)
|
139 |
diff_cdf_data.append(
|
140 |
pd.DataFrame(
|
141 |
{
|
142 |
+
"p": diff_cdf.x,
|
143 |
+
"1-CDF(p)": diff_cdf.y,
|
144 |
"fewshot": score.spec.fewshot,
|
145 |
"model": score.spec.model_name,
|
146 |
}
|
147 |
)
|
148 |
)
|
149 |
+
corr_cdf = PlotData.perf_curve_from_samples(score.cor_samples)
|
150 |
corr_cdf_data.append(
|
151 |
pd.DataFrame(
|
152 |
{
|
153 |
+
"p": corr_cdf.x,
|
154 |
+
"1-CDF(p)": corr_cdf.y,
|
155 |
"fewshot": score.spec.fewshot,
|
156 |
"model": score.spec.model_name,
|
157 |
}
|
src/rbeval/plot/score_cdf.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
from
|
2 |
-
from typing import List, Optional
|
3 |
|
4 |
from numpy._typing import NDArray
|
5 |
from rbeval.plot.data import Eval, EvalGroup, Figure
|
@@ -7,83 +6,81 @@ from abc import ABC, abstractmethod
|
|
7 |
import numpy as np
|
8 |
import altair as alt
|
9 |
import pandas as pd
|
|
|
10 |
|
11 |
-
from rbeval.plot.utils import
|
12 |
-
|
13 |
-
|
14 |
-
@dataclass
|
15 |
-
class PlotData:
|
16 |
-
renorm: List[pd.DataFrame] = field(default_factory=list)
|
17 |
-
norenorm: List[pd.DataFrame] = field(default_factory=list)
|
18 |
|
19 |
|
20 |
def plot_cfgs():
|
21 |
-
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
|
25 |
return [
|
26 |
a
|
27 |
for cfg in plot_cfgs()
|
28 |
-
for
|
29 |
-
for a in plot_with_data(cfg, get_plot_data(cfg, samples), renorm)
|
30 |
]
|
31 |
|
32 |
|
33 |
def get_plot_data(
|
34 |
cfg: "CdfPlotConfig",
|
35 |
samples: List[EvalGroup],
|
36 |
-
) ->
|
37 |
-
|
38 |
for renorm in [True, False]:
|
39 |
-
gfs = data.renorm if renorm else data.norenorm
|
40 |
for group in samples:
|
41 |
-
dfs: List[pd.DataFrame] = []
|
42 |
for m in group.model_evals:
|
43 |
spec = m.eval_spec
|
44 |
cdf = cfg.get_cdf(m.evals, renorm)
|
45 |
-
|
46 |
{
|
47 |
-
"x": cdf.
|
48 |
-
"y": cdf.
|
49 |
"label": m.model_name,
|
50 |
"group": group.name,
|
51 |
"renorm": renorm,
|
52 |
"fewshot": spec.fewshot,
|
53 |
}
|
54 |
)
|
55 |
-
|
56 |
-
gfs.append(pd.concat(dfs))
|
57 |
return data
|
58 |
|
59 |
|
60 |
def plot_with_data(
|
61 |
cfg: "CdfPlotConfig",
|
62 |
-
data:
|
63 |
-
renorm: bool = True,
|
64 |
) -> List[Figure]:
|
65 |
figures: List[Figure] = []
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
label_selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
|
70 |
fs_selection = alt.selection_point(fields=["fewshot"], bind="legend") # type: ignore
|
|
|
|
|
71 |
chart = (
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
x=alt.X("x:Q", title=cfg.xlabel),
|
76 |
-
y=alt.Y("y:Q", title=cfg.ylabel),
|
77 |
color=alt.Color(
|
78 |
"label:N", legend=alt.Legend(symbolOpacity=1.0, labelLimit=1000)
|
79 |
-
).scale(scheme="
|
|
|
80 |
opacity=alt.condition( # type: ignore
|
81 |
label_selection & fs_selection,
|
82 |
alt.Opacity("fewshot:O"),
|
83 |
alt.value(0.0), # type: ignore
|
84 |
),
|
85 |
)
|
86 |
-
.properties(title=cfg.title(group_name, renorm))
|
87 |
.add_params(fs_selection, label_selection)
|
88 |
.interactive()
|
89 |
)
|
@@ -102,9 +99,10 @@ class CdfPlotConfig(ABC):
|
|
102 |
ylabel: str
|
103 |
name: str = ""
|
104 |
xline: Optional[float] = None
|
|
|
105 |
|
106 |
@abstractmethod
|
107 |
-
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "
|
108 |
pass
|
109 |
|
110 |
def title(self, group_name: str, prob_renorm: bool) -> str:
|
@@ -119,25 +117,74 @@ class CdfPlotConfig(ABC):
|
|
119 |
|
120 |
|
121 |
class CorrectProbCdfPlot(CdfPlotConfig):
|
122 |
-
name = "𝚽
|
123 |
xlabel = "𝚽"
|
124 |
-
ylabel = "% of correct answers with 𝚽
|
125 |
xline = 0.25
|
126 |
|
127 |
-
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "
|
128 |
samples = [np.exp(e.cor_logprobs) for e in evals]
|
129 |
if prob_renorm:
|
130 |
samples = [renormed(e)[0] for e in evals]
|
131 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
|
134 |
class CorrIncorrDiffConfig(CdfPlotConfig):
|
135 |
-
name = "𝚫
|
136 |
xline = 0.0
|
137 |
xlabel = "𝚫"
|
138 |
-
ylabel = "% of samples with 𝚫
|
139 |
|
140 |
-
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "
|
141 |
score_arrs: List[NDArray[np.float64]] = []
|
142 |
for e in evals:
|
143 |
if prob_renorm:
|
@@ -148,4 +195,47 @@ class CorrIncorrDiffConfig(CdfPlotConfig):
|
|
148 |
|
149 |
score_arrs.append(cor_probs - inc_probs.max(axis=1))
|
150 |
|
151 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal, Optional
|
|
|
2 |
|
3 |
from numpy._typing import NDArray
|
4 |
from rbeval.plot.data import Eval, EvalGroup, Figure
|
|
|
6 |
import numpy as np
|
7 |
import altair as alt
|
8 |
import pandas as pd
|
9 |
+
from sklearn.metrics import roc_curve, roc_auc_score # type: ignore
|
10 |
|
11 |
+
from rbeval.plot.utils import PlotData, renormed
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def plot_cfgs():
|
15 |
+
return [
|
16 |
+
CorrectProbCdfPlot(),
|
17 |
+
CorrIncorrDiffConfig(),
|
18 |
+
ROCCurve(),
|
19 |
+
MaxIncorProbCdfPlot(),
|
20 |
+
AccVsLoss(),
|
21 |
+
AccVsAUC(),
|
22 |
+
]
|
23 |
|
24 |
|
25 |
def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
|
26 |
return [
|
27 |
a
|
28 |
for cfg in plot_cfgs()
|
29 |
+
for a in plot_with_data(cfg, get_plot_data(cfg, samples))
|
|
|
30 |
]
|
31 |
|
32 |
|
33 |
def get_plot_data(
|
34 |
cfg: "CdfPlotConfig",
|
35 |
samples: List[EvalGroup],
|
36 |
+
) -> pd.DataFrame:
|
37 |
+
records = []
|
38 |
for renorm in [True, False]:
|
|
|
39 |
for group in samples:
|
|
|
40 |
for m in group.model_evals:
|
41 |
spec = m.eval_spec
|
42 |
cdf = cfg.get_cdf(m.evals, renorm)
|
43 |
+
records.append(
|
44 |
{
|
45 |
+
"x": cdf.x,
|
46 |
+
"y": cdf.y,
|
47 |
"label": m.model_name,
|
48 |
"group": group.name,
|
49 |
"renorm": renorm,
|
50 |
"fewshot": spec.fewshot,
|
51 |
}
|
52 |
)
|
53 |
+
data = pd.DataFrame.from_records(records)
|
|
|
54 |
return data
|
55 |
|
56 |
|
57 |
def plot_with_data(
|
58 |
cfg: "CdfPlotConfig",
|
59 |
+
data: pd.DataFrame,
|
|
|
60 |
) -> List[Figure]:
|
61 |
figures: List[Figure] = []
|
62 |
+
for (group_name, renorm), df in data.groupby(["group", "renorm"]):
|
63 |
+
assert isinstance(group_name, str)
|
64 |
+
assert isinstance(renorm, (bool, np.bool_))
|
65 |
label_selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
|
66 |
fs_selection = alt.selection_point(fields=["fewshot"], bind="legend") # type: ignore
|
67 |
+
chart = alt.Chart(df.explode(["x", "y"])) # type: ignore
|
68 |
+
chart = chart.mark_line() if cfg.type == "line" else chart.mark_point()
|
69 |
chart = (
|
70 |
+
chart.encode(
|
71 |
+
x=alt.X("x:Q", title=cfg.xlabel, scale=alt.Scale(zero=False)),
|
72 |
+
y=alt.Y("y:Q", title=cfg.ylabel, scale=alt.Scale(zero=False)),
|
|
|
|
|
73 |
color=alt.Color(
|
74 |
"label:N", legend=alt.Legend(symbolOpacity=1.0, labelLimit=1000)
|
75 |
+
).scale(scheme="dark2"),
|
76 |
+
shape="label:N" if cfg.type == "scatter" else alt.Undefined,
|
77 |
opacity=alt.condition( # type: ignore
|
78 |
label_selection & fs_selection,
|
79 |
alt.Opacity("fewshot:O"),
|
80 |
alt.value(0.0), # type: ignore
|
81 |
),
|
82 |
)
|
83 |
+
.properties(title=cfg.title(group_name, renorm)) # type: ignore
|
84 |
.add_params(fs_selection, label_selection)
|
85 |
.interactive()
|
86 |
)
|
|
|
99 |
ylabel: str
|
100 |
name: str = ""
|
101 |
xline: Optional[float] = None
|
102 |
+
type: Literal["line", "scatter"] = "line"
|
103 |
|
104 |
@abstractmethod
|
105 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
106 |
pass
|
107 |
|
108 |
def title(self, group_name: str, prob_renorm: bool) -> str:
|
|
|
117 |
|
118 |
|
119 |
class CorrectProbCdfPlot(CdfPlotConfig):
|
120 |
+
name = "CDF(𝚽)"
|
121 |
xlabel = "𝚽"
|
122 |
+
ylabel = "% of correct answers with 𝚽 < x"
|
123 |
xline = 0.25
|
124 |
|
125 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
126 |
samples = [np.exp(e.cor_logprobs) for e in evals]
|
127 |
if prob_renorm:
|
128 |
samples = [renormed(e)[0] for e in evals]
|
129 |
+
return PlotData.perf_curve_from_samples(samples)
|
130 |
+
|
131 |
+
|
132 |
+
class MaxIncorProbCdfPlot(CdfPlotConfig):
|
133 |
+
name = "CDF(Max(Incorrect))"
|
134 |
+
xlabel = "max(incorrect)"
|
135 |
+
ylabel = "% of correct answers with max(incorrect) < x"
|
136 |
+
xline = 0.25
|
137 |
+
|
138 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
139 |
+
if prob_renorm:
|
140 |
+
samples = [renormed(e)[1].max(axis=1) for e in evals]
|
141 |
+
else:
|
142 |
+
samples = [np.exp(np.max(e.inc_logprobs, axis=1)) for e in evals]
|
143 |
+
return PlotData.perf_curve_from_samples(samples)
|
144 |
+
|
145 |
+
|
146 |
+
class AccVsLoss(CdfPlotConfig):
|
147 |
+
name = "Cross Entropy Loss vs Accuracy"
|
148 |
+
xlabel = "Accuracy"
|
149 |
+
ylabel = "CE Loss"
|
150 |
+
xline = None
|
151 |
+
type = "scatter"
|
152 |
+
|
153 |
+
def get_cdf(self, evals: List[Eval], _prob_renorm: bool) -> "PlotData":
|
154 |
+
cor, incor = zip(*[renormed(e) for e in evals])
|
155 |
+
cor = np.concatenate(cor)
|
156 |
+
incor = np.concatenate(incor).max(axis=1)
|
157 |
+
pct_corr = np.mean(cor > incor)
|
158 |
+
|
159 |
+
celoss = np.mean(-np.log(cor))
|
160 |
+
return PlotData(np.array([celoss]), np.array([pct_corr]))
|
161 |
+
|
162 |
+
|
163 |
+
class AccVsAUC(CdfPlotConfig):
|
164 |
+
name = "Simulated AUROC vs Accuracy"
|
165 |
+
xlabel = "Accuracy"
|
166 |
+
ylabel = "Simulated AUROC"
|
167 |
+
xline = None
|
168 |
+
type = "scatter"
|
169 |
+
|
170 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
171 |
+
cor, incor = zip(*[renormed(e) for e in evals])
|
172 |
+
cor = np.concatenate(cor)
|
173 |
+
incor = np.concatenate(incor).max(axis=1)
|
174 |
+
pct_corr = np.mean(cor > incor)
|
175 |
+
|
176 |
+
scores, labels, weights = roc_data(evals, prob_renorm)
|
177 |
+
auc = roc_auc_score(labels, scores, sample_weight=weights)
|
178 |
+
return PlotData(np.array([auc]), np.array([pct_corr]))
|
179 |
|
180 |
|
181 |
class CorrIncorrDiffConfig(CdfPlotConfig):
|
182 |
+
name = "CDF(𝚫)"
|
183 |
xline = 0.0
|
184 |
xlabel = "𝚫"
|
185 |
+
ylabel = "% of samples with 𝚫 < x"
|
186 |
|
187 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
188 |
score_arrs: List[NDArray[np.float64]] = []
|
189 |
for e in evals:
|
190 |
if prob_renorm:
|
|
|
195 |
|
196 |
score_arrs.append(cor_probs - inc_probs.max(axis=1))
|
197 |
|
198 |
+
return PlotData.perf_curve_from_samples(score_arrs, per_sample_weighting=True)
|
199 |
+
|
200 |
+
|
201 |
+
class ROCCurve(CdfPlotConfig):
|
202 |
+
name = "Simulated ROC Curve"
|
203 |
+
xline = None
|
204 |
+
xlabel = "FPR"
|
205 |
+
ylabel = "TPR"
|
206 |
+
|
207 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
208 |
+
scores, labels, weights = roc_data(evals, prob_renorm)
|
209 |
+
assert len(scores) == len(labels) == len(weights)
|
210 |
+
tpr, fpr, _ = roc_curve(labels, scores, sample_weight=weights)
|
211 |
+
|
212 |
+
x_interp = np.linspace(0, 1, 600)
|
213 |
+
y_interp = np.interp(x_interp, fpr, tpr)
|
214 |
+
|
215 |
+
return PlotData(x_interp, y_interp)
|
216 |
+
|
217 |
+
|
218 |
+
def roc_data(evals: List[Eval], prob_renorm):
|
219 |
+
weight_arrs = []
|
220 |
+
total = sum(len(e.cor_logprobs) for e in evals)
|
221 |
+
for samples in evals:
|
222 |
+
this = np.ones(2 * len(samples.cor_logprobs)) / (2 * total)
|
223 |
+
weight_arrs.append(this)
|
224 |
+
|
225 |
+
score_arrs = []
|
226 |
+
label_arrs = []
|
227 |
+
for e in evals:
|
228 |
+
if prob_renorm:
|
229 |
+
cor_probs, inc_probs = renormed(e)
|
230 |
+
else:
|
231 |
+
cor_probs = np.exp(e.cor_logprobs)
|
232 |
+
inc_probs = np.exp(e.inc_logprobs)
|
233 |
+
score_arrs.append(cor_probs)
|
234 |
+
label_arrs.append(np.ones(len(cor_probs)))
|
235 |
+
score_arrs.append(inc_probs.max(axis=1))
|
236 |
+
label_arrs.append(np.zeros(inc_probs.shape[0]))
|
237 |
+
|
238 |
+
scores = np.concatenate(score_arrs)
|
239 |
+
labels = np.concatenate(label_arrs)
|
240 |
+
weights = np.concatenate(weight_arrs)
|
241 |
+
return scores, labels, weights
|
src/rbeval/plot/utils.py
CHANGED
@@ -17,14 +17,17 @@ def renormed(eval: Eval) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
17 |
|
18 |
|
19 |
@dataclass
|
20 |
-
class
|
21 |
-
|
22 |
-
|
23 |
|
24 |
@classmethod
|
25 |
-
def
|
26 |
-
cls,
|
27 |
-
|
|
|
|
|
|
|
28 |
num_cats = len(samples)
|
29 |
scores = np.concatenate(samples)
|
30 |
if per_sample_weighting:
|
@@ -37,26 +40,29 @@ class CdfData:
|
|
37 |
weights = np.concatenate(weight_arrs)
|
38 |
else:
|
39 |
weights = np.ones_like(scores) / len(scores)
|
40 |
-
return cls.
|
41 |
|
42 |
@classmethod
|
43 |
-
def
|
44 |
cls,
|
45 |
weights: NDArray[np.float64],
|
46 |
base_scores: NDArray[np.float64],
|
47 |
max_p: int = 600,
|
48 |
-
|
|
|
49 |
sort_perm = base_scores.argsort()
|
50 |
base_weights = weights[sort_perm]
|
51 |
base_scores = base_scores[sort_perm]
|
52 |
-
base_cdf_p =
|
|
|
|
|
53 |
minscore, maxscore = base_scores[0], base_scores[-1]
|
54 |
if len(base_scores) > max_p:
|
55 |
scores = np.linspace(minscore, maxscore, max_p) # type: ignore
|
56 |
cdf_p = np.interp(scores, base_scores, base_cdf_p) # type: ignore
|
57 |
else:
|
58 |
scores, cdf_p = base_scores, base_cdf_p
|
59 |
-
return
|
60 |
-
|
61 |
-
|
62 |
)
|
|
|
17 |
|
18 |
|
19 |
@dataclass
|
20 |
+
class PlotData:
|
21 |
+
y: np.ndarray
|
22 |
+
x: np.ndarray
|
23 |
|
24 |
@classmethod
|
25 |
+
def perf_curve_from_samples(
|
26 |
+
cls,
|
27 |
+
samples: List[NDArray[np.float64]],
|
28 |
+
per_sample_weighting: bool = True,
|
29 |
+
one_minus: bool = False,
|
30 |
+
) -> "PlotData":
|
31 |
num_cats = len(samples)
|
32 |
scores = np.concatenate(samples)
|
33 |
if per_sample_weighting:
|
|
|
40 |
weights = np.concatenate(weight_arrs)
|
41 |
else:
|
42 |
weights = np.ones_like(scores) / len(scores)
|
43 |
+
return cls.perf_curve_from_weights(weights, scores, one_minus=one_minus)
|
44 |
|
45 |
@classmethod
|
46 |
+
def perf_curve_from_weights(
|
47 |
cls,
|
48 |
weights: NDArray[np.float64],
|
49 |
base_scores: NDArray[np.float64],
|
50 |
max_p: int = 600,
|
51 |
+
one_minus: bool = True,
|
52 |
+
) -> "PlotData":
|
53 |
sort_perm = base_scores.argsort()
|
54 |
base_weights = weights[sort_perm]
|
55 |
base_scores = base_scores[sort_perm]
|
56 |
+
base_cdf_p = np.cumsum(base_weights)
|
57 |
+
if one_minus:
|
58 |
+
base_cdf_p = 1 - base_cdf_p
|
59 |
minscore, maxscore = base_scores[0], base_scores[-1]
|
60 |
if len(base_scores) > max_p:
|
61 |
scores = np.linspace(minscore, maxscore, max_p) # type: ignore
|
62 |
cdf_p = np.interp(scores, base_scores, base_cdf_p) # type: ignore
|
63 |
else:
|
64 |
scores, cdf_p = base_scores, base_cdf_p
|
65 |
+
return PlotData(
|
66 |
+
y=cdf_p,
|
67 |
+
x=scores, # type: ignore
|
68 |
)
|