import copy import re import warnings from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from mmdet.models.detectors.dino import DINO from mmdet.models.detectors.glip import (create_positive_map, create_positive_map_label_to_token) from mmdet.models.layers import SinePositionalEncoding from mmdet.models.layers.transformer.grounding_dino_layers import ( GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder) from mmdet.registry import MODELS from mmdet.structures import OptSampleList, SampleList from mmdet.utils import ConfigType from mmengine.runner.amp import autocast from torch import Tensor try: import os import nltk download_dir = os.path.expanduser("~/nltk_data") nltk.download("punkt", download_dir=download_dir, quiet=True) nltk.download("averaged_perceptron_tagger", download_dir=download_dir, quiet=True) except ImportError: raise RuntimeError( "nltk is not installed, please install it by: " "pip install nltk." ) def find_noun_phrases(caption: str) -> list: """Find noun phrases in a caption using nltk. Args: caption (str): The caption to analyze. Returns: list: List of noun phrases found in the caption. Examples: >>> caption = 'There is two cat and a remote in the picture' >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] """ # try: # import nltk # import os # # nltk.download('punkt', download_dir='~/nltk_data') # # nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data') # download_dir = os.path.expanduser('~/nltk_data') # nltk.download('punkt', download_dir=download_dir) # nltk.download('averaged_perceptron_tagger', download_dir=download_dir) # except ImportError: # raise RuntimeError('nltk is not installed, please install it by: ' # 'pip install nltk.') caption = caption.lower() tokens = nltk.word_tokenize(caption) pos_tags = nltk.pos_tag(tokens) grammar = "NP: {
?*+}" cp = nltk.RegexpParser(grammar) result = cp.parse(pos_tags) noun_phrases = [] for subtree in result.subtrees(): if subtree.label() == "NP": noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) return noun_phrases def remove_punctuation(text: str) -> str: """Remove punctuation from a text. Args: text (str): The input text. Returns: str: The text with punctuation removed. """ punctuation = [ "|", ":", ";", "@", "(", ")", "[", "]", "{", "}", "^", "'", '"', "’", "`", "?", "$", "%", "#", "!", "&", "*", "+", ",", ".", ] for p in punctuation: text = text.replace(p, "") return text.strip() def run_ner(caption: str) -> Tuple[list, list]: """Run NER on a caption and return the tokens and noun phrases. Args: caption (str): The input caption. Returns: Tuple[List, List]: A tuple containing the tokens and noun phrases. - tokens_positive (List): A list of token positions. - noun_phrases (List): A list of noun phrases. """ noun_phrases = find_noun_phrases(caption) noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] # print('noun_phrases:', noun_phrases) relevant_phrases = noun_phrases labels = noun_phrases tokens_positive = [] for entity, label in zip(relevant_phrases, labels): try: # search all occurrences and mark them as different entities # TODO: Not Robust for m in re.finditer(entity, caption.lower()): tokens_positive.append([[m.start(), m.end()]]) except Exception: print("noun entities:", noun_phrases) print("entity:", entity) print("caption:", caption.lower()) return tokens_positive, noun_phrases def clean_label_name(name: str) -> str: name = re.sub(r"\(.*\)", "", name) name = re.sub(r"_", " ", name) name = re.sub(r" ", " ", name) return name def chunks(lst: list, n: int) -> list: """Yield successive n-sized chunks from lst.""" all_ = [] for i in range(0, len(lst), n): data_index = lst[i : i + n] all_.append(data_index) counter = 0 for i in all_: counter += len(i) assert counter == len(lst) return all_ @MODELS.register_module(force=True) class GroundingDINO(DINO): """Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- Training for Open-Set Object Detection. `_ Code is modified from the `official github repo `_. """ def __init__(self, language_model, *args, use_autocast=False, **kwargs) -> None: self.language_model_cfg = language_model self._special_tokens = ". " self.use_autocast = use_autocast super().__init__(*args, **kwargs) def _init_layers(self) -> None: """Initialize layers except for backbone, neck and bbox_head.""" self.positional_encoding = SinePositionalEncoding(**self.positional_encoding) self.encoder = GroundingDinoTransformerEncoder(**self.encoder) self.decoder = GroundingDinoTransformerDecoder(**self.decoder) self.embed_dims = self.encoder.embed_dims self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) num_feats = self.positional_encoding.num_feats assert num_feats * 2 == self.embed_dims, ( f"embed_dims should be exactly 2 times of num_feats. " f"Found {self.embed_dims} and {num_feats}." ) self.level_embed = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims) ) self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) self.memory_trans_norm = nn.LayerNorm(self.embed_dims) # text modules self.language_model = MODELS.build(self.language_model_cfg) self.text_feat_map = nn.Linear( self.language_model.language_backbone.body.language_dim, self.embed_dims, bias=True, ) def init_weights(self) -> None: """Initialize weights for Transformer and other components.""" super().init_weights() nn.init.constant_(self.text_feat_map.bias.data, 0) nn.init.xavier_uniform_(self.text_feat_map.weight.data) def to_enhance_text_prompts(self, original_caption, enhanced_text_prompts): caption_string = "" tokens_positive = [] for idx, word in enumerate(original_caption): if word in enhanced_text_prompts: enhanced_text_dict = enhanced_text_prompts[word] if "prefix" in enhanced_text_dict: caption_string += enhanced_text_dict["prefix"] start_i = len(caption_string) if "name" in enhanced_text_dict: caption_string += enhanced_text_dict["name"] else: caption_string += word end_i = len(caption_string) tokens_positive.append([[start_i, end_i]]) if "suffix" in enhanced_text_dict: caption_string += enhanced_text_dict["suffix"] else: tokens_positive.append( [[len(caption_string), len(caption_string) + len(word)]] ) caption_string += word caption_string += self._special_tokens return caption_string, tokens_positive def to_plain_text_prompts(self, original_caption): caption_string = "" tokens_positive = [] for idx, word in enumerate(original_caption): tokens_positive.append( [[len(caption_string), len(caption_string) + len(word)]] ) caption_string += word caption_string += self._special_tokens return caption_string, tokens_positive def get_tokens_and_prompts( self, original_caption: Union[str, list, tuple], custom_entities: bool = False, enhanced_text_prompts: Optional[ConfigType] = None, ) -> Tuple[dict, str, list]: """Get the tokens positive and prompts for the caption.""" if isinstance(original_caption, (list, tuple)) or custom_entities: if custom_entities and isinstance(original_caption, str): original_caption = original_caption.strip(self._special_tokens) original_caption = original_caption.split(self._special_tokens) original_caption = list(filter(lambda x: len(x) > 0, original_caption)) original_caption = [clean_label_name(i) for i in original_caption] if custom_entities and enhanced_text_prompts is not None: caption_string, tokens_positive = self.to_enhance_text_prompts( original_caption, enhanced_text_prompts ) else: caption_string, tokens_positive = self.to_plain_text_prompts( original_caption ) # NOTE: Tokenizer in Grounding DINO is different from # that in GLIP. The tokenizer in GLIP will pad the # caption_string to max_length, while the tokenizer # in Grounding DINO will not. tokenized = self.language_model.tokenizer( [caption_string], padding="max_length" if self.language_model.pad_to_max else "longest", return_tensors="pt", ) entities = original_caption else: if not original_caption.endswith("."): original_caption = original_caption + self._special_tokens # NOTE: Tokenizer in Grounding DINO is different from # that in GLIP. The tokenizer in GLIP will pad the # caption_string to max_length, while the tokenizer # in Grounding DINO will not. tokenized = self.language_model.tokenizer( [original_caption], padding="max_length" if self.language_model.pad_to_max else "longest", return_tensors="pt", ) tokens_positive, noun_phrases = run_ner(original_caption) entities = noun_phrases caption_string = original_caption return tokenized, caption_string, tokens_positive, entities def get_positive_map(self, tokenized, tokens_positive): positive_map = create_positive_map( tokenized, tokens_positive, max_num_entities=self.bbox_head.cls_branches[ self.decoder.num_layers ].max_text_len, ) positive_map_label_to_token = create_positive_map_label_to_token( positive_map, plus=1 ) return positive_map_label_to_token, positive_map def get_tokens_positive_and_prompts( self, original_caption: Union[str, list, tuple], custom_entities: bool = False, enhanced_text_prompt: Optional[ConfigType] = None, tokens_positive: Optional[list] = None, ) -> Tuple[dict, str, Tensor, list]: """Get the tokens positive and prompts for the caption. Args: original_caption (str): The original caption, e.g. 'bench . car .' custom_entities (bool, optional): Whether to use custom entities. If ``True``, the ``original_caption`` should be a list of strings, each of which is a word. Defaults to False. Returns: Tuple[dict, str, dict, str]: The dict is a mapping from each entity id, which is numbered from 1, to its positive token id. The str represents the prompts. """ if tokens_positive is not None: if tokens_positive == -1: if not original_caption.endswith("."): original_caption = original_caption + self._special_tokens return None, original_caption, None, original_caption else: if not original_caption.endswith("."): original_caption = original_caption + self._special_tokens tokenized = self.language_model.tokenizer( [original_caption], padding="max_length" if self.language_model.pad_to_max else "longest", return_tensors="pt", ) positive_map_label_to_token, positive_map = self.get_positive_map( tokenized, tokens_positive ) entities = [] for token_positive in tokens_positive: instance_entities = [] for t in token_positive: instance_entities.append(original_caption[t[0] : t[1]]) entities.append(" / ".join(instance_entities)) return ( positive_map_label_to_token, original_caption, positive_map, entities, ) chunked_size = self.test_cfg.get("chunked_size", -1) if not self.training and chunked_size > 0: assert ( isinstance(original_caption, (list, tuple)) or custom_entities is True ) all_output = self.get_tokens_positive_and_prompts_chunked( original_caption, enhanced_text_prompt ) ( positive_map_label_to_token, caption_string, positive_map, entities, ) = all_output else: ( tokenized, caption_string, tokens_positive, entities, ) = self.get_tokens_and_prompts( original_caption, custom_entities, enhanced_text_prompt ) positive_map_label_to_token, positive_map = self.get_positive_map( tokenized, tokens_positive ) return positive_map_label_to_token, caption_string, positive_map, entities def get_tokens_positive_and_prompts_chunked( self, original_caption: Union[list, tuple], enhanced_text_prompts: Optional[ConfigType] = None, ): chunked_size = self.test_cfg.get("chunked_size", -1) original_caption = [clean_label_name(i) for i in original_caption] original_caption_chunked = chunks(original_caption, chunked_size) ids_chunked = chunks(list(range(1, len(original_caption) + 1)), chunked_size) positive_map_label_to_token_chunked = [] caption_string_chunked = [] positive_map_chunked = [] entities_chunked = [] for i in range(len(ids_chunked)): if enhanced_text_prompts is not None: caption_string, tokens_positive = self.to_enhance_text_prompts( original_caption_chunked[i], enhanced_text_prompts ) else: caption_string, tokens_positive = self.to_plain_text_prompts( original_caption_chunked[i] ) tokenized = self.language_model.tokenizer( [caption_string], return_tensors="pt" ) if tokenized.input_ids.shape[1] > self.language_model.max_tokens: warnings.warn( "Inputting a text that is too long will result " "in poor prediction performance. " "Please reduce the --chunked-size." ) positive_map_label_to_token, positive_map = self.get_positive_map( tokenized, tokens_positive ) caption_string_chunked.append(caption_string) positive_map_label_to_token_chunked.append(positive_map_label_to_token) positive_map_chunked.append(positive_map) entities_chunked.append(original_caption_chunked[i]) return ( positive_map_label_to_token_chunked, caption_string_chunked, positive_map_chunked, entities_chunked, ) def forward_transformer( self, img_feats: Tuple[Tensor], text_dict: Dict, batch_data_samples: OptSampleList = None, ) -> Dict: encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( img_feats, batch_data_samples ) encoder_outputs_dict = self.forward_encoder( **encoder_inputs_dict, text_dict=text_dict ) tmp_dec_in, head_inputs_dict = self.pre_decoder( **encoder_outputs_dict, batch_data_samples=batch_data_samples ) decoder_inputs_dict.update(tmp_dec_in) decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) head_inputs_dict.update(decoder_outputs_dict) return head_inputs_dict def forward_encoder( self, feat: Tensor, feat_mask: Tensor, feat_pos: Tensor, spatial_shapes: Tensor, level_start_index: Tensor, valid_ratios: Tensor, text_dict: Dict, ) -> Dict: text_token_mask = text_dict["text_token_mask"] memory, memory_text = self.encoder( query=feat, query_pos=feat_pos, key_padding_mask=feat_mask, # for self_attn spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, # for text encoder memory_text=text_dict["embedded"], text_attention_mask=~text_token_mask, position_ids=text_dict["position_ids"], text_self_attention_masks=text_dict["masks"], ) encoder_outputs_dict = dict( memory=memory, memory_mask=feat_mask, spatial_shapes=spatial_shapes, memory_text=memory_text, text_token_mask=text_token_mask, ) return encoder_outputs_dict def pre_decoder( self, memory: Tensor, memory_mask: Tensor, spatial_shapes: Tensor, memory_text: Tensor, text_token_mask: Tensor, batch_data_samples: OptSampleList = None, ) -> Tuple[Dict]: bs, _, c = memory.shape output_memory, output_proposals = self.gen_encoder_output_proposals( memory, memory_mask, spatial_shapes ) enc_outputs_class = self.bbox_head.cls_branches[self.decoder.num_layers]( output_memory, memory_text, text_token_mask ) cls_out_features = self.bbox_head.cls_branches[ self.decoder.num_layers ].max_text_len enc_outputs_coord_unact = ( self.bbox_head.reg_branches[self.decoder.num_layers](output_memory) + output_proposals ) # NOTE The DINO selects top-k proposals according to scores of # multi-class classification, while DeformDETR, where the input # is `enc_outputs_class[..., 0]` selects according to scores of # binary classification. topk_indices = torch.topk( enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1 )[1] topk_score = torch.gather( enc_outputs_class, 1, topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features), ) topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4) ) topk_coords = topk_coords_unact.sigmoid() topk_coords_unact = topk_coords_unact.detach() query = self.query_embedding.weight[:, None, :] query = query.repeat(1, bs, 1).transpose(0, 1) if self.training: dn_label_query, dn_bbox_query, dn_mask, dn_meta = self.dn_query_generator( batch_data_samples ) query = torch.cat([dn_label_query, query], dim=1) reference_points = torch.cat([dn_bbox_query, topk_coords_unact], dim=1) else: reference_points = topk_coords_unact dn_mask, dn_meta = None, None reference_points = reference_points.sigmoid() decoder_inputs_dict = dict( query=query, memory=memory, reference_points=reference_points, dn_mask=dn_mask, memory_text=memory_text, text_attention_mask=~text_token_mask, ) # NOTE DINO calculates encoder losses on scores and coordinates # of selected top-k encoder queries, while DeformDETR is of all # encoder queries. head_inputs_dict = ( dict( enc_outputs_class=topk_score, enc_outputs_coord=topk_coords, dn_meta=dn_meta, ) if self.training else dict() ) # append text_feats to head_inputs_dict head_inputs_dict["memory_text"] = memory_text head_inputs_dict["text_token_mask"] = text_token_mask return decoder_inputs_dict, head_inputs_dict def loss( self, batch_inputs: Tensor, batch_data_samples: SampleList ) -> Union[dict, list]: text_prompts = [data_samples.text for data_samples in batch_data_samples] gt_labels = [ data_samples.gt_instances.labels for data_samples in batch_data_samples ] if "tokens_positive" in batch_data_samples[0]: tokens_positive = [ data_samples.tokens_positive for data_samples in batch_data_samples ] positive_maps = [] for token_positive, text_prompt, gt_label in zip( tokens_positive, text_prompts, gt_labels ): tokenized = self.language_model.tokenizer( [text_prompt], padding="max_length" if self.language_model.pad_to_max else "longest", return_tensors="pt", ) new_tokens_positive = [ token_positive[label.item()] for label in gt_label ] _, positive_map = self.get_positive_map(tokenized, new_tokens_positive) positive_maps.append(positive_map) new_text_prompts = text_prompts else: new_text_prompts = [] positive_maps = [] if len(set(text_prompts)) == 1: # All the text prompts are the same, # so there is no need to calculate them multiple times. ( tokenized, caption_string, tokens_positive, _, ) = self.get_tokens_and_prompts(text_prompts[0], True) new_text_prompts = [caption_string] * len(batch_inputs) for gt_label in gt_labels: new_tokens_positive = [tokens_positive[label] for label in gt_label] _, positive_map = self.get_positive_map( tokenized, new_tokens_positive ) positive_maps.append(positive_map) else: for text_prompt, gt_label in zip(text_prompts, gt_labels): ( tokenized, caption_string, tokens_positive, _, ) = self.get_tokens_and_prompts(text_prompt, True) new_tokens_positive = [tokens_positive[label] for label in gt_label] _, positive_map = self.get_positive_map( tokenized, new_tokens_positive ) positive_maps.append(positive_map) new_text_prompts.append(caption_string) text_dict = self.language_model(new_text_prompts) if self.text_feat_map is not None: text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) for i, data_samples in enumerate(batch_data_samples): positive_map = positive_maps[i].to(batch_inputs.device).bool().float() text_token_mask = text_dict["text_token_mask"][i] data_samples.gt_instances.positive_maps = positive_map data_samples.gt_instances.text_token_mask = text_token_mask.unsqueeze( 0 ).repeat(len(positive_map), 1) if self.use_autocast: with autocast(enabled=True): visual_features = self.extract_feat(batch_inputs) else: visual_features = self.extract_feat(batch_inputs) head_inputs_dict = self.forward_transformer( visual_features, text_dict, batch_data_samples ) losses = self.bbox_head.loss( **head_inputs_dict, batch_data_samples=batch_data_samples ) return losses def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): text_prompts = [] enhanced_text_prompts = [] tokens_positives = [] for data_samples in batch_data_samples: text_prompts.append(data_samples.text) if "caption_prompt" in data_samples: enhanced_text_prompts.append(data_samples.caption_prompt) else: enhanced_text_prompts.append(None) tokens_positives.append(data_samples.get("tokens_positive", None)) if "custom_entities" in batch_data_samples[0]: # Assuming that the `custom_entities` flag # inside a batch is always the same. For single image inference custom_entities = batch_data_samples[0].custom_entities else: custom_entities = False if len(text_prompts) == 1: # All the text prompts are the same, # so there is no need to calculate them multiple times. _positive_maps_and_prompts = [ self.get_tokens_positive_and_prompts( text_prompts[0], custom_entities, enhanced_text_prompts[0], tokens_positives[0], ) ] * len(batch_inputs) else: _positive_maps_and_prompts = [ self.get_tokens_positive_and_prompts( text_prompt, custom_entities, enhanced_text_prompt, tokens_positive ) for text_prompt, enhanced_text_prompt, tokens_positive in zip( text_prompts, enhanced_text_prompts, tokens_positives ) ] token_positive_maps, text_prompts, _, entities = zip( *_positive_maps_and_prompts ) # image feature extraction visual_feats = self.extract_feat(batch_inputs) if isinstance(text_prompts[0], list): # chunked text prompts, only bs=1 is supported assert len(batch_inputs) == 1 count = 0 results_list = [] entities = [[item for lst in entities[0] for item in lst]] for b in range(len(text_prompts[0])): text_prompts_once = [text_prompts[0][b]] token_positive_maps_once = token_positive_maps[0][b] text_dict = self.language_model(text_prompts_once) # text feature map layer if self.text_feat_map is not None: text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) batch_data_samples[0].token_positive_map = token_positive_maps_once head_inputs_dict = self.forward_transformer( copy.deepcopy(visual_feats), text_dict, batch_data_samples ) pred_instances = self.bbox_head.predict( **head_inputs_dict, rescale=rescale, batch_data_samples=batch_data_samples, )[0] if len(pred_instances) > 0: pred_instances.labels += count count += len(token_positive_maps_once) results_list.append(pred_instances) results_list = [results_list[0].cat(results_list)] is_rec_tasks = [False] * len(results_list) else: # extract text feats text_dict = self.language_model(list(text_prompts)) # text feature map layer if self.text_feat_map is not None: text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) is_rec_tasks = [] for i, data_samples in enumerate(batch_data_samples): if token_positive_maps[i] is not None: is_rec_tasks.append(False) else: is_rec_tasks.append(True) data_samples.token_positive_map = token_positive_maps[i] head_inputs_dict = self.forward_transformer( visual_feats, text_dict, batch_data_samples ) results_list = self.bbox_head.predict( **head_inputs_dict, rescale=rescale, batch_data_samples=batch_data_samples, ) for data_sample, pred_instances, entity, is_rec_task in zip( batch_data_samples, results_list, entities, is_rec_tasks ): if len(pred_instances) > 0: label_names = [] for labels in pred_instances.labels: if is_rec_task: label_names.append(entity) continue if labels >= len(entity): warnings.warn( "The unexpected output indicates an issue with " "named entity recognition. You can try " "setting custom_entities=True and running " "again to see if it helps." ) label_names.append("unobject") else: label_names.append(entity[labels]) # for visualization pred_instances.label_names = label_names data_sample.pred_instances = pred_instances return batch_data_samples