# python image_gradio.py >> ./logs/image_gradio.log 2>&1 import time import os import gradio as gr import spaces from pnpxai.core.experiment.auto_explanation import AutoExplanationForImageClassification from pnpxai.core.detector.detector import extract_graph_data, symbolic_trace import matplotlib.pyplot as plt import plotly.graph_objects as go import plotly.express as px import networkx as nx import secrets PLOT_PER_LINE = 4 N_FEATURES_TO_SHOW = 5 OPT_N_TRIALS = 10 OBJECTIVE_METRIC = "AbPC" SAMPLE_METHOD = "tpe" DEFAULT_EXPLAINER = ["GradientXInput", "IntegratedGradients", "LRPEpsilonPlus"] class App: def __init__(self): pass class Component: def __init__(self): pass class Tab(Component): def __init__(self): pass class OverviewTab(Tab): def __init__(self): pass def show(self): with gr.Tab(label="Overview") as tab: gr.Label("This is the overview tab.") gr.HTML(self.desc()) def desc(self): with open("static/overview.html", "r") as f: desc = f.read() return desc class DetectionTab(Tab): def __init__(self, experiments): self.experiments = experiments def show(self): with gr.Tab(label="Detection") as tab: gr.Label("This is the detection tab.") for nm, exp_info in self.experiments.items(): exp = exp_info['experiment'] detector_res = DetectorRes(exp) detector_res.show() class LocalExpTab(Tab): def __init__(self, experiments): self.experiments = experiments self.experiment_components = [] for nm, exp_info in self.experiments.items(): self.experiment_components.append(Experiment(exp_info)) def description(self): return "This tab shows the local explanation." def show(self): with gr.Tab(label="Local Explanation") as tab: gr.Label("This is the local explanation tab.") for i, exp in enumerate(self.experiments): self.experiment_components[i].show() class DetectorRes(Component): def __init__(self, experiment): self.experiment = experiment graph_module = symbolic_trace(experiment.model) self.graph_data = extract_graph_data(graph_module) def describe(self): return "This component shows the detection result." def show(self): G = nx.DiGraph() root = None for node in self.graph_data['nodes']: if node['op'] == 'placeholder': root = node['name'] G.add_node(node['name']) for edge in self.graph_data['edges']: if edge['source'] in G.nodes and edge['target'] in G.nodes: G.add_edge(edge['source'], edge['target']) def get_pos1(graph): graph = graph.copy() for layer, nodes in enumerate(reversed(tuple(nx.topological_generations(graph)))): for node in nodes: graph.nodes[node]["layer"] = layer pos = nx.multipartite_layout(graph, subset_key="layer", align='horizontal') return pos def get_pos2(graph, root, levels=None, width=1., height=1.): ''' G: the graph root: the root node levels: a dictionary key: level number (starting from 0) value: number of nodes in this level width: horizontal space allocated for drawing height: vertical space allocated for drawing ''' TOTAL = "total" CURRENT = "current" def make_levels(levels, node=root, currentLevel=0, parent=None): # Compute the number of nodes for each level if not currentLevel in levels: levels[currentLevel] = {TOTAL: 0, CURRENT: 0} levels[currentLevel][TOTAL] += 1 neighbors = graph.neighbors(node) for neighbor in neighbors: if not neighbor == parent: levels = make_levels(levels, neighbor, currentLevel + 1, node) return levels def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0): dx = 1/levels[currentLevel][TOTAL] left = dx/2 pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc) levels[currentLevel][CURRENT] += 1 neighbors = graph.neighbors(node) for neighbor in neighbors: if not neighbor == parent: pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc-vert_gap) return pos if levels is None: levels = make_levels({}) else: levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels} vert_gap = height / (max([l for l in levels])+1) return make_pos({}) def plot_graph(graph, pos): fig = plt.figure(figsize=(12, 24)) ax = fig.gca() nx.draw(graph, pos=pos, with_labels=True, node_size=60, font_size=8, ax=ax) fig.tight_layout() return fig pos = get_pos1(G) fig = plot_graph(G, pos) # pos = get_pos2(G, root) # fig = plot_graph(G, pos) with gr.Row(): gr.Textbox(value="Image Classficiation", label="Task") gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model") gr.Plot(value=fig, label=f"Model Architecture of {self.experiment.model.__class__.__name__}", visible=True) class ImgGallery(Component): def __init__(self, imgs): self.imgs = imgs self.selected_index = gr.Number(value=0, label="Selected Index", visible=False) def on_select(self, evt: gr.SelectData): return evt.index def show(self): self.gallery_obj = gr.Gallery(value=self.imgs, label="Input Data Gallery", columns=6, height=200) self.gallery_obj.select(self.on_select, outputs=self.selected_index) class Experiment(Component): def __init__(self, exp_info): self.exp_info = exp_info self.experiment = exp_info['experiment'] self.input_visualizer = exp_info['input_visualizer'] self.target_visualizer = exp_info['target_visualizer'] def viz_input(self, input, data_id): orig_img_np = self.input_visualizer(input) orig_img = px.imshow(orig_img_np) orig_img.update_layout( title=f"Data ID: {data_id}", width=400, height=350, xaxis=dict( showticklabels=False, ticks='', showgrid=False ), yaxis=dict( showticklabels=False, ticks='', showgrid=False ), ) return orig_img def get_prediction(self, record, topk=3): probs = record['output'].softmax(-1).squeeze().detach().numpy() text = f"Ground Truth Label: {self.target_visualizer(record['label'])}\n" for ind, pred in enumerate(probs.argsort()[-topk:][::-1]): label = self.target_visualizer(torch.tensor(pred)) prob = probs[pred] text += f"Top {ind+1} Prediction: {label} ({prob:.2f})\n" return text def get_exp_plot(self, data_index, exp_res): return ExpRes(data_index, exp_res).show() def get_metric_id_by_name(self, metric_name): metric_info = self.experiment.manager.get_metrics() idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name) return metric_info[1][idx] def generate_record(self, data_id, metric_names): record = {} _base = self.experiment.run_batch([data_id], 0, 0, 0) record['data_id'] = data_id record['input'] = _base['inputs'] record['label'] = _base['labels'] record['output'] = _base['outputs'] record['target'] = _base['targets'] record['explanations'] = [] metrics_ids = [self.get_metric_id_by_name(metric_nm) for metric_nm in metric_names] cnt = 0 for info in self.explainer_checkbox_group.info: if info['checked']: base = self.experiment.run_batch([data_id], info['id'], info['pp_id'], 0) record['explanations'].append({ 'explainer_nm': base['explainer'].__class__.__name__, 'value': base['postprocessed'], 'mode' : info['mode'], 'evaluations': [] }) for metric_id in metrics_ids: res = self.experiment.run_batch([data_id], info['id'], info['pp_id'], metric_id) record['explanations'][-1]['evaluations'].append({ 'metric_nm': res['metric'].__class__.__name__, 'value' : res['evaluation'] }) cnt += 1 # Sort record['explanations'] with respect to the metric values if len(record['explanations'][0]['evaluations']) > 0: record['explanations'] = sorted(record['explanations'], key=lambda x: x['evaluations'][0]['value'], reverse=True) return record def show(self): with gr.Row(): gr.Textbox(value="Image Classficiation", label="Task") gr.Textbox(value=f"{self.experiment.model.__class__.__name__}", label="Model") gr.Textbox(value="Heatmap", label="Explanation Type") dset = self.experiment.manager._data.dataset imgs = [] for i in range(len(dset)): img = self.input_visualizer(dset[i][0]) imgs.append(img) gallery = ImgGallery(imgs) gallery.show() explainers, _ = self.experiment.manager.get_explainers() explainer_names = [exp.__class__.__name__ for exp in explainers] self.explainer_checkbox_group = ExplainerCheckboxGroup(explainer_names, self.experiment, gallery) self.explainer_checkbox_group.show() cr_metrics_names = ["AbPC", "MoRF", "LeRF", "MuFidelity"] cn_metrics_names = ["Sensitivity"] cp_metrics_names = ["Complexity"] with gr.Accordion("Evaluators", open=True): with gr.Row(): cr_metrics = gr.CheckboxGroup(choices=cr_metrics_names, value=[cr_metrics_names[0]], label="Correctness") def on_select(metrics): if cr_metrics_names[0] not in metrics: gr.Warning(f"{cr_metrics_names[0]} is required for the sorting the explanations.") return [cr_metrics_names[0]] + metrics else: return metrics cr_metrics.select(on_select, inputs=cr_metrics, outputs=cr_metrics) with gr.Row(): # cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, value=cn_metrics_names, label="Continuity") cn_metrics = gr.CheckboxGroup(choices=cn_metrics_names, label="Continuity") with gr.Row(): # cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, value=cp_metrics_names[0], label="Compactness") cp_metrics = gr.CheckboxGroup(choices=cp_metrics_names, label="Compactness") metric_inputs = [cr_metrics, cn_metrics, cp_metrics] data_id = gallery.selected_index bttn = gr.Button("Explain", variant="primary") buffer_size = 2 * len(explainer_names) buffer_n_rows = buffer_size // PLOT_PER_LINE buffer_n_rows = buffer_n_rows + 1 if buffer_size % PLOT_PER_LINE != 0 else buffer_n_rows plots = [gr.Textbox(label="Prediction result", visible=False)] for i in range(buffer_n_rows): with gr.Row(): for j in range(PLOT_PER_LINE): plot = gr.Image(value=None, label="Blank", visible=False) plots.append(plot) def show_plots(): _plots = [gr.Textbox(label="Prediction result", visible=False)] num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']]) n_rows = num_plots // PLOT_PER_LINE n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows _plots += [gr.Image(value=None, label="Blank", visible=True)] * (n_rows * PLOT_PER_LINE) _plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE) return _plots @spaces.GPU def render_plots(data_id, *metric_inputs): # Clear Cache Files print(f"GPU Check: {torch.cuda.is_available()}") print("Which GPU: ", torch.cuda.current_device()) cache_dir = f"{os.environ['GRADIO_TEMP_DIR']}/res" if not os.path.exists(cache_dir): os.makedirs(cache_dir) for f in os.listdir(cache_dir): if len(f.split(".")[0]) == 16: os.remove(os.path.join(cache_dir, f)) # Render Plots metric_input = [] for metric in metric_inputs: if metric: metric_input += metric record = self.generate_record(data_id, metric_input) pred = self.get_prediction(record) plots = [gr.Textbox(label="Prediction result", value=pred, visible=True)] num_plots = sum([1 for info in self.explainer_checkbox_group.info if info['checked']]) n_rows = num_plots // PLOT_PER_LINE n_rows = n_rows + 1 if num_plots % PLOT_PER_LINE != 0 else n_rows for i in range(n_rows): for j in range(PLOT_PER_LINE): if i*PLOT_PER_LINE+j < len(record['explanations']): exp_res = record['explanations'][i*PLOT_PER_LINE+j] path = self.get_exp_plot(data_id, exp_res) plot_obj = gr.Image(value=path, label=f"{exp_res['explainer_nm']} ({exp_res['mode']})", visible=True) plots.append(plot_obj) else: plots.append(gr.Image(value=None, label="Blank", visible=True)) plots += [gr.Image(value=None, label="Blank", visible=False)] * ((buffer_n_rows - n_rows) * PLOT_PER_LINE) return plots bttn.click(show_plots, outputs=plots) bttn.click(render_plots, inputs=[data_id] + metric_inputs, outputs=plots) class ExplainerCheckboxGroup(Component): def __init__(self, explainer_names, experiment, gallery): super().__init__() self.explainer_names = explainer_names self.explainer_objs = [] self.experiment = experiment self.gallery = gallery explainers, exp_ids = self.experiment.manager.get_explainers() self.info = [] for exp, exp_id in zip(explainers, exp_ids): exp_nm = exp.__class__.__name__ if exp_nm in DEFAULT_EXPLAINER: checked = True else: checked = False self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : 0, 'mode': 'default', 'checked': checked}) def update_check(self, exp_id, val=None): for info in self.info: if info['id'] == exp_id: if val is not None: info['checked'] = val else: info['checked'] = not info['checked'] def insert_check(self, exp_nm, exp_id, pp_id): if exp_id in [info['id'] for info in self.info]: return self.info.append({'nm': exp_nm, 'id': exp_id, 'pp_id' : pp_id, 'mode': 'optimal', 'checked': False}) def update_gallery_change(self): checkboxes = [] bttns = [] for exp in self.explainer_objs: val = exp.explainer_name in DEFAULT_EXPLAINER checkboxes.append(gr.Checkbox(label="Default Parameter", value=val, interactive=True)) checkboxes += [gr.Checkbox(label="Optimized Parameter (Not Optimal)", value=False, interactive=False)] * len(self.explainer_objs) bttns += [gr.Button(value="Optimize", size="sm", variant="primary")] * len(self.explainer_objs) for exp in self.explainer_objs: val = exp.explainer_name in DEFAULT_EXPLAINER self.update_check(exp.default_exp_id, val) if hasattr(exp, "optimal_exp_id"): self.update_check(exp.optimal_exp_id, False) return checkboxes + bttns def get_checkboxes(self): checkboxes = [] checkboxes += [exp.default_check for exp in self.explainer_objs] checkboxes += [exp.opt_check for exp in self.explainer_objs] return checkboxes def get_bttns(self): return [exp.bttn for exp in self.explainer_objs] def show(self): cnt = 0 sorted_info = sorted(self.info, key=lambda x: (x['nm'] not in DEFAULT_EXPLAINER, x['nm'])) with gr.Accordion("Explainers", open=True): while cnt * PLOT_PER_LINE < len(self.explainer_names): with gr.Row(): for info in sorted_info[cnt*PLOT_PER_LINE:(cnt+1)*PLOT_PER_LINE]: explainer_obj = ExplainerCheckbox(info['nm'], self, self.experiment, self.gallery) self.explainer_objs.append(explainer_obj) explainer_obj.show() cnt += 1 checkboxes = self.get_checkboxes() bttns = self.get_bttns() self.gallery.gallery_obj.select( fn=self.update_gallery_change, outputs=checkboxes + bttns ) class ExplainerCheckbox(Component): def __init__(self, explainer_name, groups, experiment, gallery): self.explainer_name = explainer_name self.groups = groups self.experiment = experiment self.gallery = gallery self.default_exp_id = self.get_explainer_id_by_name(explainer_name) self.obj_metric = self.get_metric_id_by_name(OBJECTIVE_METRIC) def get_explainer_id_by_name(self, explainer_name): explainer_info = self.experiment.manager.get_explainers() idx = [exp.__class__.__name__ for exp in explainer_info[0]].index(explainer_name) return explainer_info[1][idx] def get_metric_id_by_name(self, metric_name): metric_info = self.experiment.manager.get_metrics() idx = [metric.__class__.__name__ for metric in metric_info[0]].index(metric_name) return metric_info[1][idx] @spaces.GPU def optimize(self): # if self.explainer_name in ["Lime", "KernelShap", "IntegratedGradients"]: # gr.Info("Lime, KernelShap and IntegratedGradients currently do not support hyperparameter optimization.") # return [gr.update()] * 2 data_id = self.gallery.selected_index opt_output = self.experiment.optimize( data_ids=data_id.value, explainer_id=self.default_exp_id, metric_id=self.obj_metric, direction='maximize', sampler=SAMPLE_METHOD, n_trials=OPT_N_TRIALS, ) def get_str_ppid(pp_obj): return pp_obj.pooling_fn.__class__.__name__ + pp_obj.normalization_fn.__class__.__name__ str_id = get_str_ppid(opt_output.postprocessor) for pp_obj, pp_id in zip(*self.experiment.manager.get_postprocessors()): if get_str_ppid(pp_obj) == str_id: opt_postprocessor_id = pp_id break opt_explainer_id = max([x['id'] for x in self.groups.info]) + 1 opt_output.explainer.model = self.experiment.model self.experiment.manager._explainers.append(opt_output.explainer) self.experiment.manager._explainer_ids.append(opt_explainer_id) self.groups.insert_check(self.explainer_name, opt_explainer_id, opt_postprocessor_id) self.optimal_exp_id = opt_explainer_id checkbox = gr.update(label="Optimized Parameter (Optimal)", interactive=True) bttn = gr.update(value="Optimized", variant="secondary") return [checkbox, bttn] def default_on_select(self, evt: gr.EventData): self.groups.update_check(self.default_exp_id, evt._data['value']) def optimal_on_select(self, evt: gr.EventData): if hasattr(self, "optimal_exp_id"): self.groups.update_check(self.optimal_exp_id, evt._data['value']) else: raise ValueError("Optimal explainer id is not found.") def show(self): val = self.explainer_name in DEFAULT_EXPLAINER with gr.Accordion(self.explainer_name, open=val): checked = next(filter(lambda x: x['nm'] == self.explainer_name, self.groups.info))['checked'] self.default_check = gr.Checkbox(label="Default Parameter", value=checked, interactive=True) self.opt_check = gr.Checkbox(label="Optimized Parameter (Not Optimal)", interactive=False) self.default_check.select(self.default_on_select) self.opt_check.select(self.optimal_on_select) self.bttn = gr.Button(value="Optimize", size="sm", variant="primary") self.bttn.click(self.optimize, outputs=[self.opt_check, self.bttn], queue=True, concurrency_limit=1) class ExpRes(Component): def __init__(self, data_index, exp_res): self.data_index = data_index self.exp_res = exp_res def show(self): value = self.exp_res['value'] fig = go.Figure(data=go.Heatmap( z=np.flipud(value[0].detach().numpy()), colorscale='Reds', showscale=False # remove color bar )) evaluations = self.exp_res['evaluations'] metric_values = [f"{eval['metric_nm'][:4]}: {eval['value'].item():.2f}" for eval in evaluations if eval['value'] is not None] n = 3 cnt = 0 while cnt * n < len(metric_values): metric_text = ', '.join(metric_values[cnt*n:cnt*n+n]) fig.add_annotation( x=0, y=-0.1 * (cnt+1), xref='paper', yref='paper', text=metric_text, showarrow=False, font=dict( size=18, ), ) cnt += 1 fig = fig.update_layout( width=380, height=400, xaxis=dict( showticklabels=False, ticks='', showgrid=False ), yaxis=dict( showticklabels=False, ticks='', showgrid=False ), margin=dict(t=40, b=40*cnt, l=20, r=20), ) # Generate Random Unique ID root = f"{os.environ['GRADIO_TEMP_DIR']}/res" if not os.path.exists(root): os.makedirs(root) key = secrets.token_hex(8) path = f"{root}/{key}.png" fig.write_image(path) return path class ImageClsApp(App): def __init__(self, experiments, **kwargs): self.name = "Image Classification App" super().__init__(**kwargs) self.experiments = experiments self.overview_tab = OverviewTab() self.detection_tab = DetectionTab(self.experiments) self.local_exp_tab = LocalExpTab(self.experiments) def title(self): return f"""

