Realcat
fix: eloftr
63f3cf2
raw
history blame
7.13 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> segnetvit
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 14:52
=================================================='''
import torch
from torch import nn
import torch.nn.functional as F
from nets.utils import normalize_keypoints
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
def apply_cached_rotary_emb(
freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
class LearnableFourierPositionalEncoding(nn.Module):
def __init__(self, M: int, dim: int, F_dim: int = None,
gamma: float = 1.0) -> None:
super().__init__()
F_dim = F_dim if F_dim is not None else dim
self.gamma = gamma
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" encode position vector """
projected = self.Wr(x)
cosines, sines = torch.cos(projected), torch.sin(projected)
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
return emb.repeat_interleave(2, dim=-1)
class KeypointEncoder(nn.Module):
""" Joint encoding of visual appearance and location using MLPs"""
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(2, 32),
nn.LayerNorm(32, elementwise_affine=True),
nn.GELU(),
nn.Linear(32, 64),
nn.LayerNorm(64, elementwise_affine=True),
nn.GELU(),
nn.Linear(64, 128),
nn.LayerNorm(128, elementwise_affine=True),
nn.GELU(),
nn.Linear(128, 256),
)
def forward(self, kpts, scores=None):
if scores is not None:
inputs = [kpts, scores.unsqueeze(2)] # [B, N, 2] + [B, N, 1]
return self.encoder(torch.cat(inputs, dim=-1))
else:
return self.encoder(kpts)
class Attention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v):
s = q.shape[-1] ** -0.5
attn = F.softmax(torch.einsum('...id,...jd->...ij', q, k) * s, -1)
return torch.einsum('...ij,...jd->...id', attn, v)
class SelfMultiHeadAttention(nn.Module):
def __init__(self, feat_dim: int, hidden_dim: int, num_heads: int):
super().__init__()
self.feat_dim = feat_dim
self.num_heads = num_heads
assert feat_dim % num_heads == 0
self.head_dim = feat_dim // num_heads
self.qkv = nn.Linear(feat_dim, hidden_dim * 3)
self.attn = Attention()
self.proj = nn.Linear(hidden_dim, hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(feat_dim + hidden_dim, feat_dim * 2),
nn.LayerNorm(feat_dim * 2, elementwise_affine=True),
nn.GELU(),
nn.Linear(feat_dim * 2, feat_dim)
)
def forward(self, x, encoding=None):
qkv = self.qkv(x)
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
if encoding is not None:
q = apply_cached_rotary_emb(encoding, q)
k = apply_cached_rotary_emb(encoding, k)
attn = self.attn(q, k, v)
message = self.proj(attn.transpose(1, 2).flatten(start_dim=-2))
return x + self.mlp(torch.cat([x, message], -1))
class SegGNNViT(nn.Module):
def __init__(self, feature_dim: int, n_layers: int, hidden_dim: int = 256, num_heads: int = 4, **kwargs):
super(SegGNNViT, self).__init__()
self.layers = nn.ModuleList([
SelfMultiHeadAttention(feat_dim=feature_dim, hidden_dim=hidden_dim, num_heads=num_heads)
for _ in range(n_layers)
])
def forward(self, desc, encoding=None):
for i, layer in enumerate(self.layers):
desc = layer(desc, encoding)
# desc = desc + delta // should be removed as this is already done in self-attention
return desc
class SegNetViT(nn.Module):
default_config = {
'descriptor_dim': 256,
'output_dim': 1024,
'n_class': 512,
'keypoint_encoder': [32, 64, 128, 256],
'n_layers': 15,
'num_heads': 4,
'hidden_dim': 256,
'with_score': False,
'with_global': False,
'with_cls': False,
'with_sc': False,
}
def __init__(self, config={}):
super(SegNetViT, self).__init__()
self.config = {**self.default_config, **config}
self.with_cls = self.config['with_cls']
self.with_sc = self.config['with_sc']
self.n_layers = self.config['n_layers']
self.gnn = SegGNNViT(
feature_dim=self.config['hidden_dim'],
n_layers=self.config['n_layers'],
hidden_dim=self.config['hidden_dim'],
num_heads=self.config['num_heads'],
)
self.with_score = self.config['with_score']
self.kenc = LearnableFourierPositionalEncoding(2, self.config['hidden_dim'] // self.config['num_heads'],
self.config['hidden_dim'] // self.config['num_heads'])
self.input_proj = nn.Linear(in_features=self.config['descriptor_dim'],
out_features=self.config['hidden_dim'])
self.seg = nn.Sequential(
nn.Linear(in_features=self.config['hidden_dim'], out_features=self.config['output_dim']),
nn.LayerNorm(self.config['output_dim'], elementwise_affine=True),
nn.GELU(),
nn.Linear(self.config['output_dim'], self.config['n_class'])
)
if self.with_sc:
self.sc = nn.Sequential(
nn.Linear(in_features=config['hidden_dim'], out_features=self.config['output_dim']),
nn.LayerNorm(self.config['output_dim'], elementwise_affine=True),
nn.GELU(),
nn.Linear(self.config['output_dim'], 3)
)
def preprocess(self, data):
desc0 = data['seg_descriptors']
if 'norm_keypoints' in data.keys():
norm_kpts0 = data['norm_keypoints']
elif 'image' in data.keys():
kpts0 = data['keypoints']
norm_kpts0 = normalize_keypoints(kpts0, data['image'].shape)
else:
raise ValueError('Require image shape for keypoint coordinate normalization')
enc0 = self.kenc(norm_kpts0)
return desc0, enc0
def forward(self, data):
desc, enc = self.preprocess(data=data)
desc = self.input_proj(desc)
desc = self.gnn(desc, enc)
seg_output = self.seg(desc) # [B, N, C]
output = {
'prediction': seg_output,
}
if self.with_sc:
sc_output = self.sc(desc)
output['sc'] = sc_output
return output