Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Loss functions (PyTorch and own defined) | |
# | |
# Own defined loss functions: | |
# xentropy_loss, dice_loss, mse_loss and msge_loss (https://github.com/vqdang/hover_net) | |
# WeightedBaseLoss, MAEWeighted, MSEWeighted, BCEWeighted, CEWeighted (https://github.com/okunator/cellseg_models.pytorch) | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import torch | |
import torch.nn.functional as F | |
from typing import List, Tuple | |
from torch import nn | |
from torch.nn.modules.loss import _Loss | |
from base_ml.base_utils import filter2D, gaussian_kernel2d | |
class XentropyLoss(_Loss): | |
"""Cross entropy loss""" | |
def __init__(self, reduction: str = "mean") -> None: | |
super().__init__(size_average=None, reduce=None, reduction=reduction) | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Assumes NCHW shape of array, must be torch.float32 dtype | |
Args: | |
input (torch.Tensor): Ground truth array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes | |
target (torch.Tensor): Prediction array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes | |
Returns: | |
torch.Tensor: Cross entropy loss, with shape () [scalar], grad_fn = MeanBackward0 | |
""" | |
# reshape | |
input = input.permute(0, 2, 3, 1) | |
target = target.permute(0, 2, 3, 1) | |
epsilon = 10e-8 | |
# scale preds so that the class probs of each sample sum to 1 | |
pred = input / torch.sum(input, -1, keepdim=True) | |
# manual computation of crossentropy | |
pred = torch.clamp(pred, epsilon, 1.0 - epsilon) | |
loss = -torch.sum((target * torch.log(pred)), -1, keepdim=True) | |
loss = loss.mean() if self.reduction == "mean" else loss.sum() | |
return loss | |
class DiceLoss(_Loss): | |
"""Dice loss | |
Args: | |
smooth (float, optional): Smoothing value. Defaults to 1e-3. | |
""" | |
def __init__(self, smooth: float = 1e-3) -> None: | |
super().__init__(size_average=None, reduce=None, reduction="mean") | |
self.smooth = smooth | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Assumes NCHW shape of array, must be torch.float32 dtype | |
`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC. | |
Args: | |
input (torch.Tensor): Prediction array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes | |
target (torch.Tensor): Ground truth array with shape (N, C, H, W) with N being the batch-size, H the height, W the width and C the number of classes | |
Returns: | |
torch.Tensor: Dice loss, with shape () [scalar], grad_fn=SumBackward0 | |
""" | |
input = input.permute(0, 2, 3, 1) | |
target = target.permute(0, 2, 3, 1) | |
inse = torch.sum(input * target, (0, 1, 2)) | |
l = torch.sum(input, (0, 1, 2)) | |
r = torch.sum(target, (0, 1, 2)) | |
loss = 1.0 - (2.0 * inse + self.smooth) / (l + r + self.smooth) | |
loss = torch.sum(loss) | |
return loss | |
class MSELossMaps(_Loss): | |
"""Calculate mean squared error loss for combined horizontal and vertical maps of segmentation tasks.""" | |
def __init__(self) -> None: | |
super().__init__(size_average=None, reduce=None, reduction="mean") | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Loss calculation | |
Args: | |
input (torch.Tensor): Prediction of combined horizontal and vertical maps | |
with shape (N, 2, H, W), channel 0 is vertical and channel 1 is horizontal | |
target (torch.Tensor): Ground truth of combined horizontal and vertical maps | |
with shape (N, 2, H, W), channel 0 is vertical and channel 1 is horizontal | |
Returns: | |
torch.Tensor: Mean squared error per pixel with shape (N, 2, H, W), grad_fn=SubBackward0 | |
""" | |
# reshape | |
loss = input - target | |
loss = (loss * loss).mean() | |
return loss | |
class MSGELossMaps(_Loss): | |
def __init__(self) -> None: | |
super().__init__(size_average=None, reduce=None, reduction="mean") | |
def get_sobel_kernel( | |
self, size: int, device: str | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Get sobel kernel with a given size. | |
Args: | |
size (int): Kernel site | |
device (str): Cuda device | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Horizontal and vertical sobel kernel, each with shape (size, size) | |
""" | |
assert size % 2 == 1, "Must be odd, get size=%d" % size | |
h_range = torch.arange( | |
-size // 2 + 1, | |
size // 2 + 1, | |
dtype=torch.float32, | |
device=device, | |
requires_grad=False, | |
) | |
v_range = torch.arange( | |
-size // 2 + 1, | |
size // 2 + 1, | |
dtype=torch.float32, | |
device=device, | |
requires_grad=False, | |
) | |
h, v = torch.meshgrid(h_range, v_range, indexing="ij") | |
kernel_h = h / (h * h + v * v + 1.0e-15) | |
kernel_v = v / (h * h + v * v + 1.0e-15) | |
return kernel_h, kernel_v | |
def get_gradient_hv(self, hv: torch.Tensor, device: str) -> torch.Tensor: | |
"""For calculating gradient of horizontal and vertical prediction map | |
Args: | |
hv (torch.Tensor): horizontal and vertical map | |
device (str): CUDA device | |
Returns: | |
torch.Tensor: Gradient with same shape as input | |
""" | |
kernel_h, kernel_v = self.get_sobel_kernel(5, device=device) | |
kernel_h = kernel_h.view(1, 1, 5, 5) # constant | |
kernel_v = kernel_v.view(1, 1, 5, 5) # constant | |
h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW | |
v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW | |
# can only apply in NCHW mode | |
h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) | |
v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) | |
dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) | |
dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC | |
return dhv | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
focus: torch.Tensor, | |
device: str, | |
) -> torch.Tensor: | |
"""MSGE (Gradient of MSE) loss | |
Args: | |
input (torch.Tensor): Input with shape (B, C, H, W) | |
target (torch.Tensor): Target with shape (B, C, H, W) | |
focus (torch.Tensor): Focus, type of masking (B, C, W, W) | |
device (str): CUDA device to work with. | |
Returns: | |
torch.Tensor: MSGE loss | |
""" | |
input = input.permute(0, 2, 3, 1) | |
target = target.permute(0, 2, 3, 1) | |
focus = focus.permute(0, 2, 3, 1) | |
focus = focus[..., 1] | |
focus = (focus[..., None]).float() # assume input NHW | |
focus = torch.cat([focus, focus], axis=-1).to(device) | |
true_grad = self.get_gradient_hv(target, device) | |
pred_grad = self.get_gradient_hv(input, device) | |
loss = pred_grad - true_grad | |
loss = focus * (loss * loss) | |
# artificial reduce_mean with focused region | |
loss = loss.sum() / (focus.sum() + 1.0e-8) | |
return loss | |
class FocalTverskyLoss(nn.Module): | |
"""FocalTverskyLoss | |
PyTorch implementation of the Focal Tversky Loss Function for multiple classes | |
doi: 10.1109/ISBI.2019.8759329 | |
Abraham, N., & Khan, N. M. (2019). | |
A Novel Focal Tversky Loss Function With Improved Attention U-Net for Lesion Segmentation. | |
In International Symposium on Biomedical Imaging. https://doi.org/10.1109/isbi.2019.8759329 | |
@ Fabian Hörst, fabian.hoerst@uk-essen.de | |
Institute for Artifical Intelligence in Medicine, | |
University Medicine Essen | |
Args: | |
alpha_t (float, optional): Alpha parameter for tversky loss (multiplied with false-negatives). Defaults to 0.7. | |
beta_t (float, optional): Beta parameter for tversky loss (multiplied with false-positives). Defaults to 0.3. | |
gamma_f (float, optional): Gamma Focal parameter. Defaults to 4/3. | |
smooth (float, optional): Smooting factor. Defaults to 0.000001. | |
""" | |
def __init__( | |
self, | |
alpha_t: float = 0.7, | |
beta_t: float = 0.3, | |
gamma_f: float = 4 / 3, | |
smooth: float = 1e-6, | |
) -> None: | |
super().__init__() | |
self.alpha_t = alpha_t | |
self.beta_t = beta_t | |
self.gamma_f = gamma_f | |
self.smooth = smooth | |
self.num_classes = 2 | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Loss calculation | |
Args: | |
input (torch.Tensor): Predictions, logits (without Softmax). Shape: (B, C, H, W) | |
target (torch.Tensor): Targets, either flattened (Shape: (C, H, W) or as one-hot encoded (Shape: (batch-size, C, H, W)). | |
Raises: | |
ValueError: Error if there is a shape missmatch | |
Returns: | |
torch.Tensor: FocalTverskyLoss (weighted) | |
""" | |
input = input.permute(0, 2, 3, 1) | |
if input.shape[-1] != self.num_classes: | |
raise ValueError( | |
"Predictions must be a logit tensor with the last dimension shape beeing equal to the number of classes" | |
) | |
if len(target.shape) != len(input.shape): | |
# convert the targets to onehot | |
target = F.one_hot(target, num_classes=self.num_classes) | |
# flatten | |
target = target.permute(0, 2, 3, 1) | |
target = target.view(-1) | |
input = torch.softmax(input, dim=-1).view(-1) | |
# calculate true positives, false positives and false negatives | |
tp = (input * target).sum() | |
fp = ((1 - target) * input).sum() | |
fn = (target * (1 - input)).sum() | |
Tversky = (tp + self.smooth) / ( | |
tp + self.alpha_t * fn + self.beta_t * fp + self.smooth | |
) | |
FocalTversky = (1 - Tversky) ** self.gamma_f | |
return FocalTversky | |
class MCFocalTverskyLoss(FocalTverskyLoss): | |
"""Multiclass FocalTverskyLoss | |
PyTorch implementation of the Focal Tversky Loss Function for multiple classes | |
doi: 10.1109/ISBI.2019.8759329 | |
Abraham, N., & Khan, N. M. (2019). | |
A Novel Focal Tversky Loss Function With Improved Attention U-Net for Lesion Segmentation. | |
In International Symposium on Biomedical Imaging. https://doi.org/10.1109/isbi.2019.8759329 | |
@ Fabian Hörst, fabian.hoerst@uk-essen.de | |
Institute for Artifical Intelligence in Medicine, | |
University Medicine Essen | |
Args: | |
alpha_t (float, optional): Alpha parameter for tversky loss (multiplied with false-negatives). Defaults to 0.7. | |
beta_t (float, optional): Beta parameter for tversky loss (multiplied with false-positives). Defaults to 0.3. | |
gamma_f (float, optional): Gamma Focal parameter. Defaults to 4/3. | |
smooth (float, optional): Smooting factor. Defaults to 0.000001. | |
num_classes (int, optional): Number of output classes. For binary segmentation, prefer FocalTverskyLoss (speed optimized). Defaults to 2. | |
class_weights (List[int], optional): Weights for each class. If not provided, equal weight. Length must be equal to num_classes. Defaults to None. | |
""" | |
def __init__( | |
self, | |
alpha_t: float = 0.7, | |
beta_t: float = 0.3, | |
gamma_f: float = 4 / 3, | |
smooth: float = 0.000001, | |
num_classes: int = 2, | |
class_weights: List[int] = None, | |
) -> None: | |
super().__init__(alpha_t, beta_t, gamma_f, smooth) | |
self.num_classes = num_classes | |
if class_weights is None: | |
self.class_weights = [1 for i in range(self.num_classes)] | |
else: | |
assert ( | |
len(class_weights) == self.num_classes | |
), "Please provide matching weights" | |
self.class_weights = class_weights | |
self.class_weights = torch.Tensor(self.class_weights) | |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Loss calculation | |
Args: | |
input (torch.Tensor): Predictions, logits (without Softmax). Shape: (B, num_classes, H, W) | |
target (torch.Tensor): Targets, either flattened (Shape: (B, H, W) or as one-hot encoded (Shape: (B, num_classes, H, W)). | |
Raises: | |
ValueError: Error if there is a shape missmatch | |
Returns: | |
torch.Tensor: FocalTverskyLoss (weighted) | |
""" | |
input = input.permute(0, 2, 3, 1) | |
if input.shape[-1] != self.num_classes: | |
raise ValueError( | |
"Predictions must be a logit tensor with the last dimension shape beeing equal to the number of classes" | |
) | |
if len(target.shape) != len(input.shape): | |
# convert the targets to onehot | |
target = F.one_hot(target, num_classes=self.num_classes) | |
target = target.permute(0, 2, 3, 1) | |
# Softmax | |
input = torch.softmax(input, dim=-1) | |
# Reshape | |
input = torch.permute(input, (3, 1, 2, 0)) | |
target = torch.permute(target, (3, 1, 2, 0)) | |
input = torch.flatten(input, start_dim=1) | |
target = torch.flatten(target, start_dim=1) | |
tp = torch.sum(input * target, 1) | |
fp = torch.sum((1 - target) * input, 1) | |
fn = torch.sum(target * (1 - input), 1) | |
Tversky = (tp + self.smooth) / ( | |
tp + self.alpha_t * fn + self.beta_t * fp + self.smooth | |
) | |
FocalTversky = (1 - Tversky) ** self.gamma_f | |
self.class_weights = self.class_weights.to(FocalTversky.device) | |
return torch.sum(self.class_weights * FocalTversky) | |
class WeightedBaseLoss(nn.Module): | |
"""Init a base class for weighted cross entropy based losses. | |
Enables weighting for object instance edges and classes. | |
Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) | |
Args: | |
apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the | |
loss matrix. Defaults to False. | |
apply_ls (bool, optional): If True, Label smoothing will be applied to the target.. Defaults to False. | |
apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. | |
apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. | |
class_weights (torch.Tensor, optional): Class weights. A tensor of shape (C, ). Defaults to None. | |
edge_weight (float, optional): Weight for the object instance border pixels. Defaults to None. | |
""" | |
def __init__( | |
self, | |
apply_sd: bool = False, | |
apply_ls: bool = False, | |
apply_svls: bool = False, | |
apply_mask: bool = False, | |
class_weights: torch.Tensor = None, | |
edge_weight: float = None, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
self.apply_sd = apply_sd | |
self.apply_ls = apply_ls | |
self.apply_svls = apply_svls | |
self.apply_mask = apply_mask | |
self.class_weights = class_weights | |
self.edge_weight = edge_weight | |
def apply_spectral_decouple( | |
self, loss_matrix: torch.Tensor, yhat: torch.Tensor, lam: float = 0.01 | |
) -> torch.Tensor: | |
"""Apply spectral decoupling L2 norm after the loss. | |
https://arxiv.org/abs/2011.09468 | |
Args: | |
loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). | |
yhat (torch.Tensor): The pixel predictions of the model. Shape (B, C, H, W). | |
lam (float, optional): Lambda constant.. Defaults to 0.01. | |
Returns: | |
torch.Tensor: SD-regularized loss matrix. Same shape as input. | |
""" | |
return loss_matrix + (lam / 2) * (yhat**2).mean(axis=1) | |
def apply_ls_to_target( | |
self, | |
target: torch.Tensor, | |
num_classes: int, | |
label_smoothing: float = 0.1, | |
) -> torch.Tensor: | |
"""_summary_ | |
Args: | |
target (torch.Tensor): Number of classes in the data. | |
num_classes (int): The target one hot tensor. Shape (B, C, H, W) | |
label_smoothing (float, optional): The smoothing coeff alpha. Defaults to 0.1. | |
Returns: | |
torch.Tensor: Label smoothed target. Same shape as input. | |
""" | |
return target * (1 - label_smoothing) + label_smoothing / num_classes | |
def apply_svls_to_target( | |
self, | |
target: torch.Tensor, | |
num_classes: int, | |
kernel_size: int = 5, | |
sigma: int = 3, | |
**kwargs, | |
) -> torch.Tensor: | |
"""Apply spatially varying label smoothihng to target map. | |
https://arxiv.org/abs/2104.05788 | |
Args: | |
target (torch.Tensor): The target one hot tensor. Shape (B, C, H, W). | |
num_classes (int): Number of classes in the data. | |
kernel_size (int, optional): Size of a square kernel.. Defaults to 5. | |
sigma (int, optional): The std of the gaussian. Defaults to 3. | |
Returns: | |
torch.Tensor: Label smoothed target. Same shape as input. | |
""" | |
my, mx = kernel_size // 2, kernel_size // 2 | |
gaussian_kernel = gaussian_kernel2d( | |
kernel_size, sigma, num_classes, device=target.device | |
) | |
neighborsum = (1 - gaussian_kernel[..., my, mx]) + 1e-16 | |
gaussian_kernel = gaussian_kernel.clone() | |
gaussian_kernel[..., my, mx] = neighborsum | |
svls_kernel = gaussian_kernel / neighborsum[0] | |
return filter2D(target.float(), svls_kernel) / svls_kernel[0].sum() | |
def apply_class_weights( | |
self, loss_matrix: torch.Tensor, target: torch.Tensor | |
) -> torch.Tensor: | |
"""Multiply pixelwise loss matrix by the class weights. | |
NOTE: No normalization | |
Args: | |
loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). | |
target (torch.Tensor): The target mask. Shape (B, H, W). | |
Returns: | |
torch.Tensor: The loss matrix scaled with the weight matrix. Shape (B, H, W). | |
""" | |
weight_mat = self.class_weights[target.long()].to(target.device) # to (B, H, W) | |
loss = loss_matrix * weight_mat | |
return loss | |
def apply_edge_weights( | |
self, loss_matrix: torch.Tensor, weight_map: torch.Tensor | |
) -> torch.Tensor: | |
"""Apply weights to the object boundaries. | |
Basically just computes `edge_weight`**`weight_map`. | |
Args: | |
loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). | |
weight_map (torch.Tensor): Map that points to the pixels that will be weighted. Shape (B, H, W). | |
Returns: | |
torch.Tensor: The loss matrix scaled with the nuclear boundary weights. Shape (B, H, W). | |
""" | |
return loss_matrix * self.edge_weight**weight_map | |
def apply_mask_weight( | |
self, loss_matrix: torch.Tensor, mask: torch.Tensor, norm: bool = True | |
) -> torch.Tensor: | |
"""Apply a mask to the loss matrix. | |
Args: | |
loss_matrix (torch.Tensor): Pixelwise losses. A tensor of shape (B, H, W). | |
mask (torch.Tensor): The mask. Shape (B, H, W). | |
norm (bool, optional): If True, the loss matrix will be normalized by the mean of the mask. Defaults to True. | |
Returns: | |
torch.Tensor: The loss matrix scaled with the mask. Shape (B, H, W). | |
""" | |
loss_matrix *= mask | |
if norm: | |
norm_mask = torch.mean(mask.float()) + 1e-7 | |
loss_matrix /= norm_mask | |
return loss_matrix | |
def extra_repr(self) -> str: | |
"""Add info to print.""" | |
s = "apply_sd={apply_sd}, apply_ls={apply_ls}, apply_svls={apply_svls}, apply_mask={apply_mask}, class_weights={class_weights}, edge_weight={edge_weight}" # noqa | |
return s.format(**self.__dict__) | |
class MAEWeighted(WeightedBaseLoss): | |
"""Compute the MAE loss. Used in the stardist method. | |
Stardist: | |
https://arxiv.org/pdf/1806.03535.pdf | |
Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) | |
NOTE: We have added the option to apply spectral decoupling and edge weights | |
to the loss matrix. | |
Args: | |
alpha (float, optional): Weight regulizer b/w [0,1]. In stardist repo, this is the | |
'train_background_reg' parameter. Defaults to 1e-4. | |
apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the | |
loss matrix. Defaults to False. | |
apply_mask (bool, optional): f True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. | |
edge_weight (float, optional): Weight that is added to object borders. Defaults to None. | |
""" | |
def __init__( | |
self, | |
alpha: float = 1e-4, | |
apply_sd: bool = False, | |
apply_mask: bool = False, | |
edge_weight: float = None, | |
**kwargs, | |
) -> None: | |
super().__init__(apply_sd, False, False, apply_mask, False, edge_weight) | |
self.alpha = alpha | |
self.eps = 1e-7 | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
target_weight: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
**kwargs, | |
) -> torch.Tensor: | |
"""Compute the masked MAE loss. | |
Args: | |
input (torch.Tensor): The prediction map. Shape (B, C, H, W). | |
target (torch.Tensor): The ground truth annotations. Shape (B, H, W). | |
target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. | |
mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. | |
Raises: | |
ValueError: Pred and target shapes must match. | |
Returns: | |
torch.Tensor: Computed MAE loss (scalar). | |
""" | |
yhat = input | |
n_classes = yhat.shape[1] | |
if target.size() != yhat.size(): | |
target = target.unsqueeze(1).repeat_interleave(n_classes, dim=1) | |
if not yhat.shape == target.shape: | |
raise ValueError( | |
f"Pred and target shapes must match. Got: {yhat.shape}, {target.shape}" | |
) | |
# compute the MAE loss with alpha as weight | |
mae_loss = torch.mean(torch.abs(target - yhat), axis=1) # (B, H, W) | |
if self.apply_mask and mask is not None: | |
mae_loss = self.apply_mask_weight(mae_loss, mask, norm=True) # (B, H, W) | |
# add the background regularization | |
if self.alpha > 0: | |
reg = torch.mean(((1 - mask).unsqueeze(1)) * torch.abs(yhat), axis=1) | |
mae_loss += self.alpha * reg | |
if self.apply_sd: | |
mae_loss = self.apply_spectral_decouple(mae_loss, yhat) | |
if self.edge_weight is not None: | |
mae_loss = self.apply_edge_weights(mae_loss, target_weight) | |
return mae_loss.mean() | |
class MSEWeighted(WeightedBaseLoss): | |
"""MSE-loss. | |
Args: | |
apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the | |
loss matrix. Defaults to False. | |
apply_ls (bool, optional): If True, Label smoothing will be applied to the target. Defaults to False. | |
apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. | |
apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. | |
edge_weight (float, optional): Weight that is added to object borders. Defaults to None. | |
class_weights (torch.Tensor, optional): Class weights. A tensor of shape (n_classes,). Defaults to None. | |
""" | |
def __init__( | |
self, | |
apply_sd: bool = False, | |
apply_ls: bool = False, | |
apply_svls: bool = False, | |
apply_mask: bool = False, | |
edge_weight: float = None, | |
class_weights: torch.Tensor = None, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight | |
) | |
def tensor_one_hot(type_map: torch.Tensor, n_classes: int) -> torch.Tensor: | |
"""Convert a segmentation mask into one-hot-format. | |
I.e. Takes in a segmentation mask of shape (B, H, W) and reshapes it | |
into a tensor of shape (B, C, H, W). | |
Args: | |
type_map (torch.Tensor): Multi-label Segmentation mask. Shape (B, H, W). | |
n_classes (int): Number of classes. (Zero-class included.) | |
Raises: | |
TypeError: Input `type_map` should have dtype: torch.int64. | |
Returns: | |
torch.Tensor: A one hot tensor. Shape: (B, C, H, W). Dtype: torch.FloatTensor. | |
""" | |
if not type_map.dtype == torch.int64: | |
raise TypeError( | |
f""" | |
Input `type_map` should have dtype: torch.int64. Got: {type_map.dtype}.""" | |
) | |
one_hot = torch.zeros( | |
type_map.shape[0], | |
n_classes, | |
*type_map.shape[1:], | |
device=type_map.device, | |
dtype=type_map.dtype, | |
) | |
return one_hot.scatter_(dim=1, index=type_map.unsqueeze(1), value=1.0) + 1e-7 | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
target_weight: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
**kwargs, | |
) -> torch.Tensor: | |
"""Compute the MSE-loss. | |
Args: | |
input (torch.Tensor): The prediction map. Shape (B, C, H, W, C). | |
target (torch.Tensor): The ground truth annotations. Shape (B, H, W). | |
target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. | |
mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. | |
Returns: | |
torch.Tensor: Computed MSE loss (scalar). | |
""" | |
yhat = input | |
target_one_hot = target | |
num_classes = yhat.shape[1] | |
if target.size() != yhat.size(): | |
if target.dtype == torch.float32: | |
target_one_hot = target.unsqueeze(1) | |
else: | |
target_one_hot = MSEWeighted.tensor_one_hot(target, num_classes) | |
if self.apply_svls: | |
target_one_hot = self.apply_svls_to_target( | |
target_one_hot, num_classes, **kwargs | |
) | |
if self.apply_ls: | |
target_one_hot = self.apply_ls_to_target( | |
target_one_hot, num_classes, **kwargs | |
) | |
mse = F.mse_loss(yhat, target_one_hot, reduction="none") # (B, C, H, W) | |
mse = torch.mean(mse, dim=1) # to (B, H, W) | |
if self.apply_mask and mask is not None: | |
mse = self.apply_mask_weight(mse, mask, norm=False) # (B, H, W) | |
if self.apply_sd: | |
mse = self.apply_spectral_decouple(mse, yhat) | |
if self.class_weights is not None: | |
mse = self.apply_class_weights(mse, target) | |
if self.edge_weight is not None: | |
mse = self.apply_edge_weights(mse, target_weight) | |
return torch.mean(mse) | |
class BCEWeighted(WeightedBaseLoss): | |
def __init__( | |
self, | |
apply_sd: bool = False, | |
apply_ls: bool = False, | |
apply_svls: bool = False, | |
apply_mask: bool = False, | |
edge_weight: float = None, | |
class_weights: torch.Tensor = None, | |
**kwargs, | |
) -> None: | |
"""Binary cross entropy loss with weighting and other tricks. | |
Parameters | |
---------- | |
apply_sd : bool, default=False | |
If True, Spectral decoupling regularization will be applied to the | |
loss matrix. | |
apply_ls : bool, default=False | |
If True, Label smoothing will be applied to the target. | |
apply_svls : bool, default=False | |
If True, spatially varying label smoothing will be applied to the target | |
apply_mask : bool, default=False | |
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W) | |
edge_weight : float, default=None | |
Weight that is added to object borders. | |
class_weights : torch.Tensor, default=None | |
Class weights. A tensor of shape (n_classes,). | |
""" | |
super().__init__( | |
apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight | |
) | |
self.eps = 1e-8 | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
target_weight: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
**kwargs, | |
) -> torch.Tensor: | |
"""Compute binary cross entropy loss. | |
Parameters | |
---------- | |
yhat : torch.Tensor | |
The prediction map. Shape (B, C, H, W). | |
target : torch.Tensor | |
the ground truth annotations. Shape (B, H, W). | |
target_weight : torch.Tensor, default=None | |
The edge weight map. Shape (B, H, W). | |
mask : torch.Tensor, default=None | |
The mask map. Shape (B, H, W). | |
Returns | |
------- | |
torch.Tensor: | |
Computed BCE loss (scalar). | |
""" | |
# Logits input | |
yhat = input | |
num_classes = yhat.shape[1] | |
yhat = torch.clip(yhat, self.eps, 1.0 - self.eps) | |
if target.size() != yhat.size(): | |
target = target.unsqueeze(1).repeat_interleave(num_classes, dim=1) | |
if self.apply_svls: | |
target = self.apply_svls_to_target(target, num_classes, **kwargs) | |
if self.apply_ls: | |
target = self.apply_ls_to_target(target, num_classes, **kwargs) | |
bce = F.binary_cross_entropy_with_logits( | |
yhat.float(), target.float(), reduction="none" | |
) # (B, C, H, W) | |
bce = torch.mean(bce, dim=1) # (B, H, W) | |
if self.apply_mask and mask is not None: | |
bce = self.apply_mask_weight(bce, mask, norm=False) # (B, H, W) | |
if self.apply_sd: | |
bce = self.apply_spectral_decouple(bce, yhat) | |
if self.class_weights is not None: | |
bce = self.apply_class_weights(bce, target) | |
if self.edge_weight is not None: | |
bce = self.apply_edge_weights(bce, target_weight) | |
return torch.mean(bce) | |
# class BCEWeighted(WeightedBaseLoss): | |
# """Binary cross entropy loss with weighting and other tricks. | |
# Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) | |
# Args: | |
# apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the | |
# loss matrix. Defaults to False. | |
# apply_ls (bool, optional): If True, Label smoothing will be applied to the target. Defaults to False. | |
# apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. | |
# apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. | |
# edge_weight (float, optional): Weight that is added to object borders. Defaults to None. | |
# class_weights (torch.Tensor, optional): Class weights. A tensor of shape (n_classes,). Defaults to None. | |
# """ | |
# def __init__( | |
# self, | |
# apply_sd: bool = False, | |
# apply_ls: bool = False, | |
# apply_svls: bool = False, | |
# apply_mask: bool = False, | |
# edge_weight: float = None, | |
# class_weights: torch.Tensor = None, | |
# **kwargs, | |
# ) -> None: | |
# super().__init__( | |
# apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight | |
# ) | |
# self.eps = 1e-8 | |
# def forward( | |
# self, | |
# input: torch.Tensor, | |
# target: torch.Tensor, | |
# target_weight: torch.Tensor = None, | |
# mask: torch.Tensor = None, | |
# **kwargs, | |
# ) -> torch.Tensor: | |
# """Compute binary cross entropy loss. | |
# Args: | |
# input (torch.Tensor): The prediction map. We internally convert back via logit function. Shape (B, C, H, W). | |
# target (torch.Tensor): the ground truth annotations. Shape (B, H, W). | |
# target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. | |
# mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. | |
# Returns: | |
# torch.Tensor: Computed BCE loss (scalar). | |
# """ | |
# yhat = input | |
# yhat = torch.special.logit(yhat) | |
# num_classes = yhat.shape[1] | |
# yhat = torch.clip(yhat, self.eps, 1.0 - self.eps) | |
# if target.size() != yhat.size(): | |
# target = target.unsqueeze(1).repeat_interleave(num_classes, dim=1) | |
# if self.apply_svls: | |
# target = self.apply_svls_to_target(target, num_classes, **kwargs) | |
# if self.apply_ls: | |
# target = self.apply_ls_to_target(target, num_classes, **kwargs) | |
# bce = F.binary_cross_entropy_with_logits( | |
# yhat.float(), target.float(), reduction="none" | |
# ) # (B, C, H, W) | |
# bce = torch.mean(bce, dim=1) # (B, H, W) | |
# if self.apply_mask and mask is not None: | |
# bce = self.apply_mask_weight(bce, mask, norm=False) # (B, H, W) | |
# if self.apply_sd: | |
# bce = self.apply_spectral_decouple(bce, yhat) | |
# if self.class_weights is not None: | |
# bce = self.apply_class_weights(bce, target) | |
# if self.edge_weight is not None: | |
# bce = self.apply_edge_weights(bce, target_weight) | |
# return torch.mean(bce) | |
class CEWeighted(WeightedBaseLoss): | |
def __init__( | |
self, | |
apply_sd: bool = False, | |
apply_ls: bool = False, | |
apply_svls: bool = False, | |
apply_mask: bool = False, | |
edge_weight: float = None, | |
class_weights: torch.Tensor = None, | |
**kwargs, | |
) -> None: | |
"""Cross-Entropy loss with weighting. | |
Parameters | |
---------- | |
apply_sd : bool, default=False | |
If True, Spectral decoupling regularization will be applied to the | |
loss matrix. | |
apply_ls : bool, default=False | |
If True, Label smoothing will be applied to the target. | |
apply_svls : bool, default=False | |
If True, spatially varying label smoothing will be applied to the target | |
apply_mask : bool, default=False | |
If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W) | |
edge_weight : float, default=None | |
Weight that is added to object borders. | |
class_weights : torch.Tensor, default=None | |
Class weights. A tensor of shape (n_classes,). | |
""" | |
super().__init__( | |
apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight | |
) | |
self.eps = 1e-8 | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
target_weight: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
**kwargs, | |
) -> torch.Tensor: | |
"""Compute the cross entropy loss. | |
Parameters | |
---------- | |
yhat : torch.Tensor | |
The prediction map. Shape (B, C, H, W). | |
target : torch.Tensor | |
the ground truth annotations. Shape (B, H, W). | |
target_weight : torch.Tensor, default=None | |
The edge weight map. Shape (B, H, W). | |
mask : torch.Tensor, default=None | |
The mask map. Shape (B, H, W). | |
Returns | |
------- | |
torch.Tensor: | |
Computed CE loss (scalar). | |
""" | |
yhat = input # TODO: remove doubled Softmax -> this function needs logits instead of softmax output | |
input_soft = F.softmax(yhat, dim=1) + self.eps # (B, C, H, W) | |
num_classes = yhat.shape[1] | |
if len(target.shape) != len(yhat.shape) and target.shape[1] != num_classes: | |
target_one_hot = MSEWeighted.tensor_one_hot( | |
target, num_classes | |
) # (B, C, H, W) | |
else: | |
target_one_hot = target | |
target = torch.argmax(target, dim=1) | |
assert target_one_hot.shape == yhat.shape | |
if self.apply_svls: | |
target_one_hot = self.apply_svls_to_target( | |
target_one_hot, num_classes, **kwargs | |
) | |
if self.apply_ls: | |
target_one_hot = self.apply_ls_to_target( | |
target_one_hot, num_classes, **kwargs | |
) | |
loss = -torch.sum(target_one_hot * torch.log(input_soft), dim=1) # (B, H, W) | |
if self.apply_mask and mask is not None: | |
loss = self.apply_mask_weight(loss, mask, norm=False) # (B, H, W) | |
if self.apply_sd: | |
loss = self.apply_spectral_decouple(loss, yhat) | |
if self.class_weights is not None: | |
loss = self.apply_class_weights(loss, target) | |
if self.edge_weight is not None: | |
loss = self.apply_edge_weights(loss, target_weight) | |
return loss.mean() | |
# class CEWeighted(WeightedBaseLoss): | |
# """Cross-Entropy loss with weighting. | |
# Adapted/Copied from: https://github.com/okunator/cellseg_models.pytorch (10.5281/zenodo.7064617) | |
# Args: | |
# apply_sd (bool, optional): If True, Spectral decoupling regularization will be applied to the loss matrix. Defaults to False. | |
# apply_ls (bool, optional): If True, Label smoothing will be applied to the target. Defaults to False. | |
# apply_svls (bool, optional): If True, spatially varying label smoothing will be applied to the target. Defaults to False. | |
# apply_mask (bool, optional): If True, a mask will be applied to the loss matrix. Mask shape: (B, H, W). Defaults to False. | |
# edge_weight (float, optional): Weight that is added to object borders. Defaults to None. | |
# class_weights (torch.Tensor, optional): Class weights. A tensor of shape (n_classes,). Defaults to None. | |
# logits (bool, optional): If work on logit values. Defaults to False. Defaults to False. | |
# """ | |
# def __init__( | |
# self, | |
# apply_sd: bool = False, | |
# apply_ls: bool = False, | |
# apply_svls: bool = False, | |
# apply_mask: bool = False, | |
# edge_weight: float = None, | |
# class_weights: torch.Tensor = None, | |
# logits: bool = False, | |
# **kwargs, | |
# ) -> None: | |
# super().__init__( | |
# apply_sd, apply_ls, apply_svls, apply_mask, class_weights, edge_weight | |
# ) | |
# self.eps = 1e-8 | |
# self.logits = logits | |
# def forward( | |
# self, | |
# input: torch.Tensor, | |
# target: torch.Tensor, | |
# target_weight: torch.Tensor = None, | |
# mask: torch.Tensor = None, | |
# **kwargs, | |
# ) -> torch.Tensor: | |
# """Compute the cross entropy loss. | |
# Args: | |
# input (torch.Tensor): The prediction map. Shape (B, C, H, W). | |
# target (torch.Tensor): The ground truth annotations. Shape (B, H, W). | |
# target_weight (torch.Tensor, optional): The edge weight map. Shape (B, H, W). Defaults to None. | |
# mask (torch.Tensor, optional): The mask map. Shape (B, H, W). Defaults to None. | |
# Returns: | |
# torch.Tensor: Computed CE loss (scalar). | |
# """ | |
# yhat = input | |
# if self.logits: | |
# input_soft = ( | |
# F.softmax(yhat, dim=1) + self.eps | |
# ) # (B, C, H, W) # check if doubled softmax | |
# else: | |
# input_soft = input | |
# num_classes = yhat.shape[1] | |
# if len(target.shape) != len(yhat.shape) and target.shape[1] != num_classes: | |
# target_one_hot = MSEWeighted.tensor_one_hot( | |
# target, num_classes | |
# ) # (B, C, H, W) | |
# else: | |
# target_one_hot = target | |
# target = torch.argmax(target, dim=1) | |
# assert target_one_hot.shape == yhat.shape | |
# if self.apply_svls: | |
# target_one_hot = self.apply_svls_to_target( | |
# target_one_hot, num_classes, **kwargs | |
# ) | |
# if self.apply_ls: | |
# target_one_hot = self.apply_ls_to_target( | |
# target_one_hot, num_classes, **kwargs | |
# ) | |
# loss = -torch.sum(target_one_hot * torch.log(input_soft), dim=1) # (B, H, W) | |
# if self.apply_mask and mask is not None: | |
# loss = self.apply_mask_weight(loss, mask, norm=False) # (B, H, W) | |
# if self.apply_sd: | |
# loss = self.apply_spectral_decouple(loss, yhat) | |
# if self.class_weights is not None: | |
# loss = self.apply_class_weights(loss, target) | |
# if self.edge_weight is not None: | |
# loss = self.apply_edge_weights(loss, target_weight) | |
# return loss.mean() | |
### Stardist loss functions | |
class L1LossWeighted(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
target_weight: torch.Tensor = None, | |
) -> torch.Tensor: | |
l1loss = F.l1_loss(input, target, size_average=True, reduce=False) | |
l1loss = torch.mean(l1loss, dim=1) | |
if target_weight is not None: | |
l1loss = torch.mean(target_weight * l1loss) | |
else: | |
l1loss = torch.mean(l1loss) | |
return l1loss | |
def retrieve_loss_fn(loss_name: dict, **kwargs) -> _Loss: | |
"""Return the loss function with given name defined in the LOSS_DICT and initialize with kwargs | |
kwargs must match with the parameters defined in the initialization method of the selected loss object | |
Args: | |
loss_name (dict): Name of the loss function | |
Returns: | |
_Loss: Loss | |
""" | |
loss_fn = LOSS_DICT[loss_name] | |
loss_fn = loss_fn(**kwargs) | |
return loss_fn | |
LOSS_DICT = { | |
"xentropy_loss": XentropyLoss, | |
"dice_loss": DiceLoss, | |
"mse_loss_maps": MSELossMaps, | |
"msge_loss_maps": MSGELossMaps, | |
"FocalTverskyLoss": FocalTverskyLoss, | |
"MCFocalTverskyLoss": MCFocalTverskyLoss, | |
"CrossEntropyLoss": nn.CrossEntropyLoss, # input logits, targets | |
"L1Loss": nn.L1Loss, | |
"MSELoss": nn.MSELoss, | |
"CTCLoss": nn.CTCLoss, # probability | |
"NLLLoss": nn.NLLLoss, # log-probabilities of each class | |
"PoissonNLLLoss": nn.PoissonNLLLoss, | |
"GaussianNLLLoss": nn.GaussianNLLLoss, | |
"KLDivLoss": nn.KLDivLoss, # argument input in log-space | |
"BCELoss": nn.BCELoss, # probabilities | |
"BCEWithLogitsLoss": nn.BCEWithLogitsLoss, # logits | |
"MarginRankingLoss": nn.MarginRankingLoss, | |
"HingeEmbeddingLoss": nn.HingeEmbeddingLoss, | |
"MultiLabelMarginLoss": nn.MultiLabelMarginLoss, | |
"HuberLoss": nn.HuberLoss, | |
"SmoothL1Loss": nn.SmoothL1Loss, | |
"SoftMarginLoss": nn.SoftMarginLoss, # logits | |
"MultiLabelSoftMarginLoss": nn.MultiLabelSoftMarginLoss, | |
"CosineEmbeddingLoss": nn.CosineEmbeddingLoss, | |
"MultiMarginLoss": nn.MultiMarginLoss, | |
"TripletMarginLoss": nn.TripletMarginLoss, | |
"TripletMarginWithDistanceLoss": nn.TripletMarginWithDistanceLoss, | |
"MAEWeighted": MAEWeighted, | |
"MSEWeighted": MSEWeighted, | |
"BCEWeighted": BCEWeighted, # logits | |
"CEWeighted": CEWeighted, # logits | |
"L1LossWeighted": L1LossWeighted, | |
} | |