Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import glob | |
import shutil | |
import torch | |
import argparse | |
import mediapy | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
from skimage import color, img_as_ubyte | |
from monai import transforms, data | |
os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc") | |
sys.path.append("pmrc/SwinUNETR/BTCV") | |
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig | |
ffmpeg_path = shutil.which('ffmpeg') | |
mediapy.set_ffmpeg(ffmpeg_path) | |
model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny') | |
model.eval() | |
input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz') | |
input_files = dict((f.split('/')[-1], f) for f in input_files) | |
# Load and process dicom with monai transforms | |
test_transform = transforms.Compose( | |
[ | |
transforms.LoadImaged(keys=["image"]), | |
transforms.AddChanneld(keys=["image"]), | |
transforms.Spacingd(keys="image", | |
pixdim=(1.5, 1.5, 2.0), | |
mode="bilinear"), | |
transforms.ScaleIntensityRanged(keys=["image"], | |
a_min=-175.0, | |
a_max=250.0, | |
b_min=0.0, | |
b_max=1.0, | |
clip=True), | |
# transforms.Resized(keys=["image"], spatial_size = (256,256,-1)), | |
transforms.ToTensord(keys=["image"]), | |
]) | |
# Create Data Loader | |
def create_dl(test_files): | |
ds = test_transform(test_files) | |
loader = data.DataLoader(ds, | |
batch_size=1, | |
shuffle=False) | |
return loader | |
# Inference and video generation | |
def generate_dicom_video(selected_file): | |
test_file = input_files[selected_file] | |
test_files = [{'image': test_file}] | |
dl = create_dl(test_files) | |
batch = next(iter(dl)) | |
tst_inputs = batch["image"] | |
tst_inputs = tst_inputs[:,:,:,:,-32:] | |
with torch.no_grad(): | |
outputs = model(tst_inputs, | |
(96,96,96), | |
8, | |
overlap=0.5, | |
mode="gaussian") | |
tst_outputs = torch.softmax(outputs.logits, 1) | |
tst_outputs = torch.argmax(tst_outputs, axis=1) | |
# Write frames to video | |
for inp, outp in zip(tst_inputs, tst_outputs): | |
frames = [] | |
for idx in range(inp.shape[-1]): | |
# Segmentation | |
seg = outp[:,:,idx].numpy().astype(np.uint8) | |
# Input dicom frame | |
img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8) | |
img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB) | |
frame = color.label2rgb(seg,img, bg_label = 0) | |
frame = img_as_ubyte(frame) | |
frame = np.concatenate((img, frame), 1) | |
frames.append(frame) | |
mediapy.write_video("dicom.mp4", frames, fps=4) | |
return 'dicom.mp4' | |
''' | |
test_file = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')[0] | |
generate_dicom_video(test_file) | |
''' | |
demo = gr.Blocks() | |
with demo: | |
selected_dicom_key = gr.inputs.Dropdown( | |
choices=sorted(input_files), | |
type="value", | |
label="Select a dicom file") | |
button_gen_video = gr.Button("Generate Video") | |
output_interpolation = gr.Video(label="Generated Video") | |
button_gen_video.click(fn=generate_dicom_video, inputs=selected_dicom_key, outputs=output_interpolation) | |
demo.launch(debug=True, enable_queue=True) | |