Spaces:
Runtime error
Runtime error
''' | |
Grad-CAM visualization demo | |
2021-12-18 first created | |
''' | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import os | |
import io | |
from glob import glob | |
from loguru import logger | |
import gradio as gr | |
from utils import (get_imagenet_classes, get_xception_model, get_img_4d_array, | |
make_gradcam_heatmap, align_image_with_heatmap) | |
# ----- Settings ----- | |
GPU_ID = '-1' | |
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID | |
EXAMPLE_DIR = 'examples' | |
CMAP_CHOICES = ['jet', 'rainbow', 'gist_ncar', 'autumn', 'hot', 'winter', 'hsv'] | |
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg'))) | |
examples = [[image, 'French_bulldog', 0.3, 'jet'] for image in examples] | |
# ----- Logging ----- | |
logger.add('app.log', mode='a') | |
logger.info('===== APP RESTARTED =====') | |
# ----- Model ----- | |
model, grad_model, preprocessor, decode_predictions = get_xception_model() | |
idx2lab, lab2idx = get_imagenet_classes() | |
classes = ['none'] + sorted(list(lab2idx.keys())) | |
def predict(image_obj, pred_class, alpha, cmap): | |
image_file = image_obj.name | |
logger.info(f'--- image loaded: class={pred_class} | alpha={alpha} | cmap={cmap}') | |
img = Image.open(image_file) | |
width = img.size[0] | |
height = img.size[1] | |
img_4d_array = get_img_4d_array(image_file) | |
img_4d_array = preprocessor(img_4d_array) | |
if pred_class == 'none': | |
pred_idx = None | |
else: | |
pred_idx = lab2idx[pred_class] | |
heatmap = make_gradcam_heatmap(grad_model, img_4d_array, pred_idx=pred_idx) | |
img_pil = align_image_with_heatmap(img_4d_array, heatmap, alpha=0.3, cmap=cmap) | |
img_pil = img_pil.resize((width, height)) | |
logger.info('--- Grad-CAM visualized') | |
return img_pil | |
iface = gr.Interface( | |
predict, | |
title='Gradient Class Actiavtion Map (Grad-CAM) Visualization Demo', | |
description='Provide an image with image class or just image alone. For all 1000 imagenet classes, see https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a', | |
inputs=[ | |
gr.inputs.Image(label='Input image', type='file'), | |
gr.inputs.Dropdown(label='Predicted class (if "none", predicted class will be used)', | |
choices=classes, default='none', type='value'), | |
gr.inputs.Slider(label='Output image alpha level for heatmap', | |
minimum=0, maximum=1, step=0.1, default=0.4), | |
gr.inputs.Dropdown(label='Grad-CAM heatmap colormap', | |
choices=CMAP_CHOICES, default='jet', type='value'), | |
], | |
outputs=[ | |
gr.outputs.Image(label='Output image', type='pil') | |
], | |
examples=examples, | |
article='<p style="text-align:center">Based on <a href="https://keras.io/examples/vision/grad_cam/">the example</a> written by <a href="https://twitter.com/fchollet">fchollet</a></p>', | |
) | |
iface.launch(debug=True, enable_queue=True) | |