File size: 3,974 Bytes
4d9207d
9223079
 
4d9207d
 
9223079
aa49562
4d9207d
9223079
 
 
 
 
 
 
 
 
 
 
 
 
2b78237
2246920
9223079
 
 
 
 
 
 
 
 
 
2b78237
 
 
9223079
 
 
 
 
 
2b78237
 
 
 
 
 
 
 
9223079
2b78237
 
 
9223079
 
 
aa49562
2b78237
 
 
9223079
 
 
2b78237
aa49562
2b78237
9223079
 
 
 
 
 
2b78237
 
 
94cb1cc
 
 
2b78237
9223079
 
2b78237
 
 
9223079
4d9207d
9223079
 
 
 
 
 
 
fc7dfd7
 
 
 
 
00edd17
 
 
 
 
 
 
 
 
 
9223079
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import subprocess
import sys
from pathlib import Path

import torch

from .. import do_system, logger
from ..utils.base_model import BaseModel

sys.path.append(str(Path(__file__).parent / "../../third_party"))
from ASpanFormer.src.ASpanFormer.aspanformer import ASpanFormer as _ASpanFormer
from ASpanFormer.src.config.default import get_cfg_defaults
from ASpanFormer.src.utils.misc import lower_config

aspanformer_path = Path(__file__).parent / "../../third_party/ASpanFormer"


class ASpanFormer(BaseModel):
    default_conf = {
        "weights": "outdoor",
        "match_threshold": 0.2,
        "sinkhorn_iterations": 20,
        "max_keypoints": 2048,
        "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
        "model_name": "weights_aspanformer.tar",
    }
    required_inputs = ["image0", "image1"]
    proxy = "http://localhost:1080"
    aspanformer_models = {
        "weights_aspanformer.tar": "https://drive.google.com/uc?id=1eavM9dTkw9nbc-JqlVVfGPU5UvTTfc6k&confirm=t"
    }

    def _init(self, conf):
        model_path = (
            aspanformer_path / "weights" / Path(conf["weights"] + ".ckpt")
        )
        # Download the model.
        if not model_path.exists():
            # model_path.parent.mkdir(exist_ok=True)
            tar_path = aspanformer_path / conf["model_name"]
            if not tar_path.exists():
                link = self.aspanformer_models[conf["model_name"]]
                cmd = [
                    "gdown",
                    link,
                    "-O",
                    str(tar_path),
                    "--proxy",
                    self.proxy,
                ]
                cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
                logger.info(
                    f"Downloading the Aspanformer model with `{cmd_wo_proxy}`."
                )
                try:
                    subprocess.run(cmd_wo_proxy, check=True)
                except subprocess.CalledProcessError as e:
                    logger.info(f"Downloading failed {e}.")
                    logger.info(
                        f"Downloading the Aspanformer model with `{cmd}`."
                    )
                    try:
                        subprocess.run(cmd, check=True)
                    except subprocess.CalledProcessError as e:
                        logger.error(
                            f"Failed to download the Aspanformer model: {e}"
                        )

            do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")

        config = get_cfg_defaults()
        config.merge_from_file(conf["config_path"])
        _config = lower_config(config)

        # update: match threshold
        _config["aspan"]["match_coarse"]["thr"] = conf["match_threshold"]
        _config["aspan"]["match_coarse"]["skh_iters"] = conf[
            "sinkhorn_iterations"
        ]

        self.net = _ASpanFormer(config=_config["aspan"])
        weight_path = model_path
        state_dict = torch.load(str(weight_path), map_location="cpu")[
            "state_dict"
        ]
        self.net.load_state_dict(state_dict, strict=False)
        logger.info("Loaded Aspanformer model")

    def _forward(self, data):
        data_ = {
            "image0": data["image0"],
            "image1": data["image1"],
        }
        self.net(data_, online_resize=True)
        pred = {
            "keypoints0": data_["mkpts0_f"],
            "keypoints1": data_["mkpts1_f"],
            "mconf": data_["mconf"],
        }
        scores = data_["mconf"]
        top_k = self.conf["max_keypoints"]
        if top_k is not None and len(scores) > top_k:
            keep = torch.argsort(scores, descending=True)[:top_k]
            scores = scores[keep]
            pred["keypoints0"], pred["keypoints1"], pred["mconf"] = (
                pred["keypoints0"][keep],
                pred["keypoints1"][keep],
                scores,
            )
        return pred