Spaces:
Running
Running
File size: 1,819 Bytes
9223079 6cb641c 9223079 3c77caa 9223079 3c77caa 9223079 6cb641c 3c77caa 9223079 4c930ba 9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
import warnings
from kornia.feature.loftr.loftr import default_cfg
from kornia.feature import LoFTR as LoFTR_
from hloc import logger
from ..utils.base_model import BaseModel
class LoFTR(BaseModel):
default_conf = {
"weights": "outdoor",
"match_threshold": 0.2,
"sinkhorn_iterations": 20,
"max_keypoints": -1,
}
required_inputs = ["image0", "image1"]
def _init(self, conf):
cfg = default_cfg
cfg["match_coarse"]["thr"] = conf["match_threshold"]
cfg["match_coarse"]["skh_iters"] = conf["sinkhorn_iterations"]
self.net = LoFTR_(pretrained=conf["weights"], config=cfg)
logger.info(f"Loaded LoFTR with weights {conf['weights']}")
def _forward(self, data):
# For consistency with hloc pairs, we refine kpts in image0!
rename = {
"keypoints0": "keypoints1",
"keypoints1": "keypoints0",
"image0": "image1",
"image1": "image0",
"mask0": "mask1",
"mask1": "mask0",
}
data_ = {rename[k]: v for k, v in data.items()}
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pred = self.net(data_)
scores = pred["confidence"]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
pred["keypoints0"], pred["keypoints1"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
)
scores = scores[keep]
# Switch back indices
pred = {(rename[k] if k in rename else k): v for k, v in pred.items()}
pred["scores"] = scores
del pred["confidence"]
return pred
|