William Arnold
commited on
Commit
·
3cb144f
1
Parent(s):
1e00fb8
Add yj score
Browse files- src/rbeval/plot/score_cdf.py +19 -0
src/rbeval/plot/score_cdf.py
CHANGED
@@ -19,6 +19,7 @@ def plot_cfgs():
|
|
19 |
#MaxIncorProbCdfPlot(),
|
20 |
AccVsLoss(),
|
21 |
AccVsAUC(),
|
|
|
22 |
]
|
23 |
|
24 |
|
@@ -219,6 +220,24 @@ class ROCCurve(CdfPlotConfig):
|
|
219 |
|
220 |
return PlotData(x_interp, y_interp)
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
def roc_data(evals: List[Eval], prob_renorm):
|
224 |
weight_arrs = []
|
|
|
19 |
#MaxIncorProbCdfPlot(),
|
20 |
AccVsLoss(),
|
21 |
AccVsAUC(),
|
22 |
+
YjVsAcc(),
|
23 |
]
|
24 |
|
25 |
|
|
|
220 |
|
221 |
return PlotData(x_interp, y_interp)
|
222 |
|
223 |
+
class YjVsAcc(CdfPlotConfig):
|
224 |
+
name = "Yj vs Accuracy"
|
225 |
+
xline = None
|
226 |
+
xlabel = "Accuracy"
|
227 |
+
ylabel = "Yj"
|
228 |
+
type = "scatter"
|
229 |
+
|
230 |
+
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
|
231 |
+
cor, incor = zip(*[renormed(e) for e in evals])
|
232 |
+
cor = np.concatenate(cor)
|
233 |
+
incor = np.concatenate(incor).max(axis=1)
|
234 |
+
delta = cor - incor
|
235 |
+
pos = delta[delta > 0].mean()
|
236 |
+
neg = -delta[delta <= 0].mean()
|
237 |
+
pct_corr = np.mean(cor > incor)
|
238 |
+
|
239 |
+
yj = pos - neg
|
240 |
+
return PlotData(np.array([yj]), np.array([pct_corr]))
|
241 |
|
242 |
def roc_data(evals: List[Eval], prob_renorm):
|
243 |
weight_arrs = []
|