# ------------------------------------------------------------------------ # Grounding DINO # url: https://github.com/IDEA-Research/GroundingDINO # Copyright (c) 2023 IDEA. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from: # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/functions/ms_deform_attn_func.py # https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py # https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/multi_scale_deform_attn.py # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import import math import warnings from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.init import constant_, xavier_uniform_ #try: # from groundingdino import _C # import MultiScaleDeformableAttention as _C import MultiScaleDeformableAttention as _C #except: # warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!") # helpers def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n - 1) == 0) and n != 0 class MultiScaleDeformableAttnFunction(Function): @staticmethod def forward( ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step, ): ctx.im2col_step = im2col_step output = _C.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step, ) ctx.save_for_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ) return output @staticmethod @once_differentiable def backward(ctx, grad_output): ( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ) = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = _C.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step, ) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def multi_scale_deformable_attn_pytorch( value: torch.Tensor, value_spatial_shapes: torch.Tensor, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, ) -> torch.Tensor: bs, _, num_heads, embed_dims = value.shape _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for level, (H_, W_) in enumerate(value_spatial_shapes): # bs, H_*W_, num_heads, embed_dims -> # bs, H_*W_, num_heads*embed_dims -> # bs, num_heads*embed_dims, H_*W_ -> # bs*num_heads, embed_dims, H_, W_ value_l_ = ( value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) ) # bs, num_queries, num_heads, num_points, 2 -> # bs, num_heads, num_queries, num_points, 2 -> # bs*num_heads, num_queries, num_points, 2 sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) # bs*num_heads, embed_dims, num_queries, num_points sampling_value_l_ = F.grid_sample( value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False ) sampling_value_list.append(sampling_value_l_) # (bs, num_queries, num_heads, num_levels, num_points) -> # (bs, num_heads, num_queries, num_levels, num_points) -> # (bs, num_heads, 1, num_queries, num_levels*num_points) attention_weights = attention_weights.transpose(1, 2).reshape( bs * num_heads, 1, num_queries, num_levels * num_points ) output = ( (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) .sum(-1) .view(bs, num_heads * embed_dims, num_queries) ) return output.transpose(1, 2).contiguous() class MultiScaleDeformableAttention(nn.Module): """Multi-Scale Deformable Attention Module used in Deformable-DETR `Deformable DETR: Deformable Transformers for End-to-End Object Detection. `_. Args: embed_dim (int): The embedding dimension of Attention. Default: 256. num_heads (int): The number of attention heads. Default: 8. num_levels (int): The number of feature map used in Attention. Default: 4. num_points (int): The number of sampling points for each query in each head. Default: 4. img2col_steps (int): The step used in image_to_column. Defualt: 64. dropout (float): Dropout layer used in output. Default: 0.1. batch_first (bool): if ``True``, then the input and output tensor will be provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)` """ def __init__( self, embed_dim: int = 256, num_heads: int = 8, num_levels: int = 4, num_points: int = 4, img2col_step: int = 64, batch_first: bool = False, ): super().__init__() if embed_dim % num_heads != 0: raise ValueError( "embed_dim must be divisible by num_heads, but got {} and {}".format( embed_dim, num_heads ) ) head_dim = embed_dim // num_heads self.batch_first = batch_first if not _is_power_of_2(head_dim): warnings.warn( """ You'd better set d_model in MSDeformAttn to make sure that each dim of the attention head a power of 2, which is more efficient. """ ) self.im2col_step = img2col_step self.embed_dim = embed_dim self.num_heads = num_heads self.num_levels = num_levels self.num_points = num_points self.sampling_offsets = nn.Linear(embed_dim, num_heads * num_levels * num_points * 2) self.attention_weights = nn.Linear(embed_dim, num_heads * num_levels * num_points) self.value_proj = nn.Linear(embed_dim, embed_dim) self.output_proj = nn.Linear(embed_dim, embed_dim) self.init_weights() def _reset_parameters(self): return self.init_weights() def init_weights(self): """ Default initialization for Parameters of Module. """ constant_(self.sampling_offsets.weight.data, 0.0) thetas = torch.arange(self.num_heads, dtype=torch.float32) * ( 2.0 * math.pi / self.num_heads ) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) .view(self.num_heads, 1, 1, 2) .repeat(1, self.num_levels, self.num_points, 1) ) for i in range(self.num_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.0) constant_(self.attention_weights.bias.data, 0.0) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.0) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.0) def freeze_sampling_offsets(self): self.sampling_offsets.weight.requires_grad = False self.sampling_offsets.bias.requires_grad = False def freeze_attention_weights(self): self.attention_weights.weight.requires_grad = False self.attention_weights.bias.requires_grad = False def forward( self, query: torch.Tensor, key: Optional[torch.Tensor] = None, value: Optional[torch.Tensor] = None, query_pos: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, reference_points: Optional[torch.Tensor] = None, spatial_shapes: Optional[torch.Tensor] = None, level_start_index: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """Forward Function of MultiScaleDeformableAttention Args: query (torch.Tensor): Query embeddings with shape `(num_query, bs, embed_dim)` key (torch.Tensor): Key embeddings with shape `(num_key, bs, embed_dim)` value (torch.Tensor): Value embeddings with shape `(num_key, bs, embed_dim)` query_pos (torch.Tensor): The position embedding for `query`. Default: None. key_padding_mask (torch.Tensor): ByteTensor for `query`, with shape `(bs, num_key)`, indicating which elements within `key` to be ignored in attention. reference_points (torch.Tensor): The normalized reference points with shape `(bs, num_query, num_levels, 2)`, all elements is range in [0, 1], top-left (0, 0), bottom-right (1, 1), including padding are. or `(N, Length_{query}, num_levels, 4)`, add additional two dimensions `(h, w)` to form reference boxes. spatial_shapes (torch.Tensor): Spatial shape of features in different levels. With shape `(num_levels, 2)`, last dimension represents `(h, w)`. level_start_index (torch.Tensor): The start index of each level. A tensor with shape `(num_levels, )` which can be represented as `[0, h_0 * w_0, h_0 * w_0 + h_1 * w_1, ...]`. Returns: torch.Tensor: forward results with shape `(num_query, bs, embed_dim)` """ if value is None: value = query if query_pos is not None: query = query + query_pos if not self.batch_first: # change to (bs, num_query ,embed_dims) query = query.permute(1, 0, 2) value = value.permute(1, 0, 2) bs, num_query, _ = query.shape bs, num_value, _ = value.shape assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value value = self.value_proj(value) if key_padding_mask is not None: value = value.masked_fill(key_padding_mask[..., None], float(0)) value = value.view(bs, num_value, self.num_heads, -1) sampling_offsets = self.sampling_offsets(query).view( bs, num_query, self.num_heads, self.num_levels, self.num_points, 2 ) attention_weights = self.attention_weights(query).view( bs, num_query, self.num_heads, self.num_levels * self.num_points ) attention_weights = attention_weights.softmax(-1) attention_weights = attention_weights.view( bs, num_query, self.num_heads, self.num_levels, self.num_points, ) # bs, num_query, num_heads, num_levels, num_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] ) elif reference_points.shape[-1] == 4: sampling_locations = ( reference_points[:, :, None, :, None, :2] + sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5 ) else: raise ValueError( "Last dim of reference_points must be 2 or 4, but get {} instead.".format( reference_points.shape[-1] ) ) if torch.cuda.is_available() and value.is_cuda: halffloat = False if value.dtype == torch.float16: halffloat = True value = value.float() sampling_locations = sampling_locations.float() attention_weights = attention_weights.float() output = MultiScaleDeformableAttnFunction.apply( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step, ) if halffloat: output = output.half() else: output = multi_scale_deformable_attn_pytorch( value, spatial_shapes, sampling_locations, attention_weights ) output = self.output_proj(output) if not self.batch_first: output = output.permute(1, 0, 2) return output def create_dummy_class(klass, dependency, message=""): """ When a dependency of a class is not available, create a dummy class which throws ImportError when used. Args: klass (str): name of the class. dependency (str): name of the dependency. message: extra message to print Returns: class: a class object """ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass) if message: err = err + " " + message class _DummyMetaClass(type): # throw error on class attribute access def __getattr__(_, __): # noqa: B902 raise ImportError(err) class _Dummy(object, metaclass=_DummyMetaClass): # throw error on constructor def __init__(self, *args, **kwargs): raise ImportError(err) return _Dummy def create_dummy_func(func, dependency, message=""): """ When a dependency of a function is not available, create a dummy function which throws ImportError when used. Args: func (str): name of the function. dependency (str or list[str]): name(s) of the dependency. message: extra message to print Returns: function: a function object """ err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func) if message: err = err + " " + message if isinstance(dependency, (list, tuple)): dependency = ",".join(dependency) def _dummy(*args, **kwargs): raise ImportError(err) return _dummy