Guanzheng commited on
Commit
9cfb7d3
1 Parent(s): 3a3383b

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +786 -319
modeling_llama.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # coding=utf-8
2
  # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
@@ -19,92 +20,88 @@
19
  # limitations under the License.
20
  """ PyTorch LLaMA model."""
21
  import math
 
22
  from typing import List, Optional, Tuple, Union
23
 
24
  import torch
 
25
  import torch.utils.checkpoint
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
 
29
  from transformers.activations import ACT2FN
 
 
 
 
 
 
 
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
  from transformers.modeling_utils import PreTrainedModel
32
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
- from .configuration_clex import CLEXLlamaConfig
34
- from .clex_layer import LlamaCLEXScalingRotaryEmbedding
35
- from einops import rearrange
36
- import importlib.metadata
37
- import importlib.util
38
-
39
-
40
- logger = logging.get_logger(__name__)
 
41
 
42
- def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
43
- # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
44
- package_exists = importlib.util.find_spec(pkg_name) is not None
45
- package_version = "N/A"
46
- if package_exists:
47
- try:
48
- package_version = importlib.metadata.version(pkg_name)
49
- package_exists = True
50
- except importlib.metadata.PackageNotFoundError:
51
- package_exists = False
52
- logger.info(f"Detected {pkg_name} version {package_version}")
53
- if return_version:
54
- return package_exists, package_version
55
- else:
56
- return package_exists
57
 
58
- def is_flash_attn_available():
59
- if not _is_package_available("torch", return_version=True):
60
- return False
61
 
62
- # Let's add an extra check to see if cuda is available
63
 
64
- return _is_package_available("flash_attn") and torch.cuda.is_available()
 
 
 
 
65
 
 
66
 
67
 
 
 
68
 
69
 
 
70
 
71
  _CONFIG_FOR_DOC = "CLEXLlamaConfig"
72
 
73
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
 
 
 
 
 
 
76
 
77
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
78
  def _make_causal_mask(
79
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
80
  ):
81
- """
82
- Make causal mask used for bi-directional self-attention.
83
- """
84
- bsz, tgt_len = input_ids_shape
85
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
86
- mask_cond = torch.arange(mask.size(-1), device=device)
87
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
88
- mask = mask.to(dtype)
89
-
90
- if past_key_values_length > 0:
91
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
92
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
93
-
94
-
95
- # Copied from transformers.models.bart.modeling_bart._expand_mask
96
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
97
- """
98
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
99
- """
100
- bsz, src_len = mask.size()
101
- tgt_len = tgt_len if tgt_len is not None else src_len
102
-
103
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
104
-
105
- inverted_mask = 1.0 - expanded_mask
106
-
107
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
108
 
109
 
110
  class LlamaRMSNorm(nn.Module):
@@ -117,48 +114,97 @@ class LlamaRMSNorm(nn.Module):
117
  self.variance_epsilon = eps
118
 
119
  def forward(self, hidden_states):
120
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
 
 
121
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
122
 
123
- # convert into half-precision if necessary
124
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
125
- hidden_states = hidden_states.to(self.weight.dtype)
126
 
127
- return self.weight * hidden_states
128
 
129
 
130
- class LlamaRotaryEmbedding(torch.nn.Module):
131
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
132
  super().__init__()
133
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
134
- self.register_buffer("inv_freq", inv_freq)
 
 
 
 
135
 
136
  # Build here to make `torch.jit.trace` work.
137
- self.max_seq_len_cached = max_position_embeddings
138
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
139
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
 
 
 
 
 
 
140
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
141
  emb = torch.cat((freqs, freqs), dim=-1)
142
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
143
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
144
 
145
  def forward(self, x, seq_len=None):
146
  # x: [bs, num_attention_heads, seq_len, head_size]
147
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
148
  if seq_len > self.max_seq_len_cached:
149
- self.max_seq_len_cached = seq_len
150
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
151
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
152
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
153
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
154
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
155
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
156
  return (
157
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
158
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
159
  )
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def rotate_half(x):
163
  """Rotates half the hidden dims of the input."""
