""" Loss function implementations. """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from kornia.geometry import warp_perspective from ..misc.geometry_utils import (keypoints_to_grid, get_dist_mask, get_common_line_mask) def get_loss_and_weights(model_cfg, device=torch.device("cuda")): """ Get loss functions and either static or dynamic weighting. """ # Get the global weighting policy w_policy = model_cfg.get("weighting_policy", "static") if not w_policy in ["static", "dynamic"]: raise ValueError("[Error] Not supported weighting policy.") loss_func = {} loss_weight = {} # Get junction loss function and weight w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy) loss_func["junc_loss"] = junc_loss_func.to(device) loss_weight["w_junc"] = w_junc # Get heatmap loss function and weight w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight( model_cfg, w_policy, device) loss_func["heatmap_loss"] = heatmap_loss_func.to(device) loss_weight["w_heatmap"] = w_heatmap # [Optionally] get descriptor loss function and weight if model_cfg.get("descriptor_loss_func", None) is not None: w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight( model_cfg, w_policy) loss_func["descriptor_loss"] = descriptor_loss_func.to(device) loss_weight["w_desc"] = w_descriptor return loss_func, loss_weight def get_junction_loss_and_weight(model_cfg, global_w_policy): """ Get the junction loss function and weight. """ junction_loss_cfg = model_cfg.get("junction_loss_cfg", {}) # Get the junction loss weight w_policy = junction_loss_cfg.get("policy", global_w_policy) if w_policy == "static": w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32) elif w_policy == "dynamic": w_junc = nn.Parameter( torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True) else: raise ValueError( "[Error] Unknown weighting policy for junction loss weight.") # Get the junction loss function junc_loss_name = model_cfg.get("junction_loss_func", "superpoint") if junc_loss_name == "superpoint": junc_loss_func = JunctionDetectionLoss(model_cfg["grid_size"], model_cfg["keep_border_valid"]) else: raise ValueError("[Error] Not supported junction loss function.") return w_junc, junc_loss_func def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): """ Get the heatmap loss function and weight. """ heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {}) # Get the heatmap loss weight w_policy = heatmap_loss_cfg.get("policy", global_w_policy) if w_policy == "static": w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32) elif w_policy == "dynamic": w_heatmap = nn.Parameter( torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), requires_grad=True) else: raise ValueError( "[Error] Unknown weighting policy for junction loss weight.") # Get the corresponding heatmap loss based on the config heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy") if heatmap_loss_name == "cross_entropy": # Get the heatmap class weight (always static) heatmap_class_w = model_cfg.get("w_heatmap_class", 1.) class_weight = torch.tensor( np.array([1., heatmap_class_w])).to(torch.float).to(device) heatmap_loss_func = HeatmapLoss(class_weight=class_weight) else: raise ValueError("[Error] Not supported heatmap loss function.") return w_heatmap, heatmap_loss_func def get_descriptor_loss_and_weight(model_cfg, global_w_policy): """ Get the descriptor loss function and weight. """ descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {}) # Get the descriptor loss weight w_policy = descriptor_loss_cfg.get("policy", global_w_policy) if w_policy == "static": w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32) elif w_policy == "dynamic": w_descriptor = nn.Parameter(torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True) else: raise ValueError( "[Error] Unknown weighting policy for descriptor loss weight.") # Get the descriptor loss function descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling") if descriptor_loss_name == "regular_sampling": descriptor_loss_func = TripletDescriptorLoss( descriptor_loss_cfg["grid_size"], descriptor_loss_cfg["dist_threshold"], descriptor_loss_cfg["margin"]) else: raise ValueError("[Error] Not supported descriptor loss function.") return w_descriptor, descriptor_loss_func def space_to_depth(input_tensor, grid_size): """ PixelUnshuffle for pytorch. """ N, C, H, W = input_tensor.size() # (N, C, H//bs, bs, W//bs, bs) x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size) # (N, bs, bs, C, H//bs, W//bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, C*bs^2, H//bs, W//bs) x = x.view(N, C * (grid_size ** 2), H // grid_size, W // grid_size) return x def junction_detection_loss(junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True): """ Junction detection loss. """ # Convert junc_map to channel tensor junc_map = space_to_depth(junction_map, grid_size) map_shape = junc_map.shape[-2:] batch_size = junc_map.shape[0] dust_bin_label = torch.ones( [batch_size, 1, map_shape[0], map_shape[1]]).to(junc_map.device).to(torch.int) junc_map = torch.cat([junc_map*2, dust_bin_label], dim=1) labels = torch.argmax( junc_map.to(torch.float) + torch.distributions.Uniform(0, 0.1).sample(junc_map.shape).to(junc_map.device), dim=1) # Also convert the valid mask to channel tensor valid_mask = (torch.ones(junction_map.shape) if valid_mask is None else valid_mask) valid_mask = space_to_depth(valid_mask, grid_size) # Compute junction loss on the border patch or not if keep_border: valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0 else: valid_mask = torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) >= grid_size * grid_size # Compute the classification loss loss_func = nn.CrossEntropyLoss(reduction="none") # The loss still need NCHW format loss = loss_func(input=junc_predictions, target=labels.to(torch.long)) # Weighted sum by the valid mask loss_ = torch.sum(loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2]) loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1)) return loss_final def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None): """ Heatmap prediction loss. """ # Compute the classification loss on each pixel if class_weight is None: loss_func = nn.CrossEntropyLoss(reduction="none") else: loss_func = nn.CrossEntropyLoss(class_weight, reduction="none") loss = loss_func(input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1)) # Weighted sum by the valid mask # Sum over H and W loss_spatial_sum = torch.sum(loss * torch.squeeze( valid_mask.to(torch.float), dim=1), dim=[1, 2]) valid_spatial_sum = torch.sum(torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2]) # Mean to single scalar over batch dimension loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum) return loss class JunctionDetectionLoss(nn.Module): """ Junction detection loss. """ def __init__(self, grid_size, keep_border): super(JunctionDetectionLoss, self).__init__() self.grid_size = grid_size self.keep_border = keep_border def forward(self, prediction, target, valid_mask=None): return junction_detection_loss(target, prediction, valid_mask, self.grid_size, self.keep_border) class HeatmapLoss(nn.Module): """ Heatmap prediction loss. """ def __init__(self, class_weight): super(HeatmapLoss, self).__init__() self.class_weight = class_weight def forward(self, prediction, target, valid_mask=None): return heatmap_loss(target, prediction, valid_mask, self.class_weight) class RegularizationLoss(nn.Module): """ Module for regularization loss. """ def __init__(self): super(RegularizationLoss, self).__init__() self.name = "regularization_loss" self.loss_init = torch.zeros([]) def forward(self, loss_weights): # Place it to the same device loss = self.loss_init.to(loss_weights["w_junc"].device) for _, val in loss_weights.items(): if isinstance(val, nn.Parameter): loss += val return loss def triplet_loss(desc_pred1, desc_pred2, points1, points2, line_indices, epoch, grid_size=8, dist_threshold=8, init_dist_threshold=64, margin=1): """ Regular triplet loss for descriptor learning. """ b_size, _, Hc, Wc = desc_pred1.size() img_size = (Hc * grid_size, Wc * grid_size) device = desc_pred1.device # Extract valid keypoints n_points = line_indices.size()[1] valid_points = line_indices.bool().flatten() n_correct_points = torch.sum(valid_points).item() if n_correct_points == 0: return torch.tensor(0., dtype=torch.float, device=device) # Check which keypoints are too close to be matched # dist_threshold is decreased at each epoch for easier training dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1)) dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold) # Additionally ban negative mining along the same line common_line_mask = get_common_line_mask(line_indices, valid_points) dist_mask = dist_mask | common_line_mask # Convert the keypoints to a grid suitable for interpolation grid1 = keypoints_to_grid(points1, img_size) grid2 = keypoints_to_grid(points2, img_size) # Extract the descriptors desc1 = F.grid_sample(desc_pred1, grid1).permute( 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] desc1 = F.normalize(desc1, dim=1) desc2 = F.grid_sample(desc_pred2, grid2).permute( 0, 2, 3, 1).reshape(b_size * n_points, -1)[valid_points] desc2 = F.normalize(desc2, dim=1) desc_dists = 2 - 2 * (desc1 @ desc2.t()) # Positive distance loss pos_dist = torch.diag(desc_dists) # Negative distance loss max_dist = torch.tensor(4., dtype=torch.float, device=device) desc_dists[ torch.arange(n_correct_points, dtype=torch.long), torch.arange(n_correct_points, dtype=torch.long)] = max_dist desc_dists[dist_mask] = max_dist neg_dist = torch.min(torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0]) triplet_loss = F.relu(margin + pos_dist - neg_dist) return triplet_loss, grid1, grid2, valid_points class TripletDescriptorLoss(nn.Module): """ Triplet descriptor loss. """ def __init__(self, grid_size, dist_threshold, margin): super(TripletDescriptorLoss, self).__init__() self.grid_size = grid_size self.init_dist_threshold = 64 self.dist_threshold = dist_threshold self.margin = margin def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch): return self.descriptor_loss(desc_pred1, desc_pred2, points1, points2, line_indices, epoch) # The descriptor loss based on regularly sampled points along the lines def descriptor_loss(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch): return torch.mean(triplet_loss( desc_pred1, desc_pred2, points1, points2, line_indices, epoch, self.grid_size, self.dist_threshold, self.init_dist_threshold, self.margin)[0]) class TotalLoss(nn.Module): """ Total loss summing junction, heatma, descriptor and regularization losses. """ def __init__(self, loss_funcs, loss_weights, weighting_policy): super(TotalLoss, self).__init__() # Whether we need to compute the descriptor loss self.compute_descriptors = "descriptor_loss" in loss_funcs.keys() self.loss_funcs = loss_funcs self.loss_weights = loss_weights self.weighting_policy = weighting_policy # Always add regularization loss (it will return zero if not used) self.loss_funcs["reg_loss"] = RegularizationLoss().cuda() def forward(self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None): """ Detection only loss. """ # Compute the junction loss junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask) # Compute the heatmap loss heatmap_loss = self.loss_funcs["heatmap_loss"]( heatmap_pred, heatmap_target, valid_mask) # Compute the total loss. if self.weighting_policy == "dynamic": reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) total_loss = junc_loss * torch.exp(-self.loss_weights["w_junc"]) + \ heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) + \ reg_loss return { "total_loss": total_loss, "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, "reg_loss": reg_loss, "w_junc": torch.exp(-self.loss_weights["w_junc"]).item(), "w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(), } elif self.weighting_policy == "static": total_loss = junc_loss * self.loss_weights["w_junc"] + \ heatmap_loss * self.loss_weights["w_heatmap"] return { "total_loss": total_loss, "junc_loss": junc_loss, "heatmap_loss": heatmap_loss } else: raise ValueError("[Error] Unknown weighting policy.") def forward_descriptors(self, junc_map_pred1, junc_map_pred2, junc_map_target1, junc_map_target2, heatmap_pred1, heatmap_pred2, heatmap_target1, heatmap_target2, line_points1, line_points2, line_indices, desc_pred1, desc_pred2, epoch, valid_mask1=None, valid_mask2=None): """ Loss for detection + description. """ # Compute junction loss junc_loss = self.loss_funcs["junc_loss"]( torch.cat([junc_map_pred1, junc_map_pred2], dim=0), torch.cat([junc_map_target1, junc_map_target2], dim=0), torch.cat([valid_mask1, valid_mask2], dim=0) ) # Get junction loss weight (dynamic or not) if isinstance(self.loss_weights["w_junc"], nn.Parameter): w_junc = torch.exp(-self.loss_weights["w_junc"]) else: w_junc = self.loss_weights["w_junc"] # Compute heatmap loss heatmap_loss = self.loss_funcs["heatmap_loss"]( torch.cat([heatmap_pred1, heatmap_pred2], dim=0), torch.cat([heatmap_target1, heatmap_target2], dim=0), torch.cat([valid_mask1, valid_mask2], dim=0) ) # Get heatmap loss weight (dynamic or not) if isinstance(self.loss_weights["w_heatmap"], nn.Parameter): w_heatmap = torch.exp(-self.loss_weights["w_heatmap"]) else: w_heatmap = self.loss_weights["w_heatmap"] # Compute the descriptor loss descriptor_loss = self.loss_funcs["descriptor_loss"]( desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch) # Get descriptor loss weight (dynamic or not) if isinstance(self.loss_weights["w_desc"], nn.Parameter): w_descriptor = torch.exp(-self.loss_weights["w_desc"]) else: w_descriptor = self.loss_weights["w_desc"] # Update the total loss total_loss = (junc_loss * w_junc + heatmap_loss * w_heatmap + descriptor_loss * w_descriptor) outputs = { "junc_loss": junc_loss, "heatmap_loss": heatmap_loss, "w_junc": w_junc.item() \ if isinstance(w_junc, nn.Parameter) else w_junc, "w_heatmap": w_heatmap.item() \ if isinstance(w_heatmap, nn.Parameter) else w_heatmap, "descriptor_loss": descriptor_loss, "w_desc": w_descriptor.item() \ if isinstance(w_descriptor, nn.Parameter) else w_descriptor } # Compute the regularization loss reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) total_loss += reg_loss outputs.update({ "reg_loss": reg_loss, "total_loss": total_loss }) return outputs