Spaces:
Running
Running
File size: 9,755 Bytes
e45f24a 3a60abf 88e38d2 e45f24a b3991da d05d494 e45f24a d05d494 615de6b e45f24a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import streamlit as st
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms
from PIL import Image
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import pytorch_lightning as pl
from tqdm import tqdm
from torchvision import models
# Function Definitions
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
"Pulmonary_fibrosis"
]
class VinDetector(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 15
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
self.learning_rate = 1e-3
self.batch_size = 4
def forward(self, x):
return self.model(x)
def prepare_data(self):
df = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv')
df = df[df['class_id'] != 14].reset_index(drop=True)
self.train_dataset = VBDDataset(df, '../input/vinbigdata-chest-xray-original-png/train', get_train_transform())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=4, collate_fn=collate_fn)
def training_step(self, batch, batch_idx):
images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log('Loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return {"loss": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay=1e-5, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0, verbose=True)
return [optimizer], [scheduler]
def get_train_transform():
return A.Compose([
A.Flip(0.5),
ToTensorV2(p=1.0)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
def get_valid_transform():
return A.Compose([
ToTensorV2(p=1.0)
])
def collate_fn(batch):
return tuple(zip(*batch))
def format_prediction_string(labels, boxes, scores):
pred_strings = []
for j in zip(labels, scores, boxes):
pred_strings.append("{0} {1:.4f} {2} {3} {4} {5}".format(
j[0], j[1], j[2][0], j[2][1], j[2][2], j[2][3]))
return " ".join(pred_strings)
def generate_diagnostic_report(predictions, labels, threshold=0.5):
# Initialize an empty report string
report = "Diagnostic Report:\n\n"
findings_present = False
# Loop through each detection
for element in range(len(predictions['boxes'])):
score = predictions['scores'][element].cpu().numpy()
if score > threshold:
label_index = predictions['labels'][element].cpu().numpy() - 1
label_name = labels[label_index]
report += f"- {label_name} detected with probability {score:.2f}\n"
findings_present = True
# If no findings above the threshold, report no significant abnormalities
if not findings_present:
report += "No significant abnormalities detected."
return report
def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3):
# Define your labels and their corresponding colors
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
"Pulmonary_fibrosis"
]
label2color = [
[59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3],
[210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175],
[209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100]
]
for i, box in enumerate(boxes):
if scores[i] > threshold:
# Subtract 1 from label_index to match the zero-indexed Python list
label_index = labels[i] - 1
label_name = label_names[label_index] if label_index < len(label_names) else "Unknown"
color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels
label_text = f'{label_name}: {scores[i]:.2f}'
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)
cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness)
return image
# Heatmap Generation Function
def plot_image_with_colored_mask_overlay_and_original(image, predictions):
# Assuming predictions are in the same format as Faster R-CNN outputs
boxes = predictions['boxes'].cpu().numpy()
scores = predictions['scores'].cpu().numpy()
# Create a blank mask matching image size
mask = np.zeros(image.shape[:2], dtype=np.float32)
# Fill mask based on bounding boxes and scores
for box, score in zip(boxes, scores):
if score > 0.5: # Threshold can be adjusted
x_min, y_min, x_max, y_max = map(int, box)
mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score
# Normalize mask
normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# Create a custom colormap with transparency
colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)]
cmap_name = 'doctoria_heatmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
# Apply custom colormap
heatmap = custom_cmap(normed_mask)
# Convert heatmap to BGR format with uint8 type
heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8)
# Overlay heatmap on original image
overlayed_image = image.copy()
overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5
# Plotting
fig, axs = plt.subplots(1, 2, figsize=[12, 6])
# Original image
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original Image')
axs[0].axis('off')
# Image with heatmap
axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
axs[1].set_title('Image with Heatmap Overlay')
axs[1].axis('off')
# Adding colorbar
sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1))
sm.set_array([])
fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04)
plt.show()
# Load the model
def create_model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# Streamlit app title
st.title("Doctoria CXR")
# Sidebar for user input
st.sidebar.title("Upload Chest X-ray Image")
# File uploader allows user to add their own image
uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"])
# Load the model (use your model loading function)
# Ensure the model path is correct and accessible
model = create_model(num_classes=15)
model.load_state_dict(torch.load('Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu')))
model.eval()
def process_image(image_path):
# Load and transform the image
image = Image.open(image_path).convert('RGB')
transform = get_transform()
image = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
prediction = model(image)
return prediction, image
# When the user uploads a file
if uploaded_file is not None:
# Display the uploaded image
st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True)
st.write("")
# Process the uploaded image
prediction, image_tensor = process_image(uploaded_file)
# Convert tensor to PIL Image for display
image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB")
# Visualization and report generation
image_np = np.array(image_pil)
for element in range(len(prediction[0]['boxes'])):
box = prediction[0]['boxes'][element].cpu().numpy()
score = prediction[0]['scores'][element].cpu().numpy()
label_index = prediction[0]['labels'][element].cpu().numpy()
if score > 0.5:
draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size
image_pil_processed = Image.fromarray(image_np)
# Display processed image
st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True)
# Generate the diagnostic report
report = generate_diagnostic_report(prediction[0], label_names, 0.5)
st.write(report)
# Instructions
st.sidebar.write("Instructions:")
st.sidebar.write("1. Upload an X-ray image.")
st.sidebar.write("2. View the processed image and diagnostic report.")
|