"""Define all losses. When possible, as inheriting from nn.Module |
To send predictions to target.device |
""" |
from random import random as rand |
import numpy as np |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from torchvision import models |
class GANLoss(nn.Module): |
def __init__( |
self, |
use_lsgan=True, |
target_real_label=1.0, |
target_fake_label=0.0, |
soft_shift=0.0, |
flip_prob=0.0, |
verbose=0, |
): |
"""Defines the GAN loss which uses either LSGAN or the regular GAN. |
When LSGAN is used, it is basically same as MSELoss, |
but it abstracts away the need to create the target label tensor |
that has the same size as the input + |
* label smoothing: target_real_label=0.75 |
* label flipping: flip_prob > 0. |
source: https://github.com/sangwoomo/instagan/blob |
/b67e9008fcdd6c41652f8805f0b36bcaa8b632d6/models/networks.py |
Args: |
use_lsgan (bool, optional): Use MSE or BCE. Defaults to True. |
target_real_label (float, optional): Value for the real target. |
Defaults to 1.0. |
target_fake_label (float, optional): Value for the fake target. |
Defaults to 0.0. |
flip_prob (float, optional): Probability of flipping the label |
(use for real target in Discriminator only). Defaults to 0.0. |
""" |
super().__init__() |
self.soft_shift = soft_shift |
self.verbose = verbose |
self.register_buffer("real_label", torch.tensor(target_real_label)) |
self.register_buffer("fake_label", torch.tensor(target_fake_label)) |
if use_lsgan: |
self.loss = nn.MSELoss() |
else: |
self.loss = nn.BCEWithLogitsLoss() |
self.flip_prob = flip_prob |
def get_target_tensor(self, input, target_is_real): |
soft_change = torch.FloatTensor(1).uniform_(0, self.soft_shift) |
if self.verbose > 0: |
print("GANLoss sampled soft_change:", soft_change.item()) |
if target_is_real: |
target_tensor = self.real_label - soft_change |
else: |
target_tensor = self.fake_label + soft_change |
return target_tensor.expand_as(input) |
def __call__(self, input, target_is_real, *args, **kwargs): |
r = rand() |
if isinstance(input, list): |
loss = 0 |
for pred_i in input: |
if isinstance(pred_i, list): |
pred_i = pred_i[-1] |
if r < self.flip_prob: |
target_is_real = not target_is_real |
target_tensor = self.get_target_tensor(pred_i, target_is_real) |
loss_tensor = self.loss(pred_i, target_tensor.to(pred_i.device)) |
loss += loss_tensor |
return loss / len(input) |
else: |
if r < self.flip_prob: |
target_is_real = not target_is_real |
target_tensor = self.get_target_tensor(input, target_is_real) |
return self.loss(input, target_tensor.to(input.device)) |
class FeatMatchLoss(nn.Module): |
def __init__(self): |
super().__init__() |
self.criterionFeat = nn.L1Loss() |
def __call__(self, pred_real, pred_fake): |
num_D = len(pred_fake) |
GAN_Feat_loss = 0.0 |
for i in range(num_D): |
num_intermediate_outputs = len(pred_fake[i]) - 1 |
for j in range(num_intermediate_outputs): |
unweighted_loss = self.criterionFeat( |
pred_fake[i][j], pred_real[i][j].detach() |
) |
GAN_Feat_loss += unweighted_loss / num_D |
return GAN_Feat_loss |
class CrossEntropy(nn.Module): |
def __init__(self): |
super().__init__() |
self.loss = nn.CrossEntropyLoss() |
def __call__(self, logits, target): |
return self.loss(logits, target.to(logits.device).long()) |
class TravelLoss(nn.Module): |
def __init__(self, eps=1e-12): |
super().__init__() |
self.eps = eps |
def cosine_loss(self, real, fake): |
norm_real = torch.norm(real, p=2, dim=1)[:, None] |
norm_fake = torch.norm(fake, p=2, dim=1)[:, None] |
mat_real = real / norm_real |
mat_fake = fake / norm_fake |
mat_real = torch.max(mat_real, self.eps * torch.ones_like(mat_real)) |
mat_fake = torch.max(mat_fake, self.eps * torch.ones_like(mat_fake)) |
return torch.einsum("ij, ji -> i", mat_fake, mat_real).sum() |
def __call__(self, S_real, S_fake): |
self.v_real = [] |
self.v_fake = [] |
for i in range(len(S_real)): |
for j in range(i): |
self.v_real.append((S_real[i] - S_real[j])[None, :]) |
self.v_fake.append((S_fake[i] - S_fake[j])[None, :]) |
self.v_real_t = torch.cat(self.v_real, dim=0) |
self.v_fake_t = torch.cat(self.v_fake, dim=0) |
return self.cosine_loss(self.v_real_t, self.v_fake_t) |
class TVLoss(nn.Module): |
"""Total Variational Regularization: Penalizes differences in |
neighboring pixel values |
source: |
https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py |
""" |
def __init__(self, tvloss_weight=1): |
""" |
Args: |
TVLoss_weight (int, optional): [lambda i.e. weight for loss]. Defaults to 1. |
""" |
super(TVLoss, self).__init__() |
self.tvloss_weight = tvloss_weight |
def forward(self, x): |
batch_size = x.size()[0] |
h_x = x.size()[2] |
w_x = x.size()[3] |
count_h = self._tensor_size(x[:, :, 1:, :]) |
count_w = self._tensor_size(x[:, :, :, 1:]) |
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum() |
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum() |
return self.tvloss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size |
def _tensor_size(self, t): |
return t.size()[1] * t.size()[2] * t.size()[3] |
class MinentLoss(nn.Module): |
""" |
Loss for the minimization of the entropy map |
Source for version 1: https://github.com/valeoai/ADVENT |
Version 2 adds the variance of the entropy map in the computation of the loss |
""" |
def __init__(self, version=1, lambda_var=0.1): |
super().__init__() |
self.version = version |
self.lambda_var = lambda_var |
def __call__(self, pred): |
assert pred.dim() == 4 |
n, c, h, w = pred.size() |
entropy_map = -torch.mul(pred, torch.log2(pred + 1e-30)) / np.log2(c) |
if self.version == 1: |
return torch.sum(entropy_map) / (n * h * w) |
else: |
entropy_map_demean = entropy_map - torch.sum(entropy_map) / (n * h * w) |
entropy_map_squ = torch.mul(entropy_map_demean, entropy_map_demean) |
return torch.sum(entropy_map + self.lambda_var * entropy_map_squ) / ( |
n * h * w |
) |
class MSELoss(nn.Module): |
""" |
Creates a criterion that measures the mean squared error |
(squared L2 norm) between each element in the input x and target y . |
""" |
def __init__(self): |
super().__init__() |
self.loss = nn.MSELoss() |
def __call__(self, prediction, target): |
return self.loss(prediction, target.to(prediction.device)) |
class L1Loss(MSELoss): |
""" |
Creates a criterion that measures the mean absolute error |
(MAE) between each element in the input x and target y |
""" |
def __init__(self): |
super().__init__() |
self.loss = nn.L1Loss() |
class SIMSELoss(nn.Module): |
"""Scale invariant MSE Loss""" |
def __init__(self): |
super(SIMSELoss, self).__init__() |
def __call__(self, prediction, target): |
d = prediction - target |
diff = torch.mean(d * d) |
relDiff = torch.mean(d) * torch.mean(d) |
return diff - relDiff |
class SIGMLoss(nn.Module): |
"""loss from MiDaS paper |
MiDaS did not specify how the gradients were computed but we use Sobel |
filters which approximate the derivative of an image. |
""" |
def __init__(self, gmweight=0.5, scale=4, device="cuda"): |
super(SIGMLoss, self).__init__() |
self.gmweight = gmweight |
self.sobelx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(device) |
self.sobely = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(device) |
self.scale = scale |
def __call__(self, prediction, target): |
t_pred = torch.median(prediction) |
t_targ = torch.median(target) |
s_pred = torch.mean(torch.abs(prediction - t_pred)) |
s_targ = torch.mean(torch.abs(target - t_targ)) |
pred = (prediction - t_pred) / s_pred |
targ = (target - t_targ) / s_targ |
R = pred - targ |
batch_size = prediction.size()[0] |
num_pix = prediction.size()[-1] * prediction.size()[-2] |
sobelx = (self.sobelx).expand((batch_size, 1, -1, -1)) |
sobely = (self.sobely).expand((batch_size, 1, -1, -1)) |
gmLoss = 0 |
for k in range(self.scale): |
R_ = F.interpolate(R, scale_factor=1 / 2 ** k) |
Rx = F.conv2d(R_, sobelx, stride=1) |
Ry = F.conv2d(R_, sobely, stride=1) |
gmLoss += torch.sum(torch.abs(Rx) + torch.abs(Ry)) |
gmLoss = self.gmweight / num_pix * gmLoss |
simseLoss = 0.5 / num_pix * torch.sum(torch.abs(R)) |
loss = simseLoss + gmLoss |
return loss |
class ContextLoss(nn.Module): |
""" |
Masked L1 loss on non-water |
""" |
def __call__(self, input, target, mask): |
return torch.mean(torch.abs(torch.mul((input - target), 1 - mask))) |
class ReconstructionLoss(nn.Module): |
""" |
Masked L1 loss on water |
""" |
def __call__(self, input, target, mask): |
return torch.mean(torch.abs(torch.mul((input - target), mask))) |
class Vgg19(nn.Module): |
def __init__(self, requires_grad=False): |
super(Vgg19, self).__init__() |
vgg_pretrained_features = models.vgg19(pretrained=True).features |
self.slice1 = nn.Sequential() |
self.slice2 = nn.Sequential() |
self.slice3 = nn.Sequential() |
self.slice4 = nn.Sequential() |
self.slice5 = nn.Sequential() |
for x in range(2): |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
for x in range(2, 7): |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
for x in range(7, 12): |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
for x in range(12, 21): |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
for x in range(21, 30): |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
if not requires_grad: |
for param in self.parameters(): |
param.requires_grad = False |
def forward(self, X): |
h_relu1 = self.slice1(X) |
h_relu2 = self.slice2(h_relu1) |
h_relu3 = self.slice3(h_relu2) |
h_relu4 = self.slice4(h_relu3) |
h_relu5 = self.slice5(h_relu4) |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
return out |
class VGGLoss(nn.Module): |
def __init__(self, device): |
super().__init__() |
self.vgg = Vgg19().to(device).eval() |
self.criterion = nn.L1Loss() |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
def forward(self, x, y): |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
loss = 0 |
for i in range(len(x_vgg)): |
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
return loss |
def get_losses(opts, verbose, device=None): |
"""Sets the loss functions to be used by G, D and C, as specified |
in the opts and returns a dictionnary of losses: |
losses = { |
"G": { |
"gan": {"a": ..., "t": ...}, |
"cycle": {"a": ..., "t": ...} |
"auto": {"a": ..., "t": ...} |
"tasks": {"h": ..., "d": ..., "s": ..., etc.} |
}, |
"D": GANLoss, |
"C": ... |
} |
""" |
losses = { |
"G": {"a": {}, "p": {}, "tasks": {}}, |
"D": {"default": {}, "advent": {}}, |
"C": {}, |
} |
if "p" in opts.tasks: |
losses["G"]["p"]["gan"] = ( |
HingeLoss() |
if opts.gen.p.loss == "hinge" |
else GANLoss( |
use_lsgan=False, |
soft_shift=opts.dis.soft_shift, |
flip_prob=opts.dis.flip_prob, |
) |
) |
losses["G"]["p"]["dm"] = MSELoss() |
losses["G"]["p"]["vgg"] = VGGLoss(device) |
losses["G"]["p"]["tv"] = TVLoss() |
losses["G"]["p"]["context"] = ContextLoss() |
losses["G"]["p"]["reconstruction"] = ReconstructionLoss() |
losses["G"]["p"]["featmatch"] = FeatMatchLoss() |
if "d" in opts.tasks: |
if not opts.gen.d.classify.enable: |
if opts.gen.d.loss == "dada": |
depth_func = DADADepthLoss() |
else: |
depth_func = SIGMLoss(opts.train.lambdas.G.d.gml) |
else: |
depth_func = CrossEntropy() |
losses["G"]["tasks"]["d"] = depth_func |
if "s" in opts.tasks: |
losses["G"]["tasks"]["s"] = {} |
losses["G"]["tasks"]["s"]["crossent"] = CrossEntropy() |
losses["G"]["tasks"]["s"]["minent"] = MinentLoss() |
losses["G"]["tasks"]["s"]["advent"] = ADVENTAdversarialLoss( |
opts, gan_type=opts.dis.s.gan_type |
) |
if "m" in opts.tasks: |
losses["G"]["tasks"]["m"] = {} |
losses["G"]["tasks"]["m"]["bce"] = nn.BCEWithLogitsLoss() |
if opts.gen.m.use_minent_var: |
losses["G"]["tasks"]["m"]["minent"] = MinentLoss( |
version=2, lambda_var=opts.train.lambdas.advent.ent_var |
) |
else: |
losses["G"]["tasks"]["m"]["minent"] = MinentLoss() |
losses["G"]["tasks"]["m"]["tv"] = TVLoss() |
losses["G"]["tasks"]["m"]["advent"] = ADVENTAdversarialLoss( |
opts, gan_type=opts.dis.m.gan_type |
) |
losses["G"]["tasks"]["m"]["gi"] = GroundIntersectionLoss() |
if "p" in opts.tasks: |
losses["D"]["p"] = losses["G"]["p"]["gan"] |
if "m" in opts.tasks or "s" in opts.tasks: |
losses["D"]["advent"] = ADVENTAdversarialLoss(opts) |
return losses |
class GroundIntersectionLoss(nn.Module): |
""" |
Penalize areas in ground seg but not in flood mask |
""" |
def __call__(self, pred, pseudo_ground): |
return torch.mean(1.0 * ((pseudo_ground - pred) > 0.5)) |
def prob_2_entropy(prob): |
""" |
convert probabilistic prediction maps to weighted self-information maps |
""" |
n, c, h, w = prob.size() |
return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c) |
class CustomBCELoss(nn.Module): |
""" |
The first argument is a tensor and the second argument is an int. |
There is no need to take sigmoid before calling this function. |
""" |
def __init__(self): |
super().__init__() |
self.loss = nn.BCEWithLogitsLoss() |
def __call__(self, prediction, target): |
return self.loss( |
prediction, |
torch.FloatTensor(prediction.size()) |
.fill_(target) |
.to(prediction.get_device()), |
) |
class ADVENTAdversarialLoss(nn.Module): |
""" |
The class is for calculating the advent loss. |
It is used to indirectly shrink the domain gap between sim and real |
_call_ function: |
prediction: torch.tensor with shape of [bs,c,h,w] |
target: int; domain label: 0 (sim) or 1 (real) |
discriminator: the discriminator model tells if a tensor is from sim or real |
output: the loss value of GANLoss |
""" |
def __init__(self, opts, gan_type="GAN"): |
super().__init__() |
self.opts = opts |
if gan_type == "GAN": |
self.loss = CustomBCELoss() |
elif gan_type == "WGAN" or "WGAN_gp" or "WGAN_norm": |
self.loss = lambda x, y: -torch.mean(y * x + (1 - y) * (1 - x)) |
else: |
raise NotImplementedError |
def __call__(self, prediction, target, discriminator, depth_preds=None): |
""" |
Compute the GAN loss from the Advent Discriminator given |
normalized (softmaxed) predictions (=pixel-wise class probabilities), |
and int labels (target). |
Args: |
prediction (torch.Tensor): pixel-wise probability distribution over classes |
target (torch.Tensor): pixel wise int target labels |
discriminator (torch.nn.Module): Discriminator to get the loss |
Returns: |
torch.Tensor: float 0-D loss |
""" |
d_out = prob_2_entropy(prediction) |
if depth_preds is not None: |
d_out = d_out * depth_preds |
d_out = discriminator(d_out) |
if self.opts.dis.m.architecture == "OmniDiscriminator": |
d_out = multiDiscriminatorAdapter(d_out, self.opts) |
loss_ = self.loss(d_out, target) |
return loss_ |
def multiDiscriminatorAdapter(d_out: list, opts: dict) -> torch.tensor: |
""" |
Because the OmniDiscriminator does not directly return a tensor |
(but a list of tensor). |
Since there is no multilevel masker, the 0th tensor in the list is all we want. |
This Adapter returns the first element(tensor) of the list that OmniDiscriminator |
returns. |
""" |
if ( |
isinstance(d_out, list) and len(d_out) == 1 |
): |
if not opts.dis.p.get_intermediate_features: |
d_out = d_out[0][0] |
else: |
d_out = d_out[0] |
else: |
raise Exception( |
"Check the setting of OmniDiscriminator! " |
+ "For now, we don't support multi-scale OmniDiscriminator." |
) |
return d_out |
class HingeLoss(nn.Module): |
""" |
Adapted from https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py |
for the painter |
""" |
def __init__(self, tensor=torch.FloatTensor): |
super().__init__() |
self.zero_tensor = None |
self.Tensor = tensor |
def get_zero_tensor(self, input): |
if self.zero_tensor is None: |
self.zero_tensor = self.Tensor(1).fill_(0) |
self.zero_tensor.requires_grad_(False) |
self.zero_tensor = self.zero_tensor.to(input.device) |
return self.zero_tensor.expand_as(input) |
def loss(self, input, target_is_real, for_discriminator=True): |
if for_discriminator: |
if target_is_real: |
minval = torch.min(input - 1, self.get_zero_tensor(input)) |
loss = -torch.mean(minval) |
else: |
minval = torch.min(-input - 1, self.get_zero_tensor(input)) |
loss = -torch.mean(minval) |
else: |
assert target_is_real, "The generator's hinge loss must be aiming for real" |
loss = -torch.mean(input) |
return loss |
def __call__(self, input, target_is_real, for_discriminator=True): |
if isinstance(input, list): |
loss = 0 |
for pred_i in input: |
if isinstance(pred_i, list): |
pred_i = pred_i[-1] |
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) |
loss += loss_tensor |
return loss / len(input) |
else: |
return self.loss(input, target_is_real, for_discriminator) |
class DADADepthLoss: |
"""Defines the reverse Huber loss from DADA paper for depth prediction |
- Samples with larger residuals are penalized more by l2 term |
- Samples with smaller residuals are penalized more by l1 term |
From https://github.com/valeoai/DADA/blob/master/dada/utils/func.py |
""" |
def loss_calc_depth(self, pred, label): |
n, c, h, w = pred.size() |
assert c == 1 |
pred = pred.squeeze() |
label = label.squeeze() |
adiff = torch.abs(pred - label) |
batch_max = 0.2 * torch.max(adiff).item() |
t1_mask = adiff.le(batch_max).float() |
t2_mask = adiff.gt(batch_max).float() |
t1 = adiff * t1_mask |
t2 = (adiff * adiff + batch_max * batch_max) / (2 * batch_max) |
t2 = t2 * t2_mask |
return (torch.sum(t1) + torch.sum(t2)) / torch.numel(pred.data) |
def __call__(self, pred, label): |
return self.loss_calc_depth(pred, label) |