chanycha commited on
Commit
5580778
1 Parent(s): 722ed82
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -484,12 +484,12 @@ class ExplainerCheckbox(Component):
484
  metric_info = self.experiment.manager.get_metrics()
485
  idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
486
  return metric_info[1][idx]
 
 
 
487
 
488
- # @spaces.GPU
489
  def optimize(self):
490
- # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
491
- # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
492
- # return [gr.update()] * 2
493
  data_id = self.gallery.selected_index
494
 
495
  opt_output = self.experiment.optimize(
@@ -501,13 +501,9 @@ class ExplainerCheckbox(Component):
501
  n_trials=OPT_N_TRIALS,
502
  )
503
 
504
-
505
- def get_str_ppid(pp_obj):
506
- return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
507
-
508
- str_id = get_str_ppid(opt_output.postprocessor)
509
  for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()):
510
- if get_str_ppid(pp_obj) == str_id:
511
  opt_postprocessor_id = pp_id
512
  break
513
 
 
484
  metric_info = self.experiment.manager.get_metrics()
485
  idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
486
  return metric_info[1][idx]
487
+
488
+ def get_str_ppid(self, pp_obj):
489
+ return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__
490
 
491
+ @spaces.GPU
492
  def optimize(self):
 
 
 
493
  data_id = self.gallery.selected_index
494
 
495
  opt_output = self.experiment.optimize(
 
501
  n_trials=OPT_N_TRIALS,
502
  )
503
 
504
+ str_id = self.get_str_ppid(opt_output.postprocessor)
 
 
 
 
505
  for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()):
506
+ if self.get_str_ppid(pp_obj) == str_id:
507
  opt_postprocessor_id = pp_id
508
  break
509