Doctoria_CXR / app.py
luxmorocco's picture
Update app.py
615de6b verified
import streamlit as st
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms
from PIL import Image
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import pytorch_lightning as pl
from tqdm import tqdm
from torchvision import models
# Function Definitions
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
"Pulmonary_fibrosis"
]
class VinDetector(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__()
self.model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 15
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
self.learning_rate = 1e-3
self.batch_size = 4
def forward(self, x):
return self.model(x)
def prepare_data(self):
df = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv')
df = df[df['class_id'] != 14].reset_index(drop=True)
self.train_dataset = VBDDataset(df, '../input/vinbigdata-chest-xray-original-png/train', get_train_transform())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=4, collate_fn=collate_fn)
def training_step(self, batch, batch_idx):
images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log('Loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return {"loss": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay=1e-5, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0, verbose=True)
return [optimizer], [scheduler]
def get_train_transform():
return A.Compose([
A.Flip(0.5),
ToTensorV2(p=1.0)
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
def get_valid_transform():
return A.Compose([
ToTensorV2(p=1.0)
])
def collate_fn(batch):
return tuple(zip(*batch))
def format_prediction_string(labels, boxes, scores):
pred_strings = []
for j in zip(labels, scores, boxes):
pred_strings.append("{0} {1:.4f} {2} {3} {4} {5}".format(
j[0], j[1], j[2][0], j[2][1], j[2][2], j[2][3]))
return " ".join(pred_strings)
def generate_diagnostic_report(predictions, labels, threshold=0.5):
# Initialize an empty report string
report = "Diagnostic Report:\n\n"
findings_present = False
# Loop through each detection
for element in range(len(predictions['boxes'])):
score = predictions['scores'][element].cpu().numpy()
if score > threshold:
label_index = predictions['labels'][element].cpu().numpy() - 1
label_name = labels[label_index]
report += f"- {label_name} detected with probability {score:.2f}\n"
findings_present = True
# If no findings above the threshold, report no significant abnormalities
if not findings_present:
report += "No significant abnormalities detected."
return report
def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3):
# Define your labels and their corresponding colors
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
"Pulmonary_fibrosis"
]
label2color = [
[59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3],
[210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175],
[209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100]
]
for i, box in enumerate(boxes):
if scores[i] > threshold:
# Subtract 1 from label_index to match the zero-indexed Python list
label_index = labels[i] - 1
label_name = label_names[label_index] if label_index < len(label_names) else "Unknown"
color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels
label_text = f'{label_name}: {scores[i]:.2f}'
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)
cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness)
return image
# Heatmap Generation Function
def plot_image_with_colored_mask_overlay_and_original(image, predictions):
# Assuming predictions are in the same format as Faster R-CNN outputs
boxes = predictions['boxes'].cpu().numpy()
scores = predictions['scores'].cpu().numpy()
# Create a blank mask matching image size
mask = np.zeros(image.shape[:2], dtype=np.float32)
# Fill mask based on bounding boxes and scores
for box, score in zip(boxes, scores):
if score > 0.5: # Threshold can be adjusted
x_min, y_min, x_max, y_max = map(int, box)
mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score
# Normalize mask
normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# Create a custom colormap with transparency
colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)]
cmap_name = 'doctoria_heatmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
# Apply custom colormap
heatmap = custom_cmap(normed_mask)
# Convert heatmap to BGR format with uint8 type
heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8)
# Overlay heatmap on original image
overlayed_image = image.copy()
overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5
# Plotting
fig, axs = plt.subplots(1, 2, figsize=[12, 6])
# Original image
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original Image')
axs[0].axis('off')
# Image with heatmap
axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
axs[1].set_title('Image with Heatmap Overlay')
axs[1].axis('off')
# Adding colorbar
sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1))
sm.set_array([])
fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04)
plt.show()
# Load the model
def create_model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# Streamlit app title
st.title("Doctoria CXR")
# Sidebar for user input
st.sidebar.title("Upload Chest X-ray Image")
# File uploader allows user to add their own image
uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"])
# Load the model (use your model loading function)
# Ensure the model path is correct and accessible
model = create_model(num_classes=15)
model.load_state_dict(torch.load('Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu')))
model.eval()
def process_image(image_path):
# Load and transform the image
image = Image.open(image_path).convert('RGB')
transform = get_transform()
image = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
prediction = model(image)
return prediction, image
# When the user uploads a file
if uploaded_file is not None:
# Display the uploaded image
st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True)
st.write("")
# Process the uploaded image
prediction, image_tensor = process_image(uploaded_file)
# Convert tensor to PIL Image for display
image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB")
# Visualization and report generation
image_np = np.array(image_pil)
for element in range(len(prediction[0]['boxes'])):
box = prediction[0]['boxes'][element].cpu().numpy()
score = prediction[0]['scores'][element].cpu().numpy()
label_index = prediction[0]['labels'][element].cpu().numpy()
if score > 0.5:
draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size
image_pil_processed = Image.fromarray(image_np)
# Display processed image
st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True)
# Generate the diagnostic report
report = generate_diagnostic_report(prediction[0], label_names, 0.5)
st.write(report)
# Instructions
st.sidebar.write("Instructions:")
st.sidebar.write("1. Upload an X-ray image.")
st.sidebar.write("2. View the processed image and diagnostic report.")