vdprabhu's picture
fix warning message persistence and filter miss
f7012f0
# Imports
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import streamlit as st
from app_utils import *
# The functions (except main) are taken straight from Keras Example
def compute_loss(feature_extractor, input_image, filter_index):
activation = feature_extractor(input_image)
# We avoid border artifacts by only involving non-border pixels in the loss.
filter_activation = activation[:, 2:-2, 2:-2, filter_index]
return tf.reduce_mean(filter_activation)
@tf.function
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
with tf.GradientTape() as tape:
tape.watch(img)
loss = compute_loss(feature_extractor, img, filter_index)
# Compute gradients.
grads = tape.gradient(loss, img)
# Normalize gradients.
grads = tf.math.l2_normalize(grads)
img += learning_rate * grads
return loss, img
def initialize_image():
# We start from a gray image with some random noise
img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
# ResNet50V2 expects inputs in the range [-1, +1].
# Here we scale our random inputs to [-0.125, +0.125]
return (img - 0.5) * 0.25
def visualize_filter(feature_extractor, filter_index):
# We run gradient ascent for 20 steps
img = initialize_image()
for _ in range(ITERATIONS):
loss, img = gradient_ascent_step(
feature_extractor, img, filter_index, LEARNING_RATE
)
# Decode the resulting input image
img = deprocess_image(img[0].numpy())
return loss, img
def deprocess_image(img):
# Normalize array: center on 0., ensure variance is 0.15
img -= img.mean()
img /= img.std() + 1e-5
img *= 0.15
# Center crop
img = img[25:-25, 25:-25, :]
# Clip to [0, 1]
img += 0.5
img = np.clip(img, 0, 1)
# Convert to RGB array
img *= 255
img = np.clip(img, 0, 255).astype("uint8")
return img
# The visualization function
def main():
# Initialize states
initialize_states()
# Model selector
mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)
# Check to not load the model for ever layer change
if mn_option != st.session_state.model_name:
model = getattr(keras.applications, mn_option)(
weights="imagenet", include_top=False
)
st.session_state.layer_list = ["<select layer>"] + [
layer.name for layer in model.layers
]
st.session_state.model = model
st.session_state.model_name = mn_option
# Layer selector, saves the feature selector in case 64 filters are to be seen
if st.session_state.model_name:
ln_option = st.selectbox(
"Select the target layer (best to pick somewhere in the middle of the model) -",
st.session_state.layer_list,
)
if ln_option != "<select layer>":
if ln_option != st.session_state.layer_name:
layer = st.session_state.model.get_layer(name=ln_option)
st.session_state.feat_extract = keras.Model(
inputs=st.session_state.model.inputs, outputs=layer.output
)
st.session_state.layer_name = ln_option
# Filter index selector
if st.session_state.layer_name:
warn_ph = st.empty()
layer_ph = st.empty()
filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())
if VIS_OPTION[filter_select] == 0:
loss, img = visualize_filter(st.session_state.feat_extract, 0)
st.image(img)
else:
layer = st.session_state.model.get_layer(name=st.session_state.layer_name)
num_filters = layer.get_output_at(0).get_shape().as_list()[-1]
warn_ph.warning(
":exclamation: Calculating the gradients can take a while.."
)
if num_filters < 64:
layer_ph.info(
f"{st.session_state.layer_name} has only {num_filters} filters, visualizing only those filters.."
)
prog_bar = st.progress(0)
fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
for filter_index, ax in enumerate(axis.ravel()[: min(num_filters, 64)]):
prog_bar.progress((filter_index + 1) / min(num_filters, 64))
loss, img = visualize_filter(
st.session_state.feat_extract, filter_index
)
ax.imshow(img)
ax.set_title(filter_index + 1)
ax.set_axis_off()
else:
for ax in axis.ravel()[num_filters:]:
ax.set_axis_off()
st.write(fig)
warn_ph.empty()
if __name__ == "__main__":
with open("model_names.txt", "r") as op:
AVAILABLE_MODELS = [i.strip() for i in op.readlines()]
st.set_page_config(layout="wide")
st.title(title)
st.write(info_text)
st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
st.write(self_credit)
main()