|
""" |
|
This file implements the evaluation metrics. |
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from torchvision.ops.boxes import batched_nms |
|
|
|
from ..misc.geometry_utils import keypoints_to_grid |
|
|
|
|
|
class Metrics(object): |
|
"""Metric evaluation calculator.""" |
|
|
|
def __init__( |
|
self, |
|
detection_thresh, |
|
prob_thresh, |
|
grid_size, |
|
junc_metric_lst=None, |
|
heatmap_metric_lst=None, |
|
pr_metric_lst=None, |
|
desc_metric_lst=None, |
|
): |
|
|
|
self.supported_junc_metrics = [ |
|
"junc_precision", |
|
"junc_precision_nms", |
|
"junc_recall", |
|
"junc_recall_nms", |
|
] |
|
self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"] |
|
self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] |
|
self.supported_desc_metrics = ["matching_score"] |
|
|
|
|
|
if junc_metric_lst is None: |
|
self.junc_metric_lst = self.supported_junc_metrics |
|
else: |
|
self.junc_metric_lst = junc_metric_lst |
|
if heatmap_metric_lst is None: |
|
self.heatmap_metric_lst = self.supported_heatmap_metrics |
|
else: |
|
self.heatmap_metric_lst = heatmap_metric_lst |
|
if pr_metric_lst is None: |
|
self.pr_metric_lst = self.supported_pr_metrics |
|
else: |
|
self.pr_metric_lst = pr_metric_lst |
|
|
|
if desc_metric_lst is None: |
|
self.desc_metric_lst = [] |
|
elif desc_metric_lst == "all": |
|
self.desc_metric_lst = self.supported_desc_metrics |
|
else: |
|
self.desc_metric_lst = desc_metric_lst |
|
|
|
if not self._check_metrics(): |
|
raise ValueError("[Error] Some elements in the metric_lst are invalid.") |
|
|
|
|
|
self.metric_table = { |
|
"junc_precision": junction_precision(detection_thresh), |
|
"junc_precision_nms": junction_precision(detection_thresh), |
|
"junc_recall": junction_recall(detection_thresh), |
|
"junc_recall_nms": junction_recall(detection_thresh), |
|
"heatmap_precision": heatmap_precision(prob_thresh), |
|
"heatmap_recall": heatmap_recall(prob_thresh), |
|
"junc_pr": junction_pr(), |
|
"junc_nms_pr": junction_pr(), |
|
"matching_score": matching_score(grid_size), |
|
} |
|
|
|
|
|
self.metric_results = {} |
|
for key in self.metric_table.keys(): |
|
self.metric_results[key] = 0.0 |
|
|
|
def evaluate( |
|
self, |
|
junc_pred, |
|
junc_pred_nms, |
|
junc_gt, |
|
heatmap_pred, |
|
heatmap_gt, |
|
valid_mask, |
|
line_points1=None, |
|
line_points2=None, |
|
desc_pred1=None, |
|
desc_pred2=None, |
|
valid_points=None, |
|
): |
|
"""Perform evaluation.""" |
|
for metric in self.junc_metric_lst: |
|
|
|
if "nms" in metric: |
|
junc_pred_input = junc_pred_nms |
|
|
|
else: |
|
junc_pred_input = junc_pred |
|
self.metric_results[metric] = self.metric_table[metric]( |
|
junc_pred_input, junc_gt, valid_mask |
|
) |
|
|
|
for metric in self.heatmap_metric_lst: |
|
self.metric_results[metric] = self.metric_table[metric]( |
|
heatmap_pred, heatmap_gt, valid_mask |
|
) |
|
|
|
for metric in self.pr_metric_lst: |
|
if "nms" in metric: |
|
self.metric_results[metric] = self.metric_table[metric]( |
|
junc_pred_nms, junc_gt, valid_mask |
|
) |
|
else: |
|
self.metric_results[metric] = self.metric_table[metric]( |
|
junc_pred, junc_gt, valid_mask |
|
) |
|
|
|
for metric in self.desc_metric_lst: |
|
self.metric_results[metric] = self.metric_table[metric]( |
|
line_points1, line_points2, desc_pred1, desc_pred2, valid_points |
|
) |
|
|
|
def _check_metrics(self): |
|
"""Check if all input metrics are valid.""" |
|
flag = True |
|
for metric in self.junc_metric_lst: |
|
if not metric in self.supported_junc_metrics: |
|
flag = False |
|
break |
|
for metric in self.heatmap_metric_lst: |
|
if not metric in self.supported_heatmap_metrics: |
|
flag = False |
|
break |
|
for metric in self.desc_metric_lst: |
|
if not metric in self.supported_desc_metrics: |
|
flag = False |
|
break |
|
|
|
return flag |
|
|
|
|
|
class AverageMeter(object): |
|
def __init__( |
|
self, |
|
junc_metric_lst=None, |
|
heatmap_metric_lst=None, |
|
is_training=True, |
|
desc_metric_lst=None, |
|
): |
|
|
|
self.supported_junc_metrics = [ |
|
"junc_precision", |
|
"junc_precision_nms", |
|
"junc_recall", |
|
"junc_recall_nms", |
|
] |
|
self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"] |
|
self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"] |
|
self.supported_desc_metrics = ["matching_score"] |
|
|
|
|
|
self.supported_loss = [ |
|
"junc_loss", |
|
"heatmap_loss", |
|
"descriptor_loss", |
|
"total_loss", |
|
] |
|
|
|
self.is_training = is_training |
|
|
|
|
|
if junc_metric_lst is None: |
|
self.junc_metric_lst = self.supported_junc_metrics |
|
else: |
|
self.junc_metric_lst = junc_metric_lst |
|
if heatmap_metric_lst is None: |
|
self.heatmap_metric_lst = self.supported_heatmap_metrics |
|
else: |
|
self.heatmap_metric_lst = heatmap_metric_lst |
|
|
|
if desc_metric_lst is None: |
|
self.desc_metric_lst = [] |
|
elif desc_metric_lst == "all": |
|
self.desc_metric_lst = self.supported_desc_metrics |
|
else: |
|
self.desc_metric_lst = desc_metric_lst |
|
|
|
if not self._check_metrics(): |
|
raise ValueError("[Error] Some elements in the metric_lst are invalid.") |
|
|
|
|
|
self.metric_results = {} |
|
for key in ( |
|
self.supported_junc_metrics |
|
+ self.supported_heatmap_metrics |
|
+ self.supported_loss |
|
+ self.supported_desc_metrics |
|
): |
|
self.metric_results[key] = 0.0 |
|
for key in self.supported_pr_metrics: |
|
zero_lst = [0 for _ in range(50)] |
|
self.metric_results[key] = { |
|
"tp": zero_lst, |
|
"tn": zero_lst, |
|
"fp": zero_lst, |
|
"fn": zero_lst, |
|
"precision": zero_lst, |
|
"recall": zero_lst, |
|
} |
|
|
|
|
|
self.count = 0 |
|
|
|
def update(self, metrics, loss_dict=None, num_samples=1): |
|
|
|
if self.is_training and (loss_dict is None): |
|
raise ValueError("[Error] loss info should be given in the training mode.") |
|
|
|
|
|
self.count += num_samples |
|
|
|
|
|
for met in ( |
|
self.supported_junc_metrics |
|
+ self.supported_heatmap_metrics |
|
+ self.supported_desc_metrics |
|
): |
|
self.metric_results[met] += num_samples * metrics.metric_results[met] |
|
|
|
|
|
for loss in loss_dict.keys(): |
|
self.metric_results[loss] += num_samples * loss_dict[loss] |
|
|
|
|
|
for pr_met in self.supported_pr_metrics: |
|
|
|
for key in metrics.metric_results[pr_met].keys(): |
|
|
|
for idx in range(len(self.metric_results[pr_met][key])): |
|
self.metric_results[pr_met][key][idx] += ( |
|
num_samples * metrics.metric_results[pr_met][key][idx] |
|
) |
|
|
|
def average(self): |
|
results = {} |
|
for met in self.metric_results.keys(): |
|
|
|
if not met in self.supported_pr_metrics: |
|
results[met] = self.metric_results[met] / self.count |
|
|
|
else: |
|
met_results = { |
|
"tp": self.metric_results[met]["tp"], |
|
"tn": self.metric_results[met]["tn"], |
|
"fp": self.metric_results[met]["fp"], |
|
"fn": self.metric_results[met]["fn"], |
|
"precision": [], |
|
"recall": [], |
|
} |
|
for idx in range(len(self.metric_results[met]["precision"])): |
|
met_results["precision"].append( |
|
self.metric_results[met]["precision"][idx] / self.count |
|
) |
|
met_results["recall"].append( |
|
self.metric_results[met]["recall"][idx] / self.count |
|
) |
|
|
|
results[met] = met_results |
|
|
|
return results |
|
|
|
def _check_metrics(self): |
|
"""Check if all input metrics are valid.""" |
|
flag = True |
|
for metric in self.junc_metric_lst: |
|
if not metric in self.supported_junc_metrics: |
|
flag = False |
|
break |
|
for metric in self.heatmap_metric_lst: |
|
if not metric in self.supported_heatmap_metrics: |
|
flag = False |
|
break |
|
for metric in self.desc_metric_lst: |
|
if not metric in self.supported_desc_metrics: |
|
flag = False |
|
break |
|
|
|
return flag |
|
|
|
|
|
class junction_precision(object): |
|
"""Junction precision.""" |
|
|
|
def __init__(self, detection_thresh): |
|
self.detection_thresh = detection_thresh |
|
|
|
|
|
def __call__(self, junc_pred, junc_gt, valid_mask): |
|
|
|
junc_pred = (junc_pred >= self.detection_thresh).astype(np.int) |
|
junc_pred = junc_pred * valid_mask.squeeze() |
|
|
|
|
|
if np.sum(junc_pred) > 0: |
|
precision = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_pred) |
|
else: |
|
precision = 0 |
|
|
|
return float(precision) |
|
|
|
|
|
class junction_recall(object): |
|
"""Junction recall.""" |
|
|
|
def __init__(self, detection_thresh): |
|
self.detection_thresh = detection_thresh |
|
|
|
|
|
def __call__(self, junc_pred, junc_gt, valid_mask): |
|
|
|
junc_pred = (junc_pred >= self.detection_thresh).astype(np.int) |
|
junc_pred = junc_pred * valid_mask.squeeze() |
|
|
|
|
|
if np.sum(junc_gt): |
|
recall = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_gt) |
|
else: |
|
recall = 0 |
|
|
|
return float(recall) |
|
|
|
|
|
class junction_pr(object): |
|
"""Junction precision-recall info.""" |
|
|
|
def __init__(self, num_threshold=50): |
|
self.max = 0.4 |
|
step = self.max / num_threshold |
|
self.min = step |
|
self.intervals = np.flip(np.arange(self.min, self.max + step, step)) |
|
|
|
def __call__(self, junc_pred_raw, junc_gt, valid_mask): |
|
tp_lst = [] |
|
fp_lst = [] |
|
tn_lst = [] |
|
fn_lst = [] |
|
precision_lst = [] |
|
recall_lst = [] |
|
|
|
valid_mask = valid_mask.squeeze() |
|
|
|
for thresh in list(self.intervals): |
|
|
|
junc_pred = (junc_pred_raw >= thresh).astype(np.int) |
|
junc_pred = junc_pred * valid_mask |
|
|
|
|
|
junc_gt = junc_gt.squeeze() |
|
tp = np.sum(junc_pred * junc_gt) |
|
tn = np.sum( |
|
(junc_pred == 0).astype(np.float) |
|
* (junc_gt == 0).astype(np.float) |
|
* valid_mask |
|
) |
|
fp = np.sum( |
|
(junc_pred == 1).astype(np.float) |
|
* (junc_gt == 0).astype(np.float) |
|
* valid_mask |
|
) |
|
fn = np.sum( |
|
(junc_pred == 0).astype(np.float) |
|
* (junc_gt == 1).astype(np.float) |
|
* valid_mask |
|
) |
|
|
|
tp_lst.append(tp) |
|
tn_lst.append(tn) |
|
fp_lst.append(fp) |
|
fn_lst.append(fn) |
|
precision_lst.append(tp / (tp + fp)) |
|
recall_lst.append(tp / (tp + fn)) |
|
|
|
return { |
|
"tp": np.array(tp_lst), |
|
"tn": np.array(tn_lst), |
|
"fp": np.array(fp_lst), |
|
"fn": np.array(fn_lst), |
|
"precision": np.array(precision_lst), |
|
"recall": np.array(recall_lst), |
|
} |
|
|
|
|
|
class heatmap_precision(object): |
|
"""Heatmap precision.""" |
|
|
|
def __init__(self, prob_thresh): |
|
self.prob_thresh = prob_thresh |
|
|
|
def __call__(self, heatmap_pred, heatmap_gt, valid_mask): |
|
|
|
heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh) |
|
heatmap_pred = heatmap_pred * valid_mask.squeeze() |
|
|
|
|
|
if np.sum(heatmap_pred) > 0: |
|
precision = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum( |
|
heatmap_pred |
|
) |
|
else: |
|
precision = 0.0 |
|
|
|
return precision |
|
|
|
|
|
class heatmap_recall(object): |
|
"""Heatmap recall.""" |
|
|
|
def __init__(self, prob_thresh): |
|
self.prob_thresh = prob_thresh |
|
|
|
def __call__(self, heatmap_pred, heatmap_gt, valid_mask): |
|
|
|
heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh) |
|
heatmap_pred = heatmap_pred * valid_mask.squeeze() |
|
|
|
|
|
if np.sum(heatmap_gt) > 0: |
|
recall = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(heatmap_gt) |
|
else: |
|
recall = 0.0 |
|
|
|
return recall |
|
|
|
|
|
class matching_score(object): |
|
"""Descriptors matching score.""" |
|
|
|
def __init__(self, grid_size): |
|
self.grid_size = grid_size |
|
|
|
def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices): |
|
b_size, _, Hc, Wc = desc_pred1.size() |
|
img_size = (Hc * self.grid_size, Wc * self.grid_size) |
|
device = desc_pred1.device |
|
|
|
|
|
n_points = line_indices.size()[1] |
|
valid_points = line_indices.bool().flatten() |
|
n_correct_points = torch.sum(valid_points).item() |
|
if n_correct_points == 0: |
|
return torch.tensor(0.0, dtype=torch.float, device=device) |
|
|
|
|
|
grid1 = keypoints_to_grid(points1, img_size) |
|
grid2 = keypoints_to_grid(points2, img_size) |
|
|
|
|
|
desc1 = ( |
|
F.grid_sample(desc_pred1, grid1) |
|
.permute(0, 2, 3, 1) |
|
.reshape(b_size * n_points, -1)[valid_points] |
|
) |
|
desc1 = F.normalize(desc1, dim=1) |
|
desc2 = ( |
|
F.grid_sample(desc_pred2, grid2) |
|
.permute(0, 2, 3, 1) |
|
.reshape(b_size * n_points, -1)[valid_points] |
|
) |
|
desc2 = F.normalize(desc2, dim=1) |
|
desc_dists = 2 - 2 * (desc1 @ desc2.t()) |
|
|
|
|
|
matches0 = torch.min(desc_dists, dim=1)[1] |
|
matches1 = torch.min(desc_dists, dim=0)[1] |
|
matching_score = matches1[matches0] == torch.arange(len(matches0)).to(device) |
|
matching_score = matching_score.float().mean() |
|
return matching_score |
|
|
|
|
|
def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0): |
|
"""Non-maximum suppression adapted from SuperPoint.""" |
|
|
|
im_h = prob_predictions.shape[1] |
|
im_w = prob_predictions.shape[2] |
|
output_lst = [] |
|
for i in range(prob_predictions.shape[0]): |
|
|
|
prob_pred = prob_predictions[i, ...] |
|
|
|
coord = np.where(prob_pred >= prob_thresh) |
|
points = np.concatenate( |
|
(coord[0][..., None], coord[1][..., None]), axis=1 |
|
) |
|
|
|
|
|
prob_score = prob_pred[points[:, 0], points[:, 1]] |
|
|
|
|
|
|
|
in_points = np.concatenate( |
|
(coord[1][..., None], coord[0][..., None], prob_score), axis=1 |
|
).T |
|
keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh) |
|
|
|
keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T) |
|
keep_score = keep_points_[-1, :].T |
|
|
|
|
|
if (top_k > 0) or (top_k is None): |
|
k = min([keep_points.shape[0], top_k]) |
|
keep_points = keep_points[:k, :] |
|
keep_score = keep_score[:k] |
|
|
|
|
|
output_map = np.zeros([im_h, im_w]) |
|
output_map[ |
|
keep_points[:, 0].astype(np.int), keep_points[:, 1].astype(np.int) |
|
] = keep_score.squeeze() |
|
|
|
output_lst.append(output_map[None, ...]) |
|
|
|
return np.concatenate(output_lst, axis=0) |
|
|
|
|
|
def nms_fast(in_corners, H, W, dist_thresh): |
|
""" |
|
Run a faster approximate Non-Max-Suppression on numpy corners shaped: |
|
3xN [x_i,y_i,conf_i]^T |
|
|
|
Algo summary: Create a grid sized HxW. Assign each corner location a 1, |
|
rest are zeros. Iterate through all the 1's and convert them to -1 or 0. |
|
Suppress points by setting nearby values to 0. |
|
|
|
Grid Value Legend: |
|
-1 : Kept. |
|
0 : Empty or suppressed. |
|
1 : To be processed (converted to either kept or supressed). |
|
|
|
NOTE: The NMS first rounds points to integers, so NMS distance might not |
|
be exactly dist_thresh. It also assumes points are within image boundary. |
|
|
|
Inputs |
|
in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T. |
|
H - Image height. |
|
W - Image width. |
|
dist_thresh - Distance to suppress, measured as an infinite distance. |
|
Returns |
|
nmsed_corners - 3xN numpy matrix with surviving corners. |
|
nmsed_inds - N length numpy vector with surviving corner indices. |
|
""" |
|
grid = np.zeros((H, W)).astype(int) |
|
inds = np.zeros((H, W)).astype(int) |
|
|
|
inds1 = np.argsort(-in_corners[2, :]) |
|
corners = in_corners[:, inds1] |
|
rcorners = corners[:2, :].round().astype(int) |
|
|
|
if rcorners.shape[1] == 0: |
|
return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int) |
|
if rcorners.shape[1] == 1: |
|
out = np.vstack((rcorners, in_corners[2])).reshape(3, 1) |
|
return out, np.zeros((1)).astype(int) |
|
|
|
for i, rc in enumerate(rcorners.T): |
|
grid[rcorners[1, i], rcorners[0, i]] = 1 |
|
inds[rcorners[1, i], rcorners[0, i]] = i |
|
|
|
pad = dist_thresh |
|
grid = np.pad(grid, ((pad, pad), (pad, pad)), mode="constant") |
|
|
|
count = 0 |
|
for i, rc in enumerate(rcorners.T): |
|
|
|
pt = (rc[0] + pad, rc[1] + pad) |
|
if grid[pt[1], pt[0]] == 1: |
|
grid[pt[1] - pad : pt[1] + pad + 1, pt[0] - pad : pt[0] + pad + 1] = 0 |
|
grid[pt[1], pt[0]] = -1 |
|
count += 1 |
|
|
|
keepy, keepx = np.where(grid == -1) |
|
keepy, keepx = keepy - pad, keepx - pad |
|
inds_keep = inds[keepy, keepx] |
|
out = corners[:, inds_keep] |
|
values = out[-1, :] |
|
inds2 = np.argsort(-values) |
|
out = out[:, inds2] |
|
out_inds = inds1[inds_keep[inds2]] |
|
return out, out_inds |
|
|