|
""" |
|
Loss function implementations. |
|
""" |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from kornia.geometry import warp_perspective |
|
|
|
from ..misc.geometry_utils import keypoints_to_grid, get_dist_mask, get_common_line_mask |
|
|
|
|
|
def get_loss_and_weights(model_cfg, device=torch.device("cuda")): |
|
"""Get loss functions and either static or dynamic weighting.""" |
|
|
|
w_policy = model_cfg.get("weighting_policy", "static") |
|
if not w_policy in ["static", "dynamic"]: |
|
raise ValueError("[Error] Not supported weighting policy.") |
|
|
|
loss_func = {} |
|
loss_weight = {} |
|
|
|
w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy) |
|
loss_func["junc_loss"] = junc_loss_func.to(device) |
|
loss_weight["w_junc"] = w_junc |
|
|
|
|
|
w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight( |
|
model_cfg, w_policy, device |
|
) |
|
loss_func["heatmap_loss"] = heatmap_loss_func.to(device) |
|
loss_weight["w_heatmap"] = w_heatmap |
|
|
|
|
|
if model_cfg.get("descriptor_loss_func", None) is not None: |
|
w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight( |
|
model_cfg, w_policy |
|
) |
|
loss_func["descriptor_loss"] = descriptor_loss_func.to(device) |
|
loss_weight["w_desc"] = w_descriptor |
|
|
|
return loss_func, loss_weight |
|
|
|
|
|
def get_junction_loss_and_weight(model_cfg, global_w_policy): |
|
"""Get the junction loss function and weight.""" |
|
junction_loss_cfg = model_cfg.get("junction_loss_cfg", {}) |
|
|
|
|
|
w_policy = junction_loss_cfg.get("policy", global_w_policy) |
|
if w_policy == "static": |
|
w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32) |
|
elif w_policy == "dynamic": |
|
w_junc = nn.Parameter( |
|
torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True |
|
) |
|
else: |
|
raise ValueError("[Error] Unknown weighting policy for junction loss weight.") |
|
|
|
|
|
junc_loss_name = model_cfg.get("junction_loss_func", "superpoint") |
|
if junc_loss_name == "superpoint": |
|
junc_loss_func = JunctionDetectionLoss( |
|
model_cfg["grid_size"], model_cfg["keep_border_valid"] |
|
) |
|
else: |
|
raise ValueError("[Error] Not supported junction loss function.") |
|
|
|
return w_junc, junc_loss_func |
|
|
|
|
|
def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device): |
|
"""Get the heatmap loss function and weight.""" |
|
heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {}) |
|
|
|
|
|
w_policy = heatmap_loss_cfg.get("policy", global_w_policy) |
|
if w_policy == "static": |
|
w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32) |
|
elif w_policy == "dynamic": |
|
w_heatmap = nn.Parameter( |
|
torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32), |
|
requires_grad=True, |
|
) |
|
else: |
|
raise ValueError("[Error] Unknown weighting policy for junction loss weight.") |
|
|
|
|
|
heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy") |
|
if heatmap_loss_name == "cross_entropy": |
|
|
|
heatmap_class_w = model_cfg.get("w_heatmap_class", 1.0) |
|
class_weight = ( |
|
torch.tensor(np.array([1.0, heatmap_class_w])).to(torch.float).to(device) |
|
) |
|
heatmap_loss_func = HeatmapLoss(class_weight=class_weight) |
|
else: |
|
raise ValueError("[Error] Not supported heatmap loss function.") |
|
|
|
return w_heatmap, heatmap_loss_func |
|
|
|
|
|
def get_descriptor_loss_and_weight(model_cfg, global_w_policy): |
|
"""Get the descriptor loss function and weight.""" |
|
descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {}) |
|
|
|
|
|
w_policy = descriptor_loss_cfg.get("policy", global_w_policy) |
|
if w_policy == "static": |
|
w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32) |
|
elif w_policy == "dynamic": |
|
w_descriptor = nn.Parameter( |
|
torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True |
|
) |
|
else: |
|
raise ValueError("[Error] Unknown weighting policy for descriptor loss weight.") |
|
|
|
|
|
descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling") |
|
if descriptor_loss_name == "regular_sampling": |
|
descriptor_loss_func = TripletDescriptorLoss( |
|
descriptor_loss_cfg["grid_size"], |
|
descriptor_loss_cfg["dist_threshold"], |
|
descriptor_loss_cfg["margin"], |
|
) |
|
else: |
|
raise ValueError("[Error] Not supported descriptor loss function.") |
|
|
|
return w_descriptor, descriptor_loss_func |
|
|
|
|
|
def space_to_depth(input_tensor, grid_size): |
|
"""PixelUnshuffle for pytorch.""" |
|
N, C, H, W = input_tensor.size() |
|
|
|
x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size) |
|
|
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() |
|
|
|
x = x.view(N, C * (grid_size**2), H // grid_size, W // grid_size) |
|
return x |
|
|
|
|
|
def junction_detection_loss( |
|
junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True |
|
): |
|
"""Junction detection loss.""" |
|
|
|
junc_map = space_to_depth(junction_map, grid_size) |
|
map_shape = junc_map.shape[-2:] |
|
batch_size = junc_map.shape[0] |
|
dust_bin_label = ( |
|
torch.ones([batch_size, 1, map_shape[0], map_shape[1]]) |
|
.to(junc_map.device) |
|
.to(torch.int) |
|
) |
|
junc_map = torch.cat([junc_map * 2, dust_bin_label], dim=1) |
|
labels = torch.argmax( |
|
junc_map.to(torch.float) |
|
+ torch.distributions.Uniform(0, 0.1) |
|
.sample(junc_map.shape) |
|
.to(junc_map.device), |
|
dim=1, |
|
) |
|
|
|
|
|
valid_mask = torch.ones(junction_map.shape) if valid_mask is None else valid_mask |
|
valid_mask = space_to_depth(valid_mask, grid_size) |
|
|
|
|
|
if keep_border: |
|
valid_mask = ( |
|
torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0 |
|
) |
|
else: |
|
valid_mask = ( |
|
torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) |
|
>= grid_size * grid_size |
|
) |
|
|
|
|
|
loss_func = nn.CrossEntropyLoss(reduction="none") |
|
|
|
loss = loss_func(input=junc_predictions, target=labels.to(torch.long)) |
|
|
|
|
|
loss_ = torch.sum( |
|
loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2] |
|
) |
|
loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1)) |
|
|
|
return loss_final |
|
|
|
|
|
def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None): |
|
"""Heatmap prediction loss.""" |
|
|
|
if class_weight is None: |
|
loss_func = nn.CrossEntropyLoss(reduction="none") |
|
else: |
|
loss_func = nn.CrossEntropyLoss(class_weight, reduction="none") |
|
|
|
loss = loss_func( |
|
input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1) |
|
) |
|
|
|
|
|
|
|
loss_spatial_sum = torch.sum( |
|
loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[1, 2] |
|
) |
|
valid_spatial_sum = torch.sum( |
|
torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2] |
|
) |
|
|
|
loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum) |
|
|
|
return loss |
|
|
|
|
|
class JunctionDetectionLoss(nn.Module): |
|
"""Junction detection loss.""" |
|
|
|
def __init__(self, grid_size, keep_border): |
|
super(JunctionDetectionLoss, self).__init__() |
|
self.grid_size = grid_size |
|
self.keep_border = keep_border |
|
|
|
def forward(self, prediction, target, valid_mask=None): |
|
return junction_detection_loss( |
|
target, prediction, valid_mask, self.grid_size, self.keep_border |
|
) |
|
|
|
|
|
class HeatmapLoss(nn.Module): |
|
"""Heatmap prediction loss.""" |
|
|
|
def __init__(self, class_weight): |
|
super(HeatmapLoss, self).__init__() |
|
self.class_weight = class_weight |
|
|
|
def forward(self, prediction, target, valid_mask=None): |
|
return heatmap_loss(target, prediction, valid_mask, self.class_weight) |
|
|
|
|
|
class RegularizationLoss(nn.Module): |
|
"""Module for regularization loss.""" |
|
|
|
def __init__(self): |
|
super(RegularizationLoss, self).__init__() |
|
self.name = "regularization_loss" |
|
self.loss_init = torch.zeros([]) |
|
|
|
def forward(self, loss_weights): |
|
|
|
loss = self.loss_init.to(loss_weights["w_junc"].device) |
|
for _, val in loss_weights.items(): |
|
if isinstance(val, nn.Parameter): |
|
loss += val |
|
|
|
return loss |
|
|
|
|
|
def triplet_loss( |
|
desc_pred1, |
|
desc_pred2, |
|
points1, |
|
points2, |
|
line_indices, |
|
epoch, |
|
grid_size=8, |
|
dist_threshold=8, |
|
init_dist_threshold=64, |
|
margin=1, |
|
): |
|
"""Regular triplet loss for descriptor learning.""" |
|
b_size, _, Hc, Wc = desc_pred1.size() |
|
img_size = (Hc * grid_size, Wc * grid_size) |
|
device = desc_pred1.device |
|
|
|
|
|
n_points = line_indices.size()[1] |
|
valid_points = line_indices.bool().flatten() |
|
n_correct_points = torch.sum(valid_points).item() |
|
if n_correct_points == 0: |
|
return torch.tensor(0.0, dtype=torch.float, device=device) |
|
|
|
|
|
|
|
dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1)) |
|
dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold) |
|
|
|
|
|
common_line_mask = get_common_line_mask(line_indices, valid_points) |
|
dist_mask = dist_mask | common_line_mask |
|
|
|
|
|
grid1 = keypoints_to_grid(points1, img_size) |
|
grid2 = keypoints_to_grid(points2, img_size) |
|
|
|
|
|
desc1 = ( |
|
F.grid_sample(desc_pred1, grid1) |
|
.permute(0, 2, 3, 1) |
|
.reshape(b_size * n_points, -1)[valid_points] |
|
) |
|
desc1 = F.normalize(desc1, dim=1) |
|
desc2 = ( |
|
F.grid_sample(desc_pred2, grid2) |
|
.permute(0, 2, 3, 1) |
|
.reshape(b_size * n_points, -1)[valid_points] |
|
) |
|
desc2 = F.normalize(desc2, dim=1) |
|
desc_dists = 2 - 2 * (desc1 @ desc2.t()) |
|
|
|
|
|
pos_dist = torch.diag(desc_dists) |
|
|
|
|
|
max_dist = torch.tensor(4.0, dtype=torch.float, device=device) |
|
desc_dists[ |
|
torch.arange(n_correct_points, dtype=torch.long), |
|
torch.arange(n_correct_points, dtype=torch.long), |
|
] = max_dist |
|
desc_dists[dist_mask] = max_dist |
|
neg_dist = torch.min( |
|
torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0] |
|
) |
|
|
|
triplet_loss = F.relu(margin + pos_dist - neg_dist) |
|
return triplet_loss, grid1, grid2, valid_points |
|
|
|
|
|
class TripletDescriptorLoss(nn.Module): |
|
"""Triplet descriptor loss.""" |
|
|
|
def __init__(self, grid_size, dist_threshold, margin): |
|
super(TripletDescriptorLoss, self).__init__() |
|
self.grid_size = grid_size |
|
self.init_dist_threshold = 64 |
|
self.dist_threshold = dist_threshold |
|
self.margin = margin |
|
|
|
def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch): |
|
return self.descriptor_loss( |
|
desc_pred1, desc_pred2, points1, points2, line_indices, epoch |
|
) |
|
|
|
|
|
def descriptor_loss( |
|
self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch |
|
): |
|
return torch.mean( |
|
triplet_loss( |
|
desc_pred1, |
|
desc_pred2, |
|
points1, |
|
points2, |
|
line_indices, |
|
epoch, |
|
self.grid_size, |
|
self.dist_threshold, |
|
self.init_dist_threshold, |
|
self.margin, |
|
)[0] |
|
) |
|
|
|
|
|
class TotalLoss(nn.Module): |
|
"""Total loss summing junction, heatma, descriptor |
|
and regularization losses.""" |
|
|
|
def __init__(self, loss_funcs, loss_weights, weighting_policy): |
|
super(TotalLoss, self).__init__() |
|
|
|
self.compute_descriptors = "descriptor_loss" in loss_funcs.keys() |
|
|
|
self.loss_funcs = loss_funcs |
|
self.loss_weights = loss_weights |
|
self.weighting_policy = weighting_policy |
|
|
|
|
|
self.loss_funcs["reg_loss"] = RegularizationLoss().cuda() |
|
|
|
def forward( |
|
self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None |
|
): |
|
"""Detection only loss.""" |
|
|
|
junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask) |
|
|
|
heatmap_loss = self.loss_funcs["heatmap_loss"]( |
|
heatmap_pred, heatmap_target, valid_mask |
|
) |
|
|
|
|
|
if self.weighting_policy == "dynamic": |
|
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) |
|
total_loss = ( |
|
junc_loss * torch.exp(-self.loss_weights["w_junc"]) |
|
+ heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"]) |
|
+ reg_loss |
|
) |
|
|
|
return { |
|
"total_loss": total_loss, |
|
"junc_loss": junc_loss, |
|
"heatmap_loss": heatmap_loss, |
|
"reg_loss": reg_loss, |
|
"w_junc": torch.exp(-self.loss_weights["w_junc"]).item(), |
|
"w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(), |
|
} |
|
|
|
elif self.weighting_policy == "static": |
|
total_loss = ( |
|
junc_loss * self.loss_weights["w_junc"] |
|
+ heatmap_loss * self.loss_weights["w_heatmap"] |
|
) |
|
|
|
return { |
|
"total_loss": total_loss, |
|
"junc_loss": junc_loss, |
|
"heatmap_loss": heatmap_loss, |
|
} |
|
|
|
else: |
|
raise ValueError("[Error] Unknown weighting policy.") |
|
|
|
def forward_descriptors( |
|
self, |
|
junc_map_pred1, |
|
junc_map_pred2, |
|
junc_map_target1, |
|
junc_map_target2, |
|
heatmap_pred1, |
|
heatmap_pred2, |
|
heatmap_target1, |
|
heatmap_target2, |
|
line_points1, |
|
line_points2, |
|
line_indices, |
|
desc_pred1, |
|
desc_pred2, |
|
epoch, |
|
valid_mask1=None, |
|
valid_mask2=None, |
|
): |
|
"""Loss for detection + description.""" |
|
|
|
junc_loss = self.loss_funcs["junc_loss"]( |
|
torch.cat([junc_map_pred1, junc_map_pred2], dim=0), |
|
torch.cat([junc_map_target1, junc_map_target2], dim=0), |
|
torch.cat([valid_mask1, valid_mask2], dim=0), |
|
) |
|
|
|
if isinstance(self.loss_weights["w_junc"], nn.Parameter): |
|
w_junc = torch.exp(-self.loss_weights["w_junc"]) |
|
else: |
|
w_junc = self.loss_weights["w_junc"] |
|
|
|
|
|
heatmap_loss = self.loss_funcs["heatmap_loss"]( |
|
torch.cat([heatmap_pred1, heatmap_pred2], dim=0), |
|
torch.cat([heatmap_target1, heatmap_target2], dim=0), |
|
torch.cat([valid_mask1, valid_mask2], dim=0), |
|
) |
|
|
|
if isinstance(self.loss_weights["w_heatmap"], nn.Parameter): |
|
w_heatmap = torch.exp(-self.loss_weights["w_heatmap"]) |
|
else: |
|
w_heatmap = self.loss_weights["w_heatmap"] |
|
|
|
|
|
descriptor_loss = self.loss_funcs["descriptor_loss"]( |
|
desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch |
|
) |
|
|
|
if isinstance(self.loss_weights["w_desc"], nn.Parameter): |
|
w_descriptor = torch.exp(-self.loss_weights["w_desc"]) |
|
else: |
|
w_descriptor = self.loss_weights["w_desc"] |
|
|
|
|
|
total_loss = ( |
|
junc_loss * w_junc |
|
+ heatmap_loss * w_heatmap |
|
+ descriptor_loss * w_descriptor |
|
) |
|
outputs = { |
|
"junc_loss": junc_loss, |
|
"heatmap_loss": heatmap_loss, |
|
"w_junc": w_junc.item() if isinstance(w_junc, nn.Parameter) else w_junc, |
|
"w_heatmap": w_heatmap.item() |
|
if isinstance(w_heatmap, nn.Parameter) |
|
else w_heatmap, |
|
"descriptor_loss": descriptor_loss, |
|
"w_desc": w_descriptor.item() |
|
if isinstance(w_descriptor, nn.Parameter) |
|
else w_descriptor, |
|
} |
|
|
|
|
|
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights) |
|
total_loss += reg_loss |
|
outputs.update({"reg_loss": reg_loss, "total_loss": total_loss}) |
|
|
|
return outputs |
|
|