import math from copy import deepcopy from dataclasses import fields, dataclass, replace from enum import Enum from typing import List, Optional, Tuple, Union, Dict, Any, Sequence, Callable, cast, MutableMapping import torch from transformers import PreTrainedModel, GenerationConfig, add_start_docstrings from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.models.auto import AutoModelForCausalLM from torch import nn from transformers.utils import logging from .config_molmo import MolmoConfig, MolmoVisionConfig from torch.nn import functional as F logger = logging.get_logger(__name__) MOLMO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`MolmoConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Molmo Model outputting raw hidden-states without any specific head on top.", MOLMO_START_DOCSTRING, ) class MolmoPreTrainedModel(PreTrainedModel): config_class = MolmoConfig base_model_prefix = "model" _no_split_modules = ["MolmoBlock", "MolmoeBlock", "MolmoVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True # supports_gradient_checkpointing = True # _supports_cache_class = True # _supports_static_cache = False def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear,)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) class MolmoRotaryEmbedding(nn.Module): """ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864). """ def __init__(self, dim, max_position_embeddings=2048, rope_theta=10000, full_precision=True, device=None): super().__init__() self.dim = dim self.rope_theta = rope_theta self.full_precision = full_precision self.max_position_embeddings = max_position_embeddings # Cache sin/cos embeddings dim = self.dim inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) seq = torch.arange(self.max_position_embeddings, device=device, dtype=torch.float) freqs = torch.einsum("i , j -> i j", seq, inv_freq) positions = torch.cat((freqs, freqs), dim=-1) pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] self.register_buffer("rope_pos_sin", pos_sin, persistent=False) self.register_buffer("rope_pos_cos", pos_cos, persistent=False) def rotate_half(self, x: torch.Tensor) -> torch.Tensor: B, nh, T, hs = x.size() x = x.view(B, nh, T, 2, hs // 2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: return (t * pos_cos) + (self.rotate_half(t) * pos_sin) def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: if self.full_precision: q_, k_ = q.float(), k.float() else: q_, k_ = q, k with torch.autocast(q.device.type, enabled=False): batch_size = q_.shape[0] query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None if position_ids is not None: freqs_cis_len = self.max_position_embeddings else: freqs_cis_len = key_len # self.get_rotary_embedding(freqs_cis_len, q_.device) pos_sin = self.rope_pos_sin[:, :, :freqs_cis_len, :].type_as(q_) pos_cos = self.rope_pos_cos[:, :, :freqs_cis_len, :].type_as(q_) if position_ids is not None: assert query_len == key_len, "Query and key lengths must be equal when using position IDs." pos_sin = pos_sin[0, 0][position_ids].view( (batch_size, 1, key_len, pos_sin.shape[-1]) ) pos_cos = pos_cos[0, 0][position_ids].view( (batch_size, 1, key_len, pos_cos.shape[-1]) ) q_ = self.apply_rotary_pos_emb( pos_sin[:, :, key_len - query_len : key_len, :], pos_cos[:, :, key_len - query_len : key_len, :], q_, ) k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) return q_.type_as(q), k_.type_as(k) class MolmoAttention(nn.Module): def __init__( self, config: MolmoConfig, device=None ): super().__init__() self.config = config self.rotary_emb = MolmoRotaryEmbedding( config.hidden_size // config.num_attention_heads, config.max_position_embeddings, config.rope_theta, device=device) self.k_norm: Optional[nn.Module] = None self.q_norm: Optional[nn.Module] = None self.hidden_size = config.intermediate_size if config.qk_layer_norm: if config.num_key_value_heads is None: config.num_key_value_heads = config.num_attention_heads self.q_norm = MolmoRmsLayerNorm( config, size=config.hidden_size, eps=config.layer_norm_eps ) self.k_norm = MolmoRmsLayerNorm( config, size=config.hidden_size, eps=config.layer_norm_eps ) # Attention output projection. input_dim = config.hidden_size head_dim = config.hidden_size // config.num_attention_heads self.fused_dims = ( config.hidden_size, config.num_key_value_heads * head_dim, config.num_key_value_heads * head_dim, ) self.att_proj = nn.Linear( config.hidden_size, sum(self.fused_dims), bias=config.qkv_bias, ) self.attn_out = nn.Linear( input_dim, config.hidden_size, bias=False, ) def attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, drop_mask: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, C = q.size() # batch size, sequence length, hidden_size dtype = k.dtype # Optionally apply layer norm to keys and queries. if self.q_norm is not None and self.k_norm is not None: q = self.q_norm(q).to(dtype=dtype) k = self.k_norm(k).to(dtype=dtype) # Move head forward to be next to the batch dim. # shape: (B, nh, T, hs) q = q.view(B, T, self.config.num_attention_heads, C // self.config.num_attention_heads).transpose(1, 2) # shape: (B, n_kv_h, T, hs) k = k.view(B, T, self.config.num_key_value_heads, C // self.config.num_attention_heads).transpose(1, 2) # shape: (B, n_kv_h, T, hs) v = v.view(B, T, self.config.num_key_value_heads, C // self.config.num_attention_heads).transpose(1, 2) # Apply rotary embeddings q, k = self.rotary_emb(q, k, position_ids=position_ids) if layer_past is not None: past_key, past_value = layer_past k = torch.cat((past_key.to(k.device), k), dim=-2) v = torch.cat((past_value.to(v.device), v), dim=-2) present = (k, v) if use_cache else None query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None if attention_mask is not None: attention_mask = attention_mask[:, :, key_len - query_len: key_len, :key_len] # if attention_bias is not None: # attention_bias = self._cast_attn_bias( # attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype) # Get the attention scores. # shape: (B, nh, T, hs) att = self._scaled_dot_product_attention( q, k, v, attention_mask=attention_mask, dropout_p=0.0 if not self.training else self.config.attention_dropout, is_causal=attention_mask is None, ) # Re-assemble all head outputs side-by-side. att = att.transpose(1, 2).contiguous().view(B, T, C) # Apply output projection. return self.attn_out(att), present def _scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, ) -> torch.Tensor: if attention_mask is not None: attention_mask = attention_mask.to(q.device) if self.config.attention_type == "sdpa": assert k.size(1) == v.size(1) num_kv_heads = k.size(1) num_q_heads = q.size(1) if num_q_heads != num_kv_heads: assert num_q_heads % num_kv_heads == 0 k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads) return F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=is_causal, ) elif self.config.attention_type == "flash": # Downcast in case we are running with fp32 hidden states # Our attention mask is [1, 1, N, N] valid_mask = torch.reduce_any(attention_mask, -1)[0] attn_output = _flash_attention_forward( q.transpose(1, 2).to(torch.bfloat16), k.transpose(1, 2).to(torch.bfloat16), v.transpose(1, 2).to(torch.bfloat16), attention_mask=valid_mask, query_length=q.shape[2], is_causal=True, ) else: raise NotImplementedError(self.config.attention_type) def forward( self, x, attention_mask, position_ids, layer_past, use_cache ): qkv = self.att_proj(x) q, k, v = qkv.split(self.fused_dims, dim=-1) # Get attention scores. att, cache = self.attention( q, k, v, attention_mask, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache ) return att, cache class MolmoMlp(nn.Module): def __init__(self, input_dim, hidden_size, activation_fn, include_bias=False): super().__init__() self.ff_proj = nn.Linear(input_dim, hidden_size, bias=include_bias) self.ff_out = nn.Linear(hidden_size//2, input_dim, bias=include_bias) self.act = ACT2FN[activation_fn] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: x = self.ff_proj(x) x, gate = x.chunk(2, dim=-1) x = self.act(gate) * x x = self.ff_out(x) return x class MolmoBlock(nn.Module): def __init__(self, config: MolmoConfig, device=None): super().__init__() self.config = config self.hidden_size = config.intermediate_size self.dropout = nn.Dropout(config.residual_dropout) self.attn = MolmoAttention(config) self.attn_norm = MolmoRmsLayerNorm(config, size=config.hidden_size, eps=config.layer_norm_eps) self.mlp = MolmoMlp(config.hidden_size, config.intermediate_size, config.activation_type) self.ff_norm = MolmoRmsLayerNorm(config) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: if not self.config.norm_after: atten_in = self.attn_norm(x) else: atten_in = x att, cache = self.attn( atten_in, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache ) if self.config.norm_after: att = self.attn_norm(att) x = x + self.dropout(att) og_x = x if not self.config.norm_after: x = self.ff_norm(x) x = self.mlp(x) if self.config.norm_after: x = self.ff_norm(x) x = self.dropout(x) x = og_x + x return x, cache class MolmoeMLP(nn.Module): def __init__(self, input_dim, hidden_size, activation): super().__init__() self.gate_proj = nn.Linear(input_dim, hidden_size, bias=False) self.up_proj = nn.Linear(input_dim, hidden_size, bias=False) self.down_proj = nn.Linear(hidden_size, input_dim, bias=False) self.act_fn = ACT2FN[activation] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class MolmoeMlpExpert(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.moe_num_experts self.top_k = config.moe_top_k self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) self.experts = nn.ModuleList([MolmoeMLP(config.hidden_size, config.intermediate_size // 2, config.activation_type) for _ in range(self.num_experts)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden_states = self.ff_norm(hidden_states) batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be selected expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class MolmoeBlock(nn.Module): def __init__(self, config: MolmoConfig): super().__init__() self.attn = MolmoAttention(config) self.attn_norm = MolmoRmsLayerNorm(config, size=config.hidden_size, eps=config.layer_norm_eps) assert config.moe_num_experts > 0 self.ff_norm = MolmoRmsLayerNorm(config, size=config.hidden_size, eps=config.layer_norm_eps) self.mlp = MolmoeMlpExpert(config) self.config = config self.hidden_size = config.intermediate_size self.dropout = nn.Dropout(config.residual_dropout) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: if not self.config.norm_after: atten_in = self.attn_norm(x) else: atten_in = x att, cache = self.attn( atten_in, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache ) if self.config.norm_after: att = self.attn_norm(att) x = x + self.dropout(att) og_x = x if not self.config.norm_after: x = self.ff_norm(x) x, _ = self.mlp(x) if self.config.norm_after: x = self.ff_norm(x) x = self.dropout(x) x = og_x + x return x, cache class Embedding(nn.Module): def __init__( self, num_embeddings: int, num_new_embeddings: int, features: int, device: Union[str, torch.device] = None, initializer_range: float = 0.02, new_embed_initializer_range: float = 0.02, ): super().__init__() self.initializer_range = initializer_range self.new_embed_initializer_range = new_embed_initializer_range self.embedding = nn.Parameter( torch.zeros(num_embeddings, features, device=device), ) # We keep the special token embedding separate from the embedding from the LM so we can # put a separate learning rate of them during training self.new_embedding = nn.Parameter(torch.zeros(num_new_embeddings, features, device=device)) def reset_parameters(self): nn.init.normal_(self.embedding, std=self.initializer_range) nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0)) def _expand_token(token, batch_size: int): return token.view(1, 1, -1).expand(batch_size, -1, -1) class VisionMlp(nn.Module): def __init__(self, dim: int, hidden_dim: int, hidden_act: str, device=None): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=True, device=device) self.act = ACT2FN[hidden_act] self.w2 = nn.Linear(hidden_dim, dim, bias=True, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(self.act(self.w1(x))) class MolmoVisionBlock(nn.Module): def __init__(self, config: MolmoVisionConfig, attention_type, device=None): super().__init__() self.attention = VisionAttention(config, device=device, attention_type=attention_type) self.feed_forward = VisionMlp( config.image_emb_dim, config.image_mlp_dim, config.image_mlp_activations, device) self.attention_norm = nn.LayerNorm( config.image_emb_dim, eps=config.image_norm_eps, device=device, ) self.ffn_norm = nn.LayerNorm( config.image_emb_dim, eps=config.image_norm_eps, device=device, ) def reset_parameters(self): self.attention.reset_parameters() self.feed_forward.reset_parameters() self.attention_norm.reset_parameters() self.ffn_norm.reset_parameters() def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attention(self.attention_norm(x)) x = x + self.feed_forward(self.ffn_norm(x)) return x class VisionPreLayerNorm(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: orig_type = x.dtype x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32), self.bias.to(torch.float32), self.eps) return x.to(orig_type) class VisionTransformer(nn.Module): def __init__(self, config: MolmoVisionConfig, attention_type, device=None): super().__init__() self.config = config # class embeddings and positional embeddings self.scale = config.image_emb_dim ** -0.5 self.class_embedding = nn.Parameter( torch.zeros(config.image_emb_dim, device=device)) self.positional_embedding = nn.Parameter( torch.zeros(config.image_num_pos, config.image_emb_dim, device=device)) image_patch_size = config.image_patch_size self.patch_embedding = nn.Linear( image_patch_size * image_patch_size * 3, config.image_emb_dim, bias=False, device=device ) self.pre_ln = VisionPreLayerNorm( config.image_emb_dim, eps=config.image_norm_eps, ) self.blocks = nn.ModuleList([ MolmoVisionBlock(config, attention_type=attention_type, device=device) for _ in range(config.image_num_layers) ]) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: cls_emb = self.positional_embedding[0:1] pos_emb = self.positional_embedding[1:] pos_emb = pos_emb.reshape( (int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]) ) (patch_num_0, patch_num_1) = patch_num if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py # antialias: default True in jax.image.resize pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = F.interpolate( pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, ) pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) return x def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]: if patch_num is None: patch_num = self.config.image_num_patch B, N, D = x.shape x = self.patch_embedding(x) # class embeddings and positional embeddings x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) x = self.add_pos_emb(x, patch_num) x = self.pre_ln(x) hidden_states = [] for r in self.blocks: x = r(x) hidden_states.append(x) return hidden_states class VisionAttention(nn.Module): def __init__(self, config: MolmoVisionConfig, use_bias: bool =True, embed_dim: int=None, device=None, attention_type: str="sdpa"): super().__init__() self.config = config self.embed_dim = config.image_emb_dim self.num_heads = config.image_num_heads self.head_dim = config.image_head_dim self.num_key_value_heads = config.image_num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.initializer_range = config.initializer_range self.attention_type = attention_type embed_dim = embed_dim if embed_dim else config.image_emb_dim self.wq = nn.Linear( embed_dim, self.num_heads * self.head_dim, bias=use_bias, device=device, ) self.wk = nn.Linear( embed_dim, self.num_key_value_heads * self.head_dim, bias=use_bias, device=device, ) self.wv = nn.Linear( embed_dim, self.num_key_value_heads * self.head_dim, bias=use_bias, device=device, ) self.wo = nn.Linear( self.num_heads * self.head_dim, self.embed_dim, bias=use_bias, device=device, ) self.residual_dropout = nn.Dropout(config.residual_dropout) def _split_heads(self, hidden_states, num_heads) -> torch.Tensor: return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) def _merge_heads(self, hidden_states) -> torch.Tensor: return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: if inputs_kv is not None: inputs_k = inputs_kv inputs_v = inputs_kv else: inputs_k = inputs_q inputs_v = inputs_q xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v) xq = self._split_heads(xq, self.num_heads) xk = self._split_heads(xk, self.num_key_value_heads) xv = self._split_heads(xv, self.num_key_value_heads) if self.num_heads != self.num_key_value_heads: xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads) og_dtype = xq.dtype if self.config.float32_attention: xq = xq.to(torch.float) xk = xk.to(torch.float) if self.attention_type == "direct": attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk) attn_weights = F.softmax(attn_weights, dim=-1) attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv) elif self.attention_type == "sdpa": if self.config.float32_attention and not torch.is_autocast_enabled(): xv = xv.to(torch.float32) attn_output = F.scaled_dot_product_attention( xq.transpose(1, 2).contiguous(), xk.transpose(1, 2).contiguous(), xv.transpose(1, 2).contiguous(), is_causal=False, ).transpose(1, 2) elif self.attention_type == "flash": assert not self.config.float32_attention # Downcast in case we are running with fp32 hidden states attn_output = _flash_attention_forward( xq.transpose(1, 2).to(torch.bfloat16), xk.transpose(1, 2).to(torch.bfloat16), xv.transpose(1, 2).to(torch.bfloat16), attention_mask=None, query_length=inputs_q.shape[1], is_causal=False, ) else: raise NotImplementedError(self.attention_type) attn_output = attn_output.to(og_dtype) attn_output = self._merge_heads(attn_output) attn_output = self.wo(attn_output) attn_output = self.residual_dropout(attn_output) return attn_output class MolmoImageProjector(nn.Module): def __init__(self, input_dim: int, hidden_dim, output_dim, act_fn="silu", device=None): super().__init__() self.w1 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) self.w2 = nn.Linear(hidden_dim, output_dim, bias=False, device=device) self.w3 = nn.Linear(input_dim, hidden_dim, bias=False, device=device) self.act_fn = ACT2FN[act_fn] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(self.act_fn(self.w1(x))*self.w3(x)) class OLMoVisionBackbone(nn.Module): def __init__(self, config: MolmoConfig): super().__init__() self.config = config self.image_vit = VisionTransformer(config.vision_config, config.attention_type) self.image_pooling_2d = VisionAttention( config.vision_config, embed_dim=len(config.vit_layers)*config.vision_config.image_emb_dim, attention_type=config.attention_type ) # `MLP` assume the activation takes two inputs, so it must be a 'llama' version if config.activation_type == "swiglu": mlp_config = replace(config, activation_type="llama_swiglu") elif config.activation_type == "gelu": raise NotImplementedError() else: mlp_config = config self.image_projector = MolmoImageProjector( config.vision_config.image_emb_dim, config.intermediate_size//2, # //2 since `mlp_hidden_size` includes the gate and parts config.hidden_size, act_fn=config.activation_type ) self.image_feature_dropout = nn.Dropout(config.image_feature_dropout) self.num_prefix_tokens = 1 self.pad_embed = None if config.image_padding_embed: image_dim = config.vision_config.image_emb_dim*len(self.config.vit_layers) if config.image_padding_embed == "pad_and_partial_pad": self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) else: raise ValueError(config.image_padding_embed) def encode_image(self, images: torch.Tensor) -> torch.Tensor: cfg = self.config v_cfg = self.config.vision_config B, T, N, D = images.shape mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) # Output all hidden states # n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim) images = images.view(B * T, N, D) image_features = self.image_vit(images) if cfg.vit_layers is not None: features = [] for layer in cfg.vit_layers: features.append(image_features[layer]) image_features = torch.cat(features, dim=-1) else: image_features = image_features[-1] cls_embed: torch.Tensor = None if self.num_prefix_tokens > 0: cls_embed = image_features[:, 0] image_features = image_features[:, 1:] image_features = image_features * mask image_features = image_features.view(B, T, N, -1) cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None return image_features, cls_embed def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: cfg = self.config # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) batch_size, num_image = images.shape[:2] image_features, cls_embed = self.encode_image(images) if cfg.image_padding_embed: assert image_masks is not None if cfg.image_padding_embed == "pad_embed": all_pad = (image_masks == 0).to(dtype=torch.float32) pad_embed = self.pad_embed[None, None, None, :] image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1) elif cfg.image_padding_embed == "regress": pad_embed = self.pad_embed[None, None, None, :] image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1) elif cfg.image_padding_embed == "pad_and_partial_pad": pad_embed = self.pad_embed[:, None, None, None, :] all_pad = image_masks == 0 partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype) all_pad = all_pad.to(dtype=image_features.dtype) image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1) image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1) else: raise ValueError(cfg.image_padding_embed) image_features = self.image_feature_dropout(image_features) if cls_embed is not None: cls_embed = self.image_feature_dropout(cls_embed) image_features = image_features.reshape( (batch_size, num_image) + cfg.image_num_patch + (-1,)) # transpose to get 2x2 feature squares [n_patches, 4, n_features] batch, n_crops, h, w, c = image_features.shape image_features = torch.reshape(image_features, [batch*n_crops, h//2, 2, w//2, 2, c]) image_features = torch.permute(image_features, [0, 1, 3, 2, 4, 5]) image_features = torch.reshape(image_features, [batch*n_crops*h//2*w//2, 2*2, c]) query = image_features.mean(-2, keepdim=True) image_features = self.image_pooling_2d(query, image_features) h = self.config.vision_config.image_num_patch[0]//2 w = self.config.vision_config.image_num_patch[1]//2 image_features = image_features.reshape(batch_size, num_image, h * w, -1) # MLP layer to map the feature. image_features = self.image_projector(image_features) # image_features: (batch_size, num_image, num_patch, hidden_size) # cls_embed: (batch_size, num_image, hidden_size) return image_features, cls_embed def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: att_bias = torch.triu( torch.ones(seq_len, seq_len, device=device, dtype=torch.float), diagonal=1, ) att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) return att_bias.view(1, 1, seq_len, seq_len) # type: ignore class MolmoRmsLayerNorm(nn.Module): """ RMS layer norm, a simplified :class:`LayerNorm` implementation """ def __init__( self, config: MolmoConfig, size: Optional[int] = None, elementwise_affine: Optional[bool] = None, eps: float = 1e-5, ): super().__init__() self.config = config self.eps = self.config.layer_norm_eps or eps self.normalized_shape = (size or config.hidden_size,) if elementwise_affine or (elementwise_affine is None): self.weight = nn.Parameter(torch.ones(self.normalized_shape)) use_bias = self.config.bias_for_layer_norm if use_bias: self.bias = nn.Parameter(torch.zeros(self.normalized_shape)) else: self.register_parameter("bias", None) else: self.register_parameter("bias", None) self.register_parameter("weight", None) def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.autocast(enabled=False, device_type=x.device.type): og_dtype = x.dtype x = x.to(torch.float32) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) x = x.to(og_dtype) if self.weight is not None: if self.bias is not None: return self.weight * x + self.bias else: return self.weight * x else: return x class MolmoModel(MolmoPreTrainedModel): def __init__(self, config: MolmoConfig, init_params: bool = True): super().__init__(config) if self.config.additional_vocab_size is not None: wte = Embedding( config.vocab_size, config.additional_vocab_size, config.hidden_size, ) else: wte = nn.Embedding(config.vocab_size, config.hidden_size) self.transformer = nn.ModuleDict( dict( wte=wte, emb_drop=nn.Dropout(config.embedding_dropout), ln_f=MolmoRmsLayerNorm(config), ) ) if config.moe_num_experts > 0: blocks = [MolmoeBlock(config) for i in range(config.num_hidden_layers)] else: blocks = [MolmoBlock(config) for i in range(config.num_hidden_layers)] self.transformer.update({"blocks": nn.ModuleList(blocks)}) if not config.weight_tying: self.transformer.update( { "ff_out": nn.Linear( config.hidden_size, config.vocab_size, bias=False, ) } ) self.vision_backbone: Optional[OLMoVisionBackbone] = None if config.vision_config is not None: self.vision_backbone = OLMoVisionBackbone(config) def reset_parameters(self): if self.vision_backbone is not None: self.vision_backbone.reset_parameters() self.reset_non_vision_parameters() def reset_non_vision_parameters(self): self.transformer.wte.reset_parameters() if hasattr(self.transformer.wte, "new_embedding"): nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range) if hasattr(self.transformer, "wpe"): nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0) self.transformer.ln_f.reset_parameters() # type: ignore if hasattr(self.transformer, "ff_out"): nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02) for block in self.transformer.blocks: block.reset_parameters() def forward( self, input_ids: torch.LongTensor, input_embeddings: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_masks: Optional[torch.Tensor] = None, image_input_idx: Optional[torch.Tensor] = None, subsegment_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, last_logits_only: bool = False, output_hidden_states: Optional[bool] = None, append_last_valid_logits: Optional[torch.Tensor] = None, ) -> ModelOutput: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. :param input_embeddings: A tensor of shape `(batch_size, seq_len, hidden_size)` with input embeddings. When provided, it is treated as the output of the input embedding layer. :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates which input IDs are masked. A `1` value in the mask means that the corresponding input ID should *not* be ignored. A `0` means that the corresponding input ID is masked. This has the same meaning as the `attention_mask` in HuggingFace's `transformers` library. :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used to introduce causal or other biases. If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` indicates that the i-th element in the sequence is allowed to attend to the j-th element in the sequence. If the tensor is a float tensor, it will just be added to the attention scores before the softmax. The default is causal, which corresponds to a lower-diagonal byte matrix of ones. :param response_mask: A tensor of shape `(batch_size, seq_len)` that indicates the response mask. A `1` value in the mask means that the corresponding token is a response token. A `0` means that the corresponding token is not a response token. :param past_key_values: Pre-computed keys and values for each attention block. Can be used to speed up sequential decoding. The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed. :param use_cache: If `True`, return key and value tensors for each block. :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. This can speed up decoding when you only care about the next token. """ output_hidden_states = output_hidden_states if output_hidden_states is not None else False if past_key_values: assert len(past_key_values) == self.config.num_hidden_layers has_image = images is not None assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings." assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images." batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] if past_key_values is None: past_length = 0 else: past_length = past_key_values[0][0].size(-2) if attention_mask is None: attention_mask = input_ids != -1 if subsegment_ids is not None: raise NotImplementedError() else: if position_ids is None: position_ids = torch.clamp( torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0, ).broadcast_to((batch_size, attention_mask.shape[-1])) # Get embeddings of input. # shape: (batch_size, seq_len, hidden_size) if input_ids is not None: input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore num_image: Optional[int] = None if images is not None: # shape: (batch_size, num_image, num_patch, hidden_size) # cls_embed: (batch_size, num_image, hidden_size) image_features, cls_embed = self.vision_backbone(images, image_masks) num_image, num_patch = image_features.shape[1:3] assert image_input_idx.shape == (batch_size, num_image, num_patch) # inster the image feature into the embedding. image_features = image_features.view(batch_size, num_image * num_patch, -1) image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) valid = image_input_idx >= 0 batch_idx = torch.arange(batch_size, device=x.device) batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]]) # For hf demo/endpoint image_features = image_features.to(x.device) x[batch_idx[valid], image_input_idx[valid]] += image_features[valid] # Add input + positional embeddings and apply dropout. # shape: (batch_size, seq_len, hidden_size) x = self.transformer.emb_drop(x) # type: ignore # normalized if self.config.normalize_input_embeds: x = x * (self.config.hidden_size ** 0.5) # Merge attention mask with attention bias. # FIXME we are ignoring the attention mask input parameter if self.config.attention_type == "flash": attention_mask = input_ids != -1 elif ( attention_mask is not None or past_key_values is not None ): total_len = (past_length + seq_len) attention_mask = torch.tril(torch.ones(total_len, total_len, device=x.device, dtype=torch.bool)) attention_mask = attention_mask.view(1, 1, total_len, total_len) attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None # decoder layers all_hidden_states = [] # Apply blocks one-by-one. for block_idx, block in enumerate(self.transformer.blocks): if output_hidden_states: # add hidden states all_hidden_states.append(x) layer_past = None if past_key_values is None else past_key_values[block_idx] x, cache = block(x, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache) if attn_key_values is not None: assert cache is not None attn_key_values.append(cache) if last_logits_only: # shape: (batch_size, 1, hidden_size) if append_last_valid_logits is not None: last_valid_output = x[ torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)] x = last_valid_output.unsqueeze(1) else: x = x[:, -1, :].unsqueeze(1) # Apply final layer norm. # shape: (batch_size, seq_len or 1, hidden_size) x = self.transformer.ln_f(x) # type: ignore if output_hidden_states: # add final hidden state post-final-layernorm, following HuggingFace's convention all_hidden_states.append(x) # Get logits. # shape: (batch_size, seq_len or 1, vocab_size) if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore else: logits = self.transformer.ff_out(x) # type: ignore if self.config.scale_logits: logits.mul_(1 / math.sqrt(self.config.hidden_size)) if not last_logits_only and append_last_valid_logits is not None: last_valid_logit = logits[ torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits] logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1) return ModelOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type] class MolmoForCausalLM(MolmoPreTrainedModel): def __init__(self, config: MolmoConfig, model: Optional[MolmoModel] = None, init_params: bool = False): super().__init__(config) if not model: self.model = MolmoModel(config, init_params=init_params) else: self.model = model self.post_init() def get_input_embeddings(self) -> torch.nn.Module: return self.model.transformer.wte def get_output_embeddings(self): if self.config.weight_tying: return self.model.transformer.wte else: return self.model.transformer.ff_out def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, response_mask: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_masks: Optional[torch.Tensor] = None, image_input_idx: Optional[torch.Tensor] = None, subsegment_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, loss_masks: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, last_logits_only: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, append_last_valid_logits: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, cache_position: Optional[ Cache ] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426 ) -> Union[Tuple, CausalLMOutputWithPast]: if use_cache is None: use_cache = self.config.use_cache if output_attentions: raise ValueError("output_attentions is not yet supported in Molmo") return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.forward( input_ids=input_ids, input_embeddings=inputs_embeds, attention_mask=attention_mask, images=images, image_masks=image_masks, image_input_idx=image_input_idx, subsegment_ids=subsegment_ids, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, last_logits_only=last_logits_only, output_hidden_states=output_hidden_states, append_last_valid_logits=append_last_valid_logits, ) logits = outputs.logits hidden_states = outputs.hidden_states loss = None if labels is not None: if loss_masks is not None: loss_masks = loss_masks * (loss_masks > 0) batch_size_in_tokens = max(loss_masks.sum().item(), 1) labels = labels.long() labels.masked_fill_(~(loss_masks > 0), -100) labels = labels.view(-1) logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1)) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none') loss = loss_fct(logits_for_loss, labels) loss = loss.view(input_ids.shape[0], -1) loss = loss * loss_masks loss = loss.sum() / batch_size_in_tokens use_zloss = getattr(self.config, "softmax_auxiliary_loss", False) if use_zloss: z_squared = logits_for_loss.logsumexp(-1).pow(2) z_loss = self.config.softmax_auxiliary_loss_scale * z_squared z_loss = z_loss.view(input_ids.shape[0], -1) z_loss = z_loss * loss_masks z_loss = z_loss.sum() / batch_size_in_tokens loss += z_loss else: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = torch.nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.attn_key_values, hidden_states=hidden_states, ) def can_generate(self) -> bool: return True @torch.no_grad() def generate_from_batch( self, batch: Dict[str, Any], generation_config: Optional[GenerationConfig] = None, **kwargs, ): if generation_config is not None: assert generation_config.use_cache images = batch.get("images") image_masks = batch.get("image_masks") image_input_idx = batch.get("image_input_idx") # Validate inputs. input_ids = batch["input_ids"] batch_size, seq_len = input_ids.shape attention_mask = batch.get("attention_mask", None) max_new_tokens = generation_config.max_new_tokens assert max_new_tokens is not None mask_len = seq_len + max_new_tokens position_ids: Optional[torch.Tensor] = None append_last_valid_logits: Optional[torch.Tensor] = None if attention_mask is None: attention_mask = input_ids != -1 position_ids = torch.clamp( torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0 ) append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1 attention_mask = torch.cat( [attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))], dim=1, ) if attention_mask is not None: assert attention_mask.shape == (batch_size, mask_len) out = super().generate( batch["input_ids"], generation_config, attention_mask=attention_mask, images=images, image_masks=image_masks, image_input_idx=image_input_idx, position_ids=position_ids, append_last_valid_logits=append_last_valid_logits, **kwargs, ) return out def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs ): if past_key_values: # This is because we want the model to only process the last generated token. input_ids = input_ids[:, -1:] attention_mask = kwargs.get("attention_mask") images = kwargs.get("images") image_masks = kwargs.get("image_masks") image_input_idx = kwargs.get("image_input_idx") position_ids = kwargs.get("position_ids") append_last_valid_logits = kwargs.get("append_last_valid_logits") model_inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": True, "last_logits_only": True, } if past_key_values is None: model_inputs["images"] = images model_inputs["image_masks"] = image_masks model_inputs["image_input_idx"] = image_input_idx model_inputs["append_last_valid_logits"] = append_last_valid_logits return model_inputs def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, num_new_tokens: int = 1, ) -> Dict[str, Any]: model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 if "append_last_valid_logits" in model_kwargs: del model_kwargs["append_last_valid_logits"] if "images" in model_kwargs: del model_kwargs["images"] del model_kwargs["image_masks"] del model_kwargs["image_input_idx"] cache_name, cache = super()._extract_past_from_model_output(outputs) model_kwargs[cache_name] = cache model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens return model_kwargs # Always register for multi-modal features AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM)