File size: 1,316 Bytes
63f3cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   r2d2 -> gm
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   25/05/2023 10:09
=================================================='''
import torch
from localization.base_model import BaseModel
from nets.gm import GM as GMatcher


class GM(BaseModel):
    default_config = {
        'descriptor_dim': 128,
        'hidden_dim': 256,
        'weights': 'indoor',
        'keypoint_encoder': [32, 64, 128, 256],
        'GNN_layers': ['self', 'cross'] * 9,  # [self, cross, self, cross, ...] 9 in total
        'sinkhorn_iterations': 20,
        'match_threshold': 0.2,
        'with_pose': False,
        'n_layers': 9,
        'n_min_tokens': 256,
        'with_sinkhorn': True,

        'ac_fn': 'relu',
        'norm_fn': 'bn',
        'weight_path': None,
    }

    required_inputs = [
        'image0', 'keypoints0', 'scores0', 'descriptors0',
        'image1', 'keypoints1', 'scores1', 'descriptors1',
    ]

    def _init(self, conf):
        self.net = GMatcher(config=conf).eval()
        state_dict = torch.load(conf['weight_path'], map_location='cpu')['model']
        self.net.load_state_dict(state_dict, strict=True)

    def _forward(self, data):
        with torch.no_grad():
            return self.net(data)