|
import torch |
|
import numpy as np |
|
import cv2 |
|
|
|
|
|
def avg_heads(cam, grad): |
|
cam = cam.reshape(-1, cam.shape[-3], cam.shape[-2], cam.shape[-1]) |
|
grad = grad.reshape(-1, cam.shape[-3], grad.shape[-2], grad.shape[-1]) |
|
cam = grad * cam |
|
cam = cam.clamp(min=0).mean(dim=1) |
|
return cam |
|
|
|
|
|
def apply_self_attention_rules(R_ss, cam_ss): |
|
R_ss_addition = torch.matmul(cam_ss, R_ss) |
|
return R_ss_addition |
|
|
|
def upscale_relevance(relevance): |
|
relevance = relevance.reshape(-1, 1, 14, 14) |
|
relevance = torch.nn.functional.interpolate(relevance, scale_factor=16, mode='bilinear') |
|
|
|
|
|
relevance = relevance.reshape(relevance.shape[0], -1) |
|
min = relevance.min(1, keepdim=True)[0] |
|
max = relevance.max(1, keepdim=True)[0] |
|
relevance = (relevance - min) / (max - min) |
|
|
|
relevance = relevance.reshape(-1, 1, 224, 224) |
|
return relevance |
|
|
|
def generate_relevance(model, input, index=None): |
|
|
|
batch_size = input.shape[0] |
|
output = model(input, register_hook=True) |
|
if index == None: |
|
index = np.argmax(output.cpu().data.numpy(), axis=-1) |
|
index = torch.tensor(index) |
|
|
|
one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32) |
|
one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1 |
|
one_hot = torch.from_numpy(one_hot).requires_grad_(True) |
|
one_hot = torch.sum(one_hot.to(input.device) * output) |
|
model.zero_grad() |
|
|
|
num_tokens = model.blocks[0].attn.get_attention_map().shape[-1] |
|
R = torch.eye(num_tokens, num_tokens).cuda() |
|
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens) |
|
for i, blk in enumerate(model.blocks): |
|
grad = torch.autograd.grad(one_hot, [blk.attn.attention_map], retain_graph=True)[0] |
|
cam = blk.attn.get_attention_map() |
|
cam = avg_heads(cam, grad) |
|
R = R + apply_self_attention_rules(R, cam) |
|
relevance = R[:, 0, 1:] |
|
return upscale_relevance(relevance) |
|
|
|
|
|
def show_cam_on_image(img, mask): |
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
|
heatmap = np.float32(heatmap) / 255 |
|
cam = heatmap + np.float32(img) |
|
cam = cam / np.max(cam) |
|
return cam |
|
|
|
|
|
def get_image_with_relevance(image, relevance): |
|
image = image.permute(1, 2, 0) |
|
relevance = relevance.permute(1, 2, 0) |
|
image = (image - image.min()) / (image.max() - image.min()) |
|
image = 255 * image |
|
vis = image * relevance |
|
return vis.data.cpu().numpy() |