Spaces:
Runtime error
Runtime error
File size: 3,639 Bytes
491ca4f 94d34c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import json
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
from torchvision import models, transforms
from torchvision.models.feature_extraction import create_feature_extractor
from transformers import ViTForImageClassification
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
labels = json.loads(Path("imagenet-simple-labels.json").read_text())
# Load ResNet-50
resnet50 = models.resnet50(pretrained=True).to(device)
resnet50.eval()
# Create ResNet feature extractor
feature_extractor = create_feature_extractor(resnet50, return_nodes=["layer4", "fc"])
fc_layer_weights = resnet50.fc.weight
# Load ViT
vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224").to(
device
)
vit.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose(
[transforms.Resize((224, 224)), transforms.ToTensor(), normalize]
)
examples = sorted([f.as_posix() for f in Path("examples").glob("*")])
def get_cam(img_tensor):
output = feature_extractor(img_tensor)
cnn_features = output["layer4"].squeeze()
class_id = output["fc"].argmax()
cam = fc_layer_weights[class_id].matmul(cnn_features.flatten(1))
cam = cam.reshape(cnn_features.shape[1], cnn_features.shape[2])
return cam.cpu().numpy(), labels[class_id]
def get_attention_mask(img_tensor):
result = vit(img_tensor, output_attentions=True)
class_id = result[0].argmax()
attention_probs = torch.stack(result[1]).squeeze(1)
# Average the attention at each layer over all heads
attention_probs = torch.mean(attention_probs, dim=1)
residual = torch.eye(attention_probs.size(-1)).to("cuda")
attention_probs = 0.5 * attention_probs + 0.5 * residual
# normalize by layer
attention_probs = attention_probs / attention_probs.sum(dim=-1).unsqueeze(-1)
attention_rollout = attention_probs[0]
for i in range(1, attention_probs.size(0)):
attention_rollout = torch.matmul(attention_probs[i], attention_rollout)
# Attention between cls token and patches
mask = attention_rollout[0, 1:]
mask_size = np.sqrt(mask.size(0)).astype(int)
mask = mask.reshape(mask_size, mask_size)
return mask.cpu().numpy(), labels[class_id]
def plot_mask_on_image(image, mask):
# min-max normalization
mask = (mask - mask.min()) / mask.max()
mask = (255 * mask).astype(np.uint8)
mask = cv2.resize(mask, image.size)
heatmap = cv2.applyColorMap(mask, cv2.COLORMAP_JET)
result = heatmap * 0.3 + np.array(image) * 0.5
return result.astype(np.uint8)
def inference(img):
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
cam, resnet_label = get_cam(img_tensor)
attention_mask, vit_label = get_attention_mask(img_tensor)
cam_result = plot_mask_on_image(img, cam)
rollout_result = plot_mask_on_image(img, attention_mask)
return resnet_label, cam_result, vit_label, rollout_result
if __name__ == "__main__":
interface = gr.Interface(
fn=inference,
inputs=gr.inputs.Image(type="pil", label="Input Image"),
outputs=[
gr.outputs.Label(num_top_classes=1, type="auto", label="ResNet Label"),
gr.outputs.Image(type="auto", label="ResNet CAM"),
gr.outputs.Label(num_top_classes=1, type="auto", label="ViT Label"),
gr.outputs.Image(type="auto", label="Rollout Attn Flow"),
],
examples=examples,
title="CNN - Transformer Explainability",
live=True,
)
interface.launch()
|