OLA-VLM / ola_vlm /model /aux_heads /oneformer_head.py
praeclarumjj3's picture
:zap: add code
9fa3d89
raw
history blame
10.5 kB
import torch
from typing import Optional
from torch import Tensor, nn
from ola_vlm.model.multimodal_projector.resampler import Resampler, TaskTokenResampler
import math
from torch.nn import functional as F
from transformers import OneFormerModel
from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput, OneFormerModelOutput, OneFormerPixelLevelModule, OneFormerPixelLevelModuleOutput
class AuxOneFormerPixelLevelModule(OneFormerPixelLevelModule):
def __init__(self, config):
super().__init__(config)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False, last_backbone_feats: Tensor = None, all_backbone_features: Tensor = None, return_features: bool = False, return_all_features: bool = False):
if all_backbone_features is None:
features = self.encoder(pixel_values).feature_maps
if return_all_features:
return features
else:
features = all_backbone_features
if last_backbone_feats is not None:
features = list(features)
last_backbone_feats = F.interpolate(last_backbone_feats, size=features[-1].shape[-2:], mode='bilinear', align_corners=False)
features[-1] = last_backbone_feats
for i in range(3):
features[i] = F.interpolate(features[i], size=features[-1].shape[-2:], mode='bilinear', align_corners=False)
features = tuple(features)
elif return_features:
return F.interpolate(features[-1], size=(24, 24), mode='bilinear', align_corners=False)
decoder_output = self.decoder(features, output_hidden_states=output_hidden_states)
return OneFormerPixelLevelModuleOutput(
encoder_features=tuple(features),
decoder_features=decoder_output.multi_scale_features,
decoder_last_feature=decoder_output.mask_features,
)
class OneFormerHead(OneFormerModel):
def __init__(self, config):
super().__init__(config)
self.pixel_level_module = AuxOneFormerPixelLevelModule(config)
def forward_features(
self,
pixel_values: Tensor,
task_inputs: Tensor,
text_inputs: Tensor = None,
pixel_mask: Tensor = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
backbone_last_feature = self.pixel_level_module(pixel_values, output_hidden_states, return_features=True)
return backbone_last_feature
def get_backbone_feats(
self,
pixel_values: Tensor,
task_inputs: Tensor,
text_inputs: Tensor = None,
pixel_mask: Tensor = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
backbone_last_feature = self.pixel_level_module(pixel_values, output_hidden_states, return_all_features=True)
return backbone_last_feature
def get_masks(
self,
pixel_values: Tensor,
task_inputs: Tensor,
text_inputs: Tensor = None,
pixel_mask: Tensor = None,
backbone_last_feature: Tensor = None,
all_backbone_features: Tensor = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states, backbone_last_feature, all_backbone_features)
multi_scale_features = pixel_level_module_output.decoder_features
mask_features = pixel_level_module_output.decoder_last_feature
task_token = self.task_encoder(task_inputs.to(self.dtype))
if self.is_training:
text_queries = self.text_mapper(text_inputs)
else:
text_queries = None
transformer_module_output = self.transformer_module(
multi_scale_features=multi_scale_features,
mask_features=mask_features,
task_token=task_token,
output_attentions=output_attentions,
)
queries = transformer_module_output.object_queries
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_features
pixel_decoder_hidden_states = (pixel_level_module_output.decoder_last_feature,)
for f in pixel_level_module_output.decoder_features:
pixel_decoder_hidden_states += (f,)
transformer_decoder_hidden_states = transformer_module_output.auxiliary_predictions
outputs = OneFormerModelOutput(
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
transformer_decoder_object_queries=queries,
transformer_decoder_contrastive_queries=transformer_module_output.contrastive_logits,
transformer_decoder_mask_predictions=transformer_module_output.prediction_masks,
transformer_decoder_class_predictions=transformer_module_output.prediction_class,
transformer_decoder_auxiliary_predictions=transformer_module_output.auxiliary_predictions,
text_queries=text_queries,
task_token=task_token,
attentions=transformer_module_output.attentions,
)
class_queries_logits = outputs.transformer_decoder_class_predictions
masks_queries_logits = outputs.transformer_decoder_mask_predictions
contrastive_queries_logits = outputs.transformer_decoder_contrastive_queries
auxiliary_predictions = outputs.transformer_decoder_auxiliary_predictions
text_queries = outputs.text_queries
output = OneFormerForUniversalSegmentationOutput(
class_queries_logits=class_queries_logits,
masks_queries_logits=masks_queries_logits,
auxiliary_predictions=auxiliary_predictions,
loss=None,
**outputs,
)
return output
class OneFormerSegHead(nn.Module):
def __init__(
self,
proj_config: dict = None,
llm_hidden_size: int = 4096,
) -> None:
super().__init__()
self.projector = Resampler(
dim=proj_config["output_dim"],
depth=proj_config["depth"],
dim_head=proj_config["dim_head"],
heads=proj_config["num_heads"],
num_queries=proj_config["num_tokens"],
embedding_dim=llm_hidden_size,
output_dim=proj_config["output_dim"],
ff_mult=proj_config["ff_mult"],
)
def forward(
self,
llm_feats: torch.Tensor,
):
visual_feats = self.projector(llm_feats)
b, n, c = visual_feats.shape
b = int(b)
c = int(c)
h = w = int(math.sqrt(int(n)))
visual_feats = visual_feats.permute(0, 2, 1)
image_embeddings = visual_feats.reshape(b, c, h, w)
return image_embeddings
class OneFormerTaskTokenSegHead(nn.Module):
def __init__(
self,
proj_config: dict = None,
llm_hidden_size: int = 4096,
) -> None:
super().__init__()
self.projector = TaskTokenResampler(
dim=proj_config["output_dim"],
depth=proj_config["depth"],
dim_head=proj_config["dim_head"],
heads=proj_config["num_heads"],
num_queries=proj_config["num_tokens"],
embedding_dim=llm_hidden_size,
output_dim=proj_config["output_dim"],
ff_mult=proj_config["ff_mult"],
)
def forward(
self,
llm_feats: torch.Tensor,
latents: torch.Tensor,
):
visual_feats = self.projector(llm_feats, latents)
b, n, c = visual_feats.shape
b = int(b)
c = int(c)
h = w = int(math.sqrt(int(n)))
visual_feats = visual_feats.permute(0, 2, 1)
image_embeddings = visual_feats.reshape(b, c, h, w)
return image_embeddings
def build_mlp(in_hidden_size, hidden_size):
modules = [nn.Linear(in_hidden_size, hidden_size)]
modules.append(nn.GELU())
modules.append(nn.Linear(hidden_size, hidden_size))
return nn.Sequential(*modules)