File size: 3,386 Bytes
404d2af
 
 
 
 
 
 
8b973ee
404d2af
 
 
8b973ee
 
 
 
404d2af
 
 
 
8b973ee
404d2af
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
404d2af
 
 
 
 
 
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
 
8b973ee
 
 
 
 
 
 
 
 
 
 
404d2af
8b973ee
 
 
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright 2020 Toyota Research Institute.  All rights reserved.

import numpy as np
import torch
import torchvision.transforms as transforms
from tqdm import tqdm

from evaluation.descriptor_evaluation import compute_homography, compute_matching_score
from evaluation.detector_evaluation import compute_repeatability


def evaluate_keypoint_net(
    data_loader, keypoint_net, output_shape=(320, 240), top_k=300
):
    """Keypoint net evaluation script.

    Parameters
    ----------
    data_loader: torch.utils.data.DataLoader
        Dataset loader.
    keypoint_net: torch.nn.module
        Keypoint network.
    output_shape: tuple
        Original image shape.
    top_k: int
        Number of keypoints to use to compute metrics, selected based on probability.
    use_color: bool
        Use color or grayscale images.
    """
    keypoint_net.eval()
    keypoint_net.training = False

    conf_threshold = 0.0
    localization_err, repeatability = [], []
    correctness1, correctness3, correctness5, MScore = [], [], [], []

    with torch.no_grad():
        for i, sample in tqdm(enumerate(data_loader), desc="Evaluate point model"):

            image = sample["image"].cuda()
            warped_image = sample["warped_image"].cuda()

            score_1, coord_1, desc1 = keypoint_net(image)
            score_2, coord_2, desc2 = keypoint_net(warped_image)
            B, _, Hc, Wc = desc1.shape

            # Scores & Descriptors
            score_1 = torch.cat([coord_1, score_1], dim=1).view(3, -1).t().cpu().numpy()
            score_2 = torch.cat([coord_2, score_2], dim=1).view(3, -1).t().cpu().numpy()
            desc1 = desc1.view(256, Hc, Wc).view(256, -1).t().cpu().numpy()
            desc2 = desc2.view(256, Hc, Wc).view(256, -1).t().cpu().numpy()

            # Filter based on confidence threshold
            desc1 = desc1[score_1[:, 2] > conf_threshold, :]
            desc2 = desc2[score_2[:, 2] > conf_threshold, :]
            score_1 = score_1[score_1[:, 2] > conf_threshold, :]
            score_2 = score_2[score_2[:, 2] > conf_threshold, :]

            # Prepare data for eval
            data = {
                "image": sample["image"].numpy().squeeze(),
                "image_shape": output_shape[::-1],
                "warped_image": sample["warped_image"].numpy().squeeze(),
                "homography": sample["homography"].squeeze().numpy(),
                "prob": score_1,
                "warped_prob": score_2,
                "desc": desc1,
                "warped_desc": desc2,
            }

            # Compute repeatabilty and localization error
            _, _, rep, loc_err = compute_repeatability(
                data, keep_k_points=top_k, distance_thresh=3
            )
            repeatability.append(rep)
            localization_err.append(loc_err)

            # Compute correctness
            c1, c2, c3 = compute_homography(data, keep_k_points=top_k)
            correctness1.append(c1)
            correctness3.append(c2)
            correctness5.append(c3)

            # Compute matching score
            mscore = compute_matching_score(data, keep_k_points=top_k)
            MScore.append(mscore)

    return (
        np.mean(repeatability),
        np.mean(localization_err),
        np.mean(correctness1),
        np.mean(correctness3),
        np.mean(correctness5),
        np.mean(MScore),
    )