Spaces:
Running
Running
File size: 3,974 Bytes
4d9207d 9223079 4d9207d 9223079 aa49562 4d9207d 9223079 2b78237 2246920 9223079 2b78237 9223079 2b78237 9223079 2b78237 9223079 aa49562 2b78237 9223079 2b78237 aa49562 2b78237 9223079 2b78237 94cb1cc 2b78237 9223079 2b78237 9223079 4d9207d 9223079 fc7dfd7 00edd17 9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import subprocess
import sys
from pathlib import Path
import torch
from .. import do_system, logger
from ..utils.base_model import BaseModel
sys.path.append(str(Path(__file__).parent / "../../third_party"))
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
from ASpanFormer.src.config.default import get_cfg_defaults
from ASpanFormer.src.utils.misc import lower_config
aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"
class ASpanFormer(BaseModel):
default_conf = {
"weights": "outdoor",
"match_threshold": 0.2,
"sinkhorn_iterations": 20,
"max_keypoints": 2048,
"config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
"model_name": "weights_aspanformer.tar",
}
required_inputs = ["image0", "image1"]
proxy = "http://localhost:1080"
aspanformer_models = {
"weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
}
def _init(self, conf):
model_path = (
aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
)
# Download the model.
if not model_path.exists():
# model_path.parent.mkdir(exist_ok=True)
tar_path = aspanformer_path / conf["model_name"]
if not tar_path.exists():
link = self.aspanformer_models[conf["model_name"]]
cmd = [
"gdown",
link,
"-O",
str(tar_path),
"--proxy",
self.proxy,
]
cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
logger.info(
f"Downloading the Aspanformer model with `{cmd_wo_proxy}`."
)
try:
subprocess.run(cmd_wo_proxy, check=True)
except subprocess.CalledProcessError as e:
logger.info(f"Downloading failed {e}.")
logger.info(
f"Downloading the Aspanformer model with `{cmd}`."
)
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
logger.error(
f"Failed to download the Aspanformer model: {e}"
)
do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")
config = get_cfg_defaults()
config.merge_from_file(conf["config_path"])
_config = lower_config(config)
# update: match threshold
_config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"]
_config["aspan"]["match_coarse"]["skh_iters"] = conf[
"sinkhorn_iterations"
]
self.net = _ASpanFormer(config=_config["aspan"])
weight_path = model_path
state_dict = torch.load(str(weight_path), map_location="cpu")[
"state_dict"
]
self.net.load_state_dict(state_dict, strict=False)
logger.info("Loaded Aspanformer model")
def _forward(self, data):
data_ = {
"image0": data["image0"],
"image1": data["image1"],
}
self.net(data_, online_resize=True)
pred = {
"keypoints0": data_["mkpts0_f"],
"keypoints1": data_["mkpts1_f"],
"mconf": data_["mconf"],
}
scores = data_["mconf"]
top_k = self.conf["max_keypoints"]
if top_k is not None and len(scores) > top_k:
keep = torch.argsort(scores, descending=True)[:top_k]
scores = scores[keep]
pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
pred["keypoints0"][keep],
pred["keypoints1"][keep],
scores,
)
return pred
|