|
import sys |
|
import warnings |
|
from copy import deepcopy |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
tp_path = Path(__file__).parent / "../../third_party" |
|
sys.path.append(str(tp_path)) |
|
|
|
from EfficientLoFTR.src.loftr import LoFTR as ELoFTR_ |
|
from EfficientLoFTR.src.loftr import ( |
|
full_default_cfg, |
|
opt_default_cfg, |
|
reparameter, |
|
) |
|
|
|
from hloc import logger |
|
|
|
from ..utils.base_model import BaseModel |
|
|
|
|
|
class ELoFTR(BaseModel): |
|
default_conf = { |
|
"weights": "weights/eloftr_outdoor.ckpt", |
|
"match_threshold": 0.2, |
|
|
|
"max_keypoints": -1, |
|
|
|
"model_type": "full", |
|
|
|
"precision": "fp32", |
|
} |
|
required_inputs = ["image0", "image1"] |
|
|
|
def _init(self, conf): |
|
|
|
if self.conf["model_type"] == "full": |
|
_default_cfg = deepcopy(full_default_cfg) |
|
elif self.conf["model_type"] == "opt": |
|
_default_cfg = deepcopy(opt_default_cfg) |
|
|
|
if self.conf["precision"] == "mp": |
|
_default_cfg["mp"] = True |
|
elif self.conf["precision"] == "fp16": |
|
_default_cfg["half"] = True |
|
model_path = tp_path / "EfficientLoFTR" / self.conf["weights"] |
|
cfg = _default_cfg |
|
cfg["match_coarse"]["thr"] = conf["match_threshold"] |
|
|
|
state_dict = torch.load(model_path, map_location="cpu")["state_dict"] |
|
matcher = ELoFTR_(config=cfg) |
|
matcher.load_state_dict(state_dict) |
|
self.net = reparameter(matcher) |
|
|
|
if self.conf["precision"] == "fp16": |
|
self.net = self.net.half() |
|
logger.info(f"Loaded Efficient LoFTR with weights {conf['weights']}") |
|
|
|
def _forward(self, data): |
|
|
|
rename = { |
|
"keypoints0": "keypoints1", |
|
"keypoints1": "keypoints0", |
|
"image0": "image1", |
|
"image1": "image0", |
|
"mask0": "mask1", |
|
"mask1": "mask0", |
|
} |
|
data_ = {rename[k]: v for k, v in data.items()} |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
pred = self.net(data_) |
|
pred = { |
|
"keypoints0": data_["mkpts0_f"], |
|
"keypoints1": data_["mkpts1_f"], |
|
} |
|
scores = data_["mconf"] |
|
|
|
top_k = self.conf["max_keypoints"] |
|
if top_k is not None and len(scores) > top_k: |
|
keep = torch.argsort(scores, descending=True)[:top_k] |
|
pred["keypoints0"], pred["keypoints1"] = ( |
|
pred["keypoints0"][keep], |
|
pred["keypoints1"][keep], |
|
) |
|
scores = scores[keep] |
|
|
|
|
|
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()} |
|
pred["scores"] = scores |
|
return pred |
|
|