File size: 1,289 Bytes
4bde5d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import sys
from pathlib import Path
import torch
from .. import DEVICE, logger
from ..utils.base_model import BaseModel
tp_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(tp_path))
from pram.nets.gml import GML
class IMP(BaseModel):
default_conf = {
"match_threshold": 0.2,
"features": "sfd2",
"model_name": "imp_gml.920.pth",
"sinkhorn_iterations": 20,
}
required_inputs = [
"image0",
"keypoints0",
"scores0",
"descriptors0",
"image1",
"keypoints1",
"scores1",
"descriptors1",
]
def _init(self, conf):
self.conf = {**self.default_conf, **conf}
weight_path = tp_path / "pram" / "weights" / self.conf["model_name"]
# self.net = nets.gml(self.conf).eval().to(DEVICE)
self.net = GML(self.conf).eval().to(DEVICE)
self.net.load_state_dict(
torch.load(weight_path, map_location="cpu")["model"], strict=True
)
logger.info("Load IMP model done.")
def _forward(self, data):
data["descriptors0"] = data["descriptors0"].transpose(2, 1).float()
data["descriptors1"] = data["descriptors1"].transpose(2, 1).float()
return self.net.produce_matches(data, p=0.2)
|