ariG23498 HF staff commited on
Commit
e029f9c
1 Parent(s): ad572cd

chore: adding spinner

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -25,25 +25,27 @@ st.image(image, caption="Original Image")
25
  st.title("Attention Heat Maps")
26
 
27
  # Load the DINO model
28
- dino = from_pretrained_keras("probing-vits/vit-dino-base16")
 
29
 
30
- # Get the attention scores
31
- _, attention_score_dict = dino.predict(preprocessed_image)
32
 
33
- # De-normalize the image for visual clarity.
34
- in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])
35
- in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])
36
- preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
37
- preprocessed_img_orig = preprocessed_img_orig / 255.
38
- preprocessed_img_orig = tf.clip_by_value(preprocessed_img_orig, 0.0, 1.0).numpy()
39
 
40
- attentions = utils.attention_heatmap(
41
- attention_score_dict=attention_score_dict,
42
- image=preprocessed_img_orig
43
- )
 
 
44
 
45
- utils.plot(attentions=attentions, image=preprocessed_img_orig)
 
46
 
47
- # Show the attention maps
48
- image = Image.open("heat_map.png")
49
  st.image(image, caption="Attention Heat Maps")
 
25
  st.title("Attention Heat Maps")
26
 
27
  # Load the DINO model
28
+ with st.spinner('Processing the heat maps. Wait for it 🔥...'):
29
+ dino = from_pretrained_keras("probing-vits/vit-dino-base16")
30
 
31
+ # Get the attention scores
32
+ _, attention_score_dict = dino.predict(preprocessed_image)
33
 
34
+ # De-normalize the image for visual clarity.
35
+ in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])
36
+ in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])
37
+ preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
38
+ preprocessed_img_orig = preprocessed_img_orig / 255.
39
+ preprocessed_img_orig = tf.clip_by_value(preprocessed_img_orig, 0.0, 1.0).numpy()
40
 
41
+ attentions = utils.attention_heatmap(
42
+ attention_score_dict=attention_score_dict,
43
+ image=preprocessed_img_orig
44
+ )
45
+
46
+ utils.plot(attentions=attentions, image=preprocessed_img_orig)
47
 
48
+ # Show the attention maps
49
+ image = Image.open("heat_map.png")
50
 
 
 
51
  st.image(image, caption="Attention Heat Maps")