Spaces:
Running
Running
# Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol | |
import tempfile | |
from pathlib import Path | |
import SimpleITK as sitk | |
from mrsegmentator.utils import add_postfix | |
import streamlit as st | |
import utils | |
print("script run") | |
st.title("MRSegmentator") | |
st.write("(On-site segmentation is currently disabled, because we lack access to GPUs)") | |
############################################# | |
# init session_state | |
if "option" not in st.session_state: | |
st.session_state.option = None | |
if "reset_demo_case" not in st.session_state: | |
st.session_state.reset_demo_case = False | |
if "preds_3D" not in st.session_state: | |
st.session_state.preds_3D = None | |
st.session_state.preds_path = None | |
if "data_item" not in st.session_state: | |
st.session_state.data_item = None | |
if "rectangle_3Dbox" not in st.session_state: | |
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] | |
if "running" not in st.session_state: | |
st.session_state.running = False | |
if "transparency" not in st.session_state: | |
st.session_state.transparency = 0.25 | |
case_list = [ | |
"amos_0517_MRI.nii.gz", | |
"amos_0541_MRI.nii.gz", | |
"amos_0571_MRI.nii.gz", | |
] | |
############################################# | |
############################################# | |
# reset functions | |
def clear_prompts(): | |
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] | |
def reset_demo_case(): | |
st.session_state.data_item = None | |
st.session_state.reset_demo_case = True | |
st.session_state.preds_3D = None | |
st.session_state.preds_3D_ori = None | |
clear_prompts() | |
def clear_file(): | |
st.session_state.option = None | |
reset_demo_case() | |
clear_prompts() | |
############################################# | |
github_col, arxive_col = st.columns(2) | |
with github_col: | |
st.write("Git: https://github.com/hhaentze/mrsegmentator") | |
with arxive_col: | |
st.write("Paper: https://arxiv.org/abs/2405.06463") | |
# modify demo case here | |
demo_type = st.radio("Demo case source", ["Select (presegmented)", "Upload"], on_change=clear_file) | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
# modify demo case here | |
if demo_type == "Select (presegmented)": | |
selection = st.selectbox( | |
"Select a demo case", | |
case_list, | |
index=None, | |
placeholder="Select a demo case...", | |
on_change=reset_demo_case, | |
) | |
if selection: | |
uploaded_file = "images/" + selection | |
seg_path = Path(__file__).parent / ("segmentations/" + add_postfix(selection, "seg")) | |
st.session_state.preds_3D = utils.read_image(seg_path) | |
st.session_state.preds_3D_ori = sitk.ReadImage(seg_path) | |
else: | |
uploaded_file = None | |
else: | |
uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case) | |
if uploaded_file is not None: | |
with open(tmpdirname + "/" + uploaded_file.name, "wb") as f: | |
f.write(uploaded_file.getvalue()) | |
uploaded_file = tmpdirname + "/" + uploaded_file.name | |
st.session_state.option = uploaded_file | |
if ( | |
st.session_state.option is not None | |
and st.session_state.reset_demo_case | |
or (st.session_state.data_item is None and st.session_state.option is not None) | |
): | |
st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file)) | |
# st.session_state.preds_3D = None | |
# st.session_state.preds_3D_ori = None | |
st.session_state.reset_demo_case = False | |
if st.session_state.option is None: | |
st.write("please select demo case first") | |
else: | |
image_3D = st.session_state.data_item | |
px_range = st.slider( | |
"Select intensity range", | |
int(image_3D.min()), | |
int(image_3D.max()), | |
(int(image_3D.min()), int(image_3D.max())), | |
) | |
col_control1, col_control2 = st.columns(2) | |
with col_control1: | |
selected_index_z = st.slider( | |
"Axial view", | |
0, | |
image_3D.shape[0] - 1, | |
image_3D.shape[0] // 2, | |
key="xy", | |
disabled=st.session_state.running, | |
) | |
with col_control2: | |
selected_index_y = st.slider( | |
"Coronal view", | |
0, | |
image_3D.shape[1] - 1, | |
image_3D.shape[1] // 2, | |
key="xz", | |
disabled=st.session_state.running, | |
) | |
col_image1, col_image2 = st.columns(2) | |
if st.session_state.preds_3D is not None: | |
st.session_state.transparency = st.slider("Mask opacity", 0.0, 1.0, 0.35, disabled=st.session_state.running) | |
with col_image1: | |
image_z_array = image_3D[selected_index_z] | |
preds_z_array = None | |
if st.session_state.preds_3D is not None: | |
preds_z_array = st.session_state.preds_3D[selected_index_z] | |
image_z = utils.make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency) | |
st.image(image_z, use_column_width=False) | |
with col_image2: | |
image_y_array = image_3D[:, selected_index_y, :] | |
preds_y_array = None | |
if st.session_state.preds_3D is not None: | |
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :] | |
image_y = utils.make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency) | |
st.image(image_y, use_column_width=False) | |
###################################################### | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("#") | |
st.markdown("####") | |
st.markdown("####") | |
if st.button( | |
"Clear", | |
use_container_width=True, | |
disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)), | |
): | |
clear_prompts() | |
st.session_state.preds_3D = None | |
st.session_state.preds_path = None | |
st.rerun() | |
with col2: | |
st.markdown("#") | |
st.markdown("####") | |
st.markdown("####") | |
if st.session_state.preds_3D is not None and st.session_state.data_item is not None: | |
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile: | |
preds = st.session_state.preds_3D_ori | |
sitk.WriteImage(preds, tmpfile.name) | |
with open(tmpfile.name, "rb") as f: | |
bytes_data = f.read() | |
st.download_button( | |
label="Download result (.nii.gz)", | |
data=bytes_data, | |
file_name="segmentation.nii.gz", | |
mime="application/octet-stream", | |
disabled=False, | |
) | |
with col3: | |
folds = st.radio("", ["Model of Fold 1 (fast)", "Ensemble Segmentation"]) | |
if folds == "Model of Fold 1": | |
st.session_state.folds = (0,) | |
else: | |
st.session_state.folds = ( | |
0, | |
1, | |
2, | |
3, | |
4, | |
) | |
run_button_name = "Run" if not st.session_state.running else "Running" | |
if st.button( | |
run_button_name, | |
type="primary", | |
use_container_width=True, | |
disabled=True, | |
# disabled=(st.session_state.data_item is None or st.session_state.running), | |
): | |
st.session_state.running = True | |
st.rerun() | |
if st.session_state.running: | |
st.session_state.running = False | |
with st.status("Running...", expanded=False) as status: | |
utils.run(tmpdirname) | |
st.rerun() | |