Spaces:
Sleeping
Sleeping
import spaces | |
import os | |
import tempfile | |
from pathlib import Path | |
import SimpleITK as sitk | |
import torch | |
from mrsegmentator import inference | |
from mrsegmentator.utils import add_postfix | |
import gradio as gr | |
import utils | |
description_markdown = """ | |
**GitHub:** https://github.com/hhaentze/mrsegmentator <br> | |
**Paper:** https://arxiv.org/abs/2405.06463 | |
""" | |
css = """ | |
h1 { | |
text-align: center; | |
display:block; | |
} | |
.markdown-block { | |
background-color: #0b0f1a; /* Light gray background */ | |
color: #ffffff; /* white text */ | |
padding: 10px; /* Padding around the text */ | |
border-radius: 5px; /* Rounded corners */ | |
box-shadow: 0 0 10px rgba(11,15,26,1); | |
display: inline-flex; /* Use inline-flex to shrink to content size */ | |
flex-direction: column; | |
justify-content: center; /* Vertically center content */ | |
align-items: center; /* Horizontally center items within */ | |
margin: auto; /* Center the block */ | |
} | |
footer { | |
display:none !important | |
} | |
""" | |
# .markdown-block ul, .markdown-block ol { | |
# background-color: #1e2936; | |
# border-radius: 5px; | |
# padding: 10px; | |
# box-shadow: 0 0 10px rgba(0,0,0,0.3); | |
# padding-left: 20px; /* Adjust padding for bullet alignment */ | |
# padding-left: 20px; /* Adjust padding for bullet alignment */ | |
# text-align: left; /* Ensure text within list is left-aligned */ | |
# list-style-position: outside;/* Ensures bullets/numbers are outside the content flow */ | |
# } | |
examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"] | |
def save_file(segmentation, path): | |
"""If the segmentation comes from our sample files directly return the path. | |
Otherwise save it to the temporary file that was previously allocated by the input image""" | |
if Path(path).name in examples: | |
path = "segmentations/" + add_postfix(path, "seg") | |
else: | |
sitk.WriteImage(segmentation, path) | |
return path | |
def infer(image_path): | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
if torch.cuda.is_available(): | |
inference.infer([image_path], tmpdirname, [0], cpu_only=False, split_level=2) | |
else: | |
inference.infer([image_path], tmpdirname, [0], cpu_only=True, split_level=2) | |
filename = add_postfix(Path(image_path).name, "seg") | |
segmentation = sitk.ReadImage(tmpdirname + "/" + filename) | |
return segmentation | |
def infer_wrapper(input_file, image_state, seg_state, slider=50): | |
filename = Path(input_file).name | |
# inference | |
if filename in examples: | |
segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg")) | |
else: | |
segmentation = infer(input_file.name) | |
# save file | |
seg_path = save_file(segmentation, input_file.name) | |
seg_state.append(utils.sitk2numpy(segmentation)) | |
return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path | |
with gr.Blocks(css=css, title="MRSegmentator") as iface: | |
gr.Markdown("# MRSegmentator: Multi-Modality Segmentation of 40 Classes in MRI and CT") | |
gr.Markdown(description_markdown, elem_classes="markdown-block") | |
image_state = gr.State([]) | |
seg_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(): | |
input_file = gr.File( | |
type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"] | |
) | |
gr.Examples(["images/" + ex for ex in examples], input_file) | |
with gr.Row(): | |
submit_button = gr.Button("Run", variant="primary") | |
clear_button = gr.ClearButton() | |
slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice") | |
download_file = gr.File(label="Download Segmentation", interactive=False) | |
with gr.Column(): | |
overlay_image_np = gr.AnnotatedImage(label="Axial View") | |
input_file.change( | |
utils.read_and_display, | |
inputs=[input_file, image_state, seg_state], | |
outputs=[overlay_image_np, image_state, seg_state], | |
) | |
slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np]) | |
submit_button.click( | |
infer_wrapper, | |
inputs=[input_file, image_state, seg_state, slider], | |
outputs=[overlay_image_np, seg_state, download_file], | |
) | |
clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file]) | |
if __name__ == "__main__": | |
os.environ["MRSEG_WEIGHTS_PATH"] = "weights_v1.2" | |
iface.queue() | |
iface.launch() | |