|
import subprocess |
|
import sys |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from .. import logger |
|
from ..utils.base_model import BaseModel |
|
|
|
mickey_path = Path(__file__).parent / "../../third_party" |
|
sys.path.append(str(mickey_path)) |
|
|
|
from mickey.config.default import cfg |
|
from mickey.lib.models.builder import build_model |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class Mickey(BaseModel): |
|
default_conf = { |
|
"config_path": "config.yaml", |
|
"model_name": "mickey.ckpt", |
|
"max_keypoints": 3000, |
|
} |
|
required_inputs = [ |
|
"image0", |
|
"image1", |
|
] |
|
weight_urls = "https://storage.googleapis.com/niantic-lon-static/research/mickey/assets/mickey_weights.zip" |
|
|
|
|
|
def _init(self, conf): |
|
model_path = mickey_path / "mickey/mickey_weights" / conf["model_name"] |
|
zip_path = mickey_path / "mickey/mickey_weights.zip" |
|
config_path = model_path.parent / self.conf["config_path"] |
|
|
|
if not model_path.exists(): |
|
model_path.parent.mkdir(exist_ok=True, parents=True) |
|
link = self.weight_urls |
|
if not zip_path.exists(): |
|
cmd = ["wget", "--quiet", link, "-O", str(zip_path)] |
|
logger.info(f"Downloading the Mickey model with {cmd}.") |
|
subprocess.run(cmd, check=True) |
|
cmd = ["unzip", "-d", str(model_path.parent.parent), str(zip_path)] |
|
logger.info(f"Running {cmd}.") |
|
subprocess.run(cmd, check=True) |
|
|
|
logger.info("Loading mickey model...") |
|
cfg.merge_from_file(config_path) |
|
self.net = build_model(cfg, checkpoint=model_path) |
|
logger.info("Load Mickey model done.") |
|
|
|
def _forward(self, data): |
|
|
|
|
|
pred = self.net(data) |
|
pred = { |
|
**pred, |
|
**data, |
|
} |
|
inliers = data["inliers_list"] |
|
pred = { |
|
"keypoints0": inliers[:, :2], |
|
"keypoints1": inliers[:, 2:4], |
|
} |
|
|
|
return pred |
|
|