Spaces:
Runtime error
Runtime error
File size: 4,887 Bytes
af72b72 806e6c5 21c7de8 af72b72 21c7de8 39a6dd6 bcaf154 57064d0 39a6dd6 57064d0 39a6dd6 af72b72 39a6dd6 825e1d8 bcaf154 c9b69b7 af72b72 39a6dd6 21c7de8 39a6dd6 dbb7b85 39a6dd6 57064d0 39a6dd6 c9b69b7 21c7de8 39a6dd6 21c7de8 bcaf154 21c7de8 39a6dd6 21c7de8 bcaf154 21c7de8 bcaf154 c31a89f 21c7de8 806e6c5 21c7de8 c31a89f 39a6dd6 21c7de8 c31a89f 39a6dd6 c31a89f 39a6dd6 806e6c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
'''Artist Classifier
prototype
---
- 2022-01-18 jkang first created
'''
from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns
import io
import json
import numpy as np
import skimage
import skimage.io
from skimage.transform import resize
from loguru import logger
from huggingface_hub import from_pretrained_keras
import gradio as gr
import tensorflow as tf
tfk = tf.keras
from gradcam_utils import get_img_4d_array, make_gradcam_heatmap, align_image_with_heatmap
# ---------- Settings ----------
ARTIST_META = 'artist.json'
TREND_META = 'trend.json'
EXAMPLES = ['monet2.jpg', 'surrelaism.png', 'graffitiart.png', 'lichtenstein_popart.jpg', 'pierre_augste_renoir.png']
ALPHA = 0.9
IMG_WIDTH = 299
IMG_HEIGHT = 299
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Model ----------
logger.info('loading models...')
artist_model = from_pretrained_keras("jkang/drawing-artist-classifier")
trend_model = from_pretrained_keras("jkang/drawing-artistic-trend-classifier")
logger.info('both models loaded')
def load_json_as_dict(json_file):
with open(json_file, 'r') as f:
out = json.load(f)
return dict(out)
def load_image_as_array(image_file):
img = skimage.io.imread(image_file, as_gray=False, plugin='matplotlib')
if (img.shape[-1] > 3): # if RGBA
img = img[..., :-1]
return img
def resize_image(img_array, width, height):
img_resized = resize(img_array, (height, width),
anti_aliasing=True,
preserve_range=False)
return skimage.img_as_ubyte(img_resized)
def predict(input_image):
img_3d_array = load_image_as_array(input_image)
img_3d_array = resize_image(img_3d_array, IMG_WIDTH, IMG_HEIGHT)
img_4d_array = img_3d_array[np.newaxis,...]
logger.info(f'--- {input_image} loaded')
artist2id = load_json_as_dict(ARTIST_META)
trend2id = load_json_as_dict(TREND_META)
id2artist = {artist2id[artist]:artist for artist in artist2id}
id2trend = {trend2id[trend]:trend for trend in trend2id}
# Artist model
a_heatmap, a_pred_id, a_pred_out = make_gradcam_heatmap(artist_model,
img_4d_array,
pred_idx=None)
a_img_pil = align_image_with_heatmap(
img_4d_array, a_heatmap, alpha=ALPHA, cmap='jet')
a_img = np.asarray(a_img_pil).astype('float32')/255
a_label = id2artist[a_pred_id]
a_prob = a_pred_out[a_pred_id]
# Trend model
t_heatmap, t_pred_id, t_pred_out = make_gradcam_heatmap(trend_model,
img_4d_array,
pred_idx=None)
t_img_pil = align_image_with_heatmap(
img_4d_array, t_heatmap, alpha=ALPHA, cmap='jet')
t_img = np.asarray(t_img_pil).astype('float32')/255
t_label = id2trend[t_pred_id]
t_prob = t_pred_out[t_pred_id]
with sns.plotting_context('poster', font_scale=0.7):
fig, (ax1, ax2, ax3) = plt.subplots(
1, 3, figsize=(12, 6), facecolor='white')
for ax in (ax1, ax2, ax3):
ax.set_xticks([])
ax.set_yticks([])
ax1.imshow(img_3d_array)
ax2.imshow(a_img)
ax3.imshow(t_img)
ax1.set_title(f'Input Image', ha='left', x=0, y=1.05)
ax2.set_title(f'Artist Prediction:\n => {a_label} ({a_prob:.2f})', ha='left', x=0, y=1.05)
ax3.set_title(f'Style Prediction:\n => {t_label} ({t_prob:.2f})', ha='left', x=0, y=1.05)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, bbox_inches='tight', format='jpg')
buf.seek(0)
pil_img = Image.open(buf)
plt.close()
logger.info('--- image generated')
a_labels = {id2artist[i]: float(pred) for i, pred in enumerate(a_pred_out)}
t_labels = {id2trend[i]: float(pred) for i, pred in enumerate(t_pred_out)}
return a_labels, t_labels, pil_img
iface = gr.Interface(
predict,
title='Predict Artist and Artistic Style of Drawings π¨π¨π»βπ¨ (prototype)',
description='Upload a drawing/image and the model will predict how likely it seems given 10 artists and their trend/style',
inputs=[
gr.inputs.Image(label='Upload a drawing/image', type='file')
],
outputs=[
gr.outputs.Label(label='Artists', num_top_classes=5, type='auto'),
gr.outputs.Label(label='Styles', num_top_classes=5, type='auto'),
gr.outputs.Image(label='Prediction with GradCAM')
],
examples=EXAMPLES,
)
iface.launch(debug=True, enable_queue=True)
|