File size: 3,651 Bytes
5174b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
from pathlib import Path

import cv2
import gradio as gr
import numpy as np
import torch
from torchvision import models, transforms
from torchvision.models.feature_extraction import create_feature_extractor
from transformers import ViTForImageClassification

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

labels = json.loads(Path("labels.json").read_text())

# Load ResNet-50
resnet50 = models.resnet50(pretrained=True).to(device)
resnet50.eval()

# Create ResNet feature extractor
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
fc_layer_weights = resnet50.fc.weight

# Load ViT
vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to(
    device
)
vit.eval()


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

preprocess = transforms.Compose(
    [transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
)

examples = sorted([f.as_posix() for f in Path("examples").glob("*")])


def get_cam(img_tensor):
    output = feature_extractor(img_tensor)
    cnn_features = output["layer4"].squeeze()
    class_id = output["fc"].argmax()

    cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
    cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])

    return cam.cpu().numpy(), labels[class_id]


def get_attention_mask(img_tensor):
    result = vit(img_tensor, output_attentions=True)
    class_id = result[0].argmax()
    attention_probs = torch.stack(result[1]).squeeze(1)

    # Average the attention at each layer over all heads
    attention_probs = torch.mean(attention_probs, dim=1)
    residual = torch.eye(attention_probs.size(-1)).to(device)
    attention_probs = 0.5 * attention_probs + 0.5 * residual

    # normalize by layer
    attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)

    attention_rollout = attention_probs[0]

    for i in range(1, attention_probs.size(0)):
        attention_rollout = torch.matmul(attention_probs[i], attention_rollout)

    # Attention between cls token and patches
    mask = attention_rollout[0, 1:]
    mask_size = np.sqrt(mask.size(0)).astype(int)
    mask = mask.reshape(mask_size, mask_size)

    return mask.cpu().numpy(), labels[class_id]


def plot_mask_on_image(image, mask):
    # min-max normalization
    mask = (mask - mask.min()) / mask.max()
    mask = (255 * mask).astype(np.uint8)
    mask = cv2.resize(mask, image.size)

    heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
    result = heatmap * 0.3 + np.array(image) * 0.5
    return result.astype(np.uint8)


def inference(img):
    img_tensor = preprocess(img).unsqueeze(0).to(device)

    with torch.no_grad():
        cam, resnet_label = get_cam(img_tensor)
        attention_mask, vit_label = get_attention_mask(img_tensor)

    cam_result = plot_mask_on_image(img, cam)
    rollout_result = plot_mask_on_image(img, attention_mask)

    return resnet_label, cam_result, vit_label, rollout_result

if __name__ == "__main__":
    interface = gr.Interface(
        fn=inference,
        inputs=gr.inputs.Image(type="pil", label="Input Image"),
        outputs=[
            gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
            gr.outputs.Image(type="auto", label="ResNet CAM"),
            gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
            gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"),
        ],
        examples=examples,
        title="Transformer Explainability On Our Pre Trained Model",
        live=True,
    )
    interface.launch()