import streamlit as st
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from torchvision import transforms
from PIL import Image
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import pytorch_lightning as pl
from tqdm import tqdm
from torchvision import models
# Function Definitions
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
class VinDetector(pl.LightningModule):
def __init__(self, **kwargs):
self.model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 15
in_features = self.model.roi_heads.box_predictor.cls_score.in_features
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
self.learning_rate = 1e-3
self.batch_size = 4
def forward(self, x):
return self.model(x)
def prepare_data(self):
df = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv')
df = df[df['class_id'] != 14].reset_index(drop=True)
self.train_dataset = VBDDataset(df, '../input/vinbigdata-chest-xray-original-png/train', get_train_transform())
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=4, collate_fn=collate_fn)
def training_step(self, batch, batch_idx):
images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
self.log('Loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return {"loss": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay=1e-5, nesterov=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0, verbose=True)
return [optimizer], [scheduler]
def get_train_transform():
return A.Compose([
], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})
def get_valid_transform():
return A.Compose([
def collate_fn(batch):
return tuple(zip(*batch))
def format_prediction_string(labels, boxes, scores):
pred_strings = []
for j in zip(labels, scores, boxes):
pred_strings.append("{0} {1:.4f} {2} {3} {4} {5}".format(
j[0], j[1], j[2][0], j[2][1], j[2][2], j[2][3]))
return " ".join(pred_strings)
def generate_diagnostic_report(predictions, labels, threshold=0.5):
# Initialize an empty report string
report = "Diagnostic Report:\n\n"
findings_present = False
# Loop through each detection
for element in range(len(predictions['boxes'])):
score = predictions['scores'][element].cpu().numpy()
if score > threshold:
label_index = predictions['labels'][element].cpu().numpy() - 1
label_name = labels[label_index]
report += f"- {label_name} detected with probability {score:.2f}\n"
findings_present = True
# If no findings above the threshold, report no significant abnormalities
if not findings_present:
report += "No significant abnormalities detected."
return report
def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3):
# Define your labels and their corresponding colors
label_names = [
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly",
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass",
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax",
label2color = [
[59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3],
[210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175],
[209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100]
for i, box in enumerate(boxes):
if scores[i] > threshold:
# Subtract 1 from label_index to match the zero-indexed Python list
label_index = labels[i] - 1
label_name = label_names[label_index] if label_index < len(label_names) else "Unknown"
color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels
label_text = f'{label_name}: {scores[i]:.2f}'
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness)
cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness)
return image
# Heatmap Generation Function
def plot_image_with_colored_mask_overlay_and_original(image, predictions):
# Assuming predictions are in the same format as Faster R-CNN outputs
boxes = predictions['boxes'].cpu().numpy()
scores = predictions['scores'].cpu().numpy()
# Create a blank mask matching image size
mask = np.zeros(image.shape[:2], dtype=np.float32)
# Fill mask based on bounding boxes and scores
for box, score in zip(boxes, scores):
if score > 0.5: # Threshold can be adjusted
x_min, y_min, x_max, y_max = map(int, box)
mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score
# Normalize mask
normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
# Create a custom colormap with transparency
colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)]
cmap_name = 'doctoria_heatmap'
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
# Apply custom colormap
heatmap = custom_cmap(normed_mask)
# Convert heatmap to BGR format with uint8 type
heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8)
# Overlay heatmap on original image
overlayed_image = image.copy()
overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5
# Plotting
fig, axs = plt.subplots(1, 2, figsize=[12, 6])
# Original image
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
axs[0].set_title('Original Image')
# Image with heatmap
axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB))
axs[1].set_title('Image with Heatmap Overlay')
# Adding colorbar
sm =, norm=plt.Normalize(0, 1))
fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04)
# Load the model
def create_model(num_classes):
model = fasterrcnn_resnet50_fpn(pretrained=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# Streamlit app title
st.title("Doctoria CXR")
# Sidebar for user input
st.sidebar.title("Upload Chest X-ray Image")
# File uploader allows user to add their own image
uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"])
# Load the model (use your model loading function)
# Ensure the model path is correct and accessible
#model = create_model(num_classes=15)
#model.load_state_dict(torch.load('Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu')))
def load_model(model_path):
# Create an instance of the VinDetector model
model = VinDetector(num_classes=14) # Adjust num_classes as needed
# Load the saved state_dict
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
return model
model = load_model('Doctoria CXR Thoraric Full Model.pth')
def process_image(image_path):
# Load and transform the image
image ='RGB')
transform = get_transform()
image = transform(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
prediction = model(image)
return prediction, image
# When the user uploads a file
if uploaded_file is not None:
# Display the uploaded image
st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True)
# Process the uploaded image
prediction, image_tensor = process_image(uploaded_file)
# Convert tensor to PIL Image for display
image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB")
# Visualization and report generation
image_np = np.array(image_pil)
for element in range(len(prediction[0]['boxes'])):
box = prediction[0]['boxes'][element].cpu().numpy()
score = prediction[0]['scores'][element].cpu().numpy()
label_index = prediction[0]['labels'][element].cpu().numpy()
if score > 0.5:
draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size
image_pil_processed = Image.fromarray(image_np)
# Display processed image
st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True)
# Generate the diagnostic report
report = generate_diagnostic_report(prediction[0], label_names, 0.5)
# Instructions
st.sidebar.write("1. Upload an X-ray image.")
st.sidebar.write("2. View the processed image and diagnostic report.")