Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# StarDist Experiment Class | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import inspect | |
import os | |
import sys | |
import yaml | |
from base_ml.base_trainer import BaseTrainer | |
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
parentdir = os.path.dirname(currentdir) | |
sys.path.insert(0, parentdir) | |
from pathlib import Path | |
from typing import Callable, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import ( | |
ConstantLR, | |
CosineAnnealingLR, | |
ExponentialLR, | |
ReduceLROnPlateau, | |
SequentialLR, | |
_LRScheduler, | |
) | |
from torch.utils.data import Dataset | |
from torchinfo import summary | |
from base_ml.base_loss import retrieve_loss_fn | |
from cell_unireplknet.cell_segmentation.experiments.experiment_cellvit_pannuke_origin import ( | |
ExperimentCellVitPanNuke, | |
) | |
from cell_segmentation.trainer.trainer_stardist import CellViTStarDistTrainer | |
from models.segmentation.cell_segmentation.cellvit_stardist import ( | |
CellViTStarDist, | |
CellViT256StarDist, | |
CellViTSAMStarDist, | |
) | |
from models.segmentation.cell_segmentation.cellvit_stardist_shared import ( | |
CellViTStarDistShared, | |
CellViT256StarDistShared, | |
CellViTSAMStarDistShared, | |
) | |
from models.segmentation.cell_segmentation.cpp_net_stardist_rn50 import StarDistRN50 | |
class ExperimentCellViTStarDist(ExperimentCellVitPanNuke): | |
def load_dataset_setup(self, dataset_path: Union[Path, str]) -> None: | |
"""Load the configuration of the PanNuke cell segmentation dataset. | |
The dataset must have a dataset_config.yaml file in their dataset path with the following entries: | |
* tissue_types: describing the present tissue types with corresponding integer | |
* nuclei_types: describing the present nuclei types with corresponding integer | |
Args: | |
dataset_path (Union[Path, str]): Path to dataset folder | |
""" | |
dataset_config_path = Path(dataset_path) / "dataset_config.yaml" | |
with open(dataset_config_path, "r") as dataset_config_file: | |
yaml_config = yaml.safe_load(dataset_config_file) | |
self.dataset_config = dict(yaml_config) | |
def get_loss_fn(self, loss_fn_settings: dict) -> dict: | |
"""Create a dictionary with loss functions for all branches | |
Branches: "dist_map", "stardist_map", "nuclei_type_map", "tissue_types" | |
Args: | |
loss_fn_settings (dict): Dictionary with the loss function settings. Structure | |
branch_name(str): | |
loss_name(str): | |
loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss) | |
weight(float): Weighting factor as float value | |
(optional) args: Optional parameters for initializing the loss function | |
arg_name: value | |
If a branch is not provided, the defaults settings (described below) are used. | |
For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml | |
under the section "loss" | |
Example: | |
nuclei_type_map: | |
bce: | |
loss_fn: xentropy_loss | |
weight: 1 | |
dice: | |
loss_fn: dice_loss | |
weight: 1 | |
Returns: | |
dict: Dictionary with loss functions for each branch. Structure: | |
branch_name(str): | |
loss_name(str): | |
"loss_fn": Callable loss function | |
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
loss_name(str): | |
"loss_fn": Callable loss function | |
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
branch_name(str) | |
... | |
Default loss dictionary: | |
dist_map: | |
bceweighted: | |
loss_fn: BCEWithLogitsLoss | |
weight: 1 | |
stardist_map: | |
L1LossWeighted: | |
loss_fn: L1LossWeighted | |
weight: 1 | |
nuclei_type_map | |
bce: | |
loss_fn: xentropy_loss | |
weight: 1 | |
dice: | |
loss_fn: dice_loss | |
weight: 1 | |
tissue_type has no default loss and might be skipped | |
""" | |
loss_fn_dict = {} | |
if "dist_map" in loss_fn_settings.keys(): | |
loss_fn_dict["dist_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["dist_map"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["dist_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["dist_map"] = { | |
"bceweighted": { | |
"loss_fn": retrieve_loss_fn("BCEWithLogitsLoss"), | |
"weight": 1, | |
}, | |
} | |
if "stardist_map" in loss_fn_settings.keys(): | |
loss_fn_dict["stardist_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["stardist_map"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["stardist_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["stardist_map"] = { | |
"L1LossWeighted": { | |
"loss_fn": retrieve_loss_fn("L1LossWeighted"), | |
"weight": 1, | |
}, | |
} | |
if "nuclei_type_map" in loss_fn_settings.keys(): | |
loss_fn_dict["nuclei_type_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["nuclei_type_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["nuclei_type_map"] = { | |
"bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
"dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
} | |
if "tissue_types" in loss_fn_settings.keys(): | |
loss_fn_dict["tissue_types"] = {} | |
for loss_name, loss_sett in loss_fn_settings["tissue_types"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["tissue_types"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
# skip default tissue loss! | |
return loss_fn_dict | |
def get_scheduler(self, scheduler_type: str, optimizer: Optimizer) -> _LRScheduler: | |
"""Get the learning rate scheduler for CellViT | |
The configuration of the scheduler is given in the "training" -> "scheduler" section. | |
Currenlty, "constant", "exponential" and "cosine" schedulers are implemented. | |
Required parameters for implemented schedulers: | |
- "constant": None | |
- "exponential": gamma (optional, defaults to 0.95) | |
- "cosine": eta_min (optional, defaults to 1-e5) | |
- "reducelronplateau": everything hardcoded right now, uses vall los for checking | |
Args: | |
scheduler_type (str): Type of scheduler as a string. Currently implemented: | |
- "constant" (lowering by a factor of ten after 25 epochs, increasing after 50, decreasimg again after 75) | |
- "exponential" (ExponentialLR with given gamma, gamma defaults to 0.95) | |
- "cosine" (CosineAnnealingLR, eta_min as parameter, defaults to 1-e5) | |
optimizer (Optimizer): Optimizer | |
Returns: | |
_LRScheduler: PyTorch Scheduler | |
""" | |
implemented_schedulers = [ | |
"constant", | |
"exponential", | |
"cosine", | |
"reducelronplateau", | |
] | |
if scheduler_type.lower() not in implemented_schedulers: | |
self.logger.warning( | |
f"Unknown Scheduler - No scheduler from the list {implemented_schedulers} select. Using default scheduling." | |
) | |
if scheduler_type.lower() == "constant": | |
scheduler = SequentialLR( | |
optimizer=optimizer, | |
schedulers=[ | |
ConstantLR(optimizer, factor=1, total_iters=25), | |
ConstantLR(optimizer, factor=0.1, total_iters=25), | |
ConstantLR(optimizer, factor=1, total_iters=25), | |
ConstantLR(optimizer, factor=0.1, total_iters=1000), | |
], | |
milestones=[24, 49, 74], | |
) | |
elif scheduler_type.lower() == "exponential": | |
scheduler = ExponentialLR( | |
optimizer, | |
gamma=self.run_conf["training"]["scheduler"].get("gamma", 0.95), | |
) | |
elif scheduler_type.lower() == "cosine": | |
scheduler = CosineAnnealingLR( | |
optimizer, | |
T_max=self.run_conf["training"]["epochs"], | |
eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5), | |
) | |
elif scheduler_type.lower() == "reducelronplateau": | |
scheduler = ReduceLROnPlateau( | |
optimizer, | |
mode="min", | |
factor=0.5, | |
min_lr=0.0000001, | |
patience=10, | |
threshold=1e-20, | |
) | |
else: | |
scheduler = super().get_scheduler(optimizer) | |
return scheduler | |
def get_datasets( | |
self, | |
train_transforms: Callable = None, | |
val_transforms: Callable = None, | |
) -> Tuple[Dataset, Dataset]: | |
"""Retrieve training dataset and validation dataset | |
Args: | |
dataset_name (str): Name of dataset to use | |
train_transforms (Callable, optional): PyTorch transformations for train set. Defaults to None. | |
val_transforms (Callable, optional): PyTorch transformations for validation set. Defaults to None. | |
Returns: | |
Tuple[Dataset, Dataset]: Training dataset and validation dataset | |
""" | |
self.run_conf["data"]["stardist"] = True | |
train_dataset, val_dataset = super().get_datasets( | |
train_transforms=train_transforms, | |
val_transforms=val_transforms, | |
) | |
return train_dataset, val_dataset | |
def get_train_model( | |
self, | |
pretrained_encoder: Union[Path, str] = None, | |
pretrained_model: Union[Path, str] = None, | |
backbone_type: str = "default", | |
shared_decoders: bool = False, | |
**kwargs, | |
) -> nn.Module: | |
"""Return the CellViTStarDist training model | |
Args: | |
pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None. | |
pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None. | |
backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H, RN50). Defaults to None | |
shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False. | |
Returns: | |
nn.Module: StarDist training model with given setup | |
""" | |
# reseed needed, due to subprocess seeding compatibility | |
self.seed_run(self.default_conf["random_seed"]) | |
# check for backbones | |
implemented_backbones = ["default", "vit256", "sam-b", "sam-l", "sam-h", "rn50"] | |
if backbone_type.lower() not in implemented_backbones: | |
raise NotImplementedError( | |
f"Unknown Backbone Type - Currently supported are: {implemented_backbones}" | |
) | |
if backbone_type.lower() == "default": | |
if shared_decoders: | |
model_class = CellViTStarDistShared | |
else: | |
model_class = CellViTStarDist | |
model = model_class( | |
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
embed_dim=self.run_conf["model"]["embed_dim"], | |
input_channels=self.run_conf["model"].get("input_channels", 3), | |
depth=self.run_conf["model"]["depth"], | |
num_heads=self.run_conf["model"]["num_heads"], | |
extract_layers=self.run_conf["model"]["extract_layers"], | |
drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0), | |
nrays=self.run_conf["model"].get("nrays", 32), | |
) | |
if pretrained_model is not None: | |
self.logger.info( | |
f"Loading pretrained CellViT model from path: {pretrained_model}" | |
) | |
cellvit_pretrained = torch.load(pretrained_model) | |
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
self.logger.info("Loaded CellViT model") | |
if backbone_type.lower() == "vit256": | |
if shared_decoders: | |
model_class = CellViT256StarDistShared | |
else: | |
model_class = CellViT256StarDist | |
model = model_class( | |
model256_path=pretrained_encoder, | |
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0), | |
nrays=self.run_conf["model"].get("nrays", 32), | |
) | |
model.load_pretrained_encoder(model.model256_path) | |
if pretrained_model is not None: | |
self.logger.info( | |
f"Loading pretrained CellViT model from path: {pretrained_model}" | |
) | |
cellvit_pretrained = torch.load(pretrained_model, map_location="cpu") | |
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
model.freeze_encoder() | |
self.logger.info("Loaded CellVit256 model") | |
if backbone_type.lower() in ["sam-b", "sam-l", "sam-h"]: | |
if shared_decoders: | |
model_class = CellViTSAMStarDistShared | |
else: | |
model_class = CellViTSAMStarDist | |
model = model_class( | |
model_path=pretrained_encoder, | |
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
vit_structure=backbone_type, | |
drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
nrays=self.run_conf["model"].get("nrays", 32), | |
) | |
model.load_pretrained_encoder(model.model_path) | |
if pretrained_model is not None: | |
self.logger.info( | |
f"Loading pretrained CellViT model from path: {pretrained_model}" | |
) | |
cellvit_pretrained = torch.load(pretrained_model, map_location="cpu") | |
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
model.freeze_encoder() | |
self.logger.info(f"Loaded CellViT-SAM model with backbone: {backbone_type}") | |
if backbone_type.lower() == "rn50": | |
model = StarDistRN50( | |
n_rays=self.run_conf["model"].get("nrays", 32), | |
n_seg_cls=self.run_conf["data"]["num_nuclei_classes"], | |
) | |
self.logger.info(f"\nModel: {model}") | |
model = model.to("cpu") | |
self.logger.info( | |
f"\n{summary(model, input_size=(1, 3, 256, 256), device='cpu')}" | |
) | |
return model | |
def get_trainer(self) -> BaseTrainer: | |
"""Return Trainer matching to this network | |
Returns: | |
BaseTrainer: Trainer | |
""" | |
return CellViTStarDistTrainer | |