jon-tow commited on
Commit
756dc76
1 Parent(s): d3cd371

feat: add flash_attn and dropout support

Browse files
configuration_stablelm_epoch.py CHANGED
@@ -64,6 +64,8 @@ class StableLMEpochConfig(PretrainedConfig):
64
  (not used by all models). Only relevant if `config.is_decoder=True`.
65
  tie_word_embeddings(`bool`, *optional*, defaults to `False`):
66
  Whether to tie weight embeddings
 
 
67
  """
68
  model_type = "stablelm_epoch"
69
  keys_to_ignore_at_inference = ["past_key_values"]
@@ -86,6 +88,7 @@ class StableLMEpochConfig(PretrainedConfig):
86
  bos_token_id=0,
87
  eos_token_id=2,
88
  tie_word_embeddings=False,
 
89
  **kwargs,
90
  ):
91
  self.vocab_size = vocab_size
@@ -102,6 +105,7 @@ class StableLMEpochConfig(PretrainedConfig):
102
  self.norm_eps = norm_eps
103
  self.use_cache = use_cache
104
  self.tie_word_embeddings = tie_word_embeddings
 
105
  super().__init__(
106
  bos_token_id=bos_token_id,
107
  eos_token_id=eos_token_id,
 
64
  (not used by all models). Only relevant if `config.is_decoder=True`.
65
  tie_word_embeddings(`bool`, *optional*, defaults to `False`):
66
  Whether to tie weight embeddings
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
  """
70
  model_type = "stablelm_epoch"
71
  keys_to_ignore_at_inference = ["past_key_values"]
 
88
  bos_token_id=0,
89
  eos_token_id=2,
90
  tie_word_embeddings=False,
91
+ attention_dropout: float = 0.0,
92
  **kwargs,
93
  ):
94
  self.vocab_size = vocab_size
 
105
  self.norm_eps = norm_eps
106
  self.use_cache = use_cache
107
  self.tie_word_embeddings = tie_word_embeddings
