import copy import torch import torch.nn as nn import torch.nn.functional as F from .linear_attention import Attention, crop_feature, pad_feature from einops.einops import rearrange from collections import OrderedDict from ..utils.position_encoding import RoPEPositionEncodingSine import numpy as np from loguru import logger class AG_RoPE_EncoderLayer(nn.Module): def __init__(self, d_model, nhead, agg_size0=4, agg_size1=4, no_flash=False, rope=False, npe=None, fp32=False, ): super(AG_RoPE_EncoderLayer, self).__init__() self.dim = d_model // nhead self.nhead = nhead self.agg_size0, self.agg_size1 = agg_size0, agg_size1 self.rope = rope # aggregate and position encoding self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=agg_size0, padding=0, stride=agg_size0, bias=False, groups=d_model) if self.agg_size0 != 1 else nn.Identity() self.max_pool = torch.nn.MaxPool2d(kernel_size=self.agg_size1, stride=self.agg_size1) if self.agg_size1 != 1 else nn.Identity() self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True) # multi-head attention self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.attention = Attention(no_flash, self.nhead, self.dim, fp32) self.merge = nn.Linear(d_model, d_model, bias=False) # feed-forward network self.mlp = nn.Sequential( nn.Linear(d_model*2, d_model*2, bias=False), nn.LeakyReLU(inplace = True), nn.Linear(d_model*2, d_model, bias=False), ) # norm self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x, source, x_mask=None, source_mask=None): """ Args: x (torch.Tensor): [N, C, H0, W0] source (torch.Tensor): [N, C, H1, W1] x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0) source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1) """ bs, C, H0, W0 = x.size() H1, W1 = source.size(-2), source.size(-1) # Aggragate feature assert x_mask is None and source_mask is None query, source = self.norm1(self.aggregate(x).permute(0,2,3,1)), self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C] if x_mask is not None: x_mask, source_mask = map(lambda x: self.max_pool(x.float()).bool(), [x_mask, source_mask]) query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source) # Positional encoding if self.rope: query = self.rope_pos_enc(query) key = self.rope_pos_enc(key) # multi-head attention handle padding mask m = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C] # Upsample feature m = rearrange(m, 'b (h w) c -> b c h w', h=H0 // self.agg_size0, w=W0 // self.agg_size0) # [N, C, H0, W0] if self.agg_size0 != 1: m = torch.nn.functional.interpolate(m, scale_factor=self.agg_size0, mode='bilinear', align_corners=False) # [N, C, H0, W0] # feed-forward network m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H0, W0, C] m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0] return x + m class LocalFeatureTransformer(nn.Module): """A Local Feature Transformer (LoFTR) module.""" def __init__(self, config): super(LocalFeatureTransformer, self).__init__() self.full_config = config self.fp32 = not (config['mp'] or config['half']) config = config['coarse'] self.d_model = config['d_model'] self.nhead = config['nhead'] self.layer_names = config['layer_names'] self.agg_size0, self.agg_size1 = config['agg_size0'], config['agg_size1'] self.rope = config['rope'] self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'], config['no_flash'], config['rope'], config['npe'], self.fp32) cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'], config['no_flash'], False, config['npe'], self.fp32) self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names]) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, feat0, feat1, mask0=None, mask1=None, data=None): """ Args: feat0 (torch.Tensor): [N, C, H, W] feat1 (torch.Tensor): [N, C, H, W] mask0 (torch.Tensor): [N, L] (optional) mask1 (torch.Tensor): [N, S] (optional) """ H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1) bs = feat0.shape[0] feature_cropped = False if bs == 1 and mask0 is not None and mask1 is not None: mask_H0, mask_W0, mask_H1, mask_W1 = mask0.size(-2), mask0.size(-1), mask1.size(-2), mask1.size(-1) mask_h0, mask_w0, mask_h1, mask_w1 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0], mask1[0].sum(-2)[0], mask1[0].sum(-1)[0] mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.agg_size0*self.agg_size0, mask_w0//self.agg_size0*self.agg_size0, mask_h1//self.agg_size1*self.agg_size1, mask_w1//self.agg_size1*self.agg_size1 feat0 = feat0[:, :, :mask_h0, :mask_w0] feat1 = feat1[:, :, :mask_h1, :mask_w1] feature_cropped = True for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)): if feature_cropped: mask0, mask1 = None, None if name == 'self': feat0 = layer(feat0, feat0, mask0, mask0) feat1 = layer(feat1, feat1, mask1, mask1) elif name == 'cross': feat0 = layer(feat0, feat1, mask0, mask1) feat1 = layer(feat1, feat0, mask1, mask0) else: raise KeyError if feature_cropped: # padding feature bs, c, mask_h0, mask_w0 = feat0.size() if mask_h0 != mask_H0: feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2) elif mask_w0 != mask_W0: feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1) bs, c, mask_h1, mask_w1 = feat1.size() if mask_h1 != mask_H1: feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2) elif mask_w1 != mask_W1: feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1) return feat0, feat1