DiGuaQiu's picture
Update app.py
96a9c3f verified
raw
history blame
4.64 kB
import spaces
import os
import tempfile
from pathlib import Path
import SimpleITK as sitk
import torch
from mrsegmentator import inference
from mrsegmentator.utils import add_postfix
import gradio as gr
import utils
description_markdown = """
**GitHub:** https://github.com/hhaentze/mrsegmentator <br>
**Paper:** https://arxiv.org/abs/2405.06463
"""
css = """
h1 {
text-align: center;
display:block;
}
.markdown-block {
background-color: #0b0f1a; /* Light gray background */
color: #ffffff; /* white text */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
footer {
display:none !important
}
"""
# .markdown-block ul, .markdown-block ol {
# background-color: #1e2936;
# border-radius: 5px;
# padding: 10px;
# box-shadow: 0 0 10px rgba(0,0,0,0.3);
# padding-left: 20px; /* Adjust padding for bullet alignment */
# padding-left: 20px; /* Adjust padding for bullet alignment */
# text-align: left; /* Ensure text within list is left-aligned */
# list-style-position: outside;/* Ensures bullets/numbers are outside the content flow */
# }
examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]
def save_file(segmentation, path):
"""If the segmentation comes from our sample files directly return the path.
Otherwise save it to the temporary file that was previously allocated by the input image"""
if Path(path).name in examples:
path = "segmentations/" + add_postfix(path, "seg")
else:
sitk.WriteImage(segmentation, path)
return path
@spaces.GPU(duration=150)
def infer(image_path):
with tempfile.TemporaryDirectory() as tmpdirname:
if torch.cuda.is_available():
inference.infer([image_path], tmpdirname, [0], cpu_only=False, split_level=2)
else:
inference.infer([image_path], tmpdirname, [0], cpu_only=True, split_level=2)
filename = add_postfix(Path(image_path).name, "seg")
segmentation = sitk.ReadImage(tmpdirname + "/" + filename)
return segmentation
def infer_wrapper(input_file, image_state, seg_state, slider=50):
filename = Path(input_file).name
# inference
if filename in examples:
segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
else:
segmentation = infer(input_file.name)
# save file
seg_path = save_file(segmentation, input_file.name)
seg_state.append(utils.sitk2numpy(segmentation))
return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path
with gr.Blocks(css=css, title="MRSegmentator") as iface:
gr.Markdown("# MRSegmentator: Multi-Modality Segmentation of 40 Classes in MRI and CT")
gr.Markdown(description_markdown, elem_classes="markdown-block")
image_state = gr.State([])
seg_state = gr.State([])
with gr.Row():
with gr.Column():
input_file = gr.File(
type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
)
gr.Examples(["images/" + ex for ex in examples], input_file)
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
download_file = gr.File(label="Download Segmentation", interactive=False)
with gr.Column():
overlay_image_np = gr.AnnotatedImage(label="Axial View")
input_file.change(
utils.read_and_display,
inputs=[input_file, image_state, seg_state],
outputs=[overlay_image_np, image_state, seg_state],
)
slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])
submit_button.click(
infer_wrapper,
inputs=[input_file, image_state, seg_state, slider],
outputs=[overlay_image_np, seg_state, download_file],
)
clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])
if __name__ == "__main__":
os.environ["MRSEG_WEIGHTS_PATH"] = "weights_v1.2"
iface.queue()
iface.launch()