File size: 3,544 Bytes
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
404d2af
 
 
 
 
8b973ee
404d2af
 
 
 
 
 
 
 
 
8b973ee
 
 
 
 
 
 
404d2af
8b973ee
 
 
 
 
 
 
404d2af
8b973ee
 
 
404d2af
8b973ee
 
 
 
 
 
 
 
404d2af
 
8b973ee
 
404d2af
 
8b973ee
 
404d2af
8b973ee
 
 
 
404d2af
 
 
 
 
8b973ee
 
 
404d2af
8b973ee
 
 
404d2af
 
 
 
 
 
8b973ee
 
404d2af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import torch.nn as nn
from einops.einops import rearrange

from .backbone import build_backbone
from .modules import LocalFeatureTransformer, FinePreprocess, TopicFormer
from .utils.coarse_matching import CoarseMatching
from .utils.fine_matching import FineMatching


class TopicFM(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Misc
        self.config = config

        # Modules
        self.backbone = build_backbone(config)

        self.loftr_coarse = TopicFormer(config["coarse"])
        self.coarse_matching = CoarseMatching(config["match_coarse"])
        self.fine_preprocess = FinePreprocess(config)
        self.loftr_fine = LocalFeatureTransformer(config["fine"])
        self.fine_matching = FineMatching()

    def forward(self, data):
        """
        Update:
            data (dict): {
                'image0': (torch.Tensor): (N, 1, H, W)
                'image1': (torch.Tensor): (N, 1, H, W)
                'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
                'mask1'(optional) : (torch.Tensor): (N, H, W)
            }
        """
        # 1. Local Feature CNN
        data.update(
            {
                "bs": data["image0"].size(0),
                "hw0_i": data["image0"].shape[2:],
                "hw1_i": data["image1"].shape[2:],
            }
        )

        if data["hw0_i"] == data["hw1_i"]:  # faster & better BN convergence
            feats_c, feats_f = self.backbone(
                torch.cat([data["image0"], data["image1"]], dim=0)
            )
            (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
                data["bs"]
            ), feats_f.split(data["bs"])
        else:  # handle different input shapes
            (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
                data["image0"]
            ), self.backbone(data["image1"])

        data.update(
            {
                "hw0_c": feat_c0.shape[2:],
                "hw1_c": feat_c1.shape[2:],
                "hw0_f": feat_f0.shape[2:],
                "hw1_f": feat_f1.shape[2:],
            }
        )

        # 2. coarse-level loftr module
        feat_c0 = rearrange(feat_c0, "n c h w -> n (h w) c")
        feat_c1 = rearrange(feat_c1, "n c h w -> n (h w) c")

        mask_c0 = mask_c1 = None  # mask is useful in training
        if "mask0" in data:
            mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2)

        feat_c0, feat_c1, conf_matrix, topic_matrix = self.loftr_coarse(
            feat_c0, feat_c1, mask_c0, mask_c1
        )
        data.update({"conf_matrix": conf_matrix, "topic_matrix": topic_matrix})  ######

        # 3. match coarse-level
        self.coarse_matching(data)

        # 4. fine-level refinement
        feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
            feat_f0, feat_f1, feat_c0.detach(), feat_c1.detach(), data
        )
        if feat_f0_unfold.size(0) != 0:  # at least one coarse level predicted
            feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
                feat_f0_unfold, feat_f1_unfold
            )

        # 5. match fine-level
        self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)

    def load_state_dict(self, state_dict, *args, **kwargs):
        for k in list(state_dict.keys()):
            if k.startswith("matcher."):
                state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
        return super().load_state_dict(state_dict, *args, **kwargs)