Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import Optional | |
from torch import Tensor, nn | |
from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler | |
import math | |
from torch.nn import functional as F | |
from transformers import OneFormerModel | |
from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput, OneFormerModelOutput, OneFormerPixelLevelModule, OneFormerPixelLevelModuleOutput | |
class AuxOneFormerPixelLevelModule(OneFormerPixelLevelModule): | |
def __init__(self, config): | |
super().__init__(config) | |
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False, last_backbone_feats: Tensor = None, all_backbone_features: Tensor = None, return_features: bool = False, return_all_features: bool = False): | |
if all_backbone_features is None: | |
features = self.encoder(pixel_values).feature_maps | |
if return_all_features: | |
return features | |
else: | |
features = all_backbone_features | |
if last_backbone_feats is not None: | |
features = list(features) | |
last_backbone_feats = F.interpolate(last_backbone_feats, size=features[-1].shape[-2:], mode='bilinear', align_corners=False) | |
features[-1] = last_backbone_feats | |
for i in range(3): | |
features[i] = F.interpolate(features[i], size=features[-1].shape[-2:], mode='bilinear', align_corners=False) | |
features = tuple(features) | |
elif return_features: | |
return F.interpolate(features[-1], size=(24, 24), mode='bilinear', align_corners=False) | |
decoder_output = self.decoder(features, output_hidden_states=output_hidden_states) | |
return OneFormerPixelLevelModuleOutput( | |
encoder_features=tuple(features), | |
decoder_features=decoder_output.multi_scale_features, | |
decoder_last_feature=decoder_output.mask_features, | |
) | |
class OneFormerHead(OneFormerModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.pixel_level_module = AuxOneFormerPixelLevelModule(config) | |
def forward_features( | |
self, | |
pixel_values: Tensor, | |
task_inputs: Tensor, | |
text_inputs: Tensor = None, | |
pixel_mask: Tensor = None, | |
output_hidden_states: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, _, height, width = pixel_values.shape | |
if pixel_mask is None: | |
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) | |
backbone_last_feature = self.pixel_level_module(pixel_values, output_hidden_states, return_features=True) | |
return backbone_last_feature | |
def get_backbone_feats( | |
self, | |
pixel_values: Tensor, | |
task_inputs: Tensor, | |
text_inputs: Tensor = None, | |
pixel_mask: Tensor = None, | |
output_hidden_states: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, _, height, width = pixel_values.shape | |
if pixel_mask is None: | |
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) | |
backbone_last_feature = self.pixel_level_module(pixel_values, output_hidden_states, return_all_features=True) | |
return backbone_last_feature | |
def get_masks( | |
self, | |
pixel_values: Tensor, | |
task_inputs: Tensor, | |
text_inputs: Tensor = None, | |
pixel_mask: Tensor = None, | |
backbone_last_feature: Tensor = None, | |
all_backbone_features: Tensor = None, | |
output_hidden_states: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
if pixel_values is None: | |
raise ValueError("You have to specify pixel_values") | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, _, height, width = pixel_values.shape | |
if pixel_mask is None: | |
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) | |
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states, backbone_last_feature, all_backbone_features) | |
multi_scale_features = pixel_level_module_output.decoder_features | |
mask_features = pixel_level_module_output.decoder_last_feature | |
task_token = self.task_encoder(task_inputs.to(self.dtype)) | |
if self.is_training: | |
text_queries = self.text_mapper(text_inputs) | |
else: | |
text_queries = None | |
transformer_module_output = self.transformer_module( | |
multi_scale_features=multi_scale_features, | |
mask_features=mask_features, | |
task_token=task_token, | |
output_attentions=output_attentions, | |
) | |
queries = transformer_module_output.object_queries | |
encoder_hidden_states = None | |
pixel_decoder_hidden_states = None | |
transformer_decoder_hidden_states = None | |
if output_hidden_states: | |
encoder_hidden_states = pixel_level_module_output.encoder_features | |
pixel_decoder_hidden_states = (pixel_level_module_output.decoder_last_feature,) | |
for f in pixel_level_module_output.decoder_features: | |
pixel_decoder_hidden_states += (f,) | |
transformer_decoder_hidden_states = transformer_module_output.auxiliary_predictions | |
outputs = OneFormerModelOutput( | |
encoder_hidden_states=encoder_hidden_states, | |
pixel_decoder_hidden_states=pixel_decoder_hidden_states, | |
transformer_decoder_hidden_states=transformer_decoder_hidden_states, | |
transformer_decoder_object_queries=queries, | |
transformer_decoder_contrastive_queries=transformer_module_output.contrastive_logits, | |
transformer_decoder_mask_predictions=transformer_module_output.prediction_masks, | |
transformer_decoder_class_predictions=transformer_module_output.prediction_class, | |
transformer_decoder_auxiliary_predictions=transformer_module_output.auxiliary_predictions, | |
text_queries=text_queries, | |
task_token=task_token, | |
attentions=transformer_module_output.attentions, | |
) | |
class_queries_logits = outputs.transformer_decoder_class_predictions | |
masks_queries_logits = outputs.transformer_decoder_mask_predictions | |
contrastive_queries_logits = outputs.transformer_decoder_contrastive_queries | |
auxiliary_predictions = outputs.transformer_decoder_auxiliary_predictions | |
text_queries = outputs.text_queries | |
output = OneFormerForUniversalSegmentationOutput( | |
class_queries_logits=class_queries_logits, | |
masks_queries_logits=masks_queries_logits, | |
auxiliary_predictions=auxiliary_predictions, | |
loss=None, | |
**outputs, | |
) | |
return output | |
class OneFormerSegHead(nn.Module): | |
def __init__( | |
self, | |
proj_config: dict = None, | |
llm_hidden_size: int = 4096, | |
) -> None: | |
super().__init__() | |
self.projector = Resampler( | |
dim=proj_config["output_dim"], | |
depth=proj_config["depth"], | |
dim_head=proj_config["dim_head"], | |
heads=proj_config["num_heads"], | |
num_queries=proj_config["num_tokens"], | |
embedding_dim=llm_hidden_size, | |
output_dim=proj_config["output_dim"], | |
ff_mult=proj_config["ff_mult"], | |
) | |
def forward( | |
self, | |
llm_feats: torch.Tensor, | |
): | |
visual_feats = self.projector(llm_feats) | |
b, n, c = visual_feats.shape | |
b = int(b) | |
c = int(c) | |
h = w = int(math.sqrt(int(n))) | |
visual_feats = visual_feats.permute(0, 2, 1) | |
image_embeddings = visual_feats.reshape(b, c, h, w) | |
return image_embeddings | |
class OneFormerTaskTokenSegHead(nn.Module): | |
def __init__( | |
self, | |
proj_config: dict = None, | |
llm_hidden_size: int = 4096, | |
) -> None: | |
super().__init__() | |
self.projector = TaskTokenResampler( | |
dim=proj_config["output_dim"], | |
depth=proj_config["depth"], | |
dim_head=proj_config["dim_head"], | |
heads=proj_config["num_heads"], | |
num_queries=proj_config["num_tokens"], | |
embedding_dim=llm_hidden_size, | |
output_dim=proj_config["output_dim"], | |
ff_mult=proj_config["ff_mult"], | |
) | |
def forward( | |
self, | |
llm_feats: torch.Tensor, | |
latents: torch.Tensor, | |
): | |
visual_feats = self.projector(llm_feats, latents) | |
b, n, c = visual_feats.shape | |
b = int(b) | |
c = int(c) | |
h = w = int(math.sqrt(int(n))) | |
visual_feats = visual_feats.permute(0, 2, 1) | |
image_embeddings = visual_feats.reshape(b, c, h, w) | |
return image_embeddings | |
def build_mlp(in_hidden_size, hidden_size): | |
modules = [nn.Linear(in_hidden_size, hidden_size)] | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(hidden_size, hidden_size)) | |
return nn.Sequential(*modules) |