File size: 3,879 Bytes
63f3cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   pram -> segnet
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   29/01/2024 14:46
=================================================='''
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.layers import MLP, KeypointEncoder
from nets.layers import AttentionalPropagation
from nets.utils import normalize_keypoints


class SegGNN(nn.Module):
    def __init__(self, feature_dim: int, n_layers: int, ac_fn: str = 'relu', norm_fn: str = 'bn', **kwargs):
        super().__init__()
        self.layers = nn.ModuleList([
            AttentionalPropagation(feature_dim, 4, ac_fn=ac_fn, norm_fn=norm_fn)
            for _ in range(n_layers)
        ])

    def forward(self, desc):
        for i, layer in enumerate(self.layers):
            delta = layer(desc, desc)
            desc = desc + delta

        return desc


class SegNet(nn.Module):
    default_config = {
        'descriptor_dim': 256,
        'output_dim': 1024,
        'n_class': 512,
        'keypoint_encoder': [32, 64, 128, 256],
        'n_layers': 9,
        'ac_fn': 'relu',
        'norm_fn': 'in',
        'with_score': False,
        # 'with_global': False,
        'with_cls': False,
        'with_sc': False,
    }

    def __init__(self, config={}):
        super().__init__()
        self.config = {**self.default_config, **config}
        self.with_cls = self.config['with_cls']
        self.with_sc = self.config['with_sc']

        self.n_layers = self.config['n_layers']
        self.gnn = SegGNN(
            feature_dim=self.config['descriptor_dim'],
            n_layers=self.config['n_layers'],
            ac_fn=self.config['ac_fn'],
            norm_fn=self.config['norm_fn'],
        )

        self.with_score = self.config['with_score']
        self.kenc = KeypointEncoder(
            input_dim=3 if self.with_score else 2,
            feature_dim=self.config['descriptor_dim'],
            layers=self.config['keypoint_encoder'],
            ac_fn=self.config['ac_fn'],
            norm_fn=self.config['norm_fn']
        )

        self.seg = MLP(channels=[self.config['descriptor_dim'],
                                 self.config['output_dim'],
                                 self.config['n_class']],
                       ac_fn=self.config['ac_fn'],
                       norm_fn=self.config['norm_fn']
                       )

        if self.with_sc:
            self.sc = MLP(channels=[self.config['descriptor_dim'],
                                    self.config['output_dim'],
                                    3],
                          ac_fn=self.config['ac_fn'],
                          norm_fn=self.config['norm_fn']
                          )

    def preprocess(self, data):
        desc0 = data['seg_descriptors']
        desc0 = desc0.transpose(1, 2)  # [B, N, D] - > [B, D, N]

        if 'norm_keypoints' in data.keys():
            norm_kpts0 = data['norm_keypoints']
        elif 'image' in data.keys():
            kpts0 = data['keypoints']
            norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape)
        else:
            raise ValueError('Require image shape for keypoint coordinate normalization')

        # Keypoint MLP encoder.
        if self.with_score:
            scores0 = data['scores']
        else:
            scores0 = None
        enc0 = self.kenc(norm_kpts0, scores0)

        return desc0, enc0

    def forward(self, data):
        desc, enc = self.preprocess(data=data)
        desc = desc + enc

        desc = self.gnn(desc)
        cls_output = self.seg(desc)  # [B, C, N]
        output = {
            'prediction': cls_output.transpose(-1, -2).contiguous(),
        }

        if self.with_sc:
            sc_output = self.sc(desc)
            output['sc'] = sc_output

        return output