Plug and Play XAI Platform for Image Classification

""" def launch(self, **kwargs): with gr.Blocks( title=self.name, ) as demo: file_path = os.path.dirname(os.path.abspath(__file__)) gr.set_static_paths(file_path) gr.HTML(self.title()) self.overview_tab.show() self.detection_tab.show() self.local_exp_tab.show() return demo # if __name__ == '__main__': import os import torch import numpy as np from torch.utils.data import DataLoader from helpers import get_imagenet_dataset, get_torchvision_model, denormalize_image os.environ['GRADIO_TEMP_DIR'] = '.tmp' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") def target_visualizer(x): return dataset.dataset.idx_to_label(x.item()) experiments = {} model, transform = get_torchvision_model('resnet18') dataset = get_imagenet_dataset(transform) loader = DataLoader(dataset, batch_size=4, shuffle=False) experiment1 = AutoExplanationForImageClassification( model=model.to(device), data=loader, input_extractor=lambda batch: batch[0].to(device), label_extractor=lambda batch: batch[-1].to(device), target_extractor=lambda outputs: outputs.argmax(-1).to(device), channel_dim=1 ) experiments['experiment1'] = { 'name': 'ResNet18', 'experiment': experiment1, 'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std), 'target_visualizer': target_visualizer, } model, transform = get_torchvision_model('vit_b_16') dataset = get_imagenet_dataset(transform) loader = DataLoader(dataset, batch_size=4, shuffle=False) experiment2 = AutoExplanationForImageClassification( model=model.to(device), data=loader, input_extractor=lambda batch: batch[0].to(device), label_extractor=lambda batch: batch[-1].to(device), target_extractor=lambda outputs: outputs.argmax(-1).to(device), channel_dim=1 ) experiments['experiment2'] = { 'name': 'ViT-B_16', 'experiment': experiment2, 'input_visualizer': lambda x: denormalize_image(x, transform.mean, transform.std), 'target_visualizer': target_visualizer, } app = ImageClsApp(experiments) demo = app.launch() demo.launch(favicon_path=f"static/XAI-Top-PnP.svg")