Spaces:
Sleeping
Sleeping
File size: 1,827 Bytes
1046772 a2ca503 1ae2bef fc4cd1f e53b04f 1f029d6 321f9f3 51718fb d7f67d7 ad572cd d7f67d7 1f029d6 fc4cd1f a2ca503 3010861 e029f9c a2ca503 3c26522 e029f9c a2ca503 e029f9c 85a0494 a2ca503 9d17ea5 e670f36 85a0494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import utils
from huggingface_hub.keras_mixin import from_pretrained_keras
from PIL import Image
import streamlit as st
import tensorflow as tf
st.cache(show_spinner=True)
def load_model():
# Load the DINO model
dino = from_pretrained_keras("probing-vits/vit-dino-base16")
return dino
dino=load_model()
# Inputs
st.title("Input your image")
image_url = st.text_input(
label="URL of image",
value="https://dl.fbaipublicfiles.com/dino/img.png",
placeholder="https://your-favourite-image.png"
)
uploaded_file = st.file_uploader("or an image file", type =["jpg","jpeg"])
# Outputs
st.title("Original Image from URL")
# Preprocess the same image but with normlization.
image, preprocessed_image = utils.load_image_from_url(
image_url,
model_type="dino"
)
if uploaded_file:
image = Image.open(im)
preprocessed_image = utils.preprocess_image(image, model_type)
st.image(image, caption="Original Image")
with st.spinner("Generating the attention scores..."):
# Get the attention scores
_, attention_score_dict = dino.predict(preprocessed_image)
with st.spinner("Generating the heat maps... HOLD ON!"):
# De-normalize the image for visual clarity.
in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])
in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])
preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
preprocessed_img_orig = preprocessed_img_orig / 255.
preprocessed_img_orig = tf.clip_by_value(preprocessed_img_orig, 0.0, 1.0).numpy()
attentions = utils.attention_heatmap(
attention_score_dict=attention_score_dict,
image=preprocessed_img_orig
)
utils.plot(attentions=attentions, image=preprocessed_img_orig)
# Show the attention maps
st.title("Attention 🔥 Maps")
image = Image.open("heat_map.png")
st.image(image, caption="Attention Heat Maps") |