File size: 4,948 Bytes
dbf8b7e
 
 
 
 
 
 
 
358ab8f
dbf8b7e
358ab8f
 
 
dbf8b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358ab8f
 
 
dbf8b7e
 
 
 
 
358ab8f
dbf8b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358ab8f
dbf8b7e
 
358ab8f
dbf8b7e
358ab8f
 
 
 
dbf8b7e
 
358ab8f
 
 
dbf8b7e
 
 
358ab8f
 
 
 
dbf8b7e
358ab8f
 
dbf8b7e
358ab8f
 
dbf8b7e
358ab8f
 
dbf8b7e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import numpy as np
import tqdm
from roma.datasets import MegadepthBuilder
from roma.utils import warp_kpts
from torch.utils.data import ConcatDataset
import roma


class MegadepthDenseBenchmark:
    def __init__(
        self, data_root="data/megadepth", h=384, w=512, num_samples=2000
    ) -> None:
        mega = MegadepthBuilder(data_root=data_root)
        self.dataset = ConcatDataset(
            mega.build_scenes(split="test_loftr", ht=h, wt=w)
        )  # fixed resolution of 384,512
        self.num_samples = num_samples

    def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
        b, h1, w1, d = dense_matches.shape
        with torch.no_grad():
            x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
            mask, x2 = warp_kpts(
                x1.double(),
                depth1.double(),
                depth2.double(),
                T_1to2.double(),
                K1.double(),
                K2.double(),
            )
            x2 = torch.stack(
                (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
            )
            prob = mask.float().reshape(b, h1, w1)
        x2_hat = dense_matches[..., 2:]
        x2_hat = torch.stack(
            (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
        )
        gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
        gd = gd[prob == 1]
        pck_1 = (gd < 1.0).float().mean()
        pck_3 = (gd < 3.0).float().mean()
        pck_5 = (gd < 5.0).float().mean()
        return gd, pck_1, pck_3, pck_5, prob

    def benchmark(self, model, batch_size=8):
        model.train(False)
        with torch.no_grad():
            gd_tot = 0.0
            pck_1_tot = 0.0
            pck_3_tot = 0.0
            pck_5_tot = 0.0
            sampler = torch.utils.data.WeightedRandomSampler(
                torch.ones(len(self.dataset)),
                replacement=False,
                num_samples=self.num_samples,
            )
            B = batch_size
            dataloader = torch.utils.data.DataLoader(
                self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
            )
            for idx, data in tqdm.tqdm(enumerate(dataloader), disable=roma.RANK > 0):
                im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
                    data["im_A"],
                    data["im_B"],
                    data["im_A_depth"].cuda(),
                    data["im_B_depth"].cuda(),
                    data["T_1to2"].cuda(),
                    data["K1"].cuda(),
                    data["K2"].cuda(),
                )
                matches, certainty = model.match(im_A, im_B, batched=True)
                gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
                    depth1, depth2, T_1to2, K1, K2, matches
                )
                if roma.DEBUG_MODE:
                    from roma.utils.utils import tensor_to_pil
                    import torch.nn.functional as F

                    path = "vis"
                    H, W = model.get_output_resolution()
                    white_im = torch.ones((B, 1, H, W), device="cuda")
                    im_B_transfer_rgb = F.grid_sample(
                        im_B.cuda(),
                        matches[:, :, :W, 2:],
                        mode="bilinear",
                        align_corners=False,
                    )
                    warp_im = im_B_transfer_rgb
                    c_b = certainty[
                        :, None
                    ]  # (certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
                    vis_im = c_b * warp_im + (1 - c_b) * white_im
                    for b in range(B):
                        import os

                        os.makedirs(
                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}", exist_ok=True
                        )
                        tensor_to_pil(vis_im[b], unnormalize=True).save(
                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg"
                        )
                        tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg"
                        )
                        tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg"
                        )

                gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
                    gd_tot + gd.mean(),
                    pck_1_tot + pck_1,
                    pck_3_tot + pck_3,
                    pck_5_tot + pck_5,
                )
        return {
            "epe": gd_tot.item() / len(dataloader),
            "mega_pck_1": pck_1_tot.item() / len(dataloader),
            "mega_pck_3": pck_3_tot.item() / len(dataloader),
            "mega_pck_5": pck_5_tot.item() / len(dataloader),
        }