''' Grad-CAM visualization demo 2021-12-18 first created ''' from PIL import Image import matplotlib.pyplot as plt from PIL import Image import os import io from glob import glob from loguru import logger import gradio as gr from utils import (get_imagenet_classes, get_xception_model, get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap) # ----- Settings ----- GPU_ID = '-1' os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID EXAMPLE_DIR = 'examples' CMAP_CHOICES = ['jet', 'rainbow', 'gist_ncar', 'autumn', 'hot', 'winter', 'hsv'] examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg'))) examples = [[image, 'French_bulldog', 0.3, 'jet'] for image in examples] # ----- Logging ----- logger.add('app.log', mode='a') logger.info('===== APP RESTARTED =====') # ----- Model ----- model, grad_model, preprocessor, decode_predictions = get_xception_model() idx2lab, lab2idx = get_imagenet_classes() classes = ['none'] + sorted(list(lab2idx.keys())) def predict(image_obj, pred_class, alpha, cmap): image_file = image_obj.name logger.info(f'--- image loaded: class={pred_class} | alpha={alpha} | cmap={cmap}') img = Image.open(image_file) width = img.size[0] height = img.size[1] img_4d_array = get_img_4d_array(image_file) img_4d_array = preprocessor(img_4d_array) if pred_class == 'none': pred_idx = None else: pred_idx = lab2idx[pred_class] heatmap = make_gradcam_heatmap(grad_model, img_4d_array, pred_idx=pred_idx) img_pil = align_image_with_heatmap(img_4d_array, heatmap, alpha=0.3, cmap=cmap) img_pil = img_pil.resize((width, height)) logger.info('--- Grad-CAM visualized') return img_pil iface = gr.Interface( predict, title='Gradient Class Actiavtion Map (Grad-CAM) Visualization Demo', description='Provide an image with image class or just image alone. For all 1000 imagenet classes, see https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a', inputs=[ gr.inputs.Image(label='Input image', type='file'), gr.inputs.Dropdown(label='Predicted class (if "none", predicted class will be used)', choices=classes, default='none', type='value'), gr.inputs.Slider(label='Output image alpha level for heatmap', minimum=0, maximum=1, step=0.1, default=0.4), gr.inputs.Dropdown(label='Grad-CAM heatmap colormap', choices=CMAP_CHOICES, default='jet', type='value'), ], outputs=[ gr.outputs.Image(label='Output image', type='pil') ], examples=examples, article='
Based on the example written by fchollet
', ) iface.launch(debug=True, enable_queue=True)