amazingvince commited on
Commit
6a9c34e
·
verified ·
1 Parent(s): 7a29a57

Update modeling_diff_llama.py

Browse files
Files changed (1) hide show
  1. modeling_diff_llama.py +518 -0
modeling_diff_llama.py CHANGED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union, List, Dict
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, repeat
8
+ from transformers import PreTrainedModel, LlamaConfig
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
10
+ from transformers.models.llama.modeling_llama import (
11
+ LlamaRMSNorm,
12
+ LlamaRotaryEmbedding,
13
+ LlamaLinearScalingRotaryEmbedding,
14
+ LlamaDynamicNTKScalingRotaryEmbedding,
15
+ LlamaMLP,
16
+ apply_rotary_pos_emb,
17
+ repeat_kv,
18
+ )
19
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
20
+
21
+
22
+ class DiffLLaMAConfig(LlamaConfig):
23
+ """
24
+ Configuration class for the DiffLLaMA model.
25
+ Inherits from LlamaConfig and can be extended with additional parameters.
26
+ """
27
+ model_type = "diff_llama"
28
+
29
+ def __init__(
30
+ self,
31
+ num_kv_heads: int = 8,
32
+ intermediate_size: int = 3072,
33
+ rope_scaling: Optional[Dict[str, Union[str, float]]] = None,
34
+ **kwargs
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.num_kv_heads = num_kv_heads
38
+ self.intermediate_size = intermediate_size
39
+ self.rope_scaling = rope_scaling or {"type": "linear", "factor": 1.0}
40
+ # Add any custom configuration parameters here
41
+
42
+ @classmethod
43
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
44
+ """
45
+ Load configuration from a pretrained model.
46
+ """
47
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
48
+ return cls(**config_dict)
49
+
50
+
51
+ def init_method(tensor):
52
+ """Initialize tensor with Kaiming uniform initialization."""
53
+ nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))
54
+
55
+ def lambda_init_fn(depth):
56
+ """Compute lambda initialization value based on layer depth."""
57
+ return 0.8 - 0.6 * math.exp(-0.3 * depth)
58
+
59
+ class MultiheadDiffAttn(nn.Module):
60
+ def __init__(self, config: DiffLLaMAConfig, layer_idx: Optional[int] = None):
61
+ super().__init__()
62
+ self.config = config
63
+ self.hidden_size = config.hidden_size
64
+ self.num_heads = config.num_attention_heads
65
+ self.head_dim = self.hidden_size // self.num_heads
66
+ self.num_key_value_heads = config.num_kv_heads
67
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
68
+ self.max_position_embeddings = config.max_position_embeddings
69
+ self.rope_theta = config.rope_theta
70
+
71
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
72
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
73
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
74
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
75
+
76
+ self.scaling = self.head_dim ** -0.5
77
+
78
+ self.rotary_emb = self._init_rope()
79
+
80
+ self.lambda_init = lambda_init_fn(layer_idx if layer_idx is not None else 0)
81
+ self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
82
+ self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
83
+ self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
84
+ self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
85
+
86
+ self.subln = nn.LayerNorm(self.num_heads * self.head_dim, elementwise_affine=False)
87
+
88
+ self._init_rope()
89
+
90
+ def _init_rope(self):
91
+ if not hasattr(self.config, 'rope_scaling') or self.config.rope_scaling is None:
92
+ self.rotary_emb = LlamaRotaryEmbedding(
93
+ self.head_dim,
94
+ max_position_embeddings=self.max_position_embeddings,
95
+ base=self.rope_theta,
96
+ )
97
+ else:
98
+ scaling_type = self.config.rope_scaling.get("type", "linear")
99
+ scaling_factor = self.config.rope_scaling.get("factor", 1.0)
100
+ if scaling_type == "linear":
101
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
102
+ self.head_dim,
103
+ max_position_embeddings=self.max_position_embeddings,
104
+ scaling_factor=scaling_factor,
105
+ base=self.rope_theta,
106
+ )
107
+ elif scaling_type == "dynamic":
108
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
109
+ self.head_dim,
110
+ max_position_embeddings=self.max_position_embeddings,
111
+ scaling_factor=scaling_factor,
112
+ base=self.rope_theta,
113
+ )
114
+ else:
115
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
116
+
117
+ def forward(
118
+ self,
119
+ hidden_states: torch.Tensor,
120
+ attention_mask: Optional[torch.Tensor] = None,
121
+ position_ids: Optional[torch.LongTensor] = None,
122
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
123
+ output_attentions: bool = False,
124
+ use_cache: bool = False,
125
+ cache_position: Optional[torch.LongTensor] = None,
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
127
+ batch_size, seq_length, _ = hidden_states.size()
128
+
129
+ query_states = self.q_proj(hidden_states)
130
+ key_states = self.k_proj(hidden_states)
131
+ value_states = self.v_proj(hidden_states)
132
+
133
+ query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
134
+ key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
135
+ value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
136
+
137
+ kv_seq_len = key_states.shape[-2]
138
+ if past_key_value is not None:
139
+ kv_seq_len += past_key_value[0].shape[-2]
140
+ cos, sin = self.rotary_emb(value_states, position_ids)
141
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
142
+
143
+ if past_key_value is not None:
144
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
145
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
146
+
147
+ past_key_value = (key_states, value_states) if use_cache else None
148
+
149
+ # Repeat k/v heads if n_kv_heads < n_heads
150
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
151
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
152
+
153
+ attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2))
154
+ attn_weights = attn_weights * self.scaling
155
+
156
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1))
157
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2))
158
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
159
+
160
+ # Apply differential attention
161
+ attn_weights_diff = attn_weights[:, :, :, :-1] - lambda_full * attn_weights[:, :, :, 1:]
162
+ attn_weights = torch.cat([attn_weights_diff, attn_weights[:, :, :, -1:]], dim=-1)
163
+
164
+ if attention_mask is not None:
165
+ # Expand attention_mask
166
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
167
+ attention_mask = attention_mask.expand(batch_size, self.num_heads, seq_length, attention_mask.size(-1))
168
+ attention_mask = attention_mask.to(dtype=attn_weights.dtype) # Convert to same dtype as attn_weights
169
+
170
+ # Use a large negative number instead of negative infinity
171
+ attn_weights = attn_weights + (1.0 - attention_mask) * -10000.0
172
+
173
+ attn_weights = F.softmax(attn_weights, dim=-1)
174
+
175
+ attn_output = torch.matmul(attn_weights, value_states)
176
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.num_heads * self.head_dim)
177
+
178
+ attn_output = self.subln(attn_output)
179
+ attn_output = attn_output * (1 - self.lambda_init)
180
+
181
+ attn_output = self.o_proj(attn_output)
182
+
183
+ if not output_attentions:
184
+ attn_weights = None
185
+
186
+ return attn_output, attn_weights, past_key_value
187
+
188
+
189
+ class DiffLLaMALayer(nn.Module):
190
+ """
191
+ A single layer of the DiffLLaMA model, consisting of multi-head differential attention and a feed-forward network.
192
+ Incorporates gradient checkpointing for memory efficiency.
193
+ """
194
+ def __init__(self, config: DiffLLaMAConfig, layer_idx: int):
195
+ super().__init__()
196
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
197
+ self.self_attn = MultiheadDiffAttn(
198
+ config=config,
199
+ layer_idx=layer_idx
200
+ )
201
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
202
+ self.mlp = LlamaMLP(config)
203
+
204
+ def forward(
205
+ self,
206
+ hidden_states: torch.Tensor,
207
+ attention_mask: Optional[torch.Tensor] = None,
208
+ position_ids: Optional[torch.LongTensor] = None,
209
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
210
+ output_attentions: bool = False,
211
+ use_cache: bool = False,
212
+ cache_position: Optional[torch.LongTensor] = None,
213
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
214
+ residual = hidden_states
215
+ hidden_states = self.input_layernorm(hidden_states)
216
+
217
+ # Self Attention
218
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
219
+ hidden_states=hidden_states,
220
+ attention_mask=attention_mask,
221
+ position_ids=position_ids,
222
+ past_key_value=past_key_value,
223
+ output_attentions=output_attentions,
224
+ use_cache=use_cache,
225
+ cache_position=cache_position,
226
+ )
227
+ hidden_states = residual + hidden_states
228
+
229
+ # Fully Connected
230
+ residual = hidden_states
231
+ hidden_states = self.post_attention_layernorm(hidden_states)
232
+ hidden_states = self.mlp(hidden_states)
233
+ hidden_states = residual + hidden_states
234
+
235
+ outputs = (hidden_states,)
236
+
237
+ if output_attentions:
238
+ outputs += (self_attn_weights,)
239
+
240
+ if use_cache:
241
+ outputs += (present_key_value,)
242
+
243
+ return outputs
244
+
245
+ class DiffLLaMAModel(PreTrainedModel):
246
+ """
247
+ DiffLLaMAModel is a variant of LLaMA with differential attention mechanisms.
248
+ Incorporates mixed precision training and gradient checkpointing for optimized performance.
249
+ """
250
+ config_class = DiffLLaMAConfig
251
+
252
+ def __init__(self, config: DiffLLaMAConfig):
253
+ super().__init__(config)
254
+ self.config = config
255
+
256
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
257
+ self.layers = nn.ModuleList([
258
+ DiffLLaMALayer(config, layer_idx=i) for i in range(config.num_hidden_layers)
259
+ ])
260
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
261
+
262
+ self.rotary_emb = LlamaRotaryEmbedding(
263
+ dim=config.hidden_size // config.num_attention_heads,
264
+ max_position_embeddings=config.max_position_embeddings,
265
+ base=config.rope_theta,
266
+ )
267
+
268
+ self.gradient_checkpointing = False
269
+
270
+ # Initialize weights and apply final processing
271
+ self.post_init()
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: Optional[torch.LongTensor] = None,
276
+ attention_mask: Optional[torch.Tensor] = None,
277
+ position_ids: Optional[torch.LongTensor] = None,
278
+ past_key_values: Optional[List[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None,
279
+ inputs_embeds: Optional[torch.FloatTensor] = None,
280
+ use_cache: Optional[bool] = None,
281
+ output_attentions: Optional[bool] = None,
282
+ output_hidden_states: Optional[bool] = None,
283
+ return_dict: Optional[bool] = None,
284
+ cache_position: Optional[torch.LongTensor] = None,
285
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
286
+
287
+ """
288
+ Forward pass for the DiffLLaMAModel with performance optimizations.
289
+
290
+ Args:
291
+ input_ids: Input token IDs.
292
+ attention_mask: Attention mask.
293
+ position_ids: Position IDs.
294
+ past_key_values: Past key and value tensors for caching.
295
+ inputs_embeds: Input embeddings.
296
+ use_cache: Whether to return present key and value for caching.
297
+ output_attentions: Whether to output attention weights.
298
+ output_hidden_states: Whether to output hidden states.
299
+ return_dict: Whether to return a dict.
300
+ cache_position: Position IDs for caching.
301
+
302
+ Returns:
303
+ Model output, either as a tuple or a BaseModelOutputWithPast.
304
+ """
305
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
306
+ output_hidden_states = (
307
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
308
+ )
309
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
310
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
311
+
312
+ if input_ids is not None and inputs_embeds is not None:
313
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
314
+ elif input_ids is not None:
315
+ batch_size, seq_length = input_ids.shape
316
+ elif inputs_embeds is not None:
317
+ batch_size, seq_length, _ = inputs_embeds.shape
318
+ else:
319
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
320
+
321
+ if position_ids is None:
322
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
323
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
324
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
325
+
326
+ if inputs_embeds is None:
327
+ inputs_embeds = self.embed_tokens(input_ids)
328
+
329
+ # Position embeddings are handled within each layer; remove pre-computation
330
+ # Removed the following lines:
331
+ # cos, sin = self.rotary_emb(position_ids, seq_len=seq_length)
332
+ # position_embeddings = (cos, sin)
333
+
334
+ hidden_states = inputs_embeds
335
+
336
+ # Attention mask
337
+ if attention_mask is None:
338
+ attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
339
+
340
+ # Initialize lists to store outputs
341
+ all_hidden_states = () if output_hidden_states else None
342
+ all_self_attns = () if output_attentions else None
343
+ next_cache = () if use_cache else None
344
+
345
+ for idx, layer in enumerate(self.layers):
346
+ if output_hidden_states:
347
+ all_hidden_states += (hidden_states,)
348
+
349
+ layer_outputs = layer(
350
+ hidden_states=hidden_states,
351
+ attention_mask=attention_mask,
352
+ position_ids=position_ids,
353
+ past_key_value=past_key_values[idx] if past_key_values is not None else None,
354
+ output_attentions=output_attentions,
355
+ use_cache=use_cache,
356
+ cache_position=cache_position,
357
+ )
358
+
359
+ # Correctly unpack layer_outputs based on the configuration
360
+ hidden_states = layer_outputs[0]
361
+
362
+ if use_cache:
363
+ present_key_value = layer_outputs[-1]
364
+ next_cache += (present_key_value,)
365
+
366
+ if output_attentions:
367
+ self_attn_weights = layer_outputs[1]
368
+ all_self_attns += (self_attn_weights,)
369
+
370
+ hidden_states = self.norm(hidden_states)
371
+
372
+ if output_hidden_states:
373
+ all_hidden_states += (hidden_states,)
374
+
375
+ next_cache = None
376
+ if use_cache:
377
+ next_cache = (
378
+ next_cache.to_legacy_cache() if isinstance(next_cache, Cache) else next_cache
379
+ )
380
+ if not return_dict:
381
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
382
+
383
+ return BaseModelOutputWithPast(
384
+ last_hidden_state=hidden_states,
385
+ past_key_values=next_cache,
386
+ hidden_states=all_hidden_states,
387
+ attentions=all_self_attns,
388
+ )
389
+
390
+ class DiffLLaMAForCausalLM(PreTrainedModel):
391
+ """
392
+ DiffLLaMA model with a causal language modeling head.
393
+ Incorporates mixed precision training for optimized performance.
394
+ """
395
+ config_class = DiffLLaMAConfig
396
+ _tied_weights_keys = ["lm_head.weight"]
397
+
398
+ def __init__(self, config: DiffLLaMAConfig):
399
+ super().__init__(config)
400
+ self.model = DiffLLaMAModel(config)
401
+ self.vocab_size = config.vocab_size
402
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
403
+
404
+ # Initialize weights and apply final processing
405
+ self.post_init()
406
+
407
+ def get_input_embeddings(self):
408
+ """Return input embeddings."""
409
+ return self.model.get_input_embeddings()
410
+
411
+ def set_input_embeddings(self, value):
412
+ """Set input embeddings."""
413
+ self.model.set_input_embeddings(value)
414
+
415
+ def get_output_embeddings(self):
416
+ """Return output embeddings (language modeling head)."""
417
+ return self.lm_head
418
+
419
+ def set_output_embeddings(self, new_embeddings):
420
+ """Set output embeddings (language modeling head)."""
421
+ self.lm_head = new_embeddings
422
+
423
+ def set_decoder(self, decoder):
424
+ """Set the decoder model."""
425
+ self.model = decoder
426
+
427
+ def get_decoder(self):
428
+ """Get the decoder model."""
429
+ return self.model
430
+
431
+ def forward(
432
+ self,
433
+ input_ids: Optional[torch.LongTensor] = None,
434
+ attention_mask: Optional[torch.Tensor] = None,
435
+ position_ids: Optional[torch.LongTensor] = None,
436
+ past_key_values: Optional[List[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None,
437
+ inputs_embeds: Optional[torch.FloatTensor] = None,
438
+ labels: Optional[torch.LongTensor] = None,
439
+ use_cache: Optional[bool] = None,
440
+ output_attentions: Optional[bool] = None,
441
+ output_hidden_states: Optional[bool] = None,
442
+ return_dict: Optional[bool] = None,
443
+ cache_position: Optional[torch.LongTensor] = None,
444
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
445
+ """
446
+ Forward pass for DiffLLaMAForCausalLM with performance optimizations.
447
+
448
+ Args:
449
+ input_ids: Input token IDs.
450
+ attention_mask: Attention mask.
451
+ position_ids: Position IDs.
452
+ past_key_values: Past key and value tensors for caching.
453
+ inputs_embeds: Input embeddings.
454
+ labels: Labels for computing the loss.
455
+ use_cache: Whether to return past key and value tensors.
456
+ output_attentions: Whether to output attention weights.
457
+ output_hidden_states: Whether to output hidden states.
458
+ return_dict: Whether to return a dict.
459
+ cache_position: Position IDs for caching.
460
+
461
+ Returns:
462
+ CausalLMOutputWithPast or tuple containing loss and logits.
463
+ """
464
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
465
+ output_hidden_states = (
466
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
467
+ )
468
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
469
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
470
+
471
+ # Get outputs from the model
472
+ outputs = self.model(
473
+ input_ids=input_ids,
474
+ attention_mask=attention_mask,
475
+ position_ids=position_ids,
476
+ past_key_values=past_key_values,
477
+ inputs_embeds=inputs_embeds,
478
+ use_cache=use_cache,
479
+ output_attentions=output_attentions,
480
+ output_hidden_states=output_hidden_states,
481
+ return_dict=return_dict,
482
+ cache_position=cache_position,
483
+ )
484
+
485
+ hidden_states = outputs.last_hidden_state if return_dict else outputs[0]
486
+ logits = self.lm_head(hidden_states)
487
+
488
+ loss = None
489
+ if labels is not None:
490
+ # Shift so that tokens < n predict n
491
+ shift_logits = logits[..., :-1, :].contiguous()
492
+ shift_labels = labels[..., 1:].contiguous()
493
+ # Flatten the tokens
494
+ loss_fct = nn.CrossEntropyLoss()
495
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
496
+ shift_labels = shift_labels.view(-1)
497
+ # Compute loss using mixed precision if enabled
498
+ if shift_logits.dtype == torch.float16:
499
+ with torch.cuda.amp.autocast(enabled=False):
500
+ loss = loss_fct(shift_logits, shift_labels)
501
+ else:
502
+ loss = loss_fct(shift_logits, shift_labels)
503
+
504
+ if not return_dict:
505
+ if use_cache:
506
+ return ((loss, logits) + outputs[1:]) if loss is not None else (logits,) + outputs[1:]
507
+ else:
508
+ return (loss, logits) if loss is not None else (logits,)
509
+
510
+ return CausalLMOutputWithPast(
511
+ loss=loss,
512
+ logits=logits,
513
+ past_key_values=outputs.past_key_values,
514
+ hidden_states=outputs.hidden_states,
515
+ attentions=outputs.attentions,
516
+ )
517
+
518
+