File size: 2,329 Bytes
3ad39d6
4d9207d
3ad39d6
4d9207d
 
 
3ad39d6
4d9207d
3ad39d6
 
 
 
4d9207d
 
 
 
3ad39d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d9207d
3ad39d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b1a727
 
3ad39d6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import sys
from pathlib import Path

import numpy as np
import torch
from torchvision.transforms import ToPILImage

from ..utils.base_model import BaseModel

sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
from COTR.inference.sparse_engine import SparseEngine
from COTR.models import build_model
from COTR.options.options import *  # noqa: F403
from COTR.options.options_utils import *  # noqa: F403
from COTR.utils import utils as utils_cotr

utils_cotr.fix_randomness(0)
torch.set_grad_enabled(False)

cotr_path = Path(__file__).parent / "../../third_party/COTR"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class COTR(BaseModel):
    default_conf = {
        "weights": "out/default",
        "match_threshold": 0.2,
        "max_keypoints": -1,
    }
    required_inputs = ["image0", "image1"]

    def _init(self, conf):
        parser = argparse.ArgumentParser()
        set_COTR_arguments(parser)  # noqa: F405
        opt = parser.parse_args()
        opt.command = " ".join(sys.argv)
        opt.load_weights_path = str(
            cotr_path / conf["weights"] / "checkpoint.pth.tar"
        )

        layer_2_channels = {
            "layer1": 256,
            "layer2": 512,
            "layer3": 1024,
            "layer4": 2048,
        }
        opt.dim_feedforward = layer_2_channels[opt.layer]

        model = build_model(opt)
        model = model.to(device)
        weights = torch.load(opt.load_weights_path, map_location="cpu")[
            "model_state_dict"
        ]
        utils_cotr.safe_load_weights(model, weights)
        self.net = model.eval()
        self.to_pil_func = ToPILImage(mode="RGB")

    def _forward(self, data):
        img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
        img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
        corrs = SparseEngine(
            self.net, 32, mode="tile"
        ).cotr_corr_multiscale_with_cycle_consistency(
            img_a,
            img_b,
            np.linspace(0.5, 0.0625, 4),
            1,
            max_corrs=self.conf["max_keypoints"],
            queries_a=None,
        )
        pred = {
            "keypoints0": torch.from_numpy(corrs[:, :2]),
            "keypoints1": torch.from_numpy(corrs[:, 2:]),
        }
        return pred