CindyBSydney commited on
Commit
6605792
·
1 Parent(s): 6499566

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -77,18 +77,48 @@ def process_image(image_path):
77
 
78
  # Check for anomaly
79
  if is_anomaly(clf, feature_extractor, input_image):
80
- return "Anomaly detected. Image will not be classified.", None, None
81
 
82
  # Classify image
83
  predicted_class, probability = classify_image(classification_model, input_image)
84
  result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%"
85
 
86
- # Further processing for heatmap or additional features can be added here
87
-
88
- return result, None, None # Returning placeholders for additional outputs if needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # Gradio interface
91
-
92
  iface = gr.Interface(
93
  fn=process_image,
94
  inputs=File(type="filepath"),
@@ -97,7 +127,7 @@ iface = gr.Interface(
97
  description="Upload an image to classify it as normal or abnormal.",
98
  article="Above is a sample image to test the results of the model. Click it to see the results.",
99
  examples=[
100
- ["Gastric_Images/Ladybug.png"],
101
  ],
102
  allow_flagging="never",
103
  )
 
77
 
78
  # Check for anomaly
79
  if is_anomaly(clf, feature_extractor, input_image):
80
+ return "Anomaly detected. Image will not be classified.", None
81
 
82
  # Classify image
83
  predicted_class, probability = classify_image(classification_model, input_image)
84
  result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%"
85
 
86
+ # Generate heatmap
87
+ heatmap = generate_heatmap(classification_model, input_image)
88
+ heatmap_image = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255))
89
+
90
+ return result, heatmap_image
91
+
92
+ # Function to generate heatmap
93
+ def generate_heatmap(model, image):
94
+ activation = []
95
+ def hook_fn(module, input, output):
96
+ activation.append(output)
97
+ for module in model.named_modules():
98
+ if isinstance(module[1], torch.nn.ReLU):
99
+ module[1].register_forward_hook(hook_fn)
100
+
101
+ # Forward pass
102
+ output = model(image)
103
+ prediction = output.argmax(1)
104
+
105
+ # Backpropagation to compute gradients
106
+ model.zero_grad()
107
+ one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device)
108
+ one_hot_output[0][prediction] = 1
109
+ output.backward(gradient=one_hot_output)
110
+
111
+ # Compute the heatmap
112
+ if len(activation) > 0:
113
+ gradients = activation[0][0].detach().cpu().numpy()
114
+ heatmap = gradients.max(axis=0)
115
+ threshold = 0.5 # Adjust this threshold value as needed
116
+ heatmap[heatmap < threshold] = 0
117
+ return heatmap
118
+ else:
119
+ return np.zeros((224, 224)) # Return an empty heatmap if no activation is recorded
120
 
121
  # Gradio interface
 
122
  iface = gr.Interface(
123
  fn=process_image,
124
  inputs=File(type="filepath"),
 
127
  description="Upload an image to classify it as normal or abnormal.",
128
  article="Above is a sample image to test the results of the model. Click it to see the results.",
129
  examples=[
130
+ ["/Gastric_Images/Ladybug.png"],
131
  ],
132
  allow_flagging="never",
133
  )