Spaces:
Running
Running
File size: 1,385 Bytes
8320ccc 8811cfe 8320ccc 8811cfe 8320ccc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import sys
from pathlib import Path
from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel
mickey_path = Path(__file__).parent / "../../third_party"
sys.path.append(str(mickey_path))
from mickey.config.default import cfg
from mickey.lib.models.builder import build_model
class Mickey(BaseModel):
default_conf = {
"config_path": "config.yaml",
"model_name": "mickey.ckpt",
"max_keypoints": 3000,
}
required_inputs = [
"image0",
"image1",
]
# Initialize the line matcher
def _init(self, conf):
model_path = self._download_model(
repo_id=MODEL_REPO_ID,
filename="{}/{}".format(
Path(__file__).stem, self.conf["model_name"]
),
)
# TODO: config path of mickey
config_path = model_path.parent / self.conf["config_path"]
logger.info("Loading mickey model...")
cfg.merge_from_file(config_path)
self.net = build_model(cfg, checkpoint=model_path)
logger.info("Load Mickey model done.")
def _forward(self, data):
pred = self.net(data)
pred = {
**pred,
**data,
}
inliers = data["inliers_list"]
pred = {
"keypoints0": inliers[:, :2],
"keypoints1": inliers[:, 2:4],
}
return pred
|