|
"""Flash attention monkey patch for mistral model""" |
|
|
|
|
|
import logging |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import transformers |
|
from einops import rearrange |
|
from flash_attn.bert_padding import pad_input, unpad_input |
|
from flash_attn.flash_attn_interface import ( |
|
flash_attn_kvpacked_func, |
|
flash_attn_varlen_kvpacked_func, |
|
flash_attn_varlen_qkvpacked_func, |
|
) |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
from transformers.models.mistral.modeling_mistral import ( |
|
MistralAttention as OriginalMistralAttention, |
|
) |
|
from transformers.models.mistral.modeling_mistral import ( |
|
MistralDecoderLayer as OriginalMistralDecoderLayer, |
|
) |
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv |
|
|
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids |
|
|
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral") |
|
|
|
|
|
def replace_mistral_attn_with_flash_attn( |
|
packed: Optional[bool] = False, |
|
): |
|
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( |
|
_prepare_decoder_attention_mask |
|
) |
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( |
|
flashattn_forward |
|
) |
|
if packed: |
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( |
|
MistralDecoderLayer |
|
) |
|
transformers.models.mistral.modeling_mistral.MistralModel.forward = ( |
|
mistral_model_forward |
|
) |
|
|
|
|
|
@torch.jit.script |
|
def _make_sliding_window_causal_mask( |
|
bsz: int, |
|
tgt_len: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
past_key_values_length: int = 0, |
|
sliding_window: int = 4096, |
|
): |
|
""" |
|
Make causal mask used for sliding window attention |
|
""" |
|
tensor = torch.full( |
|
(tgt_len, tgt_len), |
|
fill_value=1, |
|
device=device, |
|
) |
|
mask = torch.tril(tensor, diagonal=0) |
|
|
|
|
|
mask = torch.triu(mask, diagonal=-sliding_window + 1) |
|
mask = torch.log(mask).to(dtype) |
|
|
|
if past_key_values_length > 0: |
|
mask = torch.cat( |
|
[ |
|
torch.zeros( |
|
tgt_len, past_key_values_length, dtype=dtype, device=device |
|
), |
|
mask, |
|
], |
|
dim=-1, |
|
) |
|
return mask[None, None, :, :].expand( |
|
bsz, 1, tgt_len, tgt_len + past_key_values_length |
|
) |
|
|
|
|
|
|
|
|
|
def _prepare_decoder_attention_mask( |
|
self, |
|
attention_mask, |
|
input_shape, |
|
inputs_embeds, |
|
past_key_values_length, |
|
sliding_window, |
|
): |
|
|
|
if attention_mask is None: |
|
return attention_mask |
|
|
|
|
|
|
|
if input_shape[-1] > 1 and attention_mask.shape[0] == 1: |
|
sliding_window_mask = _make_sliding_window_causal_mask( |
|
bsz=input_shape[0], |
|
tgt_len=input_shape[1], |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
past_key_values_length=past_key_values_length, |
|
sliding_window=sliding_window, |
|
) |
|
attention_mask = attention_mask + sliding_window_mask |
|
else: |
|
LOG.info("skipping sliding window mask, not broadcastable with attention mask") |
|
|
|
return attention_mask |
|
|
|
|
|
def flashattn_forward( |
|
self: OriginalMistralAttention, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view( |
|
bsz, q_len, self.num_heads, self.head_dim |
|
).transpose(1, 2) |
|
key_states = key_states.view( |
|
bsz, q_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
value_states = value_states.view( |
|
bsz, q_len, self.num_key_value_heads, self.head_dim |
|
).transpose(1, 2) |
|
|
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb( |
|
query_states, key_states, cos, sin, position_ids |
|
) |
|
|
|
use_sliding_windows = ( |
|
hasattr(self.config, "sliding_window") is not None |
|
and kv_seq_len > self.config.sliding_window |
|
) |
|
|
|
if use_sliding_windows: |
|
window_size = (self.config.sliding_window, self.config.sliding_window) |
|
else: |
|
window_size = (-1, -1) |
|
|
|
if past_key_value is not None: |
|
|
|
if ( |
|
hasattr(self.config, "sliding_window") |
|
and kv_seq_len > self.config.sliding_window |
|
): |
|
slicing_tokens = kv_seq_len - self.config.sliding_window |
|
|
|
past_key = past_key_value[0] |
|
past_value = past_key_value[1] |
|
|
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous() |
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous() |
|
|
|
if past_key.shape[-2] != self.config.sliding_window - 1: |
|
raise ValueError( |
|
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" |
|
f" {past_key.shape}" |
|
) |
|
|
|
past_key_value = (past_key, past_value) if use_cache else None |
|
|
|
if past_key_value is not None: |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
if self.training: |
|
|
|
assert key_states.shape == query_states.shape |
|
is_causal = True |
|
else: |
|
|
|
|
|
is_causal = key_states.shape == query_states.shape |
|
|
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0) |
|
|
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: |
|
|
|
qkv = torch.stack( |
|
[query_states, key_states, value_states], dim=2 |
|
) |
|
qkv = qkv.transpose(1, 3) |
|
qkv = rearrange(qkv, "b s ... -> (b s) ...") |
|
|
|
output = flash_attn_varlen_qkvpacked_func( |
|
qkv, |
|
cu_seqlens, |
|
max_seqlen, |
|
dropout_p=dropout_rate, |
|
softmax_scale=None, |
|
causal=True, |
|
window_size=window_size, |
|
) |
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) |
|
elif query_states.shape == key_states.shape: |
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( |
|
query_states, |
|
key_states, |
|
value_states, |
|
qkvpacked=True, |
|
|
|
|
|
key_padding_mask=attention_mask, |
|
query_padding_mask=attention_mask[:, -query_states.size(1) :] |
|
if attention_mask is not None |
|
else None, |
|
) |
|
output_unpad = flash_attn_varlen_qkvpacked_func( |
|
qkv_unpad, |
|
cu_seqlens_q, |
|
max_seqlen_q, |
|
dropout_p=dropout_rate, |
|
softmax_scale=None, |
|
causal=is_causal, |
|
window_size=window_size, |
|
) |
|
output = output_pad_fn(output_unpad) |
|
else: |
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
if attention_mask is None or attention_mask.all().item(): |
|
output = flash_attn_kvpacked_func( |
|
query_states, |
|
torch.stack([key_states, value_states], 2), |
|
dropout_p=dropout_rate, |
|
causal=is_causal, |
|
window_size=window_size, |
|
) |
|
else: |
|
( |
|
q_unpad, |
|
kv_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
_, |
|
_, |
|
output_pad_fn, |
|
) = generate_qkv( |
|
query_states, |
|
key_states, |
|
value_states, |
|
kvpacked=True, |
|
key_padding_mask=attention_mask, |
|
query_padding_mask=attention_mask[:, -query_states.size(1) :] |
|
if attention_mask is not None |
|
else None, |
|
) |
|
if q_unpad.dtype != kv_unpad.dtype: |
|
kv_unpad = kv_unpad.to(q_unpad.dtype) |
|
output_unpad = flash_attn_varlen_kvpacked_func( |
|
q_unpad, |
|
kv_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
dropout_p=dropout_rate, |
|
softmax_scale=None, |
|
causal=is_causal, |
|
window_size=window_size, |
|
) |
|
output = output_pad_fn(output_unpad) |
|
|
|
attn_output = output |
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)") |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
|
|
def generate_qkv( |
|
q, |
|
k, |
|
v, |
|
query_padding_mask=None, |
|
key_padding_mask=None, |
|
kvpacked=False, |
|
qkvpacked=False, |
|
): |
|
""" |
|
Arguments: |
|
q: (batch_size, seqlen_q, nheads, d) |
|
k: (batch_size, seqlen_k, nheads_k, d) |
|
v: (batch_size, seqlen_k, nheads_k, d) |
|
query_padding_mask: (batch_size, seqlen), bool |
|
key_padding_mask: (batch_size, seqlen), bool |
|
""" |
|
assert not (kvpacked and qkvpacked) |
|
batch_size, seqlen_q, nheads, d = q.shape |
|
_, seqlen_k, nheads_k, _ = k.shape |
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d) |
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d) |
|
|
|
if query_padding_mask is not None: |
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input( |
|
q, query_padding_mask |
|
) |
|
|
|
output_pad_fn = lambda output_unpad: pad_input( |
|
output_unpad, indices_q, batch_size, seqlen_q |
|
) |
|
|
|
else: |
|
q_unpad = rearrange(q, "b s h d -> (b s) h d") |
|
cu_seqlens_q = torch.arange( |
|
0, |
|
(batch_size + 1) * seqlen_q, |
|
step=seqlen_q, |
|
dtype=torch.int32, |
|
device=q_unpad.device, |
|
) |
|
max_seqlen_q = seqlen_q |
|
|
|
output_pad_fn = lambda output_unpad: rearrange( |
|
output_unpad, "(b s) h d -> b s h d", b=batch_size |
|
) |
|
|
|
if key_padding_mask is not None: |
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) |
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask) |
|
else: |
|
k_unpad = rearrange(k, "b s h d -> (b s) h d") |
|
v_unpad = rearrange(v, "b s h d -> (b s) h d") |
|
cu_seqlens_k = torch.arange( |
|
0, |
|
(batch_size + 1) * seqlen_k, |
|
step=seqlen_k, |
|
dtype=torch.int32, |
|
device=k_unpad.device, |
|
) |
|
max_seqlen_k = seqlen_k |
|
|
|
if qkvpacked: |
|
assert nheads == nheads_k |
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) |
|
qkv = torch.stack([q, k, v], dim=2) |
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn) |
|
|
|
if kvpacked: |
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) |
|
kv = torch.stack([k, v], dim=2) |
|
return ( |
|
q_unpad, |
|
kv_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q, |
|
kv, |
|
output_pad_fn, |
|
) |
|
|
|
return ( |
|
q_unpad, |
|
k_unpad, |
|
v_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q, |
|
k, |
|
v, |
|
output_pad_fn, |
|
) |
|
|
|
|
|
def mistral_model_forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" |
|
) |
|
if input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError( |
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds" |
|
) |
|
|
|
seq_length_with_past = seq_length |
|
past_key_values_length = 0 |
|
|
|
if past_key_values is not None: |
|
past_key_values_length = past_key_values[0][0].shape[2] |
|
seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
|
cu_seqlens = None |
|
max_seqlen = None |
|
if position_ids is None: |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
position_ids = torch.arange( |
|
past_key_values_length, |
|
seq_length + past_key_values_length, |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
|
else: |
|
position_ids = position_ids.view(-1, seq_length).long() |
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) |
|
cu_seqlens = cu_seqlens.squeeze() |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
(batch_size, seq_length_with_past), |
|
dtype=torch.bool, |
|
device=inputs_embeds.device, |
|
) |
|
attention_mask = ( |
|
self._prepare_decoder_attention_mask( |
|
attention_mask, |
|
(batch_size, seq_length), |
|
inputs_embeds, |
|
past_key_values_length, |
|
sliding_window=self.config.sliding_window, |
|
) |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
transformers.logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(decoder_layer), |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
past_key_value, |
|
output_attentions, |
|
None, |
|
cu_seqlens, |
|
max_seqlen, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cu_seqlens=cu_seqlens, |
|
max_seqlen=max_seqlen, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] |
|
if v is not None |
|
) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
|
|
class MistralDecoderLayer(OriginalMistralDecoderLayer): |
|
""" |
|
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens |
|
""" |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[torch.Tensor] = None, |
|
) -> Tuple[ |
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] |
|
]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing |
|
""" |
|
|
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cu_seqlens=cu_seqlens, |
|
max_seqlen=max_seqlen, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
return outputs |
|
|