|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
from typing import List |
|
import torchvision.transforms.functional as vis_F |
|
from torchvision.transforms import InterpolationMode |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torchvision.ops.boxes import nms |
|
from torchvision.ops import roi_align |
|
from transformers import ( |
|
AutoTokenizer, |
|
BertModel, |
|
BertTokenizer, |
|
RobertaModel, |
|
RobertaTokenizerFast, |
|
) |
|
|
|
from groundingdino.util import box_ops, get_tokenlizer |
|
from groundingdino.util.misc import ( |
|
NestedTensor, |
|
accuracy, |
|
get_world_size, |
|
interpolate, |
|
inverse_sigmoid, |
|
is_dist_avail_and_initialized, |
|
nested_tensor_from_tensor_list, |
|
) |
|
from groundingdino.util.utils import get_phrases_from_posmap |
|
from groundingdino.util.visualizer import COCOVisualizer |
|
from groundingdino.util.vl_utils import create_positive_map_from_span |
|
|
|
from ..registry import MODULE_BUILD_FUNCS |
|
from .backbone import build_backbone |
|
from .bertwarper import ( |
|
BertModelWarper, |
|
generate_masks_with_special_tokens, |
|
generate_masks_with_special_tokens_and_transfer_map, |
|
) |
|
from .transformer import build_transformer |
|
from .utils import MLP, ContrastiveEmbed, sigmoid_focal_loss |
|
|
|
from .matcher import build_matcher |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.patches import Rectangle |
|
from groundingdino.util.visualizer import renorm |
|
|
|
|
|
def numpy_2_cv2(np_img): |
|
if np.min(np_img) < 0: |
|
raise Exception("image min is less than 0. Img min: " + str(np.min(np_img))) |
|
if np.max(np_img) > 1: |
|
raise Exception("image max is greater than 1. Img max: " + str(np.max(np_img))) |
|
np_img = (np_img * 255).astype(np.uint8) |
|
|
|
cv2_image = np.asarray(np_img) |
|
return cv2_image |
|
|
|
|
|
def vis_exemps(image, exemp, f_name): |
|
plt.imshow(image) |
|
plt.gca().add_patch( |
|
Rectangle( |
|
(exemp[0], exemp[1]), |
|
exemp[2] - exemp[0], |
|
exemp[3] - exemp[1], |
|
edgecolor="red", |
|
facecolor="none", |
|
lw=1, |
|
) |
|
) |
|
plt.savefig(f_name) |
|
plt.close() |
|
|
|
|
|
class GroundingDINO(nn.Module): |
|
"""This is the Cross-Attention Detector module that performs object detection""" |
|
|
|
def __init__( |
|
self, |
|
backbone, |
|
transformer, |
|
num_queries, |
|
aux_loss=False, |
|
iter_update=False, |
|
query_dim=2, |
|
num_feature_levels=1, |
|
nheads=8, |
|
|
|
two_stage_type="no", |
|
dec_pred_bbox_embed_share=True, |
|
two_stage_class_embed_share=True, |
|
two_stage_bbox_embed_share=True, |
|
num_patterns=0, |
|
dn_number=100, |
|
dn_box_noise_scale=0.4, |
|
dn_label_noise_ratio=0.5, |
|
dn_labelbook_size=100, |
|
text_encoder_type="bert-base-uncased", |
|
sub_sentence_present=True, |
|
max_text_len=256, |
|
): |
|
"""Initializes the model. |
|
Parameters: |
|
backbone: torch module of the backbone to be used. See backbone.py |
|
transformer: torch module of the transformer architecture. See transformer.py |
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects |
|
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. |
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. |
|
""" |
|
super().__init__() |
|
self.num_queries = num_queries |
|
self.transformer = transformer |
|
self.hidden_dim = hidden_dim = transformer.d_model |
|
self.num_feature_levels = num_feature_levels |
|
self.nheads = nheads |
|
self.max_text_len = max_text_len |
|
self.sub_sentence_present = sub_sentence_present |
|
|
|
|
|
self.query_dim = query_dim |
|
assert query_dim == 4 |
|
|
|
|
|
self.feature_map_proj = nn.Conv2d((256 + 512 + 1024), hidden_dim, kernel_size=1) |
|
|
|
|
|
self.num_patterns = num_patterns |
|
self.dn_number = dn_number |
|
self.dn_box_noise_scale = dn_box_noise_scale |
|
self.dn_label_noise_ratio = dn_label_noise_ratio |
|
self.dn_labelbook_size = dn_labelbook_size |
|
|
|
|
|
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) |
|
self.bert = get_tokenlizer.get_pretrained_language_model(text_encoder_type) |
|
self.bert.pooler.dense.weight.requires_grad_(False) |
|
self.bert.pooler.dense.bias.requires_grad_(False) |
|
self.bert = BertModelWarper(bert_model=self.bert) |
|
|
|
self.feat_map = nn.Linear( |
|
self.bert.config.hidden_size, self.hidden_dim, bias=True |
|
) |
|
nn.init.constant_(self.feat_map.bias.data, 0) |
|
nn.init.xavier_uniform_(self.feat_map.weight.data) |
|
|
|
|
|
|
|
self.specical_tokens = self.tokenizer.convert_tokens_to_ids( |
|
["[CLS]", "[SEP]", ".", "?"] |
|
) |
|
|
|
|
|
if num_feature_levels > 1: |
|
num_backbone_outs = len(backbone.num_channels) |
|
input_proj_list = [] |
|
for _ in range(num_backbone_outs): |
|
in_channels = backbone.num_channels[_] |
|
input_proj_list.append( |
|
nn.Sequential( |
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
) |
|
) |
|
for _ in range(num_feature_levels - num_backbone_outs): |
|
input_proj_list.append( |
|
nn.Sequential( |
|
nn.Conv2d( |
|
in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 |
|
), |
|
nn.GroupNorm(32, hidden_dim), |
|
) |
|
) |
|
in_channels = hidden_dim |
|
self.input_proj = nn.ModuleList(input_proj_list) |
|
else: |
|
assert ( |
|
two_stage_type == "no" |
|
), "two_stage_type should be no if num_feature_levels=1 !!!" |
|
self.input_proj = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
) |
|
] |
|
) |
|
|
|
self.backbone = backbone |
|
self.aux_loss = aux_loss |
|
self.box_pred_damping = box_pred_damping = None |
|
|
|
self.iter_update = iter_update |
|
assert iter_update, "Why not iter_update?" |
|
|
|
|
|
self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share |
|
|
|
_class_embed = ContrastiveEmbed() |
|
|
|
_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) |
|
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) |
|
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) |
|
|
|
if dec_pred_bbox_embed_share: |
|
box_embed_layerlist = [ |
|
_bbox_embed for i in range(transformer.num_decoder_layers) |
|
] |
|
else: |
|
box_embed_layerlist = [ |
|
copy.deepcopy(_bbox_embed) |
|
for i in range(transformer.num_decoder_layers) |
|
] |
|
class_embed_layerlist = [ |
|
_class_embed for i in range(transformer.num_decoder_layers) |
|
] |
|
self.bbox_embed = nn.ModuleList(box_embed_layerlist) |
|
self.class_embed = nn.ModuleList(class_embed_layerlist) |
|
self.transformer.decoder.bbox_embed = self.bbox_embed |
|
self.transformer.decoder.class_embed = self.class_embed |
|
|
|
|
|
self.two_stage_type = two_stage_type |
|
assert two_stage_type in [ |
|
"no", |
|
"standard", |
|
], "unknown param {} of two_stage_type".format(two_stage_type) |
|
if two_stage_type != "no": |
|
if two_stage_bbox_embed_share: |
|
assert dec_pred_bbox_embed_share |
|
self.transformer.enc_out_bbox_embed = _bbox_embed |
|
else: |
|
self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) |
|
|
|
if two_stage_class_embed_share: |
|
assert dec_pred_bbox_embed_share |
|
self.transformer.enc_out_class_embed = _class_embed |
|
else: |
|
self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) |
|
|
|
self.refpoint_embed = None |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
|
|
for proj in self.input_proj: |
|
nn.init.xavier_uniform_(proj[0].weight, gain=1) |
|
nn.init.constant_(proj[0].bias, 0) |
|
|
|
def init_ref_points(self, use_num_queries): |
|
self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) |
|
|
|
def add_exemplar_tokens(self, tokenized, text_dict, exemplar_tokens, labels): |
|
input_ids = tokenized["input_ids"] |
|
|
|
device = input_ids.device |
|
new_input_ids = [] |
|
encoded_text = text_dict["encoded_text"] |
|
new_encoded_text = [] |
|
text_token_mask = text_dict["text_token_mask"] |
|
new_text_token_mask = [] |
|
position_ids = text_dict["position_ids"] |
|
text_self_attention_masks = text_dict["text_self_attention_masks"] |
|
|
|
for sample_ind in range(len(labels)): |
|
label = labels[sample_ind][0] |
|
exemplars = exemplar_tokens[sample_ind] |
|
label_count = -1 |
|
assert len(input_ids[sample_ind]) == len(position_ids[sample_ind]) |
|
for token_ind in range(len(input_ids[sample_ind])): |
|
input_id = input_ids[sample_ind][token_ind] |
|
if (input_id not in self.specical_tokens) and ( |
|
token_ind == 0 |
|
or (input_ids[sample_ind][token_ind - 1] in self.specical_tokens) |
|
): |
|
label_count += 1 |
|
if label_count == label: |
|
|
|
ind_to_insert_exemplar = token_ind |
|
while ( |
|
input_ids[sample_ind][ind_to_insert_exemplar] |
|
not in self.specical_tokens |
|
): |
|
ind_to_insert_exemplar += 1 |
|
break |
|
|
|
|
|
if label_count == -1: |
|
ind_to_insert_exemplar = 1 |
|
|
|
new_input_ids.append( |
|
torch.cat( |
|
[ |
|
input_ids[sample_ind][:ind_to_insert_exemplar], |
|
torch.tensor([1008] * exemplars.shape[0]).to(device), |
|
input_ids[sample_ind][ind_to_insert_exemplar:], |
|
] |
|
) |
|
) |
|
new_encoded_text.append( |
|
torch.cat( |
|
[ |
|
encoded_text[sample_ind][:ind_to_insert_exemplar, :], |
|
exemplars, |
|
encoded_text[sample_ind][ind_to_insert_exemplar:, :], |
|
] |
|
) |
|
) |
|
new_text_token_mask.append( |
|
torch.full((len(new_input_ids[sample_ind]),), True).to(device) |
|
) |
|
|
|
tokenized["input_ids"] = torch.stack(new_input_ids) |
|
print(tokenized["input_ids"]) |
|
|
|
( |
|
text_self_attention_masks, |
|
position_ids, |
|
_, |
|
) = generate_masks_with_special_tokens_and_transfer_map( |
|
tokenized, self.specical_tokens, None |
|
) |
|
|
|
return { |
|
"encoded_text": torch.stack(new_encoded_text), |
|
"text_token_mask": torch.stack(new_text_token_mask), |
|
"position_ids": position_ids, |
|
"text_self_attention_masks": text_self_attention_masks, |
|
} |
|
|
|
def combine_features(self, features): |
|
(bs, c, h, w) = ( |
|
features[0].decompose()[0].shape[-4], |
|
features[0].decompose()[0].shape[-3], |
|
features[0].decompose()[0].shape[-2], |
|
features[0].decompose()[0].shape[-1], |
|
) |
|
|
|
x = torch.cat( |
|
[ |
|
F.interpolate( |
|
feat.decompose()[0], |
|
size=(h, w), |
|
mode="bilinear", |
|
align_corners=True, |
|
) |
|
for feat in features |
|
], |
|
dim=1, |
|
) |
|
|
|
x = self.feature_map_proj(x) |
|
|
|
return x |
|
|
|
def forward( |
|
self, |
|
samples: NestedTensor, |
|
exemplar_images: NestedTensor, |
|
exemplars: List, |
|
labels, |
|
targets: List = None, |
|
cropped=False, |
|
orig_img=None, |
|
crop_width=0, |
|
crop_height=0, |
|
**kw, |
|
): |
|
"""The forward expects a NestedTensor, which consists of: |
|
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] |
|
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels |
|
|
|
It returns a dict with the following elements: |
|
- "pred_logits": the classification logits (including no-object) for all queries. |
|
Shape= [batch_size x num_queries x num_classes] |
|
- "pred_boxes": The normalized boxes coordinates for all queries, represented as |
|
(center_x, center_y, width, height). These values are normalized in [0, 1], |
|
relative to the size of each individual image (disregarding possible padding). |
|
See PostProcess for information on how to retrieve the unnormalized bounding box. |
|
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of |
|
dictionnaries containing the two above keys for each decoder layer. |
|
""" |
|
|
|
if targets is None: |
|
captions = kw["captions"] |
|
else: |
|
captions = [t["caption"] for t in targets] |
|
|
|
|
|
|
|
tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( |
|
samples.device |
|
) |
|
|
|
one_hot_token = tokenized |
|
|
|
( |
|
text_self_attention_masks, |
|
position_ids, |
|
cate_to_token_mask_list, |
|
) = generate_masks_with_special_tokens_and_transfer_map( |
|
tokenized, self.specical_tokens, self.tokenizer |
|
) |
|
|
|
if text_self_attention_masks.shape[1] > self.max_text_len: |
|
text_self_attention_masks = text_self_attention_masks[ |
|
:, : self.max_text_len, : self.max_text_len |
|
] |
|
position_ids = position_ids[:, : self.max_text_len] |
|
tokenized["input_ids"] = tokenized["input_ids"][:, : self.max_text_len] |
|
tokenized["attention_mask"] = tokenized["attention_mask"][ |
|
:, : self.max_text_len |
|
] |
|
tokenized["token_type_ids"] = tokenized["token_type_ids"][ |
|
:, : self.max_text_len |
|
] |
|
|
|
|
|
if self.sub_sentence_present: |
|
tokenized_for_encoder = { |
|
k: v for k, v in tokenized.items() if k != "attention_mask" |
|
} |
|
tokenized_for_encoder["attention_mask"] = text_self_attention_masks |
|
tokenized_for_encoder["position_ids"] = position_ids |
|
else: |
|
tokenized_for_encoder = tokenized |
|
|
|
bert_output = self.bert(**tokenized_for_encoder) |
|
|
|
encoded_text = self.feat_map( |
|
bert_output["last_hidden_state"] |
|
) |
|
text_token_mask = tokenized.attention_mask.bool() |
|
|
|
|
|
|
|
if encoded_text.shape[1] > self.max_text_len: |
|
encoded_text = encoded_text[:, : self.max_text_len, :] |
|
text_token_mask = text_token_mask[:, : self.max_text_len] |
|
position_ids = position_ids[:, : self.max_text_len] |
|
text_self_attention_masks = text_self_attention_masks[ |
|
:, : self.max_text_len, : self.max_text_len |
|
] |
|
|
|
text_dict = { |
|
"encoded_text": encoded_text, |
|
"text_token_mask": text_token_mask, |
|
"position_ids": position_ids, |
|
"text_self_attention_masks": text_self_attention_masks, |
|
} |
|
|
|
if isinstance(samples, (list, torch.Tensor)): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
|
|
if not cropped: |
|
features, poss = self.backbone(samples) |
|
features_exemp, _ = self.backbone(exemplar_images) |
|
combined_features = self.combine_features(features_exemp) |
|
|
|
bs = len(exemplars) |
|
num_exemplars = exemplars[0].shape[0] |
|
print(exemplars) |
|
print(num_exemplars) |
|
if num_exemplars > 0: |
|
exemplar_tokens = ( |
|
roi_align( |
|
combined_features, |
|
boxes=exemplars, |
|
output_size=(1, 1), |
|
spatial_scale=(1 / 8), |
|
aligned=True, |
|
) |
|
.squeeze(-1) |
|
.squeeze(-1) |
|
.reshape(bs, num_exemplars, -1) |
|
) |
|
else: |
|
exemplar_tokens = None |
|
|
|
else: |
|
features, poss = self.backbone(samples) |
|
(h, w) = ( |
|
samples.decompose()[0][0].shape[1], |
|
samples.decompose()[0][0].shape[2], |
|
) |
|
(orig_img_h, orig_img_w) = orig_img.shape[1], orig_img.shape[2] |
|
bs = len(samples.decompose()[0]) |
|
|
|
exemp_imgs = [] |
|
new_exemplars = [] |
|
ind = 0 |
|
for exemp in exemplars[0]: |
|
center_x = (exemp[0] + exemp[2]) / 2 |
|
center_y = (exemp[1] + exemp[3]) / 2 |
|
start_x = max(int(center_x - crop_width / 2), 0) |
|
end_x = min(int(center_x + crop_width / 2), orig_img_w) |
|
start_y = max(int(center_y - crop_height / 2), 0) |
|
end_y = min(int(center_y + crop_height / 2), orig_img_h) |
|
scale_x = w / (end_x - start_x) |
|
scale_y = h / (end_y - start_y) |
|
exemp_imgs.append( |
|
vis_F.resize( |
|
orig_img[:, start_y:end_y, start_x:end_x], |
|
(h, w), |
|
interpolation=InterpolationMode.BICUBIC, |
|
) |
|
) |
|
new_exemplars.append( |
|
[ |
|
(exemp[0] - start_x) * scale_x, |
|
(exemp[1] - start_y) * scale_y, |
|
(exemp[2] - start_x) * scale_x, |
|
(exemp[3] - start_y) * scale_y, |
|
] |
|
) |
|
|
|
vis_exemps( |
|
renorm(exemp_imgs[-1].cpu()).permute(1, 2, 0).numpy(), |
|
[coord.item() for coord in new_exemplars[-1]], |
|
str(ind) + ".jpg", |
|
) |
|
vis_exemps( |
|
renorm(orig_img.cpu()).permute(1, 2, 0).numpy(), |
|
[coord.item() for coord in exemplars[0][ind]], |
|
"orig-" + str(ind) + ".jpg", |
|
) |
|
ind += 1 |
|
|
|
exemp_imgs = nested_tensor_from_tensor_list(exemp_imgs) |
|
features_exemp, _ = self.backbone(exemp_imgs) |
|
combined_features = self.combine_features(features_exemp) |
|
new_exemplars = [ |
|
torch.tensor(exemp).unsqueeze(0).to(samples.device) for exemp in new_exemplars |
|
] |
|
|
|
|
|
exemplar_tokens = ( |
|
roi_align( |
|
combined_features, |
|
boxes=new_exemplars, |
|
output_size=(1, 1), |
|
spatial_scale=(1 / 8), |
|
aligned=True, |
|
) |
|
.squeeze(-1) |
|
.squeeze(-1) |
|
.reshape(3, 256) |
|
) |
|
|
|
exemplar_tokens = torch.stack([exemplar_tokens] * bs) |
|
|
|
if exemplar_tokens is not None: |
|
text_dict = self.add_exemplar_tokens( |
|
tokenized, text_dict, exemplar_tokens, labels |
|
) |
|
|
|
srcs = [] |
|
masks = [] |
|
for l, feat in enumerate(features): |
|
src, mask = feat.decompose() |
|
srcs.append(self.input_proj[l](src)) |
|
masks.append(mask) |
|
assert mask is not None |
|
if self.num_feature_levels > len(srcs): |
|
_len_srcs = len(srcs) |
|
for l in range(_len_srcs, self.num_feature_levels): |
|
if l == _len_srcs: |
|
src = self.input_proj[l](features[-1].tensors) |
|
else: |
|
src = self.input_proj[l](srcs[-1]) |
|
m = samples.mask |
|
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( |
|
torch.bool |
|
)[0] |
|
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) |
|
srcs.append(src) |
|
masks.append(mask) |
|
poss.append(pos_l) |
|
|
|
input_query_bbox = input_query_label = attn_mask = dn_meta = None |
|
hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer( |
|
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict |
|
) |
|
|
|
|
|
outputs_coord_list = [] |
|
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( |
|
zip(reference[:-1], self.bbox_embed, hs) |
|
): |
|
layer_delta_unsig = layer_bbox_embed(layer_hs) |
|
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) |
|
layer_outputs_unsig = layer_outputs_unsig.sigmoid() |
|
outputs_coord_list.append(layer_outputs_unsig) |
|
outputs_coord_list = torch.stack(outputs_coord_list) |
|
|
|
outputs_class = torch.stack( |
|
[ |
|
layer_cls_embed(layer_hs, text_dict) |
|
for layer_cls_embed, layer_hs in zip(self.class_embed, hs) |
|
] |
|
) |
|
|
|
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1]} |
|
|
|
|
|
bs, len_td = text_dict["text_token_mask"].shape |
|
out["text_mask"] = torch.zeros(bs, self.max_text_len, dtype=torch.bool).to( |
|
samples.device |
|
) |
|
for b in range(bs): |
|
for j in range(len_td): |
|
if text_dict["text_token_mask"][b][j] == True: |
|
out["text_mask"][b][j] = True |
|
|
|
|
|
if self.aux_loss: |
|
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord_list) |
|
out["token"] = one_hot_token |
|
|
|
if hs_enc is not None: |
|
|
|
interm_coord = ref_enc[-1] |
|
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1], text_dict) |
|
out["interm_outputs"] = { |
|
"pred_logits": interm_class, |
|
"pred_boxes": interm_coord, |
|
} |
|
out["interm_outputs_for_matching_pre"] = { |
|
"pred_logits": interm_class, |
|
"pred_boxes": init_box_proposal, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out |
|
|
|
@torch.jit.unused |
|
def _set_aux_loss(self, outputs_class, outputs_coord): |
|
|
|
|
|
|
|
return [ |
|
{"pred_logits": a, "pred_boxes": b} |
|
for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) |
|
] |
|
|
|
|
|
class SetCriterion(nn.Module): |
|
def __init__(self, matcher, weight_dict, focal_alpha, focal_gamma, losses): |
|
"""Create the criterion. |
|
Parameters: |
|
matcher: module able to compute a matching between targets and proposals |
|
weight_dict: dict containing as key the names of the losses and as values their relative weight. |
|
losses: list of all the losses to be applied. See get_loss for list of available losses. |
|
focal_alpha: alpha in Focal Loss |
|
""" |
|
super().__init__() |
|
self.matcher = matcher |
|
self.weight_dict = weight_dict |
|
self.losses = losses |
|
self.focal_alpha = focal_alpha |
|
self.focal_gamma = focal_gamma |
|
|
|
@torch.no_grad() |
|
def loss_cardinality(self, outputs, targets, indices, num_boxes): |
|
"""Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes |
|
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients |
|
""" |
|
|
|
pred_logits = outputs["pred_logits"] |
|
device = pred_logits.device |
|
tgt_lengths = torch.as_tensor( |
|
[len(v["labels"]) for v in targets], device=device |
|
) |
|
|
|
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) |
|
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) |
|
losses = {"cardinality_error": card_err} |
|
return losses |
|
|
|
def loss_boxes(self, outputs, targets, indices, num_boxes): |
|
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss |
|
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] |
|
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. |
|
""" |
|
assert "pred_boxes" in outputs |
|
idx = self._get_src_permutation_idx(indices) |
|
src_boxes = outputs["pred_boxes"][idx] |
|
target_boxes = torch.cat( |
|
[t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0 |
|
) |
|
|
|
loss_bbox = F.l1_loss(src_boxes[:, :2], target_boxes[:, :2], reduction="none") |
|
|
|
losses = {} |
|
losses["loss_bbox"] = loss_bbox.sum() / num_boxes |
|
|
|
loss_giou = 1 - torch.diag( |
|
box_ops.generalized_box_iou( |
|
box_ops.box_cxcywh_to_xyxy(src_boxes), |
|
box_ops.box_cxcywh_to_xyxy(target_boxes), |
|
) |
|
) |
|
losses["loss_giou"] = loss_giou.sum() / num_boxes |
|
|
|
|
|
with torch.no_grad(): |
|
losses["loss_xy"] = loss_bbox[..., :2].sum() / num_boxes |
|
losses["loss_hw"] = loss_bbox[..., 2:].sum() / num_boxes |
|
|
|
return losses |
|
|
|
def token_sigmoid_binary_focal_loss(self, outputs, targets, indices, num_boxes): |
|
pred_logits = outputs["pred_logits"] |
|
new_targets = outputs["one_hot"].to(pred_logits.device) |
|
text_mask = outputs["text_mask"] |
|
|
|
assert new_targets.dim() == 3 |
|
assert pred_logits.dim() == 3 |
|
|
|
bs, n, _ = pred_logits.shape |
|
alpha = self.focal_alpha |
|
gamma = self.focal_gamma |
|
if text_mask is not None: |
|
|
|
text_mask = text_mask.repeat(1, pred_logits.size(1)).view( |
|
outputs["text_mask"].shape[0], -1, outputs["text_mask"].shape[1] |
|
) |
|
pred_logits = torch.masked_select(pred_logits, text_mask) |
|
new_targets = torch.masked_select(new_targets, text_mask) |
|
|
|
new_targets = new_targets.float() |
|
p = torch.sigmoid(pred_logits) |
|
ce_loss = F.binary_cross_entropy_with_logits( |
|
pred_logits, new_targets, reduction="none" |
|
) |
|
p_t = p * new_targets + (1 - p) * (1 - new_targets) |
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
if alpha >= 0: |
|
alpha_t = alpha * new_targets + (1 - alpha) * (1 - new_targets) |
|
loss = alpha_t * loss |
|
|
|
total_num_pos = 0 |
|
for batch_indices in indices: |
|
total_num_pos += len(batch_indices[0]) |
|
num_pos_avg_per_gpu = max(total_num_pos, 1.0) |
|
loss = loss.sum() / num_pos_avg_per_gpu |
|
|
|
losses = {"loss_ce": loss} |
|
return losses |
|
|
|
def _get_src_permutation_idx(self, indices): |
|
|
|
batch_idx = torch.cat( |
|
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)] |
|
) |
|
src_idx = torch.cat([src for (src, _) in indices]) |
|
return batch_idx, src_idx |
|
|
|
def _get_tgt_permutation_idx(self, indices): |
|
|
|
batch_idx = torch.cat( |
|
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] |
|
) |
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices]) |
|
return batch_idx, tgt_idx |
|
|
|
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): |
|
loss_map = { |
|
"labels": self.token_sigmoid_binary_focal_loss, |
|
"cardinality": self.loss_cardinality, |
|
"boxes": self.loss_boxes, |
|
} |
|
assert loss in loss_map, f"do you really want to compute {loss} loss?" |
|
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) |
|
|
|
def forward(self, outputs, targets, cat_list, caption, return_indices=False): |
|
"""This performs the loss computation. |
|
Parameters: |
|
outputs: dict of tensors, see the output specification of the model for the format |
|
targets: list of dicts, such that len(targets) == batch_size. |
|
The expected keys in each dict depends on the losses applied, see each loss' doc |
|
|
|
return_indices: used for vis. if True, the layer0-5 indices will be returned as well. |
|
""" |
|
device = next(iter(outputs.values())).device |
|
one_hot = torch.zeros( |
|
outputs["pred_logits"].size(), dtype=torch.int64 |
|
) |
|
token = outputs["token"] |
|
|
|
label_map_list = [] |
|
indices = [] |
|
for j in range(len(cat_list)): |
|
label_map = [] |
|
for i in range(len(cat_list[j])): |
|
label_id = torch.tensor([i]) |
|
per_label = create_positive_map_exemplar( |
|
token["input_ids"][j], label_id, [101, 102, 1012, 1029] |
|
) |
|
label_map.append(per_label) |
|
label_map = torch.stack(label_map, dim=0).squeeze(1) |
|
|
|
label_map_list.append(label_map) |
|
for j in range(len(cat_list)): |
|
for_match = { |
|
"pred_logits": outputs["pred_logits"][j].unsqueeze(0), |
|
"pred_boxes": outputs["pred_boxes"][j].unsqueeze(0), |
|
} |
|
|
|
inds = self.matcher(for_match, [targets[j]], label_map_list[j]) |
|
indices.extend(inds) |
|
|
|
|
|
|
|
|
|
|
|
tgt_ids = [v["labels"].cpu() for v in targets] |
|
|
|
for i in range(len(indices)): |
|
tgt_ids[i] = tgt_ids[i][indices[i][1]] |
|
one_hot[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to(torch.long) |
|
outputs["one_hot"] = one_hot |
|
if return_indices: |
|
indices0_copy = indices |
|
indices_list = [] |
|
|
|
|
|
num_boxes_list = [len(t["labels"]) for t in targets] |
|
num_boxes = sum(num_boxes_list) |
|
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=device) |
|
if is_dist_avail_and_initialized(): |
|
torch.distributed.all_reduce(num_boxes) |
|
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() |
|
|
|
|
|
losses = {} |
|
for loss in self.losses: |
|
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) |
|
|
|
|
|
if "aux_outputs" in outputs: |
|
for idx, aux_outputs in enumerate(outputs["aux_outputs"]): |
|
indices = [] |
|
for j in range(len(cat_list)): |
|
aux_output_single = { |
|
"pred_logits": aux_outputs["pred_logits"][j].unsqueeze(0), |
|
"pred_boxes": aux_outputs["pred_boxes"][j].unsqueeze(0), |
|
} |
|
inds = self.matcher( |
|
aux_output_single, [targets[j]], label_map_list[j] |
|
) |
|
indices.extend(inds) |
|
one_hot_aux = torch.zeros( |
|
outputs["pred_logits"].size(), dtype=torch.int64 |
|
) |
|
tgt_ids = [v["labels"].cpu() for v in targets] |
|
for i in range(len(indices)): |
|
tgt_ids[i] = tgt_ids[i][indices[i][1]] |
|
one_hot_aux[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to( |
|
torch.long |
|
) |
|
aux_outputs["one_hot"] = one_hot_aux |
|
aux_outputs["text_mask"] = outputs["text_mask"] |
|
if return_indices: |
|
indices_list.append(indices) |
|
for loss in self.losses: |
|
kwargs = {} |
|
l_dict = self.get_loss( |
|
loss, aux_outputs, targets, indices, num_boxes, **kwargs |
|
) |
|
l_dict = {k + f"_{idx}": v for k, v in l_dict.items()} |
|
losses.update(l_dict) |
|
|
|
|
|
if "interm_outputs" in outputs: |
|
interm_outputs = outputs["interm_outputs"] |
|
indices = [] |
|
for j in range(len(cat_list)): |
|
interm_output_single = { |
|
"pred_logits": interm_outputs["pred_logits"][j].unsqueeze(0), |
|
"pred_boxes": interm_outputs["pred_boxes"][j].unsqueeze(0), |
|
} |
|
inds = self.matcher( |
|
interm_output_single, [targets[j]], label_map_list[j] |
|
) |
|
indices.extend(inds) |
|
one_hot_aux = torch.zeros(outputs["pred_logits"].size(), dtype=torch.int64) |
|
tgt_ids = [v["labels"].cpu() for v in targets] |
|
for i in range(len(indices)): |
|
tgt_ids[i] = tgt_ids[i][indices[i][1]] |
|
one_hot_aux[i, indices[i][0]] = label_map_list[i][tgt_ids[i]].to( |
|
torch.long |
|
) |
|
interm_outputs["one_hot"] = one_hot_aux |
|
interm_outputs["text_mask"] = outputs["text_mask"] |
|
if return_indices: |
|
indices_list.append(indices) |
|
for loss in self.losses: |
|
kwargs = {} |
|
l_dict = self.get_loss( |
|
loss, interm_outputs, targets, indices, num_boxes, **kwargs |
|
) |
|
l_dict = {k + f"_interm": v for k, v in l_dict.items()} |
|
losses.update(l_dict) |
|
|
|
if return_indices: |
|
indices_list.append(indices0_copy) |
|
return losses, indices_list |
|
|
|
return losses |
|
|
|
|
|
class PostProcess(nn.Module): |
|
"""This module converts the model's output into the format expected by the coco api""" |
|
|
|
def __init__( |
|
self, |
|
num_select=100, |
|
text_encoder_type="text_encoder_type", |
|
nms_iou_threshold=-1, |
|
use_coco_eval=False, |
|
args=None, |
|
) -> None: |
|
super().__init__() |
|
self.num_select = num_select |
|
self.tokenizer = get_tokenlizer.get_tokenlizer(text_encoder_type) |
|
if args.use_coco_eval: |
|
from pycocotools.coco import COCO |
|
|
|
coco = COCO(args.coco_val_path) |
|
category_dict = coco.loadCats(coco.getCatIds()) |
|
cat_list = [item["name"] for item in category_dict] |
|
else: |
|
cat_list = args.label_list |
|
caption = " . ".join(cat_list) + " ." |
|
tokenized = self.tokenizer(caption, padding="longest", return_tensors="pt") |
|
label_list = torch.arange(len(cat_list)) |
|
pos_map = create_positive_map(tokenized, label_list, cat_list, caption) |
|
|
|
if args.use_coco_eval: |
|
id_map = { |
|
0: 1, |
|
1: 2, |
|
2: 3, |
|
3: 4, |
|
4: 5, |
|
5: 6, |
|
6: 7, |
|
7: 8, |
|
8: 9, |
|
9: 10, |
|
10: 11, |
|
11: 13, |
|
12: 14, |
|
13: 15, |
|
14: 16, |
|
15: 17, |
|
16: 18, |
|
17: 19, |
|
18: 20, |
|
19: 21, |
|
20: 22, |
|
21: 23, |
|
22: 24, |
|
23: 25, |
|
24: 27, |
|
25: 28, |
|
26: 31, |
|
27: 32, |
|
28: 33, |
|
29: 34, |
|
30: 35, |
|
31: 36, |
|
32: 37, |
|
33: 38, |
|
34: 39, |
|
35: 40, |
|
36: 41, |
|
37: 42, |
|
38: 43, |
|
39: 44, |
|
40: 46, |
|
41: 47, |
|
42: 48, |
|
43: 49, |
|
44: 50, |
|
45: 51, |
|
46: 52, |
|
47: 53, |
|
48: 54, |
|
49: 55, |
|
50: 56, |
|
51: 57, |
|
52: 58, |
|
53: 59, |
|
54: 60, |
|
55: 61, |
|
56: 62, |
|
57: 63, |
|
58: 64, |
|
59: 65, |
|
60: 67, |
|
61: 70, |
|
62: 72, |
|
63: 73, |
|
64: 74, |
|
65: 75, |
|
66: 76, |
|
67: 77, |
|
68: 78, |
|
69: 79, |
|
70: 80, |
|
71: 81, |
|
72: 82, |
|
73: 84, |
|
74: 85, |
|
75: 86, |
|
76: 87, |
|
77: 88, |
|
78: 89, |
|
79: 90, |
|
} |
|
new_pos_map = torch.zeros((91, 256)) |
|
for k, v in id_map.items(): |
|
new_pos_map[v] = pos_map[k] |
|
pos_map = new_pos_map |
|
|
|
self.nms_iou_threshold = nms_iou_threshold |
|
self.positive_map = pos_map |
|
|
|
@torch.no_grad() |
|
def forward(self, outputs, target_sizes, not_to_xyxy=False, test=False): |
|
"""Perform the computation |
|
Parameters: |
|
outputs: raw outputs of the model |
|
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch |
|
For evaluation, this must be the original image size (before any data augmentation) |
|
For visualization, this should be the image size after data augment, but before padding |
|
""" |
|
num_select = self.num_select |
|
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] |
|
|
|
prob_to_token = out_logits.sigmoid() |
|
pos_maps = self.positive_map.to(prob_to_token.device) |
|
for label_ind in range(len(pos_maps)): |
|
if pos_maps[label_ind].sum() != 0: |
|
pos_maps[label_ind] = pos_maps[label_ind] / pos_maps[label_ind].sum() |
|
|
|
prob_to_label = prob_to_token @ pos_maps.T |
|
|
|
assert len(out_logits) == len(target_sizes) |
|
assert target_sizes.shape[1] == 2 |
|
|
|
prob = prob_to_label |
|
topk_values, topk_indexes = torch.topk( |
|
prob.view(prob.shape[0], -1), num_select, dim=1 |
|
) |
|
scores = topk_values |
|
topk_boxes = torch.div(topk_indexes, prob.shape[2], rounding_mode="trunc") |
|
labels = topk_indexes % prob.shape[2] |
|
if not_to_xyxy: |
|
boxes = out_bbox |
|
else: |
|
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) |
|
|
|
|
|
|
|
|
|
boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) |
|
|
|
|
|
img_h, img_w = target_sizes.unbind(1) |
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) |
|
boxes = boxes * scale_fct[:, None, :] |
|
|
|
if self.nms_iou_threshold > 0: |
|
item_indices = [ |
|
nms(b, s, iou_threshold=self.nms_iou_threshold) |
|
for b, s in zip(boxes, scores) |
|
] |
|
|
|
results = [ |
|
{"scores": s[i], "labels": l[i], "boxes": b[i]} |
|
for s, l, b, i in zip(scores, labels, boxes, item_indices) |
|
] |
|
else: |
|
results = [ |
|
{"scores": s, "labels": l, "boxes": b} |
|
for s, l, b in zip(scores, labels, boxes) |
|
] |
|
results = [ |
|
{"scores": s, "labels": l, "boxes": b} |
|
for s, l, b in zip(scores, labels, boxes) |
|
] |
|
return results |
|
|
|
|
|
@MODULE_BUILD_FUNCS.registe_with_name(module_name="groundingdino") |
|
def build_groundingdino(args): |
|
device = torch.device(args.device) |
|
backbone = build_backbone(args) |
|
transformer = build_transformer(args) |
|
|
|
dn_labelbook_size = args.dn_labelbook_size |
|
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share |
|
sub_sentence_present = args.sub_sentence_present |
|
|
|
model = GroundingDINO( |
|
backbone, |
|
transformer, |
|
num_queries=args.num_queries, |
|
aux_loss=args.aux_loss, |
|
iter_update=True, |
|
query_dim=4, |
|
num_feature_levels=args.num_feature_levels, |
|
nheads=args.nheads, |
|
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, |
|
two_stage_type=args.two_stage_type, |
|
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, |
|
two_stage_class_embed_share=args.two_stage_class_embed_share, |
|
num_patterns=args.num_patterns, |
|
dn_number=0, |
|
dn_box_noise_scale=args.dn_box_noise_scale, |
|
dn_label_noise_ratio=args.dn_label_noise_ratio, |
|
dn_labelbook_size=dn_labelbook_size, |
|
text_encoder_type=args.text_encoder_type, |
|
sub_sentence_present=sub_sentence_present, |
|
max_text_len=args.max_text_len, |
|
) |
|
|
|
matcher = build_matcher(args) |
|
|
|
|
|
weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef} |
|
weight_dict["loss_giou"] = args.giou_loss_coef |
|
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) |
|
|
|
clean_weight_dict = copy.deepcopy(weight_dict) |
|
|
|
|
|
if args.aux_loss: |
|
aux_weight_dict = {} |
|
for i in range(args.dec_layers - 1): |
|
aux_weight_dict.update( |
|
{k + f"_{i}": v for k, v in clean_weight_dict.items()} |
|
) |
|
weight_dict.update(aux_weight_dict) |
|
|
|
if args.two_stage_type != "no": |
|
interm_weight_dict = {} |
|
try: |
|
no_interm_box_loss = args.no_interm_box_loss |
|
except: |
|
no_interm_box_loss = False |
|
_coeff_weight_dict = { |
|
"loss_ce": 1.0, |
|
"loss_bbox": 1.0 if not no_interm_box_loss else 0.0, |
|
"loss_giou": 1.0 if not no_interm_box_loss else 0.0, |
|
} |
|
try: |
|
interm_loss_coef = args.interm_loss_coef |
|
except: |
|
interm_loss_coef = 1.0 |
|
interm_weight_dict.update( |
|
{ |
|
k + f"_interm": v * interm_loss_coef * _coeff_weight_dict[k] |
|
for k, v in clean_weight_dict_wo_dn.items() |
|
} |
|
) |
|
weight_dict.update(interm_weight_dict) |
|
|
|
|
|
losses = ["labels", "boxes"] |
|
|
|
criterion = SetCriterion( |
|
matcher=matcher, |
|
weight_dict=weight_dict, |
|
focal_alpha=args.focal_alpha, |
|
focal_gamma=args.focal_gamma, |
|
losses=losses, |
|
) |
|
criterion.to(device) |
|
postprocessors = { |
|
"bbox": PostProcess( |
|
num_select=args.num_select, |
|
text_encoder_type=args.text_encoder_type, |
|
nms_iou_threshold=args.nms_iou_threshold, |
|
args=args, |
|
) |
|
} |
|
|
|
return model, criterion, postprocessors |
|
|
|
|
|
def create_positive_map(tokenized, tokens_positive, cat_list, caption): |
|
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j""" |
|
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) |
|
|
|
for j, label in enumerate(tokens_positive): |
|
start_ind = caption.find(cat_list[label]) |
|
end_ind = start_ind + len(cat_list[label]) - 1 |
|
beg_pos = tokenized.char_to_token(start_ind) |
|
try: |
|
end_pos = tokenized.char_to_token(end_ind) |
|
except: |
|
end_pos = None |
|
if end_pos is None: |
|
try: |
|
end_pos = tokenized.char_to_token(end_ind - 1) |
|
if end_pos is None: |
|
end_pos = tokenized.char_to_token(end_ind - 2) |
|
except: |
|
end_pos = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if beg_pos is None or end_pos is None: |
|
continue |
|
if beg_pos < 0 or end_pos < 0: |
|
continue |
|
if beg_pos > end_pos: |
|
continue |
|
|
|
positive_map[j, beg_pos : end_pos + 1].fill_(1) |
|
return positive_map |
|
|
|
|
|
def create_positive_map_exemplar(input_ids, label, special_tokens): |
|
tokens_positive = torch.zeros(256, dtype=torch.float) |
|
count = -1 |
|
for token_ind in range(len(input_ids)): |
|
input_id = input_ids[token_ind] |
|
if (input_id not in special_tokens) and ( |
|
token_ind == 0 or (input_ids[token_ind - 1] in special_tokens) |
|
): |
|
count += 1 |
|
if count == label: |
|
ind_to_insert_ones = token_ind |
|
|
|
while input_ids[ind_to_insert_ones] not in special_tokens: |
|
tokens_positive[ind_to_insert_ones] = 1 |
|
ind_to_insert_ones += 1 |
|
break |
|
return tokens_positive |
|
|