import streamlit as st import matplotlib.pyplot as plt import matplotlib.patches as patches from torchvision import transforms from PIL import Image import torch from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor import cv2 import numpy as np from matplotlib.colors import LinearSegmentedColormap # Function Definitions label_names = [ "Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass", "Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax", "Pulmonary_fibrosis" ] def generate_diagnostic_report(predictions, labels, threshold=0.5): # Initialize an empty report string report = "Diagnostic Report:\n\n" findings_present = False # Loop through each detection for element in range(len(predictions['boxes'])): score = predictions['scores'][element].cpu().numpy() if score > threshold: label_index = predictions['labels'][element].cpu().numpy() - 1 label_name = labels[label_index] report += f"- {label_name} detected with probability {score:.2f}\n" findings_present = True # If no findings above the threshold, report no significant abnormalities if not findings_present: report += "No significant abnormalities detected." return report def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3): # Define your labels and their corresponding colors label_names = [ "Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass", "Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax", "Pulmonary_fibrosis" ] label2color = [ [59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3], [210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175], [209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100] ] for i, box in enumerate(boxes): if scores[i] > threshold: # Subtract 1 from label_index to match the zero-indexed Python list label_index = labels[i] - 1 label_name = label_names[label_index] if label_index < len(label_names) else "Unknown" color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels label_text = f'{label_name}: {scores[i]:.2f}' cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness) cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) return image # Heatmap Generation Function def plot_image_with_colored_mask_overlay_and_original(image, predictions): # Assuming predictions are in the same format as Faster R-CNN outputs boxes = predictions['boxes'].cpu().numpy() scores = predictions['scores'].cpu().numpy() # Create a blank mask matching image size mask = np.zeros(image.shape[:2], dtype=np.float32) # Fill mask based on bounding boxes and scores for box, score in zip(boxes, scores): if score > 0.5: # Threshold can be adjusted x_min, y_min, x_max, y_max = map(int, box) mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score # Normalize mask normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) # Create a custom colormap with transparency colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)] cmap_name = 'doctoria_heatmap' custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256) # Apply custom colormap heatmap = custom_cmap(normed_mask) # Convert heatmap to BGR format with uint8 type heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8) # Overlay heatmap on original image overlayed_image = image.copy() overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5 # Plotting fig, axs = plt.subplots(1, 2, figsize=[12, 6]) # Original image axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) axs[0].set_title('Original Image') axs[0].axis('off') # Image with heatmap axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB)) axs[1].set_title('Image with Heatmap Overlay') axs[1].axis('off') # Adding colorbar sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1)) sm.set_array([]) fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04) plt.show() # Load the model def create_model(num_classes): model = fasterrcnn_resnet50_fpn(pretrained=False) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model # Streamlit app title st.title("Doctoria CXR") # Sidebar for user input st.sidebar.title("Upload Chest X-ray Image") # File uploader allows user to add their own image uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"]) # Load the model (use your model loading function) # Ensure the model path is correct and accessible model = create_model(num_classes=14) model.load_state_dict(torch.load('Models/Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu'))) model.eval() def process_image(image_path): # Load and transform the image image = Image.open(image_path).convert('RGB') transform = get_transform() image = transform(image).unsqueeze(0) # Perform inference with torch.no_grad(): prediction = model(image) return prediction, image # When the user uploads a file if uploaded_file is not None: # Display the uploaded image st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True) st.write("") # Process the uploaded image prediction, image_tensor = process_image(uploaded_file) # Convert tensor to PIL Image for display image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB") # Visualization and report generation image_np = np.array(image_pil) for element in range(len(prediction[0]['boxes'])): box = prediction[0]['boxes'][element].cpu().numpy() score = prediction[0]['scores'][element].cpu().numpy() label_index = prediction[0]['labels'][element].cpu().numpy() if score > 0.5: draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size image_pil_processed = Image.fromarray(image_np) # Display processed image st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True) # Generate the diagnostic report report = generate_diagnostic_report(prediction[0], label_names, 0.5) st.write(report) # Instructions st.sidebar.write("Instructions:") st.sidebar.write("1. Upload an X-ray image.") st.sidebar.write("2. View the processed image and diagnostic report.")