|
import torch |
|
import warnings |
|
from ..utils.base_model import BaseModel |
|
import sys |
|
from pathlib import Path |
|
|
|
sys.path.append(str(Path(__file__).parent / "../../third_party")) |
|
from TopicFM.src.models.topic_fm import TopicFM as _TopicFM |
|
from TopicFM.src import get_model_cfg |
|
|
|
topicfm_path = Path(__file__).parent / "../../third_party/TopicFM" |
|
|
|
|
|
class TopicFM(BaseModel): |
|
default_conf = { |
|
"weights": "outdoor", |
|
"match_threshold": 0.2, |
|
"n_sampling_topics": 4, |
|
} |
|
required_inputs = ["image0", "image1"] |
|
|
|
def _init(self, conf): |
|
_conf = dict(get_model_cfg()) |
|
_conf["match_coarse"]["thr"] = conf["match_threshold"] |
|
_conf["coarse"]["n_samples"] = conf["n_sampling_topics"] |
|
weight_path = topicfm_path / "pretrained/model_best.ckpt" |
|
self.net = _TopicFM(config=_conf) |
|
ckpt_dict = torch.load(weight_path, map_location="cpu") |
|
self.net.load_state_dict(ckpt_dict["state_dict"]) |
|
|
|
def _forward(self, data): |
|
data_ = { |
|
"image0": data["image0"], |
|
"image1": data["image1"], |
|
} |
|
self.net(data_) |
|
mkpts0 = data_["mkpts0_f"] |
|
mkpts1 = data_["mkpts1_f"] |
|
mconf = data_["mconf"] |
|
total_n_matches = len(data_["mkpts0_f"]) |
|
|
|
pred = {} |
|
pred["keypoints0"], pred["keypoints1"] = mkpts0, mkpts1 |
|
pred["mconf"] = mconf |
|
return pred |
|
|