Spaces:
Runtime error
Runtime error
import json | |
from pathlib import Path | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import torch | |
from torchvision import models, transforms | |
from torchvision.models.feature_extraction import create_feature_extractor | |
from transformers import ViTForImageClassification | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
labels = json.loads(Path("imagenet-simple-labels.json").read_text()) | |
# Load ResNet-50 | |
resnet50 = models.resnet50(pretrained=True).to(device) | |
resnet50.eval() | |
# Create ResNet feature extractor | |
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"]) | |
fc_layer_weights = resnet50.fc.weight | |
# Load ViT | |
vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to( | |
device | |
) | |
vit.eval() | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
preprocess = transforms.Compose( | |
[transforms.Resize((224, 224)), transforms.ToTensor(), normalize] | |
) | |
examples = sorted([f.as_posix() for f in Path("examples").glob("*")]) | |
def get_cam(img_tensor): | |
output = feature_extractor(img_tensor) | |
cnn_features = output["layer4"].squeeze() | |
class_id = output["fc"].argmax() | |
cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1)) | |
cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2]) | |
return cam.cpu().numpy(), labels[class_id] | |
def get_attention_mask(img_tensor): | |
result = vit(img_tensor, output_attentions=True) | |
class_id = result[0].argmax() | |
attention_probs = torch.stack(result[1]).squeeze(1) | |
# Average the attention at each layer over all heads | |
attention_probs = torch.mean(attention_probs, dim=1) | |
residual = torch.eye(attention_probs.size(-1)).to("cuda") | |
attention_probs = 0.5 * attention_probs + 0.5 * residual | |
# normalize by layer | |
attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1) | |
attention_rollout = attention_probs[0] | |
for i in range(1, attention_probs.size(0)): | |
attention_rollout = torch.matmul(attention_probs[i], attention_rollout) | |
# Attention between cls token and patches | |
mask = attention_rollout[0, 1:] | |
mask_size = np.sqrt(mask.size(0)).astype(int) | |
mask = mask.reshape(mask_size, mask_size) | |
return mask.cpu().numpy(), labels[class_id] | |
def plot_mask_on_image(image, mask): | |
# min-max normalization | |
mask = (mask - mask.min()) / mask.max() | |
mask = (255 * mask).astype(np.uint8) | |
mask = cv2.resize(mask, image.size) | |
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET) | |
result = heatmap * 0.3 + np.array(image) * 0.5 | |
return result.astype(np.uint8) | |
def inference(img): | |
img_tensor = preprocess(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
cam, resnet_label = get_cam(img_tensor) | |
attention_mask, vit_label = get_attention_mask(img_tensor) | |
cam_result = plot_mask_on_image(img, cam) | |
rollout_result = plot_mask_on_image(img, attention_mask) | |
return resnet_label, cam_result, vit_label, rollout_result | |
if __name__ == "__main__": | |
interface = gr.Interface( | |
fn=inference, | |
inputs=gr.inputs.Image(type="pil", label="Input Image"), | |
outputs=[ | |
gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"), | |
gr.outputs.Image(type="auto", label="ResNet CAM"), | |
gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"), | |
gr.outputs.Image(type="auto", label="Rollout Attn Flow"), | |
], | |
examples=examples, | |
title="CNN - Transformer Explainability", | |
live=True, | |
) | |
interface.launch() | |