GradCAMViT / app.py
raedinkhaled's picture
Create app.py
5174b1f
raw
history blame
3.65 kB
import json
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
from torchvision import models, transforms
from torchvision.models.feature_extraction import create_feature_extractor
from transformers import ViTForImageClassification
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
labels = json.loads(Path("labels.json").read_text())
# Load ResNet-50
resnet50 = models.resnet50(pretrained=True).to(device)
resnet50.eval()
# Create ResNet feature extractor
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
fc_layer_weights = resnet50.fc.weight
# Load ViT
vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to(
device
)
vit.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose(
[transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
)
examples = sorted([f.as_posix() for f in Path("examples").glob("*")])
def get_cam(img_tensor):
output = feature_extractor(img_tensor)
cnn_features = output["layer4"].squeeze()
class_id = output["fc"].argmax()
cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])
return cam.cpu().numpy(), labels[class_id]
def get_attention_mask(img_tensor):
result = vit(img_tensor, output_attentions=True)
class_id = result[0].argmax()
attention_probs = torch.stack(result[1]).squeeze(1)
# Average the attention at each layer over all heads
attention_probs = torch.mean(attention_probs, dim=1)
residual = torch.eye(attention_probs.size(-1)).to(device)
attention_probs = 0.5 * attention_probs + 0.5 * residual
# normalize by layer
attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)
attention_rollout = attention_probs[0]
for i in range(1, attention_probs.size(0)):
attention_rollout = torch.matmul(attention_probs[i], attention_rollout)
# Attention between cls token and patches
mask = attention_rollout[0, 1:]
mask_size = np.sqrt(mask.size(0)).astype(int)
mask = mask.reshape(mask_size, mask_size)
return mask.cpu().numpy(), labels[class_id]
def plot_mask_on_image(image, mask):
# min-max normalization
mask = (mask - mask.min()) / mask.max()
mask = (255 * mask).astype(np.uint8)
mask = cv2.resize(mask, image.size)
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
result = heatmap * 0.3 + np.array(image) * 0.5
return result.astype(np.uint8)
def inference(img):
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
cam, resnet_label = get_cam(img_tensor)
attention_mask, vit_label = get_attention_mask(img_tensor)
cam_result = plot_mask_on_image(img, cam)
rollout_result = plot_mask_on_image(img, attention_mask)
return resnet_label, cam_result, vit_label, rollout_result
if __name__ == "__main__":
interface = gr.Interface(
fn=inference,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs=[
gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
gr.outputs.Image(type="auto", label="ResNet CAM"),
gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"),
],
examples=examples,
title="Transformer Explainability On Our Pre Trained Model",
live=True,
)
interface.launch()