Spaces:
Running
Running
File size: 4,064 Bytes
4d9207d 9223079 4d9207d 9223079 489cb4a 9223079 2b78237 2246920 9223079 2b78237 9223079 2b78237 9223079 2b78237 9223079 aa49562 2b78237 9223079 2b78237 aa49562 2b78237 9223079 489cb4a 9223079 2b78237 94cb1cc 2b78237 9223079 2b78237 9223079 4d9207d 9223079 fc7dfd7 00edd17 9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import subprocess
import sys
from pathlib import Path
import torch
from hloc import logger
from hloc.utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party"))
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
from ASpanFormer.src.config.default import get_cfg_defaults
from ASpanFormer.src.utils.misc import lower_config
aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
class ASpanFormer(BaseModel):
default_conf = {
"weights": "outdoor",
"match_threshold": 0.2,
"sinkhorn_iterations": 20,
"max_keypoints": 2048,
"config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
"model_name": "weights_aspanformer.tar",
}
required_inputs = ["image0", "image1"]
proxy = "http://localhost:1080"
aspanformer_models = {
"weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
}
def _init(self, conf):
model_path = (
aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
)
# Download the model.
if not model_path.exists():
# model_path.parent.mkdir(exist_ok=True)
tar_path = aspanformer_path / conf["model_name"]
if not tar_path.exists():
link = self.aspanformer_models[conf["model_name"]]
cmd = [
"gdown",
link,
"-O",
str(tar_path),
"--proxy",
self.proxy,
]
cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
logger.info(
f"Downloading the Aspanformer model with `{cmd_wo_proxy}`."
)
try:
subprocess.run(cmd_wo_proxy, check=True)
except subprocess.CalledProcessError as e:
logger.info(f"Downloading failed {e}.")
logger.info(
f"Downloading the Aspanformer model with `{cmd}`."
)
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(
f"Failed to download the Aspanformer model: {e}"
)
cmd = ["tar", "-xvf", str(tar_path), "-C", str(aspanformer_path)]
logger.info(f"Unzip model file `{cmd}`.")
subprocess.run(cmd, check=True)
config = get_cfg_defaults()
config.merge_from_file(conf["config_path"])
_config = lower_config(config)
# update: match threshold
_config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"]
_config["aspan"]["match_coarse"]["skh_iters"] = conf[
"sinkhorn_iterations"
]
self.net = _ASpanFormer(config=_config["aspan"])
weight_path = model_path
state_dict = torch.load(str(weight_path), map_location="cpu")[
"state_dict"
]
self.net.load_state_dict(state_dict, strict=False)
logger.info("Loaded Aspanformer model")
def _forward(self, data):
data_ = {
"image0": data["image0"],
"image1": data["image1"],
}
self.net(data_, online_resize=True)
pred = {
"keypoints0": data_["mkpts0_f"],
"keypoints1": data_["mkpts1_f"],
"mconf": data_["mconf"],
}
scores = data_["mconf"]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
scores = scores[keep]
pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
scores,
)
return pred
|