Spaces:
Running
Running
File size: 7,577 Bytes
e02ffe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .linear_attention import Attention, crop_feature, pad_feature
from einops.einops import rearrange
from collections import OrderedDict
from ..utils.position_encoding import RoPEPositionEncodingSine
import numpy as np
from loguru import logger
class AG_RoPE_EncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
agg_size0=4,
agg_size1=4,
no_flash=False,
rope=False,
npe=None,
fp32=False,
):
super(AG_RoPE_EncoderLayer, self).__init__()
self.dim = d_model // nhead
self.nhead = nhead
self.agg_size0, self.agg_size1 = agg_size0, agg_size1
self.rope = rope
# aggregate and position encoding
self.aggregate = nn.Conv2d(d_model, d_model, kernel_size=agg_size0, padding=0, stride=agg_size0, bias=False, groups=d_model) if self.agg_size0 != 1 else nn.Identity()
self.max_pool = torch.nn.MaxPool2d(kernel_size=self.agg_size1, stride=self.agg_size1) if self.agg_size1 != 1 else nn.Identity()
self.rope_pos_enc = RoPEPositionEncodingSine(d_model, max_shape=(256, 256), npe=npe, ropefp16=True)
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = Attention(no_flash, self.nhead, self.dim, fp32)
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model*2, d_model*2, bias=False),
nn.LeakyReLU(inplace = True),
nn.Linear(d_model*2, d_model, bias=False),
)
# norm
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, source, x_mask=None, source_mask=None):
"""
Args:
x (torch.Tensor): [N, C, H0, W0]
source (torch.Tensor): [N, C, H1, W1]
x_mask (torch.Tensor): [N, H0, W0] (optional) (L = H0*W0)
source_mask (torch.Tensor): [N, H1, W1] (optional) (S = H1*W1)
"""
bs, C, H0, W0 = x.size()
H1, W1 = source.size(-2), source.size(-1)
# Aggragate feature
assert x_mask is None and source_mask is None
query, source = self.norm1(self.aggregate(x).permute(0,2,3,1)), self.norm1(self.max_pool(source).permute(0,2,3,1)) # [N, H, W, C]
if x_mask is not None:
x_mask, source_mask = map(lambda x: self.max_pool(x.float()).bool(), [x_mask, source_mask])
query, key, value = self.q_proj(query), self.k_proj(source), self.v_proj(source)
# Positional encoding
if self.rope:
query = self.rope_pos_enc(query)
key = self.rope_pos_enc(key)
# multi-head attention handle padding mask
m = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask)
m = self.merge(m.reshape(bs, -1, self.nhead*self.dim)) # [N, L, C]
# Upsample feature
m = rearrange(m, 'b (h w) c -> b c h w', h=H0 // self.agg_size0, w=W0 // self.agg_size0) # [N, C, H0, W0]
if self.agg_size0 != 1:
m = torch.nn.functional.interpolate(m, scale_factor=self.agg_size0, mode='bilinear', align_corners=False) # [N, C, H0, W0]
# feed-forward network
m = self.mlp(torch.cat([x, m], dim=1).permute(0, 2, 3, 1)) # [N, H0, W0, C]
m = self.norm2(m).permute(0, 3, 1, 2) # [N, C, H0, W0]
return x + m
class LocalFeatureTransformer(nn.Module):
"""A Local Feature Transformer (LoFTR) module."""
def __init__(self, config):
super(LocalFeatureTransformer, self).__init__()
self.full_config = config
self.fp32 = not (config['mp'] or config['half'])
config = config['coarse']
self.d_model = config['d_model']
self.nhead = config['nhead']
self.layer_names = config['layer_names']
self.agg_size0, self.agg_size1 = config['agg_size0'], config['agg_size1']
self.rope = config['rope']
self_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'],
config['no_flash'], config['rope'], config['npe'], self.fp32)
cross_layer = AG_RoPE_EncoderLayer(config['d_model'], config['nhead'], config['agg_size0'], config['agg_size1'],
config['no_flash'], False, config['npe'], self.fp32)
self.layers = nn.ModuleList([copy.deepcopy(self_layer) if _ == 'self' else copy.deepcopy(cross_layer) for _ in self.layer_names])
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None, data=None):
"""
Args:
feat0 (torch.Tensor): [N, C, H, W]
feat1 (torch.Tensor): [N, C, H, W]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
H0, W0, H1, W1 = feat0.size(-2), feat0.size(-1), feat1.size(-2), feat1.size(-1)
bs = feat0.shape[0]
feature_cropped = False
if bs == 1 and mask0 is not None and mask1 is not None:
mask_H0, mask_W0, mask_H1, mask_W1 = mask0.size(-2), mask0.size(-1), mask1.size(-2), mask1.size(-1)
mask_h0, mask_w0, mask_h1, mask_w1 = mask0[0].sum(-2)[0], mask0[0].sum(-1)[0], mask1[0].sum(-2)[0], mask1[0].sum(-1)[0]
mask_h0, mask_w0, mask_h1, mask_w1 = mask_h0//self.agg_size0*self.agg_size0, mask_w0//self.agg_size0*self.agg_size0, mask_h1//self.agg_size1*self.agg_size1, mask_w1//self.agg_size1*self.agg_size1
feat0 = feat0[:, :, :mask_h0, :mask_w0]
feat1 = feat1[:, :, :mask_h1, :mask_w1]
feature_cropped = True
for i, (layer, name) in enumerate(zip(self.layers, self.layer_names)):
if feature_cropped:
mask0, mask1 = None, None
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == 'cross':
feat0 = layer(feat0, feat1, mask0, mask1)
feat1 = layer(feat1, feat0, mask1, mask0)
else:
raise KeyError
if feature_cropped:
# padding feature
bs, c, mask_h0, mask_w0 = feat0.size()
if mask_h0 != mask_H0:
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0-mask_h0, mask_W0, device=feat0.device, dtype=feat0.dtype)], dim=-2)
elif mask_w0 != mask_W0:
feat0 = torch.cat([feat0, torch.zeros(bs, c, mask_H0, mask_W0-mask_w0, device=feat0.device, dtype=feat0.dtype)], dim=-1)
bs, c, mask_h1, mask_w1 = feat1.size()
if mask_h1 != mask_H1:
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1-mask_h1, mask_W1, device=feat1.device, dtype=feat1.dtype)], dim=-2)
elif mask_w1 != mask_W1:
feat1 = torch.cat([feat1, torch.zeros(bs, c, mask_H1, mask_W1-mask_w1, device=feat1.device, dtype=feat1.dtype)], dim=-1)
return feat0, feat1 |