164
  x1 = x[..., : x.shape[-1] // 2]
@@ -166,217 +212,608 @@ def rotate_half(x):
166
  return torch.cat((-x2, x1), dim=-1)
167
 
168
 
169
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
171
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
172
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
173
- cos_q = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
174
- sin_q = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
175
 
176
- cos_k = cos[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
177
- sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
178
  q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
179
  k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
180
  return q_embed, k_embed
181
 
182
 
183
  class LlamaMLP(nn.Module):
184
- def __init__(
185
- self,
186
- hidden_size: int,
187
- intermediate_size: int,
188
- hidden_act: str,
189
- ):
190
  super().__init__()
191
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
192
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
193
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
194
- self.act_fn = ACT2FN[hidden_act]
 
 
 
195
 
196
  def forward(self, x):
197
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
 
200
  class LlamaAttention(nn.Module):
201
  """Multi-headed attention from 'Attention Is All You Need' paper"""
202
 
203
- def __init__(self, config: CLEXLlamaConfig):
204
  super().__init__()
205
  self.config = config
 
 
 
 
 
 
 
 
 
206
  self.hidden_size = config.hidden_size
207
  self.num_heads = config.num_attention_heads
208
  self.head_dim = self.hidden_size // self.num_heads
 
 
209
  self.max_position_embeddings = config.max_position_embeddings
210
- self.log_scale = config.log_scale
 
 
211
  if (self.head_dim * self.num_heads) != self.hidden_size:
212
  raise ValueError(
213
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
214
  f" and `num_heads`: {self.num_heads})."
215
  )
216
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
217
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
218
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
219
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
220
- self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
223
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
224
-
225
- def flash_attn_forward(
226
  self,
227
- qkv: torch.Tensor,
228
- key_padding_mask: Optional[torch.Tensor] = None,
 
 
 
 
 
 
229
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
- """Input shape: Batch x Time x Channel
231
-
232
- attention_mask: [bsz, q_len]
233
- """
234
- if is_flash_attn_available():
235
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
236
- # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
237
- from flash_attn.bert_padding import unpad_input, pad_input
238
- bsz, q_len, *_ = qkv.size()
239
-
240
- if key_padding_mask is None:
241
- # qkv = rearrange(qkv, "b s ... -> (b s) ...")
242
- max_s = q_len
243
- cu_q_lens = torch.arange(
244
- 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
245
  )
246
- output = flash_attn_qkvpacked_func(
247
- qkv, 0.0, softmax_scale=None, causal=True
 
 
 
 
 
248
  )
 
 
 
 
 
 
 
 
 
 
 
 
249
  else:
250
- nheads = qkv.shape[-2]
251
- x = rearrange(qkv, "b s three h d -> b s (three h d)")
252
- x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
253
- x_unpad = rearrange(
254
- x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
255
- )
256
- output_unpad = flash_attn_varlen_qkvpacked_func(
257
- x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  )
259
- output = rearrange(
260
- pad_input(
261
- rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
262
- ),
263
- "b s (h d) -> b s h d",
264
- h=nheads,
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
- return self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
267
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  def forward(
269
  self,
270
  hidden_states: torch.Tensor,
271
- attention_mask: Optional[torch.Tensor] = None,
272
  position_ids: Optional[torch.LongTensor] = None,
273
- pack_cos_sin = None,
274
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
275
  output_attentions: bool = False,
276
  use_cache: bool = False,
 
277
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
 
 
278
  bsz, q_len, _ = hidden_states.size()
279
 
280
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
281
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
282
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
283
 
284
- kv_seq_len = key_states.shape[-2]
 
 
 
 
 
285
 
 
286
  if past_key_value is not None:
287
- kv_seq_len += past_key_value[0].shape[-2]
288
- cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
289
- else:
290
- cache_key_states = key_states
291
-
292
  if pack_cos_sin is not None:
293
  cos, sin = pack_cos_sin.to(query_states.device)
294
  else:
295
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
296
- key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
 
 
 
 
 
 
 
 
 
 
297
  query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
298
 
299
- if past_key_value is not None:
300
- # reuse k, v, self_attention
301
- # key_states = torch.cat([past_key_value[0], key_states], dim=2)
302
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- past_key_value = (cache_key_states, value_states) if use_cache else None
 
 
 
 
305
 
306
- use_flashattn = self.config.use_flashattn and is_flash_attn_available()
 
 
307
 
308
- if self.log_scale:
 
 
309
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
310
- torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
311
  query_states = query_states * log_n
312
-
313
 
314
- if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
315
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
318
- raise ValueError(
319
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
320
- f" {attn_weights.size()}"
321
- )
322
 
323
- if attention_mask is not None:
324
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
325
- raise ValueError(
326
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
327
- )
328
- attn_weights = attn_weights + attention_mask
329
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
330
 
331
- # upcast attention to fp32
332
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
333
- attn_output = torch.matmul(attn_weights, value_states)
334
 
335
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
336
  raise ValueError(
337
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
338
- f" {attn_output.size()}"
339
  )
340
 
341
- attn_output = attn_output.transpose(1, 2)
342
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- attn_output = self.o_proj(attn_output)
 
 
 
345
 
346
- if not output_attentions:
347
- attn_weights = None
348
-
349
- return attn_output, attn_weights, past_key_value
350
- # use flash attention
351
- elif past_key_value is not None:
352
- from flash_attn.flash_attn_interface import flash_attn_with_kvcache
353
- output = flash_attn_with_kvcache(
354
- query_states.transpose(1, 2),
355
- key_states.transpose(1, 2),
356
- value_states.transpose(1, 2),
357
- cache_seqlens=kv_seq_len,
358
- causal=True,
359
- )
360
- attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
361
- else:
362
- qkv = torch.stack(
363
- [query_states, key_states, value_states], dim=2
364
- ) # [bsz, nh, 3, q_len, hd]
365
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
366
- attn_output = self.flash_attn_forward(qkv)
367
  return attn_output, None, past_key_value
368
 
369
 
 
 
 
 
 
 
 
370
  class LlamaDecoderLayer(nn.Module):
371
- def __init__(self, config: CLEXLlamaConfig):
372
  super().__init__()
373
  self.hidden_size = config.hidden_size
374
- self.self_attn = LlamaAttention(config=config)
375
- self.mlp = LlamaMLP(
376
- hidden_size=self.hidden_size,
377
- intermediate_size=config.intermediate_size,
378
- hidden_act=config.hidden_act,
379
- )
380
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
  self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
382
 
@@ -385,16 +822,18 @@ class LlamaDecoderLayer(nn.Module):
385
  hidden_states: torch.Tensor,
386
  attention_mask: Optional[torch.Tensor] = None,
387
  position_ids: Optional[torch.LongTensor] = None,
388
- pack_cos_sin=None,
389
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
390
  output_attentions: Optional[bool] = False,
391
  use_cache: Optional[bool] = False,
 
392
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
393
  """
394
  Args:
395
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
396
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
397
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
 
398
  output_attentions (`bool`, *optional*):
399
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
400
  returned tensors for more detail.
@@ -403,6 +842,10 @@ class LlamaDecoderLayer(nn.Module):
403
  (see `past_key_values`).
404
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
405
  """
 
 
 
 
406
 
407
  residual = hidden_states
408
 
@@ -411,12 +854,13 @@ class LlamaDecoderLayer(nn.Module):
411
  # Self Attention
412
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
413
  hidden_states=hidden_states,
414
- attention_mask=attention_mask,
415
  position_ids=position_ids,
416
  pack_cos_sin=pack_cos_sin,
417
  past_key_value=past_key_value,
418
  output_attentions=output_attentions,
419
  use_cache=use_cache,
 
420
  )
421
  hidden_states = residual + hidden_states
422
 
@@ -463,8 +907,10 @@ class LlamaPreTrainedModel(PreTrainedModel):
463
  base_model_prefix = "model"
464
  supports_gradient_checkpointing = True
465
  _no_split_modules = ["LlamaDecoderLayer"]
466
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
467
- _keep_in_fp32_modules = ["model.clex_layer.proj_func.ode_up_proj", "model.clex_layer.proj_func.ode_down_proj", "model.clex_layer.inv_freq"]
 
 
468
 
469
  def _init_weights(self, module):
470
  std = self.config.initializer_range
@@ -477,10 +923,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
477
  if module.padding_idx is not None:
478
  module.weight.data[module.padding_idx].zero_()
479
 
480
- def _set_gradient_checkpointing(self, module, value=False):
481
- if isinstance(module, LlamaModel):
482
- module.gradient_checkpointing = value
483
-
484
 
485
  LLAMA_INPUTS_DOCSTRING = r"""
486
  Args:
@@ -503,7 +945,7 @@ LLAMA_INPUTS_DOCSTRING = r"""
503
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
504
  [`PreTrainedTokenizer.__call__`] for details.
505
 
506
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
507
  `past_key_values`).
508
 
509
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@@ -517,17 +959,23 @@ LLAMA_INPUTS_DOCSTRING = r"""
517
  config.n_positions - 1]`.
518
 
519
  [What are position IDs?](../glossary#position-ids)
520
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
521
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
522
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
523
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
524
-
525
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
526
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
527
-
528
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
529
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
530
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
 
 
 
 
 
531
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
532
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
533
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@@ -564,15 +1012,20 @@ class LlamaModel(LlamaPreTrainedModel):
564
  self.vocab_size = config.vocab_size
565
 
566
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
567
- self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
 
568
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
569
- head_dim = config.hidden_size // config.num_attention_heads
570
- if config.rope_scaling["type"] == "clex":
571
- self.clex_layer = LlamaCLEXScalingRotaryEmbedding(head_dim, config.max_position_embeddings, config.rope_scaling)
572
  self.gradient_checkpointing = False
573
  # Initialize weights and apply final processing
574
  self.post_init()
575
-
 
 
 
576
 
577
  def get_input_embeddings(self):
578
  return self.embed_tokens
@@ -580,30 +1033,6 @@ class LlamaModel(LlamaPreTrainedModel):
580
  def set_input_embeddings(self, value):
581
  self.embed_tokens = value
582
 
583
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
584
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
585
- # create causal mask
586
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
587
- combined_attention_mask = None
588
- if input_shape[-1] > 1:
589
- combined_attention_mask = _make_causal_mask(
590
- input_shape,
591
- inputs_embeds.dtype,
592
- device=inputs_embeds.device,
593
- past_key_values_length=past_key_values_length,
594
- )
595
-
596
- if attention_mask is not None:
597
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
598
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
599
- inputs_embeds.device
600
- )
601
- combined_attention_mask = (
602
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
603
- )
604
-
605
- return combined_attention_mask
606
-
607
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
608
  def forward(
609
  self,
@@ -627,43 +1056,50 @@ class LlamaModel(LlamaPreTrainedModel):
627
 
628
  # retrieve input_ids and inputs_embeds
629
  if input_ids is not None and inputs_embeds is not None:
630
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
631
  elif input_ids is not None:
632
- batch_size, seq_length = input_ids.shape
633
  elif inputs_embeds is not None:
634
- batch_size, seq_length, _ = inputs_embeds.shape
635
  else:
636
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
637
 
638
- seq_length_with_past = seq_length
639
  past_key_values_length = 0
640
-
641
- if past_key_values is not None:
642
- past_key_values_length = past_key_values[0][0].shape[2]
643
- seq_length_with_past = seq_length_with_past + past_key_values_length
 
644
 
645
  if position_ids is None:
646
  device = input_ids.device if input_ids is not None else inputs_embeds.device
647
  position_ids = torch.arange(
648
  past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
649
  )
650
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
651
- else:
652
- position_ids = position_ids.view(-1, seq_length).long()
653
 
654
  if inputs_embeds is None:
655
  inputs_embeds = self.embed_tokens(input_ids)
656
- # embed positions
657
- if attention_mask is None:
658
- attention_mask = torch.ones(
659
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
660
- )
661
- attention_mask = self._prepare_decoder_attention_mask(
662
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
663
- )
664
- # attention_mask = None
665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
 
667
  hidden_states = inputs_embeds
668
 
669
  if self.gradient_checkpointing and self.training:
@@ -676,34 +1112,26 @@ class LlamaModel(LlamaPreTrainedModel):
676
  # decoder layers
677
  all_hidden_states = () if output_hidden_states else None
678
  all_self_attns = () if output_attentions else None
679
- next_decoder_cache = () if use_cache else None
680
-
681
  pack_cos_sin = None
682
  if self.config.rope_scaling["type"] == "clex":
683
- pack_cos_sin = self.clex_layer(inputs_embeds.device, inputs_embeds.dtype, seq_length_with_past, self.training)
684
 
685
- for idx, decoder_layer in enumerate(self.layers):
 
686
  if output_hidden_states:
687
  all_hidden_states += (hidden_states,)
688
 
689
- past_key_value = past_key_values[idx] if past_key_values is not None else None
690
-
691
  if self.gradient_checkpointing and self.training:
692
-
693
- def create_custom_forward(module):
694
- def custom_forward(*inputs):
695
- # None for past_key_value
696
- return module(*inputs, output_attentions, None)
697
-
698
- return custom_forward
699
-
700
- layer_outputs = torch.utils.checkpoint.checkpoint(
701
- create_custom_forward(decoder_layer),
702
  hidden_states,
703
  attention_mask,
704
  position_ids,
705
  pack_cos_sin,
706
- None,
 
 
707
  )
708
  else:
709
  layer_outputs = decoder_layer(
@@ -711,7 +1139,7 @@ class LlamaModel(LlamaPreTrainedModel):
711
  attention_mask=attention_mask,
712
  position_ids=position_ids,
713
  pack_cos_sin=pack_cos_sin,
714
- past_key_value=past_key_value,
715
  output_attentions=output_attentions,
716
  use_cache=use_cache,
717
  )
@@ -719,7 +1147,7 @@ class LlamaModel(LlamaPreTrainedModel):
719
  hidden_states = layer_outputs[0]
720
 
721
  if use_cache:
722
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
723
 
724
  if output_attentions:
725
  all_self_attns += (layer_outputs[1],)
@@ -730,7 +1158,9 @@ class LlamaModel(LlamaPreTrainedModel):
730
  if output_hidden_states:
731
  all_hidden_states += (hidden_states,)
732
 
733
- next_cache = next_decoder_cache if use_cache else None
 
 
734
  if not return_dict:
735
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
736
  return BaseModelOutputWithPast(
@@ -742,10 +1172,12 @@ class LlamaModel(LlamaPreTrainedModel):
742
 
743
 
744
  class LlamaForCausalLM(LlamaPreTrainedModel):
 
 
745
  def __init__(self, config):
746
  super().__init__(config)
747
  self.model = LlamaModel(config)
748
-
749
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
750
 
751
  # Initialize weights and apply final processing
@@ -801,15 +1233,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
801
  >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
802
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
803
 
804
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
805
  >>> inputs = tokenizer(prompt, return_tensors="pt")
806
 
807
  >>> # Generate
808
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
809
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
810
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
811
  ```"""
812
-
813
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
814
  output_hidden_states = (
815
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -830,7 +1261,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
830
  )
831
 
832
  hidden_states = outputs[0]
833
- logits = self.lm_head(hidden_states)
 
 
 
 
 
 
834
 
835
  loss = None
836
  if labels is not None:
@@ -844,9 +1281,11 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
844
  # Enable model parallelism
845
  shift_labels = shift_labels.to(shift_logits.device)
846
  loss = loss_fct(shift_logits, shift_labels)
 
847
  if not return_dict:
848
  output = (logits,) + outputs[1:]
849
  return (loss,) + output if loss is not None else output
 
850
  return CausalLMOutputWithPast(
851
  loss=loss,
852
  logits=logits,
@@ -858,8 +1297,34 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
858
  def prepare_inputs_for_generation(
859
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
860
  ):
861
- if past_key_values:
862
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
  position_ids = kwargs.get("position_ids", None)
865
  if attention_mask is not None and position_ids is None:
@@ -867,7 +1332,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
867
  position_ids = attention_mask.long().cumsum(-1) - 1
868
  position_ids.masked_fill_(attention_mask == 0, 1)
869
  if past_key_values:
870
- position_ids = position_ids[:, -1].unsqueeze(-1)
871
 
872
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
873
  if inputs_embeds is not None and past_key_values is None:
@@ -889,7 +1354,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
889
  def _reorder_cache(past_key_values, beam_idx):
890
  reordered_past = ()
891
  for layer_past in past_key_values:
892
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
 
 
893
  return reordered_past
894
 
895
 
@@ -909,8 +1376,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
909
  LLAMA_START_DOCSTRING,
910
  )
911
  class LlamaForSequenceClassification(LlamaPreTrainedModel):
912
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
913
-
914
  def __init__(self, config):
915
  super().__init__(config)
916
  self.num_labels = config.num_labels
@@ -973,7 +1438,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
973
  sequence_lengths = -1
974
  else:
975
  if input_ids is not None:
976
- sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
 
 
977
  else:
978
  sequence_lengths = -1
979
 
 
1
+
2
  # coding=utf-8
3
  # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
  #
 
20
  # limitations under the License.
21
  """ PyTorch LLaMA model."""
22
  import math
23
+ import warnings
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
27
+ import torch.nn.functional as F
28
  import torch.utils.checkpoint
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
  from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache
34
+ from transformers.modeling_attn_mask_utils import (
35
+ AttentionMaskConverter,
36
+ _prepare_4d_attention_mask,
37
+ _prepare_4d_causal_attention_mask,
38
+ _prepare_4d_causal_attention_mask_for_sdpa,
39
+ )
40
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
41
  from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
43
+ from transformers.utils import (
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_flash_attn_2_available,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from transformers.utils.import_utils import is_torch_fx_available
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ if is_flash_attn_2_available():
55
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
 
 
58
 
59
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
60
+ # It means that the function will not be traced through and simply appear as a node in the graph.
61
+ if is_torch_fx_available():
62
+ if not is_torch_greater_or_equal_than_1_13:
63
+ import torch.fx
64
 
65
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
66
 
67
 
68
+ from .configuration_llama_clex import CLEXLlamaConfig
69
+ from .clex_layer import CLEXScalingRotaryEmbedding
70
 
71
 
72
+ logger = logging.get_logger(__name__)
73
 
74
  _CONFIG_FOR_DOC = "CLEXLlamaConfig"
75
 
76
 
77
+ def _get_unpad_data(attention_mask):
78
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
79
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
80
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
81
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
82
+ return (
83
+ indices,
84
+ cu_seqlens,
85
+ max_seqlen_in_batch,
86
+ )
87
 
88
 
89
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
90
+ warnings.warn(
91
+ "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
92
+ )
93
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
94
+
95
 
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
98
  ):
99
+ warnings.warn(
100
+ "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
101
+ )
102
+ return AttentionMaskConverter._make_causal_mask(
103
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
104
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  class LlamaRMSNorm(nn.Module):
 
114
  self.variance_epsilon = eps
115
 
116
  def forward(self, hidden_states):
117
+ input_dtype = hidden_states.dtype
118
+ hidden_states = hidden_states.to(torch.float32)
119
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
120
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
121
+ return self.weight * hidden_states.to(input_dtype)
122
 
 
 
 
123
 
124
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
125
 
126
 
127
+ class LlamaRotaryEmbedding(nn.Module):
128
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
129
  super().__init__()
130
+
131
+ self.dim = dim
132
+ self.max_position_embeddings = max_position_embeddings
133
+ self.base = base
134
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
135
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
136
 
137
  # Build here to make `torch.jit.trace` work.
138
+ self._set_cos_sin_cache(
139
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
140
+ )
141
+
142
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
143
+ self.max_seq_len_cached = seq_len
144
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
145
+
146
+ freqs = torch.outer(t, self.inv_freq)
147
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
148
  emb = torch.cat((freqs, freqs), dim=-1)
149
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
150
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
151
 
152
  def forward(self, x, seq_len=None):
153
  # x: [bs, num_attention_heads, seq_len, head_size]
 
154
  if seq_len > self.max_seq_len_cached:
155
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
156
+
 
 
 
 
 
157
  return (
158
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
159
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
160
  )
161
 
162
 
163
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
164
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
165
+
166
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
167
+ self.scaling_factor = scaling_factor
168
+ super().__init__(dim, max_position_embeddings, base, device)
169
+
170
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
171
+ self.max_seq_len_cached = seq_len
172
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
173
+ t = t / self.scaling_factor
174
+
175
+ freqs = torch.outer(t, self.inv_freq)
176
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
177
+ emb = torch.cat((freqs, freqs), dim=-1)
178
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
179
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
180
+
181
+
182
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
183
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
184
+
185
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
186
+ self.scaling_factor = scaling_factor
187
+ super().__init__(dim, max_position_embeddings, base, device)
188
+
189
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
190
+ self.max_seq_len_cached = seq_len
191
+
192
+ if seq_len > self.max_position_embeddings:
193
+ base = self.base * (
194
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
195
+ ) ** (self.dim / (self.dim - 2))
196
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
197
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
198
+
199
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
200
+
201
+ freqs = torch.outer(t, self.inv_freq)
202
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
203
+ emb = torch.cat((freqs, freqs), dim=-1)
204
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
205
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
206
+
207
+
208
  def rotate_half(x):
209
  """Rotates half the hidden dims of the input."""
210
  x1 = x[..., : x.shape[-1] // 2]
 
212
  return torch.cat((-x2, x1), dim=-1)
213
 
214
 
215
+ # def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
216
+ # """Applies Rotary Position Embedding to the query and key tensors.
217
+
218
+ # Args:
219
+ # q (`torch.Tensor`): The query tensor.
220
+ # k (`torch.Tensor`): The key tensor.
221
+ # cos (`torch.Tensor`): The cosine part of the rotary embedding.
222
+ # sin (`torch.Tensor`): The sine part of the rotary embedding.
223
+ # position_ids (`torch.Tensor`):
224
+ # The position indices of the tokens corresponding to the query and key tensors. For example, this can be
225
+ # used to pass offsetted position ids when working with a KV-cache.
226
+ # unsqueeze_dim (`int`, *optional*, defaults to 1):
227
+ # The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
228
+ # sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
229
+ # that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
230
+ # k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
231
+ # cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
232
+ # the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
233
+ # Returns:
234
+ # `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
235
+ # """
236
+ # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
237
+ # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
238
+ # q_embed = (q * cos) + (rotate_half(q) * sin)
239
+ # k_embed = (k * cos) + (rotate_half(k) * sin)
240
+ # return q_embed, k_embed
241
+
242
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids, unsqueeze_dim=1):
243
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
244
+ cos_q = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
245
+ sin_q = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
 
 
246
 
247
+ cos_k = cos[key_position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
248
+ sin_k = sin[key_position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
249
  q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
250
  k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
251
  return q_embed, k_embed
252
 
253
 
254
  class LlamaMLP(nn.Module):
255
+ def __init__(self, config):
 
 
 
 
 
256
  super().__init__()
257
+ self.config = config
258
+ self.hidden_size = config.hidden_size
259
+ self.intermediate_size = config.intermediate_size
260
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
261
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
262
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
263
+ self.act_fn = ACT2FN[config.hidden_act]
264
 
265
  def forward(self, x):
266
+ if self.config.pretraining_tp > 1:
267
+ slice = self.intermediate_size // self.config.pretraining_tp
268
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
269
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
270
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
271
+
272
+ gate_proj = torch.cat(
273
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
274
+ )
275
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
276
+
277
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
278
+ down_proj = [
279
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
280
+ ]
281
+ down_proj = sum(down_proj)
282
+ else:
283
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
284
+
285
+ return down_proj
286
+
287
+
288
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
289
+ """
290
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
291
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
292
+ """
293
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
294
+ if n_rep == 1:
295
+ return hidden_states
296
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
297
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
298
 
299
 
300
  class LlamaAttention(nn.Module):
301
  """Multi-headed attention from 'Attention Is All You Need' paper"""
302
 
303
+ def __init__(self, config: CLEXLlamaConfig, layer_idx: Optional[int] = None):
304
  super().__init__()
305
  self.config = config
306
+ self.layer_idx = layer_idx
307
+ if layer_idx is None:
308
+ logger.warning_once(
309
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
310
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
311
+ "when creating this class."
312
+ )
313
+
314
+ self.attention_dropout = config.attention_dropout
315
  self.hidden_size = config.hidden_size
316
  self.num_heads = config.num_attention_heads
317
  self.head_dim = self.hidden_size // self.num_heads
318
+ self.num_key_value_heads = config.num_key_value_heads
319
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
320
  self.max_position_embeddings = config.max_position_embeddings
321
+ self.rope_theta = config.rope_theta
322
+ self.is_causal = True
323
+
324
  if (self.head_dim * self.num_heads) != self.hidden_size:
325
  raise ValueError(
326
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
327
  f" and `num_heads`: {self.num_heads})."
328
  )
329
+
330
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
331
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
332
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
333
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
334
+ self._init_rope()
335
+
336
+ def _init_rope(self):
337
+ if self.config.rope_scaling is None:
338
+ self.rotary_emb = LlamaRotaryEmbedding(
339
+ self.head_dim,
340
+ max_position_embeddings=self.max_position_embeddings,
341
+ base=self.rope_theta,
342
+ )
343
+ else:
344
+ scaling_type = self.config.rope_scaling["type"]
345
+ scaling_factor = self.config.rope_scaling["factor"]
346
+ if scaling_type == "linear":
347
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
348
+ self.head_dim,
349
+ max_position_embeddings=self.max_position_embeddings,
350
+ scaling_factor=scaling_factor,
351
+ base=self.rope_theta,
352
+ )
353
+ elif scaling_type == "dynamic":
354
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
355
+ self.head_dim,
356
+ max_position_embeddings=self.max_position_embeddings,
357
+ scaling_factor=scaling_factor,
358
+ base=self.rope_theta,
359
+ )
360
+ else: pass
361
+ # raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
362
 
363
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
364
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
365
+
366
+ def forward(
367
  self,
368
+ hidden_states: torch.Tensor,
369
+ attention_mask: Optional[torch.Tensor] = None,
370
+ position_ids: Optional[torch.LongTensor] = None,
371
+ pack_cos_sin: Optional[torch.Tensor] = None,
372
+ past_key_value: Optional[Cache] = None,
373
+ output_attentions: bool = False,
374
+ use_cache: bool = False,
375
+ **kwargs,
376
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
377
+ if "padding_mask" in kwargs:
378
+ warnings.warn(
379
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
 
 
 
 
 
 
 
 
 
 
 
 
380
  )
381
+
382
+ bsz, q_len, _ = hidden_states.size()
383
+
384
+ if self.config.pretraining_tp > 1:
385
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
386
+ query_slices = self.q_proj.weight.split(
387
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
388
  )
389
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
390
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
391
+
392
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
393
+ query_states = torch.cat(query_states, dim=-1)
394
+
395
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
396
+ key_states = torch.cat(key_states, dim=-1)
397
+
398
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
399
+ value_states = torch.cat(value_states, dim=-1)
400
+
401
  else:
402
+ query_states = self.q_proj(hidden_states)
403
+ key_states = self.k_proj(hidden_states)
404
+ value_states = self.v_proj(hidden_states)
405
+
406
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
407
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
408
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
409
+
410
+ kv_seq_len = key_states.shape[-2]
411
+ if past_key_value is not None:
412
+ if self.layer_idx is None:
413
+ raise ValueError(
414
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
415
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
416
+ "with a layer index."
417
+ )
418
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
419
+
420
+
421
+ if pack_cos_sin is not None:
422
+ cos, sin = pack_cos_sin.to(query_states.device)
423
+ else:
424
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
425
+ ## Update KV cache before RoPE
426
+ if past_key_value is not None:
427
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
428
+ cache_key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
429
+ else:
430
+ cache_key_states = key_states
431
+
432
+ key_position_ids = torch.arange(position_ids[:, -1].max().item() + 1, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, position_ids[:, -1].max().item() + 1)
433
+
434
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
435
+ # print(cache_key_states.size(), cos.size())
436
+ query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
437
+
438
+
439
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
440
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
441
+
442
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
443
+
444
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
445
+ raise ValueError(
446
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
447
+ f" {attn_weights.size()}"
448
  )
449
+
450
+ if attention_mask is not None:
451
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
452
+ raise ValueError(
453
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
454
+ )
455
+ attn_weights = attn_weights + attention_mask
456
+
457
+ # upcast attention to fp32
458
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
459
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
460
+ attn_output = torch.matmul(attn_weights, value_states)
461
+
462
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
463
+ raise ValueError(
464
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
465
+ f" {attn_output.size()}"
466
  )
467
+
468
+ attn_output = attn_output.transpose(1, 2).contiguous()
469
+
470
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
471
+
472
+ if self.config.pretraining_tp > 1:
473
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
474
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
475
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
476
+ else:
477
+ attn_output = self.o_proj(attn_output)
478
+
479
+ if not output_attentions:
480
+ attn_weights = None
481
+
482
+ return attn_output, attn_weights, past_key_value
483
+
484
+
485
+ class LlamaFlashAttention2(LlamaAttention):
486
+ """
487
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
488
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
489
+ flash attention and deal with padding tokens in case the input contains any of them.
490
+ """
491
+
492
+ def __init__(self, *args, **kwargs):
493
+ super().__init__(*args, **kwargs)
494
+
495
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
496
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
497
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
498
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
499
+
500
  def forward(
501
  self,
502
  hidden_states: torch.Tensor,
503
+ attention_mask: Optional[torch.LongTensor] = None,
504
  position_ids: Optional[torch.LongTensor] = None,
505
+ pack_cos_sin: Optional[torch.Tensor] = None,
506
+ past_key_value: Optional[Cache] = None,
507
  output_attentions: bool = False,
508
  use_cache: bool = False,
509
+ **kwargs,
510
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
511
+ # LlamaFlashAttention2 attention does not support output_attentions
512
+ if "padding_mask" in kwargs:
513
+ warnings.warn(
514
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
515
+ )
516
+
517
+ # overwrite attention_mask with padding_mask
518
+ attention_mask = kwargs.pop("padding_mask")
519
+
520
+ output_attentions = False
521
+
522
  bsz, q_len, _ = hidden_states.size()
523
 
524
+ query_states = self.q_proj(hidden_states)
525
+ key_states = self.k_proj(hidden_states)
526
+ value_states = self.v_proj(hidden_states)
527
 
528
+ # Flash attention requires the input to have the shape
529
+ # batch_size x seq_length x head_dim x hidden_dim
530
+ # therefore we just need to keep the original shape
531
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
532
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
533
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
534
 
535
+ kv_seq_len = key_states.shape[-2]
536
  if past_key_value is not None:
537
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
538
+
 
 
 
539
  if pack_cos_sin is not None:
540
  cos, sin = pack_cos_sin.to(query_states.device)
541
  else:
542
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
543
+ ## Update KV cache before RoPE
544
+ if past_key_value is not None:
545
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
546
+ cache_key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
547
+ else:
548
+ cache_key_states = key_states
549
+
550
+ key_position_ids = torch.arange(position_ids[:, -1].max().item() + 1, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, position_ids[:, -1].max().item() + 1)
551
+
552
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
553
+ # print(cache_key_states.size(), cos.size())
554
  query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
555
 
556
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
557
+ # to be able to avoid many of these transpose/reshape/view.
558
+ query_states = query_states.transpose(1, 2)
559
+ key_states = key_states.transpose(1, 2)
560
+ value_states = value_states.transpose(1, 2)
561
+
562
+ dropout_rate = self.attention_dropout if self.training else 0.0
563
+
564
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
565
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
566
+ # cast them back in the correct dtype just to be sure everything works as expected.
567
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
568
+ # in fp32. (LlamaRMSNorm handles it correctly)
569
+
570
+ input_dtype = query_states.dtype
571
+ if input_dtype == torch.float32:
572
+ # Handle the case where the model is quantized
573
+ if hasattr(self.config, "_pre_quantization_dtype"):
574
+ target_dtype = self.config._pre_quantization_dtype
575
+ else:
576
+ target_dtype = self.q_proj.weight.dtype
577
 
578
+ logger.warning_once(
579
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
580
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
581
+ f" {target_dtype}."
582
+ )
583
 
584
+ query_states = query_states.to(target_dtype)
585
+ key_states = key_states.to(target_dtype)
586
+ value_states = value_states.to(target_dtype)
587
 
588
+ if self.config.log_scale:
589
+ # naive_len = kv_seq_len if kv_seq_len < self.config.max_position_embeddings else self.config.max_position_embeddings
590
+ naive_len = self.config.max_position_embeddings
591
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
592
+ torch.log(torch.tensor(naive_len)).to(query_states.device, dtype=query_states.dtype)
593
  query_states = query_states * log_n
 
594
 
595
+ attn_output = self._flash_attention_forward(
596
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
597
+ )
598
+
599
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
600
+ attn_output = self.o_proj(attn_output)
601
+
602
+ if not output_attentions:
603
+ attn_weights = None
604
+
605
+ return attn_output, attn_weights, past_key_value
606
+
607
+ def _flash_attention_forward(
608
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
609
+ ):
610
+ """
611
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
612
+ first unpad the input, then computes the attention scores and pad the final attention scores.
613
+
614
+ Args:
615
+ query_states (`torch.Tensor`):
616
+ Input query states to be passed to Flash Attention API
617
+ key_states (`torch.Tensor`):
618
+ Input key states to be passed to Flash Attention API
619
+ value_states (`torch.Tensor`):
620
+ Input value states to be passed to Flash Attention API
621
+ attention_mask (`torch.Tensor`):
622
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
623
+ position of padding tokens and 1 for the position of non-padding tokens.
624
+ dropout (`int`, *optional*):
625
+ Attention dropout
626
+ softmax_scale (`float`, *optional*):
627
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
628
+ """
629
+ if not self._flash_attn_uses_top_left_mask:
630
+ causal = self.is_causal
631
+ else:
632
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
633
+ causal = self.is_causal and query_length != 1
634
+
635
+ # Contains at least one padding token in the sequence
636
+ if attention_mask is not None:
637
+ batch_size = query_states.shape[0]
638
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
639
+ query_states, key_states, value_states, attention_mask, query_length
640
+ )
641
+
642
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
643
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
644
+
645
+ attn_output_unpad = flash_attn_varlen_func(
646
+ query_states,
647
+ key_states,
648
+ value_states,
649
+ cu_seqlens_q=cu_seqlens_q,
650
+ cu_seqlens_k=cu_seqlens_k,
651
+ max_seqlen_q=max_seqlen_in_batch_q,
652
+ max_seqlen_k=max_seqlen_in_batch_k,
653
+ dropout_p=dropout,
654
+ softmax_scale=softmax_scale,
655
+ causal=causal,
656
+ )
657
+
658
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
659
+ else:
660
+ attn_output = flash_attn_func(
661
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
662
+ )
663
+
664
+ return attn_output
665
+
666
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
667
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
668
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
669
+
670
+ key_layer = index_first_axis(
671
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
672
+ )
673
+ value_layer = index_first_axis(
674
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
675
+ )
676
+ if query_length == kv_seq_len:
677
+ query_layer = index_first_axis(
678
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
679
+ )
680
+ cu_seqlens_q = cu_seqlens_k
681
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
682
+ indices_q = indices_k
683
+ elif query_length == 1:
684
+ max_seqlen_in_batch_q = 1
685
+ cu_seqlens_q = torch.arange(
686
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
687
+ ) # There is a memcpy here, that is very bad.
688
+ indices_q = cu_seqlens_q[:-1]
689
+ query_layer = query_layer.squeeze(1)
690
+ else:
691
+ # The -q_len: slice assumes left padding.
692
+ attention_mask = attention_mask[:, -query_length:]
693
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
694
+
695
+ return (
696
+ query_layer,
697
+ key_layer,
698
+ value_layer,
699
+ indices_q,
700
+ (cu_seqlens_q, cu_seqlens_k),
701
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
702
+ )
703
+
704
+
705
+ class LlamaSdpaAttention(LlamaAttention):
706
+ """
707
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
708
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
709
+ SDPA API.
710
+ """
711
+
712
+ # Adapted from LlamaAttention.forward
713
+ def forward(
714
+ self,
715
+ hidden_states: torch.Tensor,
716
+ attention_mask: Optional[torch.Tensor] = None,
717
+ position_ids: Optional[torch.LongTensor] = None,
718
+ pack_cos_sin: Optional[torch.Tensor] = None,
719
+ past_key_value: Optional[Cache] = None,
720
+ output_attentions: bool = False,
721
+ use_cache: bool = False,
722
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
723
+ if output_attentions:
724
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
725
+ logger.warning_once(
726
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
727
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
728
+ )
729
+ return super().forward(
730
+ hidden_states=hidden_states,
731
+ attention_mask=attention_mask,
732
+ position_ids=position_ids,
733
+ past_key_value=past_key_value,
734
+ output_attentions=output_attentions,
735
+ use_cache=use_cache,
736
+ )
737
+
738
+ bsz, q_len, _ = hidden_states.size()
739
+
740
+ query_states = self.q_proj(hidden_states)
741
+ key_states = self.k_proj(hidden_states)
742
+ value_states = self.v_proj(hidden_states)
743
+
744
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
745
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
746
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
747
+
748
+ kv_seq_len = key_states.shape[-2]
749
+ if past_key_value is not None:
750
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
751
+ if pack_cos_sin is not None:
752
+ cos, sin = pack_cos_sin.to(query_states.device)
753
+ else:
754
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
755
+ ## Update KV cache before RoPE
756
+ if past_key_value is not None:
757
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
758
+ cache_key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
759
+ else:
760
+ cache_key_states = key_states
761
+
762
+ key_position_ids = torch.arange(position_ids[:, -1].max().item() + 1, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, position_ids[:, -1].max().item() + 1)
763
+
764
+ query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
765
 
 
 
 
 
 
766
 
 
 
 
 
 
 
 
767
 
768
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
769
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
 
770
 
771
+ if attention_mask is not None:
772
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
773
  raise ValueError(
774
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
 
775
  )
776
 
777
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
778
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
779
+ if query_states.device.type == "cuda" and attention_mask is not None:
780
+ query_states = query_states.contiguous()
781
+ key_states = key_states.contiguous()
782
+ value_states = value_states.contiguous()
783
+
784
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
785
+ query_states,
786
+ key_states,
787
+ value_states,
788
+ attn_mask=attention_mask,
789
+ dropout_p=self.attention_dropout if self.training else 0.0,
790
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
791
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
792
+ )
793
 
794
+ attn_output = attn_output.transpose(1, 2).contiguous()
795
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
796
+
797
+ attn_output = self.o_proj(attn_output)
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
  return attn_output, None, past_key_value
800
 
801
 
802
+ LLAMA_ATTENTION_CLASSES = {
803
+ "eager": LlamaAttention,
804
+ "flash_attention_2": LlamaFlashAttention2,
805
+ "sdpa": LlamaSdpaAttention,
806
+ }
807
+
808
+
809
  class LlamaDecoderLayer(nn.Module):
810
+ def __init__(self, config: CLEXLlamaConfig, layer_idx: int):
811
  super().__init__()
812
  self.hidden_size = config.hidden_size
813
+
814
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
815
+
816
+ self.mlp = LlamaMLP(config)
 
 
817
  self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
818
  self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
819
 
 
822
  hidden_states: torch.Tensor,
823
  attention_mask: Optional[torch.Tensor] = None,
824
  position_ids: Optional[torch.LongTensor] = None,
825
+ pack_cos_sin: Optional[torch.Tensor] = None,
826
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
827
  output_attentions: Optional[bool] = False,
828
  use_cache: Optional[bool] = False,
829
+ **kwargs,
830
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
831
  """
832
  Args:
833
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
834
+ attention_mask (`torch.FloatTensor`, *optional*):
835
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
836
+ query_sequence_length, key_sequence_length)` if default attention is used.
837
  output_attentions (`bool`, *optional*):
838
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
839
  returned tensors for more detail.
 
842
  (see `past_key_values`).
843
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
844
  """
845
+ if "padding_mask" in kwargs:
846
+ warnings.warn(
847
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
848
+ )
849
 
850
  residual = hidden_states
851
 
 
854
  # Self Attention
855
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
856
  hidden_states=hidden_states,
857
+ attention_mask=attention_mask,
858
  position_ids=position_ids,
859
  pack_cos_sin=pack_cos_sin,
860
  past_key_value=past_key_value,
861
  output_attentions=output_attentions,
862
  use_cache=use_cache,
863
+ **kwargs,
864
  )
865
  hidden_states = residual + hidden_states
866
 
 
907
  base_model_prefix = "model"
908
  supports_gradient_checkpointing = True
909
  _no_split_modules = ["LlamaDecoderLayer"]
910
+ _skip_keys_device_placement = "past_key_values"
911
+ _supports_flash_attn_2 = True
912
+ _supports_sdpa = True
913
+ _supports_cache_class = True
914
 
915
  def _init_weights(self, module):
916
  std = self.config.initializer_range
 
923
  if module.padding_idx is not None:
924
  module.weight.data[module.padding_idx].zero_()
925
 
 
 
 
 
926
 
927
  LLAMA_INPUTS_DOCSTRING = r"""
928
  Args:
 
945
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
946
  [`PreTrainedTokenizer.__call__`] for details.
947
 
948
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
949
  `past_key_values`).
950
 
951
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
 
959
  config.n_positions - 1]`.
960
 
961
  [What are position IDs?](../glossary#position-ids)
962
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
963
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
964
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
965
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
966
+
967
+ Two formats are allowed:
968
+ - a [`~cache_utils.Cache`] instance;
969
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
970
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
971
+ cache format.
972
+
973
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
974
+ legacy cache format will be returned.
975
+
976
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
977
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
978
+ of shape `(batch_size, sequence_length)`.
979
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
980
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
981
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
 
1012
  self.vocab_size = config.vocab_size
1013
 
1014
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1015
+ self.layers = nn.ModuleList(
1016
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1017
+ )
1018
+ self._use_sdpa = config._attn_implementation == "sdpa"
1019
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1020
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1021
+
 
 
1022
  self.gradient_checkpointing = False
1023
  # Initialize weights and apply final processing
1024
  self.post_init()
1025
+ head_dim = config.hidden_size // config.num_attention_heads
1026
+ if config.rope_scaling["type"] == "clex":
1027
+ self.clex_layer = CLEXScalingRotaryEmbedding(head_dim, config.max_position_embeddings, config.rope_scaling, config.rope_theta)
1028
+
1029
 
1030
  def get_input_embeddings(self):
1031
  return self.embed_tokens
 
1033
  def set_input_embeddings(self, value):
1034
  self.embed_tokens = value
1035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1036
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1037
  def forward(
1038
  self,
 
1056
 
1057
  # retrieve input_ids and inputs_embeds
1058
  if input_ids is not None and inputs_embeds is not None:
1059
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1060
  elif input_ids is not None:
1061
+ batch_size, seq_length = input_ids.shape[:2]
1062
  elif inputs_embeds is not None:
1063
+ batch_size, seq_length = inputs_embeds.shape[:2]
1064
  else:
1065
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1066
 
 
1067
  past_key_values_length = 0
1068
+ if use_cache:
1069
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1070
+ if use_legacy_cache:
1071
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1072
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1073
 
1074
  if position_ids is None:
1075
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1076
  position_ids = torch.arange(
1077
  past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1078
  )
1079
+ position_ids = position_ids.unsqueeze(0)
 
 
1080
 
1081
  if inputs_embeds is None:
1082
  inputs_embeds = self.embed_tokens(input_ids)
 
 
 
 
 
 
 
 
 
1083
 
1084
+ if self._use_flash_attention_2:
1085
+ # 2d mask is passed through the layers
1086
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1087
+ elif self._use_sdpa and not output_attentions:
1088
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1089
+ # the manual implementation that requires a 4D causal mask in all cases.
1090
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1091
+ attention_mask,
1092
+ (batch_size, seq_length),
1093
+ inputs_embeds,
1094
+ past_key_values_length,
1095
+ )
1096
+ else:
1097
+ # 4d mask is passed through the layers
1098
+ attention_mask = _prepare_4d_causal_attention_mask(
1099
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1100
+ )
1101
 
1102
+ # embed positions
1103
  hidden_states = inputs_embeds
1104
 
1105
  if self.gradient_checkpointing and self.training:
 
1112
  # decoder layers
1113
  all_hidden_states = () if output_hidden_states else None
1114
  all_self_attns = () if output_attentions else None
1115
+ next_decoder_cache = None
 
1116
  pack_cos_sin = None
1117
  if self.config.rope_scaling["type"] == "clex":
1118
+ pack_cos_sin = self.clex_layer(inputs_embeds.device, inputs_embeds.dtype, seq_length + past_key_values_length, self.training)
1119
 
1120
+
1121
+ for decoder_layer in self.layers:
1122
  if output_hidden_states:
1123
  all_hidden_states += (hidden_states,)
1124
 
 
 
1125
  if self.gradient_checkpointing and self.training:
1126
+ layer_outputs = self._gradient_checkpointing_func(
1127
+ decoder_layer.__call__,
 
 
 
 
 
 
 
 
1128
  hidden_states,
1129
  attention_mask,
1130
  position_ids,
1131
  pack_cos_sin,
1132
+ past_key_values,
1133
+ output_attentions,
1134
+ use_cache,
1135
  )
1136
  else:
1137
  layer_outputs = decoder_layer(
 
1139
  attention_mask=attention_mask,
1140
  position_ids=position_ids,
1141
  pack_cos_sin=pack_cos_sin,
1142
+ past_key_value=past_key_values,
1143
  output_attentions=output_attentions,
1144
  use_cache=use_cache,
1145
  )
 
1147
  hidden_states = layer_outputs[0]
1148
 
1149
  if use_cache:
1150
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1151
 
1152
  if output_attentions:
1153
  all_self_attns += (layer_outputs[1],)
 
1158
  if output_hidden_states:
1159
  all_hidden_states += (hidden_states,)
1160
 
1161
+ next_cache = None
1162
+ if use_cache:
1163
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1164
  if not return_dict:
1165
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1166
  return BaseModelOutputWithPast(
 
1172
 
1173
 
1174
  class LlamaForCausalLM(LlamaPreTrainedModel):
1175
+ _tied_weights_keys = ["lm_head.weight"]
1176
+
1177
  def __init__(self, config):
1178
  super().__init__(config)
1179
  self.model = LlamaModel(config)
1180
+ self.vocab_size = config.vocab_size
1181
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1182
 
1183
  # Initialize weights and apply final processing
 
1233
  >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1234
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1235
 
1236
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1237
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1238
 
1239
  >>> # Generate
1240
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1241
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1242
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1243
  ```"""
 
1244
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1245
  output_hidden_states = (
1246
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
1261
  )
1262
 
1263
  hidden_states = outputs[0]
1264
+ if self.config.pretraining_tp > 1:
1265
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1266
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1267
+ logits = torch.cat(logits, dim=-1)
1268
+ else:
1269
+ logits = self.lm_head(hidden_states)
1270
+ logits = logits.float()
1271
 
1272
  loss = None
1273
  if labels is not None:
 
1281
  # Enable model parallelism
1282
  shift_labels = shift_labels.to(shift_logits.device)
1283
  loss = loss_fct(shift_logits, shift_labels)
1284
+
1285
  if not return_dict:
1286
  output = (logits,) + outputs[1:]
1287
  return (loss,) + output if loss is not None else output
1288
+
1289
  return CausalLMOutputWithPast(
1290
  loss=loss,
1291
  logits=logits,
 
1297
  def prepare_inputs_for_generation(
1298
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1299
  ):
1300
+ if past_key_values is not None:
1301
+ if isinstance(past_key_values, Cache):
1302
+ cache_length = past_key_values.get_seq_length()
1303
+ past_length = past_key_values.seen_tokens
1304
+ max_cache_length = past_key_values.get_max_length()
1305
+ else:
1306
+ cache_length = past_length = past_key_values[0][0].shape[2]
1307
+ max_cache_length = None
1308
+
1309
+ # Keep only the unprocessed tokens:
1310
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1311
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1312
+ # input)
1313
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1314
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1315
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1316
+ # input_ids based on the past_length.
1317
+ elif past_length < input_ids.shape[1]:
1318
+ input_ids = input_ids[:, past_length:]
1319
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1320
+
1321
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1322
+ if (
1323
+ max_cache_length is not None
1324
+ and attention_mask is not None
1325
+ and cache_length + input_ids.shape[1] > max_cache_length
1326
+ ):
1327
+ attention_mask = attention_mask[:, -max_cache_length:]
1328
 
1329
  position_ids = kwargs.get("position_ids", None)
1330
  if attention_mask is not None and position_ids is None:
 
1332
  position_ids = attention_mask.long().cumsum(-1) - 1
1333
  position_ids.masked_fill_(attention_mask == 0, 1)
1334
  if past_key_values:
1335
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1336
 
1337
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1338
  if inputs_embeds is not None and past_key_values is None:
 
1354
  def _reorder_cache(past_key_values, beam_idx):
1355
  reordered_past = ()
1356
  for layer_past in past_key_values:
1357
+ reordered_past += (
1358
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1359
+ )
1360
  return reordered_past
1361
 
1362
 
 
1376
  LLAMA_START_DOCSTRING,
1377
  )
1378
  class LlamaForSequenceClassification(LlamaPreTrainedModel):
 
 
1379
  def __init__(self, config):
1380
  super().__init__(config)
1381
  self.num_labels = config.num_labels
 
1438
  sequence_lengths = -1
1439
  else:
1440
  if input_ids is not None:
1441
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1442
+ logits.device
1443
+ )
1444
  else:
1445
  sequence_lengths = -1
1446