William Arnold commited on
Commit
3cb144f
·
1 Parent(s): 1e00fb8

Add yj score

Browse files
Files changed (1) hide show
  1. src/rbeval/plot/score_cdf.py +19 -0
src/rbeval/plot/score_cdf.py CHANGED
@@ -19,6 +19,7 @@ def plot_cfgs():
19
  #MaxIncorProbCdfPlot(),
20
  AccVsLoss(),
21
  AccVsAUC(),
 
22
  ]
23
 
24
 
@@ -219,6 +220,24 @@ class ROCCurve(CdfPlotConfig):
219
 
220
  return PlotData(x_interp, y_interp)
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  def roc_data(evals: List[Eval], prob_renorm):
224
  weight_arrs = []
 
19
  #MaxIncorProbCdfPlot(),
20
  AccVsLoss(),
21
  AccVsAUC(),
22
+ YjVsAcc(),
23
  ]
24
 
25
 
 
220
 
221
  return PlotData(x_interp, y_interp)
222
 
223
+ class YjVsAcc(CdfPlotConfig):
224
+ name = "Yj vs Accuracy"
225
+ xline = None
226
+ xlabel = "Accuracy"
227
+ ylabel = "Yj"
228
+ type = "scatter"
229
+
230
+ def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "PlotData":
231
+ cor, incor = zip(*[renormed(e) for e in evals])
232
+ cor = np.concatenate(cor)
233
+ incor = np.concatenate(incor).max(axis=1)
234
+ delta = cor - incor
235
+ pos = delta[delta > 0].mean()
236
+ neg = -delta[delta <= 0].mean()
237
+ pct_corr = np.mean(cor > incor)
238
+
239
+ yj = pos - neg
240
+ return PlotData(np.array([yj]), np.array([pct_corr]))
241
 
242
  def roc_data(evals: List[Eval], prob_renorm):
243
  weight_arrs = []