|
import numpy as np |
|
from typing import List, Tuple |
|
from monai.transforms import Compose, AddChannelD, MaskIntensityD, DeleteItemsD, CropForegroundD, ResizeD |
|
|
|
class SelectMaskByLevelD: |
|
""" |
|
Selects a mask segment from a mask image based on a given level index. May also |
|
be applied to a single channel. |
|
""" |
|
def __init__(self, mask_key: str, level_idx_key: str): |
|
self.mask_key = mask_key |
|
self.level_idx_key = level_idx_key |
|
|
|
def __call__(self, data): |
|
d = dict(data) |
|
mask = np.zeros_like(d[self.mask_key]) |
|
mask[d[self.mask_key] == d[self.level_idx_key]] = 1 |
|
d[self.mask_key] = mask |
|
return d |
|
|
|
def get_mask_transform(hparams, loaded_keys: List[str], level_idx_key='level_idx') -> Tuple[Compose, List[str]]: |
|
""" |
|
Depending on the configuration values for 'MASK', the transform returned by this method does one of the following: |
|
- nothing ('none') |
|
- applies the mask of the critical vertebra to the image ('apply') |
|
- applies the mask of all visible vertebrae to the image ('apply_all') |
|
- loads the mask into the 'mask' key s.t. it will later be stacked with the image ('channel') |
|
- crop the image to the critical vertebra and upsample it ('crop') |
|
""" |
|
|
|
if hparams.mask == 'none': |
|
return Compose([]), loaded_keys |
|
|
|
assert len(loaded_keys) == 2 |
|
image_key, mask_key = loaded_keys |
|
|
|
if hparams.mask == 'apply': |
|
return Compose([ |
|
|
|
SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key), |
|
|
|
MaskIntensityD(keys=image_key, mask_key=mask_key), |
|
|
|
DeleteItemsD(keys=mask_key), |
|
]), [image_key] |
|
|
|
elif hparams.mask == 'apply_all': |
|
return Compose([ |
|
|
|
|
|
MaskIntensityD(keys=image_key, mask_key=mask_key), |
|
|
|
DeleteItemsD(keys=mask_key), |
|
]), [image_key] |
|
|
|
elif hparams.mask == 'channel': |
|
return Compose([ |
|
SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key), |
|
]), loaded_keys |
|
|
|
elif hparams.mask == 'crop': |
|
|
|
return Compose([ |
|
SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key), |
|
CropForegroundD(keys=image_key, source_key=mask_key, margin=2), |
|
DeleteItemsD(keys=mask_key), |
|
AddChannelD(keys=image_key), |
|
ResizeD(keys=image_key, spatial_size=[hparams.input_size] * hparams.input_dim, mode='trilinear'), |
|
]), [image_key] |