import spaces import os import tempfile from pathlib import Path import SimpleITK as sitk import torch from mrsegmentator import inference from mrsegmentator.utils import add_postfix import gradio as gr import utils description_markdown = """ **GitHub:** https://github.com/hhaentze/mrsegmentator
**Paper:** https://arxiv.org/abs/2405.06463 """ css = """ h1 { text-align: center; display:block; } .markdown-block { background-color: #0b0f1a; /* Light gray background */ color: #ffffff; /* white text */ padding: 10px; /* Padding around the text */ border-radius: 5px; /* Rounded corners */ box-shadow: 0 0 10px rgba(11,15,26,1); display: inline-flex; /* Use inline-flex to shrink to content size */ flex-direction: column; justify-content: center; /* Vertically center content */ align-items: center; /* Horizontally center items within */ margin: auto; /* Center the block */ } footer { display:none !important } """ # .markdown-block ul, .markdown-block ol { # background-color: #1e2936; # border-radius: 5px; # padding: 10px; # box-shadow: 0 0 10px rgba(0,0,0,0.3); # padding-left: 20px; /* Adjust padding for bullet alignment */ # padding-left: 20px; /* Adjust padding for bullet alignment */ # text-align: left; /* Ensure text within list is left-aligned */ # list-style-position: outside;/* Ensures bullets/numbers are outside the content flow */ # } examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"] def save_file(segmentation, path): """If the segmentation comes from our sample files directly return the path. Otherwise save it to the temporary file that was previously allocated by the input image""" if Path(path).name in examples: path = "segmentations/" + add_postfix(path, "seg") else: sitk.WriteImage(segmentation, path) return path @spaces.GPU(duration=150) def infer(image_path): with tempfile.TemporaryDirectory() as tmpdirname: if torch.cuda.is_available(): inference.infer([image_path], tmpdirname, [0], cpu_only=False, split_level=2) else: inference.infer([image_path], tmpdirname, [0], cpu_only=True, split_level=2) filename = add_postfix(Path(image_path).name, "seg") segmentation = sitk.ReadImage(tmpdirname + "/" + filename) return segmentation def infer_wrapper(input_file, image_state, seg_state, slider=50): filename = Path(input_file).name # inference if filename in examples: segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg")) else: segmentation = infer(input_file.name) # save file seg_path = save_file(segmentation, input_file.name) seg_state.append(utils.sitk2numpy(segmentation)) return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path with gr.Blocks(css=css, title="MRSegmentator") as iface: gr.Markdown("# MRSegmentator: Multi-Modality Segmentation of 40 Classes in MRI and CT") gr.Markdown(description_markdown, elem_classes="markdown-block") image_state = gr.State([]) seg_state = gr.State([]) with gr.Row(): with gr.Column(): input_file = gr.File( type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"] ) gr.Examples(["images/" + ex for ex in examples], input_file) with gr.Row(): submit_button = gr.Button("Run", variant="primary") clear_button = gr.ClearButton() slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice") download_file = gr.File(label="Download Segmentation", interactive=False) with gr.Column(): overlay_image_np = gr.AnnotatedImage(label="Axial View") input_file.change( utils.read_and_display, inputs=[input_file, image_state, seg_state], outputs=[overlay_image_np, image_state, seg_state], ) slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np]) submit_button.click( infer_wrapper, inputs=[input_file, image_state, seg_state, slider], outputs=[overlay_image_np, seg_state, download_file], ) clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file]) if __name__ == "__main__": os.environ["MRSEG_WEIGHTS_PATH"] = "weights_v1.2" iface.queue() iface.launch()