File size: 4,882 Bytes
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5da9dd4
9223079
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
from pathlib import Path
import subprocess
import logging
import torch
from PIL import Image
from collections import OrderedDict, namedtuple
from ..utils.base_model import BaseModel
from ..utils import do_system

sgmnet_path = Path(__file__).parent / "../../third_party/SGMNet"
sys.path.append(str(sgmnet_path))

from sgmnet import matcher as SGM_Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.getLogger(__name__)


class SGMNet(BaseModel):
    default_conf = {
        "name": "SGM",
        "model_name": "model_best.pth",
        "seed_top_k": [256, 256],
        "seed_radius_coe": 0.01,
        "net_channels": 128,
        "layer_num": 9,
        "head": 4,
        "seedlayer": [0, 6],
        "use_mc_seeding": True,
        "use_score_encoding": False,
        "conf_bar": [1.11, 0.1],
        "sink_iter": [10, 100],
        "detach_iter": 1000000,
        "match_threshold": 0.2,
    }
    required_inputs = [
        "image0",
        "image1",
    ]
    weight_urls = {
        "model_best.pth": "https://drive.google.com/uc?id=1Ca0WmKSSt2G6P7m8YAOlSAHEFar_TAWb&confirm=t",
    }
    proxy = "http://localhost:1080"

    # Initialize the line matcher
    def _init(self, conf):
        sgmnet_weights = sgmnet_path / "weights/sgm/root" / conf["model_name"]

        link = self.weight_urls[conf["model_name"]]
        tar_path = sgmnet_path / "weights.tar.gz"
        # Download the model.
        if not sgmnet_weights.exists():
            if not tar_path.exists():
                cmd = ["gdown", link, "-O", str(tar_path), "--proxy", self.proxy]
                cmd_wo_proxy = ["gdown", link, "-O", str(tar_path)]
                logger.info(f"Downloading the SGMNet model with `{cmd_wo_proxy}`.")
                try:
                    subprocess.run(cmd_wo_proxy, check=True)
                except subprocess.CalledProcessError as e:
                    logger.info(f"Downloading the SGMNet model with `{cmd}`.")
                    try:
                        subprocess.run(cmd, check=True)
                    except subprocess.CalledProcessError as e:
                        logger.error(f"Failed to download the SGMNet model.")
                        raise e
            cmd = [f"cd {str(sgmnet_path)} & tar -xvf", str(tar_path)]
            logger.info(f"Unzip model file `{cmd}`.")
            do_system(f"cd {str(sgmnet_path)} & tar -xvf {str(tar_path)}")

        # config
        config = namedtuple("config", conf.keys())(*conf.values())
        self.net = SGM_Model(config)
        checkpoint = torch.load(sgmnet_weights, map_location="cpu")
        # for ddp model
        if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module":
            new_stat_dict = OrderedDict()
            for key, value in checkpoint["state_dict"].items():
                new_stat_dict[key[7:]] = value
            checkpoint["state_dict"] = new_stat_dict
        self.net.load_state_dict(checkpoint["state_dict"])
        logger.info(f"Load SGMNet model done.")

    def _forward(self, data):
        x1 = data["keypoints0"]  # N x 2
        x2 = data["keypoints1"]
        score1 = data["scores0"].reshape(-1, 1)  # N x 1
        score2 = data["scores1"].reshape(-1, 1)
        desc1 = data["descriptors0"].permute(0, 2, 1)  # 1 x N x 128
        desc2 = data["descriptors1"].permute(0, 2, 1)
        size1 = torch.tensor(data["image0"].shape[2:]).flip(0)  # W x H -> x & y
        size2 = torch.tensor(data["image1"].shape[2:]).flip(0)  # W x H
        norm_x1 = self.normalize_size(x1, size1)
        norm_x2 = self.normalize_size(x2, size2)

        x1 = torch.cat((norm_x1, score1), dim=-1)  # N x 3
        x2 = torch.cat((norm_x2, score2), dim=-1)
        input = {"x1": x1[None], "x2": x2[None], "desc1": desc1, "desc2": desc2}
        input = {
            k: v.to(device).float() if isinstance(v, torch.Tensor) else v
            for k, v in input.items()
        }
        pred = self.net(input, test_mode=True)

        p = pred["p"]  # shape: N * M
        indices0 = self.match_p(p[0, :-1, :-1])
        pred = {
            "matches0": indices0.unsqueeze(0),
            "matching_scores0": torch.zeros(indices0.size(0)).unsqueeze(0),
        }
        return pred

    def match_p(self, p):
        score, index = torch.topk(p, k=1, dim=-1)
        _, index2 = torch.topk(p, k=1, dim=-2)
        mask_th, index, index2 = (
            score[:, 0] > self.conf["match_threshold"],
            index[:, 0],
            index2.squeeze(0),
        )
        mask_mc = index2[index] == torch.arange(len(p)).to(device)
        mask = mask_th & mask_mc
        indices0 = torch.where(mask, index, index.new_tensor(-1))
        return indices0

    def normalize_size(self, x, size, scale=1):
        norm_fac = size.max()
        return (x - size / 2 + 0.5) / (norm_fac * scale)