|
"""Define all losses. When possible, as inheriting from nn.Module |
|
To send predictions to target.device |
|
""" |
|
from random import random as rand |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import models |
|
|
|
|
|
class GANLoss(nn.Module): |
|
def __init__( |
|
self, |
|
use_lsgan=True, |
|
target_real_label=1.0, |
|
target_fake_label=0.0, |
|
soft_shift=0.0, |
|
flip_prob=0.0, |
|
verbose=0, |
|
): |
|
"""Defines the GAN loss which uses either LSGAN or the regular GAN. |
|
When LSGAN is used, it is basically same as MSELoss, |
|
but it abstracts away the need to create the target label tensor |
|
that has the same size as the input + |
|
|
|
* label smoothing: target_real_label=0.75 |
|
* label flipping: flip_prob > 0. |
|
|
|
source: https://github.com/sangwoomo/instagan/blob |
|
/b67e9008fcdd6c41652f8805f0b36bcaa8b632d6/models/networks.py |
|
|
|
Args: |
|
use_lsgan (bool, optional): Use MSE or BCE. Defaults to True. |
|
target_real_label (float, optional): Value for the real target. |
|
Defaults to 1.0. |
|
target_fake_label (float, optional): Value for the fake target. |
|
Defaults to 0.0. |
|
flip_prob (float, optional): Probability of flipping the label |
|
(use for real target in Discriminator only). Defaults to 0.0. |
|
""" |
|
super().__init__() |
|
|
|
self.soft_shift = soft_shift |
|
self.verbose = verbose |
|
|
|
self.register_buffer("real_label", torch.tensor(target_real_label)) |
|
self.register_buffer("fake_label", torch.tensor(target_fake_label)) |
|
if use_lsgan: |
|
self.loss = nn.MSELoss() |
|
else: |
|
self.loss = nn.BCEWithLogitsLoss() |
|
self.flip_prob = flip_prob |
|
|
|
def get_target_tensor(self, input, target_is_real): |
|
soft_change = torch.FloatTensor(1).uniform_(0, self.soft_shift) |
|
if self.verbose > 0: |
|
print("GANLoss sampled soft_change:", soft_change.item()) |
|
if target_is_real: |
|
target_tensor = self.real_label - soft_change |
|
else: |
|
target_tensor = self.fake_label + soft_change |
|
return target_tensor.expand_as(input) |
|
|
|
def __call__(self, input, target_is_real, *args, **kwargs): |
|
r = rand() |
|
if isinstance(input, list): |
|
loss = 0 |
|
for pred_i in input: |
|
if isinstance(pred_i, list): |
|
pred_i = pred_i[-1] |
|
if r < self.flip_prob: |
|
target_is_real = not target_is_real |
|
target_tensor = self.get_target_tensor(pred_i, target_is_real) |
|
loss_tensor = self.loss(pred_i, target_tensor.to(pred_i.device)) |
|
loss += loss_tensor |
|
return loss / len(input) |
|
else: |
|
if r < self.flip_prob: |
|
target_is_real = not target_is_real |
|
target_tensor = self.get_target_tensor(input, target_is_real) |
|
return self.loss(input, target_tensor.to(input.device)) |
|
|
|
|
|
class FeatMatchLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.criterionFeat = nn.L1Loss() |
|
|
|
def __call__(self, pred_real, pred_fake): |
|
|
|
num_D = len(pred_fake) |
|
GAN_Feat_loss = 0.0 |
|
for i in range(num_D): |
|
|
|
num_intermediate_outputs = len(pred_fake[i]) - 1 |
|
for j in range(num_intermediate_outputs): |
|
unweighted_loss = self.criterionFeat( |
|
pred_fake[i][j], pred_real[i][j].detach() |
|
) |
|
GAN_Feat_loss += unweighted_loss / num_D |
|
return GAN_Feat_loss |
|
|
|
|
|
class CrossEntropy(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.loss = nn.CrossEntropyLoss() |
|
|
|
def __call__(self, logits, target): |
|
return self.loss(logits, target.to(logits.device).long()) |
|
|
|
|
|
class TravelLoss(nn.Module): |
|
def __init__(self, eps=1e-12): |
|
super().__init__() |
|
self.eps = eps |
|
|
|
def cosine_loss(self, real, fake): |
|
norm_real = torch.norm(real, p=2, dim=1)[:, None] |
|
norm_fake = torch.norm(fake, p=2, dim=1)[:, None] |
|
mat_real = real / norm_real |
|
mat_fake = fake / norm_fake |
|
mat_real = torch.max(mat_real, self.eps * torch.ones_like(mat_real)) |
|
mat_fake = torch.max(mat_fake, self.eps * torch.ones_like(mat_fake)) |
|
|
|
return torch.einsum("ij, ji -> i", mat_fake, mat_real).sum() |
|
|
|
def __call__(self, S_real, S_fake): |
|
self.v_real = [] |
|
self.v_fake = [] |
|
for i in range(len(S_real)): |
|
for j in range(i): |
|
self.v_real.append((S_real[i] - S_real[j])[None, :]) |
|
self.v_fake.append((S_fake[i] - S_fake[j])[None, :]) |
|
self.v_real_t = torch.cat(self.v_real, dim=0) |
|
self.v_fake_t = torch.cat(self.v_fake, dim=0) |
|
return self.cosine_loss(self.v_real_t, self.v_fake_t) |
|
|
|
|
|
class TVLoss(nn.Module): |
|
"""Total Variational Regularization: Penalizes differences in |
|
neighboring pixel values |
|
|
|
source: |
|
https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py |
|
""" |
|
|
|
def __init__(self, tvloss_weight=1): |
|
""" |
|
Args: |
|
TVLoss_weight (int, optional): [lambda i.e. weight for loss]. Defaults to 1. |
|
""" |
|
super(TVLoss, self).__init__() |
|
self.tvloss_weight = tvloss_weight |
|
|
|
def forward(self, x): |
|
batch_size = x.size()[0] |
|
h_x = x.size()[2] |
|
w_x = x.size()[3] |
|
count_h = self._tensor_size(x[:, :, 1:, :]) |
|
count_w = self._tensor_size(x[:, :, :, 1:]) |
|
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum() |
|
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum() |
|
return self.tvloss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size |
|
|
|
def _tensor_size(self, t): |
|
return t.size()[1] * t.size()[2] * t.size()[3] |
|
|
|
|
|
class MinentLoss(nn.Module): |
|
""" |
|
Loss for the minimization of the entropy map |
|
Source for version 1: https://github.com/valeoai/ADVENT |
|
|
|
Version 2 adds the variance of the entropy map in the computation of the loss |
|
""" |
|
|
|
def __init__(self, version=1, lambda_var=0.1): |
|
super().__init__() |
|
self.version = version |
|
self.lambda_var = lambda_var |
|
|
|
def __call__(self, pred): |
|
assert pred.dim() == 4 |
|
n, c, h, w = pred.size() |
|
entropy_map = -torch.mul(pred, torch.log2(pred + 1e-30)) / np.log2(c) |
|
if self.version == 1: |
|
return torch.sum(entropy_map) / (n * h * w) |
|
else: |
|
entropy_map_demean = entropy_map - torch.sum(entropy_map) / (n * h * w) |
|
entropy_map_squ = torch.mul(entropy_map_demean, entropy_map_demean) |
|
return torch.sum(entropy_map + self.lambda_var * entropy_map_squ) / ( |
|
n * h * w |
|
) |
|
|
|
|
|
class MSELoss(nn.Module): |
|
""" |
|
Creates a criterion that measures the mean squared error |
|
(squared L2 norm) between each element in the input x and target y . |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.loss = nn.MSELoss() |
|
|
|
def __call__(self, prediction, target): |
|
return self.loss(prediction, target.to(prediction.device)) |
|
|
|
|
|
class L1Loss(MSELoss): |
|
""" |
|
Creates a criterion that measures the mean absolute error |
|
(MAE) between each element in the input x and target y |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.loss = nn.L1Loss() |
|
|
|
|
|
class SIMSELoss(nn.Module): |
|
"""Scale invariant MSE Loss""" |
|
|
|
def __init__(self): |
|
super(SIMSELoss, self).__init__() |
|
|
|
def __call__(self, prediction, target): |
|
d = prediction - target |
|
diff = torch.mean(d * d) |
|
relDiff = torch.mean(d) * torch.mean(d) |
|
return diff - relDiff |
|
|
|
|
|
class SIGMLoss(nn.Module): |
|
"""loss from MiDaS paper |
|
MiDaS did not specify how the gradients were computed but we use Sobel |
|
filters which approximate the derivative of an image. |
|
""" |
|
|
|
def __init__(self, gmweight=0.5, scale=4, device="cuda"): |
|
super(SIGMLoss, self).__init__() |
|
self.gmweight = gmweight |
|
self.sobelx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(device) |
|
self.sobely = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(device) |
|
self.scale = scale |
|
|
|
def __call__(self, prediction, target): |
|
|
|
|
|
|
|
t_pred = torch.median(prediction) |
|
t_targ = torch.median(target) |
|
s_pred = torch.mean(torch.abs(prediction - t_pred)) |
|
s_targ = torch.mean(torch.abs(target - t_targ)) |
|
pred = (prediction - t_pred) / s_pred |
|
targ = (target - t_targ) / s_targ |
|
|
|
R = pred - targ |
|
|
|
|
|
batch_size = prediction.size()[0] |
|
num_pix = prediction.size()[-1] * prediction.size()[-2] |
|
sobelx = (self.sobelx).expand((batch_size, 1, -1, -1)) |
|
sobely = (self.sobely).expand((batch_size, 1, -1, -1)) |
|
gmLoss = 0 |
|
for k in range(self.scale): |
|
R_ = F.interpolate(R, scale_factor=1 / 2 ** k) |
|
Rx = F.conv2d(R_, sobelx, stride=1) |
|
Ry = F.conv2d(R_, sobely, stride=1) |
|
gmLoss += torch.sum(torch.abs(Rx) + torch.abs(Ry)) |
|
gmLoss = self.gmweight / num_pix * gmLoss |
|
|
|
simseLoss = 0.5 / num_pix * torch.sum(torch.abs(R)) |
|
loss = simseLoss + gmLoss |
|
return loss |
|
|
|
|
|
class ContextLoss(nn.Module): |
|
""" |
|
Masked L1 loss on non-water |
|
""" |
|
|
|
def __call__(self, input, target, mask): |
|
return torch.mean(torch.abs(torch.mul((input - target), 1 - mask))) |
|
|
|
|
|
class ReconstructionLoss(nn.Module): |
|
""" |
|
Masked L1 loss on water |
|
""" |
|
|
|
def __call__(self, input, target, mask): |
|
return torch.mean(torch.abs(torch.mul((input - target), mask))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Vgg19(nn.Module): |
|
def __init__(self, requires_grad=False): |
|
super(Vgg19, self).__init__() |
|
vgg_pretrained_features = models.vgg19(pretrained=True).features |
|
self.slice1 = nn.Sequential() |
|
self.slice2 = nn.Sequential() |
|
self.slice3 = nn.Sequential() |
|
self.slice4 = nn.Sequential() |
|
self.slice5 = nn.Sequential() |
|
for x in range(2): |
|
self.slice1.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(2, 7): |
|
self.slice2.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(7, 12): |
|
self.slice3.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(12, 21): |
|
self.slice4.add_module(str(x), vgg_pretrained_features[x]) |
|
for x in range(21, 30): |
|
self.slice5.add_module(str(x), vgg_pretrained_features[x]) |
|
if not requires_grad: |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, X): |
|
h_relu1 = self.slice1(X) |
|
h_relu2 = self.slice2(h_relu1) |
|
h_relu3 = self.slice3(h_relu2) |
|
h_relu4 = self.slice4(h_relu3) |
|
h_relu5 = self.slice5(h_relu4) |
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] |
|
return out |
|
|
|
|
|
|
|
class VGGLoss(nn.Module): |
|
def __init__(self, device): |
|
super().__init__() |
|
self.vgg = Vgg19().to(device).eval() |
|
self.criterion = nn.L1Loss() |
|
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] |
|
|
|
def forward(self, x, y): |
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y) |
|
loss = 0 |
|
for i in range(len(x_vgg)): |
|
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) |
|
return loss |
|
|
|
|
|
def get_losses(opts, verbose, device=None): |
|
"""Sets the loss functions to be used by G, D and C, as specified |
|
in the opts and returns a dictionnary of losses: |
|
|
|
losses = { |
|
"G": { |
|
"gan": {"a": ..., "t": ...}, |
|
"cycle": {"a": ..., "t": ...} |
|
"auto": {"a": ..., "t": ...} |
|
"tasks": {"h": ..., "d": ..., "s": ..., etc.} |
|
}, |
|
"D": GANLoss, |
|
"C": ... |
|
} |
|
""" |
|
|
|
losses = { |
|
"G": {"a": {}, "p": {}, "tasks": {}}, |
|
"D": {"default": {}, "advent": {}}, |
|
"C": {}, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
if "p" in opts.tasks: |
|
losses["G"]["p"]["gan"] = ( |
|
HingeLoss() |
|
if opts.gen.p.loss == "hinge" |
|
else GANLoss( |
|
use_lsgan=False, |
|
soft_shift=opts.dis.soft_shift, |
|
flip_prob=opts.dis.flip_prob, |
|
) |
|
) |
|
losses["G"]["p"]["dm"] = MSELoss() |
|
losses["G"]["p"]["vgg"] = VGGLoss(device) |
|
losses["G"]["p"]["tv"] = TVLoss() |
|
losses["G"]["p"]["context"] = ContextLoss() |
|
losses["G"]["p"]["reconstruction"] = ReconstructionLoss() |
|
losses["G"]["p"]["featmatch"] = FeatMatchLoss() |
|
|
|
|
|
if "d" in opts.tasks: |
|
if not opts.gen.d.classify.enable: |
|
if opts.gen.d.loss == "dada": |
|
depth_func = DADADepthLoss() |
|
else: |
|
depth_func = SIGMLoss(opts.train.lambdas.G.d.gml) |
|
else: |
|
depth_func = CrossEntropy() |
|
|
|
losses["G"]["tasks"]["d"] = depth_func |
|
|
|
|
|
if "s" in opts.tasks: |
|
losses["G"]["tasks"]["s"] = {} |
|
losses["G"]["tasks"]["s"]["crossent"] = CrossEntropy() |
|
losses["G"]["tasks"]["s"]["minent"] = MinentLoss() |
|
losses["G"]["tasks"]["s"]["advent"] = ADVENTAdversarialLoss( |
|
opts, gan_type=opts.dis.s.gan_type |
|
) |
|
|
|
|
|
if "m" in opts.tasks: |
|
losses["G"]["tasks"]["m"] = {} |
|
losses["G"]["tasks"]["m"]["bce"] = nn.BCEWithLogitsLoss() |
|
if opts.gen.m.use_minent_var: |
|
losses["G"]["tasks"]["m"]["minent"] = MinentLoss( |
|
version=2, lambda_var=opts.train.lambdas.advent.ent_var |
|
) |
|
else: |
|
losses["G"]["tasks"]["m"]["minent"] = MinentLoss() |
|
losses["G"]["tasks"]["m"]["tv"] = TVLoss() |
|
losses["G"]["tasks"]["m"]["advent"] = ADVENTAdversarialLoss( |
|
opts, gan_type=opts.dis.m.gan_type |
|
) |
|
losses["G"]["tasks"]["m"]["gi"] = GroundIntersectionLoss() |
|
|
|
|
|
|
|
|
|
if "p" in opts.tasks: |
|
losses["D"]["p"] = losses["G"]["p"]["gan"] |
|
if "m" in opts.tasks or "s" in opts.tasks: |
|
losses["D"]["advent"] = ADVENTAdversarialLoss(opts) |
|
return losses |
|
|
|
|
|
class GroundIntersectionLoss(nn.Module): |
|
""" |
|
Penalize areas in ground seg but not in flood mask |
|
""" |
|
|
|
def __call__(self, pred, pseudo_ground): |
|
return torch.mean(1.0 * ((pseudo_ground - pred) > 0.5)) |
|
|
|
|
|
def prob_2_entropy(prob): |
|
""" |
|
convert probabilistic prediction maps to weighted self-information maps |
|
""" |
|
n, c, h, w = prob.size() |
|
return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c) |
|
|
|
|
|
class CustomBCELoss(nn.Module): |
|
""" |
|
The first argument is a tensor and the second argument is an int. |
|
There is no need to take sigmoid before calling this function. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.loss = nn.BCEWithLogitsLoss() |
|
|
|
def __call__(self, prediction, target): |
|
return self.loss( |
|
prediction, |
|
torch.FloatTensor(prediction.size()) |
|
.fill_(target) |
|
.to(prediction.get_device()), |
|
) |
|
|
|
|
|
class ADVENTAdversarialLoss(nn.Module): |
|
""" |
|
The class is for calculating the advent loss. |
|
It is used to indirectly shrink the domain gap between sim and real |
|
|
|
_call_ function: |
|
prediction: torch.tensor with shape of [bs,c,h,w] |
|
target: int; domain label: 0 (sim) or 1 (real) |
|
discriminator: the discriminator model tells if a tensor is from sim or real |
|
|
|
output: the loss value of GANLoss |
|
""" |
|
|
|
def __init__(self, opts, gan_type="GAN"): |
|
super().__init__() |
|
self.opts = opts |
|
if gan_type == "GAN": |
|
self.loss = CustomBCELoss() |
|
elif gan_type == "WGAN" or "WGAN_gp" or "WGAN_norm": |
|
self.loss = lambda x, y: -torch.mean(y * x + (1 - y) * (1 - x)) |
|
else: |
|
raise NotImplementedError |
|
|
|
def __call__(self, prediction, target, discriminator, depth_preds=None): |
|
""" |
|
Compute the GAN loss from the Advent Discriminator given |
|
normalized (softmaxed) predictions (=pixel-wise class probabilities), |
|
and int labels (target). |
|
|
|
Args: |
|
prediction (torch.Tensor): pixel-wise probability distribution over classes |
|
target (torch.Tensor): pixel wise int target labels |
|
discriminator (torch.nn.Module): Discriminator to get the loss |
|
|
|
Returns: |
|
torch.Tensor: float 0-D loss |
|
""" |
|
d_out = prob_2_entropy(prediction) |
|
if depth_preds is not None: |
|
d_out = d_out * depth_preds |
|
d_out = discriminator(d_out) |
|
if self.opts.dis.m.architecture == "OmniDiscriminator": |
|
d_out = multiDiscriminatorAdapter(d_out, self.opts) |
|
loss_ = self.loss(d_out, target) |
|
return loss_ |
|
|
|
|
|
def multiDiscriminatorAdapter(d_out: list, opts: dict) -> torch.tensor: |
|
""" |
|
Because the OmniDiscriminator does not directly return a tensor |
|
(but a list of tensor). |
|
Since there is no multilevel masker, the 0th tensor in the list is all we want. |
|
This Adapter returns the first element(tensor) of the list that OmniDiscriminator |
|
returns. |
|
""" |
|
if ( |
|
isinstance(d_out, list) and len(d_out) == 1 |
|
): |
|
if not opts.dis.p.get_intermediate_features: |
|
d_out = d_out[0][0] |
|
else: |
|
d_out = d_out[0] |
|
else: |
|
raise Exception( |
|
"Check the setting of OmniDiscriminator! " |
|
+ "For now, we don't support multi-scale OmniDiscriminator." |
|
) |
|
return d_out |
|
|
|
|
|
class HingeLoss(nn.Module): |
|
""" |
|
Adapted from https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py |
|
for the painter |
|
""" |
|
|
|
def __init__(self, tensor=torch.FloatTensor): |
|
super().__init__() |
|
self.zero_tensor = None |
|
self.Tensor = tensor |
|
|
|
def get_zero_tensor(self, input): |
|
if self.zero_tensor is None: |
|
self.zero_tensor = self.Tensor(1).fill_(0) |
|
self.zero_tensor.requires_grad_(False) |
|
self.zero_tensor = self.zero_tensor.to(input.device) |
|
return self.zero_tensor.expand_as(input) |
|
|
|
def loss(self, input, target_is_real, for_discriminator=True): |
|
if for_discriminator: |
|
if target_is_real: |
|
minval = torch.min(input - 1, self.get_zero_tensor(input)) |
|
loss = -torch.mean(minval) |
|
else: |
|
minval = torch.min(-input - 1, self.get_zero_tensor(input)) |
|
loss = -torch.mean(minval) |
|
else: |
|
assert target_is_real, "The generator's hinge loss must be aiming for real" |
|
loss = -torch.mean(input) |
|
return loss |
|
|
|
def __call__(self, input, target_is_real, for_discriminator=True): |
|
|
|
|
|
if isinstance(input, list): |
|
loss = 0 |
|
for pred_i in input: |
|
if isinstance(pred_i, list): |
|
pred_i = pred_i[-1] |
|
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) |
|
loss += loss_tensor |
|
return loss / len(input) |
|
else: |
|
return self.loss(input, target_is_real, for_discriminator) |
|
|
|
|
|
class DADADepthLoss: |
|
"""Defines the reverse Huber loss from DADA paper for depth prediction |
|
- Samples with larger residuals are penalized more by l2 term |
|
- Samples with smaller residuals are penalized more by l1 term |
|
From https://github.com/valeoai/DADA/blob/master/dada/utils/func.py |
|
""" |
|
|
|
def loss_calc_depth(self, pred, label): |
|
n, c, h, w = pred.size() |
|
assert c == 1 |
|
|
|
pred = pred.squeeze() |
|
label = label.squeeze() |
|
|
|
adiff = torch.abs(pred - label) |
|
batch_max = 0.2 * torch.max(adiff).item() |
|
t1_mask = adiff.le(batch_max).float() |
|
t2_mask = adiff.gt(batch_max).float() |
|
t1 = adiff * t1_mask |
|
t2 = (adiff * adiff + batch_max * batch_max) / (2 * batch_max) |
|
t2 = t2 * t2_mask |
|
return (torch.sum(t1) + torch.sum(t2)) / torch.numel(pred.data) |
|
|
|
def __call__(self, pred, label): |
|
return self.loss_calc_depth(pred, label) |
|
|