Spaces:
Running
Running
File size: 4,352 Bytes
10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f 10b4a5f 358ab8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from argparse import Namespace
import os
import torch
import cv2
from .base import Viz
from src.utils.metrics import compute_symmetrical_epipolar_errors, compute_pose_errors
from third_party.loftr.src.loftr import LoFTR, default_cfg
class VizLoFTR(Viz):
def __init__(self, args):
super().__init__()
if type(args) == dict:
args = Namespace(**args)
self.match_threshold = args.match_threshold
# Load model
conf = dict(default_cfg)
conf["match_coarse"]["thr"] = self.match_threshold
print(conf)
self.model = LoFTR(config=conf)
ckpt_dict = torch.load(args.ckpt)
self.model.load_state_dict(ckpt_dict["state_dict"])
self.model = self.model.eval().to(self.device)
# Name the method
# self.ckpt_name = args.ckpt.split('/')[-1].split('.')[0]
self.name = "LoFTR"
print(f"Initialize {self.name}")
def match_and_draw(
self,
data_dict,
root_dir=None,
ground_truth=False,
measure_time=False,
viz_matches=True,
):
if measure_time:
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
self.model(data_dict)
if measure_time:
torch.cuda.synchronize()
end.record()
torch.cuda.synchronize()
self.time_stats.append(start.elapsed_time(end))
kpts0 = data_dict["mkpts0_f"].cpu().numpy()
kpts1 = data_dict["mkpts1_f"].cpu().numpy()
img_name0, img_name1 = list(zip(*data_dict["pair_names"]))[0]
img0 = cv2.imread(os.path.join(root_dir, img_name0))
img1 = cv2.imread(os.path.join(root_dir, img_name1))
if str(data_dict["dataset_name"][0]).lower() == "scannet":
img0 = cv2.resize(img0, (640, 480))
img1 = cv2.resize(img1, (640, 480))
if viz_matches:
saved_name = "_".join(
[
img_name0.split("/")[-1].split(".")[0],
img_name1.split("/")[-1].split(".")[0],
]
)
folder_matches = os.path.join(root_dir, "{}_viz_matches".format(self.name))
if not os.path.exists(folder_matches):
os.makedirs(folder_matches)
path_to_save_matches = os.path.join(
folder_matches, "{}.png".format(saved_name)
)
if ground_truth:
compute_symmetrical_epipolar_errors(
data_dict
) # compute epi_errs for each match
compute_pose_errors(
data_dict
) # compute R_errs, t_errs, pose_errs for each pair
epi_errors = data_dict["epi_errs"].cpu().numpy()
R_errors, t_errors = data_dict["R_errs"][0], data_dict["t_errs"][0]
self.draw_matches(
kpts0,
kpts1,
img0,
img1,
epi_errors,
path=path_to_save_matches,
R_errs=R_errors,
t_errs=t_errors,
)
rel_pair_names = list(zip(*data_dict["pair_names"]))
bs = data_dict["image0"].size(0)
metrics = {
# to filter duplicate pairs caused by DistributedSampler
"identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
"epi_errs": [
data_dict["epi_errs"][data_dict["m_bids"] == b].cpu().numpy()
for b in range(bs)
],
"R_errs": data_dict["R_errs"],
"t_errs": data_dict["t_errs"],
"inliers": data_dict["inliers"],
}
self.eval_stats.append({"metrics": metrics})
else:
m_conf = 1 - data_dict["mconf"].cpu().numpy()
self.draw_matches(
kpts0,
kpts1,
img0,
img1,
m_conf,
path=path_to_save_matches,
conf_thr=0.4,
)
|