LKCell / cell_segmentation /experiments /experiment_stardist_pannuke.py
qingke1's picture
initial commit
aea73e2
# -*- coding: utf-8 -*-
# StarDist Experiment Class
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import inspect
import os
import sys
import yaml
from base_ml.base_trainer import BaseTrainer
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
from pathlib import Path
from typing import Callable, Tuple, Union
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
ConstantLR,
CosineAnnealingLR,
ExponentialLR,
ReduceLROnPlateau,
SequentialLR,
_LRScheduler,
)
from torch.utils.data import Dataset
from torchinfo import summary
from base_ml.base_loss import retrieve_loss_fn
from cell_unireplknet.cell_segmentation.experiments.experiment_cellvit_pannuke_origin import (
ExperimentCellVitPanNuke,
)
from cell_segmentation.trainer.trainer_stardist import CellViTStarDistTrainer
from models.segmentation.cell_segmentation.cellvit_stardist import (
CellViTStarDist,
CellViT256StarDist,
CellViTSAMStarDist,
)
from models.segmentation.cell_segmentation.cellvit_stardist_shared import (
CellViTStarDistShared,
CellViT256StarDistShared,
CellViTSAMStarDistShared,
)
from models.segmentation.cell_segmentation.cpp_net_stardist_rn50 import StarDistRN50
class ExperimentCellViTStarDist(ExperimentCellVitPanNuke):
def load_dataset_setup(self, dataset_path: Union[Path, str]) -> None:
"""Load the configuration of the PanNuke cell segmentation dataset.
The dataset must have a dataset_config.yaml file in their dataset path with the following entries:
* tissue_types: describing the present tissue types with corresponding integer
* nuclei_types: describing the present nuclei types with corresponding integer
Args:
dataset_path (Union[Path, str]): Path to dataset folder
"""
dataset_config_path = Path(dataset_path) / "dataset_config.yaml"
with open(dataset_config_path, "r") as dataset_config_file:
yaml_config = yaml.safe_load(dataset_config_file)
self.dataset_config = dict(yaml_config)
def get_loss_fn(self, loss_fn_settings: dict) -> dict:
"""Create a dictionary with loss functions for all branches
Branches: "dist_map", "stardist_map", "nuclei_type_map", "tissue_types"
Args:
loss_fn_settings (dict): Dictionary with the loss function settings. Structure
branch_name(str):
loss_name(str):
loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss)
weight(float): Weighting factor as float value
(optional) args: Optional parameters for initializing the loss function
arg_name: value
If a branch is not provided, the defaults settings (described below) are used.
For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml
under the section "loss"
Example:
nuclei_type_map:
bce:
loss_fn: xentropy_loss
weight: 1
dice:
loss_fn: dice_loss
weight: 1
Returns:
dict: Dictionary with loss functions for each branch. Structure:
branch_name(str):
loss_name(str):
"loss_fn": Callable loss function
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass
loss_name(str):
"loss_fn": Callable loss function
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass
branch_name(str)
...
Default loss dictionary:
dist_map:
bceweighted:
loss_fn: BCEWithLogitsLoss
weight: 1
stardist_map:
L1LossWeighted:
loss_fn: L1LossWeighted
weight: 1
nuclei_type_map
bce:
loss_fn: xentropy_loss
weight: 1
dice:
loss_fn: dice_loss
weight: 1
tissue_type has no default loss and might be skipped
"""
loss_fn_dict = {}
if "dist_map" in loss_fn_settings.keys():
loss_fn_dict["dist_map"] = {}
for loss_name, loss_sett in loss_fn_settings["dist_map"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["dist_map"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["dist_map"] = {
"bceweighted": {
"loss_fn": retrieve_loss_fn("BCEWithLogitsLoss"),
"weight": 1,
},
}
if "stardist_map" in loss_fn_settings.keys():
loss_fn_dict["stardist_map"] = {}
for loss_name, loss_sett in loss_fn_settings["stardist_map"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["stardist_map"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["stardist_map"] = {
"L1LossWeighted": {
"loss_fn": retrieve_loss_fn("L1LossWeighted"),
"weight": 1,
},
}
if "nuclei_type_map" in loss_fn_settings.keys():
loss_fn_dict["nuclei_type_map"] = {}
for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["nuclei_type_map"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
else:
loss_fn_dict["nuclei_type_map"] = {
"bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1},
"dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1},
}
if "tissue_types" in loss_fn_settings.keys():
loss_fn_dict["tissue_types"] = {}
for loss_name, loss_sett in loss_fn_settings["tissue_types"].items():
parameters = loss_sett.get("args", {})
loss_fn_dict["tissue_types"][loss_name] = {
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters),
"weight": loss_sett["weight"],
}
# skip default tissue loss!
return loss_fn_dict
def get_scheduler(self, scheduler_type: str, optimizer: Optimizer) -> _LRScheduler:
"""Get the learning rate scheduler for CellViT
The configuration of the scheduler is given in the "training" -> "scheduler" section.
Currenlty, "constant", "exponential" and "cosine" schedulers are implemented.
Required parameters for implemented schedulers:
- "constant": None
- "exponential": gamma (optional, defaults to 0.95)
- "cosine": eta_min (optional, defaults to 1-e5)
- "reducelronplateau": everything hardcoded right now, uses vall los for checking
Args:
scheduler_type (str): Type of scheduler as a string. Currently implemented:
- "constant" (lowering by a factor of ten after 25 epochs, increasing after 50, decreasimg again after 75)
- "exponential" (ExponentialLR with given gamma, gamma defaults to 0.95)
- "cosine" (CosineAnnealingLR, eta_min as parameter, defaults to 1-e5)
optimizer (Optimizer): Optimizer
Returns:
_LRScheduler: PyTorch Scheduler
"""
implemented_schedulers = [
"constant",
"exponential",
"cosine",
"reducelronplateau",
]
if scheduler_type.lower() not in implemented_schedulers:
self.logger.warning(
f"Unknown Scheduler - No scheduler from the list {implemented_schedulers} select. Using default scheduling."
)
if scheduler_type.lower() == "constant":
scheduler = SequentialLR(
optimizer=optimizer,
schedulers=[
ConstantLR(optimizer, factor=1, total_iters=25),
ConstantLR(optimizer, factor=0.1, total_iters=25),
ConstantLR(optimizer, factor=1, total_iters=25),
ConstantLR(optimizer, factor=0.1, total_iters=1000),
],
milestones=[24, 49, 74],
)
elif scheduler_type.lower() == "exponential":
scheduler = ExponentialLR(
optimizer,
gamma=self.run_conf["training"]["scheduler"].get("gamma", 0.95),
)
elif scheduler_type.lower() == "cosine":
scheduler = CosineAnnealingLR(
optimizer,
T_max=self.run_conf["training"]["epochs"],
eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5),
)
elif scheduler_type.lower() == "reducelronplateau":
scheduler = ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
min_lr=0.0000001,
patience=10,
threshold=1e-20,
)
else:
scheduler = super().get_scheduler(optimizer)
return scheduler
def get_datasets(
self,
train_transforms: Callable = None,
val_transforms: Callable = None,
) -> Tuple[Dataset, Dataset]:
"""Retrieve training dataset and validation dataset
Args:
dataset_name (str): Name of dataset to use
train_transforms (Callable, optional): PyTorch transformations for train set. Defaults to None.
val_transforms (Callable, optional): PyTorch transformations for validation set. Defaults to None.
Returns:
Tuple[Dataset, Dataset]: Training dataset and validation dataset
"""
self.run_conf["data"]["stardist"] = True
train_dataset, val_dataset = super().get_datasets(
train_transforms=train_transforms,
val_transforms=val_transforms,
)
return train_dataset, val_dataset
def get_train_model(
self,
pretrained_encoder: Union[Path, str] = None,
pretrained_model: Union[Path, str] = None,
backbone_type: str = "default",
shared_decoders: bool = False,
**kwargs,
) -> nn.Module:
"""Return the CellViTStarDist training model
Args:
pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None.
pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None.
backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H, RN50). Defaults to None
shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False.
Returns:
nn.Module: StarDist training model with given setup
"""
# reseed needed, due to subprocess seeding compatibility
self.seed_run(self.default_conf["random_seed"])
# check for backbones
implemented_backbones = ["default", "vit256", "sam-b", "sam-l", "sam-h", "rn50"]
if backbone_type.lower() not in implemented_backbones:
raise NotImplementedError(
f"Unknown Backbone Type - Currently supported are: {implemented_backbones}"
)
if backbone_type.lower() == "default":
if shared_decoders:
model_class = CellViTStarDistShared
else:
model_class = CellViTStarDist
model = model_class(
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
embed_dim=self.run_conf["model"]["embed_dim"],
input_channels=self.run_conf["model"].get("input_channels", 3),
depth=self.run_conf["model"]["depth"],
num_heads=self.run_conf["model"]["num_heads"],
extract_layers=self.run_conf["model"]["extract_layers"],
drop_rate=self.run_conf["training"].get("drop_rate", 0),
attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0),
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0),
nrays=self.run_conf["model"].get("nrays", 32),
)
if pretrained_model is not None:
self.logger.info(
f"Loading pretrained CellViT model from path: {pretrained_model}"
)
cellvit_pretrained = torch.load(pretrained_model)
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True))
self.logger.info("Loaded CellViT model")
if backbone_type.lower() == "vit256":
if shared_decoders:
model_class = CellViT256StarDistShared
else:
model_class = CellViT256StarDist
model = model_class(
model256_path=pretrained_encoder,
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
drop_rate=self.run_conf["training"].get("drop_rate", 0),
attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0),
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0),
nrays=self.run_conf["model"].get("nrays", 32),
)
model.load_pretrained_encoder(model.model256_path)
if pretrained_model is not None:
self.logger.info(
f"Loading pretrained CellViT model from path: {pretrained_model}"
)
cellvit_pretrained = torch.load(pretrained_model, map_location="cpu")
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True))
model.freeze_encoder()
self.logger.info("Loaded CellVit256 model")
if backbone_type.lower() in ["sam-b", "sam-l", "sam-h"]:
if shared_decoders:
model_class = CellViTSAMStarDistShared
else:
model_class = CellViTSAMStarDist
model = model_class(
model_path=pretrained_encoder,
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
vit_structure=backbone_type,
drop_rate=self.run_conf["training"].get("drop_rate", 0),
nrays=self.run_conf["model"].get("nrays", 32),
)
model.load_pretrained_encoder(model.model_path)
if pretrained_model is not None:
self.logger.info(
f"Loading pretrained CellViT model from path: {pretrained_model}"
)
cellvit_pretrained = torch.load(pretrained_model, map_location="cpu")
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True))
model.freeze_encoder()
self.logger.info(f"Loaded CellViT-SAM model with backbone: {backbone_type}")
if backbone_type.lower() == "rn50":
model = StarDistRN50(
n_rays=self.run_conf["model"].get("nrays", 32),
n_seg_cls=self.run_conf["data"]["num_nuclei_classes"],
)
self.logger.info(f"\nModel: {model}")
model = model.to("cpu")
self.logger.info(
f"\n{summary(model, input_size=(1, 3, 256, 256), device='cpu')}"
)
return model
def get_trainer(self) -> BaseTrainer:
"""Return Trainer matching to this network
Returns:
BaseTrainer: Trainer
"""
return CellViTStarDistTrainer