Spaces:
Running
Running
File size: 1,316 Bytes
63f3cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File r2d2 -> gm
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 25/05/2023 10:09
=================================================='''
import torch
from localization.base_model import BaseModel
from nets.gm import GM as GMatcher
class GM(BaseModel):
default_config = {
'descriptor_dim': 128,
'hidden_dim': 256,
'weights': 'indoor',
'keypoint_encoder': [32, 64, 128, 256],
'GNN_layers': ['self', 'cross'] * 9, # [self, cross, self, cross, ...] 9 in total
'sinkhorn_iterations': 20,
'match_threshold': 0.2,
'with_pose': False,
'n_layers': 9,
'n_min_tokens': 256,
'with_sinkhorn': True,
'ac_fn': 'relu',
'norm_fn': 'bn',
'weight_path': None,
}
required_inputs = [
'image0', 'keypoints0', 'scores0', 'descriptors0',
'image1', 'keypoints1', 'scores1', 'descriptors1',
]
def _init(self, conf):
self.net = GMatcher(config=conf).eval()
state_dict = torch.load(conf['weight_path'], map_location='cpu')['model']
self.net.load_state_dict(state_dict, strict=True)
def _forward(self, data):
with torch.no_grad():
return self.net(data)
|