raedinkhaled commited on
Commit
a14b289
1 Parent(s): 0304304

Create plot.py

Browse files
Files changed (1) hide show
  1. plot.py +26 -0
plot.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from app import inference, examples
3
+ from PIL import Image
4
+
5
+ plt.rcParams["figure.figsize"] = (11,2)
6
+
7
+ title = ["CAM", "ROLLOUT"]
8
+
9
+ fig_resnet, axis_resnet = plt.subplots(1, len(examples))
10
+
11
+ plots = [plt.subplots(1, len(examples)) for _ in range(2)]
12
+
13
+ for i, image_path in enumerate(examples):
14
+ image = Image.open(image_path)
15
+
16
+ result = inference(image)
17
+
18
+ for j, (fig, axis) in enumerate(plots):
19
+ axis[i].imshow(result[2*j+1])
20
+ axis[i].set_title(result[2*j])
21
+ axis[i].set_axis_off()
22
+
23
+
24
+ for i, (plot, title) in enumerate(zip(plots, title)):
25
+ # plot[0].suptitle(title)
26
+ plot[0].savefig(f"{title}.png")