Spaces:
Runtime error
Runtime error
File size: 3,575 Bytes
68facde |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import sys
import os
import glob
import shutil
import torch
import argparse
import mediapy
import cv2
import numpy as np
import gradio as gr
from skimage import color, img_as_ubyte
from monai import transforms, data
os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc")
sys.path.append("pmrc/SwinUNETR/BTCV")
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig
ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)
model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
model.eval()
input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
input_files = dict((f.split('/')[-1], f) for f in input_files)
# Load and process dicom with monai transforms
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.AddChanneld(keys=["image"]),
transforms.Spacingd(keys="image",
pixdim=(1.5, 1.5, 2.0),
mode="bilinear"),
transforms.ScaleIntensityRanged(keys=["image"],
a_min=-175.0,
a_max=250.0,
b_min=0.0,
b_max=1.0,
clip=True),
# transforms.Resized(keys=["image"], spatial_size = (256,256,-1)),
transforms.ToTensord(keys=["image"]),
])
# Create Data Loader
def create_dl(test_files):
ds = test_transform(test_files)
loader = data.DataLoader(ds,
batch_size=1,
shuffle=False)
return loader
# Inference and video generation
def generate_dicom_video(selected_file):
test_file = input_files[selected_file]
test_files = [{'image': test_file}]
dl = create_dl(test_files)
batch = next(iter(dl))
tst_inputs = batch["image"]
tst_inputs = tst_inputs[:,:,:,:,-32:]
with torch.no_grad():
outputs = model(tst_inputs,
(96,96,96),
8,
overlap=0.5,
mode="gaussian")
tst_outputs = torch.softmax(outputs.logits, 1)
tst_outputs = torch.argmax(tst_outputs, axis=1)
# Write frames to video
for inp, outp in zip(tst_inputs, tst_outputs):
frames = []
for idx in range(inp.shape[-1]):
# Segmentation
seg = outp[:,:,idx].numpy().astype(np.uint8)
# Input dicom frame
img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
frame = color.label2rgb(seg,img, bg_label = 0)
frame = img_as_ubyte(frame)
frame = np.concatenate((img, frame), 1)
frames.append(frame)
mediapy.write_video("dicom.mp4", frames, fps=4)
return 'dicom.mp4'
'''
test_file = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')[0]
generate_dicom_video(test_file)
'''
demo = gr.Blocks()
with demo:
selected_dicom_key = gr.inputs.Dropdown(
choices=sorted(input_files),
type="value",
label="Select a dicom file")
button_gen_video = gr.Button("Generate Video")
output_interpolation = gr.Video(label="Generated Video")
button_gen_video.click(fn=generate_dicom_video, inputs=selected_dicom_key, outputs=output_interpolation)
demo.launch(debug=True, enable_queue=True)
|