|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import torch.utils.checkpoint as checkpoint |
|
from torch import Tensor, nn |
|
|
|
from groundingdino.util.misc import inverse_sigmoid |
|
|
|
from .fuse_modules import BiAttentionBlock |
|
from .ms_deform_attn import MultiScaleDeformableAttention as MSDeformAttn |
|
from .transformer_vanilla import TransformerEncoderLayer |
|
from .utils import ( |
|
MLP, |
|
_get_activation_fn, |
|
_get_clones, |
|
gen_encoder_output_proposals, |
|
gen_sineembed_for_position, |
|
get_sine_pos_embed, |
|
) |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model=256, |
|
nhead=8, |
|
num_queries=300, |
|
num_encoder_layers=6, |
|
num_unicoder_layers=0, |
|
num_decoder_layers=6, |
|
dim_feedforward=2048, |
|
dropout=0.0, |
|
activation="relu", |
|
normalize_before=False, |
|
return_intermediate_dec=False, |
|
query_dim=4, |
|
num_patterns=0, |
|
|
|
num_feature_levels=1, |
|
enc_n_points=4, |
|
dec_n_points=4, |
|
|
|
learnable_tgt_init=False, |
|
|
|
two_stage_type="no", |
|
embed_init_tgt=False, |
|
|
|
use_text_enhancer=False, |
|
use_fusion_layer=False, |
|
use_checkpoint=False, |
|
use_transformer_ckpt=False, |
|
use_text_cross_attention=False, |
|
text_dropout=0.1, |
|
fusion_dropout=0.1, |
|
fusion_droppath=0.0, |
|
): |
|
super().__init__() |
|
self.num_feature_levels = num_feature_levels |
|
self.num_encoder_layers = num_encoder_layers |
|
self.num_unicoder_layers = num_unicoder_layers |
|
self.num_decoder_layers = num_decoder_layers |
|
self.num_queries = num_queries |
|
assert query_dim == 4 |
|
|
|
|
|
encoder_layer = DeformableTransformerEncoderLayer( |
|
d_model, |
|
dim_feedforward, |
|
dropout, |
|
activation, |
|
num_feature_levels, |
|
nhead, |
|
enc_n_points, |
|
) |
|
|
|
if use_text_enhancer: |
|
text_enhance_layer = TransformerEncoderLayer( |
|
d_model=d_model, |
|
nhead=nhead // 2, |
|
dim_feedforward=dim_feedforward // 2, |
|
dropout=text_dropout, |
|
) |
|
else: |
|
text_enhance_layer = None |
|
|
|
if use_fusion_layer: |
|
feature_fusion_layer = BiAttentionBlock( |
|
v_dim=d_model, |
|
l_dim=d_model, |
|
embed_dim=dim_feedforward // 2, |
|
num_heads=nhead // 2, |
|
dropout=fusion_dropout, |
|
drop_path=fusion_droppath, |
|
) |
|
else: |
|
feature_fusion_layer = None |
|
|
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
|
assert encoder_norm is None |
|
self.encoder = TransformerEncoder( |
|
encoder_layer, |
|
num_encoder_layers, |
|
d_model=d_model, |
|
num_queries=num_queries, |
|
text_enhance_layer=text_enhance_layer, |
|
feature_fusion_layer=feature_fusion_layer, |
|
use_checkpoint=use_checkpoint, |
|
use_transformer_ckpt=use_transformer_ckpt, |
|
) |
|
|
|
|
|
decoder_layer = DeformableTransformerDecoderLayer( |
|
d_model, |
|
dim_feedforward, |
|
dropout, |
|
activation, |
|
num_feature_levels, |
|
nhead, |
|
dec_n_points, |
|
use_text_cross_attention=use_text_cross_attention, |
|
) |
|
|
|
decoder_norm = nn.LayerNorm(d_model) |
|
self.decoder = TransformerDecoder( |
|
decoder_layer, |
|
num_decoder_layers, |
|
decoder_norm, |
|
return_intermediate=return_intermediate_dec, |
|
d_model=d_model, |
|
query_dim=query_dim, |
|
num_feature_levels=num_feature_levels, |
|
) |
|
|
|
self.d_model = d_model |
|
self.nhead = nhead |
|
self.dec_layers = num_decoder_layers |
|
self.num_queries = num_queries |
|
self.num_patterns = num_patterns |
|
if not isinstance(num_patterns, int): |
|
Warning("num_patterns should be int but {}".format(type(num_patterns))) |
|
self.num_patterns = 0 |
|
|
|
if num_feature_levels > 1: |
|
if self.num_encoder_layers > 0: |
|
self.level_embed = nn.Parameter( |
|
torch.Tensor(num_feature_levels, d_model) |
|
) |
|
else: |
|
self.level_embed = None |
|
|
|
self.learnable_tgt_init = learnable_tgt_init |
|
assert learnable_tgt_init, "why not learnable_tgt_init" |
|
self.embed_init_tgt = embed_init_tgt |
|
if (two_stage_type != "no" and embed_init_tgt) or (two_stage_type == "no"): |
|
self.tgt_embed = nn.Embedding(self.num_queries, d_model) |
|
nn.init.normal_(self.tgt_embed.weight.data) |
|
else: |
|
self.tgt_embed = None |
|
|
|
|
|
self.two_stage_type = two_stage_type |
|
assert two_stage_type in [ |
|
"no", |
|
"standard", |
|
], "unknown param {} of two_stage_type".format(two_stage_type) |
|
if two_stage_type == "standard": |
|
|
|
self.enc_output = nn.Linear(d_model, d_model) |
|
self.enc_output_norm = nn.LayerNorm(d_model) |
|
self.two_stage_wh_embedding = None |
|
|
|
if two_stage_type == "no": |
|
self.init_ref_points(num_queries) |
|
|
|
self.enc_out_class_embed = None |
|
self.enc_out_bbox_embed = None |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
for m in self.modules(): |
|
if isinstance(m, MSDeformAttn): |
|
m._reset_parameters() |
|
if self.num_feature_levels > 1 and self.level_embed is not None: |
|
nn.init.normal_(self.level_embed) |
|
|
|
def get_valid_ratio(self, mask): |
|
_, H, W = mask.shape |
|
valid_H = torch.sum(~mask[:, :, 0], 1) |
|
valid_W = torch.sum(~mask[:, 0, :], 1) |
|
valid_ratio_h = valid_H.float() / H |
|
valid_ratio_w = valid_W.float() / W |
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
|
return valid_ratio |
|
|
|
def init_ref_points(self, use_num_queries): |
|
self.refpoint_embed = nn.Embedding(use_num_queries, 4) |
|
|
|
def forward( |
|
self, |
|
srcs, |
|
masks, |
|
refpoint_embed, |
|
pos_embeds, |
|
tgt, |
|
attn_mask=None, |
|
text_dict=None, |
|
): |
|
""" |
|
Input: |
|
- srcs: List of multi features [bs, ci, hi, wi] |
|
- masks: List of multi masks [bs, hi, wi] |
|
- refpoint_embed: [bs, num_dn, 4]. None in infer |
|
- pos_embeds: List of multi pos embeds [bs, ci, hi, wi] |
|
- tgt: [bs, num_dn, d_model]. None in infer |
|
|
|
""" |
|
|
|
src_flatten = [] |
|
mask_flatten = [] |
|
lvl_pos_embed_flatten = [] |
|
spatial_shapes = [] |
|
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): |
|
bs, c, h, w = src.shape |
|
spatial_shape = (h, w) |
|
spatial_shapes.append(spatial_shape) |
|
|
|
src = src.flatten(2).transpose(1, 2) |
|
mask = mask.flatten(1) |
|
pos_embed = pos_embed.flatten(2).transpose(1, 2) |
|
if self.num_feature_levels > 1 and self.level_embed is not None: |
|
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) |
|
else: |
|
lvl_pos_embed = pos_embed |
|
lvl_pos_embed_flatten.append(lvl_pos_embed) |
|
src_flatten.append(src) |
|
mask_flatten.append(mask) |
|
src_flatten = torch.cat(src_flatten, 1) |
|
mask_flatten = torch.cat(mask_flatten, 1) |
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) |
|
spatial_shapes = torch.as_tensor( |
|
spatial_shapes, dtype=torch.long, device=src_flatten.device |
|
) |
|
level_start_index = torch.cat( |
|
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) |
|
) |
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) |
|
|
|
|
|
enc_topk_proposals = enc_refpoint_embed = None |
|
|
|
|
|
|
|
|
|
|
|
memory, memory_text = self.encoder( |
|
src_flatten, |
|
pos=lvl_pos_embed_flatten, |
|
level_start_index=level_start_index, |
|
spatial_shapes=spatial_shapes, |
|
valid_ratios=valid_ratios, |
|
key_padding_mask=mask_flatten, |
|
memory_text=text_dict["encoded_text"], |
|
text_attention_mask=~text_dict["text_token_mask"], |
|
|
|
position_ids=text_dict["position_ids"], |
|
text_self_attention_masks=text_dict["text_self_attention_masks"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_dict["encoded_text"] = memory_text |
|
|
|
|
|
|
|
|
|
if self.two_stage_type == "standard": |
|
output_memory, output_proposals = gen_encoder_output_proposals( |
|
memory, mask_flatten, spatial_shapes |
|
) |
|
output_memory = self.enc_output_norm(self.enc_output(output_memory)) |
|
|
|
if text_dict is not None: |
|
enc_outputs_class_unselected = self.enc_out_class_embed( |
|
output_memory, text_dict |
|
) |
|
else: |
|
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory) |
|
|
|
topk_logits = enc_outputs_class_unselected.max(-1)[0] |
|
enc_outputs_coord_unselected = ( |
|
self.enc_out_bbox_embed(output_memory) + output_proposals |
|
) |
|
topk = self.num_queries |
|
|
|
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] |
|
|
|
|
|
refpoint_embed_undetach = torch.gather( |
|
enc_outputs_coord_unselected, |
|
1, |
|
topk_proposals.unsqueeze(-1).repeat(1, 1, 4), |
|
) |
|
refpoint_embed_ = refpoint_embed_undetach.detach() |
|
init_box_proposal = torch.gather( |
|
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) |
|
).sigmoid() |
|
|
|
|
|
tgt_undetach = torch.gather( |
|
output_memory, |
|
1, |
|
topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model), |
|
) |
|
if self.embed_init_tgt: |
|
tgt_ = ( |
|
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) |
|
) |
|
else: |
|
tgt_ = tgt_undetach.detach() |
|
|
|
if refpoint_embed is not None: |
|
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) |
|
tgt = torch.cat([tgt, tgt_], dim=1) |
|
else: |
|
refpoint_embed, tgt = refpoint_embed_, tgt_ |
|
|
|
elif self.two_stage_type == "no": |
|
tgt_ = ( |
|
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) |
|
) |
|
refpoint_embed_ = ( |
|
self.refpoint_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) |
|
) |
|
|
|
if refpoint_embed is not None: |
|
refpoint_embed = torch.cat([refpoint_embed, refpoint_embed_], dim=1) |
|
tgt = torch.cat([tgt, tgt_], dim=1) |
|
else: |
|
refpoint_embed, tgt = refpoint_embed_, tgt_ |
|
|
|
if self.num_patterns > 0: |
|
tgt_embed = tgt.repeat(1, self.num_patterns, 1) |
|
refpoint_embed = refpoint_embed.repeat(1, self.num_patterns, 1) |
|
tgt_pat = self.patterns.weight[None, :, :].repeat_interleave( |
|
self.num_queries, 1 |
|
) |
|
tgt = tgt_embed + tgt_pat |
|
|
|
init_box_proposal = refpoint_embed_.sigmoid() |
|
|
|
else: |
|
raise NotImplementedError( |
|
"unknown two_stage_type {}".format(self.two_stage_type) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hs, references = self.decoder( |
|
tgt=tgt.transpose(0, 1), |
|
memory=memory.transpose(0, 1), |
|
memory_key_padding_mask=mask_flatten, |
|
pos=lvl_pos_embed_flatten.transpose(0, 1), |
|
refpoints_unsigmoid=refpoint_embed.transpose(0, 1), |
|
level_start_index=level_start_index, |
|
spatial_shapes=spatial_shapes, |
|
valid_ratios=valid_ratios, |
|
tgt_mask=attn_mask, |
|
memory_text=text_dict["encoded_text"], |
|
text_attention_mask=~text_dict["text_token_mask"], |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.two_stage_type == "standard": |
|
hs_enc = tgt_undetach.unsqueeze(0) |
|
ref_enc = refpoint_embed_undetach.sigmoid().unsqueeze(0) |
|
else: |
|
hs_enc = ref_enc = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
return hs, references, hs_enc, ref_enc, init_box_proposal |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
encoder_layer, |
|
num_layers, |
|
d_model=256, |
|
num_queries=300, |
|
enc_layer_share=False, |
|
text_enhance_layer=None, |
|
feature_fusion_layer=None, |
|
use_checkpoint=False, |
|
use_transformer_ckpt=False, |
|
): |
|
"""_summary_ |
|
|
|
Args: |
|
encoder_layer (_type_): _description_ |
|
num_layers (_type_): _description_ |
|
norm (_type_, optional): _description_. Defaults to None. |
|
d_model (int, optional): _description_. Defaults to 256. |
|
num_queries (int, optional): _description_. Defaults to 300. |
|
enc_layer_share (bool, optional): _description_. Defaults to False. |
|
|
|
""" |
|
super().__init__() |
|
|
|
self.layers = [] |
|
self.text_layers = [] |
|
self.fusion_layers = [] |
|
if num_layers > 0: |
|
self.layers = _get_clones( |
|
encoder_layer, num_layers, layer_share=enc_layer_share |
|
) |
|
|
|
if text_enhance_layer is not None: |
|
self.text_layers = _get_clones( |
|
text_enhance_layer, num_layers, layer_share=enc_layer_share |
|
) |
|
if feature_fusion_layer is not None: |
|
self.fusion_layers = _get_clones( |
|
feature_fusion_layer, num_layers, layer_share=enc_layer_share |
|
) |
|
else: |
|
self.layers = [] |
|
del encoder_layer |
|
|
|
if text_enhance_layer is not None: |
|
self.text_layers = [] |
|
del text_enhance_layer |
|
if feature_fusion_layer is not None: |
|
self.fusion_layers = [] |
|
del feature_fusion_layer |
|
|
|
self.query_scale = None |
|
self.num_queries = num_queries |
|
self.num_layers = num_layers |
|
self.d_model = d_model |
|
|
|
self.use_checkpoint = use_checkpoint |
|
self.use_transformer_ckpt = use_transformer_ckpt |
|
|
|
@staticmethod |
|
def get_reference_points(spatial_shapes, valid_ratios, device): |
|
reference_points_list = [] |
|
for lvl, (H_, W_) in enumerate(spatial_shapes): |
|
ref_y, ref_x = torch.meshgrid( |
|
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), |
|
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), |
|
) |
|
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) |
|
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) |
|
ref = torch.stack((ref_x, ref_y), -1) |
|
reference_points_list.append(ref) |
|
reference_points = torch.cat(reference_points_list, 1) |
|
reference_points = reference_points[:, :, None] * valid_ratios[:, None] |
|
return reference_points |
|
|
|
def forward( |
|
self, |
|
|
|
src: Tensor, |
|
pos: Tensor, |
|
spatial_shapes: Tensor, |
|
level_start_index: Tensor, |
|
valid_ratios: Tensor, |
|
key_padding_mask: Tensor, |
|
|
|
memory_text: Tensor = None, |
|
text_attention_mask: Tensor = None, |
|
pos_text: Tensor = None, |
|
text_self_attention_masks: Tensor = None, |
|
position_ids: Tensor = None, |
|
): |
|
""" |
|
Input: |
|
- src: [bs, sum(hi*wi), 256] |
|
- pos: pos embed for src. [bs, sum(hi*wi), 256] |
|
- spatial_shapes: h,w of each level [num_level, 2] |
|
- level_start_index: [num_level] start point of level in sum(hi*wi). |
|
- valid_ratios: [bs, num_level, 2] |
|
- key_padding_mask: [bs, sum(hi*wi)] |
|
|
|
- memory_text: bs, n_text, 256 |
|
- text_attention_mask: bs, n_text |
|
False for no padding; True for padding |
|
- pos_text: bs, n_text, 256 |
|
|
|
- position_ids: bs, n_text |
|
Intermedia: |
|
- reference_points: [bs, sum(hi*wi), num_level, 2] |
|
Outpus: |
|
- output: [bs, sum(hi*wi), 256] |
|
""" |
|
|
|
output = src |
|
|
|
|
|
if self.num_layers > 0: |
|
reference_points = self.get_reference_points( |
|
spatial_shapes, valid_ratios, device=src.device |
|
) |
|
|
|
if self.text_layers: |
|
|
|
bs, n_text, text_dim = memory_text.shape |
|
if pos_text is None and position_ids is None: |
|
pos_text = ( |
|
torch.arange(n_text, device=memory_text.device) |
|
.float() |
|
.unsqueeze(0) |
|
.unsqueeze(-1) |
|
.repeat(bs, 1, 1) |
|
) |
|
pos_text = get_sine_pos_embed( |
|
pos_text, num_pos_feats=256, exchange_xy=False |
|
) |
|
if position_ids is not None: |
|
pos_text = get_sine_pos_embed( |
|
position_ids[..., None], num_pos_feats=256, exchange_xy=False |
|
) |
|
|
|
|
|
for layer_id, layer in enumerate(self.layers): |
|
|
|
|
|
|
|
if self.fusion_layers: |
|
if self.use_checkpoint: |
|
output, memory_text = checkpoint.checkpoint( |
|
self.fusion_layers[layer_id], |
|
output, |
|
memory_text, |
|
key_padding_mask, |
|
text_attention_mask, |
|
) |
|
else: |
|
output, memory_text = self.fusion_layers[layer_id]( |
|
v=output, |
|
l=memory_text, |
|
attention_mask_v=key_padding_mask, |
|
attention_mask_l=text_attention_mask, |
|
) |
|
|
|
if self.text_layers: |
|
memory_text = self.text_layers[layer_id]( |
|
src=memory_text.transpose(0, 1), |
|
src_mask=~text_self_attention_masks, |
|
src_key_padding_mask=text_attention_mask, |
|
pos=(pos_text.transpose(0, 1) if pos_text is not None else None), |
|
).transpose(0, 1) |
|
|
|
|
|
if self.use_transformer_ckpt: |
|
output = checkpoint.checkpoint( |
|
layer, |
|
output, |
|
pos, |
|
reference_points, |
|
spatial_shapes, |
|
level_start_index, |
|
key_padding_mask, |
|
) |
|
else: |
|
output = layer( |
|
src=output, |
|
pos=pos, |
|
reference_points=reference_points, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
key_padding_mask=key_padding_mask, |
|
) |
|
|
|
return output, memory_text |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
decoder_layer, |
|
num_layers, |
|
norm=None, |
|
return_intermediate=False, |
|
d_model=256, |
|
query_dim=4, |
|
num_feature_levels=1, |
|
): |
|
super().__init__() |
|
if num_layers > 0: |
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
else: |
|
self.layers = [] |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
self.return_intermediate = return_intermediate |
|
assert return_intermediate, "support return_intermediate only" |
|
self.query_dim = query_dim |
|
assert query_dim in [2, 4], "query_dim should be 2/4 but {}".format(query_dim) |
|
self.num_feature_levels = num_feature_levels |
|
|
|
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2) |
|
self.query_pos_sine_scale = None |
|
|
|
self.query_scale = None |
|
self.bbox_embed = None |
|
self.class_embed = None |
|
|
|
self.d_model = d_model |
|
|
|
self.ref_anchor_head = None |
|
|
|
def forward( |
|
self, |
|
tgt, |
|
memory, |
|
tgt_mask: Optional[Tensor] = None, |
|
memory_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
refpoints_unsigmoid: Optional[Tensor] = None, |
|
|
|
level_start_index: Optional[Tensor] = None, |
|
spatial_shapes: Optional[Tensor] = None, |
|
valid_ratios: Optional[Tensor] = None, |
|
|
|
memory_text: Optional[Tensor] = None, |
|
text_attention_mask: Optional[Tensor] = None, |
|
): |
|
""" |
|
Input: |
|
- tgt: nq, bs, d_model |
|
- memory: hw, bs, d_model |
|
- pos: hw, bs, d_model |
|
- refpoints_unsigmoid: nq, bs, 2/4 |
|
- valid_ratios/spatial_shapes: bs, nlevel, 2 |
|
""" |
|
output = tgt |
|
|
|
intermediate = [] |
|
reference_points = refpoints_unsigmoid.sigmoid() |
|
ref_points = [reference_points] |
|
|
|
for layer_id, layer in enumerate(self.layers): |
|
if reference_points.shape[-1] == 4: |
|
reference_points_input = ( |
|
reference_points[:, :, None] |
|
* torch.cat([valid_ratios, valid_ratios], -1)[None, :] |
|
) |
|
else: |
|
assert reference_points.shape[-1] == 2 |
|
reference_points_input = ( |
|
reference_points[:, :, None] * valid_ratios[None, :] |
|
) |
|
query_sine_embed = gen_sineembed_for_position( |
|
reference_points_input[:, :, 0, :] |
|
) |
|
|
|
|
|
raw_query_pos = self.ref_point_head(query_sine_embed) |
|
pos_scale = self.query_scale(output) if self.query_scale is not None else 1 |
|
query_pos = pos_scale * raw_query_pos |
|
|
|
|
|
|
|
|
|
|
|
output = layer( |
|
tgt=output, |
|
tgt_query_pos=query_pos, |
|
tgt_query_sine_embed=query_sine_embed, |
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
tgt_reference_points=reference_points_input, |
|
memory_text=memory_text, |
|
text_attention_mask=text_attention_mask, |
|
memory=memory, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
memory_level_start_index=level_start_index, |
|
memory_spatial_shapes=spatial_shapes, |
|
memory_pos=pos, |
|
self_attn_mask=tgt_mask, |
|
cross_attn_mask=memory_mask, |
|
) |
|
if output.isnan().any() | output.isinf().any(): |
|
print(f"output layer_id {layer_id} is nan") |
|
try: |
|
num_nan = output.isnan().sum().item() |
|
num_inf = output.isinf().sum().item() |
|
print(f"num_nan {num_nan}, num_inf {num_inf}") |
|
except Exception as e: |
|
print(e) |
|
|
|
|
|
|
|
|
|
if self.bbox_embed is not None: |
|
|
|
|
|
|
|
|
|
reference_before_sigmoid = inverse_sigmoid(reference_points) |
|
delta_unsig = self.bbox_embed[layer_id](output) |
|
outputs_unsig = delta_unsig + reference_before_sigmoid |
|
new_reference_points = outputs_unsig.sigmoid() |
|
|
|
reference_points = new_reference_points.detach() |
|
|
|
ref_points.append(new_reference_points) |
|
|
|
intermediate.append(self.norm(output)) |
|
|
|
|
|
|
|
return [ |
|
[itm_out.transpose(0, 1) for itm_out in intermediate], |
|
[itm_refpoint.transpose(0, 1) for itm_refpoint in ref_points], |
|
] |
|
|
|
|
|
class DeformableTransformerEncoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model=256, |
|
d_ffn=1024, |
|
dropout=0.1, |
|
activation="relu", |
|
n_levels=4, |
|
n_heads=8, |
|
n_points=4, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.self_attn = MSDeformAttn( |
|
embed_dim=d_model, |
|
num_levels=n_levels, |
|
num_heads=n_heads, |
|
num_points=n_points, |
|
batch_first=True, |
|
) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, d_ffn) |
|
self.activation = _get_activation_fn(activation, d_model=d_ffn) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(d_ffn, d_model) |
|
self.dropout3 = nn.Dropout(dropout) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
@staticmethod |
|
def with_pos_embed(tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_ffn(self, src): |
|
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) |
|
src = src + self.dropout3(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
def forward( |
|
self, |
|
src, |
|
pos, |
|
reference_points, |
|
spatial_shapes, |
|
level_start_index, |
|
key_padding_mask=None, |
|
): |
|
|
|
|
|
src2 = self.self_attn( |
|
query=self.with_pos_embed(src, pos), |
|
reference_points=reference_points, |
|
value=src, |
|
spatial_shapes=spatial_shapes, |
|
level_start_index=level_start_index, |
|
key_padding_mask=key_padding_mask, |
|
) |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
|
|
|
|
src = self.forward_ffn(src) |
|
|
|
return src |
|
|
|
|
|
class DeformableTransformerDecoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model=256, |
|
d_ffn=1024, |
|
dropout=0.1, |
|
activation="relu", |
|
n_levels=4, |
|
n_heads=8, |
|
n_points=4, |
|
use_text_feat_guide=False, |
|
use_text_cross_attention=False, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.cross_attn = MSDeformAttn( |
|
embed_dim=d_model, |
|
num_levels=n_levels, |
|
num_heads=n_heads, |
|
num_points=n_points, |
|
batch_first=True, |
|
) |
|
self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
|
|
|
if use_text_cross_attention: |
|
self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) |
|
self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
self.catext_norm = nn.LayerNorm(d_model) |
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) |
|
self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, d_ffn) |
|
self.activation = _get_activation_fn(activation, d_model=d_ffn, batch_dim=1) |
|
self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
self.linear2 = nn.Linear(d_ffn, d_model) |
|
self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
|
self.key_aware_proj = None |
|
self.use_text_feat_guide = use_text_feat_guide |
|
assert not use_text_feat_guide |
|
self.use_text_cross_attention = use_text_cross_attention |
|
|
|
def rm_self_attn_modules(self): |
|
self.self_attn = None |
|
self.dropout2 = None |
|
self.norm2 = None |
|
|
|
@staticmethod |
|
def with_pos_embed(tensor, pos): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_ffn(self, tgt): |
|
with torch.cuda.amp.autocast(enabled=False): |
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout4(tgt2) |
|
tgt = self.norm3(tgt) |
|
return tgt |
|
|
|
def forward( |
|
self, |
|
|
|
tgt: Optional[Tensor], |
|
tgt_query_pos: Optional[Tensor] = None, |
|
tgt_query_sine_embed: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
tgt_reference_points: Optional[Tensor] = None, |
|
memory_text: Optional[Tensor] = None, |
|
text_attention_mask: Optional[Tensor] = None, |
|
|
|
memory: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
memory_level_start_index: Optional[Tensor] = None, |
|
memory_spatial_shapes: Optional[Tensor] = None, |
|
memory_pos: Optional[Tensor] = None, |
|
|
|
self_attn_mask: Optional[Tensor] = None, |
|
cross_attn_mask: Optional[Tensor] = None, |
|
): |
|
""" |
|
Input: |
|
- tgt/tgt_query_pos: nq, bs, d_model |
|
- |
|
""" |
|
assert cross_attn_mask is None |
|
|
|
|
|
if self.self_attn is not None: |
|
|
|
q = k = self.with_pos_embed(tgt, tgt_query_pos) |
|
tgt2 = self.self_attn(q, k, tgt, attn_mask=self_attn_mask)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt = self.norm2(tgt) |
|
|
|
if self.use_text_cross_attention: |
|
tgt2 = self.ca_text( |
|
self.with_pos_embed(tgt, tgt_query_pos), |
|
memory_text.transpose(0, 1), |
|
memory_text.transpose(0, 1), |
|
key_padding_mask=text_attention_mask, |
|
)[0] |
|
tgt = tgt + self.catext_dropout(tgt2) |
|
tgt = self.catext_norm(tgt) |
|
|
|
tgt2 = self.cross_attn( |
|
query=self.with_pos_embed(tgt, tgt_query_pos).transpose(0, 1), |
|
reference_points=tgt_reference_points.transpose(0, 1).contiguous(), |
|
value=memory.transpose(0, 1), |
|
spatial_shapes=memory_spatial_shapes, |
|
level_start_index=memory_level_start_index, |
|
key_padding_mask=memory_key_padding_mask, |
|
).transpose(0, 1) |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt = self.norm1(tgt) |
|
|
|
|
|
tgt = self.forward_ffn(tgt) |
|
|
|
return tgt |
|
|
|
|
|
def build_transformer(args): |
|
return Transformer( |
|
d_model=args.hidden_dim, |
|
dropout=args.dropout, |
|
nhead=args.nheads, |
|
num_queries=args.num_queries, |
|
dim_feedforward=args.dim_feedforward, |
|
num_encoder_layers=args.enc_layers, |
|
num_decoder_layers=args.dec_layers, |
|
normalize_before=args.pre_norm, |
|
return_intermediate_dec=True, |
|
query_dim=args.query_dim, |
|
activation=args.transformer_activation, |
|
num_patterns=args.num_patterns, |
|
num_feature_levels=args.num_feature_levels, |
|
enc_n_points=args.enc_n_points, |
|
dec_n_points=args.dec_n_points, |
|
learnable_tgt_init=True, |
|
|
|
two_stage_type=args.two_stage_type, |
|
embed_init_tgt=args.embed_init_tgt, |
|
use_text_enhancer=args.use_text_enhancer, |
|
use_fusion_layer=args.use_fusion_layer, |
|
use_checkpoint=args.use_checkpoint, |
|
use_transformer_ckpt=args.use_transformer_ckpt, |
|
use_text_cross_attention=args.use_text_cross_attention, |
|
text_dropout=args.text_dropout, |
|
fusion_dropout=args.fusion_dropout, |
|
fusion_droppath=args.fusion_droppath, |
|
) |
|
|