Spaces:
Running
Running
File size: 3,507 Bytes
dbf8b7e 358ab8f dbf8b7e 358ab8f dbf8b7e 358ab8f dbf8b7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
import numpy as np
import tqdm
from dkm.datasets import MegadepthBuilder
from dkm.utils import warp_kpts
from torch.utils.data import ConcatDataset
class MegadepthDenseBenchmark:
def __init__(
self, data_root="data/megadepth", h=384, w=512, num_samples=2000, device=None
) -> None:
mega = MegadepthBuilder(data_root=data_root)
self.dataset = ConcatDataset(
mega.build_scenes(split="test_loftr", ht=h, wt=w)
) # fixed resolution of 384,512
self.num_samples = num_samples
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
b, h1, w1, d = dense_matches.shape
with torch.no_grad():
x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
# x1 = torch.stack((2*x1[...,0]/w1-1,2*x1[...,1]/h1-1),dim=-1)
mask, x2 = warp_kpts(
x1.double(),
depth1.double(),
depth2.double(),
T_1to2.double(),
K1.double(),
K2.double(),
)
x2 = torch.stack(
(w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
)
prob = mask.float().reshape(b, h1, w1)
x2_hat = dense_matches[..., 2:]
x2_hat = torch.stack(
(w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
)
gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
gd = gd[prob == 1]
pck_1 = (gd < 1.0).float().mean()
pck_3 = (gd < 3.0).float().mean()
pck_5 = (gd < 5.0).float().mean()
gd = gd.mean()
return gd, pck_1, pck_3, pck_5
def benchmark(self, model, batch_size=8):
model.train(False)
with torch.no_grad():
gd_tot = 0.0
pck_1_tot = 0.0
pck_3_tot = 0.0
pck_5_tot = 0.0
sampler = torch.utils.data.WeightedRandomSampler(
torch.ones(len(self.dataset)),
replacement=False,
num_samples=self.num_samples,
)
dataloader = torch.utils.data.DataLoader(
self.dataset, batch_size=8, num_workers=batch_size, sampler=sampler
)
for data in tqdm.tqdm(dataloader):
im1, im2, depth1, depth2, T_1to2, K1, K2 = (
data["query"],
data["support"],
data["query_depth"].to(self.device),
data["support_depth"].to(self.device),
data["T_1to2"].to(self.device),
data["K1"].to(self.device),
data["K2"].to(self.device),
)
matches, certainty = model.match(im1, im2, batched=True)
gd, pck_1, pck_3, pck_5 = self.geometric_dist(
depth1, depth2, T_1to2, K1, K2, matches
)
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
gd_tot + gd,
pck_1_tot + pck_1,
pck_3_tot + pck_3,
pck_5_tot + pck_5,
)
return {
"mega_pck_1": pck_1_tot.item() / len(dataloader),
"mega_pck_3": pck_3_tot.item() / len(dataloader),
"mega_pck_5": pck_5_tot.item() / len(dataloader),
}
|