File size: 3,142 Bytes
51718fb
 
e1c2e43
51718fb
 
 
 
 
 
e1c2e43
51718fb
 
 
e1c2e43
51718fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1c2e43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
004c076
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# import the necessary packages
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from PIL import Image
from io import BytesIO
import requests
import numpy as np
from matplotlib import pyplot as plt


RESOLUTION = 224
PATCH_SIZE = 16

crop_layer = layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = layers.Rescaling(scale=1./127.5, offset=-1)


def preprocess_image(image, model_type, size=RESOLUTION):
    # Turn the image into a numpy array and add batch dim.
    image = np.array(image)
    image = tf.expand_dims(image, 0)
    
    # If model type is vit rescale the image to [-1, 1].
    if model_type == "original_vit":
        image = rescale_layer(image)

    # Resize the image using bicubic interpolation.
    resize_size = int((256 / 224) * size)
    image = tf.image.resize(
        image,
        (resize_size, resize_size),
        method="bicubic"
    )

    # Crop the image.
    image = crop_layer(image)

    # If model type is DeiT or DINO normalize the image.
    if model_type != "original_vit":
        image = norm_layer(image)
    
    return image.numpy()
    

def load_image_from_url(url, model_type):
    # Credit: Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image, model_type)
    return image, preprocessed_image


def attention_heatmap(attention_score_dict, image, model_type="dino", num_heads=12):
    num_tokens = 2 if "distilled" in model_type else 1
    
    # Sort the transformer blocks in order of their depth.
    attention_score_list = list(attention_score_dict.keys())
    attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)

    # Process the attention maps for overlay.
    w_featmap = image.shape[2] // PATCH_SIZE
    h_featmap = image.shape[1] // PATCH_SIZE
    attention_scores = attention_score_dict[attention_score_list[0]]

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
    attentions = attentions.transpose((1, 2, 0))

    # Resize the attention patches to 224x224 (224: 14x16).
    attentions = tf.image.resize(attentions, size=(
        h_featmap * PATCH_SIZE,
        w_featmap * PATCH_SIZE)
    )
    return attentions


def plot(attentions, image):
    fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
    img_count = 0

    for i in range(3):
        for j in range(4):
            if img_count < len(attentions):
                axes[i, j].imshow(image[0])
                axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
                axes[i, j].title.set_text(f"Attention head: {img_count}")
                axes[i, j].axis("off")
                img_count += 1
    
    plt.tight_layout()
    return plt