# Imports
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

import streamlit as st

from app_utils import *

# The functions (except main) are taken straight from Keras Example
def compute_loss(feature_extractor, input_image, filter_index):
    activation = feature_extractor(input_image)
    # We avoid border artifacts by only involving non-border pixels in the loss.
    filter_activation = activation[:, 2:-2, 2:-2, filter_index]
    return tf.reduce_mean(filter_activation)


@tf.function
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
    with tf.GradientTape() as tape:
        tape.watch(img)
        loss = compute_loss(feature_extractor, img, filter_index)
    # Compute gradients.
    grads = tape.gradient(loss, img)
    # Normalize gradients.
    grads = tf.math.l2_normalize(grads)
    img += learning_rate * grads
    return loss, img


def initialize_image():
    # We start from a gray image with some random noise
    img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
    # ResNet50V2 expects inputs in the range [-1, +1].
    # Here we scale our random inputs to [-0.125, +0.125]
    return (img - 0.5) * 0.25


def visualize_filter(feature_extractor, filter_index):
    # We run gradient ascent for 20 steps
    img = initialize_image()
    for _ in range(ITERATIONS):
        loss, img = gradient_ascent_step(
            feature_extractor, img, filter_index, LEARNING_RATE
        )

    # Decode the resulting input image
    img = deprocess_image(img[0].numpy())
    return loss, img


def deprocess_image(img):
    # Normalize array: center on 0., ensure variance is 0.15
    img -= img.mean()
    img /= img.std() + 1e-5
    img *= 0.15

    # Center crop
    img = img[25:-25, 25:-25, :]

    # Clip to [0, 1]
    img += 0.5
    img = np.clip(img, 0, 1)

    # Convert to RGB array
    img *= 255
    img = np.clip(img, 0, 255).astype("uint8")
    return img


# The visualization function
def main():
    # Initialize states
    initialize_states()

    # Model selector
    mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)

    # Check to not load the model for ever layer change
    if mn_option != st.session_state.model_name:
        model = getattr(keras.applications, mn_option)(
            weights="imagenet", include_top=False
        )
        st.session_state.layer_list = ["<select layer>"] + [
            layer.name for layer in model.layers
        ]
        st.session_state.model = model
        st.session_state.model_name = mn_option

    # Layer selector, saves the feature selector in case 64 filters are to be seen
    if st.session_state.model_name:
        ln_option = st.selectbox(
            "Select the target layer (best to pick somewhere in the middle of the model) -",
            st.session_state.layer_list,
        )
        if ln_option != "<select layer>":
            if ln_option != st.session_state.layer_name:
                layer = st.session_state.model.get_layer(name=ln_option)
                st.session_state.feat_extract = keras.Model(
                    inputs=st.session_state.model.inputs, outputs=layer.output
                )
                st.session_state.layer_name = ln_option

    # Filter index selector
    if st.session_state.layer_name:
        warn_ph = st.empty()
        layer_ph = st.empty()

        filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())

        if VIS_OPTION[filter_select] == 0:
            loss, img = visualize_filter(st.session_state.feat_extract, 0)
            st.image(img)
        else:
            layer = st.session_state.model.get_layer(name=st.session_state.layer_name)
            num_filters = layer.get_output_at(0).get_shape().as_list()[-1]

            warn_ph.warning(
                ":exclamation: Calculating the gradients can take a while.."
            )
            if num_filters < 64:
                layer_ph.info(
                    f"{st.session_state.layer_name} has only {num_filters} filters, visualizing only those filters.."
                )

            prog_bar = st.progress(0)
            fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
            for filter_index, ax in enumerate(axis.ravel()[: min(num_filters, 64)]):
                prog_bar.progress((filter_index + 1) / min(num_filters, 64))
                loss, img = visualize_filter(
                    st.session_state.feat_extract, filter_index
                )
                ax.imshow(img)
                ax.set_title(filter_index + 1)
                ax.set_axis_off()
            else:
                for ax in axis.ravel()[num_filters:]:
                    ax.set_axis_off()

            st.write(fig)
            warn_ph.empty()


if __name__ == "__main__":

    with open("model_names.txt", "r") as op:
        AVAILABLE_MODELS = [i.strip() for i in op.readlines()]

    st.set_page_config(layout="wide")

    st.title(title)
    st.write(info_text)
    st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
    st.write(self_credit)

    main()