| | import cv2 |
| | import glob |
| | import numpy as np |
| | import os |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from einops import rearrange |
| | from transformers import PreTrainedModel |
| | from timm import create_model |
| |
|
| | from .configuration import TotalClassifierConfig |
| | from .label2index import label2index |
| |
|
| | _PYDICOM_AVAILABLE = False |
| | try: |
| | from pydicom import dcmread |
| |
|
| | _PYDICOM_AVAILABLE = True |
| | except ModuleNotFoundError: |
| | pass |
| |
|
| | _PANDAS_AVAILABLE = False |
| | try: |
| | import pandas as pd |
| |
|
| | _PANDAS_AVAILABLE = True |
| | except ModuleNotFoundError: |
| | pass |
| |
|
| |
|
| | class RNNHead(nn.Module): |
| | def __init__( |
| | self, |
| | rnn_type: str, |
| | rnn_num_layers: int, |
| | rnn_dropout: float, |
| | feature_dim: int, |
| | linear_dropout: float, |
| | num_classes: int, |
| | ): |
| | super().__init__() |
| | self.rnn = getattr(nn, rnn_type)( |
| | input_size=feature_dim, |
| | hidden_size=feature_dim // 2, |
| | num_layers=rnn_num_layers, |
| | dropout=rnn_dropout, |
| | batch_first=True, |
| | bidirectional=True, |
| | ) |
| | self.dropout = nn.Dropout(linear_dropout) |
| | self.linear = nn.Linear(feature_dim, num_classes) |
| |
|
| | @staticmethod |
| | def convert_seq_and_mask_to_packed_sequence( |
| | seq: torch.Tensor, mask: torch.Tensor |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | assert seq.shape[0] == mask.shape[0] |
| | lengths = mask.sum(1) |
| | seq = nn.utils.rnn.pack_padded_sequence( |
| | seq, lengths.cpu().int(), batch_first=True, enforce_sorted=False |
| | ) |
| | return seq |
| |
|
| | def forward( |
| | self, x: torch.Tensor, mask: torch.Tensor | None = None |
| | ) -> torch.Tensor: |
| | skip = x |
| | if mask is not None: |
| | |
| | L = x.shape[1] |
| | x = self.convert_seq_and_mask_to_packed_sequence(x, mask) |
| |
|
| | x, _ = self.rnn(x) |
| |
|
| | if mask is not None: |
| | |
| | x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=L)[0] |
| |
|
| | x = x + skip |
| | return self.linear(self.dropout(x)) |
| |
|
| |
|
| | class TotalClassifierModel(PreTrainedModel): |
| | config_class = TotalClassifierConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.image_size = config.image_size |
| | self.backbone = create_model( |
| | model_name=config.backbone, |
| | pretrained=False, |
| | num_classes=0, |
| | global_pool="", |
| | features_only=True, |
| | in_chans=config.in_chans, |
| | ) |
| | self.cnn_dropout = nn.Dropout(p=config.cnn_dropout) |
| | self.head = RNNHead( |
| | rnn_type=config.rnn_type, |
| | rnn_num_layers=config.rnn_num_layers, |
| | rnn_dropout=config.rnn_dropout, |
| | feature_dim=config.feature_dim, |
| | linear_dropout=config.linear_dropout, |
| | num_classes=config.num_classes, |
| | ) |
| | self.label2index = label2index |
| |
|
| | self.index2label = {v: k for k, v in self.label2index.items()} |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | mask: torch.Tensor | None = None, |
| | return_logits: bool = False, |
| | return_as_dict: bool = False, |
| | return_as_list: bool = False, |
| | return_as_df: bool = False, |
| | threshold: float = 0.5, |
| | ) -> torch.Tensor: |
| | if return_as_df: |
| | assert ( |
| | _PANDAS_AVAILABLE |
| | ), "`return_as_df=True` requires pandas to be installed" |
| | |
| | b, n, c, h, w = x.shape |
| | |
| | x = x.reshape(b * n, c, h, w) |
| | x = self.normalize(x) |
| | |
| | features = self.backbone(x) |
| | |
| | features = F.adaptive_avg_pool2d(features[-1], 1).flatten(1) |
| | features = self.cnn_dropout(features) |
| | |
| | features = features.reshape(b, n, -1) |
| | logits = self.head(features, mask=mask) |
| | if return_logits: |
| | |
| | return logits |
| | probas = logits.sigmoid() |
| |
|
| | if return_as_dict or return_as_df: |
| | |
| | batch_list = [] |
| | for i in range(probas.shape[0]): |
| | dict_for_batch = {} |
| | probas_i = probas[i] |
| | for each_class in range(probas_i.shape[1]): |
| | dict_for_batch[self.index2label[each_class]] = probas_i[ |
| | :, each_class |
| | ] |
| | if return_as_df: |
| | batch_list.append( |
| | pd.DataFrame( |
| | {k: v.cpu().numpy() for k, v in dict_for_batch.items()} |
| | ) |
| | ) |
| | else: |
| | batch_list.append(dict_for_batch) |
| | return batch_list |
| |
|
| | if return_as_list: |
| | |
| | |
| | |
| | |
| | batch_list = [] |
| | |
| | for i in range(probas.shape[0]): |
| | probas_i = probas[i] |
| | |
| | list_for_batch = [] |
| | for each_slice in range(probas_i.shape[0]): |
| | for each_class in range(probas_i.shape[1]): |
| | list_for_batch.append( |
| | [ |
| | self.index2label[each_class] |
| | for each_class in range(probas_i.shape[1]) |
| | if probas_i[each_slice, each_class] >= threshold |
| | ] |
| | ) |
| | batch_list.append(list_for_batch) |
| | return batch_list |
| |
|
| | return probas |
| |
|
| | def normalize(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | mini, maxi = 0.0, 255.0 |
| | x = (x - mini) / (maxi - mini) |
| | x = (x - 0.5) * 2.0 |
| | return x |
| |
|
| | @staticmethod |
| | def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]: |
| | |
| | lower, upper = WL - WW // 2, WL + WW // 2 |
| | x = np.clip(x, lower, upper) |
| | x = (x - lower) / (upper - lower) |
| | return (x * 255.0).astype("uint8") |
| |
|
| | @staticmethod |
| | def validate_windows_type(windows): |
| | assert isinstance(windows, tuple) or isinstance(windows, list) |
| | if isinstance(windows, tuple): |
| | assert len(windows) == 2 |
| | assert [isinstance(_, int) for _ in windows] |
| | elif isinstance(windows, list): |
| | assert all([isinstance(_, tuple) for _ in windows]) |
| | assert all([len(_) == 2 for _ in windows]) |
| | assert all([isinstance(__, int) for _ in windows for __ in _]) |
| |
|
| | @staticmethod |
| | def determine_dicom_orientation(ds) -> int: |
| | iop = ds.ImageOrientationPatient |
| |
|
| | |
| | normal_vector = np.cross(iop[:3], iop[3:]) |
| |
|
| | |
| | abs_normal = np.abs(normal_vector) |
| | if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: |
| | return 0 |
| | elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: |
| | return 1 |
| | else: |
| | return 2 |
| |
|
| | def load_image_from_dicom( |
| | self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None |
| | ) -> np.ndarray: |
| | |
| | |
| | |
| | if not _PYDICOM_AVAILABLE: |
| | raise Exception("`pydicom` is not installed") |
| | dicom = dcmread(path) |
| | array = dicom.pixel_array.astype("float32") |
| | m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept) |
| | array = array * m + b |
| | if windows is None: |
| | return array |
| |
|
| | self.validate_windows_type(windows) |
| | if isinstance(windows, tuple): |
| | windows = [windows] |
| |
|
| | arr_list = [] |
| | for WL, WW in windows: |
| | arr_list.append(self.window(array.copy(), WL, WW)) |
| |
|
| | array = np.stack(arr_list, axis=-1) |
| | if array.shape[-1] == 1: |
| | array = np.squeeze(array, axis=-1) |
| |
|
| | return array |
| |
|
| | @staticmethod |
| | def is_valid_dicom( |
| | ds, |
| | fname: str = "", |
| | sort_by_instance_number: bool = False, |
| | exclude_invalid_dicoms: bool = False, |
| | ) -> bool: |
| | attributes = [ |
| | "pixel_array", |
| | "RescaleSlope", |
| | "RescaleIntercept", |
| | ] |
| | if sort_by_instance_number: |
| | attributes.append("InstanceNumber") |
| | else: |
| | attributes.append("ImagePositionPatient") |
| | attributes.append("ImageOrientationPatient") |
| | attributes_present = [hasattr(ds, attr) for attr in attributes] |
| | valid = all(attributes_present) |
| | if not valid and not exclude_invalid_dicoms: |
| | raise Exception( |
| | f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}" |
| | ) |
| | return valid |
| |
|
| | @staticmethod |
| | def most_common_element(lst): |
| | return max(set(lst), key=lst.count) |
| |
|
| | @staticmethod |
| | def center_crop_or_pad_borders(image, size): |
| | height, width = image.shape[:2] |
| | new_height, new_width = size |
| | if new_height < height: |
| | |
| | crop_top = (height - new_height) // 2 |
| | crop_bottom = height - new_height - crop_top |
| | image = image[crop_top:-crop_bottom] |
| | elif new_height > height: |
| | |
| | pad_top = (new_height - height) // 2 |
| | pad_bottom = new_height - height - pad_top |
| | image = np.pad( |
| | image, |
| | ((pad_top, pad_bottom), (0, 0)), |
| | mode="constant", |
| | constant_values=0, |
| | ) |
| |
|
| | if new_width < width: |
| | |
| | crop_left = (width - new_width) // 2 |
| | crop_right = width - new_width - crop_left |
| | image = image[:, crop_left:-crop_right] |
| | elif new_width > width: |
| | |
| | pad_left = (new_width - width) // 2 |
| | pad_right = new_width - width - pad_left |
| | image = np.pad( |
| | image, |
| | ((0, 0), (pad_left, pad_right)), |
| | mode="constant", |
| | constant_values=0, |
| | ) |
| |
|
| | return image |
| |
|
| | def load_stack_from_dicom_folder( |
| | self, |
| | path: str, |
| | windows: tuple[int, int] | list[tuple[int, int]] | None = None, |
| | dicom_extension: str = ".dcm", |
| | sort_by_instance_number: bool = False, |
| | exclude_invalid_dicoms: bool = False, |
| | fix_unequal_shapes: str = "crop_pad", |
| | return_sorted_dicom_files: bool = False, |
| | ) -> np.ndarray | tuple[np.ndarray, list[str]]: |
| | if not _PYDICOM_AVAILABLE: |
| | raise Exception("`pydicom` is not installed") |
| | dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}")) |
| | if len(dicom_files) == 0: |
| | raise Exception( |
| | f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`" |
| | ) |
| | dicoms = [dcmread(f) for f in dicom_files] |
| | dicoms = [ |
| | (d, dicom_files[idx]) |
| | for idx, d in enumerate(dicoms) |
| | if self.is_valid_dicom( |
| | d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms |
| | ) |
| | ] |
| | |
| | |
| | dicom_files = [_[1] for _ in dicoms] |
| | dicoms = [_[0] for _ in dicoms] |
| |
|
| | slices = [dcm.pixel_array.astype("float32") for dcm in dicoms] |
| | shapes = np.stack([s.shape for s in slices], axis=0) |
| | if not np.all(shapes == shapes[0]): |
| | unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True) |
| | standard_shape = tuple(unique_shapes[np.argmax(counts)]) |
| | print( |
| | f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}" |
| | ) |
| | if fix_unequal_shapes == "crop_pad": |
| | slices = [ |
| | self.center_crop_or_pad_borders(s, standard_shape) |
| | if s.shape != standard_shape |
| | else s |
| | for s in slices |
| | ] |
| | elif fix_unequal_shapes == "resize": |
| | slices = [ |
| | cv2.resize(s, standard_shape) if s.shape != standard_shape else s |
| | for s in slices |
| | ] |
| | slices = np.stack(slices, axis=0) |
| | |
| | orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms] |
| | |
| | orientation = self.most_common_element(orientation) |
| |
|
| | |
| | |
| | if sort_by_instance_number: |
| | positions = [float(d.InstanceNumber) for d in dicoms] |
| | else: |
| | positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms] |
| | indices = np.argsort(positions) |
| | slices = slices[indices] |
| |
|
| | |
| | m, b = ( |
| | [float(d.RescaleSlope) for d in dicoms], |
| | [float(d.RescaleIntercept) for d in dicoms], |
| | ) |
| | m, b = self.most_common_element(m), self.most_common_element(b) |
| | slices = slices * m + b |
| | if windows is not None: |
| | self.validate_windows_type(windows) |
| | if isinstance(windows, tuple): |
| | windows = [windows] |
| |
|
| | arr_list = [] |
| | for WL, WW in windows: |
| | arr_list.append(self.window(slices.copy(), WL, WW)) |
| |
|
| | slices = np.stack(arr_list, axis=-1) |
| | if slices.shape[-1] == 1: |
| | slices = np.squeeze(slices, axis=-1) |
| |
|
| | if return_sorted_dicom_files: |
| | return slices, [dicom_files[idx] for idx in indices] |
| | return slices |
| |
|
| | def preprocess( |
| | self, |
| | x: np.ndarray, |
| | mode: str = "2d", |
| | torchify: bool = True, |
| | add_batch_dim: bool = False, |
| | device: str | torch.device | None = None, |
| | ) -> np.ndarray: |
| | if device is not None: |
| | assert torchify, "`torchify` must be `True` if specifying `device`" |
| | mode = mode.lower() |
| | if mode == "2d": |
| | x = cv2.resize(x, self.image_size) |
| | if x.ndim == 2: |
| | x = x[:, :, np.newaxis] |
| | elif mode == "3d": |
| | x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0) |
| | if x.ndim == 3: |
| | x = x[:, :, :, np.newaxis] |
| | if torchify: |
| | if x.ndim == 3: |
| | x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w") |
| | elif x.ndim == 4: |
| | x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w") |
| | if add_batch_dim: |
| | if torchify: |
| | x = x.unsqueeze(0) |
| | else: |
| | x = x[np.newaxis] |
| | if device is not None: |
| | x = x.to(device) |
| | return x |
| |
|
| | def crop_single_plane( |
| | self, |
| | x: np.ndarray, |
| | device: str | torch.device, |
| | organ: str | list[str], |
| | threshold: float = 0.5, |
| | buffer: float | int = 0, |
| | speed_up: str | None = None, |
| | ) -> np.ndarray: |
| | num_slices = x.shape[0] |
| | if speed_up is not None: |
| | assert speed_up in ["fast", "faster", "fastest"] |
| | if speed_up == "fast": |
| | |
| | reduce_num_slices = 3 * num_slices // 4 |
| | elif speed_up == "faster": |
| | |
| | reduce_num_slices = num_slices // 2 |
| | elif speed_up == "fastest": |
| | |
| | reduce_num_slices = num_slices // 3 |
| | indices = np.linspace(0, num_slices - 1, reduce_num_slices).astype(int) |
| | x = x[indices] |
| | x = self.preprocess(x, mode="3d") |
| | x = torch.from_numpy(x) |
| | x = rearrange(x, "n h w c -> n c h w").float().to(device) |
| | x = rearrange(x, "n c h w -> 1 n c h w") |
| | if x.size(2) > 1: |
| | |
| | x = x.mean(2, keepdim=True) |
| | organ_cls = self.forward(x)[0] |
| | if speed_up is not None: |
| | |
| | organ_cls = ( |
| | F.interpolate( |
| | organ_cls.transpose(1, 0).unsqueeze(0), |
| | size=(num_slices,), |
| | mode="linear", |
| | ) |
| | .squeeze(0) |
| | .transpose(1, 0) |
| | ) |
| | assert organ_cls.shape[0] == num_slices |
| | slices = [] |
| | for each_organ in organ: |
| | slices.append( |
| | torch.where(organ_cls[:, self.label2index[each_organ]] >= threshold)[0] |
| | ) |
| | slices = torch.cat(slices) |
| | slice_min, slice_max = slices.min().item(), slices.max().item() |
| | if buffer > 0: |
| | if isinstance(buffer, float): |
| | |
| | diff = slice_max - slice_min |
| | buf = int(buffer * diff) |
| | else: |
| | |
| | buf = buffer |
| | slice_min = max(0, slice_min - buf) |
| | slice_max = min(num_slices - 1, slice_max + buf) |
| | return slice_min, slice_max |
| |
|
| | @torch.no_grad() |
| | def crop( |
| | self, |
| | x: np.ndarray, |
| | organ: str | list[str], |
| | crop_dims: int | list[int] = 0, |
| | device: str | torch.device | None = None, |
| | raw_hu: bool = False, |
| | threshold: float = 0.5, |
| | buffer: float | int = 0, |
| | speed_up: str | None = None, |
| | ) -> ( |
| | np.ndarray |
| | | tuple[np.ndarray, list[int]] |
| | | tuple[np.ndarray, list[int], list[int]] |
| | ): |
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | assert isinstance(x, np.ndarray) |
| | assert x.ndim in { |
| | 3, |
| | 4, |
| | }, f"x should be a 3D or 4D array, but got {x.ndim} dimensions" |
| |
|
| | if raw_hu: |
| | |
| | x = self.window(x, WL=50, WW=400) |
| |
|
| | x0 = x |
| | if not isinstance(organ, list): |
| | organ = [organ] |
| | if not isinstance(crop_dims, list): |
| | crop_dims = [crop_dims] |
| |
|
| | assert max(crop_dims) <= 2 |
| | assert min(crop_dims) >= 0 |
| |
|
| | if isinstance(buffer, float): |
| | |
| | assert buffer < 1 |
| |
|
| | if 0 in crop_dims: |
| | smin0, smax0 = self.crop_single_plane( |
| | x0, device, organ, threshold, buffer, speed_up |
| | ) |
| | else: |
| | smin0, smax0 = 0, x0.shape[0] |
| |
|
| | if 1 in crop_dims: |
| | |
| | x = x0.swapaxes(1, 0) |
| | smin1, smax1 = self.crop_single_plane( |
| | x, device, organ, threshold, buffer, speed_up |
| | ) |
| | else: |
| | smin1, smax1 = 0, x0.shape[1] |
| |
|
| | if 2 in crop_dims: |
| | |
| | x = x0.swapaxes(2, 0) |
| | smin2, smax2 = self.crop_single_plane( |
| | x, device, organ, threshold, buffer, speed_up |
| | ) |
| | else: |
| | smin2, smax2 = 0, x0.shape[2] |
| |
|
| | return x0[smin0 : smax0 + 1, smin1 : smax1 + 1, smin2 : smax2 + 1] |
| |
|