smb-vision-base / tumor_mask.py
chenz53's picture
Upload tumor_mask.py
3f039bf verified
import numpy as np
from scipy import ndimage
def extract_tumor_and_peritumoral(mask_volume, peritumoral_margin=2, patch_size=(16, 16, 16)):
"""
Extract tumor and peritumoral regions from a 3D annotation mask.
Flattens dilated mask into sequence of patches and creates position mask.
Parameters:
mask_volume: 3D numpy array (z, y, x) with tumor annotations (1 for tumor, 0 for background)
peritumoral_margin: Integer specifying the margin size (in voxels) for peritumoral region
patch_size: Tuple (z,y,x) specifying size of patches to use
Returns:
tumor_coords: List of coordinates (z, y, x) for tumor region
peritumoral_coords: List of coordinates (z, y, x) for peritumoral region
patch_mask: Binary mask indicating if patches contain tumor (1) or not (0)
"""
# Get tumor coordinates
tumor_coords = np.where(mask_volume == 1)
tumor_coords = list(zip(tumor_coords[0], tumor_coords[1], tumor_coords[2]))
# Create dilated mask for peritumoral region
dilated_mask = ndimage.binary_dilation(
mask_volume,
structure=np.ones((peritumoral_margin * 2 + 1, peritumoral_margin * 2 + 1, peritumoral_margin * 2 + 1)),
)
# Create patch position mask
z_steps = mask_volume.shape[0] // patch_size[0]
y_steps = mask_volume.shape[1] // patch_size[1]
x_steps = mask_volume.shape[2] // patch_size[2]
patch_mask = np.zeros((z_steps, y_steps, x_steps))
for z in range(z_steps):
for y in range(y_steps):
for x in range(x_steps):
patch = dilated_mask[
z * patch_size[0] : (z + 1) * patch_size[0],
y * patch_size[1] : (y + 1) * patch_size[1],
x * patch_size[2] : (x + 1) * patch_size[2],
]
if np.any(patch):
patch_mask[z, y, x] = 1
return tumor_coords, patch_mask.flatten()
# Example usage
def main():
# Create sample data for testing
volume_shape = (96, 96, 96)
mask_volume = np.zeros(volume_shape)
# Create a synthetic tumor mask in the middle
mask_volume[40:60, 40:60, 40:60] = 1
# Test parameters
patch_size = (16, 16, 16)
peritumoral_margin = 5
# Call function and get results
tumor_coords, patch_mask = extract_tumor_and_peritumoral(
mask_volume, peritumoral_margin=peritumoral_margin, patch_size=patch_size
)
# Print test results
print(f"Volume shape: {volume_shape}")
print(f"Tumor volume: {len(tumor_coords)}")
print(f"Number of total patches: {patch_mask.shape}")
print(f"Number of patches containing tumor/peritumoral region: {np.sum(patch_mask)}")
# Validate results
assert len(tumor_coords) > 0, "No tumor coordinates found"
assert len(patch_mask) == np.prod(np.array(volume_shape) // np.array(patch_size)), "Incorrect patch mask size"
return tumor_coords, patch_mask
if __name__ == "__main__":
tumor_coords, patch_mask = main()