chanycha commited on
Commit
20171bc
1 Parent(s): cd20edc
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +5 -5
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .vscode
2
+ .tmp/res
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import time
3
  import os
4
  import gradio as gr
5
- # import spaces
6
  from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
7
  from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
8
  import matplotlib.pyplot as plt
@@ -343,7 +343,7 @@ class Experiment(Component):
343
  _plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE)
344
  return _plots
345
 
346
- # @spaces.GPU
347
  def render_plots(data_id, *metric_inputs):
348
  # Clear Cache Files
349
  # print(f"GPU Check: {torch.cuda.is_available()}")
@@ -485,7 +485,7 @@ class ExplainerCheckbox(Component):
485
  idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
486
  return metric_info[1][idx]
487
 
488
- # @spaces.GPU
489
  def optimize(self):
490
  # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
491
  # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
@@ -649,8 +649,8 @@ from torch.utils.data import DataLoader
649
  from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
650
 
651
  os.environ['GRADIO_TEMP_DIR'] = '.tmp'
652
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
653
- device = torch.device("cpu")
654
 
655
  def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
656
 
 
2
  import time
3
  import os
4
  import gradio as gr
5
+ import spaces
6
  from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification
7
  from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace
8
  import matplotlib.pyplot as plt
 
343
  _plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE)
344
  return _plots
345
 
346
+ @spaces.GPU
347
  def render_plots(data_id, *metric_inputs):
348
  # Clear Cache Files
349
  # print(f"GPU Check: {torch.cuda.is_available()}")
 
485
  idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name)
486
  return metric_info[1][idx]
487
 
488
+ @spaces.GPU
489
  def optimize(self):
490
  # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]:
491
  # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.")
 
649
  from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image
650
 
651
  os.environ['GRADIO_TEMP_DIR'] = '.tmp'
652
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
653
+ # device = torch.device("cpu")
654
 
655
  def target_visualizer(x): return dataset.dataset.idx_to_label(x.item())
656