|
from loguru import logger |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class ASpanLoss(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.loss_config = config["aspan"]["loss"] |
|
self.match_type = self.config["aspan"]["match_coarse"]["match_type"] |
|
self.sparse_spvs = self.config["aspan"]["match_coarse"]["sparse_spvs"] |
|
self.flow_weight = self.config["aspan"]["loss"]["flow_weight"] |
|
|
|
|
|
self.correct_thr = self.loss_config["fine_correct_thr"] |
|
self.c_pos_w = self.loss_config["pos_weight"] |
|
self.c_neg_w = self.loss_config["neg_weight"] |
|
|
|
self.fine_type = self.loss_config["fine_type"] |
|
|
|
def compute_flow_loss(self, coarse_corr_gt, flow_list, h0, w0, h1, w1): |
|
|
|
|
|
loss1 = self.flow_loss_worker( |
|
flow_list[0], coarse_corr_gt[0], coarse_corr_gt[1], coarse_corr_gt[2], w1 |
|
) |
|
loss2 = self.flow_loss_worker( |
|
flow_list[1], coarse_corr_gt[0], coarse_corr_gt[2], coarse_corr_gt[1], w0 |
|
) |
|
total_loss = (loss1 + loss2) / 2 |
|
return total_loss |
|
|
|
def flow_loss_worker(self, flow, batch_indicies, self_indicies, cross_indicies, w): |
|
bs, layer_num = flow.shape[1], flow.shape[0] |
|
flow = flow.view(layer_num, bs, -1, 4) |
|
gt_flow = torch.stack([cross_indicies % w, cross_indicies // w], dim=1) |
|
|
|
total_loss_list = [] |
|
for layer_index in range(layer_num): |
|
cur_flow_list = flow[layer_index] |
|
spv_flow = cur_flow_list[batch_indicies, self_indicies][:, :2] |
|
spv_conf = cur_flow_list[batch_indicies, self_indicies][ |
|
:, 2: |
|
] |
|
l2_flow_dis = (gt_flow - spv_flow) ** 2 |
|
total_loss = spv_conf + torch.exp(-spv_conf) * l2_flow_dis |
|
total_loss_list.append(total_loss.mean()) |
|
total_loss = torch.stack(total_loss_list, dim=-1) * self.flow_weight |
|
return total_loss |
|
|
|
def compute_coarse_loss(self, conf, conf_gt, weight=None): |
|
"""Point-wise CE / Focal Loss with 0 / 1 confidence as gt. |
|
Args: |
|
conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1) |
|
conf_gt (torch.Tensor): (N, HW0, HW1) |
|
weight (torch.Tensor): (N, HW0, HW1) |
|
""" |
|
pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 |
|
c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w |
|
|
|
if not pos_mask.any(): |
|
pos_mask[0, 0, 0] = True |
|
if weight is not None: |
|
weight[0, 0, 0] = 0.0 |
|
c_pos_w = 0.0 |
|
if not neg_mask.any(): |
|
neg_mask[0, 0, 0] = True |
|
if weight is not None: |
|
weight[0, 0, 0] = 0.0 |
|
c_neg_w = 0.0 |
|
|
|
if self.loss_config["coarse_type"] == "cross_entropy": |
|
assert ( |
|
not self.sparse_spvs |
|
), "Sparse Supervision for cross-entropy not implemented!" |
|
conf = torch.clamp(conf, 1e-6, 1 - 1e-6) |
|
loss_pos = -torch.log(conf[pos_mask]) |
|
loss_neg = -torch.log(1 - conf[neg_mask]) |
|
if weight is not None: |
|
loss_pos = loss_pos * weight[pos_mask] |
|
loss_neg = loss_neg * weight[neg_mask] |
|
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() |
|
elif self.loss_config["coarse_type"] == "focal": |
|
conf = torch.clamp(conf, 1e-6, 1 - 1e-6) |
|
alpha = self.loss_config["focal_alpha"] |
|
gamma = self.loss_config["focal_gamma"] |
|
|
|
if self.sparse_spvs: |
|
pos_conf = ( |
|
conf[:, :-1, :-1][pos_mask] |
|
if self.match_type == "sinkhorn" |
|
else conf[pos_mask] |
|
) |
|
loss_pos = -alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() |
|
|
|
if self.match_type == "sinkhorn": |
|
neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0 |
|
neg_conf = torch.cat( |
|
[conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0 |
|
) |
|
loss_neg = -alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log() |
|
else: |
|
|
|
|
|
pass |
|
|
|
if weight is not None: |
|
|
|
|
|
loss_pos = loss_pos * weight[pos_mask] |
|
if self.match_type == "sinkhorn": |
|
neg_w0 = (weight.sum(-1) != 0)[neg0] |
|
neg_w1 = (weight.sum(1) != 0)[neg1] |
|
neg_mask = torch.cat([neg_w0, neg_w1], 0) |
|
loss_neg = loss_neg[neg_mask] |
|
|
|
loss = ( |
|
c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() |
|
if self.match_type == "sinkhorn" |
|
else c_pos_w * loss_pos.mean() |
|
) |
|
return loss |
|
|
|
else: |
|
loss_pos = ( |
|
-alpha |
|
* torch.pow(1 - conf[pos_mask], gamma) |
|
* (conf[pos_mask]).log() |
|
) |
|
loss_neg = ( |
|
-alpha |
|
* torch.pow(conf[neg_mask], gamma) |
|
* (1 - conf[neg_mask]).log() |
|
) |
|
if weight is not None: |
|
loss_pos = loss_pos * weight[pos_mask] |
|
loss_neg = loss_neg * weight[neg_mask] |
|
return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() |
|
|
|
else: |
|
raise ValueError( |
|
"Unknown coarse loss: {type}".format( |
|
type=self.loss_config["coarse_type"] |
|
) |
|
) |
|
|
|
def compute_fine_loss(self, expec_f, expec_f_gt): |
|
if self.fine_type == "l2_with_std": |
|
return self._compute_fine_loss_l2_std(expec_f, expec_f_gt) |
|
elif self.fine_type == "l2": |
|
return self._compute_fine_loss_l2(expec_f, expec_f_gt) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def _compute_fine_loss_l2(self, expec_f, expec_f_gt): |
|
""" |
|
Args: |
|
expec_f (torch.Tensor): [M, 2] <x, y> |
|
expec_f_gt (torch.Tensor): [M, 2] <x, y> |
|
""" |
|
correct_mask = ( |
|
torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr |
|
) |
|
if correct_mask.sum() == 0: |
|
if ( |
|
self.training |
|
): |
|
logger.warning("assign a false supervision to avoid ddp deadlock") |
|
correct_mask[0] = True |
|
else: |
|
return None |
|
flow_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask]) ** 2).sum(-1) |
|
return flow_l2.mean() |
|
|
|
def _compute_fine_loss_l2_std(self, expec_f, expec_f_gt): |
|
""" |
|
Args: |
|
expec_f (torch.Tensor): [M, 3] <x, y, std> |
|
expec_f_gt (torch.Tensor): [M, 2] <x, y> |
|
""" |
|
|
|
correct_mask = ( |
|
torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr |
|
) |
|
|
|
|
|
std = expec_f[:, 2] |
|
inverse_std = 1.0 / torch.clamp(std, min=1e-10) |
|
weight = ( |
|
inverse_std / torch.mean(inverse_std) |
|
).detach() |
|
|
|
|
|
if not correct_mask.any(): |
|
if ( |
|
self.training |
|
): |
|
|
|
logger.warning("assign a false supervision to avoid ddp deadlock") |
|
correct_mask[0] = True |
|
weight[0] = 0.0 |
|
else: |
|
return None |
|
|
|
|
|
flow_l2 = ((expec_f_gt[correct_mask] - expec_f[correct_mask, :2]) ** 2).sum(-1) |
|
loss = (flow_l2 * weight[correct_mask]).mean() |
|
|
|
return loss |
|
|
|
@torch.no_grad() |
|
def compute_c_weight(self, data): |
|
"""compute element-wise weights for computing coarse-level loss.""" |
|
if "mask0" in data: |
|
c_weight = ( |
|
data["mask0"].flatten(-2)[..., None] |
|
* data["mask1"].flatten(-2)[:, None] |
|
).float() |
|
else: |
|
c_weight = None |
|
return c_weight |
|
|
|
def forward(self, data): |
|
""" |
|
Update: |
|
data (dict): update{ |
|
'loss': [1] the reduced loss across a batch, |
|
'loss_scalars' (dict): loss scalars for tensorboard_record |
|
} |
|
""" |
|
loss_scalars = {} |
|
|
|
c_weight = self.compute_c_weight(data) |
|
|
|
|
|
loss_c = self.compute_coarse_loss( |
|
data["conf_matrix_with_bin"] |
|
if self.sparse_spvs and self.match_type == "sinkhorn" |
|
else data["conf_matrix"], |
|
data["conf_matrix_gt"], |
|
weight=c_weight, |
|
) |
|
loss = loss_c * self.loss_config["coarse_weight"] |
|
loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) |
|
|
|
|
|
loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"]) |
|
if loss_f is not None: |
|
loss += loss_f * self.loss_config["fine_weight"] |
|
loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()}) |
|
else: |
|
assert self.training is False |
|
loss_scalars.update({"loss_f": torch.tensor(1.0)}) |
|
|
|
|
|
coarse_corr = [data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]] |
|
loss_flow = self.compute_flow_loss( |
|
coarse_corr, |
|
data["predict_flow"], |
|
data["hw0_c"][0], |
|
data["hw0_c"][1], |
|
data["hw1_c"][0], |
|
data["hw1_c"][1], |
|
) |
|
loss_flow = loss_flow * self.flow_weight |
|
for index, loss_off in enumerate(loss_flow): |
|
loss_scalars.update( |
|
{"loss_flow_" + str(index): loss_off.clone().detach().cpu()} |
|
) |
|
conf = data["predict_flow"][0][:, :, :, :, 2:] |
|
layer_num = conf.shape[0] |
|
for layer_index in range(layer_num): |
|
loss_scalars.update( |
|
{ |
|
"conf_" |
|
+ str(layer_index): conf[layer_index] |
|
.mean() |
|
.clone() |
|
.detach() |
|
.cpu() |
|
} |
|
) |
|
|
|
loss += loss_flow.sum() |
|
|
|
loss_scalars.update({"loss": loss.clone().detach().cpu()}) |
|
data.update({"loss": loss, "loss_scalars": loss_scalars}) |
|
|