| | """Backbone components for Mimi models - shared attention transformers.""" |
| |
|
| | import math |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from transformers.cache_utils import Cache, DynamicCache, StaticCache |
| | from transformers.masking_utils import create_causal_mask |
| | from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available |
| | from transformers.modeling_layers import GradientCheckpointingLayer |
| | from transformers.modeling_outputs import BaseModelOutputWithPast |
| | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| | from transformers.utils import logging |
| |
|
| | try: |
| | from .configuration_mimi import MimiConfig |
| | from .modeling_mimi_clean import ( |
| | MimiAttention, |
| | MimiMLP, |
| | MimiLayerScale, |
| | MimiRotaryEmbedding, |
| | apply_rotary_pos_emb, |
| | MIMI_ATTENTION_CLASSES |
| | ) |
| | except ImportError: |
| | from configuration_mimi import MimiConfig |
| | from modeling_mimi_clean import ( |
| | MimiAttention, |
| | MimiMLP, |
| | MimiLayerScale, |
| | MimiRotaryEmbedding, |
| | apply_rotary_pos_emb, |
| | MIMI_ATTENTION_CLASSES |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class CausalAttentionTransformer(nn.Module): |
| | """ |
| | Standard causal attention transformer (decoder-only) consisting of *config.num_hidden_layers* layers. |
| | Each layer is a [`MimiTransformerLayer`] with self-attention only. |
| | |
| | This is a standard decoder-only transformer architecture for causal language modeling. |
| | |
| | Args: |
| | config: MimiConfig |
| | """ |
| |
|
| | def __init__(self, config: MimiConfig): |
| | super().__init__() |
| | |
| | self.layers = nn.ModuleList( |
| | [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| | ) |
| | self._attn_implementation = config._attn_implementation |
| | self.gradient_checkpointing = False |
| | self.config = config |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | ) -> Union[tuple, BaseModelOutputWithPast]: |
| | """ |
| | Args: |
| | hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
| | Input embeddings or hidden states from previous layer |
| | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | |
| | [What are attention masks?](../glossary#attention-mask) |
| | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| | config.max_position_embeddings - 1]`. |
| | |
| | [What are position IDs?](../glossary#position-ids) |
| | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): |
| | Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up |
| | sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous |
| | stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
| | |
| | Two formats are allowed: |
| | - a [`~cache_utils.Cache`] instance; |
| | - Tuple of `tuple(torch.FloatTensor)` of length `config.num_hidden_layers`, with each tuple having 2 tensors of |
| | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy |
| | cache format. |
| | |
| | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the |
| | legacy cache format will be returned. |
| | |
| | If `past_key_values` are used, the user can optionally input only the last `hidden_states` of shape |
| | `(batch_size, 1, hidden_size)` instead of all `hidden_states` of shape `(batch_size, sequence_length, hidden_size)`. |
| | use_cache (`bool`, *optional*): |
| | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| | `past_key_values`). |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| | tensors for more detail. |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| | more detail. |
| | return_dict (`bool`, *optional*): |
| | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| | Indices depicting the position of the input sequence tokens in the sequence. |
| | """ |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if self.gradient_checkpointing and self.training and use_cache: |
| | logger.warning_once( |
| | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| | ) |
| | use_cache = False |
| |
|
| | if use_cache and not isinstance(past_key_values, Cache): |
| | if past_key_values is None: |
| | past_key_values = DynamicCache() |
| | else: |
| | past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| | logger.warning_once( |
| | "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " |
| | "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " |
| | "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" |
| | ) |
| |
|
| | if cache_position is None: |
| | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | cache_position = torch.arange( |
| | past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
| | ) |
| |
|
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| |
|
| | |
| | causal_mask = create_causal_mask( |
| | config=self.config, |
| | input_embeds=hidden_states, |
| | attention_mask=attention_mask, |
| | cache_position=cache_position, |
| | past_key_values=past_key_values, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | |
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| | next_decoder_cache = None |
| |
|
| | for decoder_layer in self.layers: |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | decoder_layer.__call__, |
| | hidden_states, |
| | causal_mask, |
| | position_ids, |
| | past_key_values, |
| | output_attentions, |
| | use_cache, |
| | cache_position, |
| | ) |
| | else: |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | attention_mask=causal_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_values, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if use_cache: |
| | next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| |
|
| | |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | next_cache = next_decoder_cache if use_cache else None |
| |
|
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=next_cache, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | ) |
| |
|
| |
|
| | class MimiTransformerLayer(GradientCheckpointingLayer): |
| | def __init__(self, config: MimiConfig, layer_idx: int): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| |
|
| | self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
| |
|
| | self.mlp = MimiMLP(config) |
| | self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| | self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| | self.self_attn_layer_scale = MimiLayerScale(config) |
| | self.mlp_layer_scale = MimiLayerScale(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: Optional[bool] = False, |
| | use_cache: Optional[bool] = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| | """ |
| | Args: |
| | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| | attention_mask (`torch.FloatTensor`, *optional*): |
| | attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
| | query_sequence_length, key_sequence_length)` if default attention is used. |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| | returned tensors for more detail. |
| | use_cache (`bool`, *optional*): |
| | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| | (see `past_key_values`). |
| | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| | cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| | Indices depicting the position of the input sequence tokens in the sequence |
| | kwargs (`dict`, *optional*): |
| | Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
| | into the model |
| | """ |
| | residual = hidden_states |
| |
|
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | |
| | hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| | hidden_states = residual + self.self_attn_layer_scale(hidden_states) |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + self.mlp_layer_scale(hidden_states) |
| |
|
| | outputs = (hidden_states,) |
| |
|
| | if output_attentions: |
| | outputs += (self_attn_weights,) |
| |
|
| | if use_cache: |
| | outputs += (present_key_value,) |
| |
|
| | return outputs |
| |
|
| |
|
| | class CrossAttention(nn.Module): |
| | """ |
| | Cross-attention layer with monotonic masking for decoder queries attending to encoder outputs. |
| | Queries come from decoder, keys and values come from encoder. |
| | Supports monotonic attention where each query can only attend to a progressive subset of keys. |
| | """ |
| |
|
| | def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): |
| | super().__init__() |
| | self.config = config |
| | self.layer_idx = layer_idx |
| | if layer_idx is None: |
| | logger.warning_once( |
| | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| | "when creating this class." |
| | ) |
| |
|
| | self.attention_dropout = config.attention_dropout |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.head_dim = config.head_dim |
| | self.num_key_value_heads = config.num_key_value_heads |
| | self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| | self.max_position_embeddings = config.max_position_embeddings |
| | self.rope_theta = config.rope_theta |
| | self.is_causal = True |
| | self.scaling = 1 / math.sqrt(config.head_dim) |
| |
|
| | if self.hidden_size % self.num_heads != 0: |
| | raise ValueError( |
| | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| | f" and `num_heads`: {self.num_heads})." |
| | ) |
| |
|
| | |
| | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| | |
| | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
| | |
| | |
| | self.rotary_emb = MimiRotaryEmbedding(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | alignment_chunk_sizes: Optional[torch.Tensor] = None, |
| | ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| | bsz, q_len, _ = hidden_states.size() |
| | _, kv_len, _ = encoder_hidden_states.size() |
| |
|
| | |
| | query_states = self.q_proj(hidden_states) |
| | |
| | key_states = self.k_proj(encoder_hidden_states) |
| | value_states = self.v_proj(encoder_hidden_states) |
| |
|
| | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(bsz, kv_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
|
| | |
| | if position_ids is not None: |
| | cos, sin = self.rotary_emb(value_states, position_ids) |
| | query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin) |
| |
|
| | if past_key_value is not None: |
| | |
| | cache_kwargs = {"sin": sin if position_ids is not None else None, |
| | "cos": cos if position_ids is not None else None, |
| | "cache_position": cache_position} |
| | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
| |
|
| | key_states = repeat_kv(key_states, self.num_key_value_groups) |
| | value_states = repeat_kv(value_states, self.num_key_value_groups) |
| |
|
| | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling |
| |
|
| | |
| | if alignment_chunk_sizes is not None: |
| | monotonic_mask = _create_monotonic_attention_mask( |
| | alignment_chunk_sizes=alignment_chunk_sizes, |
| | query_length=q_len, |
| | key_length=kv_len, |
| | device=attn_weights.device, |
| | dtype=attn_weights.dtype, |
| | ) |
| | attn_weights = attn_weights + monotonic_mask |
| |
|
| | |
| | if attention_mask is not None: |
| | |
| | |
| | attn_weights = attn_weights + attention_mask |
| |
|
| | |
| | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
| | attn_output = torch.matmul(attn_weights, value_states) |
| |
|
| | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| | raise ValueError( |
| | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| | f" {attn_output.size()}" |
| | ) |
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| | attn_output = attn_output.view(bsz, q_len, -1) |
| | attn_output = self.o_proj(attn_output) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| |
|
| | class CrossAttentionLayer(GradientCheckpointingLayer): |
| | """ |
| | Cross-attention transformer layer with layer normalization and MLP. |
| | Includes self-attention on decoder, cross-attention to encoder, and feed-forward. |
| | """ |
| | |
| | def __init__(self, config: MimiConfig, layer_idx: int): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| |
|
| | |
| | self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
| | |
| | |
| | self.cross_attn = CrossAttention(config=config, layer_idx=layer_idx) |
| |
|
| | self.mlp = MimiMLP(config) |
| | |
| | |
| | self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| | self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| | self.post_cross_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) |
| | |
| | |
| | self.self_attn_layer_scale = MimiLayerScale(config) |
| | self.cross_attn_layer_scale = MimiLayerScale(config) |
| | self.mlp_layer_scale = MimiLayerScale(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | cross_past_key_value: Optional[Cache] = None, |
| | output_attentions: Optional[bool] = False, |
| | use_cache: Optional[bool] = False, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | alignment_chunk_sizes: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| | """ |
| | Args: |
| | hidden_states (`torch.FloatTensor`): decoder input of shape `(batch, seq_len, embed_dim)` |
| | encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch, encoder_seq_len, embed_dim)` |
| | attention_mask (`torch.FloatTensor`, *optional*): causal attention mask for self-attention |
| | encoder_attention_mask (`torch.FloatTensor`, *optional*): mask for encoder positions |
| | position_ids (`torch.LongTensor`, *optional*): position IDs for decoder |
| | past_key_value (`Cache`, *optional*): cached self-attention states |
| | cross_past_key_value (`Cache`, *optional*): cached cross-attention states |
| | output_attentions (`bool`, *optional*): whether to return attention weights |
| | use_cache (`bool`, *optional*): whether to use caching |
| | cache_position (`torch.LongTensor`, *optional*): cache positions |
| | """ |
| | residual = hidden_states |
| |
|
| | |
| | hidden_states = self.input_layernorm(hidden_states) |
| |
|
| | |
| | hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| | hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| | hidden_states = residual + self.self_attn_layer_scale(hidden_states) |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_attention_layernorm(hidden_states) |
| | |
| | hidden_states, cross_attn_weights, cross_present_key_value = self.cross_attn( |
| | hidden_states=hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=encoder_attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=cross_past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | alignment_chunk_sizes=alignment_chunk_sizes, |
| | ) |
| | hidden_states = residual + self.cross_attn_layer_scale(hidden_states) |
| |
|
| | |
| | residual = hidden_states |
| | hidden_states = self.post_cross_attention_layernorm(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + self.mlp_layer_scale(hidden_states) |
| |
|
| | outputs = (hidden_states,) |
| |
|
| | if output_attentions: |
| | outputs += (self_attn_weights, cross_attn_weights) |
| |
|
| | if use_cache: |
| | outputs += (present_key_value, cross_present_key_value) |
| |
|
| | return outputs |
| |
|
| |
|
| | class CrossAttentionTransformer(nn.Module): |
| | """ |
| | Cross-attention transformer consisting of N cross-attention layers. |
| | Each layer performs self-attention on decoder and cross-attention to encoder. |
| | |
| | Args: |
| | config: MimiConfig |
| | """ |
| |
|
| | def __init__(self, config: MimiConfig): |
| | super().__init__() |
| |
|
| | self.layers = nn.ModuleList( |
| | [CrossAttentionLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| | ) |
| | self._attn_implementation = config._attn_implementation |
| |
|
| | self.gradient_checkpointing = False |
| | self.config = config |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| | cross_past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | alignment_chunk_sizes: Optional[torch.Tensor] = None, |
| | ) -> Union[tuple, BaseModelOutputWithPast]: |
| | """ |
| | Args: |
| | hidden_states (`torch.FloatTensor`): decoder input of shape `(batch_size, decoder_sequence_length, hidden_size)` |
| | encoder_hidden_states (`torch.FloatTensor`): encoder output of shape `(batch_size, encoder_sequence_length, hidden_size)` |
| | attention_mask (`torch.Tensor`, *optional*): causal attention mask for decoder self-attention |
| | encoder_attention_mask (`torch.Tensor`, *optional*): attention mask for encoder positions |
| | position_ids (`torch.LongTensor`, *optional*): position IDs for decoder |
| | past_key_values (`Cache` or `list`, *optional*): cached self-attention states |
| | cross_past_key_values (`Cache` or `list`, *optional*): cached cross-attention states |
| | use_cache (`bool`, *optional*): whether to use caching |
| | output_attentions (`bool`, *optional*): whether to return attention weights |
| | output_hidden_states (`bool`, *optional*): whether to return hidden states |
| | return_dict (`bool`, *optional*): whether to return ModelOutput |
| | cache_position (`torch.LongTensor`, *optional*): cache positions |
| | alignment_chunk_sizes (`torch.Tensor`, *optional*): tensor of shape `(decoder_sequence_length,)` specifying |
| | how many encoder positions each decoder position can attend to cumulatively. Enables monotonic attention |
| | where decoder position i can attend to encoder positions 0 through sum(alignment_chunk_sizes[:i+1])-1. |
| | """ |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | use_cache = use_cache if use_cache is not None else self.config.use_cache |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if use_cache and past_key_values is None: |
| | logger.warning_once("use_cache=True was passed, but no past_key_values were given. Creating new cache.") |
| | past_key_values = DynamicCache() |
| | |
| | if use_cache and cross_past_key_values is None: |
| | cross_past_key_values = DynamicCache() |
| |
|
| | if cache_position is None: |
| | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | cache_position = torch.arange( |
| | past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device |
| | ) |
| |
|
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| |
|
| | |
| | causal_mask = create_causal_mask( |
| | config=self.config, |
| | input_embeds=hidden_states, |
| | attention_mask=attention_mask, |
| | cache_position=cache_position, |
| | past_key_values=past_key_values, |
| | position_ids=position_ids, |
| | ) |
| |
|
| | |
| | all_hidden_states = () if output_hidden_states else None |
| | all_self_attns = () if output_attentions else None |
| | all_cross_attns = () if output_attentions else None |
| | next_decoder_cache = None |
| | next_cross_cache = None |
| |
|
| | for layer_idx, decoder_layer in enumerate(self.layers): |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | |
| | layer_past_key_value = past_key_values[layer_idx] if past_key_values is not None else None |
| | layer_cross_past_key_value = cross_past_key_values[layer_idx] if cross_past_key_values is not None else None |
| |
|
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = self._gradient_checkpointing_func( |
| | decoder_layer.__call__, |
| | hidden_states, |
| | encoder_hidden_states, |
| | causal_mask, |
| | encoder_attention_mask, |
| | position_ids, |
| | layer_past_key_value, |
| | layer_cross_past_key_value, |
| | output_attentions, |
| | use_cache, |
| | cache_position, |
| | alignment_chunk_sizes, |
| | ) |
| | else: |
| | layer_outputs = decoder_layer( |
| | hidden_states, |
| | encoder_hidden_states=encoder_hidden_states, |
| | attention_mask=causal_mask, |
| | encoder_attention_mask=encoder_attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=layer_past_key_value, |
| | cross_past_key_value=layer_cross_past_key_value, |
| | output_attentions=output_attentions, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | alignment_chunk_sizes=alignment_chunk_sizes, |
| | ) |
| |
|
| | hidden_states = layer_outputs[0] |
| |
|
| | if use_cache: |
| | |
| | if output_attentions: |
| | next_decoder_cache = layer_outputs[3] |
| | next_cross_cache = layer_outputs[4] |
| | else: |
| | next_decoder_cache = layer_outputs[1] |
| | next_cross_cache = layer_outputs[2] |
| |
|
| | if output_attentions: |
| | all_self_attns += (layer_outputs[1],) |
| | all_cross_attns += (layer_outputs[2],) |
| |
|
| | |
| | if output_hidden_states: |
| | all_hidden_states += (hidden_states,) |
| |
|
| | next_cache = next_decoder_cache if use_cache else None |
| | next_cross_cache = next_cross_cache if use_cache else None |
| |
|
| | if not return_dict: |
| | return tuple(v for v in [hidden_states, next_cache, next_cross_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None) |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=next_cache, |
| | hidden_states=all_hidden_states, |
| | attentions=all_self_attns, |
| | ) |
| |
|
| |
|
| | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| | """ |
| | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| | """ |
| | batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| | if n_rep == 1: |
| | return hidden_states |
| | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| |
|
| |
|
| | def _create_monotonic_attention_mask( |
| | alignment_chunk_sizes: torch.Tensor, |
| | query_length: int, |
| | key_length: int, |
| | device: torch.device, |
| | dtype: torch.dtype, |
| | ) -> torch.Tensor: |
| | """ |
| | Create a monotonic attention mask where each query can only attend to a progressive subset of keys. |
| | |
| | Args: |
| | alignment_chunk_sizes: Tensor of shape (batch_size, query_length) where each element represents |
| | how many keys the corresponding query can attend to cumulatively. |
| | query_length: Number of queries (text tokens) |
| | key_length: Number of keys (speech features) |
| | device: Device to create the mask on |
| | dtype: Data type for the mask |
| | |
| | Returns: |
| | Attention mask of shape (batch_size, 1, query_length, key_length) where |
| | -inf masks out invalid positions, 0.0 allows attention. |
| | """ |
| | batch_size = alignment_chunk_sizes.shape[0] |
| | |
| | |
| | cumulative_positions = torch.cumsum(alignment_chunk_sizes, dim=1) |
| | |
| | |
| | cumulative_positions = torch.clamp(cumulative_positions, max=key_length) |
| | |
| | |
| | key_positions = torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0) |
| | |
| | |
| | cumulative_positions = cumulative_positions.unsqueeze(2) |
| | |
| | |
| | mask = key_positions < cumulative_positions |
| | |
| | |
| | attention_mask = torch.where(mask, 0.0, float('-inf')) |
| | |
| | |
| | attention_mask = attention_mask.unsqueeze(1) |
| | |
| | return attention_mask.to(dtype) |
| |
|
| |
|
| |
|
| | __all__ = [ |
| | "CausalAttentionTransformer", |
| | "MimiTransformerLayer", |
| | "CrossAttention", |
| | "CrossAttentionLayer", |
| | "CrossAttentionTransformer", |
| | "_create_monotonic_attention_mask", |
| | ] |
| |
|