File size: 2,846 Bytes
22001ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9bd3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
'''
Grad-CAM visualization demo

2021-12-18 first created
'''
from PIL import Image
import matplotlib.pyplot as plt
from PIL import Image
import os
import io
from glob import glob
from loguru import logger
import gradio as gr

from utils import (get_imagenet_classes, get_xception_model, get_img_4d_array,
                   make_gradcam_heatmap, align_image_with_heatmap)

# ----- Settings -----
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID

EXAMPLE_DIR = 'examples'
CMAP_CHOICES = ['jet', 'rainbow', 'gist_ncar', 'autumn', 'hot', 'winter', 'hsv']
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg')))
examples = [[image, 'French_bulldog', 0.3, 'jet'] for image in examples]

# ----- Logging -----
logger.add('app.log', mode='a')
logger.info('===== APP RESTARTED =====')

# ----- Model -----
model, grad_model, preprocessor, decode_predictions = get_xception_model()
idx2lab, lab2idx = get_imagenet_classes()
classes = ['none'] + sorted(list(lab2idx.keys()))

def predict(image_obj, pred_class, alpha, cmap):
    image_file = image_obj.name
    logger.info(f'--- image loaded: class={pred_class} | alpha={alpha} | cmap={cmap}')
    
    img = Image.open(image_file)
    width = img.size[0]
    height = img.size[1]

    img_4d_array = get_img_4d_array(image_file)
    img_4d_array = preprocessor(img_4d_array)

    if pred_class == 'none':
        pred_idx = None
    else:
        pred_idx = lab2idx[pred_class]
    heatmap = make_gradcam_heatmap(grad_model, img_4d_array, pred_idx=pred_idx)
    img_pil = align_image_with_heatmap(img_4d_array, heatmap, alpha=0.3, cmap=cmap)
    img_pil = img_pil.resize((width, height))
    logger.info('--- Grad-CAM visualized')
    return img_pil

iface = gr.Interface(
    predict,
    title='Gradient Class Actiavtion Map (Grad-CAM) Visualization Demo',
    description='Provide an image with image class or just image alone. For all 1000 imagenet classes, see https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a',
    inputs=[
        gr.inputs.Image(label='Input image', type='file'),
        gr.inputs.Dropdown(label='Predicted class (if "none", predicted class will be used)',
                           choices=classes, default='none', type='value'),
        gr.inputs.Slider(label='Output image alpha level for heatmap',
                         minimum=0, maximum=1, step=0.1, default=0.4),
        gr.inputs.Dropdown(label='Grad-CAM heatmap colormap', 
                           choices=CMAP_CHOICES, default='jet', type='value'),
    ],
    outputs=[
        gr.outputs.Image(label='Output image', type='pil')
    ],
    examples=examples,
    article='<p style="text-align:center">Based on <a href="https://keras.io/examples/vision/grad_cam/">the example</a> written by <a href="https://twitter.com/fchollet">fchollet</a></p>',
)

iface.launch(debug=True, enable_queue=True)