Vincentqyw
fix: roma
8b973ee
raw
history blame
20.7 kB
"""
This file implements the evaluation metrics.
"""
import torch
import torch.nn.functional as F
import numpy as np
from torchvision.ops.boxes import batched_nms
from ..misc.geometry_utils import keypoints_to_grid
class Metrics(object):
"""Metric evaluation calculator."""
def __init__(
self,
detection_thresh,
prob_thresh,
grid_size,
junc_metric_lst=None,
heatmap_metric_lst=None,
pr_metric_lst=None,
desc_metric_lst=None,
):
# List supported metrics
self.supported_junc_metrics = [
"junc_precision",
"junc_precision_nms",
"junc_recall",
"junc_recall_nms",
]
self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"]
self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"]
self.supported_desc_metrics = ["matching_score"]
# If metric_lst is None, default to use all metrics
if junc_metric_lst is None:
self.junc_metric_lst = self.supported_junc_metrics
else:
self.junc_metric_lst = junc_metric_lst
if heatmap_metric_lst is None:
self.heatmap_metric_lst = self.supported_heatmap_metrics
else:
self.heatmap_metric_lst = heatmap_metric_lst
if pr_metric_lst is None:
self.pr_metric_lst = self.supported_pr_metrics
else:
self.pr_metric_lst = pr_metric_lst
# For the descriptors, the default None assumes no desc metric at all
if desc_metric_lst is None:
self.desc_metric_lst = []
elif desc_metric_lst == "all":
self.desc_metric_lst = self.supported_desc_metrics
else:
self.desc_metric_lst = desc_metric_lst
if not self._check_metrics():
raise ValueError("[Error] Some elements in the metric_lst are invalid.")
# Metric mapping table
self.metric_table = {
"junc_precision": junction_precision(detection_thresh),
"junc_precision_nms": junction_precision(detection_thresh),
"junc_recall": junction_recall(detection_thresh),
"junc_recall_nms": junction_recall(detection_thresh),
"heatmap_precision": heatmap_precision(prob_thresh),
"heatmap_recall": heatmap_recall(prob_thresh),
"junc_pr": junction_pr(),
"junc_nms_pr": junction_pr(),
"matching_score": matching_score(grid_size),
}
# Initialize the results
self.metric_results = {}
for key in self.metric_table.keys():
self.metric_results[key] = 0.0
def evaluate(
self,
junc_pred,
junc_pred_nms,
junc_gt,
heatmap_pred,
heatmap_gt,
valid_mask,
line_points1=None,
line_points2=None,
desc_pred1=None,
desc_pred2=None,
valid_points=None,
):
"""Perform evaluation."""
for metric in self.junc_metric_lst:
# If nms metrics then use nms to compute it.
if "nms" in metric:
junc_pred_input = junc_pred_nms
# Use normal inputs instead.
else:
junc_pred_input = junc_pred
self.metric_results[metric] = self.metric_table[metric](
junc_pred_input, junc_gt, valid_mask
)
for metric in self.heatmap_metric_lst:
self.metric_results[metric] = self.metric_table[metric](
heatmap_pred, heatmap_gt, valid_mask
)
for metric in self.pr_metric_lst:
if "nms" in metric:
self.metric_results[metric] = self.metric_table[metric](
junc_pred_nms, junc_gt, valid_mask
)
else:
self.metric_results[metric] = self.metric_table[metric](
junc_pred, junc_gt, valid_mask
)
for metric in self.desc_metric_lst:
self.metric_results[metric] = self.metric_table[metric](
line_points1, line_points2, desc_pred1, desc_pred2, valid_points
)
def _check_metrics(self):
"""Check if all input metrics are valid."""
flag = True
for metric in self.junc_metric_lst:
if not metric in self.supported_junc_metrics:
flag = False
break
for metric in self.heatmap_metric_lst:
if not metric in self.supported_heatmap_metrics:
flag = False
break
for metric in self.desc_metric_lst:
if not metric in self.supported_desc_metrics:
flag = False
break
return flag
class AverageMeter(object):
def __init__(
self,
junc_metric_lst=None,
heatmap_metric_lst=None,
is_training=True,
desc_metric_lst=None,
):
# List supported metrics
self.supported_junc_metrics = [
"junc_precision",
"junc_precision_nms",
"junc_recall",
"junc_recall_nms",
]
self.supported_heatmap_metrics = ["heatmap_precision", "heatmap_recall"]
self.supported_pr_metrics = ["junc_pr", "junc_nms_pr"]
self.supported_desc_metrics = ["matching_score"]
# Record loss in training mode
# if is_training:
self.supported_loss = [
"junc_loss",
"heatmap_loss",
"descriptor_loss",
"total_loss",
]
self.is_training = is_training
# If metric_lst is None, default to use all metrics
if junc_metric_lst is None:
self.junc_metric_lst = self.supported_junc_metrics
else:
self.junc_metric_lst = junc_metric_lst
if heatmap_metric_lst is None:
self.heatmap_metric_lst = self.supported_heatmap_metrics
else:
self.heatmap_metric_lst = heatmap_metric_lst
# For the descriptors, the default None assumes no desc metric at all
if desc_metric_lst is None:
self.desc_metric_lst = []
elif desc_metric_lst == "all":
self.desc_metric_lst = self.supported_desc_metrics
else:
self.desc_metric_lst = desc_metric_lst
if not self._check_metrics():
raise ValueError("[Error] Some elements in the metric_lst are invalid.")
# Initialize the results
self.metric_results = {}
for key in (
self.supported_junc_metrics
+ self.supported_heatmap_metrics
+ self.supported_loss
+ self.supported_desc_metrics
):
self.metric_results[key] = 0.0
for key in self.supported_pr_metrics:
zero_lst = [0 for _ in range(50)]
self.metric_results[key] = {
"tp": zero_lst,
"tn": zero_lst,
"fp": zero_lst,
"fn": zero_lst,
"precision": zero_lst,
"recall": zero_lst,
}
# Initialize total count
self.count = 0
def update(self, metrics, loss_dict=None, num_samples=1):
# loss should be given in the training mode
if self.is_training and (loss_dict is None):
raise ValueError("[Error] loss info should be given in the training mode.")
# update total counts
self.count += num_samples
# update all the metrics
for met in (
self.supported_junc_metrics
+ self.supported_heatmap_metrics
+ self.supported_desc_metrics
):
self.metric_results[met] += num_samples * metrics.metric_results[met]
# Update all the losses
for loss in loss_dict.keys():
self.metric_results[loss] += num_samples * loss_dict[loss]
# Update all pr counts
for pr_met in self.supported_pr_metrics:
# Update all tp, tn, fp, fn, precision, and recall.
for key in metrics.metric_results[pr_met].keys():
# Update each interval
for idx in range(len(self.metric_results[pr_met][key])):
self.metric_results[pr_met][key][idx] += (
num_samples * metrics.metric_results[pr_met][key][idx]
)
def average(self):
results = {}
for met in self.metric_results.keys():
# Skip pr curve metrics
if not met in self.supported_pr_metrics:
results[met] = self.metric_results[met] / self.count
# Only update precision and recall in pr metrics
else:
met_results = {
"tp": self.metric_results[met]["tp"],
"tn": self.metric_results[met]["tn"],
"fp": self.metric_results[met]["fp"],
"fn": self.metric_results[met]["fn"],
"precision": [],
"recall": [],
}
for idx in range(len(self.metric_results[met]["precision"])):
met_results["precision"].append(
self.metric_results[met]["precision"][idx] / self.count
)
met_results["recall"].append(
self.metric_results[met]["recall"][idx] / self.count
)
results[met] = met_results
return results
def _check_metrics(self):
"""Check if all input metrics are valid."""
flag = True
for metric in self.junc_metric_lst:
if not metric in self.supported_junc_metrics:
flag = False
break
for metric in self.heatmap_metric_lst:
if not metric in self.supported_heatmap_metrics:
flag = False
break
for metric in self.desc_metric_lst:
if not metric in self.supported_desc_metrics:
flag = False
break
return flag
class junction_precision(object):
"""Junction precision."""
def __init__(self, detection_thresh):
self.detection_thresh = detection_thresh
# Compute the evaluation result
def __call__(self, junc_pred, junc_gt, valid_mask):
# Convert prediction to discrete detection
junc_pred = (junc_pred >= self.detection_thresh).astype(np.int)
junc_pred = junc_pred * valid_mask.squeeze()
# Deal with the corner case of the prediction
if np.sum(junc_pred) > 0:
precision = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_pred)
else:
precision = 0
return float(precision)
class junction_recall(object):
"""Junction recall."""
def __init__(self, detection_thresh):
self.detection_thresh = detection_thresh
# Compute the evaluation result
def __call__(self, junc_pred, junc_gt, valid_mask):
# Convert prediction to discrete detection
junc_pred = (junc_pred >= self.detection_thresh).astype(np.int)
junc_pred = junc_pred * valid_mask.squeeze()
# Deal with the corner case of the recall.
if np.sum(junc_gt):
recall = np.sum(junc_pred * junc_gt.squeeze()) / np.sum(junc_gt)
else:
recall = 0
return float(recall)
class junction_pr(object):
"""Junction precision-recall info."""
def __init__(self, num_threshold=50):
self.max = 0.4
step = self.max / num_threshold
self.min = step
self.intervals = np.flip(np.arange(self.min, self.max + step, step))
def __call__(self, junc_pred_raw, junc_gt, valid_mask):
tp_lst = []
fp_lst = []
tn_lst = []
fn_lst = []
precision_lst = []
recall_lst = []
valid_mask = valid_mask.squeeze()
# Iterate through all the thresholds
for thresh in list(self.intervals):
# Convert prediction to discrete detection
junc_pred = (junc_pred_raw >= thresh).astype(np.int)
junc_pred = junc_pred * valid_mask
# Compute tp, fp, tn, fn
junc_gt = junc_gt.squeeze()
tp = np.sum(junc_pred * junc_gt)
tn = np.sum(
(junc_pred == 0).astype(np.float)
* (junc_gt == 0).astype(np.float)
* valid_mask
)
fp = np.sum(
(junc_pred == 1).astype(np.float)
* (junc_gt == 0).astype(np.float)
* valid_mask
)
fn = np.sum(
(junc_pred == 0).astype(np.float)
* (junc_gt == 1).astype(np.float)
* valid_mask
)
tp_lst.append(tp)
tn_lst.append(tn)
fp_lst.append(fp)
fn_lst.append(fn)
precision_lst.append(tp / (tp + fp))
recall_lst.append(tp / (tp + fn))
return {
"tp": np.array(tp_lst),
"tn": np.array(tn_lst),
"fp": np.array(fp_lst),
"fn": np.array(fn_lst),
"precision": np.array(precision_lst),
"recall": np.array(recall_lst),
}
class heatmap_precision(object):
"""Heatmap precision."""
def __init__(self, prob_thresh):
self.prob_thresh = prob_thresh
def __call__(self, heatmap_pred, heatmap_gt, valid_mask):
# Assume NHWC (Handle L1 and L2 cases) NxHxWx1
heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh)
heatmap_pred = heatmap_pred * valid_mask.squeeze()
# Deal with the corner case of the prediction
if np.sum(heatmap_pred) > 0:
precision = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(
heatmap_pred
)
else:
precision = 0.0
return precision
class heatmap_recall(object):
"""Heatmap recall."""
def __init__(self, prob_thresh):
self.prob_thresh = prob_thresh
def __call__(self, heatmap_pred, heatmap_gt, valid_mask):
# Assume NHWC (Handle L1 and L2 cases) NxHxWx1
heatmap_pred = np.squeeze(heatmap_pred > self.prob_thresh)
heatmap_pred = heatmap_pred * valid_mask.squeeze()
# Deal with the corner case of the ground truth
if np.sum(heatmap_gt) > 0:
recall = np.sum(heatmap_pred * heatmap_gt.squeeze()) / np.sum(heatmap_gt)
else:
recall = 0.0
return recall
class matching_score(object):
"""Descriptors matching score."""
def __init__(self, grid_size):
self.grid_size = grid_size
def __call__(self, points1, points2, desc_pred1, desc_pred2, line_indices):
b_size, _, Hc, Wc = desc_pred1.size()
img_size = (Hc * self.grid_size, Wc * self.grid_size)
device = desc_pred1.device
# Extract valid keypoints
n_points = line_indices.size()[1]
valid_points = line_indices.bool().flatten()
n_correct_points = torch.sum(valid_points).item()
if n_correct_points == 0:
return torch.tensor(0.0, dtype=torch.float, device=device)
# Convert the keypoints to a grid suitable for interpolation
grid1 = keypoints_to_grid(points1, img_size)
grid2 = keypoints_to_grid(points2, img_size)
# Extract the descriptors
desc1 = (
F.grid_sample(desc_pred1, grid1)
.permute(0, 2, 3, 1)
.reshape(b_size * n_points, -1)[valid_points]
)
desc1 = F.normalize(desc1, dim=1)
desc2 = (
F.grid_sample(desc_pred2, grid2)
.permute(0, 2, 3, 1)
.reshape(b_size * n_points, -1)[valid_points]
)
desc2 = F.normalize(desc2, dim=1)
desc_dists = 2 - 2 * (desc1 @ desc2.t())
# Compute percentage of correct matches
matches0 = torch.min(desc_dists, dim=1)[1]
matches1 = torch.min(desc_dists, dim=0)[1]
matching_score = matches1[matches0] == torch.arange(len(matches0)).to(device)
matching_score = matching_score.float().mean()
return matching_score
def super_nms(prob_predictions, dist_thresh, prob_thresh=0.01, top_k=0):
"""Non-maximum suppression adapted from SuperPoint."""
# Iterate through batch dimension
im_h = prob_predictions.shape[1]
im_w = prob_predictions.shape[2]
output_lst = []
for i in range(prob_predictions.shape[0]):
# print(i)
prob_pred = prob_predictions[i, ...]
# Filter the points using prob_thresh
coord = np.where(prob_pred >= prob_thresh) # HW format
points = np.concatenate(
(coord[0][..., None], coord[1][..., None]), axis=1
) # HW format
# Get the probability score
prob_score = prob_pred[points[:, 0], points[:, 1]]
# Perform super nms
# Modify the in_points to xy format (instead of HW format)
in_points = np.concatenate(
(coord[1][..., None], coord[0][..., None], prob_score), axis=1
).T
keep_points_, keep_inds = nms_fast(in_points, im_h, im_w, dist_thresh)
# Remember to flip outputs back to HW format
keep_points = np.round(np.flip(keep_points_[:2, :], axis=0).T)
keep_score = keep_points_[-1, :].T
# Whether we only keep the topk value
if (top_k > 0) or (top_k is None):
k = min([keep_points.shape[0], top_k])
keep_points = keep_points[:k, :]
keep_score = keep_score[:k]
# Re-compose the probability map
output_map = np.zeros([im_h, im_w])
output_map[
keep_points[:, 0].astype(np.int), keep_points[:, 1].astype(np.int)
] = keep_score.squeeze()
output_lst.append(output_map[None, ...])
return np.concatenate(output_lst, axis=0)
def nms_fast(in_corners, H, W, dist_thresh):
"""
Run a faster approximate Non-Max-Suppression on numpy corners shaped:
3xN [x_i,y_i,conf_i]^T
Algo summary: Create a grid sized HxW. Assign each corner location a 1,
rest are zeros. Iterate through all the 1's and convert them to -1 or 0.
Suppress points by setting nearby values to 0.
Grid Value Legend:
-1 : Kept.
0 : Empty or suppressed.
1 : To be processed (converted to either kept or supressed).
NOTE: The NMS first rounds points to integers, so NMS distance might not
be exactly dist_thresh. It also assumes points are within image boundary.
Inputs
in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
H - Image height.
W - Image width.
dist_thresh - Distance to suppress, measured as an infinite distance.
Returns
nmsed_corners - 3xN numpy matrix with surviving corners.
nmsed_inds - N length numpy vector with surviving corner indices.
"""
grid = np.zeros((H, W)).astype(int) # Track NMS data.
inds = np.zeros((H, W)).astype(int) # Store indices of points.
# Sort by confidence and round to nearest int.
inds1 = np.argsort(-in_corners[2, :])
corners = in_corners[:, inds1]
rcorners = corners[:2, :].round().astype(int) # Rounded corners.
# Check for edge case of 0 or 1 corners.
if rcorners.shape[1] == 0:
return np.zeros((3, 0)).astype(int), np.zeros(0).astype(int)
if rcorners.shape[1] == 1:
out = np.vstack((rcorners, in_corners[2])).reshape(3, 1)
return out, np.zeros((1)).astype(int)
# Initialize the grid.
for i, rc in enumerate(rcorners.T):
grid[rcorners[1, i], rcorners[0, i]] = 1
inds[rcorners[1, i], rcorners[0, i]] = i
# Pad the border of the grid, so that we can NMS points near the border.
pad = dist_thresh
grid = np.pad(grid, ((pad, pad), (pad, pad)), mode="constant")
# Iterate through points, highest to lowest conf, suppress neighborhood.
count = 0
for i, rc in enumerate(rcorners.T):
# Account for top and left padding.
pt = (rc[0] + pad, rc[1] + pad)
if grid[pt[1], pt[0]] == 1: # If not yet suppressed.
grid[pt[1] - pad : pt[1] + pad + 1, pt[0] - pad : pt[0] + pad + 1] = 0
grid[pt[1], pt[0]] = -1
count += 1
# Get all surviving -1's and return sorted array of remaining corners.
keepy, keepx = np.where(grid == -1)
keepy, keepx = keepy - pad, keepx - pad
inds_keep = inds[keepy, keepx]
out = corners[:, inds_keep]
values = out[-1, :]
inds2 = np.argsort(-values)
out = out[:, inds2]
out_inds = inds1[inds_keep[inds2]]
return out, out_inds