Spaces:
Sleeping
Sleeping
File size: 3,142 Bytes
51718fb e1c2e43 51718fb e1c2e43 51718fb e1c2e43 51718fb e1c2e43 004c076 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# import the necessary packages
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from PIL import Image
from io import BytesIO
import requests
import numpy as np
from matplotlib import pyplot as plt
RESOLUTION = 224
PATCH_SIZE = 16
crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = layers.Normalization(
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = layers.Rescaling(scale=1./127.5, offset=-1)
def preprocess_image(image, model_type, size=RESOLUTION):
# Turn the image into a numpy array and add batch dim.
image = np.array(image)
image = tf.expand_dims(image, 0)
# If model type is vit rescale the image to [-1, 1].
if model_type == "original_vit":
image = rescale_layer(image)
# Resize the image using bicubic interpolation.
resize_size = int((256 / 224) * size)
image = tf.image.resize(
image,
(resize_size, resize_size),
method="bicubic"
)
# Crop the image.
image = crop_layer(image)
# If model type is DeiT or DINO normalize the image.
if model_type != "original_vit":
image = norm_layer(image)
return image.numpy()
def load_image_from_url(url, model_type):
# Credit: Willi Gierke
response = requests.get(url)
image = Image.open(BytesIO(response.content))
preprocessed_image = preprocess_image(image, model_type)
return image, preprocessed_image
def attention_heatmap(attention_score_dict, image, model_type="dino", num_heads=12):
num_tokens = 2 if "distilled" in model_type else 1
# Sort the transformer blocks in order of their depth.
attention_score_list = list(attention_score_dict.keys())
attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)
# Process the attention maps for overlay.
w_featmap = image.shape[2] // PATCH_SIZE
h_featmap = image.shape[1] // PATCH_SIZE
attention_scores = attention_score_dict[attention_score_list[0]]
# Taking the representations from CLS token.
attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
attentions = attentions.transpose((1, 2, 0))
# Resize the attention patches to 224x224 (224: 14x16).
attentions = tf.image.resize(attentions, size=(
h_featmap * PATCH_SIZE,
w_featmap * PATCH_SIZE)
)
return attentions
def plot(attentions, image):
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0
for i in range(3):
for j in range(4):
if img_count < len(attentions):
axes[i, j].imshow(image[0])
axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
axes[i, j].title.set_text(f"Attention head: {img_count}")
axes[i, j].axis("off")
img_count += 1
plt.tight_layout()
return plt |