File size: 6,511 Bytes
e0d8c59
 
 
2c243ae
 
e0d8c59
2c243ae
 
 
 
 
a63d7e6
2c243ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bba327
e0d8c59
2c243ae
 
 
 
 
e0d8c59
2c243ae
 
 
e0d8c59
 
2c243ae
 
 
 
 
 
 
5bba327
 
 
 
 
e0d8c59
2c243ae
5bba327
2c243ae
 
 
 
 
 
 
 
 
 
 
1a33535
a63d7e6
 
 
 
 
 
 
1a33535
 
 
 
 
a63d7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a33535
a63d7e6
 
 
1a33535
a63d7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a33535
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import timm
import gradio as gr
from huggingface_hub import hf_hub_download
import os
from ViT.ViT_new import vit_base_patch16_224 as vit
import torchvision.transforms as transforms
import requests
from PIL import Image
import numpy as np
import cv2
import pathlib


# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

start_layer = 0

# rule 5 from paper
def avg_heads(cam, grad):
    cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
    cam = grad * cam
    cam = cam.clamp(min=0).mean(dim=0)
    return cam

# rule 6 from paper
def apply_self_attention_rules(R_ss, cam_ss):
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition

def generate_relevance(model, input, index=None):
    output = model(input, register_hook=True)
    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot_vector = one_hot
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot * output)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]
    R = torch.eye(num_tokens, num_tokens)
    for i,blk in enumerate(model.blocks):
        if i < start_layer:
            continue
        grad = blk.attn.get_attn_gradients()
        cam = blk.attn.get_attention_map()
        cam = avg_heads(cam, grad)
        R += apply_self_attention_rules(R, cam)
    return R[0, 1:]

def generate_visualization(model, original_image, class_index=None):
    with torch.enable_grad():
        transformer_attribution = generate_relevance(model, original_image.unsqueeze(0), index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis
	
model_finetuned = None
model = None

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_224 = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def image_classifier(inp):
	image = transform_224(inp)
	print(image.shape)
	#return model_finetuned(image.unsqueeze(0))
	with torch.no_grad():
		prediction = torch.nn.functional.softmax(model_finetuned(image.unsqueeze(0))[0], dim=0)
		confidences = {labels[i]: float(prediction[i]) for i in range(1000)}    
		heatmap = generate_visualization(model_finetuned, image)
		
		prediction_orig = torch.nn.functional.softmax(model(image.unsqueeze(0))[0], dim=0)
		confidences_orig = {labels[i]: float(prediction_orig[i]) for i in range(1000)}    
		heatmap_orig = generate_visualization(model, image)
	return confidences, heatmap, confidences_orig, heatmap_orig

def _load_model(model_name: str):
	global model_finetuned, model
	path = hf_hub_download('Hila/RobustViT',
						   f'{model_name}')
						   
	model = vit(pretrained=True)
	model.eval()
	model_finetuned = vit()
	checkpoint = torch.load(path, map_location='cpu')
	model_finetuned.load_state_dict(checkpoint['state_dict'])
	model_finetuned.eval()
	
_load_model('ar_base.tar')

def _set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])
	
def _clear_image():
	return None

demo = gr.Blocks(css='style.css')

with demo:
	
	
	with gr.Row():
		with gr.Column():
			gr.Markdown('## [Optimizing Relevance Maps of Vision Transformers Improves Robustness](https://github.com/hila-chefer/RobustViT) - Official Demo')
			# gr.Markdown('This is an official demo for [Optimizing Relevance Maps of Vision Transformers Improves Robustness](https://github.com/hila-chefer/RobustViT).')
			gr.Markdown('Select or upload an image and then click **Submit** to see the output.')
			with gr.Row():
				input_image = gr.Image(shape=(224,224))
			with gr.Row():	
				btn = gr.Button("Submit", variant="primary")
				clear_btn = gr.Button('Clear')
		with gr.Column():
			gr.Markdown('### Examples')
			gr.Markdown('#### Corrected Prediction')
			with gr.Row():
				paths = sorted(pathlib.Path('samples/corrected').rglob('*.png'))
				corrected_pred_examples = gr.Dataset(components=[input_image], headers=['header'],
											samples=[[path.as_posix()] for path in paths])
		
			gr.Markdown('#### Improved Explainability')
			with gr.Row():
				paths = sorted(pathlib.Path('samples/better_expl').rglob('*.png'))
				better_expl = gr.Dataset(components=[input_image], headers=['header'],
											samples=[[path.as_posix()] for path in paths])
	
	
	#gr.Markdown('### Results:')
		
	with gr.Row():
		with gr.Column():
			gr.Markdown('### Ours (finetuned model)')
			out1 = gr.outputs.Label(label="Our Classification", num_top_classes=3)
			out2 = gr.Image(label="Our Relevance",shape=(224,224), elem_id="expl1")
			
		with gr.Column():
			gr.Markdown('### Original model')
			out3 = gr.outputs.Label(label="Original Classification", num_top_classes=3)
			out4 = gr.Image(label="Original Relevance",shape=(224,224),elem_id="expl2")
		
	
	corrected_pred_examples.click(fn=_set_example_image, inputs=corrected_pred_examples, outputs=input_image)
	better_expl.click(fn=_set_example_image, inputs=better_expl, outputs=input_image)
	btn.click(fn=image_classifier, inputs=input_image, outputs=[out1, out2, out3, out4])
	clear_btn.click(fn=_clear_image, inputs=[], outputs=[input_image])
	

demo.launch()