molmo-dense-captioner-v22-qwen2 / modeling_molmo.py
cydhsieh01's picture
Upload folder using huggingface_hub
56df21f verified
raw
history blame
57 kB
import math
from copy import deepcopy
from dataclasses import fields, dataclass, replace
from enum import Enum
from typing import List, Optional, Tuple, Union, Dict, Any, Sequence, Callable, cast, MutableMapping
import torch
from transformers import PreTrainedModel, GenerationConfig, add_start_docstrings
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto import AutoModelForCausalLM
from torch import nn
from transformers.utils import logging
from .config_molmo import MolmoConfig, MolmoVisionConfig
from torch.nn import functional as F
logger = logging.get_logger(__name__)
MOLMO_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MolmoConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Molmo Model outputting raw hidden-states without any specific head on top.",
MOLMO_START_DOCSTRING,
)
class MolmoPreTrainedModel(PreTrainedModel):
config_class = MolmoConfig
base_model_prefix = "model"
_no_split_modules = ["MolmoBlock", "MolmoeBlock", "MolmoVisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
# supports_gradient_checkpointing = True
# _supports_cache_class = True
# _supports_static_cache = False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear,)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
class MolmoRotaryEmbedding(nn.Module):
"""
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
"""
def __init__(self, dim, max_position_embeddings=2048, rope_theta=10000, full_precision=True, device=None):
super().__init__()
self.dim = dim
self.rope_theta = rope_theta
self.full_precision = full_precision
self.max_position_embeddings = max_position_embeddings
# Cache sin/cos embeddings
dim = self.dim
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
seq = torch.arange(self.max_position_embeddings, device=device, dtype=torch.float)
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
positions = torch.cat((freqs, freqs), dim=-1)
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
self.register_buffer("rope_pos_sin", pos_sin, persistent=False)
self.register_buffer("rope_pos_cos", pos_cos, persistent=False)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
B, nh, T, hs = x.size()
x = x.view(B, nh, T, 2, hs // 2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (t * pos_cos) + (self.rotate_half(t) * pos_sin)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.full_precision:
q_, k_ = q.float(), k.float()
else:
q_, k_ = q, k
with torch.autocast(q.device.type, enabled=False):
batch_size = q_.shape[0]
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
if position_ids is not None:
freqs_cis_len = self.max_position_embeddings
else:
freqs_cis_len = key_len
# self.get_rotary_embedding(freqs_cis_len, q_.device)
pos_sin = self.rope_pos_sin[:, :, :freqs_cis_len, :].type_as(q_)
pos_cos = self.rope_pos_cos[:, :, :freqs_cis_len, :].type_as(q_)
if position_ids is not None:
assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
pos_sin = pos_sin[0, 0][position_ids].view(
(batch_size, 1, key_len, pos_sin.shape[-1])
)
pos_cos = pos_cos[0, 0][position_ids].view(
(batch_size, 1, key_len, pos_cos.shape[-1])
)
q_ = self.apply_rotary_pos_emb(
pos_sin[:, :, key_len - query_len : key_len, :],
pos_cos[:, :, key_len - query_len : key_len, :],
q_,
)
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
return q_.type_as(q), k_.type_as(k)
class MolmoAttention(nn.Module):
def __init__(
self,
config: MolmoConfig,
device=None
):
super().__init__()
self.config = config
self.rotary_emb = MolmoRotaryEmbedding(
config.hidden_size // config.num_attention_heads,
config.max_position_embeddings,
config.rope_theta, device=device)
self.k_norm: Optional[nn.Module] = None
self.q_norm: Optional[nn.Module] = None
self.hidden_size = config.intermediate_size
if config.qk_layer_norm:
if config.num_key_value_heads is None:
config.num_key_value_heads = config.num_attention_heads
self.q_norm = MolmoRmsLayerNorm(
config,
size=config.hidden_size,
eps=config.layer_norm_eps
)
self.k_norm = MolmoRmsLayerNorm(
config,
size=config.hidden_size,
eps=config.layer_norm_eps
)
# Attention output projection.
input_dim = config.hidden_size
head_dim = config.hidden_size // config.num_attention_heads
self.fused_dims = (
config.hidden_size,
config.num_key_value_heads * head_dim,
config.num_key_value_heads * head_dim,
)
self.att_proj = nn.Linear(
config.hidden_size, sum(self.fused_dims),
bias=config.qkv_bias,
)
self.attn_out = nn.Linear(
input_dim, config.hidden_size,
bias=False,
)
def attention(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
drop_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = q.size() # batch size, sequence length, hidden_size
dtype = k.dtype
# Optionally apply layer norm to keys and queries.
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q).to(dtype=dtype)
k = self.k_norm(k).to(dtype=dtype)
# Move head forward to be next to the batch dim.
# shape: (B, nh, T, hs)
q = q.view(B, T, self.config.num_attention_heads, C // self.config.num_attention_heads).transpose(1, 2)
# shape: (B, n_kv_h, T, hs)
k = k.view(B, T, self.config.num_key_value_heads, C // self.config.num_attention_heads).transpose(1, 2)
# shape: (B, n_kv_h, T, hs)
v = v.view(B, T, self.config.num_key_value_heads, C // self.config.num_attention_heads).transpose(1, 2)
# Apply rotary embeddings
q, k = self.rotary_emb(q, k, position_ids=position_ids)
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key.to(k.device), k), dim=-2)
v = torch.cat((past_value.to(v.device), v), dim=-2)
present = (k, v) if use_cache else None
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
if attention_mask is not None:
attention_mask = attention_mask[:, :, key_len - query_len: key_len, :key_len]
# if attention_bias is not None:
# attention_bias = self._cast_attn_bias(
# attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype)
# Get the attention scores.
# shape: (B, nh, T, hs)
att = self._scaled_dot_product_attention(
q,
k,
v,
attention_mask=attention_mask,
dropout_p=0.0 if not self.training else self.config.attention_dropout,
is_causal=attention_mask is None,
)
# Re-assemble all head outputs side-by-side.
att = att.transpose(1, 2).contiguous().view(B, T, C)
# Apply output projection.
return self.attn_out(att), present
def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
if attention_mask is not None:
attention_mask = attention_mask.to(q.device)
if self.config.attention_type == "sdpa":
assert k.size(1) == v.size(1)
num_kv_heads = k.size(1)
num_q_heads = q.size(1)
if num_q_heads != num_kv_heads:
assert num_q_heads % num_kv_heads == 0
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
elif self.config.attention_type == "flash":
# Downcast in case we are running with fp32 hidden states
# Our attention mask is [1, 1, N, N]
valid_mask = torch.reduce_any(attention_mask, -1)[0]
attn_output = _flash_attention_forward(
q.transpose(1, 2).to(torch.bfloat16),
k.transpose(1, 2).to(torch.bfloat16),
v.transpose(1, 2).to(torch.bfloat16),
attention_mask=valid_mask,
query_length=q.shape[2],
is_causal=True,
)
else:
raise NotImplementedError(self.config.attention_type)
def forward(
self,
x,
attention_mask,
position_ids,
layer_past,
use_cache
):
qkv = self.att_proj(x)
q, k, v = qkv.split(self.fused_dims, dim=-1)
# Get attention scores.
att, cache = self.attention(
q, k, v,
attention_mask,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache
)
return att, cache
class MolmoMlp(nn.Module):
def __init__(self, input_dim, hidden_size, activation_fn, include_bias=False):
super().__init__()
self.ff_proj = nn.Linear(input_dim, hidden_size, bias=include_bias)
self.ff_out = nn.Linear(hidden_size//2, input_dim, bias=include_bias)
self.act = ACT2FN[activation_fn]
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
x = self.ff_proj(x)
x, gate = x.chunk(2, dim=-1)
x = self.act(gate) * x
x = self.ff_out(x)
return x
class MolmoBlock(nn.Module):
def __init__(self, config: MolmoConfig, device=None):
super().__init__()
self.config = config
self.hidden_size = config.intermediate_size
self.dropout = nn.Dropout(config.residual_dropout)
self.attn = MolmoAttention(config)
self.attn_norm = MolmoRmsLayerNorm(config, size=config.hidden_size, eps=config.layer_norm_eps)
self.mlp = MolmoMlp(config.hidden_size, config.intermediate_size, config.activation_type)
self.ff_norm = MolmoRmsLayerNorm(config)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
if not self.config.norm_after:
atten_in = self.attn_norm(x)
else:
atten_in = x
att, cache = self.attn(
atten_in,
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache
)
if self.config.norm_after:
att = self.attn_norm(att)
x = x + self.dropout(att)
og_x = x
if not self.config.norm_after:
x = self.ff_norm(x)
x = self.mlp(x)
if self.config.norm_after:
x = self.ff_norm(x)
x = self.dropout(x)
x = og_x + x
return x, cache
class MolmoeMLP(nn.Module):
def __init__(self, input_dim, hidden_size, activation):
super().__init__()
self.gate_proj = nn.Linear(input_dim, hidden_size, bias=False)
self.up_proj = nn.Linear(input_dim, hidden_size, bias=False)
self.down_proj = nn.Linear(hidden_size, input_dim, bias=False)
self.act_fn = ACT2FN[activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MolmoeMlpExpert(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
self.top_k = config.moe_top_k
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList([MolmoeMLP(config.hidden_size, config.intermediate_size // 2, config.activation_type)
for _ in range(self.num_experts)])
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states = self.ff_norm(hidden_states)
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be selected
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
class MolmoeBlock(nn.Module):
def __init__(self, config: MolmoConfig):
super().__init__()
self.attn = MolmoAttention(config)
self.attn_norm = MolmoRmsLayerNorm(config, size=config.hidden_size, eps=config.layer_norm_eps)
assert config.moe_num_experts > 0
self.ff_norm = MolmoRmsLayerNorm(config, size=config.hidden_size, eps=config.layer_norm_eps)
self.mlp = MolmoeMlpExpert(config)
self.config = config
self.hidden_size = config.intermediate_size
self.dropout = nn.Dropout(config.residual_dropout)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
if not self.config.norm_after:
atten_in = self.attn_norm(x)
else:
atten_in = x
att, cache = self.attn(
atten_in,
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
use_cache=use_cache
)
if self.config.norm_after:
att = self.attn_norm(att)
x = x + self.dropout(att)
og_x = x
if not self.config.norm_after:
x = self.ff_norm(x)
x, _ = self.mlp(x)
if self.config.norm_after:
x = self.ff_norm(x)
x = self.dropout(x)
x = og_x + x
return x, cache
class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int,
num_new_embeddings: int,
features: int,
device: Union[str, torch.device] = None,
initializer_range: float = 0.02,
new_embed_initializer_range: float = 0.02,
):
super().__init__()
self.initializer_range = initializer_range
self.new_embed_initializer_range = new_embed_initializer_range
self.embedding = nn.Parameter(
torch.zeros(num_embeddings, features, device=device),
)
# We keep the special token embedding separate from the embedding from the LM so we can
# put a separate learning rate of them during training
self.new_embedding = nn.Parameter(torch.zeros(num_new_embeddings, features, device=device))
def reset_parameters(self):
nn.init.normal_(self.embedding, std=self.initializer_range)
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
def _expand_token(token, batch_size: int):
return token.view(1, 1, -1).expand(batch_size, -1, -1)
class VisionMlp(nn.Module):
def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device=None):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device)
self.act = ACT2FN[hidden_act]
self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(self.act(self.w1(x)))
class MolmoVisionBlock(nn.Module):
def __init__(self, config: MolmoVisionConfig, attention_type, device=None):
super().__init__()
self.attention = VisionAttention(config, device=device, attention_type=attention_type)
self.feed_forward = VisionMlp(
config.image_emb_dim, config.image_mlp_dim, config.image_mlp_activations, device)
self.attention_norm = nn.LayerNorm(
config.image_emb_dim,
eps=config.image_norm_eps,
device=device,
)
self.ffn_norm = nn.LayerNorm(
config.image_emb_dim,
eps=config.image_norm_eps,
device=device,
)
def reset_parameters(self):
self.attention.reset_parameters()
self.feed_forward.reset_parameters()
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attention(self.attention_norm(x))
x = x + self.feed_forward(self.ffn_norm(x))
return x
class VisionPreLayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32),
self.bias.to(torch.float32), self.eps)
return x.to(orig_type)
class VisionTransformer(nn.Module):
def __init__(self, config: MolmoVisionConfig, attention_type, device=None):
super().__init__()
self.config = config
# class embeddings and positional embeddings
self.scale = config.image_emb_dim ** -0.5
self.class_embedding = nn.Parameter(
torch.zeros(config.image_emb_dim, device=device))
self.positional_embedding = nn.Parameter(
torch.zeros(config.image_num_pos, config.image_emb_dim, device=device))
image_patch_size = config.image_patch_size
self.patch_embedding = nn.Linear(
image_patch_size * image_patch_size * 3,
config.image_emb_dim,
bias=False,
device=device
)
self.pre_ln = VisionPreLayerNorm(
config.image_emb_dim,
eps=config.image_norm_eps,
)
self.blocks = nn.ModuleList([
MolmoVisionBlock(config, attention_type=attention_type, device=device)
for _ in range(config.image_num_layers)
])
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
cls_emb = self.positional_embedding[0:1]
pos_emb = self.positional_embedding[1:]
pos_emb = pos_emb.reshape(
(int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])
)
(patch_num_0, patch_num_1) = patch_num
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# antialias: default True in jax.image.resize
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
pos_emb = F.interpolate(
pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True,
)
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
return x
def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]:
if patch_num is None:
patch_num = self.config.image_num_patch
B, N, D = x.shape
x = self.patch_embedding(x)
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
x = self.add_pos_emb(x, patch_num)
x = self.pre_ln(x)
hidden_states = []
for r in self.blocks:
x = r(x)
hidden_states.append(x)
return hidden_states
class VisionAttention(nn.Module):
def __init__(self, config: MolmoVisionConfig, use_bias: bool =True,
embed_dim: int=None, device=None, attention_type: str="sdpa"):
super().__init__()
self.config = config
self.embed_dim = config.image_emb_dim
self.num_heads = config.image_num_heads
self.head_dim = config.image_head_dim
self.num_key_value_heads = config.image_num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.initializer_range = config.initializer_range
self.attention_type = attention_type
embed_dim = embed_dim if embed_dim else config.image_emb_dim
self.wq = nn.Linear(
embed_dim,
self.num_heads * self.head_dim,
bias=use_bias,
device=device,
)
self.wk = nn.Linear(
embed_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=device,
)
self.wv = nn.Linear(
embed_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=device,
)
self.wo = nn.Linear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=use_bias,
device=device,
)
self.residual_dropout = nn.Dropout(config.residual_dropout)
def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states) -> torch.Tensor:
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
if inputs_kv is not None:
inputs_k = inputs_kv
inputs_v = inputs_kv
else:
inputs_k = inputs_q
inputs_v = inputs_q
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
xq = self._split_heads(xq, self.num_heads)
xk = self._split_heads(xk, self.num_key_value_heads)
xv = self._split_heads(xv, self.num_key_value_heads)
if self.num_heads != self.num_key_value_heads:
xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
og_dtype = xq.dtype
if self.config.float32_attention:
xq = xq.to(torch.float)
xk = xk.to(torch.float)
if self.attention_type == "direct":
attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
elif self.attention_type == "sdpa":
if self.config.float32_attention and not torch.is_autocast_enabled():
xv = xv.to(torch.float32)
attn_output = F.scaled_dot_product_attention(
xq.transpose(1, 2).contiguous(),
xk.transpose(1, 2).contiguous(),
xv.transpose(1, 2).contiguous(),
is_causal=False,
).transpose(1, 2)
elif self.attention_type == "flash":
assert not self.config.float32_attention
# Downcast in case we are running with fp32 hidden states
attn_output = _flash_attention_forward(
xq.transpose(1, 2).to(torch.bfloat16),
xk.transpose(1, 2).to(torch.bfloat16),
xv.transpose(1, 2).to(torch.bfloat16),
attention_mask=None,
query_length=inputs_q.shape[1],
is_causal=False,
)
else:
raise NotImplementedError(self.attention_type)
attn_output = attn_output.to(og_dtype)
attn_output = self._merge_heads(attn_output)
attn_output = self.wo(attn_output)
attn_output = self.residual_dropout(attn_output)
return attn_output
class MolmoImageProjector(nn.Module):
def __init__(self, input_dim: int, hidden_dim, output_dim, act_fn="silu", device=None):
super().__init__()
self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device)
self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device)
self.act_fn = ACT2FN[act_fn]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(self.act_fn(self.w1(x))*self.w3(x))
class OLMoVisionBackbone(nn.Module):
def __init__(self, config: MolmoConfig):
super().__init__()
self.config = config
self.image_vit = VisionTransformer(config.vision_config, config.attention_type)
self.image_pooling_2d = VisionAttention(
config.vision_config,
embed_dim=len(config.vit_layers)*config.vision_config.image_emb_dim,
attention_type=config.attention_type
)
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
if config.activation_type == "swiglu":
mlp_config = replace(config, activation_type="llama_swiglu")
elif config.activation_type == "gelu":
raise NotImplementedError()
else:
mlp_config = config
self.image_projector = MolmoImageProjector(
config.vision_config.image_emb_dim,
config.intermediate_size//2, # //2 since `mlp_hidden_size` includes the gate and parts
config.hidden_size,
act_fn=config.activation_type
)
self.image_feature_dropout = nn.Dropout(config.image_feature_dropout)
self.num_prefix_tokens = 1
self.pad_embed = None
if config.image_padding_embed:
image_dim = config.vision_config.image_emb_dim*len(self.config.vit_layers)
if config.image_padding_embed == "pad_and_partial_pad":
self.pad_embed = nn.Parameter(torch.zeros((2, image_dim)))
else:
raise ValueError(config.image_padding_embed)
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
cfg = self.config
v_cfg = self.config.vision_config
B, T, N, D = images.shape
mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
# Output all hidden states
# n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
images = images.view(B * T, N, D)
image_features = self.image_vit(images)
if cfg.vit_layers is not None:
features = []
for layer in cfg.vit_layers:
features.append(image_features[layer])
image_features = torch.cat(features, dim=-1)
else:
image_features = image_features[-1]
cls_embed: torch.Tensor = None
if self.num_prefix_tokens > 0:
cls_embed = image_features[:, 0]
image_features = image_features[:, 1:]
image_features = image_features * mask
image_features = image_features.view(B, T, N, -1)
cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
return image_features, cls_embed
def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
cfg = self.config
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
batch_size, num_image = images.shape[:2]
image_features, cls_embed = self.encode_image(images)
if cfg.image_padding_embed:
assert image_masks is not None
if cfg.image_padding_embed == "pad_embed":
all_pad = (image_masks == 0).to(dtype=torch.float32)
pad_embed = self.pad_embed[None, None, None, :]
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
elif cfg.image_padding_embed == "regress":
pad_embed = self.pad_embed[None, None, None, :]
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
elif cfg.image_padding_embed == "pad_and_partial_pad":
pad_embed = self.pad_embed[:, None, None, None, :]
all_pad = image_masks == 0
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
all_pad = all_pad.to(dtype=image_features.dtype)
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
else:
raise ValueError(cfg.image_padding_embed)
image_features = self.image_feature_dropout(image_features)
if cls_embed is not None:
cls_embed = self.image_feature_dropout(cls_embed)
image_features = image_features.reshape(
(batch_size, num_image) + cfg.image_num_patch + (-1,))
# transpose to get 2x2 feature squares [n_patches, 4, n_features]
batch, n_crops, h, w, c = image_features.shape
image_features = torch.reshape(image_features, [batch*n_crops, h//2, 2, w//2, 2, c])
image_features = torch.permute(image_features, [0, 1, 3, 2, 4, 5])
image_features = torch.reshape(image_features, [batch*n_crops*h//2*w//2, 2*2, c])
query = image_features.mean(-2, keepdim=True)
image_features = self.image_pooling_2d(query, image_features)
h = self.config.vision_config.image_num_patch[0]//2
w = self.config.vision_config.image_num_patch[1]//2
image_features = image_features.reshape(batch_size, num_image, h * w, -1)
# MLP layer to map the feature.
image_features = self.image_projector(image_features)
# image_features: (batch_size, num_image, num_patch, hidden_size)
# cls_embed: (batch_size, num_image, hidden_size)
return image_features, cls_embed
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
class MolmoRmsLayerNorm(nn.Module):
"""
RMS layer norm, a simplified :class:`LayerNorm` implementation
"""
def __init__(
self,
config: MolmoConfig,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-5,
):
super().__init__()
self.config = config
self.eps = self.config.layer_norm_eps or eps
self.normalized_shape = (size or config.hidden_size,)
if elementwise_affine or (elementwise_affine is None):
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
use_bias = self.config.bias_for_layer_norm
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.normalized_shape))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("bias", None)
self.register_parameter("weight", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type):
og_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
x = x.to(og_dtype)
if self.weight is not None:
if self.bias is not None:
return self.weight * x + self.bias
else:
return self.weight * x
else:
return x
class MolmoModel(MolmoPreTrainedModel):
def __init__(self, config: MolmoConfig, init_params: bool = True):
super().__init__(config)
if self.config.additional_vocab_size is not None:
wte = Embedding(
config.vocab_size,
config.additional_vocab_size,
config.hidden_size,
)
else:
wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.transformer = nn.ModuleDict(
dict(
wte=wte,
emb_drop=nn.Dropout(config.embedding_dropout),
ln_f=MolmoRmsLayerNorm(config),
)
)
if config.moe_num_experts > 0:
blocks = [MolmoeBlock(config) for i in range(config.num_hidden_layers)]
else:
blocks = [MolmoBlock(config) for i in range(config.num_hidden_layers)]
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not config.weight_tying:
self.transformer.update(
{
"ff_out": nn.Linear(
config.hidden_size,
config.vocab_size,
bias=False,
)
}
)
self.vision_backbone: Optional[OLMoVisionBackbone] = None
if config.vision_config is not None:
self.vision_backbone = OLMoVisionBackbone(config)
def reset_parameters(self):
if self.vision_backbone is not None:
self.vision_backbone.reset_parameters()
self.reset_non_vision_parameters()
def reset_non_vision_parameters(self):
self.transformer.wte.reset_parameters()
if hasattr(self.transformer.wte, "new_embedding"):
nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
if hasattr(self.transformer, "wpe"):
nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
self.transformer.ln_f.reset_parameters() # type: ignore
if hasattr(self.transformer, "ff_out"):
nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
for block in self.transformer.blocks:
block.reset_parameters()
def forward(
self,
input_ids: torch.LongTensor,
input_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_masks: Optional[torch.Tensor] = None,
image_input_idx: Optional[torch.Tensor] = None,
subsegment_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
last_logits_only: bool = False,
output_hidden_states: Optional[bool] = None,
append_last_valid_logits: Optional[torch.Tensor] = None,
) -> ModelOutput:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
:param input_embeddings: A tensor of shape `(batch_size, seq_len, hidden_size)` with input
embeddings. When provided, it is treated as the output of the input embedding layer.
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
which input IDs are masked. A `1` value in the mask means that
the corresponding input ID should *not* be ignored. A `0` means
that the corresponding input ID is masked.
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
library.
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
to introduce causal or other biases.
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
indicates that the i-th element in the sequence is allowed to attend to the j-th
element in the sequence.
If the tensor is a float tensor, it will just be added to the attention
scores before the softmax.
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
:param response_mask: A tensor of shape `(batch_size, seq_len)` that indicates
the response mask. A `1` value in the mask means that the corresponding token
is a response token. A `0` means that the corresponding token is not
a response token.
:param past_key_values: Pre-computed keys and values for each attention block.
Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
:param use_cache: If `True`, return key and value tensors for each block.
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
This can speed up decoding when you only care about the next token.
"""
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
if past_key_values:
assert len(past_key_values) == self.config.num_hidden_layers
has_image = images is not None
assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images."
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
if past_key_values is None:
past_length = 0
else:
past_length = past_key_values[0][0].size(-2)
if attention_mask is None:
attention_mask = input_ids != -1
if subsegment_ids is not None:
raise NotImplementedError()
else:
if position_ids is None:
position_ids = torch.clamp(
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
min=0,
).broadcast_to((batch_size, attention_mask.shape[-1]))
# Get embeddings of input.
# shape: (batch_size, seq_len, hidden_size)
if input_ids is not None:
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
num_image: Optional[int] = None
if images is not None:
# shape: (batch_size, num_image, num_patch, hidden_size)
# cls_embed: (batch_size, num_image, hidden_size)
image_features, cls_embed = self.vision_backbone(images, image_masks)
num_image, num_patch = image_features.shape[1:3]
assert image_input_idx.shape == (batch_size, num_image, num_patch)
# inster the image feature into the embedding.
image_features = image_features.view(batch_size, num_image * num_patch, -1)
image_input_idx = image_input_idx.view(batch_size, num_image * num_patch)
valid = image_input_idx >= 0
batch_idx = torch.arange(batch_size, device=x.device)
batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
# For hf demo/endpoint
image_features = image_features.to(x.device)
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
# Add input + positional embeddings and apply dropout.
# shape: (batch_size, seq_len, hidden_size)
x = self.transformer.emb_drop(x) # type: ignore
# normalized
if self.config.normalize_input_embeds:
x = x * (self.config.hidden_size ** 0.5)
# Merge attention mask with attention bias.
# FIXME we are ignoring the attention mask input parameter
if self.config.attention_type == "flash":
attention_mask = input_ids != -1
elif (
attention_mask is not None
or past_key_values is not None
):
total_len = (past_length + seq_len)
attention_mask = torch.tril(torch.ones(total_len, total_len, device=x.device, dtype=torch.bool))
attention_mask = attention_mask.view(1, 1, total_len, total_len)
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
# decoder layers
all_hidden_states = []
# Apply blocks one-by-one.
for block_idx, block in enumerate(self.transformer.blocks):
if output_hidden_states:
# add hidden states
all_hidden_states.append(x)
layer_past = None if past_key_values is None else past_key_values[block_idx]
x, cache = block(x, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
if last_logits_only:
# shape: (batch_size, 1, hidden_size)
if append_last_valid_logits is not None:
last_valid_output = x[
torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)]
x = last_valid_output.unsqueeze(1)
else:
x = x[:, -1, :].unsqueeze(1)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, hidden_size)
x = self.transformer.ln_f(x) # type: ignore
if output_hidden_states:
# add final hidden state post-final-layernorm, following HuggingFace's convention
all_hidden_states.append(x)
# Get logits.
# shape: (batch_size, seq_len or 1, vocab_size)
if self.config.weight_tying:
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
else:
logits = self.transformer.ff_out(x) # type: ignore
if self.config.scale_logits:
logits.mul_(1 / math.sqrt(self.config.hidden_size))
if not last_logits_only and append_last_valid_logits is not None:
last_valid_logit = logits[
torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits]
logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
return ModelOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
class MolmoForCausalLM(MolmoPreTrainedModel):
def __init__(self, config: MolmoConfig, model: Optional[MolmoModel] = None, init_params: bool = False):
super().__init__(config)
if not model:
self.model = MolmoModel(config, init_params=init_params)
else:
self.model = model
self.post_init()
def get_input_embeddings(self) -> torch.nn.Module:
return self.model.transformer.wte
def get_output_embeddings(self):
if self.config.weight_tying:
return self.model.transformer.wte
else:
return self.model.transformer.ff_out
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
response_mask: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_masks: Optional[torch.Tensor] = None,
image_input_idx: Optional[torch.Tensor] = None,
subsegment_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
loss_masks: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
last_logits_only: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
append_last_valid_logits: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[
Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]:
if use_cache is None:
use_cache = self.config.use_cache
if output_attentions:
raise ValueError("output_attentions is not yet supported in Molmo")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
input_embeddings=inputs_embeds,
attention_mask=attention_mask,
images=images,
image_masks=image_masks,
image_input_idx=image_input_idx,
subsegment_ids=subsegment_ids,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
last_logits_only=last_logits_only,
output_hidden_states=output_hidden_states,
append_last_valid_logits=append_last_valid_logits,
)
logits = outputs.logits
hidden_states = outputs.hidden_states
loss = None
if labels is not None:
if loss_masks is not None:
loss_masks = loss_masks * (loss_masks > 0)
batch_size_in_tokens = max(loss_masks.sum().item(), 1)
labels = labels.long()
labels.masked_fill_(~(loss_masks > 0), -100)
labels = labels.view(-1)
logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none')
loss = loss_fct(logits_for_loss, labels)
loss = loss.view(input_ids.shape[0], -1)
loss = loss * loss_masks
loss = loss.sum() / batch_size_in_tokens
use_zloss = getattr(self.config, "softmax_auxiliary_loss", False)
if use_zloss:
z_squared = logits_for_loss.logsumexp(-1).pow(2)
z_loss = self.config.softmax_auxiliary_loss_scale * z_squared
z_loss = z_loss.view(input_ids.shape[0], -1)
z_loss = z_loss * loss_masks
z_loss = z_loss.sum() / batch_size_in_tokens
loss += z_loss
else:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.attn_key_values,
hidden_states=hidden_states,
)
def can_generate(self) -> bool:
return True
@torch.no_grad()
def generate_from_batch(
self,
batch: Dict[str, Any],
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
if generation_config is not None:
assert generation_config.use_cache
images = batch.get("images")
image_masks = batch.get("image_masks")
image_input_idx = batch.get("image_input_idx")
# Validate inputs.
input_ids = batch["input_ids"]
batch_size, seq_len = input_ids.shape
attention_mask = batch.get("attention_mask", None)
max_new_tokens = generation_config.max_new_tokens
assert max_new_tokens is not None
mask_len = seq_len + max_new_tokens
position_ids: Optional[torch.Tensor] = None
append_last_valid_logits: Optional[torch.Tensor] = None
if attention_mask is None:
attention_mask = input_ids != -1
position_ids = torch.clamp(
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
min=0
)
append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
dim=1,
)
if attention_mask is not None:
assert attention_mask.shape == (batch_size, mask_len)
out = super().generate(
batch["input_ids"],
generation_config,
attention_mask=attention_mask,
images=images,
image_masks=image_masks,
image_input_idx=image_input_idx,
position_ids=position_ids,
append_last_valid_logits=append_last_valid_logits,
**kwargs,
)
return out
def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
if past_key_values:
# This is because we want the model to only process the last generated token.
input_ids = input_ids[:, -1:]
attention_mask = kwargs.get("attention_mask")
images = kwargs.get("images")
image_masks = kwargs.get("image_masks")
image_input_idx = kwargs.get("image_input_idx")
position_ids = kwargs.get("position_ids")
append_last_valid_logits = kwargs.get("append_last_valid_logits")
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": True,
"last_logits_only": True,
}
if past_key_values is None:
model_inputs["images"] = images
model_inputs["image_masks"] = image_masks
model_inputs["image_input_idx"] = image_input_idx
model_inputs["append_last_valid_logits"] = append_last_valid_logits
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
if "append_last_valid_logits" in model_kwargs:
del model_kwargs["append_last_valid_logits"]
if "images" in model_kwargs:
del model_kwargs["images"]
del model_kwargs["image_masks"]
del model_kwargs["image_input_idx"]
cache_name, cache = super()._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs
# Always register for multi-modal features
AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM)