luxmorocco commited on
Commit
e45f24a
1 Parent(s): 828e90b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +192 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
8
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
+ import cv2
10
+ import numpy as np
11
+ from matplotlib.colors import LinearSegmentedColormap
12
+
13
+ # Function Definitions
14
+
15
+ label_names = [
16
+ "Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
17
+ "Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
18
+ "Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
19
+ "Pulmonary_fibrosis"
20
+ ]
21
+
22
+ def generate_diagnostic_report(predictions, labels, threshold=0.5):
23
+ # Initialize an empty report string
24
+ report = "Diagnostic Report:\n\n"
25
+ findings_present = False
26
+
27
+ # Loop through each detection
28
+ for element in range(len(predictions['boxes'])):
29
+ score = predictions['scores'][element].cpu().numpy()
30
+ if score > threshold:
31
+ label_index = predictions['labels'][element].cpu().numpy() - 1
32
+ label_name = labels[label_index]
33
+ report += f"- {label_name} detected with probability {score:.2f}\n"
34
+ findings_present = True
35
+
36
+ # If no findings above the threshold, report no significant abnormalities
37
+ if not findings_present:
38
+ report += "No significant abnormalities detected."
39
+
40
+ return report
41
+
42
+ def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3):
43
+ # Define your labels and their corresponding colors
44
+ label_names = [
45
+ "Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
46
+ "Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
47
+ "Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
48
+ "Pulmonary_fibrosis"
49
+ ]
50
+
51
+ label2color = [
52
+ [59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3],
53
+ [210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175],
54
+ [209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100]
55
+ ]
56
+
57
+ for i, box in enumerate(boxes):
58
+ if scores[i] > threshold:
59
+ # Subtract 1 from label_index to match the zero-indexed Python list
60
+ label_index = labels[i] - 1
61
+ label_name = label_names[label_index] if label_index < len(label_names) else "Unknown"
62
+ color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels
63
+
64
+ label_text = f'{label_name}: {scores[i]:.2f}'
65
+ cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)
66
+ cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness)
67
+
68
+ return image
69
+
70
+ # Heatmap Generation Function
71
+ def plot_image_with_colored_mask_overlay_and_original(image, predictions):
72
+ # Assuming predictions are in the same format as Faster R-CNN outputs
73
+ boxes = predictions['boxes'].cpu().numpy()
74
+ scores = predictions['scores'].cpu().numpy()
75
+
76
+ # Create a blank mask matching image size
77
+ mask = np.zeros(image.shape[:2], dtype=np.float32)
78
+
79
+ # Fill mask based on bounding boxes and scores
80
+ for box, score in zip(boxes, scores):
81
+ if score > 0.5: # Threshold can be adjusted
82
+ x_min, y_min, x_max, y_max = map(int, box)
83
+ mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score
84
+
85
+ # Normalize mask
86
+ normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
87
+
88
+ # Create a custom colormap with transparency
89
+ colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)]
90
+ cmap_name = 'doctoria_heatmap'
91
+ custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
92
+
93
+ # Apply custom colormap
94
+ heatmap = custom_cmap(normed_mask)
95
+
96
+ # Convert heatmap to BGR format with uint8 type
97
+ heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8)
98
+
99
+ # Overlay heatmap on original image
100
+ overlayed_image = image.copy()
101
+ overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5
102
+
103
+ # Plotting
104
+ fig, axs = plt.subplots(1, 2, figsize=[12, 6])
105
+
106
+ # Original image
107
+ axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
108
+ axs[0].set_title('Original Image')
109
+ axs[0].axis('off')
110
+
111
+ # Image with heatmap
112
+ axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
113
+ axs[1].set_title('Image with Heatmap Overlay')
114
+ axs[1].axis('off')
115
+
116
+ # Adding colorbar
117
+ sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1))
118
+ sm.set_array([])
119
+ fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04)
120
+
121
+ plt.show()
122
+
123
+
124
+ # Load the model
125
+ def create_model(num_classes):
126
+ model = fasterrcnn_resnet50_fpn(pretrained=False)
127
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
128
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
129
+ return model
130
+
131
+
132
+ # Streamlit app title
133
+ st.title("Doctoria CXR")
134
+
135
+ # Sidebar for user input
136
+ st.sidebar.title("Upload Chest X-ray Image")
137
+
138
+ # File uploader allows user to add their own image
139
+ uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"])
140
+
141
+ # Load the model (use your model loading function)
142
+ # Ensure the model path is correct and accessible
143
+ model = create_model(num_classes=14)
144
+ model.load_state_dict(torch.load('Models/Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu')))
145
+ model.eval()
146
+
147
+ def process_image(image_path):
148
+ # Load and transform the image
149
+ image = Image.open(image_path).convert('RGB')
150
+ transform = get_transform()
151
+ image = transform(image).unsqueeze(0)
152
+
153
+ # Perform inference
154
+ with torch.no_grad():
155
+ prediction = model(image)
156
+
157
+ return prediction, image
158
+
159
+ # When the user uploads a file
160
+ if uploaded_file is not None:
161
+ # Display the uploaded image
162
+ st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True)
163
+ st.write("")
164
+
165
+ # Process the uploaded image
166
+ prediction, image_tensor = process_image(uploaded_file)
167
+
168
+ # Convert tensor to PIL Image for display
169
+ image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB")
170
+
171
+ # Visualization and report generation
172
+ image_np = np.array(image_pil)
173
+ for element in range(len(prediction[0]['boxes'])):
174
+ box = prediction[0]['boxes'][element].cpu().numpy()
175
+ score = prediction[0]['scores'][element].cpu().numpy()
176
+ label_index = prediction[0]['labels'][element].cpu().numpy()
177
+ if score > 0.5:
178
+ draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size
179
+
180
+ image_pil_processed = Image.fromarray(image_np)
181
+
182
+ # Display processed image
183
+ st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True)
184
+
185
+ # Generate the diagnostic report
186
+ report = generate_diagnostic_report(prediction[0], label_names, 0.5)
187
+ st.write(report)
188
+
189
+ # Instructions
190
+ st.sidebar.write("Instructions:")
191
+ st.sidebar.write("1. Upload an X-ray image.")
192
+ st.sidebar.write("2. View the processed image and diagnostic report.")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ tensorflow==2.4.1
3
+ torch==1.7.0
4
+ torchvision
5
+ matplotlib
6
+ Pillow
7
+ tqdm==4.56.2
8
+ opencv-python-headless
9
+ pytorch-lightning==1.2.3
10
+ numpy