Spaces:
Starting
on
T4
Starting
on
T4
# import the necessary packages | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import numpy as np | |
from matplotlib import pyplot as plt | |
RESOLUTION = 224 | |
PATCH_SIZE = 16 | |
crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION) | |
norm_layer = layers.Normalization( | |
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], | |
variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2], | |
) | |
rescale_layer = layers.Rescaling(scale=1./127.5, offset=-1) | |
def preprocess_image(image, model_type, size=RESOLUTION): | |
# Turn the image into a numpy array and add batch dim. | |
image = np.array(image) | |
image = tf.expand_dims(image, 0) | |
# If model type is vit rescale the image to [-1, 1]. | |
if model_type == "original_vit": | |
image = rescale_layer(image) | |
# Resize the image using bicubic interpolation. | |
resize_size = int((256 / 224) * size) | |
image = tf.image.resize( | |
image, | |
(resize_size, resize_size), | |
method="bicubic" | |
) | |
# Crop the image. | |
image = crop_layer(image) | |
# If model type is DeiT or DINO normalize the image. | |
if model_type != "original_vit": | |
image = norm_layer(image) | |
return image.numpy() | |
def load_image_from_url(url, model_type): | |
# Credit: Willi Gierke | |
response = requests.get(url) | |
image = Image.open(BytesIO(response.content)) | |
preprocessed_image = preprocess_image(image, model_type) | |
return image, preprocessed_image | |
def attention_heatmap(attention_score_dict, image, model_type="dino", num_heads=12): | |
num_tokens = 2 if "distilled" in model_type else 1 | |
# Sort the transformer blocks in order of their depth. | |
attention_score_list = list(attention_score_dict.keys()) | |
attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True) | |
# Process the attention maps for overlay. | |
w_featmap = image.shape[2] // PATCH_SIZE | |
h_featmap = image.shape[1] // PATCH_SIZE | |
attention_scores = attention_score_dict[attention_score_list[0]] | |
# Taking the representations from CLS token. | |
attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1) | |
# Reshape the attention scores to resemble mini patches. | |
attentions = attentions.reshape(num_heads, w_featmap, h_featmap) | |
attentions = attentions.transpose((1, 2, 0)) | |
# Resize the attention patches to 224x224 (224: 14x16). | |
attentions = tf.image.resize(attentions, size=( | |
h_featmap * PATCH_SIZE, | |
w_featmap * PATCH_SIZE) | |
) | |
return attentions | |
def plot(attentions, image): | |
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13)) | |
img_count = 0 | |
for i in range(3): | |
for j in range(4): | |
if img_count < len(attentions): | |
axes[i, j].imshow(image[0]) | |
axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6) | |
axes[i, j].title.set_text(f"Attention head: {img_count}") | |
axes[i, j].axis("off") | |
img_count += 1 | |
plt.tight_layout() | |
plt.savefig("heat_map.png") |