Vincentqyw
fix: roma
358ab8f
raw
history blame
17.8 kB
"""
Loss function implementations.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.geometry import warp_perspective
from ..misc.geometry_utils import keypoints_to_grid, get_dist_mask, get_common_line_mask
def get_loss_and_weights(model_cfg, device=torch.device("cuda")):
"""Get loss functions and either static or dynamic weighting."""
# Get the global weighting policy
w_policy = model_cfg.get("weighting_policy", "static")
if not w_policy in ["static", "dynamic"]:
raise ValueError("[Error] Not supported weighting policy.")
loss_func = {}
loss_weight = {}
# Get junction loss function and weight
w_junc, junc_loss_func = get_junction_loss_and_weight(model_cfg, w_policy)
loss_func["junc_loss"] = junc_loss_func.to(device)
loss_weight["w_junc"] = w_junc
# Get heatmap loss function and weight
w_heatmap, heatmap_loss_func = get_heatmap_loss_and_weight(
model_cfg, w_policy, device
)
loss_func["heatmap_loss"] = heatmap_loss_func.to(device)
loss_weight["w_heatmap"] = w_heatmap
# [Optionally] get descriptor loss function and weight
if model_cfg.get("descriptor_loss_func", None) is not None:
w_descriptor, descriptor_loss_func = get_descriptor_loss_and_weight(
model_cfg, w_policy
)
loss_func["descriptor_loss"] = descriptor_loss_func.to(device)
loss_weight["w_desc"] = w_descriptor
return loss_func, loss_weight
def get_junction_loss_and_weight(model_cfg, global_w_policy):
"""Get the junction loss function and weight."""
junction_loss_cfg = model_cfg.get("junction_loss_cfg", {})
# Get the junction loss weight
w_policy = junction_loss_cfg.get("policy", global_w_policy)
if w_policy == "static":
w_junc = torch.tensor(model_cfg["w_junc"], dtype=torch.float32)
elif w_policy == "dynamic":
w_junc = nn.Parameter(
torch.tensor(model_cfg["w_junc"], dtype=torch.float32), requires_grad=True
)
else:
raise ValueError("[Error] Unknown weighting policy for junction loss weight.")
# Get the junction loss function
junc_loss_name = model_cfg.get("junction_loss_func", "superpoint")
if junc_loss_name == "superpoint":
junc_loss_func = JunctionDetectionLoss(
model_cfg["grid_size"], model_cfg["keep_border_valid"]
)
else:
raise ValueError("[Error] Not supported junction loss function.")
return w_junc, junc_loss_func
def get_heatmap_loss_and_weight(model_cfg, global_w_policy, device):
"""Get the heatmap loss function and weight."""
heatmap_loss_cfg = model_cfg.get("heatmap_loss_cfg", {})
# Get the heatmap loss weight
w_policy = heatmap_loss_cfg.get("policy", global_w_policy)
if w_policy == "static":
w_heatmap = torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32)
elif w_policy == "dynamic":
w_heatmap = nn.Parameter(
torch.tensor(model_cfg["w_heatmap"], dtype=torch.float32),
requires_grad=True,
)
else:
raise ValueError("[Error] Unknown weighting policy for junction loss weight.")
# Get the corresponding heatmap loss based on the config
heatmap_loss_name = model_cfg.get("heatmap_loss_func", "cross_entropy")
if heatmap_loss_name == "cross_entropy":
# Get the heatmap class weight (always static)
heatmap_class_w = model_cfg.get("w_heatmap_class", 1.0)
class_weight = (
torch.tensor(np.array([1.0, heatmap_class_w])).to(torch.float).to(device)
)
heatmap_loss_func = HeatmapLoss(class_weight=class_weight)
else:
raise ValueError("[Error] Not supported heatmap loss function.")
return w_heatmap, heatmap_loss_func
def get_descriptor_loss_and_weight(model_cfg, global_w_policy):
"""Get the descriptor loss function and weight."""
descriptor_loss_cfg = model_cfg.get("descriptor_loss_cfg", {})
# Get the descriptor loss weight
w_policy = descriptor_loss_cfg.get("policy", global_w_policy)
if w_policy == "static":
w_descriptor = torch.tensor(model_cfg["w_desc"], dtype=torch.float32)
elif w_policy == "dynamic":
w_descriptor = nn.Parameter(
torch.tensor(model_cfg["w_desc"], dtype=torch.float32), requires_grad=True
)
else:
raise ValueError("[Error] Unknown weighting policy for descriptor loss weight.")
# Get the descriptor loss function
descriptor_loss_name = model_cfg.get("descriptor_loss_func", "regular_sampling")
if descriptor_loss_name == "regular_sampling":
descriptor_loss_func = TripletDescriptorLoss(
descriptor_loss_cfg["grid_size"],
descriptor_loss_cfg["dist_threshold"],
descriptor_loss_cfg["margin"],
)
else:
raise ValueError("[Error] Not supported descriptor loss function.")
return w_descriptor, descriptor_loss_func
def space_to_depth(input_tensor, grid_size):
"""PixelUnshuffle for pytorch."""
N, C, H, W = input_tensor.size()
# (N, C, H//bs, bs, W//bs, bs)
x = input_tensor.view(N, C, H // grid_size, grid_size, W // grid_size, grid_size)
# (N, bs, bs, C, H//bs, W//bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous()
# (N, C*bs^2, H//bs, W//bs)
x = x.view(N, C * (grid_size**2), H // grid_size, W // grid_size)
return x
def junction_detection_loss(
junction_map, junc_predictions, valid_mask=None, grid_size=8, keep_border=True
):
"""Junction detection loss."""
# Convert junc_map to channel tensor
junc_map = space_to_depth(junction_map, grid_size)
map_shape = junc_map.shape[-2:]
batch_size = junc_map.shape[0]
dust_bin_label = (
torch.ones([batch_size, 1, map_shape[0], map_shape[1]])
.to(junc_map.device)
.to(torch.int)
)
junc_map = torch.cat([junc_map * 2, dust_bin_label], dim=1)
labels = torch.argmax(
junc_map.to(torch.float)
+ torch.distributions.Uniform(0, 0.1)
.sample(junc_map.shape)
.to(junc_map.device),
dim=1,
)
# Also convert the valid mask to channel tensor
valid_mask = torch.ones(junction_map.shape) if valid_mask is None else valid_mask
valid_mask = space_to_depth(valid_mask, grid_size)
# Compute junction loss on the border patch or not
if keep_border:
valid_mask = (
torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True) > 0
)
else:
valid_mask = (
torch.sum(valid_mask.to(torch.bool).to(torch.int), dim=1, keepdim=True)
>= grid_size * grid_size
)
# Compute the classification loss
loss_func = nn.CrossEntropyLoss(reduction="none")
# The loss still need NCHW format
loss = loss_func(input=junc_predictions, target=labels.to(torch.long))
# Weighted sum by the valid mask
loss_ = torch.sum(
loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[0, 1, 2]
)
loss_final = loss_ / torch.sum(torch.squeeze(valid_mask.to(torch.float), dim=1))
return loss_final
def heatmap_loss(heatmap_gt, heatmap_pred, valid_mask=None, class_weight=None):
"""Heatmap prediction loss."""
# Compute the classification loss on each pixel
if class_weight is None:
loss_func = nn.CrossEntropyLoss(reduction="none")
else:
loss_func = nn.CrossEntropyLoss(class_weight, reduction="none")
loss = loss_func(
input=heatmap_pred, target=torch.squeeze(heatmap_gt.to(torch.long), dim=1)
)
# Weighted sum by the valid mask
# Sum over H and W
loss_spatial_sum = torch.sum(
loss * torch.squeeze(valid_mask.to(torch.float), dim=1), dim=[1, 2]
)
valid_spatial_sum = torch.sum(
torch.squeeze(valid_mask.to(torch.float32), dim=1), dim=[1, 2]
)
# Mean to single scalar over batch dimension
loss = torch.sum(loss_spatial_sum) / torch.sum(valid_spatial_sum)
return loss
class JunctionDetectionLoss(nn.Module):
"""Junction detection loss."""
def __init__(self, grid_size, keep_border):
super(JunctionDetectionLoss, self).__init__()
self.grid_size = grid_size
self.keep_border = keep_border
def forward(self, prediction, target, valid_mask=None):
return junction_detection_loss(
target, prediction, valid_mask, self.grid_size, self.keep_border
)
class HeatmapLoss(nn.Module):
"""Heatmap prediction loss."""
def __init__(self, class_weight):
super(HeatmapLoss, self).__init__()
self.class_weight = class_weight
def forward(self, prediction, target, valid_mask=None):
return heatmap_loss(target, prediction, valid_mask, self.class_weight)
class RegularizationLoss(nn.Module):
"""Module for regularization loss."""
def __init__(self):
super(RegularizationLoss, self).__init__()
self.name = "regularization_loss"
self.loss_init = torch.zeros([])
def forward(self, loss_weights):
# Place it to the same device
loss = self.loss_init.to(loss_weights["w_junc"].device)
for _, val in loss_weights.items():
if isinstance(val, nn.Parameter):
loss += val
return loss
def triplet_loss(
desc_pred1,
desc_pred2,
points1,
points2,
line_indices,
epoch,
grid_size=8,
dist_threshold=8,
init_dist_threshold=64,
margin=1,
):
"""Regular triplet loss for descriptor learning."""
b_size, _, Hc, Wc = desc_pred1.size()
img_size = (Hc * grid_size, Wc * grid_size)
device = desc_pred1.device
# Extract valid keypoints
n_points = line_indices.size()[1]
valid_points = line_indices.bool().flatten()
n_correct_points = torch.sum(valid_points).item()
if n_correct_points == 0:
return torch.tensor(0.0, dtype=torch.float, device=device)
# Check which keypoints are too close to be matched
# dist_threshold is decreased at each epoch for easier training
dist_threshold = max(dist_threshold, 2 * init_dist_threshold // (epoch + 1))
dist_mask = get_dist_mask(points1, points2, valid_points, dist_threshold)
# Additionally ban negative mining along the same line
common_line_mask = get_common_line_mask(line_indices, valid_points)
dist_mask = dist_mask | common_line_mask
# Convert the keypoints to a grid suitable for interpolation
grid1 = keypoints_to_grid(points1, img_size)
grid2 = keypoints_to_grid(points2, img_size)
# Extract the descriptors
desc1 = (
F.grid_sample(desc_pred1, grid1)
.permute(0, 2, 3, 1)
.reshape(b_size * n_points, -1)[valid_points]
)
desc1 = F.normalize(desc1, dim=1)
desc2 = (
F.grid_sample(desc_pred2, grid2)
.permute(0, 2, 3, 1)
.reshape(b_size * n_points, -1)[valid_points]
)
desc2 = F.normalize(desc2, dim=1)
desc_dists = 2 - 2 * (desc1 @ desc2.t())
# Positive distance loss
pos_dist = torch.diag(desc_dists)
# Negative distance loss
max_dist = torch.tensor(4.0, dtype=torch.float, device=device)
desc_dists[
torch.arange(n_correct_points, dtype=torch.long),
torch.arange(n_correct_points, dtype=torch.long),
] = max_dist
desc_dists[dist_mask] = max_dist
neg_dist = torch.min(
torch.min(desc_dists, dim=1)[0], torch.min(desc_dists, dim=0)[0]
)
triplet_loss = F.relu(margin + pos_dist - neg_dist)
return triplet_loss, grid1, grid2, valid_points
class TripletDescriptorLoss(nn.Module):
"""Triplet descriptor loss."""
def __init__(self, grid_size, dist_threshold, margin):
super(TripletDescriptorLoss, self).__init__()
self.grid_size = grid_size
self.init_dist_threshold = 64
self.dist_threshold = dist_threshold
self.margin = margin
def forward(self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch):
return self.descriptor_loss(
desc_pred1, desc_pred2, points1, points2, line_indices, epoch
)
# The descriptor loss based on regularly sampled points along the lines
def descriptor_loss(
self, desc_pred1, desc_pred2, points1, points2, line_indices, epoch
):
return torch.mean(
triplet_loss(
desc_pred1,
desc_pred2,
points1,
points2,
line_indices,
epoch,
self.grid_size,
self.dist_threshold,
self.init_dist_threshold,
self.margin,
)[0]
)
class TotalLoss(nn.Module):
"""Total loss summing junction, heatma, descriptor
and regularization losses."""
def __init__(self, loss_funcs, loss_weights, weighting_policy):
super(TotalLoss, self).__init__()
# Whether we need to compute the descriptor loss
self.compute_descriptors = "descriptor_loss" in loss_funcs.keys()
self.loss_funcs = loss_funcs
self.loss_weights = loss_weights
self.weighting_policy = weighting_policy
# Always add regularization loss (it will return zero if not used)
self.loss_funcs["reg_loss"] = RegularizationLoss().cuda()
def forward(
self, junc_pred, junc_target, heatmap_pred, heatmap_target, valid_mask=None
):
"""Detection only loss."""
# Compute the junction loss
junc_loss = self.loss_funcs["junc_loss"](junc_pred, junc_target, valid_mask)
# Compute the heatmap loss
heatmap_loss = self.loss_funcs["heatmap_loss"](
heatmap_pred, heatmap_target, valid_mask
)
# Compute the total loss.
if self.weighting_policy == "dynamic":
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
total_loss = (
junc_loss * torch.exp(-self.loss_weights["w_junc"])
+ heatmap_loss * torch.exp(-self.loss_weights["w_heatmap"])
+ reg_loss
)
return {
"total_loss": total_loss,
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"reg_loss": reg_loss,
"w_junc": torch.exp(-self.loss_weights["w_junc"]).item(),
"w_heatmap": torch.exp(-self.loss_weights["w_heatmap"]).item(),
}
elif self.weighting_policy == "static":
total_loss = (
junc_loss * self.loss_weights["w_junc"]
+ heatmap_loss * self.loss_weights["w_heatmap"]
)
return {
"total_loss": total_loss,
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
}
else:
raise ValueError("[Error] Unknown weighting policy.")
def forward_descriptors(
self,
junc_map_pred1,
junc_map_pred2,
junc_map_target1,
junc_map_target2,
heatmap_pred1,
heatmap_pred2,
heatmap_target1,
heatmap_target2,
line_points1,
line_points2,
line_indices,
desc_pred1,
desc_pred2,
epoch,
valid_mask1=None,
valid_mask2=None,
):
"""Loss for detection + description."""
# Compute junction loss
junc_loss = self.loss_funcs["junc_loss"](
torch.cat([junc_map_pred1, junc_map_pred2], dim=0),
torch.cat([junc_map_target1, junc_map_target2], dim=0),
torch.cat([valid_mask1, valid_mask2], dim=0),
)
# Get junction loss weight (dynamic or not)
if isinstance(self.loss_weights["w_junc"], nn.Parameter):
w_junc = torch.exp(-self.loss_weights["w_junc"])
else:
w_junc = self.loss_weights["w_junc"]
# Compute heatmap loss
heatmap_loss = self.loss_funcs["heatmap_loss"](
torch.cat([heatmap_pred1, heatmap_pred2], dim=0),
torch.cat([heatmap_target1, heatmap_target2], dim=0),
torch.cat([valid_mask1, valid_mask2], dim=0),
)
# Get heatmap loss weight (dynamic or not)
if isinstance(self.loss_weights["w_heatmap"], nn.Parameter):
w_heatmap = torch.exp(-self.loss_weights["w_heatmap"])
else:
w_heatmap = self.loss_weights["w_heatmap"]
# Compute the descriptor loss
descriptor_loss = self.loss_funcs["descriptor_loss"](
desc_pred1, desc_pred2, line_points1, line_points2, line_indices, epoch
)
# Get descriptor loss weight (dynamic or not)
if isinstance(self.loss_weights["w_desc"], nn.Parameter):
w_descriptor = torch.exp(-self.loss_weights["w_desc"])
else:
w_descriptor = self.loss_weights["w_desc"]
# Update the total loss
total_loss = (
junc_loss * w_junc
+ heatmap_loss * w_heatmap
+ descriptor_loss * w_descriptor
)
outputs = {
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"w_junc": w_junc.item() if isinstance(w_junc, nn.Parameter) else w_junc,
"w_heatmap": w_heatmap.item()
if isinstance(w_heatmap, nn.Parameter)
else w_heatmap,
"descriptor_loss": descriptor_loss,
"w_desc": w_descriptor.item()
if isinstance(w_descriptor, nn.Parameter)
else w_descriptor,
}
# Compute the regularization loss
reg_loss = self.loss_funcs["reg_loss"](self.loss_weights)
total_loss += reg_loss
outputs.update({"reg_loss": reg_loss, "total_loss": total_loss})
return outputs