File size: 3,575 Bytes
68facde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
import os
import glob
import shutil
import torch
import argparse
import mediapy
import cv2
import numpy as np
import gradio as gr 
from skimage import color, img_as_ubyte
from monai import transforms, data

os.system("git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc")
sys.path.append("pmrc/SwinUNETR/BTCV")
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig


ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)

model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
model.eval()

input_files = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')
input_files = dict((f.split('/')[-1], f) for f in input_files)

# Load and process dicom with monai transforms
test_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.AddChanneld(keys=["image"]),
        transforms.Spacingd(keys="image",
                            pixdim=(1.5, 1.5, 2.0),
                            mode="bilinear"),
        transforms.ScaleIntensityRanged(keys=["image"],
                                        a_min=-175.0,
                                        a_max=250.0,
                                        b_min=0.0,
                                        b_max=1.0,
                                        clip=True),
        # transforms.Resized(keys=["image"], spatial_size = (256,256,-1)), 
        transforms.ToTensord(keys=["image"]),
    ])



# Create Data Loader
def create_dl(test_files):
    ds = test_transform(test_files)
    loader = data.DataLoader(ds,
                             batch_size=1,
                             shuffle=False)
    return loader

# Inference and video generation
def generate_dicom_video(selected_file):
    
    test_file = input_files[selected_file]
    test_files = [{'image': test_file}]
    dl = create_dl(test_files)
    batch = next(iter(dl))
    
    tst_inputs = batch["image"]
    tst_inputs = tst_inputs[:,:,:,:,-32:]
    
    with torch.no_grad():
        outputs = model(tst_inputs,
                            (96,96,96),
                            8,
                            overlap=0.5,
                            mode="gaussian")
    tst_outputs = torch.softmax(outputs.logits, 1)
    tst_outputs = torch.argmax(tst_outputs, axis=1)
    
    # Write frames to video
    for inp, outp in zip(tst_inputs, tst_outputs):

        frames = []
        for idx in range(inp.shape[-1]):
            # Segmentation
            seg = outp[:,:,idx].numpy().astype(np.uint8)
            # Input dicom frame
            img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
            frame = color.label2rgb(seg,img, bg_label = 0)
            frame = img_as_ubyte(frame)
            frame = np.concatenate((img, frame), 1)
            frames.append(frame)
        mediapy.write_video("dicom.mp4", frames, fps=4)
    
    return 'dicom.mp4'


'''
test_file = glob.glob('pmrc/SwinUNETR/BTCV/dataset/imagesSampleTs/*.nii.gz')[0]
generate_dicom_video(test_file)
'''

demo = gr.Blocks()

with demo:
    selected_dicom_key = gr.inputs.Dropdown(
            choices=sorted(input_files),
            type="value",
            label="Select a dicom file")
    button_gen_video = gr.Button("Generate Video")
    output_interpolation = gr.Video(label="Generated Video")
    button_gen_video.click(fn=generate_dicom_video, inputs=selected_dicom_key, outputs=output_interpolation)

demo.launch(debug=True, enable_queue=True)