GradCAMViT / app.py
raedinkhaled's picture
hjj
5e7fd2b
import json
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
from torchvision import models, transforms
from torchvision.models.feature_extraction import create_feature_extractor
from transformers import ViTForImageClassification
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
labels = json.loads(Path("labels.json").read_text())
# Load ResNet-50
resnet50 = models.resnet50(pretrained=True).to(device)
resnet50.eval()
# Create ResNet feature extractor
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
fc_layer_weights = resnet50.fc.weight
# Load ViT
vit = ViTForImageClassification.from_pretrained("raedinkhaled/vit-base-mri").to(
device
)
vit.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose(
[transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
)
examples = sorted([f.as_posix() for f in Path("examples").glob("*")])
def get_cam(img_tensor):
output = feature_extractor(img_tensor)
cnn_features = output["layer4"].squeeze()
class_id = output["fc"].argmax()
cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])
return cam.cpu().numpy(), labels[class_id]
def get_attention_mask(img_tensor):
result = vit(img_tensor, output_attentions=True)
class_id = result[0].argmax()
attention_probs = torch.stack(result[1]).squeeze(1)
# Average the attention at each layer over all heads
attention_probs = torch.mean(attention_probs, dim=1)
residual = torch.eye(attention_probs.size(-1)).to(device)
attention_probs = 0.5 * attention_probs + 0.5 * residual
# normalize by layer
attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)
attention_rollout = attention_probs[0]
for i in range(1, attention_probs.size(0)):
attention_rollout = torch.matmul(attention_probs[i], attention_rollout)
# Attention between cls token and patches
mask = attention_rollout[0, 1:]
mask_size = np.sqrt(mask.size(0)).astype(int)
mask = mask.reshape(mask_size, mask_size)
return mask.cpu().numpy(), labels[class_id]
def plot_mask_on_image(image, mask):
# min-max normalization
mask = (mask - mask.min()) / mask.max()
mask = (255 * mask).astype(np.uint8)
mask = cv2.resize(mask, image.size)
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
result = heatmap * 0.3 + np.array(image) * 0.5
return result.astype(np.uint8)
def inference(img):
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
cam, resnet_label = get_cam(img_tensor)
attention_mask, vit_label = get_attention_mask(img_tensor)
cam_result = plot_mask_on_image(img, cam)
rollout_result = plot_mask_on_image(img, attention_mask)
return resnet_label, cam_result, vit_label, rollout_result
if __name__ == "__main__":
interface = gr.Interface(
fn=inference,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs=[
gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
gr.outputs.Image(type="auto", label="ResNet CAM"),
gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
gr.outputs.Image(type="auto", label="raedinkhaled/vit-base-mri CAM"),
],
examples=examples,
title="Transformer Explainability On Our Pre Trained Model",
live=True,
)
interface.launch()