|
"""PyTorch TraVisionLM"""
|
|
import torch
|
|
from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM
|
|
from transformers.utils import logging, add_start_docstrings, ModelOutput
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple, Union
|
|
from torch import nn
|
|
from transformers.cache_utils import Cache
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
from .configuration_travisionlm import TraVisionLMConfig
|
|
|
|
_CONFIG_FOR_DOC = "TraVisionLMConfig"
|
|
|
|
@dataclass
|
|
class TraVisionCausalLMOutputWithPast(ModelOutput):
|
|
"""
|
|
Base class for TraVision language model (or autoregressive) outputs.
|
|
|
|
Args:
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
Language modeling loss (for next-token prediction).
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
|
`past_key_values` input) to speed up sequential decoding.
|
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
sequence_length)`.
|
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
|
heads.
|
|
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
|
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
|
|
sequence_length, hidden_size)`.
|
|
|
|
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
|
|
"""
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits: torch.FloatTensor = None
|
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
|
|
|
class TraVisionMultiModalProjector(nn.Module):
|
|
"""
|
|
Multimodal projector that cast the image features into the same dimension space as the language model
|
|
"""
|
|
def __init__(self, config: TraVisionLMConfig, dropout=0.1):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(config.vision_config.projection_dim, 4*config.vision_config.projection_dim, bias=True),
|
|
nn.GELU(),
|
|
nn.Linear(4*config.vision_config.projection_dim, config.hidden_size, bias=True),
|
|
nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(self, image_features):
|
|
hidden_states = self.net(image_features).to(image_features.dtype)
|
|
return hidden_states
|
|
|
|
|
|
TRAVISIONLM_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`TraVisionLMConfig`]):
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
load the weights associated with the model, only the configuration. Check out the
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
@add_start_docstrings(
|
|
"The bare TraVision Model outputting raw hidden-states without any specific head on top.",
|
|
TRAVISIONLM_START_DOCSTRING,
|
|
)
|
|
|
|
class TraVisionPreTrainedModel(PreTrainedModel):
|
|
config_class = TraVisionLMConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["TraVisionMultiModalProjector"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_flash_attn_2 = True
|
|
_supports_sdpa = True
|
|
|
|
def _init_weights(self, module):
|
|
|
|
|
|
std = (
|
|
self.config.initializer_range
|
|
if hasattr(self.config, "initializer_range")
|
|
else self.config.text_config.initializer_range
|
|
)
|
|
|
|
if hasattr(module, "class_embedding"):
|
|
module.class_embedding.data.normal_(mean=0.0, std=std)
|
|
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
|
@property
|
|
def _supports_sdpa(self):
|
|
"""
|
|
Retrieve language_model's attribute to check whether the model supports
|
|
SDPA or not.
|
|
"""
|
|
return self.language_model._supports_sdpa
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""The TraVisionLM model which consists of a vision backbone and a language model.""",
|
|
TRAVISIONLM_START_DOCSTRING,
|
|
)
|
|
class TraVisionForCausalLM(TraVisionPreTrainedModel):
|
|
def __init__(self, config: TraVisionLMConfig):
|
|
super(TraVisionForCausalLM, self).__init__(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.pad_token_id = -1 if config.pad_token_id == None else config.pad_token_id
|
|
self._attn_implementation = config._attn_implementation
|
|
self.gradient_checkpointing = False
|
|
|
|
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
|
self.vision_projector = TraVisionMultiModalProjector(config)
|
|
|
|
language_model = AutoModelForCausalLM.from_config(
|
|
config=config.text_config, attn_implementation=self._attn_implementation
|
|
)
|
|
|
|
if language_model._tied_weights_keys is not None:
|
|
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
|
|
|
self.language_model = language_model
|
|
|
|
self.post_init()
|
|
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
return self.language_model.get_output_embeddings()
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.language_model.set_output_embeddings(new_embeddings)
|
|
|
|
|
|
def set_decoder(self, decoder):
|
|
self.language_model.set_decoder(decoder)
|
|
|
|
|
|
def get_decoder(self):
|
|
return self.language_model.get_decoder()
|
|
|
|
|
|
def tie_weights(self):
|
|
return self.language_model.tie_weights()
|
|
|
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
|
|
|
|
|
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
|
self.config.vocab_size = model_embeds.num_embeddings
|
|
self.vocab_size = model_embeds.num_embeddings
|
|
return model_embeds
|
|
|
|
|
|
""" !!! Two significant modifications are made to the original code:
|
|
------> 1) The pad and eos tokens are set to be the same in TraVisionProcessor. Hence, only the features corresponding to the padding mask are filtered out
|
|
using the attention mask.
|
|
------> 2) The features corresponding to both the prompts (called prefixes in PaliGemma) and labels (called suffixes in PaliGemma) are added the final embedding tensor
|
|
and the tokens of both the prompts and labels are applied causal attention mask. All the image tokens are attended using full-attention mask.
|
|
NOTE: In the original PaliGemma implementation, only the suffix tokens are applied causal masking. Check out [PaliGemma arXiv Paper](https://arxiv.org/pdf/2407.07726)
|
|
for the details.
|
|
"""
|
|
def _merge_input_ids_with_image_features(
|
|
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
|
):
|
|
_, _, embed_dim = image_features.shape
|
|
batch_size, sequence_length = input_ids.shape
|
|
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
|
min_dtype = torch.finfo(dtype).min
|
|
|
|
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
|
final_embedding = torch.zeros(
|
|
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
|
)
|
|
|
|
text_mask = (input_ids != self.config.image_token_index) & (attention_mask | input_ids != self.config.text_config.pad_token_id)
|
|
image_mask = input_ids == self.config.image_token_index
|
|
pad_mask = (attention_mask == 0) & (input_ids == self.config.text_config.pad_token_id)
|
|
|
|
|
|
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
|
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
|
|
|
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
|
|
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
|
|
|
final_embedding = final_embedding.masked_scatter(
|
|
image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
|
|
scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
|
|
)
|
|
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
|
if attention_mask is not None:
|
|
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
|
else:
|
|
position_ids = None
|
|
|
|
if token_type_ids is not None:
|
|
|
|
target_length = cache_position[-1] + 1
|
|
causal_mask = torch.full(
|
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
|
)
|
|
if sequence_length != 1:
|
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
|
if attention_mask is not None:
|
|
causal_mask = causal_mask.clone()
|
|
mask_length = attention_mask.shape[-1]
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
|
causal_mask.device
|
|
)
|
|
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
|
)
|
|
padding_mask = padding_mask == 0
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
padding_mask, min_dtype
|
|
)
|
|
|
|
final_labels = None
|
|
if labels is not None:
|
|
final_labels = torch.full(
|
|
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
|
)
|
|
final_labels = torch.where((attention_mask | input_ids != self.config.text_config.pad_token_id), labels, final_labels)
|
|
else:
|
|
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
|
|
|
causal_mask = torch.where(causal_mask == 0, min_dtype, 0).to(dtype)
|
|
final_labels = None
|
|
|
|
return final_embedding, causal_mask, final_labels, position_ids
|
|
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, TraVisionCausalLMOutputWithPast]:
|
|
|
|
if labels is not None:
|
|
use_cache = False
|
|
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
|
input_shape = input_ids.size()
|
|
input_ids = input_ids.view(-1, input_shape[-1])
|
|
batch_size = input_ids.shape[0]
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
batch_size = inputs_embeds.shape[0]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if past_key_values is None:
|
|
past_length = 0
|
|
past_key_values = tuple([None] * len(self.language_model.transformer.h))
|
|
else:
|
|
past_length = past_key_values[0][0].size(-2)
|
|
if position_ids is None:
|
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device)
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
|
|
input_attention_mask = attention_mask
|
|
|
|
if inputs_embeds is None:
|
|
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
|
|
if pixel_values is not None and inputs_embeds.shape[1] != 1:
|
|
|
|
position_ids_mask = torch.where(input_ids != self.config.image_token_index, position_ids, 1)
|
|
|
|
position_ids_mask[:, :-1] = torch.where(input_ids[:, :-1] != 0, position_ids_mask[:, :-1], 1)
|
|
|
|
first_position_embed_locs = torch.sum(position_ids_mask == 1, dim=1)
|
|
|
|
position_ids_mask.sub_(first_position_embed_locs[:, None])
|
|
|
|
position_emb_ids = torch.where(position_ids_mask >= 0, position_ids_mask, 1)
|
|
|
|
position_embeds = self.language_model.transformer.wpe(position_emb_ids)
|
|
else:
|
|
|
|
pos_emb_ind = position_ids.view(batch_size, -1)
|
|
position_embeds = self.language_model.transformer.wpe(pos_emb_ind)
|
|
|
|
|
|
hidden_states = inputs_embeds + position_embeds
|
|
|
|
|
|
if pixel_values is not None and input_ids.shape[1] != 1:
|
|
|
|
if pixel_values.dim() == 3:
|
|
pixel_values = pixel_values.unsqueeze(dim=0)
|
|
|
|
image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
|
selected_image_feature = image_outputs.last_hidden_state
|
|
image_features = self.vision_projector(selected_image_feature)
|
|
|
|
if cache_position is None:
|
|
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
|
hidden_states, attention_mask, labels, _ = self._merge_input_ids_with_image_features(
|
|
image_features, hidden_states, input_ids, attention_mask, labels, token_type_ids, cache_position
|
|
)
|
|
|
|
else:
|
|
|
|
|
|
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
|
|
|
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.view(batch_size, -1)
|
|
if self._attn_implementation == "flash_attention_2":
|
|
attention_mask = attention_mask if 0 in attention_mask else None
|
|
elif _use_sdpa:
|
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
attention_mask=attention_mask,
|
|
input_shape=(batch_size, input_shape[-1]),
|
|
inputs_embeds=inputs_embeds,
|
|
past_key_values_length=past_length,
|
|
)
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask = attention_mask[:, None, None, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask = attention_mask.to(dtype=self.dtype)
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
|
|
hidden_states = self.language_model.transformer.drop(hidden_states)
|
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
|
|
|
presents = () if use_cache else None
|
|
all_self_attentions = () if output_attentions else None
|
|
all_hidden_states = () if output_hidden_states else None
|
|
for i, (block, layer_past) in enumerate(zip(self.language_model.transformer.h, past_key_values)):
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
outputs = block(
|
|
hidden_states,
|
|
layer_past=layer_past,
|
|
attention_mask=attention_mask,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
)
|
|
hidden_states = outputs[0]
|
|
if use_cache is True:
|
|
presents = presents + (outputs[1],)
|
|
|
|
if output_attentions:
|
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
|
hidden_states = self.language_model.transformer.ln_f(hidden_states)
|
|
|
|
hidden_states = hidden_states.view(output_shape)
|
|
|
|
if output_hidden_states:
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
logits = self.language_model.lm_head(hidden_states)
|
|
logits = logits.float()
|
|
loss = None
|
|
if labels is not None:
|
|
shift_logits = logits[..., :-1, :]
|
|
shift_labels = labels[..., 1:]
|
|
if input_attention_mask is not None:
|
|
|
|
shift_attention_mask = input_attention_mask[..., 1:]
|
|
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
|
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
|
else:
|
|
shift_logits = shift_logits.contiguous()
|
|
shift_labels = shift_labels.contiguous()
|
|
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
|
|
flat_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
loss = loss_fct(flat_logits, flat_labels)
|
|
if not return_dict:
|
|
output = (logits, presents, all_hidden_states, all_self_attentions)
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
return TraVisionCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=presents,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
cache_position=None,
|
|
position_ids=None,
|
|
pixel_values=None,
|
|
attention_mask=None,
|
|
token_type_ids=None,
|
|
use_cache=True,
|
|
**kwargs,
|
|
):
|
|
|
|
if attention_mask is not None and position_ids is None:
|
|
if past_key_values:
|
|
position_ids_mask = (input_ids != self.config.image_token_index)
|
|
position_ids_mask[:, :-1] &= (input_ids[:, :-1] != self.config.text_config.pad_token_id)
|
|
last_index = position_ids_mask.sum(dim=1) - 1
|
|
position_ids = torch.stack([torch.arange(start, start+cache_position.shape[0], device=input_ids.device) for start in last_index])
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None:
|
|
if inputs_embeds is not None:
|
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
elif input_ids.shape[1] != cache_position.shape[0]:
|
|
input_ids = input_ids[:, cache_position]
|
|
|
|
|
|
if inputs_embeds is not None and cache_position[0] == 0:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids.contiguous()}
|
|
|
|
model_inputs.update(
|
|
{
|
|
"position_ids": position_ids,
|
|
"past_key_values": past_key_values,
|
|
"cache_position": cache_position,
|
|
"use_cache": use_cache,
|
|
"attention_mask": attention_mask,
|
|
"pixel_values": pixel_values,
|
|
"token_type_ids": token_type_ids,
|
|
}
|
|
)
|
|
return model_inputs |