Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# CellVit Experiment Class | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import argparse | |
import copy | |
import datetime | |
import inspect | |
import os | |
import shutil | |
import sys | |
import yaml | |
import numpy as np | |
import math | |
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
parentdir = os.path.dirname(currentdir) | |
sys.path.insert(0, parentdir) | |
import uuid | |
from pathlib import Path | |
from typing import Callable, Tuple, Union | |
import torch | |
from torchsummary import summary | |
from torchstat import stat | |
import albumentations as A | |
import torch | |
import torch.nn as nn | |
import wandb | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import ( | |
ConstantLR, | |
CosineAnnealingLR, | |
ExponentialLR, | |
SequentialLR, | |
_LRScheduler, | |
CosineAnnealingWarmRestarts, | |
) | |
from torch.utils.data import ( | |
DataLoader, | |
Dataset, | |
RandomSampler, | |
Sampler, | |
Subset, | |
WeightedRandomSampler, | |
) | |
from torchinfo import summary | |
from wandb.sdk.lib.runid import generate_id | |
from base_ml.base_early_stopping import EarlyStopping | |
from base_ml.base_experiment import BaseExperiment | |
from base_ml.base_loss import retrieve_loss_fn | |
from base_ml.base_trainer import BaseTrainer | |
from cell_segmentation.datasets.base_cell import CellDataset | |
from cell_segmentation.datasets.dataset_coordinator import select_dataset | |
from cell_segmentation.trainer.trainer_cellvit import CellViTTrainer | |
from models.segmentation.cell_segmentation.cellvit import CellViT | |
from utils.tools import close_logger | |
class WarmupCosineAnnealingLR(CosineAnnealingLR): | |
def __init__(self, optimizer, T_max, eta_min=0, warmup_epochs=0, warmup_factor=0): | |
super().__init__(optimizer, T_max=T_max, eta_min=eta_min) | |
self.warmup_epochs = warmup_epochs | |
self.warmup_factor = warmup_factor | |
self.initial_lr = [group['lr'] for group in optimizer.param_groups] #初始化的学习率 | |
def get_lr(self): | |
if self.last_epoch < self.warmup_epochs: | |
warmup_factor = self.warmup_factor + (1.0 - self.warmup_factor) * (self.last_epoch / self.warmup_epochs) | |
return [base_lr * warmup_factor for base_lr in self.initial_lr] | |
else: | |
return [base_lr * self.get_lr_ratio() for base_lr in self.initial_lr] | |
def get_lr_ratio(self): | |
T_cur = min(self.last_epoch - self.warmup_epochs, self.T_max - self.warmup_epochs) | |
return 0.5 * (1 + math.cos(math.pi * T_cur / (self.T_max - self.warmup_epochs))) | |
class ExperimentCellVitPanNuke(BaseExperiment): | |
def __init__(self, default_conf: dict, checkpoint=None) -> None: | |
super().__init__(default_conf, checkpoint) | |
self.load_dataset_setup(dataset_path=self.default_conf["data"]["dataset_path"]) | |
def run_experiment(self) -> Tuple[Path, dict, nn.Module, dict]: | |
"""Main Experiment Code""" | |
### Setup | |
# close loggers | |
self.close_remaining_logger() | |
# Initialize distributed training environment | |
# get the config for the current run | |
self.run_conf = copy.deepcopy(self.default_conf) | |
self.run_conf["dataset_config"] = self.dataset_config | |
self.run_name = f"{datetime.datetime.now().strftime('%Y-%m-%dT%H%M%S')}_{self.run_conf['logging']['log_comment']}" | |
wandb_run_id = generate_id() | |
resume = None | |
if self.checkpoint is not None: | |
wandb_run_id = self.checkpoint["wandb_id"] | |
resume = "must" | |
self.run_name = self.checkpoint["run_name"] | |
# initialize wandb | |
run = wandb.init( | |
project=self.run_conf["logging"]["project"], | |
tags=self.run_conf["logging"].get("tags", []), | |
name=self.run_name, | |
notes=self.run_conf["logging"]["notes"], | |
dir=self.run_conf["logging"]["wandb_dir"], | |
mode=self.run_conf["logging"]["mode"].lower(), | |
group=self.run_conf["logging"].get("group", str(uuid.uuid4())), | |
allow_val_change=True, | |
id=wandb_run_id, | |
resume=resume, | |
settings=wandb.Settings(start_method="fork"), | |
) | |
# get ids | |
self.run_conf["logging"]["run_id"] = run.id | |
self.run_conf["logging"]["wandb_file"] = run.id | |
# overwrite configuration with sweep values are leave them as they are | |
if self.run_conf["run_sweep"] is True: | |
self.run_conf["logging"]["sweep_id"] = run.sweep_id | |
self.run_conf["logging"]["log_dir"] = str( | |
Path(self.default_conf["logging"]["log_dir"]) | |
/ f"sweep_{run.sweep_id}" | |
/ f"{self.run_name}_{self.run_conf['logging']['run_id']}" | |
) | |
self.overwrite_sweep_values(self.run_conf, run.config) | |
else: | |
self.run_conf["logging"]["log_dir"] = str( | |
Path(self.default_conf["logging"]["log_dir"]) / self.run_name | |
) | |
# update wandb | |
wandb.config.update( | |
self.run_conf, allow_val_change=True | |
) # this may lead to the problem | |
# create output folder, instantiate logger and store config | |
self.create_output_dir(self.run_conf["logging"]["log_dir"]) | |
self.logger = self.instantiate_logger() | |
self.logger.info("Instantiated Logger. WandB init and config update finished.") | |
self.logger.info(f"Run ist stored here: {self.run_conf['logging']['log_dir']}") | |
self.store_config() | |
self.logger.info( | |
f"Cuda devices: {[torch.cuda.device(i) for i in range(torch.cuda.device_count())]}" | |
) | |
### Machine Learning | |
#device = f"cuda:{2}" | |
#device = torch.device("cuda:2") | |
device = f"cuda:{self.run_conf['gpu']}" | |
self.logger.info(f"Using GPU: {device}") | |
self.logger.info(f"Using device: {device}") | |
# loss functions | |
loss_fn_dict = self.get_loss_fn(self.run_conf.get("loss", {})) | |
self.logger.info("Loss functions:") | |
self.logger.info(loss_fn_dict) | |
# model | |
model = self.get_train_model( | |
pretrained_encoder=self.run_conf["model"].get("pretrained_encoder", None), | |
pretrained_model=self.run_conf["model"].get("pretrained", None), | |
backbone_type=self.run_conf["model"].get("backbone", "default"), | |
shared_decoders=self.run_conf["model"].get("shared_decoders", False), | |
regression_loss=self.run_conf["model"].get("regression_loss", False), | |
) | |
model.to(device) | |
# optimizer | |
optimizer = self.get_optimizer( | |
model, | |
self.run_conf["training"]["optimizer"].lower(), | |
self.run_conf["training"]["optimizer_hyperparameter"], | |
#self.run_conf["training"]["optimizer"], | |
self.run_conf["training"]["layer_decay"], | |
) | |
# scheduler | |
scheduler = self.get_scheduler( | |
optimizer=optimizer, | |
scheduler_type=self.run_conf["training"]["scheduler"]["scheduler_type"], | |
) | |
# early stopping (no early stopping for basic setup) | |
early_stopping = None | |
if "early_stopping_patience" in self.run_conf["training"]: | |
if self.run_conf["training"]["early_stopping_patience"] is not None: | |
early_stopping = EarlyStopping( | |
patience=self.run_conf["training"]["early_stopping_patience"], | |
strategy="maximize", | |
) | |
### Data handling | |
train_transforms, val_transforms = self.get_transforms( | |
self.run_conf["transformations"], | |
input_shape=self.run_conf["data"].get("input_shape", 256), | |
) | |
train_dataset, val_dataset = self.get_datasets( | |
train_transforms=train_transforms, | |
val_transforms=val_transforms, | |
) | |
# load sampler | |
training_sampler = self.get_sampler( | |
train_dataset=train_dataset, | |
strategy=self.run_conf["training"].get("sampling_strategy", "random"), | |
gamma=self.run_conf["training"].get("sampling_gamma", 1), | |
) | |
# define dataloaders | |
train_dataloader = DataLoader( | |
train_dataset, | |
batch_size=self.run_conf["training"]["batch_size"], | |
sampler=training_sampler, | |
num_workers=16, | |
pin_memory=False, | |
worker_init_fn=self.seed_worker, | |
) | |
val_dataloader = DataLoader( | |
val_dataset, | |
batch_size=64, | |
num_workers=8, | |
pin_memory=True, | |
worker_init_fn=self.seed_worker, | |
) | |
# start Training | |
self.logger.info("Instantiate Trainer") | |
trainer_fn = self.get_trainer() | |
trainer = trainer_fn( | |
model=model, | |
loss_fn_dict=loss_fn_dict, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
device=device, | |
logger=self.logger, | |
logdir=self.run_conf["logging"]["log_dir"], | |
num_classes=self.run_conf["data"]["num_nuclei_classes"], | |
dataset_config=self.dataset_config, | |
early_stopping=early_stopping, | |
experiment_config=self.run_conf, | |
log_images=self.run_conf["logging"].get("log_images", False), | |
magnification=self.run_conf["data"].get("magnification", 40), | |
mixed_precision=self.run_conf["training"].get("mixed_precision", False), | |
) | |
# Load checkpoint if provided | |
if self.checkpoint is not None: | |
self.logger.info("Checkpoint was provided. Restore ...") | |
trainer.resume_checkpoint(self.checkpoint) | |
# Call fit method | |
self.logger.info("Calling Trainer Fit") | |
trainer.fit( | |
epochs=self.run_conf["training"]["epochs"], | |
train_dataloader=train_dataloader, | |
val_dataloader=val_dataloader, | |
metric_init=self.get_wandb_init_dict(), | |
unfreeze_epoch=self.run_conf["training"]["unfreeze_epoch"], | |
eval_every=self.run_conf["training"].get("eval_every", 1), | |
) | |
# Select best model if not provided by early stopping | |
checkpoint_dir = Path(self.run_conf["logging"]["log_dir"]) / "checkpoints" | |
if not (checkpoint_dir / "model_best.pth").is_file(): | |
shutil.copy( | |
checkpoint_dir / "latest_checkpoint.pth", | |
checkpoint_dir / "model_best.pth", | |
) | |
# At the end close logger | |
self.logger.info(f"Finished run {run.id}") | |
close_logger(self.logger) | |
return self.run_conf["logging"]["log_dir"] | |
def load_dataset_setup(self, dataset_path: Union[Path, str]) -> None: | |
"""Load the configuration of the cell segmentation dataset. | |
The dataset must have a dataset_config.yaml file in their dataset path with the following entries: | |
* tissue_types: describing the present tissue types with corresponding integer | |
* nuclei_types: describing the present nuclei types with corresponding integer | |
Args: | |
dataset_path (Union[Path, str]): Path to dataset folder | |
""" | |
dataset_config_path = Path(dataset_path) / "dataset_config.yaml" | |
with open(dataset_config_path, "r") as dataset_config_file: | |
yaml_config = yaml.safe_load(dataset_config_file) | |
self.dataset_config = dict(yaml_config) | |
def get_loss_fn(self, loss_fn_settings: dict) -> dict: | |
"""Create a dictionary with loss functions for all branches | |
Branches: "nuclei_binary_map", "hv_map", "nuclei_type_map", "tissue_types" | |
Args: | |
loss_fn_settings (dict): Dictionary with the loss function settings. Structure | |
branch_name(str): | |
loss_name(str): | |
loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss) | |
weight(float): Weighting factor as float value | |
(optional) args: Optional parameters for initializing the loss function | |
arg_name: value | |
If a branch is not provided, the defaults settings (described below) are used. | |
For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml | |
under the section "loss" | |
Example: | |
nuclei_binary_map: | |
bce: | |
loss_fn: xentropy_loss | |
weight: 1 | |
dice: | |
loss_fn: dice_loss | |
weight: 1 | |
Returns: | |
dict: Dictionary with loss functions for each branch. Structure: | |
branch_name(str): | |
loss_name(str): | |
"loss_fn": Callable loss function | |
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
loss_name(str): | |
"loss_fn": Callable loss function | |
"weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
branch_name(str) | |
... | |
Default loss dictionary: | |
nuclei_binary_map: | |
bce: | |
loss_fn: xentropy_loss | |
weight: 1 | |
dice: | |
loss_fn: dice_loss | |
weight: 1 | |
hv_map: | |
mse: | |
loss_fn: mse_loss_maps | |
weight: 1 | |
msge: | |
loss_fn: msge_loss_maps | |
weight: 1 | |
nuclei_type_map | |
bce: | |
loss_fn: xentropy_loss | |
weight: 1 | |
dice: | |
loss_fn: dice_loss | |
weight: 1 | |
tissue_types | |
ce: | |
loss_fn: nn.CrossEntropyLoss() | |
weight: 1 | |
""" | |
loss_fn_dict = {} | |
if "nuclei_binary_map" in loss_fn_settings.keys(): | |
loss_fn_dict["nuclei_binary_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["nuclei_binary_map"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["nuclei_binary_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["nuclei_binary_map"] = { | |
"bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
"dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
} | |
if "hv_map" in loss_fn_settings.keys(): | |
loss_fn_dict["hv_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["hv_map"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["hv_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["hv_map"] = { | |
"mse": {"loss_fn": retrieve_loss_fn("mse_loss_maps"), "weight": 1}, | |
"msge": {"loss_fn": retrieve_loss_fn("msge_loss_maps"), "weight": 1}, | |
} | |
if "nuclei_type_map" in loss_fn_settings.keys(): | |
loss_fn_dict["nuclei_type_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["nuclei_type_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["nuclei_type_map"] = { | |
"bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
"dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
} | |
if "tissue_types" in loss_fn_settings.keys(): | |
loss_fn_dict["tissue_types"] = {} | |
for loss_name, loss_sett in loss_fn_settings["tissue_types"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["tissue_types"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
else: | |
loss_fn_dict["tissue_types"] = { | |
"ce": {"loss_fn": nn.CrossEntropyLoss(), "weight": 1}, | |
} | |
if "regression_loss" in loss_fn_settings.keys(): | |
loss_fn_dict["regression_map"] = {} | |
for loss_name, loss_sett in loss_fn_settings["regression_loss"].items(): | |
parameters = loss_sett.get("args", {}) | |
loss_fn_dict["regression_map"][loss_name] = { | |
"loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
"weight": loss_sett["weight"], | |
} | |
elif "regression_loss" in self.run_conf["model"].keys(): | |
loss_fn_dict["regression_map"] = { | |
"mse": {"loss_fn": retrieve_loss_fn("mse_loss_maps"), "weight": 1}, | |
} | |
return loss_fn_dict | |
def get_scheduler(self, scheduler_type: str, optimizer: Optimizer) -> _LRScheduler: | |
"""Get the learning rate scheduler for CellViT | |
The configuration of the scheduler is given in the "training" -> "scheduler" section. | |
Currenlty, "constant", "exponential" and "cosine" schedulers are implemented. | |
Required parameters for implemented schedulers: | |
- "constant": None | |
- "exponential": gamma (optional, defaults to 0.95) | |
- "cosine": eta_min (optional, defaults to 1-e5) | |
Args: | |
scheduler_type (str): Type of scheduler as a string. Currently implemented: | |
- "constant" (lowering by a factor of ten after 25 epochs, increasing after 50, decreasimg again after 75) | |
- "exponential" (ExponentialLR with given gamma, gamma defaults to 0.95) | |
- "cosine" (CosineAnnealingLR, eta_min as parameter, defaults to 1-e5) | |
optimizer (Optimizer): Optimizer | |
Returns: | |
_LRScheduler: PyTorch Scheduler | |
""" | |
implemented_schedulers = ["constant", "exponential", "cosine", "default"] | |
if scheduler_type.lower() not in implemented_schedulers: | |
self.logger.warning( | |
f"Unknown Scheduler - No scheduler from the list {implemented_schedulers} select. Using default scheduling." | |
) | |
if scheduler_type.lower() == "constant": | |
scheduler = SequentialLR( | |
optimizer=optimizer, | |
schedulers=[ | |
ConstantLR(optimizer, factor=1, total_iters=25), | |
ConstantLR(optimizer, factor=0.1, total_iters=25), | |
ConstantLR(optimizer, factor=1, total_iters=25), | |
ConstantLR(optimizer, factor=0.1, total_iters=1000), | |
], | |
milestones=[24, 49, 74], | |
) | |
elif scheduler_type.lower() == "exponential": | |
scheduler = ExponentialLR( | |
optimizer, | |
gamma=self.run_conf["training"]["scheduler"].get("gamma", 0.95), | |
) | |
elif scheduler_type.lower() == "cosine": | |
scheduler = CosineAnnealingLR( | |
optimizer, | |
T_max=self.run_conf["training"]["epochs"], | |
eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5), | |
) | |
# elif scheduler_type.lower == "cosinewarmrestarts": | |
# scheduler = CosineAnnealingWarmRestarts( | |
# optimizer, | |
# T_0=self.run_conf["training"]["scheduler"]["T_0"], | |
# T_mult=self.run_conf["training"]["scheduler"]["T_mult"], | |
# eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5) | |
# ) | |
elif scheduler_type.lower() == "default": | |
scheduler = super().get_scheduler(optimizer) | |
return scheduler | |
def get_datasets( | |
self, | |
train_transforms: Callable = None, | |
val_transforms: Callable = None, | |
) -> Tuple[Dataset, Dataset]: | |
"""Retrieve training dataset and validation dataset | |
Args: | |
train_transforms (Callable, optional): PyTorch transformations for train set. Defaults to None. | |
val_transforms (Callable, optional): PyTorch transformations for validation set. Defaults to None. | |
Returns: | |
Tuple[Dataset, Dataset]: Training dataset and validation dataset | |
""" | |
if ( | |
"val_split" in self.run_conf["data"] | |
and "val_folds" in self.run_conf["data"] | |
): | |
raise RuntimeError( | |
"Provide either val_splits or val_folds in configuration file, not both." | |
) | |
if ( | |
"val_split" not in self.run_conf["data"] | |
and "val_folds" not in self.run_conf["data"] | |
): | |
raise RuntimeError( | |
"Provide either val_split or val_folds in configuration file, one is necessary." | |
) | |
if ( | |
"val_split" not in self.run_conf["data"] | |
and "val_folds" not in self.run_conf["data"] | |
): | |
raise RuntimeError( | |
"Provide either val_split or val_fold in configuration file, one is necessary." | |
) | |
if "regression_loss" in self.run_conf["model"].keys(): | |
self.run_conf["data"]["regression_loss"] = True | |
full_dataset = select_dataset( | |
dataset_name="pannuke", | |
split="train", | |
dataset_config=self.run_conf["data"], | |
transforms=train_transforms, | |
) | |
if "val_split" in self.run_conf["data"]: | |
generator_split = torch.Generator().manual_seed( | |
self.default_conf["random_seed"] | |
) | |
val_splits = float(self.run_conf["data"]["val_split"]) | |
train_dataset, val_dataset = torch.utils.data.random_split( | |
full_dataset, | |
lengths=[1 - val_splits, val_splits], | |
generator=generator_split, | |
) | |
val_dataset.dataset = copy.deepcopy(full_dataset) | |
val_dataset.dataset.set_transforms(val_transforms) | |
else: | |
train_dataset = full_dataset | |
val_dataset = select_dataset( | |
dataset_name="pannuke", | |
split="validation", | |
dataset_config=self.run_conf["data"], | |
transforms=val_transforms, | |
) | |
return train_dataset, val_dataset | |
def get_train_model( | |
self, | |
pretrained_encoder: Union[Path, str] = None, | |
pretrained_model: Union[Path, str] = None, | |
backbone_type: str = "default", | |
shared_decoders: bool = False, | |
regression_loss: bool = False, | |
**kwargs, | |
) -> CellViT: | |
"""Return the CellViT training model | |
Args: | |
pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None. | |
pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None. | |
backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H). Defaults to None | |
shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False. | |
regression_loss (bool, optional): If regression loss is used. Defaults to False | |
Returns: | |
CellViT: CellViT training model with given setup | |
""" | |
# reseed needed, due to subprocess seeding compatibility | |
self.seed_run(self.default_conf["random_seed"]) | |
# check for backbones | |
implemented_backbones = ["default", "UniRepLKNet", "vit256", "sam-b", "sam-l", "sam-h"] | |
if backbone_type.lower() not in implemented_backbones: | |
raise NotImplementedError( | |
f"Unknown Backbone Type - Currently supported are: {implemented_backbones}" | |
) | |
if backbone_type.lower() == "default": | |
model_class = CellViT | |
model = model_class( | |
model256_path = pretrained_encoder, | |
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
#embed_dim=self.run_conf["model"]["embed_dim"], | |
in_channels=self.run_conf["model"].get("input_chanels", 3), | |
#depth=self.run_conf["model"]["depth"], | |
#change | |
#depth=(3, 3, 27, 3), | |
#num_heads=self.run_conf["model"]["num_heads"], | |
# extract_layers=self.run_conf["model"]["extract_layers"], | |
dropout=self.run_conf["training"].get("drop_rate", 0), | |
#attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0.1), | |
#regression_loss=regression_loss, | |
) | |
model.load_pretrained_encoder(model.model256_path) | |
#model.load_state_dict(checkpoint["model"]) | |
if pretrained_model is not None: | |
self.logger.info( | |
f"Loading pretrained CellViT model from path: {pretrained_model}" | |
) | |
cellvit_pretrained = torch.load(pretrained_model) | |
self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
self.logger.info("Loaded CellViT model") | |
self.logger.info(f"\nModel: {model}") | |
print(f"\nModel: {model}") | |
model = model.to("cuda") | |
self.logger.info( | |
f"\n{summary(model, input_size=(1, 3, 256, 256), device='cuda')}" | |
) | |
# from thop import profile | |
# input_size=torch.randn(1, 3, 256, 256) | |
# self.logger.info( | |
# f"\n{profile(model, inputs=(input_size,))}" | |
# ) | |
#self.logger.info(f"\n{stat(model, (3, 256, 256))}") | |
total_params = 0 | |
Trainable_params = 0 | |
NonTrainable_params = 0 | |
for param in model.parameters(): | |
multvalue = np.prod(param.size()) | |
total_params += multvalue | |
if param.requires_grad: | |
Trainable_params += multvalue # 可训练参数量 | |
else: | |
NonTrainable_params += multvalue # 非可训练参数量 | |
print(f'Total params: {total_params}') | |
print(f'Trainable params: {Trainable_params}') | |
print(f'Non-trainable params: {NonTrainable_params}') | |
return model | |
def get_wandb_init_dict(self) -> dict: | |
pass | |
def get_transforms( | |
self, transform_settings: dict, input_shape: int = 256 | |
) -> Tuple[Callable, Callable]: | |
"""Get Transformations (Albumentation Transformations). Return both training and validation transformations. | |
The transformation settings are given in the following format: | |
key: dict with parameters | |
Example: | |
colorjitter: | |
p: 0.1 | |
scale_setting: 0.5 | |
scale_color: 0.1 | |
For further information on how to setup the dictionary and default (recommended) values is given here: | |
configs/examples/cell_segmentation/train_cellvit.yaml | |
Training Transformations: | |
Implemented are: | |
- A.RandomRotate90: Key in transform_settings: randomrotate90, parameters: p | |
- A.HorizontalFlip: Key in transform_settings: horizontalflip, parameters: p | |
- A.VerticalFlip: Key in transform_settings: verticalflip, parameters: p | |
- A.Downscale: Key in transform_settings: downscale, parameters: p, scale | |
- A.Blur: Key in transform_settings: blur, parameters: p, blur_limit | |
- A.GaussNoise: Key in transform_settings: gaussnoise, parameters: p, var_limit | |
- A.ColorJitter: Key in transform_settings: colorjitter, parameters: p, scale_setting, scale_color | |
- A.Superpixels: Key in transform_settings: superpixels, parameters: p | |
- A.ZoomBlur: Key in transform_settings: zoomblur, parameters: p | |
- A.RandomSizedCrop: Key in transform_settings: randomsizedcrop, parameters: p | |
- A.ElasticTransform: Key in transform_settings: elastictransform, parameters: p | |
Always implemented at the end of the pipeline: | |
- A.Normalize with given mean (default: (0.5, 0.5, 0.5)) and std (default: (0.5, 0.5, 0.5)) | |
Validation Transformations: | |
A.Normalize with given mean (default: (0.5, 0.5, 0.5)) and std (default: (0.5, 0.5, 0.5)) | |
Args: | |
transform_settings (dict): dictionay with the transformation settings. | |
input_shape (int, optional): Input shape of the images to used. Defaults to 256. | |
Returns: | |
Tuple[Callable, Callable]: Train Transformations, Validation Transformations | |
""" | |
transform_list = [] | |
transform_settings = {k.lower(): v for k, v in transform_settings.items()} | |
if "RandomRotate90".lower() in transform_settings: | |
p = transform_settings["randomrotate90"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append(A.RandomRotate90(p=p)) | |
if "HorizontalFlip".lower() in transform_settings.keys(): | |
p = transform_settings["horizontalflip"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append(A.HorizontalFlip(p=p)) | |
if "VerticalFlip".lower() in transform_settings: | |
p = transform_settings["verticalflip"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append(A.VerticalFlip(p=p)) | |
if "Downscale".lower() in transform_settings: | |
p = transform_settings["downscale"]["p"] | |
scale = transform_settings["downscale"]["scale"] | |
if p > 0 and p <= 1: | |
transform_list.append( | |
A.Downscale(p=p, scale_max=scale, scale_min=scale) | |
) | |
if "Blur".lower() in transform_settings: | |
p = transform_settings["blur"]["p"] | |
blur_limit = transform_settings["blur"]["blur_limit"] | |
if p > 0 and p <= 1: | |
transform_list.append(A.Blur(p=p, blur_limit=blur_limit)) | |
if "GaussNoise".lower() in transform_settings: | |
p = transform_settings["gaussnoise"]["p"] | |
var_limit = transform_settings["gaussnoise"]["var_limit"] | |
if p > 0 and p <= 1: | |
transform_list.append(A.GaussNoise(p=p, var_limit=var_limit)) | |
if "ColorJitter".lower() in transform_settings: | |
p = transform_settings["colorjitter"]["p"] | |
scale_setting = transform_settings["colorjitter"]["scale_setting"] | |
scale_color = transform_settings["colorjitter"]["scale_color"] | |
if p > 0 and p <= 1: | |
transform_list.append( | |
A.ColorJitter( | |
p=p, | |
brightness=scale_setting, | |
contrast=scale_setting, | |
saturation=scale_color, | |
hue=scale_color / 2, | |
) | |
) | |
if "Superpixels".lower() in transform_settings: | |
p = transform_settings["superpixels"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append( | |
A.Superpixels( | |
p=p, | |
p_replace=0.1, | |
n_segments=200, | |
max_size=int(input_shape / 2), | |
) | |
) | |
if "ZoomBlur".lower() in transform_settings: | |
p = transform_settings["zoomblur"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append(A.ZoomBlur(p=p, max_factor=1.05)) | |
if "RandomSizedCrop".lower() in transform_settings: | |
p = transform_settings["randomsizedcrop"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append( | |
A.RandomSizedCrop( | |
min_max_height=(input_shape / 2, input_shape), | |
height=input_shape, | |
width=input_shape, | |
p=p, | |
) | |
) | |
if "ElasticTransform".lower() in transform_settings: | |
p = transform_settings["elastictransform"]["p"] | |
if p > 0 and p <= 1: | |
transform_list.append( | |
A.ElasticTransform(p=p, sigma=25, alpha=0.5, alpha_affine=15) | |
) | |
if "normalize" in transform_settings: | |
mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
transform_list.append(A.Normalize(mean=mean, std=std)) | |
train_transforms = A.Compose(transform_list) | |
val_transforms = A.Compose([A.Normalize(mean=mean, std=std)]) | |
return train_transforms, val_transforms | |
def get_sampler( | |
self, train_dataset: CellDataset, strategy: str = "random", gamma: float = 1 | |
) -> Sampler: | |
"""Return the sampler (either RandomSampler or WeightedRandomSampler) | |
Args: | |
train_dataset (CellDataset): Dataset for training | |
strategy (str, optional): Sampling strategy. Defaults to "random" (random sampling). | |
Implemented are "random", "cell", "tissue", "cell+tissue". | |
gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
1 means total balancing, 0 means original weights. Defaults to 1. | |
Raises: | |
NotImplementedError: Not implemented sampler is selected | |
Returns: | |
Sampler: Sampler for training | |
""" | |
if strategy.lower() == "random": | |
sampling_generator = torch.Generator().manual_seed( | |
self.default_conf["random_seed"] | |
) | |
sampler = RandomSampler(train_dataset, generator=sampling_generator) | |
self.logger.info("Using RandomSampler") | |
else: | |
# this solution is not accurate when a subset is used since the weights are calculated on the whole training dataset | |
if isinstance(train_dataset, Subset): | |
ds = train_dataset.dataset | |
else: | |
ds = train_dataset | |
ds.load_cell_count() | |
if strategy.lower() == "cell": | |
weights = ds.get_sampling_weights_cell(gamma) | |
elif strategy.lower() == "tissue": | |
weights = ds.get_sampling_weights_tissue(gamma) | |
elif strategy.lower() == "cell+tissue": | |
weights = ds.get_sampling_weights_cell_tissue(gamma) | |
else: | |
raise NotImplementedError( | |
"Unknown sampling strategy - Implemented are cell, tissue and cell+tissue" | |
) | |
if isinstance(train_dataset, Subset): | |
weights = torch.Tensor([weights[i] for i in train_dataset.indices]) | |
sampling_generator = torch.Generator().manual_seed( | |
self.default_conf["random_seed"] | |
) | |
sampler = WeightedRandomSampler( | |
weights=weights, | |
num_samples=len(train_dataset), | |
replacement=True, | |
generator=sampling_generator, | |
) | |
self.logger.info(f"Using Weighted Sampling with strategy: {strategy}") | |
self.logger.info(f"Unique-Weights: {torch.unique(weights)}") | |
return sampler | |
def get_trainer(self) -> BaseTrainer: | |
"""Return Trainer matching to this network | |
Returns: | |
BaseTrainer: Trainer | |
""" | |
return CellViTTrainer | |