108
+ self.attention_dropout = attention_dropout
109
  super().__init__(
110
  bos_token_id=bos_token_id,
111
  eos_token_id=eos_token_id,
modeling_stablelm_epoch.py CHANGED
@@ -19,23 +19,48 @@
19
  """ PyTorch StableLM Epoch model. """
20
  from typing import Optional, Tuple, Union
21
  import math
 
22
 
23
  import torch
 
24
  import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import CrossEntropyLoss
 
 
27
  from transformers.modeling_outputs import (
28
  BaseModelOutputWithPast,
29
  CausalLMOutputWithPast,
30
  )
31
  from transformers.modeling_utils import PreTrainedModel
32
- from transformers.utils import logging
 
33
  from .configuration_stablelm_epoch import StableLMEpochConfig
34
 
 
 
 
 
 
 
 
35
 
36
  logger = logging.get_logger(__name__)
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
40
  def _make_causal_mask(
41
  input_ids_shape: torch.Size,
@@ -165,6 +190,8 @@ class Attention(nn.Module):
165
  self.num_key_value_heads = config.num_key_value_heads
166
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
167
  self.max_position_embeddings = config.max_position_embeddings
 
 
168
 
169
  if (self.head_dim * self.num_heads) != self.hidden_size:
170
  raise ValueError(
@@ -248,6 +275,7 @@ class Attention(nn.Module):
248
 
249
  # Upcast attention to fp32
250
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
251
  attn_output = torch.matmul(attn_weights, value_states)
252
 
253
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -269,10 +297,202 @@ class Attention(nn.Module):
269
  return attn_output, attn_weights, past_key_value
270
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  class DecoderLayer(nn.Module):
273
  def __init__(self, config: StableLMEpochConfig):
274
  super().__init__()
275
- self.self_attn = Attention(config)
276
  self.mlp = MLP(config)
277
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
278
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
@@ -328,6 +548,7 @@ class StableLMEpochPreTrainedModel(PreTrainedModel):
328
  supports_gradient_checkpointing = True
329
  _no_split_modules = ["DecoderLayer"]
330
  _skip_keys_device_placement = "past_key_values"
 
331
 
332
  def _init_weights(self, module: nn.Module):
333
  """Initialize the weights"""
@@ -355,6 +576,7 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
355
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
356
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
357
 
 
358
  self.gradient_checkpointing = False
359
  # Initialize weights and apply final processing
360
  self.post_init()
@@ -428,10 +650,6 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
428
  seq_length_with_past = seq_length
429
  past_key_values_length = 0
430
 
431
- if past_key_values is not None:
432
- past_key_values_length = past_key_values[0][0].shape[2]
433
- seq_length_with_past = seq_length_with_past + past_key_values_length
434
-
435
  if position_ids is None:
436
  device = input_ids.device if input_ids is not None else inputs_embeds.device
437
  position_ids = torch.arange(
@@ -447,18 +665,22 @@ class StableLMEpochModel(StableLMEpochPreTrainedModel):
447
  if inputs_embeds is None:
448
  inputs_embeds = self.embed_tokens(input_ids)
449
  # Embed positions
450
- if attention_mask is None:
451
- attention_mask = torch.ones(
452
- (batch_size, seq_length_with_past),
453
- dtype=torch.bool,
454
- device=inputs_embeds.device,
 
 
 
 
 
 
 
 
 
 
455
  )
456
- attention_mask = self._prepare_decoder_attention_mask(
457
- attention_mask,
458
- (batch_size, seq_length),
459
- inputs_embeds,
460
- past_key_values_length,
461
- )
462
 
463
  hidden_states = inputs_embeds
464
 
@@ -643,8 +865,17 @@ class StableLMEpochForCausalLM(StableLMEpochPreTrainedModel):
643
  **kwargs,
644
  ):
645
  # Trim decoder_input_ids if past is used
646
- if past_key_values and past_key_values[0] is not None:
647
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
 
 
648
 
649
  position_ids = kwargs.get("position_ids", None)
650
  if attention_mask is not None and position_ids is None:
 
19
  """ PyTorch StableLM Epoch model. """
20
  from typing import Optional, Tuple, Union
21
  import math
22
+ import warnings
23
 
24
  import torch
25
+ import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import CrossEntropyLoss
29
+
30
+ from transformers.cache_utils import Cache
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
34
  )
35
  from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging, is_flash_attn_greater_or_equal_2_10
37
+
38
  from .configuration_stablelm_epoch import StableLMEpochConfig
39
 
40
+ try:
41
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
42
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
43
+ except:
44
+ flash_attn_func, flash_attn_varlen_func = None, None
45
+ index_first_axis, pad_input, unpad_input = None, None, None
46
+
47
 
48
  logger = logging.get_logger(__name__)
49
 
50
 
51
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
52
+ def _get_unpad_data(attention_mask):
53
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
54
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
55
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
56
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
57
+ return (
58
+ indices,
59
+ cu_seqlens,
60
+ max_seqlen_in_batch,
61
+ )
62
+
63
+
64
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
  def _make_causal_mask(
66
  input_ids_shape: torch.Size,
 
190
  self.num_key_value_heads = config.num_key_value_heads
191
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
  self.max_position_embeddings = config.max_position_embeddings
193
+ self.is_causal = True
194
+ self.attention_dropout = config.attention_dropout
195
 
196
  if (self.head_dim * self.num_heads) != self.hidden_size:
197
  raise ValueError(
 
275
 
276
  # Upcast attention to fp32
277
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
278
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
279
  attn_output = torch.matmul(attn_weights, value_states)
280
 
281
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
297
  return attn_output, attn_weights, past_key_value
298
 
299
 
300
+ class FlashAttention2(Attention):
301
+ """
302
+ Reference: https://github.com/huggingface/transformers/blob/5d36025ca13d05151b7a0c761e90d429c4644a30/src/transformers/models/llama/modeling_llama.py#L456
303
+ """
304
+
305
+ def __init__(self, *args, **kwargs):
306
+ super().__init__(*args, **kwargs)
307
+
308
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
309
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
310
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
311
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
312
+
313
+ def forward(
314
+ self,
315
+ hidden_states: torch.Tensor,
316
+ attention_mask: Optional[torch.LongTensor] = None,
317
+ position_ids: Optional[torch.LongTensor] = None,
318
+ past_key_value: Optional[Cache] = None,
319
+ output_attentions: bool = False,
320
+ use_cache: bool = False,
321
+ **kwargs,
322
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
323
+ # FlashAttention2 attention does not support output_attentions
324
+ if "padding_mask" in kwargs:
325
+ warnings.warn(
326
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
327
+ )
328
+
329
+ # overwrite attention_mask with padding_mask
330
+ attention_mask = kwargs.pop("padding_mask")
331
+
332
+ output_attentions = False
333
+
334
+ bsz, q_len, _ = hidden_states.size()
335
+
336
+ query_states = self.q_proj(hidden_states)
337
+ key_states = self.k_proj(hidden_states)
338
+ value_states = self.v_proj(hidden_states)
339
+
340
+ # Flash attention requires the input to have the shape
341
+ # batch_size x seq_length x head_dim x hidden_dim
342
+ # therefore we just need to keep the original shape
343
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
344
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
345
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
346
+
347
+ query_rot = query_states[..., : self.rotary_ndims]
348
+ query_pass = query_states[..., self.rotary_ndims :]
349
+ key_rot = key_states[..., : self.rotary_ndims]
350
+ key_pass = key_states[..., self.rotary_ndims :]
351
+
352
+ kv_seq_len = key_states.shape[-2]
353
+ if past_key_value is not None:
354
+ kv_seq_len += past_key_value[0].shape[-2]
355
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
356
+ query_states, key_states = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
357
+
358
+ # [batch_size, num_heads, seq_len, head_dim]
359
+ query_states = torch.cat((query_states, query_pass), dim=-1)
360
+ key_states = torch.cat((key_states, key_pass), dim=-1)
361
+
362
+ if past_key_value is not None:
363
+ # Reuse k, v, self_attention
364
+ key_states = torch.cat((past_key_value[0], key_states), dim=2)
365
+ value_states = torch.cat((past_key_value[1], value_states), dim=2)
366
+
367
+ past_key_value = (key_states, value_states) if use_cache else None
368
+
369
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
370
+ # to be able to avoid many of these transpose/reshape/view.
371
+ query_states = query_states.transpose(1, 2)
372
+ key_states = key_states.transpose(1, 2)
373
+ value_states = value_states.transpose(1, 2)
374
+
375
+ dropout_rate = self.attention_dropout if self.training else 0.0
376
+
377
+ attn_output = self._flash_attention_forward(
378
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
379
+ )
380
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
381
+ attn_output = self.o_proj(attn_output)
382
+
383
+ if not output_attentions:
384
+ attn_weights = None
385
+
386
+ return attn_output, attn_weights, past_key_value
387
+
388
+ def _flash_attention_forward(
389
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
390
+ ):
391
+ """
392
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
393
+ first unpad the input, then computes the attention scores and pad the final attention scores.
394
+
395
+ Args:
396
+ query_states (`torch.Tensor`):
397
+ Input query states to be passed to Flash Attention API
398
+ key_states (`torch.Tensor`):
399
+ Input key states to be passed to Flash Attention API
400
+ value_states (`torch.Tensor`):
401
+ Input value states to be passed to Flash Attention API
402
+ attention_mask (`torch.Tensor`):
403
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
404
+ position of padding tokens and 1 for the position of non-padding tokens.
405
+ dropout (`int`, *optional*):
406
+ Attention dropout
407
+ softmax_scale (`float`, *optional*):
408
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
409
+ """
410
+ if not self._flash_attn_uses_top_left_mask:
411
+ causal = self.is_causal
412
+ else:
413
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in FlashAttention2 __init__.
414
+ causal = self.is_causal and query_length != 1
415
+
416
+ # Contains at least one padding token in the sequence
417
+ if attention_mask is not None:
418
+ batch_size = query_states.shape[0]
419
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
420
+ query_states, key_states, value_states, attention_mask, query_length
421
+ )
422
+
423
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
424
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
425
+
426
+ attn_output_unpad = flash_attn_varlen_func(
427
+ query_states,
428
+ key_states,
429
+ value_states,
430
+ cu_seqlens_q=cu_seqlens_q,
431
+ cu_seqlens_k=cu_seqlens_k,
432
+ max_seqlen_q=max_seqlen_in_batch_q,
433
+ max_seqlen_k=max_seqlen_in_batch_k,
434
+ dropout_p=dropout,
435
+ softmax_scale=softmax_scale,
436
+ causal=causal,
437
+ )
438
+
439
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
440
+ else:
441
+ attn_output = flash_attn_func(
442
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
443
+ )
444
+
445
+ return attn_output
446
+
447
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
448
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
449
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
450
+
451
+ key_layer = index_first_axis(
452
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
453
+ )
454
+ value_layer = index_first_axis(
455
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
456
+ )
457
+ if query_length == kv_seq_len:
458
+ query_layer = index_first_axis(
459
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
460
+ )
461
+ cu_seqlens_q = cu_seqlens_k
462
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
463
+ indices_q = indices_k
464
+ elif query_length == 1:
465
+ max_seqlen_in_batch_q = 1
466
+ cu_seqlens_q = torch.arange(
467
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
468
+ ) # There is a memcpy here, that is very bad.
469
+ indices_q = cu_seqlens_q[:-1]
470
+ query_layer = query_layer.squeeze(1)
471
+ else:
472
+ # The -q_len: slice assumes left padding.
473
+ attention_mask = attention_mask[:, -query_length:]
474
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
475
+
476
+ return (
477
+ query_layer,
478
+ key_layer,
479
+ value_layer,
480
+ indices_q,
481
+ (cu_seqlens_q, cu_seqlens_k),
482
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
483
+ )
484
+
485
+
486
+ ATTENTION_CLASSES = {
487
+ "eager": Attention,
488
+ "flash_attention_2": FlashAttention2,
489
+ }
490
+
491
+
492
  class DecoderLayer(nn.Module):
493
  def __init__(self, config: StableLMEpochConfig):
494
  super().__init__()
495
+ self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config=config)
496
  self.mlp = MLP(config)
497
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
498
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
 
548
  supports_gradient_checkpointing = True
549
  _no_split_modules = ["DecoderLayer"]
550
  _skip_keys_device_placement = "past_key_values"
551
+ _supports_flash_attn_2 = True
552
 
553
  def _init_weights(self, module: nn.Module):
554
  """Initialize the weights"""
 
576
  self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
577
  self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
578
 
579
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
580
  self.gradient_checkpointing = False
581
  # Initialize weights and apply final processing
582
  self.post_init()
 
650
  seq_length_with_past = seq_length
651
  past_key_values_length = 0
652
 
 
 
 
 
653
  if position_ids is None:
654
  device = input_ids.device if input_ids is not None else inputs_embeds.device
655
  position_ids = torch.arange(
 
665
  if inputs_embeds is None:
666
  inputs_embeds = self.embed_tokens(input_ids)
667
  # Embed positions
668
+ if self._use_flash_attention_2:
669
+ # 2d mask is passed through the layers
670
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
671
+ else:
672
+ if attention_mask is None:
673
+ attention_mask = torch.ones(
674
+ (batch_size, seq_length_with_past),
675
+ dtype=torch.bool,
676
+ device=inputs_embeds.device,
677
+ )
678
+ attention_mask = self._prepare_decoder_attention_mask(
679
+ attention_mask,
680
+ (batch_size, seq_length),
681
+ inputs_embeds,
682
+ past_key_values_length,
683
  )
 
 
 
 
 
 
684
 
685
  hidden_states = inputs_embeds
686
 
 
865
  **kwargs,
866
  ):
867
  # Trim decoder_input_ids if past is used
868
+ if past_key_values is not None:
869
+ past_length = past_key_values[0][0].shape[2]
870
+
871
+ # Some generation methods already pass only the last input ID
872
+ if input_ids.shape[1] > past_length:
873
+ remove_prefix_length = past_length
874
+ else:
875
+ # Default to old behavior: keep only final ID
876
+ remove_prefix_length = input_ids.shape[1] - 1
877
+
878
+ input_ids = input_ids[:, remove_prefix_length:]
879
 
880
  position_ids = kwargs.get("position_ids", None)
881
  if attention_mask is not None and position_ids is None: