Spaces:
Running
Running
luxmorocco
commited on
Commit
•
e45f24a
1
Parent(s):
828e90b
Upload 2 files
Browse files- app.py +192 -0
- requirements.txt +10 -0
app.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import matplotlib.patches as patches
|
4 |
+
from torchvision import transforms
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
8 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
from matplotlib.colors import LinearSegmentedColormap
|
12 |
+
|
13 |
+
# Function Definitions
|
14 |
+
|
15 |
+
label_names = [
|
16 |
+
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
|
17 |
+
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
|
18 |
+
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
|
19 |
+
"Pulmonary_fibrosis"
|
20 |
+
]
|
21 |
+
|
22 |
+
def generate_diagnostic_report(predictions, labels, threshold=0.5):
|
23 |
+
# Initialize an empty report string
|
24 |
+
report = "Diagnostic Report:\n\n"
|
25 |
+
findings_present = False
|
26 |
+
|
27 |
+
# Loop through each detection
|
28 |
+
for element in range(len(predictions['boxes'])):
|
29 |
+
score = predictions['scores'][element].cpu().numpy()
|
30 |
+
if score > threshold:
|
31 |
+
label_index = predictions['labels'][element].cpu().numpy() - 1
|
32 |
+
label_name = labels[label_index]
|
33 |
+
report += f"- {label_name} detected with probability {score:.2f}\n"
|
34 |
+
findings_present = True
|
35 |
+
|
36 |
+
# If no findings above the threshold, report no significant abnormalities
|
37 |
+
if not findings_present:
|
38 |
+
report += "No significant abnormalities detected."
|
39 |
+
|
40 |
+
return report
|
41 |
+
|
42 |
+
def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3):
|
43 |
+
# Define your labels and their corresponding colors
|
44 |
+
label_names = [
|
45 |
+
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
|
46 |
+
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
|
47 |
+
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
|
48 |
+
"Pulmonary_fibrosis"
|
49 |
+
]
|
50 |
+
|
51 |
+
label2color = [
|
52 |
+
[59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3],
|
53 |
+
[210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175],
|
54 |
+
[209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100]
|
55 |
+
]
|
56 |
+
|
57 |
+
for i, box in enumerate(boxes):
|
58 |
+
if scores[i] > threshold:
|
59 |
+
# Subtract 1 from label_index to match the zero-indexed Python list
|
60 |
+
label_index = labels[i] - 1
|
61 |
+
label_name = label_names[label_index] if label_index < len(label_names) else "Unknown"
|
62 |
+
color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels
|
63 |
+
|
64 |
+
label_text = f'{label_name}: {scores[i]:.2f}'
|
65 |
+
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)
|
66 |
+
cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness)
|
67 |
+
|
68 |
+
return image
|
69 |
+
|
70 |
+
# Heatmap Generation Function
|
71 |
+
def plot_image_with_colored_mask_overlay_and_original(image, predictions):
|
72 |
+
# Assuming predictions are in the same format as Faster R-CNN outputs
|
73 |
+
boxes = predictions['boxes'].cpu().numpy()
|
74 |
+
scores = predictions['scores'].cpu().numpy()
|
75 |
+
|
76 |
+
# Create a blank mask matching image size
|
77 |
+
mask = np.zeros(image.shape[:2], dtype=np.float32)
|
78 |
+
|
79 |
+
# Fill mask based on bounding boxes and scores
|
80 |
+
for box, score in zip(boxes, scores):
|
81 |
+
if score > 0.5: # Threshold can be adjusted
|
82 |
+
x_min, y_min, x_max, y_max = map(int, box)
|
83 |
+
mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score
|
84 |
+
|
85 |
+
# Normalize mask
|
86 |
+
normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
|
87 |
+
|
88 |
+
# Create a custom colormap with transparency
|
89 |
+
colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)]
|
90 |
+
cmap_name = 'doctoria_heatmap'
|
91 |
+
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
|
92 |
+
|
93 |
+
# Apply custom colormap
|
94 |
+
heatmap = custom_cmap(normed_mask)
|
95 |
+
|
96 |
+
# Convert heatmap to BGR format with uint8 type
|
97 |
+
heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8)
|
98 |
+
|
99 |
+
# Overlay heatmap on original image
|
100 |
+
overlayed_image = image.copy()
|
101 |
+
overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5
|
102 |
+
|
103 |
+
# Plotting
|
104 |
+
fig, axs = plt.subplots(1, 2, figsize=[12, 6])
|
105 |
+
|
106 |
+
# Original image
|
107 |
+
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
108 |
+
axs[0].set_title('Original Image')
|
109 |
+
axs[0].axis('off')
|
110 |
+
|
111 |
+
# Image with heatmap
|
112 |
+
axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
|
113 |
+
axs[1].set_title('Image with Heatmap Overlay')
|
114 |
+
axs[1].axis('off')
|
115 |
+
|
116 |
+
# Adding colorbar
|
117 |
+
sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1))
|
118 |
+
sm.set_array([])
|
119 |
+
fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04)
|
120 |
+
|
121 |
+
plt.show()
|
122 |
+
|
123 |
+
|
124 |
+
# Load the model
|
125 |
+
def create_model(num_classes):
|
126 |
+
model = fasterrcnn_resnet50_fpn(pretrained=False)
|
127 |
+
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
128 |
+
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
|
129 |
+
return model
|
130 |
+
|
131 |
+
|
132 |
+
# Streamlit app title
|
133 |
+
st.title("Doctoria CXR")
|
134 |
+
|
135 |
+
# Sidebar for user input
|
136 |
+
st.sidebar.title("Upload Chest X-ray Image")
|
137 |
+
|
138 |
+
# File uploader allows user to add their own image
|
139 |
+
uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"])
|
140 |
+
|
141 |
+
# Load the model (use your model loading function)
|
142 |
+
# Ensure the model path is correct and accessible
|
143 |
+
model = create_model(num_classes=14)
|
144 |
+
model.load_state_dict(torch.load('Models/Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu')))
|
145 |
+
model.eval()
|
146 |
+
|
147 |
+
def process_image(image_path):
|
148 |
+
# Load and transform the image
|
149 |
+
image = Image.open(image_path).convert('RGB')
|
150 |
+
transform = get_transform()
|
151 |
+
image = transform(image).unsqueeze(0)
|
152 |
+
|
153 |
+
# Perform inference
|
154 |
+
with torch.no_grad():
|
155 |
+
prediction = model(image)
|
156 |
+
|
157 |
+
return prediction, image
|
158 |
+
|
159 |
+
# When the user uploads a file
|
160 |
+
if uploaded_file is not None:
|
161 |
+
# Display the uploaded image
|
162 |
+
st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True)
|
163 |
+
st.write("")
|
164 |
+
|
165 |
+
# Process the uploaded image
|
166 |
+
prediction, image_tensor = process_image(uploaded_file)
|
167 |
+
|
168 |
+
# Convert tensor to PIL Image for display
|
169 |
+
image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB")
|
170 |
+
|
171 |
+
# Visualization and report generation
|
172 |
+
image_np = np.array(image_pil)
|
173 |
+
for element in range(len(prediction[0]['boxes'])):
|
174 |
+
box = prediction[0]['boxes'][element].cpu().numpy()
|
175 |
+
score = prediction[0]['scores'][element].cpu().numpy()
|
176 |
+
label_index = prediction[0]['labels'][element].cpu().numpy()
|
177 |
+
if score > 0.5:
|
178 |
+
draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size
|
179 |
+
|
180 |
+
image_pil_processed = Image.fromarray(image_np)
|
181 |
+
|
182 |
+
# Display processed image
|
183 |
+
st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True)
|
184 |
+
|
185 |
+
# Generate the diagnostic report
|
186 |
+
report = generate_diagnostic_report(prediction[0], label_names, 0.5)
|
187 |
+
st.write(report)
|
188 |
+
|
189 |
+
# Instructions
|
190 |
+
st.sidebar.write("Instructions:")
|
191 |
+
st.sidebar.write("1. Upload an X-ray image.")
|
192 |
+
st.sidebar.write("2. View the processed image and diagnostic report.")
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
tensorflow==2.4.1
|
3 |
+
torch==1.7.0
|
4 |
+
torchvision
|
5 |
+
matplotlib
|
6 |
+
Pillow
|
7 |
+
tqdm==4.56.2
|
8 |
+
opencv-python-headless
|
9 |
+
pytorch-lightning==1.2.3
|
10 |
+
numpy
|