File size: 4,950 Bytes
ab163d2 f722806 ab163d2 f722806 ab163d2 f722806 ab163d2 f722806 ab163d2 5da7918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import utils
import Model_Class
import Model_Seg
import SimpleITK as sitk
import torch
from numpy import uint8
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
description_markdown = """
- This tool combines a U-Net Segmentation Model with a ResNet-50 for Classification.
- **Usage:** Just drag a pelvic x-ray into the box and hit run.
- **Process:** The input image will be segmented and cropped to the SIJ before classification.
- **Please Note:** This tool is intended for research purposes only.
- **Privacy:** This tool runs completely locally, ensuring data privacy.
"""
css = """
h1 {
text-align: center;
display:block;
}
.markdown-block {
background-color: #0b0f1a; /* Light gray background */
color: black; /* Black text */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.markdown-block ul, .markdown-block ol {
background-color: #1e2936;
border-radius: 5px;
padding: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.3);
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}
footer {
display:none !important
}
"""
@spaces.GPU
def predict_image(input_image, input_file):
if input_image is not None:
image_path = input_image
elif input_file is not None:
image_path = input_file
else:
return None , None , "Please input an image before pressing run" , None , None
image_mask = Model_Seg.load_and_segment_image(image_path, device)
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
image_mask_im = sitk.GetImageFromArray(image_mask[None, :, :].astype(uint8))
image_im = sitk.GetImageFromArray(original_image_np[None, :, :].astype(uint8))
cropped_boxed_im, _ = utils.mask_and_crop(image_im, image_mask_im)
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
cropped_boxed_array_disp = cropped_boxed_array.squeeze()
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device)
gradcam = Model_Class.make_GradCAM(image_transformed, device)
nr_axSpA_prob = float(prediction[0].item())
r_axSpA_prob = float(prediction[1].item())
# Decision based on the threshold
considered = "be considered r-axSpA" if r_axSpA_prob > 0.59 else "not be considered r-axSpA"
explanation = f"According to the pre-determined cut-off threshold of 0.59, the image should {considered}. This Tool is for research purposes only."
pred_dict = {"nr-axSpA": nr_axSpA_prob, "r-axSpA": r_axSpA_prob}
return overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp
with gr.Blocks(css=css, title="Anatomy Aware axSpA") as iface:
gr.Markdown("# Anatomy-Aware Image Classification for radiographic axSpA")
gr.Markdown(description_markdown, elem_classes="markdown-block")
with gr.Row():
with gr.Column():
with gr.Tab("PNG/JPG"):
input_image = gr.Image(type='filepath', label="Upload an X-ray Image")
with gr.Tab("NIfTI/DICOM"):
input_file = gr.File(type='filepath', label="Upload an X-ray Image")
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
with gr.Column():
overlay_image_np = gr.Image(label="Segmentation Mask")
pred_dict = gr.Label(label="Prediction")
explanation= gr.Textbox(label="Classification Decision")
with gr.Accordion("Additional Information", open=False):
gradcam = gr.Image(label="GradCAM")
cropped_boxed_array_disp = gr.Image(label="Bounding Box")
submit_button.click(predict_image, inputs = [input_image, input_file], outputs=[overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp])
clear_button.add([input_image,overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp])
gr.HTML(article_html)
if __name__ == "__main__":
iface.queue()
iface.launch()
|