|
from loguru import logger |
|
import copy |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .linear_attention import LinearAttention, FullAttention |
|
|
|
|
|
class LoFTREncoderLayer(nn.Module): |
|
def __init__(self, |
|
d_model, |
|
nhead, |
|
attention='linear'): |
|
super(LoFTREncoderLayer, self).__init__() |
|
|
|
self.dim = d_model // nhead |
|
self.nhead = nhead |
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=False) |
|
self.k_proj = nn.Linear(d_model, d_model, bias=False) |
|
self.v_proj = nn.Linear(d_model, d_model, bias=False) |
|
self.attention = LinearAttention() if attention == 'linear' else FullAttention() |
|
self.merge = nn.Linear(d_model, d_model, bias=False) |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
nn.Linear(d_model*2, d_model*2, bias=False), |
|
nn.GELU(), |
|
nn.Linear(d_model*2, d_model, bias=False), |
|
) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
def forward(self, x, source, x_mask=None, source_mask=None): |
|
""" |
|
Args: |
|
x (torch.Tensor): [N, L, C] |
|
source (torch.Tensor): [N, S, C] |
|
x_mask (torch.Tensor): [N, L] (optional) |
|
source_mask (torch.Tensor): [N, S] (optional) |
|
""" |
|
bs = x.shape[0] |
|
query, key, value = x, source, source |
|
|
|
|
|
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) |
|
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) |
|
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) |
|
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) |
|
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) |
|
message = self.norm1(message) |
|
|
|
|
|
message = self.mlp(torch.cat([x, message], dim=2)) |
|
message = self.norm2(message) |
|
|
|
return x + message |
|
|
|
|
|
class TopicFormer(nn.Module): |
|
"""A Local Feature Transformer (LoFTR) module.""" |
|
|
|
def __init__(self, config): |
|
super(TopicFormer, self).__init__() |
|
|
|
self.config = config |
|
self.d_model = config['d_model'] |
|
self.nhead = config['nhead'] |
|
self.layer_names = config['layer_names'] |
|
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) |
|
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) |
|
|
|
self.topic_transformers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2*config['n_topic_transformers'])]) if config['n_samples'] > 0 else None |
|
self.n_iter_topic_transformer = config['n_topic_transformers'] |
|
|
|
self.seed_tokens = nn.Parameter(torch.randn(config['n_topics'], config['d_model'])) |
|
self.register_parameter('seed_tokens', self.seed_tokens) |
|
self.n_samples = config['n_samples'] |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def sample_topic(self, prob_topics, topics, L): |
|
""" |
|
Args: |
|
topics (torch.Tensor): [N, L+S, K] |
|
""" |
|
prob_topics0, prob_topics1 = prob_topics[:, :L], prob_topics[:, L:] |
|
topics0, topics1 = topics[:, :L], topics[:, L:] |
|
|
|
theta0 = F.normalize(prob_topics0.sum(dim=1), p=1, dim=-1) |
|
theta1 = F.normalize(prob_topics1.sum(dim=1), p=1, dim=-1) |
|
theta = F.normalize(theta0 * theta1, p=1, dim=-1) |
|
if self.n_samples == 0: |
|
return None |
|
if self.training: |
|
sampled_inds = torch.multinomial(theta, self.n_samples) |
|
sampled_values = torch.gather(theta, dim=-1, index=sampled_inds) |
|
else: |
|
sampled_values, sampled_inds = torch.topk(theta, self.n_samples, dim=-1) |
|
sampled_topics0 = torch.gather(topics0, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics0.shape[1], 1)) |
|
sampled_topics1 = torch.gather(topics1, dim=-1, index=sampled_inds.unsqueeze(1).repeat(1, topics1.shape[1], 1)) |
|
return sampled_topics0, sampled_topics1 |
|
|
|
def reduce_feat(self, feat, topick, N, C): |
|
len_topic = topick.sum(dim=-1).int() |
|
max_len = len_topic.max().item() |
|
selected_ids = topick.bool() |
|
resized_feat = torch.zeros((N, max_len, C), dtype=torch.float32, device=feat.device) |
|
new_mask = torch.zeros_like(resized_feat[..., 0]).bool() |
|
for i in range(N): |
|
new_mask[i, :len_topic[i]] = True |
|
resized_feat[new_mask, :] = feat[selected_ids, :] |
|
return resized_feat, new_mask, selected_ids |
|
|
|
def forward(self, feat0, feat1, mask0=None, mask1=None): |
|
""" |
|
Args: |
|
feat0 (torch.Tensor): [N, L, C] |
|
feat1 (torch.Tensor): [N, S, C] |
|
mask0 (torch.Tensor): [N, L] (optional) |
|
mask1 (torch.Tensor): [N, S] (optional) |
|
""" |
|
|
|
assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal" |
|
N, L, S, C, K = feat0.shape[0], feat0.shape[1], feat1.shape[1], feat0.shape[2], self.config['n_topics'] |
|
|
|
seeds = self.seed_tokens.unsqueeze(0).repeat(N, 1, 1) |
|
|
|
feat = torch.cat((feat0, feat1), dim=1) |
|
if mask0 is not None: |
|
mask = torch.cat((mask0, mask1), dim=-1) |
|
else: |
|
mask = None |
|
|
|
for layer, name in zip(self.layers, self.layer_names): |
|
if name == 'seed': |
|
|
|
|
|
seeds = layer(seeds, feat, None, mask) |
|
elif name == 'feat': |
|
feat0 = layer(feat0, seeds, mask0, None) |
|
feat1 = layer(feat1, seeds, mask1, None) |
|
|
|
dmatrix = torch.einsum("nmd,nkd->nmk", feat, seeds) |
|
prob_topics = F.softmax(dmatrix, dim=-1) |
|
|
|
feat_topics = torch.zeros_like(dmatrix).scatter_(-1, torch.argmax(dmatrix, dim=-1, keepdim=True), 1.0) |
|
|
|
if mask is not None: |
|
feat_topics = feat_topics * mask.unsqueeze(-1) |
|
prob_topics = prob_topics * mask.unsqueeze(-1) |
|
|
|
if (feat_topics.detach().sum(dim=1).sum(dim=0) > 100).sum() <= 3: |
|
logger.warning("topic distribution is highly sparse!") |
|
sampled_topics = self.sample_topic(prob_topics.detach(), feat_topics, L) |
|
if sampled_topics is not None: |
|
updated_feat0, updated_feat1 = torch.zeros_like(feat0), torch.zeros_like(feat1) |
|
s_topics0, s_topics1 = sampled_topics |
|
for k in range(s_topics0.shape[-1]): |
|
topick0, topick1 = s_topics0[..., k], s_topics1[..., k] |
|
if (topick0.sum() > 0) and (topick1.sum() > 0): |
|
new_feat0, new_mask0, selected_ids0 = self.reduce_feat(feat0, topick0, N, C) |
|
new_feat1, new_mask1, selected_ids1 = self.reduce_feat(feat1, topick1, N, C) |
|
for idt in range(self.n_iter_topic_transformer): |
|
new_feat0 = self.topic_transformers[idt*2](new_feat0, new_feat0, new_mask0, new_mask0) |
|
new_feat1 = self.topic_transformers[idt*2](new_feat1, new_feat1, new_mask1, new_mask1) |
|
new_feat0 = self.topic_transformers[idt*2+1](new_feat0, new_feat1, new_mask0, new_mask1) |
|
new_feat1 = self.topic_transformers[idt*2+1](new_feat1, new_feat0, new_mask1, new_mask0) |
|
updated_feat0[selected_ids0, :] = new_feat0[new_mask0, :] |
|
updated_feat1[selected_ids1, :] = new_feat1[new_mask1, :] |
|
|
|
feat0 = (1 - s_topics0.sum(dim=-1, keepdim=True)) * feat0 + updated_feat0 |
|
feat1 = (1 - s_topics1.sum(dim=-1, keepdim=True)) * feat1 + updated_feat1 |
|
|
|
conf_matrix = torch.einsum("nlc,nsc->nls", feat0, feat1) / C**.5 |
|
if self.training: |
|
topic_matrix = torch.einsum("nlk,nsk->nls", prob_topics[:, :L], prob_topics[:, L:]) |
|
outlier_mask = torch.einsum("nlk,nsk->nls", feat_topics[:, :L], feat_topics[:, L:]) |
|
else: |
|
topic_matrix = {"img0": feat_topics[:, :L], "img1": feat_topics[:, L:]} |
|
outlier_mask = torch.ones_like(conf_matrix) |
|
if mask0 is not None: |
|
outlier_mask = (outlier_mask * mask0[..., None] * mask1[:, None]) |
|
conf_matrix.masked_fill_(~outlier_mask.bool(), -1e9) |
|
conf_matrix = F.softmax(conf_matrix, 1) * F.softmax(conf_matrix, 2) |
|
|
|
return feat0, feat1, conf_matrix, topic_matrix |
|
|
|
|
|
class LocalFeatureTransformer(nn.Module): |
|
"""A Local Feature Transformer (LoFTR) module.""" |
|
|
|
def __init__(self, config): |
|
super(LocalFeatureTransformer, self).__init__() |
|
|
|
self.config = config |
|
self.d_model = config['d_model'] |
|
self.nhead = config['nhead'] |
|
self.layer_names = config['layer_names'] |
|
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) |
|
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(2)]) |
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, feat0, feat1, mask0=None, mask1=None): |
|
""" |
|
Args: |
|
feat0 (torch.Tensor): [N, L, C] |
|
feat1 (torch.Tensor): [N, S, C] |
|
mask0 (torch.Tensor): [N, L] (optional) |
|
mask1 (torch.Tensor): [N, S] (optional) |
|
""" |
|
|
|
assert self.d_model == feat0.shape[2], "the feature number of src and transformer must be equal" |
|
|
|
feat0 = self.layers[0](feat0, feat1, mask0, mask1) |
|
feat1 = self.layers[1](feat1, feat0, mask1, mask0) |
|
|
|
return feat0, feat1 |
|
|