darragh's picture
Test video output
68facde
raw
history blame
3.58 kB
import sys
import os
import glob
import shutil
import torch
import argparse
import mediapy
import cv2
import numpy as np
import gradio as gr
from skimage import color, img_as_ubyte
from monai import transforms, data
os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc")
sys.path.append("pmrc/SwinUNETR/BTCV")
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig
ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)
model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
model.eval()
input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
input_files = dict((f.split('/')[-1], f) for f in input_files)
# Load and process dicom with monai transforms
test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.AddChanneld(keys=["image"]),
transforms.Spacingd(keys="image",
pixdim=(1.5, 1.5, 2.0),
mode="bilinear"),
transforms.ScaleIntensityRanged(keys=["image"],
a_min=-175.0,
a_max=250.0,
b_min=0.0,
b_max=1.0,
clip=True),
# transforms.Resized(keys=["image"], spatial_size = (256,256,-1)),
transforms.ToTensord(keys=["image"]),
])
# Create Data Loader
def create_dl(test_files):
ds = test_transform(test_files)
loader = data.DataLoader(ds,
batch_size=1,
shuffle=False)
return loader
# Inference and video generation
def generate_dicom_video(selected_file):
test_file = input_files[selected_file]
test_files = [{'image': test_file}]
dl = create_dl(test_files)
batch = next(iter(dl))
tst_inputs = batch["image"]
tst_inputs = tst_inputs[:,:,:,:,-32:]
with torch.no_grad():
outputs = model(tst_inputs,
(96,96,96),
8,
overlap=0.5,
mode="gaussian")
tst_outputs = torch.softmax(outputs.logits, 1)
tst_outputs = torch.argmax(tst_outputs, axis=1)
# Write frames to video
for inp, outp in zip(tst_inputs, tst_outputs):
frames = []
for idx in range(inp.shape[-1]):
# Segmentation
seg = outp[:,:,idx].numpy().astype(np.uint8)
# Input dicom frame
img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
frame = color.label2rgb(seg,img, bg_label = 0)
frame = img_as_ubyte(frame)
frame = np.concatenate((img, frame), 1)
frames.append(frame)
mediapy.write_video("dicom.mp4", frames, fps=4)
return 'dicom.mp4'
'''
test_file = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')[0]
generate_dicom_video(test_file)
'''
demo = gr.Blocks()
with demo:
selected_dicom_key = gr.inputs.Dropdown(
choices=sorted(input_files),
type="value",
label="Select a dicom file")
button_gen_video = gr.Button("Generate Video")
output_interpolation = gr.Video(label="Generated Video")
button_gen_video.click(fn=generate_dicom_video, inputs=selected_dicom_key, outputs=output_interpolation)
demo.launch(debug=True, enable_queue=True)