demo-gradcam-imagenet / gradio_gradcam.py
jaekookang
update launch methd
4b9bd3c
'''
Grad-CAM visualization demo
2021-12-18 first created
'''
from PIL import Image
import matplotlib.pyplot as plt
from PIL import Image
import os
import io
from glob import glob
from loguru import logger
import gradio as gr
from utils import (get_imagenet_classes, get_xception_model, get_img_4d_array,
make_gradcam_heatmap, align_image_with_heatmap)
# ----- Settings -----
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
EXAMPLE_DIR = 'examples'
CMAP_CHOICES = ['jet', 'rainbow', 'gist_ncar', 'autumn', 'hot', 'winter', 'hsv']
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg')))
examples = [[image, 'French_bulldog', 0.3, 'jet'] for image in examples]
# ----- Logging -----
logger.add('app.log', mode='a')
logger.info('===== APP RESTARTED =====')
# ----- Model -----
model, grad_model, preprocessor, decode_predictions = get_xception_model()
idx2lab, lab2idx = get_imagenet_classes()
classes = ['none'] + sorted(list(lab2idx.keys()))
def predict(image_obj, pred_class, alpha, cmap):
image_file = image_obj.name
logger.info(f'--- image loaded: class={pred_class} | alpha={alpha} | cmap={cmap}')
img = Image.open(image_file)
width = img.size[0]
height = img.size[1]
img_4d_array = get_img_4d_array(image_file)
img_4d_array = preprocessor(img_4d_array)
if pred_class == 'none':
pred_idx = None
else:
pred_idx = lab2idx[pred_class]
heatmap = make_gradcam_heatmap(grad_model, img_4d_array, pred_idx=pred_idx)
img_pil = align_image_with_heatmap(img_4d_array, heatmap, alpha=0.3, cmap=cmap)
img_pil = img_pil.resize((width, height))
logger.info('--- Grad-CAM visualized')
return img_pil
iface = gr.Interface(
predict,
title='Gradient Class Actiavtion Map (Grad-CAM) Visualization Demo',
description='Provide an image with image class or just image alone. For all 1000 imagenet classes, see https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a',
inputs=[
gr.inputs.Image(label='Input image', type='file'),
gr.inputs.Dropdown(label='Predicted class (if "none", predicted class will be used)',
choices=classes, default='none', type='value'),
gr.inputs.Slider(label='Output image alpha level for heatmap',
minimum=0, maximum=1, step=0.1, default=0.4),
gr.inputs.Dropdown(label='Grad-CAM heatmap colormap',
choices=CMAP_CHOICES, default='jet', type='value'),
],
outputs=[
gr.outputs.Image(label='Output image', type='pil')
],
examples=examples,
article='<p style="text-align:center">Based on <a href="https://keras.io/examples/vision/grad_cam/">the example</a> written by <a href="https://twitter.com/fchollet">fchollet</a></p>',
)
iface.launch(debug=True, enable_queue=True)