RobustViT / ViT /explainer.py
Hila's picture
init commit
7754b29
import torch
import numpy as np
import cv2
# rule 5 from paper
def avg_heads(cam, grad):
cam = cam.reshape(-1, cam.shape[-3], cam.shape[-2], cam.shape[-1])
grad = grad.reshape(-1, cam.shape[-3], grad.shape[-2], grad.shape[-1])
cam = grad * cam
cam = cam.clamp(min=0).mean(dim=1)
return cam
# rule 6 from paper
def apply_self_attention_rules(R_ss, cam_ss):
R_ss_addition = torch.matmul(cam_ss, R_ss)
return R_ss_addition
def upscale_relevance(relevance):
relevance = relevance.reshape(-1, 1, 14, 14)
relevance = torch.nn.functional.interpolate(relevance, scale_factor=16, mode='bilinear')
# normalize between 0 and 1
relevance = relevance.reshape(relevance.shape[0], -1)
min = relevance.min(1, keepdim=True)[0]
max = relevance.max(1, keepdim=True)[0]
relevance = (relevance - min) / (max - min)
relevance = relevance.reshape(-1, 1, 224, 224)
return relevance
def generate_relevance(model, input, index=None):
# a batch of samples
batch_size = input.shape[0]
output = model(input, register_hook=True)
if index == None:
index = np.argmax(output.cpu().data.numpy(), axis=-1)
index = torch.tensor(index)
one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.to(input.device) * output)
model.zero_grad()
num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]
R = torch.eye(num_tokens, num_tokens).cuda()
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(model.blocks):
grad = torch.autograd.grad(one_hot, [blk.attn.attention_map], retain_graph=True)[0]
cam = blk.attn.get_attention_map()
cam = avg_heads(cam, grad)
R = R + apply_self_attention_rules(R, cam)
relevance = R[:, 0, 1:]
return upscale_relevance(relevance)
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
def get_image_with_relevance(image, relevance):
image = image.permute(1, 2, 0)
relevance = relevance.permute(1, 2, 0)
image = (image - image.min()) / (image.max() - image.min())
image = 255 * image
vis = image * relevance
return vis.data.cpu().numpy()