Spaces:
Sleeping
Sleeping
import streamlit as st | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from torchvision import transforms | |
from PIL import Image | |
import torch | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
import cv2 | |
import numpy as np | |
from matplotlib.colors import LinearSegmentedColormap | |
# Function Definitions | |
label_names = [ | |
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly", | |
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass", | |
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax", | |
"Pulmonary_fibrosis" | |
] | |
class VinDetector(pl.LightningModule): | |
def __init__(self, **kwargs): | |
super().__init__() | |
self.model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
num_classes = 15 | |
in_features = self.model.roi_heads.box_predictor.cls_score.in_features | |
self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
self.learning_rate = 1e-3 | |
self.batch_size = 4 | |
def forward(self, x): | |
return self.model(x) | |
def prepare_data(self): | |
df = pd.read_csv('../input/vinbigdata-chest-xray-abnormalities-detection/train.csv') | |
df = df[df['class_id'] != 14].reset_index(drop=True) | |
self.train_dataset = VBDDataset(df, '../input/vinbigdata-chest-xray-original-png/train', get_train_transform()) | |
def train_dataloader(self): | |
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=4, collate_fn=collate_fn) | |
def training_step(self, batch, batch_idx): | |
images, targets = batch | |
targets = [{k: v for k, v in t.items()} for t in targets] | |
loss_dict = self.model(images, targets) | |
loss = sum(loss for loss in loss_dict.values()) | |
self.log('Loss', loss, on_step=True, on_epoch=True, prog_bar=True) | |
return {"loss": loss} | |
def configure_optimizers(self): | |
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.95, weight_decay=1e-5, nesterov=True) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0, verbose=True) | |
return [optimizer], [scheduler] | |
#VBDDataset Class | |
class VBDDataset(Dataset): | |
def __init__(self, dataframe, image_dir, transforms=None, phase='train'): | |
super().__init__() | |
self.image_ids = dataframe['image_id'].unique() | |
self.df = dataframe | |
self.image_dir = image_dir | |
self.transforms = transforms | |
self.phase = phase | |
def __getitem__(self, idx): | |
image_id = self.image_ids[idx] | |
records = self.df[self.df['image_id'] == image_id] | |
image = cv2.imread(f'{self.image_dir}/{image_id}.png', cv2.IMREAD_COLOR) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) | |
image /= 255.0 | |
if self.phase == 'test': | |
if self.transforms: | |
sample = { | |
'image': image, | |
} | |
sample = self.transforms(**sample) | |
image = sample['image'] | |
return image, image_id | |
boxes = records[['x_min', 'y_min', 'x_max', 'y_max']].values | |
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) | |
area = torch.as_tensor(area, dtype=torch.float32) | |
labels = torch.squeeze(torch.as_tensor((records.class_id.values+1,), dtype=torch.int64)) | |
iscrowd = torch.zeros((records.shape[0],), dtype=torch.int64) | |
target = {} | |
target['boxes'] = boxes | |
target['labels'] = labels | |
target['area'] = area | |
target['image_id'] = torch.tensor([idx]) | |
target['iscrowd'] = iscrowd | |
if self.transforms: | |
sample = { | |
'image': image, | |
'bboxes': target['boxes'], | |
'labels': labels | |
} | |
sample = self.transforms(**sample) | |
image = sample['image'] | |
target['boxes'] = torch.as_tensor(sample['bboxes']) | |
return image, target | |
def __len__(self): | |
return self.image_ids.shape[0] | |
def generate_diagnostic_report(predictions, labels, threshold=0.5): | |
# Initialize an empty report string | |
report = "Diagnostic Report:\n\n" | |
findings_present = False | |
# Loop through each detection | |
for element in range(len(predictions['boxes'])): | |
score = predictions['scores'][element].cpu().numpy() | |
if score > threshold: | |
label_index = predictions['labels'][element].cpu().numpy() - 1 | |
label_name = labels[label_index] | |
report += f"- {label_name} detected with probability {score:.2f}\n" | |
findings_present = True | |
# If no findings above the threshold, report no significant abnormalities | |
if not findings_present: | |
report += "No significant abnormalities detected." | |
return report | |
def draw_boxes_cv2(image, boxes, labels, scores, threshold=0.5, font_scale=1.0, thickness=3): | |
# Define your labels and their corresponding colors | |
label_names = [ | |
"Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly", | |
"Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass", | |
"Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax", | |
"Pulmonary_fibrosis" | |
] | |
label2color = [ | |
[59, 238, 119], [222, 21, 229], [94, 49, 164], [206, 221, 133], [117, 75, 3], | |
[210, 224, 119], [211, 176, 166], [63, 7, 197], [102, 65, 77], [194, 134, 175], | |
[209, 219, 50], [255, 44, 47], [89, 125, 149], [110, 27, 100] | |
] | |
for i, box in enumerate(boxes): | |
if scores[i] > threshold: | |
# Subtract 1 from label_index to match the zero-indexed Python list | |
label_index = labels[i] - 1 | |
label_name = label_names[label_index] if label_index < len(label_names) else "Unknown" | |
color = label2color[label_index] if label_index < len(label2color) else (255, 255, 255) # Default to white for unknown labels | |
label_text = f'{label_name}: {scores[i]:.2f}' | |
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness) | |
cv2.putText(image, label_text, (int(box[0]), int(box[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) | |
return image | |
# Heatmap Generation Function | |
def plot_image_with_colored_mask_overlay_and_original(image, predictions): | |
# Assuming predictions are in the same format as Faster R-CNN outputs | |
boxes = predictions['boxes'].cpu().numpy() | |
scores = predictions['scores'].cpu().numpy() | |
# Create a blank mask matching image size | |
mask = np.zeros(image.shape[:2], dtype=np.float32) | |
# Fill mask based on bounding boxes and scores | |
for box, score in zip(boxes, scores): | |
if score > 0.5: # Threshold can be adjusted | |
x_min, y_min, x_max, y_max = map(int, box) | |
mask[y_min:y_max, x_min:x_max] += score # Increase mask intensity with score | |
# Normalize mask | |
normed_mask = cv2.normalize(mask, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F) | |
# Create a custom colormap with transparency | |
colors = [(0, 0, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1), (1, 1, 0, 1), (1, 0, 0, 1)] | |
cmap_name = 'doctoria_heatmap' | |
custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=256) | |
# Apply custom colormap | |
heatmap = custom_cmap(normed_mask) | |
# Convert heatmap to BGR format with uint8 type | |
heatmap_bgr = (heatmap[:, :, 2::-1] * 255).astype(np.uint8) | |
# Overlay heatmap on original image | |
overlayed_image = image.copy() | |
overlayed_image[mask > 0] = overlayed_image[mask > 0] * 0.5 + heatmap_bgr[mask > 0] * 0.5 | |
# Plotting | |
fig, axs = plt.subplots(1, 2, figsize=[12, 6]) | |
# Original image | |
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
axs[0].set_title('Original Image') | |
axs[0].axis('off') | |
# Image with heatmap | |
axs[1].imshow(cv2.cvtColor(overlayed_image, cv2.COLOR_BGR2RGB)) | |
axs[1].set_title('Image with Heatmap Overlay') | |
axs[1].axis('off') | |
# Adding colorbar | |
sm = plt.cm.ScalarMappable(cmap=custom_cmap, norm=plt.Normalize(0, 1)) | |
sm.set_array([]) | |
fig.colorbar(sm, ax=axs[1], orientation='vertical', fraction=0.046, pad=0.04) | |
plt.show() | |
# Load the model | |
def create_model(num_classes): | |
model = fasterrcnn_resnet50_fpn(pretrained=False) | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
return model | |
# Streamlit app title | |
st.title("Doctoria CXR") | |
# Sidebar for user input | |
st.sidebar.title("Upload Chest X-ray Image") | |
# File uploader allows user to add their own image | |
uploaded_file = st.sidebar.file_uploader("Upload Chest X-ray image", type=["png", "jpg", "jpeg"]) | |
# Load the model (use your model loading function) | |
# Ensure the model path is correct and accessible | |
model = create_model(num_classes=14) | |
model.load_state_dict(torch.load('Doctoria CXR Thoraric Full Model.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
def process_image(image_path): | |
# Load and transform the image | |
image = Image.open(image_path).convert('RGB') | |
transform = get_transform() | |
image = transform(image).unsqueeze(0) | |
# Perform inference | |
with torch.no_grad(): | |
prediction = model(image) | |
return prediction, image | |
# When the user uploads a file | |
if uploaded_file is not None: | |
# Display the uploaded image | |
st.image(uploaded_file, caption="Uploaded X-ray", use_column_width=True) | |
st.write("") | |
# Process the uploaded image | |
prediction, image_tensor = process_image(uploaded_file) | |
# Convert tensor to PIL Image for display | |
image_pil = transforms.ToPILImage()(image_tensor.squeeze(0)).convert("RGB") | |
# Visualization and report generation | |
image_np = np.array(image_pil) | |
for element in range(len(prediction[0]['boxes'])): | |
box = prediction[0]['boxes'][element].cpu().numpy() | |
score = prediction[0]['scores'][element].cpu().numpy() | |
label_index = prediction[0]['labels'][element].cpu().numpy() | |
if score > 0.5: | |
draw_boxes_cv2(image_np, [box], [label_index], [score], font_scale=3) # Increased font size | |
image_pil_processed = Image.fromarray(image_np) | |
# Display processed image | |
st.image(image_pil_processed, caption="Processed X-ray with Abnormalities Marked", use_column_width=True) | |
# Generate the diagnostic report | |
report = generate_diagnostic_report(prediction[0], label_names, 0.5) | |
st.write(report) | |
# Instructions | |
st.sidebar.write("Instructions:") | |
st.sidebar.write("1. Upload an X-ray image.") | |
st.sidebar.write("2. View the processed image and diagnostic report.") | |