Hila commited on
Commit
a63d7e6
1 Parent(s): 1a33535

Change layout

Browse files
app.py CHANGED
@@ -9,6 +9,7 @@ import requests
9
  from PIL import Image
10
  import numpy as np
11
  import cv2
 
12
 
13
 
14
  # create heatmap from mask on image
@@ -113,28 +114,62 @@ def _load_model(model_name: str):
113
  model_finetuned.eval()
114
 
115
  _load_model('ar_base.tar')
116
- #demo = gr.Interface(image_classifier, gr.inputs.Image(shape=(224,224)), [gr.outputs.Label(label="Our Classification", num_top_classes=3), gr.Image(label="Our Relevance",shape=(64,64)), gr.outputs.Label(label="Original Classification", num_top_classes=3), gr.Image(label="Original Relevance",shape=(64,64))],examples=["samples/augreg_base/tank.png", "samples/augreg_base/sundial.png", "samples/augreg_base/lizard.png", "samples/augreg_base/storck.png", "samples/augreg_base/hummingbird2.png", "samples/augreg_base/hummingbird.png"], capture_session=True)
117
- #demo.launch(debug=True)
118
 
119
- demo = gr.Blocks()
 
 
 
 
 
 
120
 
121
  with demo:
122
- gr.Markdown('Select an image and then click **submit** to see the output.')
123
 
124
- with gr.Row():
125
- inp = gr.inputs.Image(shape=(224,224))
126
 
127
- outs = []
128
  with gr.Row():
129
- out1 = gr.outputs.Label(label="Our Classification", num_top_classes=3)
130
- out2 = gr.Image(label="Our Relevance",shape=(224,224))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
 
 
132
  with gr.Row():
133
- out3 = gr.outputs.Label(label="Original Classification", num_top_classes=3)
134
- out4 = gr.Image(label="Original Relevance",shape=(224,224))
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- btn = gr.Button('Submit')
137
- btn.click(fn=image_classifier, inputs=inp, outputs=[out1, out2, out3, out4])
138
 
139
  demo.launch()
140
 
 
9
  from PIL import Image
10
  import numpy as np
11
  import cv2
12
+ import pathlib
13
 
14
 
15
  # create heatmap from mask on image
 
114
  model_finetuned.eval()
115
 
116
  _load_model('ar_base.tar')
 
 
117
 
118
+ def _set_example_image(example: list) -> dict:
119
+ return gr.Image.update(value=example[0])
120
+
121
+ def _clear_image():
122
+ return None
123
+
124
+ demo = gr.Blocks(css='style.css')
125
 
126
  with demo:
 
127
 
 
 
128
 
 
129
  with gr.Row():
130
+ with gr.Column():
131
+ gr.Markdown('## [Optimizing Relevance Maps of Vision Transformers Improves Robustness](https://github.com/hila-chefer/RobustViT) - Official Demo')
132
+ # gr.Markdown('This is an official demo for [Optimizing Relevance Maps of Vision Transformers Improves Robustness](https://github.com/hila-chefer/RobustViT).')
133
+ gr.Markdown('Select or upload an image and then click **Submit** to see the output.')
134
+ with gr.Row():
135
+ input_image = gr.Image(shape=(224,224))
136
+ with gr.Row():
137
+ btn = gr.Button("Submit", variant="primary")
138
+ clear_btn = gr.Button('Clear')
139
+ with gr.Column():
140
+ gr.Markdown('### Examples')
141
+ gr.Markdown('#### Corrected Prediction')
142
+ with gr.Row():
143
+ paths = sorted(pathlib.Path('samples/corrected').rglob('*.png'))
144
+ corrected_pred_examples = gr.Dataset(components=[input_image], headers=['header'],
145
+ samples=[[path.as_posix()] for path in paths])
146
+
147
+ gr.Markdown('#### Improved Explainability')
148
+ with gr.Row():
149
+ paths = sorted(pathlib.Path('samples/better_expl').rglob('*.png'))
150
+ better_expl = gr.Dataset(components=[input_image], headers=['header'],
151
+ samples=[[path.as_posix()] for path in paths])
152
 
153
+
154
+ #gr.Markdown('### Results:')
155
+
156
  with gr.Row():
157
+ with gr.Column():
158
+ gr.Markdown('### Ours (finetuned model)')
159
+ out1 = gr.outputs.Label(label="Our Classification", num_top_classes=3)
160
+ out2 = gr.Image(label="Our Relevance",shape=(224,224), elem_id="expl1")
161
+
162
+ with gr.Column():
163
+ gr.Markdown('### Original model')
164
+ out3 = gr.outputs.Label(label="Original Classification", num_top_classes=3)
165
+ out4 = gr.Image(label="Original Relevance",shape=(224,224),elem_id="expl2")
166
+
167
+
168
+ corrected_pred_examples.click(fn=_set_example_image, inputs=corrected_pred_examples, outputs=input_image)
169
+ better_expl.click(fn=_set_example_image, inputs=better_expl, outputs=input_image)
170
+ btn.click(fn=image_classifier, inputs=input_image, outputs=[out1, out2, out3, out4])
171
+ clear_btn.click(fn=_clear_image, inputs=[], outputs=[input_image])
172
 
 
 
173
 
174
  demo.launch()
175
 
samples/augreg_base/hummingbird2.png DELETED
Binary file (98.1 kB)
 
samples/{augreg_base/sundial.png → better_expl/1.png} RENAMED
File without changes
samples/{augreg_base/storck.png → better_expl/2.png} RENAMED
File without changes
samples/{augreg_base/lizard.png → better_expl/3.png} RENAMED
File without changes
samples/better_expl/4.png ADDED
samples/better_expl/5.png ADDED
samples/better_expl/6.png ADDED
samples/better_expl/7.png ADDED
samples/catdog.png DELETED
Binary file (90.2 kB)
 
samples/corrected/0_new.png ADDED
samples/{augreg_base/tank.png → corrected/1_new.png} RENAMED
File without changes
samples/corrected/2_new.png ADDED
samples/{augreg_base/hummingbird.png → corrected/3_new.png} RENAMED
File without changes
samples/corrected/4_new.png ADDED
samples/corrected/5_new.png ADDED
samples/corrected/6_new.png ADDED
samples/corrected/7_new.png ADDED
samples/corrected/8_new.png ADDED
samples/deit_base/1_in.png DELETED
Binary file (110 kB)
 
samples/deit_base/2_in.png DELETED
Binary file (81.6 kB)
 
samples/deit_base/3_in.png DELETED
Binary file (81.9 kB)
 
samples/deit_base/a.png DELETED
Binary file (102 kB)
 
samples/deit_base/a_2.png DELETED
Binary file (81.7 kB)
 
samples/deit_base/a_3.png DELETED
Binary file (80.3 kB)
 
samples/dogbird.png DELETED
Binary file (96.8 kB)
 
samples/el2.png DELETED
Binary file (13.4 kB)
 
samples/vit_base/1_in.png DELETED
Binary file (84 kB)
 
samples/vit_base/2_in.png DELETED
Binary file (91.8 kB)
 
samples/vit_base/3_in.png DELETED
Binary file (83.2 kB)
 
samples/vit_base/a.png DELETED
Binary file (72.4 kB)
 
samples/vit_base/a_2.png DELETED
Binary file (108 kB)
 
samples/vit_base/a_3.png DELETED
Binary file (90.5 kB)
 
style.css ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ div#expl1 {
2
+ max-width: 224px;
3
+ max-height: 224px;
4
+ }
5
+
6
+ div#expl2 {
7
+ max-width: 224px;
8
+ max-height: 224px;
9
+ }