Spaces:
Runtime error
Runtime error
import copy | |
import re | |
import warnings | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from mmdet.models.detectors.dino import DINO | |
from mmdet.models.detectors.glip import (create_positive_map, | |
create_positive_map_label_to_token) | |
from mmdet.models.layers import SinePositionalEncoding | |
from mmdet.models.layers.transformer.grounding_dino_layers import ( | |
GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder) | |
from mmdet.registry import MODELS | |
from mmdet.structures import OptSampleList, SampleList | |
from mmdet.utils import ConfigType | |
from mmengine.runner.amp import autocast | |
from torch import Tensor | |
try: | |
import os | |
import nltk | |
download_dir = os.path.expanduser("~/nltk_data") | |
nltk.download("punkt", download_dir=download_dir, quiet=True) | |
nltk.download("averaged_perceptron_tagger", download_dir=download_dir, quiet=True) | |
except ImportError: | |
raise RuntimeError( | |
"nltk is not installed, please install it by: " "pip install nltk." | |
) | |
def find_noun_phrases(caption: str) -> list: | |
"""Find noun phrases in a caption using nltk. | |
Args: | |
caption (str): The caption to analyze. | |
Returns: | |
list: List of noun phrases found in the caption. | |
Examples: | |
>>> caption = 'There is two cat and a remote in the picture' | |
>>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] | |
""" | |
# try: | |
# import nltk | |
# import os | |
# # nltk.download('punkt', download_dir='~/nltk_data') | |
# # nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data') | |
# download_dir = os.path.expanduser('~/nltk_data') | |
# nltk.download('punkt', download_dir=download_dir) | |
# nltk.download('averaged_perceptron_tagger', download_dir=download_dir) | |
# except ImportError: | |
# raise RuntimeError('nltk is not installed, please install it by: ' | |
# 'pip install nltk.') | |
caption = caption.lower() | |
tokens = nltk.word_tokenize(caption) | |
pos_tags = nltk.pos_tag(tokens) | |
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}" | |
cp = nltk.RegexpParser(grammar) | |
result = cp.parse(pos_tags) | |
noun_phrases = [] | |
for subtree in result.subtrees(): | |
if subtree.label() == "NP": | |
noun_phrases.append(" ".join(t[0] for t in subtree.leaves())) | |
return noun_phrases | |
def remove_punctuation(text: str) -> str: | |
"""Remove punctuation from a text. | |
Args: | |
text (str): The input text. | |
Returns: | |
str: The text with punctuation removed. | |
""" | |
punctuation = [ | |
"|", | |
":", | |
";", | |
"@", | |
"(", | |
")", | |
"[", | |
"]", | |
"{", | |
"}", | |
"^", | |
"'", | |
'"', | |
"’", | |
"`", | |
"?", | |
"$", | |
"%", | |
"#", | |
"!", | |
"&", | |
"*", | |
"+", | |
",", | |
".", | |
] | |
for p in punctuation: | |
text = text.replace(p, "") | |
return text.strip() | |
def run_ner(caption: str) -> Tuple[list, list]: | |
"""Run NER on a caption and return the tokens and noun phrases. | |
Args: | |
caption (str): The input caption. | |
Returns: | |
Tuple[List, List]: A tuple containing the tokens and noun phrases. | |
- tokens_positive (List): A list of token positions. | |
- noun_phrases (List): A list of noun phrases. | |
""" | |
noun_phrases = find_noun_phrases(caption) | |
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | |
noun_phrases = [phrase for phrase in noun_phrases if phrase != ""] | |
# print('noun_phrases:', noun_phrases) | |
relevant_phrases = noun_phrases | |
labels = noun_phrases | |
tokens_positive = [] | |
for entity, label in zip(relevant_phrases, labels): | |
try: | |
# search all occurrences and mark them as different entities | |
# TODO: Not Robust | |
for m in re.finditer(entity, caption.lower()): | |
tokens_positive.append([[m.start(), m.end()]]) | |
except Exception: | |
print("noun entities:", noun_phrases) | |
print("entity:", entity) | |
print("caption:", caption.lower()) | |
return tokens_positive, noun_phrases | |
def clean_label_name(name: str) -> str: | |
name = re.sub(r"\(.*\)", "", name) | |
name = re.sub(r"_", " ", name) | |
name = re.sub(r" ", " ", name) | |
return name | |
def chunks(lst: list, n: int) -> list: | |
"""Yield successive n-sized chunks from lst.""" | |
all_ = [] | |
for i in range(0, len(lst), n): | |
data_index = lst[i : i + n] | |
all_.append(data_index) | |
counter = 0 | |
for i in all_: | |
counter += len(i) | |
assert counter == len(lst) | |
return all_ | |
class GroundingDINO(DINO): | |
"""Implementation of `Grounding DINO: Marrying DINO with Grounded Pre- | |
Training for Open-Set Object Detection. | |
<https://arxiv.org/abs/2303.05499>`_ | |
Code is modified from the `official github repo | |
<https://github.com/IDEA-Research/GroundingDINO>`_. | |
""" | |
def __init__(self, language_model, *args, use_autocast=False, **kwargs) -> None: | |
self.language_model_cfg = language_model | |
self._special_tokens = ". " | |
self.use_autocast = use_autocast | |
super().__init__(*args, **kwargs) | |
def _init_layers(self) -> None: | |
"""Initialize layers except for backbone, neck and bbox_head.""" | |
self.positional_encoding = SinePositionalEncoding(**self.positional_encoding) | |
self.encoder = GroundingDinoTransformerEncoder(**self.encoder) | |
self.decoder = GroundingDinoTransformerDecoder(**self.decoder) | |
self.embed_dims = self.encoder.embed_dims | |
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) | |
num_feats = self.positional_encoding.num_feats | |
assert num_feats * 2 == self.embed_dims, ( | |
f"embed_dims should be exactly 2 times of num_feats. " | |
f"Found {self.embed_dims} and {num_feats}." | |
) | |
self.level_embed = nn.Parameter( | |
torch.Tensor(self.num_feature_levels, self.embed_dims) | |
) | |
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) | |
self.memory_trans_norm = nn.LayerNorm(self.embed_dims) | |
# text modules | |
self.language_model = MODELS.build(self.language_model_cfg) | |
self.text_feat_map = nn.Linear( | |
self.language_model.language_backbone.body.language_dim, | |
self.embed_dims, | |
bias=True, | |
) | |
def init_weights(self) -> None: | |
"""Initialize weights for Transformer and other components.""" | |
super().init_weights() | |
nn.init.constant_(self.text_feat_map.bias.data, 0) | |
nn.init.xavier_uniform_(self.text_feat_map.weight.data) | |
def to_enhance_text_prompts(self, original_caption, enhanced_text_prompts): | |
caption_string = "" | |
tokens_positive = [] | |
for idx, word in enumerate(original_caption): | |
if word in enhanced_text_prompts: | |
enhanced_text_dict = enhanced_text_prompts[word] | |
if "prefix" in enhanced_text_dict: | |
caption_string += enhanced_text_dict["prefix"] | |
start_i = len(caption_string) | |
if "name" in enhanced_text_dict: | |
caption_string += enhanced_text_dict["name"] | |
else: | |
caption_string += word | |
end_i = len(caption_string) | |
tokens_positive.append([[start_i, end_i]]) | |
if "suffix" in enhanced_text_dict: | |
caption_string += enhanced_text_dict["suffix"] | |
else: | |
tokens_positive.append( | |
[[len(caption_string), len(caption_string) + len(word)]] | |
) | |
caption_string += word | |
caption_string += self._special_tokens | |
return caption_string, tokens_positive | |
def to_plain_text_prompts(self, original_caption): | |
caption_string = "" | |
tokens_positive = [] | |
for idx, word in enumerate(original_caption): | |
tokens_positive.append( | |
[[len(caption_string), len(caption_string) + len(word)]] | |
) | |
caption_string += word | |
caption_string += self._special_tokens | |
return caption_string, tokens_positive | |
def get_tokens_and_prompts( | |
self, | |
original_caption: Union[str, list, tuple], | |
custom_entities: bool = False, | |
enhanced_text_prompts: Optional[ConfigType] = None, | |
) -> Tuple[dict, str, list]: | |
"""Get the tokens positive and prompts for the caption.""" | |
if isinstance(original_caption, (list, tuple)) or custom_entities: | |
if custom_entities and isinstance(original_caption, str): | |
original_caption = original_caption.strip(self._special_tokens) | |
original_caption = original_caption.split(self._special_tokens) | |
original_caption = list(filter(lambda x: len(x) > 0, original_caption)) | |
original_caption = [clean_label_name(i) for i in original_caption] | |
if custom_entities and enhanced_text_prompts is not None: | |
caption_string, tokens_positive = self.to_enhance_text_prompts( | |
original_caption, enhanced_text_prompts | |
) | |
else: | |
caption_string, tokens_positive = self.to_plain_text_prompts( | |
original_caption | |
) | |
# NOTE: Tokenizer in Grounding DINO is different from | |
# that in GLIP. The tokenizer in GLIP will pad the | |
# caption_string to max_length, while the tokenizer | |
# in Grounding DINO will not. | |
tokenized = self.language_model.tokenizer( | |
[caption_string], | |
padding="max_length" if self.language_model.pad_to_max else "longest", | |
return_tensors="pt", | |
) | |
entities = original_caption | |
else: | |
if not original_caption.endswith("."): | |
original_caption = original_caption + self._special_tokens | |
# NOTE: Tokenizer in Grounding DINO is different from | |
# that in GLIP. The tokenizer in GLIP will pad the | |
# caption_string to max_length, while the tokenizer | |
# in Grounding DINO will not. | |
tokenized = self.language_model.tokenizer( | |
[original_caption], | |
padding="max_length" if self.language_model.pad_to_max else "longest", | |
return_tensors="pt", | |
) | |
tokens_positive, noun_phrases = run_ner(original_caption) | |
entities = noun_phrases | |
caption_string = original_caption | |
return tokenized, caption_string, tokens_positive, entities | |
def get_positive_map(self, tokenized, tokens_positive): | |
positive_map = create_positive_map( | |
tokenized, | |
tokens_positive, | |
max_num_entities=self.bbox_head.cls_branches[ | |
self.decoder.num_layers | |
].max_text_len, | |
) | |
positive_map_label_to_token = create_positive_map_label_to_token( | |
positive_map, plus=1 | |
) | |
return positive_map_label_to_token, positive_map | |
def get_tokens_positive_and_prompts( | |
self, | |
original_caption: Union[str, list, tuple], | |
custom_entities: bool = False, | |
enhanced_text_prompt: Optional[ConfigType] = None, | |
tokens_positive: Optional[list] = None, | |
) -> Tuple[dict, str, Tensor, list]: | |
"""Get the tokens positive and prompts for the caption. | |
Args: | |
original_caption (str): The original caption, e.g. 'bench . car .' | |
custom_entities (bool, optional): Whether to use custom entities. | |
If ``True``, the ``original_caption`` should be a list of | |
strings, each of which is a word. Defaults to False. | |
Returns: | |
Tuple[dict, str, dict, str]: The dict is a mapping from each entity | |
id, which is numbered from 1, to its positive token id. | |
The str represents the prompts. | |
""" | |
if tokens_positive is not None: | |
if tokens_positive == -1: | |
if not original_caption.endswith("."): | |
original_caption = original_caption + self._special_tokens | |
return None, original_caption, None, original_caption | |
else: | |
if not original_caption.endswith("."): | |
original_caption = original_caption + self._special_tokens | |
tokenized = self.language_model.tokenizer( | |
[original_caption], | |
padding="max_length" | |
if self.language_model.pad_to_max | |
else "longest", | |
return_tensors="pt", | |
) | |
positive_map_label_to_token, positive_map = self.get_positive_map( | |
tokenized, tokens_positive | |
) | |
entities = [] | |
for token_positive in tokens_positive: | |
instance_entities = [] | |
for t in token_positive: | |
instance_entities.append(original_caption[t[0] : t[1]]) | |
entities.append(" / ".join(instance_entities)) | |
return ( | |
positive_map_label_to_token, | |
original_caption, | |
positive_map, | |
entities, | |
) | |
chunked_size = self.test_cfg.get("chunked_size", -1) | |
if not self.training and chunked_size > 0: | |
assert ( | |
isinstance(original_caption, (list, tuple)) or custom_entities is True | |
) | |
all_output = self.get_tokens_positive_and_prompts_chunked( | |
original_caption, enhanced_text_prompt | |
) | |
( | |
positive_map_label_to_token, | |
caption_string, | |
positive_map, | |
entities, | |
) = all_output | |
else: | |
( | |
tokenized, | |
caption_string, | |
tokens_positive, | |
entities, | |
) = self.get_tokens_and_prompts( | |
original_caption, custom_entities, enhanced_text_prompt | |
) | |
positive_map_label_to_token, positive_map = self.get_positive_map( | |
tokenized, tokens_positive | |
) | |
return positive_map_label_to_token, caption_string, positive_map, entities | |
def get_tokens_positive_and_prompts_chunked( | |
self, | |
original_caption: Union[list, tuple], | |
enhanced_text_prompts: Optional[ConfigType] = None, | |
): | |
chunked_size = self.test_cfg.get("chunked_size", -1) | |
original_caption = [clean_label_name(i) for i in original_caption] | |
original_caption_chunked = chunks(original_caption, chunked_size) | |
ids_chunked = chunks(list(range(1, len(original_caption) + 1)), chunked_size) | |
positive_map_label_to_token_chunked = [] | |
caption_string_chunked = [] | |
positive_map_chunked = [] | |
entities_chunked = [] | |
for i in range(len(ids_chunked)): | |
if enhanced_text_prompts is not None: | |
caption_string, tokens_positive = self.to_enhance_text_prompts( | |
original_caption_chunked[i], enhanced_text_prompts | |
) | |
else: | |
caption_string, tokens_positive = self.to_plain_text_prompts( | |
original_caption_chunked[i] | |
) | |
tokenized = self.language_model.tokenizer( | |
[caption_string], return_tensors="pt" | |
) | |
if tokenized.input_ids.shape[1] > self.language_model.max_tokens: | |
warnings.warn( | |
"Inputting a text that is too long will result " | |
"in poor prediction performance. " | |
"Please reduce the --chunked-size." | |
) | |
positive_map_label_to_token, positive_map = self.get_positive_map( | |
tokenized, tokens_positive | |
) | |
caption_string_chunked.append(caption_string) | |
positive_map_label_to_token_chunked.append(positive_map_label_to_token) | |
positive_map_chunked.append(positive_map) | |
entities_chunked.append(original_caption_chunked[i]) | |
return ( | |
positive_map_label_to_token_chunked, | |
caption_string_chunked, | |
positive_map_chunked, | |
entities_chunked, | |
) | |
def forward_transformer( | |
self, | |
img_feats: Tuple[Tensor], | |
text_dict: Dict, | |
batch_data_samples: OptSampleList = None, | |
) -> Dict: | |
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( | |
img_feats, batch_data_samples | |
) | |
encoder_outputs_dict = self.forward_encoder( | |
**encoder_inputs_dict, text_dict=text_dict | |
) | |
tmp_dec_in, head_inputs_dict = self.pre_decoder( | |
**encoder_outputs_dict, batch_data_samples=batch_data_samples | |
) | |
decoder_inputs_dict.update(tmp_dec_in) | |
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) | |
head_inputs_dict.update(decoder_outputs_dict) | |
return head_inputs_dict | |
def forward_encoder( | |
self, | |
feat: Tensor, | |
feat_mask: Tensor, | |
feat_pos: Tensor, | |
spatial_shapes: Tensor, | |
level_start_index: Tensor, | |
valid_ratios: Tensor, | |
text_dict: Dict, | |
) -> Dict: | |
text_token_mask = text_dict["text_token_mask"] | |
memory, memory_text = self.encoder( | |
query=feat, | |
query_pos=feat_pos, | |
key_padding_mask=feat_mask, # for self_attn | |
spatial_shapes=spatial_shapes, | |
level_start_index=level_start_index, | |
valid_ratios=valid_ratios, | |
# for text encoder | |
memory_text=text_dict["embedded"], | |
text_attention_mask=~text_token_mask, | |
position_ids=text_dict["position_ids"], | |
text_self_attention_masks=text_dict["masks"], | |
) | |
encoder_outputs_dict = dict( | |
memory=memory, | |
memory_mask=feat_mask, | |
spatial_shapes=spatial_shapes, | |
memory_text=memory_text, | |
text_token_mask=text_token_mask, | |
) | |
return encoder_outputs_dict | |
def pre_decoder( | |
self, | |
memory: Tensor, | |
memory_mask: Tensor, | |
spatial_shapes: Tensor, | |
memory_text: Tensor, | |
text_token_mask: Tensor, | |
batch_data_samples: OptSampleList = None, | |
) -> Tuple[Dict]: | |
bs, _, c = memory.shape | |
output_memory, output_proposals = self.gen_encoder_output_proposals( | |
memory, memory_mask, spatial_shapes | |
) | |
enc_outputs_class = self.bbox_head.cls_branches[self.decoder.num_layers]( | |
output_memory, memory_text, text_token_mask | |
) | |
cls_out_features = self.bbox_head.cls_branches[ | |
self.decoder.num_layers | |
].max_text_len | |
enc_outputs_coord_unact = ( | |
self.bbox_head.reg_branches[self.decoder.num_layers](output_memory) | |
+ output_proposals | |
) | |
# NOTE The DINO selects top-k proposals according to scores of | |
# multi-class classification, while DeformDETR, where the input | |
# is `enc_outputs_class[..., 0]` selects according to scores of | |
# binary classification. | |
topk_indices = torch.topk( | |
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1 | |
)[1] | |
topk_score = torch.gather( | |
enc_outputs_class, | |
1, | |
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features), | |
) | |
topk_coords_unact = torch.gather( | |
enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4) | |
) | |
topk_coords = topk_coords_unact.sigmoid() | |
topk_coords_unact = topk_coords_unact.detach() | |
query = self.query_embedding.weight[:, None, :] | |
query = query.repeat(1, bs, 1).transpose(0, 1) | |
if self.training: | |
dn_label_query, dn_bbox_query, dn_mask, dn_meta = self.dn_query_generator( | |
batch_data_samples | |
) | |
query = torch.cat([dn_label_query, query], dim=1) | |
reference_points = torch.cat([dn_bbox_query, topk_coords_unact], dim=1) | |
else: | |
reference_points = topk_coords_unact | |
dn_mask, dn_meta = None, None | |
reference_points = reference_points.sigmoid() | |
decoder_inputs_dict = dict( | |
query=query, | |
memory=memory, | |
reference_points=reference_points, | |
dn_mask=dn_mask, | |
memory_text=memory_text, | |
text_attention_mask=~text_token_mask, | |
) | |
# NOTE DINO calculates encoder losses on scores and coordinates | |
# of selected top-k encoder queries, while DeformDETR is of all | |
# encoder queries. | |
head_inputs_dict = ( | |
dict( | |
enc_outputs_class=topk_score, | |
enc_outputs_coord=topk_coords, | |
dn_meta=dn_meta, | |
) | |
if self.training | |
else dict() | |
) | |
# append text_feats to head_inputs_dict | |
head_inputs_dict["memory_text"] = memory_text | |
head_inputs_dict["text_token_mask"] = text_token_mask | |
return decoder_inputs_dict, head_inputs_dict | |
def loss( | |
self, batch_inputs: Tensor, batch_data_samples: SampleList | |
) -> Union[dict, list]: | |
text_prompts = [data_samples.text for data_samples in batch_data_samples] | |
gt_labels = [ | |
data_samples.gt_instances.labels for data_samples in batch_data_samples | |
] | |
if "tokens_positive" in batch_data_samples[0]: | |
tokens_positive = [ | |
data_samples.tokens_positive for data_samples in batch_data_samples | |
] | |
positive_maps = [] | |
for token_positive, text_prompt, gt_label in zip( | |
tokens_positive, text_prompts, gt_labels | |
): | |
tokenized = self.language_model.tokenizer( | |
[text_prompt], | |
padding="max_length" | |
if self.language_model.pad_to_max | |
else "longest", | |
return_tensors="pt", | |
) | |
new_tokens_positive = [ | |
token_positive[label.item()] for label in gt_label | |
] | |
_, positive_map = self.get_positive_map(tokenized, new_tokens_positive) | |
positive_maps.append(positive_map) | |
new_text_prompts = text_prompts | |
else: | |
new_text_prompts = [] | |
positive_maps = [] | |
if len(set(text_prompts)) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
( | |
tokenized, | |
caption_string, | |
tokens_positive, | |
_, | |
) = self.get_tokens_and_prompts(text_prompts[0], True) | |
new_text_prompts = [caption_string] * len(batch_inputs) | |
for gt_label in gt_labels: | |
new_tokens_positive = [tokens_positive[label] for label in gt_label] | |
_, positive_map = self.get_positive_map( | |
tokenized, new_tokens_positive | |
) | |
positive_maps.append(positive_map) | |
else: | |
for text_prompt, gt_label in zip(text_prompts, gt_labels): | |
( | |
tokenized, | |
caption_string, | |
tokens_positive, | |
_, | |
) = self.get_tokens_and_prompts(text_prompt, True) | |
new_tokens_positive = [tokens_positive[label] for label in gt_label] | |
_, positive_map = self.get_positive_map( | |
tokenized, new_tokens_positive | |
) | |
positive_maps.append(positive_map) | |
new_text_prompts.append(caption_string) | |
text_dict = self.language_model(new_text_prompts) | |
if self.text_feat_map is not None: | |
text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) | |
for i, data_samples in enumerate(batch_data_samples): | |
positive_map = positive_maps[i].to(batch_inputs.device).bool().float() | |
text_token_mask = text_dict["text_token_mask"][i] | |
data_samples.gt_instances.positive_maps = positive_map | |
data_samples.gt_instances.text_token_mask = text_token_mask.unsqueeze( | |
0 | |
).repeat(len(positive_map), 1) | |
if self.use_autocast: | |
with autocast(enabled=True): | |
visual_features = self.extract_feat(batch_inputs) | |
else: | |
visual_features = self.extract_feat(batch_inputs) | |
head_inputs_dict = self.forward_transformer( | |
visual_features, text_dict, batch_data_samples | |
) | |
losses = self.bbox_head.loss( | |
**head_inputs_dict, batch_data_samples=batch_data_samples | |
) | |
return losses | |
def predict(self, batch_inputs, batch_data_samples, rescale: bool = True): | |
text_prompts = [] | |
enhanced_text_prompts = [] | |
tokens_positives = [] | |
for data_samples in batch_data_samples: | |
text_prompts.append(data_samples.text) | |
if "caption_prompt" in data_samples: | |
enhanced_text_prompts.append(data_samples.caption_prompt) | |
else: | |
enhanced_text_prompts.append(None) | |
tokens_positives.append(data_samples.get("tokens_positive", None)) | |
if "custom_entities" in batch_data_samples[0]: | |
# Assuming that the `custom_entities` flag | |
# inside a batch is always the same. For single image inference | |
custom_entities = batch_data_samples[0].custom_entities | |
else: | |
custom_entities = False | |
if len(text_prompts) == 1: | |
# All the text prompts are the same, | |
# so there is no need to calculate them multiple times. | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts( | |
text_prompts[0], | |
custom_entities, | |
enhanced_text_prompts[0], | |
tokens_positives[0], | |
) | |
] * len(batch_inputs) | |
else: | |
_positive_maps_and_prompts = [ | |
self.get_tokens_positive_and_prompts( | |
text_prompt, custom_entities, enhanced_text_prompt, tokens_positive | |
) | |
for text_prompt, enhanced_text_prompt, tokens_positive in zip( | |
text_prompts, enhanced_text_prompts, tokens_positives | |
) | |
] | |
token_positive_maps, text_prompts, _, entities = zip( | |
*_positive_maps_and_prompts | |
) | |
# image feature extraction | |
visual_feats = self.extract_feat(batch_inputs) | |
if isinstance(text_prompts[0], list): | |
# chunked text prompts, only bs=1 is supported | |
assert len(batch_inputs) == 1 | |
count = 0 | |
results_list = [] | |
entities = [[item for lst in entities[0] for item in lst]] | |
for b in range(len(text_prompts[0])): | |
text_prompts_once = [text_prompts[0][b]] | |
token_positive_maps_once = token_positive_maps[0][b] | |
text_dict = self.language_model(text_prompts_once) | |
# text feature map layer | |
if self.text_feat_map is not None: | |
text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) | |
batch_data_samples[0].token_positive_map = token_positive_maps_once | |
head_inputs_dict = self.forward_transformer( | |
copy.deepcopy(visual_feats), text_dict, batch_data_samples | |
) | |
pred_instances = self.bbox_head.predict( | |
**head_inputs_dict, | |
rescale=rescale, | |
batch_data_samples=batch_data_samples, | |
)[0] | |
if len(pred_instances) > 0: | |
pred_instances.labels += count | |
count += len(token_positive_maps_once) | |
results_list.append(pred_instances) | |
results_list = [results_list[0].cat(results_list)] | |
is_rec_tasks = [False] * len(results_list) | |
else: | |
# extract text feats | |
text_dict = self.language_model(list(text_prompts)) | |
# text feature map layer | |
if self.text_feat_map is not None: | |
text_dict["embedded"] = self.text_feat_map(text_dict["embedded"]) | |
is_rec_tasks = [] | |
for i, data_samples in enumerate(batch_data_samples): | |
if token_positive_maps[i] is not None: | |
is_rec_tasks.append(False) | |
else: | |
is_rec_tasks.append(True) | |
data_samples.token_positive_map = token_positive_maps[i] | |
head_inputs_dict = self.forward_transformer( | |
visual_feats, text_dict, batch_data_samples | |
) | |
results_list = self.bbox_head.predict( | |
**head_inputs_dict, | |
rescale=rescale, | |
batch_data_samples=batch_data_samples, | |
) | |
for data_sample, pred_instances, entity, is_rec_task in zip( | |
batch_data_samples, results_list, entities, is_rec_tasks | |
): | |
if len(pred_instances) > 0: | |
label_names = [] | |
for labels in pred_instances.labels: | |
if is_rec_task: | |
label_names.append(entity) | |
continue | |
if labels >= len(entity): | |
warnings.warn( | |
"The unexpected output indicates an issue with " | |
"named entity recognition. You can try " | |
"setting custom_entities=True and running " | |
"again to see if it helps." | |
) | |
label_names.append("unobject") | |
else: | |
label_names.append(entity[labels]) | |
# for visualization | |
pred_instances.label_names = label_names | |
data_sample.pred_instances = pred_instances | |
return batch_data_samples | |