Vincentqyw
fix: roma
358ab8f
raw
history blame
7.58 kB
import torch
import numpy as np
def batch_episym(x1, x2, F):
batch_size, num_pts = x1.shape[0], x1.shape[1]
x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts, 1)], dim=-1).reshape(
batch_size, num_pts, 3, 1
)
x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts, 1)], dim=-1).reshape(
batch_size, num_pts, 3, 1
)
F = F.reshape(-1, 1, 3, 3).repeat(1, num_pts, 1, 1)
x2Fx1 = torch.matmul(x2.transpose(2, 3), torch.matmul(F, x1)).reshape(
batch_size, num_pts
)
Fx1 = torch.matmul(F, x1).reshape(batch_size, num_pts, 3)
Ftx2 = torch.matmul(F.transpose(2, 3), x2).reshape(batch_size, num_pts, 3)
ys = (
x2Fx1**2
* (
1.0 / (Fx1[:, :, 0] ** 2 + Fx1[:, :, 1] ** 2 + 1e-15)
+ 1.0 / (Ftx2[:, :, 0] ** 2 + Ftx2[:, :, 1] ** 2 + 1e-15)
)
).sqrt()
return ys
def CELoss(seed_x1, seed_x2, e, confidence, inlier_th, batch_mask=1):
# seed_x: b*k*2
ys = batch_episym(seed_x1, seed_x2, e)
mask_pos, mask_neg = (ys <= inlier_th).float(), (ys > inlier_th).float()
num_pos, num_neg = (
torch.relu(torch.sum(mask_pos, dim=1) - 1.0) + 1.0,
torch.relu(torch.sum(mask_neg, dim=1) - 1.0) + 1.0,
)
loss_pos, loss_neg = (
-torch.log(abs(confidence) + 1e-8) * mask_pos,
-torch.log(abs(1 - confidence) + 1e-8) * mask_neg,
)
classif_loss = torch.mean(
loss_pos * 0.5 / num_pos.unsqueeze(-1) + loss_neg * 0.5 / num_neg.unsqueeze(-1),
dim=-1,
)
classif_loss = classif_loss * batch_mask
classif_loss = classif_loss.mean()
precision = torch.mean(
torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1)
/ (torch.sum((confidence > 0.5).type(confidence.type()), dim=1) + 1e-8)
)
recall = torch.mean(
torch.sum((confidence > 0.5).type(confidence.type()) * mask_pos, dim=1)
/ num_pos
)
return classif_loss, precision, recall
def CorrLoss(desc_mat, batch_num_corr, batch_num_incorr1, batch_num_incorr2):
total_loss_corr, total_loss_incorr = 0, 0
total_acc_corr, total_acc_incorr = 0, 0
batch_size = desc_mat.shape[0]
log_p = torch.log(abs(desc_mat) + 1e-8)
for i in range(batch_size):
cur_log_p = log_p[i]
num_corr = batch_num_corr[i]
num_incorr1, num_incorr2 = batch_num_incorr1[i], batch_num_incorr2[i]
# loss and acc
loss_corr = -torch.diag(cur_log_p)[:num_corr].mean()
loss_incorr = (
-cur_log_p[num_corr : num_corr + num_incorr1, -1].mean()
- cur_log_p[-1, num_corr : num_corr + num_incorr2].mean()
) / 2
value_row, row_index = torch.max(desc_mat[i, :-1, :-1], dim=-1)
value_col, col_index = torch.max(desc_mat[i, :-1, :-1], dim=-2)
acc_incorr = (
(value_row[num_corr : num_corr + num_incorr1] < 0.2).float().mean()
+ (value_col[num_corr : num_corr + num_incorr2] < 0.2).float().mean()
) / 2
acc_row_mask = row_index[:num_corr] == torch.arange(num_corr).cuda()
acc_col_mask = col_index[:num_corr] == torch.arange(num_corr).cuda()
acc = (acc_col_mask & acc_row_mask).float().mean()
total_loss_corr += loss_corr
total_loss_incorr += loss_incorr
total_acc_corr += acc
total_acc_incorr += acc_incorr
total_acc_corr /= batch_size
total_acc_incorr /= batch_size
total_loss_corr /= batch_size
total_loss_incorr /= batch_size
return total_loss_corr, total_loss_incorr, total_acc_corr, total_acc_incorr
class SGMLoss:
def __init__(self, config, model_config):
self.config = config
self.model_config = model_config
def run(self, data, result):
loss_corr, loss_incorr, acc_corr, acc_incorr = CorrLoss(
result["p"], data["num_corr"], data["num_incorr1"], data["num_incorr2"]
)
loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = [], [], []
# mid loss
for i in range(len(result["mid_p"])):
mid_p = result["mid_p"][i]
loss_mid_corr, loss_mid_incorr, mid_acc_corr, mid_acc_incorr = CorrLoss(
mid_p, data["num_corr"], data["num_incorr1"], data["num_incorr2"]
)
loss_mid_corr_tower.append(loss_mid_corr), loss_mid_incorr_tower.append(
loss_mid_incorr
), acc_mid_tower.append(mid_acc_corr)
if len(result["mid_p"]) != 0:
loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = (
torch.stack(loss_mid_corr_tower),
torch.stack(loss_mid_incorr_tower),
torch.stack(acc_mid_tower),
)
else:
loss_mid_corr_tower, loss_mid_incorr_tower, acc_mid_tower = (
torch.zeros(1).cuda(),
torch.zeros(1).cuda(),
torch.zeros(1).cuda(),
)
# seed confidence loss
classif_loss_tower, classif_precision_tower, classif_recall_tower = [], [], []
for layer in range(len(result["seed_conf"])):
confidence = result["seed_conf"][layer]
seed_index = result["seed_index"][
(np.asarray(self.model_config.seedlayer) <= layer).nonzero()[0][-1]
]
seed_x1, seed_x2 = data["x1"].gather(
dim=1, index=seed_index[:, :, 0, None].expand(-1, -1, 2)
), data["x2"].gather(
dim=1, index=seed_index[:, :, 1, None].expand(-1, -1, 2)
)
classif_loss, classif_precision, classif_recall = CELoss(
seed_x1, seed_x2, data["e_gt"], confidence, self.config.inlier_th
)
classif_loss_tower.append(classif_loss), classif_precision_tower.append(
classif_precision
), classif_recall_tower.append(classif_recall)
classif_loss, classif_precision_tower, classif_recall_tower = (
torch.stack(classif_loss_tower).mean(),
torch.stack(classif_precision_tower),
torch.stack(classif_recall_tower),
)
classif_loss *= self.config.seed_loss_weight
loss_mid_corr_tower *= self.config.mid_loss_weight
loss_mid_incorr_tower *= self.config.mid_loss_weight
total_loss = (
loss_corr
+ loss_incorr
+ classif_loss
+ loss_mid_corr_tower.sum()
+ loss_mid_incorr_tower.sum()
)
return {
"loss_corr": loss_corr,
"loss_incorr": loss_incorr,
"acc_corr": acc_corr,
"acc_incorr": acc_incorr,
"loss_seed_conf": classif_loss,
"pre_seed_conf": classif_precision_tower,
"recall_seed_conf": classif_recall_tower,
"loss_corr_mid": loss_mid_corr_tower,
"loss_incorr_mid": loss_mid_incorr_tower,
"mid_acc_corr": acc_mid_tower,
"total_loss": total_loss,
}
class SGLoss:
def __init__(self, config, model_config):
self.config = config
self.model_config = model_config
def run(self, data, result):
loss_corr, loss_incorr, acc_corr, acc_incorr = CorrLoss(
result["p"], data["num_corr"], data["num_incorr1"], data["num_incorr2"]
)
total_loss = loss_corr + loss_incorr
return {
"loss_corr": loss_corr,
"loss_incorr": loss_incorr,
"acc_corr": acc_corr,
"acc_incorr": acc_incorr,
"total_loss": total_loss,
}