Spaces:
Running
Running
File size: 4,064 Bytes
e64cfb1 9223079 e64cfb1 9223079 4cffcfe 9223079 f269db9 121523b 9223079 f269db9 9223079 f269db9 9223079 f269db9 9223079 0160434 f269db9 9223079 f269db9 0160434 f269db9 9223079 4cffcfe 9223079 f269db9 63932be f269db9 9223079 f269db9 9223079 e64cfb1 9223079 5715373 a4927e4 9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import subprocess
import sys
from pathlib import Path
import torch
from hloc import logger
from hloc.utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party"))
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
from ASpanFormer.src.config.default import get_cfg_defaults
from ASpanFormer.src.utils.misc import lower_config
aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
class ASpanFormer(BaseModel):
default_conf = {
"weights": "outdoor",
"match_threshold": 0.2,
"sinkhorn_iterations": 20,
"max_keypoints": 2048,
"config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
"model_name": "weights_aspanformer.tar",
}
required_inputs = ["image0", "image1"]
proxy = "http://localhost:1080"
aspanformer_models = {
"weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
}
def _init(self, conf):
model_path = (
aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
)
# Download the model.
if not model_path.exists():
# model_path.parent.mkdir(exist_ok=True)
tar_path = aspanformer_path / conf["model_name"]
if not tar_path.exists():
link = self.aspanformer_models[conf["model_name"]]
cmd = [
"gdown",
link,
"-O",
str(tar_path),
"--proxy",
self.proxy,
]
cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
logger.info(
f"Downloading the Aspanformer model with `{cmd_wo_proxy}`."
)
try:
subprocess.run(cmd_wo_proxy, check=True)
except subprocess.CalledProcessError as e:
logger.info(f"Downloading failed {e}.")
logger.info(
f"Downloading the Aspanformer model with `{cmd}`."
)
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(
f"Failed to download the Aspanformer model: {e}"
)
cmd = ["tar", "-xvf", str(tar_path), "-C", str(aspanformer_path)]
logger.info(f"Unzip model file `{cmd}`.")
subprocess.run(cmd, check=True)
config = get_cfg_defaults()
config.merge_from_file(conf["config_path"])
_config = lower_config(config)
# update: match threshold
_config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"]
_config["aspan"]["match_coarse"]["skh_iters"] = conf[
"sinkhorn_iterations"
]
self.net = _ASpanFormer(config=_config["aspan"])
weight_path = model_path
state_dict = torch.load(str(weight_path), map_location="cpu")[
"state_dict"
]
self.net.load_state_dict(state_dict, strict=False)
logger.info("Loaded Aspanformer model")
def _forward(self, data):
data_ = {
"image0": data["image0"],
"image1": data["image1"],
}
self.net(data_, online_resize=True)
pred = {
"keypoints0": data_["mkpts0_f"],
"keypoints1": data_["mkpts1_f"],
"mconf": data_["mconf"],
}
scores = data_["mconf"]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
scores = scores[keep]
pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
scores,
)
return pred
|