|
import sys |
|
from pathlib import Path |
|
from ..utils.base_model import BaseModel |
|
from .. import logger |
|
|
|
lightglue_path = Path(__file__).parent / "../../third_party/LightGlue" |
|
sys.path.append(str(lightglue_path)) |
|
from lightglue import LightGlue as LG |
|
|
|
|
|
class LightGlue(BaseModel): |
|
default_conf = { |
|
"match_threshold": 0.2, |
|
"filter_threshold": 0.2, |
|
"width_confidence": 0.99, |
|
"depth_confidence": 0.95, |
|
"features": "superpoint", |
|
"model_name": "superpoint_lightglue.pth", |
|
"flash": True, |
|
"mp": False, |
|
} |
|
required_inputs = [ |
|
"image0", |
|
"keypoints0", |
|
"scores0", |
|
"descriptors0", |
|
"image1", |
|
"keypoints1", |
|
"scores1", |
|
"descriptors1", |
|
] |
|
|
|
def _init(self, conf): |
|
weight_path = lightglue_path / "weights" / conf["model_name"] |
|
conf["weights"] = str(weight_path) |
|
conf["filter_threshold"] = conf["match_threshold"] |
|
self.net = LG(**conf) |
|
logger.info(f"Load lightglue model done.") |
|
|
|
def _forward(self, data): |
|
input = {} |
|
input["image0"] = { |
|
"image": data["image0"], |
|
"keypoints": data["keypoints0"], |
|
"descriptors": data["descriptors0"].permute(0, 2, 1), |
|
} |
|
input["image1"] = { |
|
"image": data["image1"], |
|
"keypoints": data["keypoints1"], |
|
"descriptors": data["descriptors1"].permute(0, 2, 1), |
|
} |
|
return self.net(input) |
|
|