Spaces:
Sleeping
Sleeping
File size: 7,384 Bytes
e45f24a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import streamlit as st
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms
from PIL import Image
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
# Function Definitions
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
"Pulmonary_fibrosis"
]
def generate_diagnostic_report(predictions, labels, threshold=0.5):
# Initialize an empty report string
report = "Diagnostic Report:\n\n"
findings_present = False
# Loop through each detection
for element in range(len(predictions['boxes'])):
score = predictions['scores'][element].cpu().numpy()
if score > threshold:
label_index = predictions['labels'][element].cpu().numpy() - 1
label_name = labels[label_index]
report += f"- {label_name} detected with probability {score:.2f}\n"
findings_present = True
# If no findings above the threshold, report no significant abnormalities
if not findings_present:
report += "No significant abnormalities detected."
return report
def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3):
# Define your labels and their corresponding colors
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
"Pulmonary_fibrosis"
]
label2color = [
[59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3],
[210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175],
[209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100]
]
for i, box in enumerate(boxes):
if scores[i] > threshold:
# Subtract 1 from label_index to match the zero-indexed Python list
label_index = labels[i] - 1
label_name = label_names[label_index] if label_index < len(label_names) else "Unknown"
color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels
label_text = f'{label_name}: {scores[i]:.2f}'
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)
cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness)
return image
# Heatmap Generation Function
def plot_image_with_colored_mask_overlay_and_original(image, predictions):
# Assuming predictions are in the same format as Faster R-CNN outputs
boxes = predictions['boxes'].cpu().numpy()
scores = predictions['scores'].cpu().numpy()
# Create a blank mask matching image size
mask = np.zeros(image.shape[:2], dtype=np.float32)
# Fill mask based on bounding boxes and scores
for box, score in zip(boxes, scores):
if score > 0.5: # Threshold can be adjusted
x_min, y_min, x_max, y_max = map(int, box)
mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score
# Normalize mask
normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# Create a custom colormap with transparency
colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)]
cmap_name = 'doctoria_heatmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
# Apply custom colormap
heatmap = custom_cmap(normed_mask)
# Convert heatmap to BGR format with uint8 type
heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8)
# Overlay heatmap on original image
overlayed_image = image.copy()
overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5
# Plotting
fig, axs = plt.subplots(1, 2, figsize=[12, 6])
# Original image
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original Image')
axs[0].axis('off')
# Image with heatmap
axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
axs[1].set_title('Image with Heatmap Overlay')
axs[1].axis('off')
# Adding colorbar
sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1))
sm.set_array([])
fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04)
plt.show()
# Load the model
def create_model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# Streamlit app title
st.title("Doctoria CXR")
# Sidebar for user input
st.sidebar.title("Upload Chest X-ray Image")
# File uploader allows user to add their own image
uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"])
# Load the model (use your model loading function)
# Ensure the model path is correct and accessible
model = create_model(num_classes=14)
model.load_state_dict(torch.load('Models/Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu')))
model.eval()
def process_image(image_path):
# Load and transform the image
image = Image.open(image_path).convert('RGB')
transform = get_transform()
image = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
prediction = model(image)
return prediction, image
# When the user uploads a file
if uploaded_file is not None:
# Display the uploaded image
st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True)
st.write("")
# Process the uploaded image
prediction, image_tensor = process_image(uploaded_file)
# Convert tensor to PIL Image for display
image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB")
# Visualization and report generation
image_np = np.array(image_pil)
for element in range(len(prediction[0]['boxes'])):
box = prediction[0]['boxes'][element].cpu().numpy()
score = prediction[0]['scores'][element].cpu().numpy()
label_index = prediction[0]['labels'][element].cpu().numpy()
if score > 0.5:
draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size
image_pil_processed = Image.fromarray(image_np)
# Display processed image
st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True)
# Generate the diagnostic report
report = generate_diagnostic_report(prediction[0], label_names, 0.5)
st.write(report)
# Instructions
st.sidebar.write("Instructions:")
st.sidebar.write("1. Upload an X-ray image.")
st.sidebar.write("2. View the processed image and diagnostic report.")
|