appledora commited on
Commit
877b677
1 Parent(s): f3b5355

Upload recastmlp_llama/modeling_recastmlp_llama.py with huggingface_hub

Browse files
recastmlp_llama/modeling_recastmlp_llama.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recastmlp_llama import RECASTMLP_llama
3
+ from transformers import PreTrainedModel
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional, Tuple, Union, List
8
+ from transformers import AutoConfig
9
+ from transformers.utils import logging
10
+ from transformers.cache_utils import Cache, StaticCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class MLPTemplateBank(nn.Module):
19
+ def __init__(self, config, num_templates):
20
+ """
21
+ Initialize template bank for MLP layers
22
+ Args:
23
+ config: LlamaConfig instance
24
+ num_templates: Number of templates in bank
25
+ """
26
+ super().__init__()
27
+ self.num_templates = config.num_templates
28
+ self.hidden_size = config.hidden_size
29
+ self.intermediate_size = config.intermediate_size
30
+
31
+ # Create templates for gate, up and down projections
32
+ self.gate_templates = nn.Parameter(
33
+ torch.stack(
34
+ [
35
+ torch.empty(self.intermediate_size, self.hidden_size)
36
+ for _ in range(self.num_templates)
37
+ ]
38
+ )
39
+ )
40
+
41
+ self.up_templates = nn.Parameter(
42
+ torch.stack(
43
+ [
44
+ torch.empty(self.intermediate_size, self.hidden_size)
45
+ for _ in range(self.num_templates)
46
+ ]
47
+ )
48
+ )
49
+
50
+ self.down_templates = nn.Parameter(
51
+ torch.stack(
52
+ [
53
+ torch.empty(self.hidden_size, self.intermediate_size)
54
+ for _ in range(self.num_templates)
55
+ ]
56
+ )
57
+ )
58
+
59
+ # Initialize templates
60
+ for i in range(self.num_templates):
61
+ nn.init.kaiming_normal_(self.gate_templates[i])
62
+ nn.init.kaiming_normal_(self.up_templates[i])
63
+ nn.init.kaiming_normal_(self.down_templates[i])
64
+
65
+ self.coefficient_shape = (self.num_templates, 1, 1)
66
+
67
+ def forward(self, gate_coeffs, up_coeffs, down_coeffs):
68
+ """Generate weights from coefficients"""
69
+ gate_weights = (self.gate_templates * gate_coeffs).sum(0)
70
+ up_weights = (self.up_templates * up_coeffs).sum(0)
71
+ down_weights = (self.down_templates * down_coeffs).sum(0)
72
+ return gate_weights, up_weights, down_weights
73
+
74
+ def __repr__(self):
75
+ return f"MLPTemplateBank(num_templates={self.num_templates}, hidden_size={self.hidden_size}, intermediate_size={self.intermediate_size})"
76
+
77
+
78
+ class SharedLlamaMLP(nn.Module):
79
+ def __init__(self, config, bank):
80
+ super().__init__()
81
+ self.config = config
82
+ self.hidden_size = config.hidden_size
83
+ self.intermediate_size = config.intermediate_size
84
+ self.bank = bank
85
+ num_cf = config.num_cf
86
+
87
+ # Coefficients for template bank
88
+ self.gate_coefficients = nn.ParameterList(
89
+ [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
90
+ )
91
+ self.up_coefficients = nn.ParameterList(
92
+ [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
93
+ )
94
+ self.down_coefficients = nn.ParameterList(
95
+ [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
96
+ )
97
+
98
+ # Initialize coefficients
99
+ for cf in self.gate_coefficients:
100
+ nn.init.orthogonal_(cf)
101
+ for cf in self.up_coefficients:
102
+ nn.init.orthogonal_(cf)
103
+ for cf in self.down_coefficients:
104
+ nn.init.orthogonal_(cf)
105
+
106
+ # Biases
107
+ self.gate_bias = (
108
+ nn.Parameter(torch.zeros(self.intermediate_size))
109
+ if config.mlp_bias
110
+ else None
111
+ )
112
+ self.up_bias = (
113
+ nn.Parameter(torch.zeros(self.intermediate_size))
114
+ if config.mlp_bias
115
+ else None
116
+ )
117
+ self.down_bias = (
118
+ nn.Parameter(torch.zeros(self.hidden_size)) if config.mlp_bias else None
119
+ )
120
+
121
+ # Activation
122
+ # self.act_fn = nn.functional.__dict__[config.hidden_act]
123
+ # self.act_fn = keras.activations.swish
124
+ self.act_fn = F.silu
125
+
126
+ def forward(self, x):
127
+ # Generate weights using coefficients
128
+ gate_weights = []
129
+ up_weights = []
130
+ down_weights = []
131
+
132
+ for i in range(len(self.gate_coefficients)):
133
+ gate, up, down = self.bank(
134
+ self.gate_coefficients[i],
135
+ self.up_coefficients[i],
136
+ self.down_coefficients[i],
137
+ )
138
+ gate_weights.append(gate)
139
+ up_weights.append(up)
140
+ down_weights.append(down)
141
+
142
+ gate_weights = torch.stack(gate_weights).mean(0)
143
+ up_weights = torch.stack(up_weights).mean(0)
144
+ down_weights = torch.stack(down_weights).mean(0)
145
+
146
+ # Apply MLP operations
147
+ gate_output = F.linear(x, gate_weights, self.gate_bias)
148
+ up_output = F.linear(x, up_weights, self.up_bias)
149
+
150
+ # Apply activation and down projection
151
+ hidden_states = self.act_fn(gate_output) * up_output
152
+ output = F.linear(hidden_states, down_weights, self.down_bias)
153
+
154
+ return output
155
+
156
+ def __repr__(self):
157
+ return (
158
+ f"SharedLlamaMLP(hidden_size={self.hidden_size}, "
159
+ f"intermediate_size={self.intermediate_size}, "
160
+ f"gate_coefficients={len(self.gate_coefficients)}, "
161
+ f"up_coefficients={len(self.up_coefficients)}, "
162
+ f"down_coefficients={len(self.down_coefficients)})"
163
+ )
164
+
165
+
166
+ def fixed_cross_entropy(
167
+ source,
168
+ target,
169
+ num_items_in_batch: int = None,
170
+ ignore_index: int = -100,
171
+ **kwargs,
172
+ ):
173
+ reduction = "sum" if num_items_in_batch is not None else "mean"
174
+ loss = nn.functional.cross_entropy(
175
+ source, target, ignore_index=ignore_index, reduction=reduction
176
+ )
177
+ if reduction == "sum":
178
+ loss = loss / num_items_in_batch
179
+ return loss
180
+
181
+
182
+ from transformers.models.llama.modeling_llama import (
183
+ LlamaDecoderLayer,
184
+ LlamaRotaryEmbedding,
185
+ LlamaRMSNorm,
186
+ apply_rotary_pos_emb,
187
+ )
188
+ from transformers.modeling_outputs import BaseModelOutputWithPast
189
+
190
+
191
+ class RECASTMLP_llamaModel(PreTrainedModel):
192
+ config_class = RECASTMLP_llama
193
+ base_model_prefix = "llama"
194
+ supports_gradient_checkpointing = True
195
+
196
+ def __init__(self, config):
197
+ super().__init__(config)
198
+ self.padding_idx = config.pad_token_id
199
+ self.vocab_size = config.vocab_size
200
+
201
+ self.embed_tokens = nn.Embedding(
202
+ config.vocab_size, config.hidden_size, self.padding_idx
203
+ )
204
+ # Initialize rotary embeddings
205
+ rope_config = config.rope_scaling
206
+ if rope_config:
207
+ rope_type = rope_config.get("rope_type", "default")
208
+ scaling_factor = rope_config.get("factor", 1.0)
209
+ else:
210
+ rope_type = "default"
211
+ scaling_factor = None
212
+ original_config = AutoConfig.from_pretrained(
213
+ "meta-llama/Llama-3.1-8b", trust_remote_code=True
214
+ )
215
+ self.rotary_emb = LlamaRotaryEmbedding(
216
+ config=original_config,
217
+ )
218
+
219
+ # Create template banks first
220
+ self.banks = []
221
+ layers_per_group = config.num_hidden_layers // config.num_groups
222
+ for _ in range(config.num_groups):
223
+ bank = MLPTemplateBank(config, config.num_templates)
224
+ self.banks.append(bank)
225
+
226
+ # Create layers using LlamaDecoderLayer but replace MLPs
227
+ self.layers = nn.ModuleList()
228
+ for layer_idx in range(config.num_hidden_layers):
229
+ # Create standard LlamaDecoderLayer
230
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
231
+
232
+ # Replace its MLP with our SharedLlamaMLP
233
+ group_idx = layer_idx // layers_per_group
234
+ group_bank = self.banks[group_idx]
235
+ decoder_layer.mlp = SharedLlamaMLP(config, bank=group_bank)
236
+
237
+ self.layers.append(decoder_layer)
238
+
239
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+ self.gradient_checkpointing = False
241
+
242
+ def forward(
243
+ self,
244
+ input_ids: torch.LongTensor = None,
245
+ attention_mask: Optional[torch.Tensor] = None,
246
+ position_ids: Optional[torch.LongTensor] = None,
247
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
248
+ inputs_embeds: Optional[torch.FloatTensor] = None,
249
+ use_cache: Optional[bool] = None,
250
+ output_attentions: Optional[bool] = None,
251
+ output_hidden_states: Optional[bool] = None,
252
+ return_dict: Optional[bool] = None,
253
+ cache_position: Optional[torch.LongTensor] = None,
254
+ **flash_attn_kwargs,
255
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
267
+ return_dict = (
268
+ return_dict if return_dict is not None else self.config.use_return_dict
269
+ )
270
+
271
+ if (input_ids is None) ^ (inputs_embeds is not None):
272
+ raise ValueError(
273
+ "You must specify exactly one of input_ids or inputs_embeds"
274
+ )
275
+
276
+ if self.gradient_checkpointing and self.training and use_cache:
277
+ logger.warning_once(
278
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
279
+ )
280
+ use_cache = False
281
+
282
+ if inputs_embeds is None:
283
+ inputs_embeds = self.embed_tokens(input_ids)
284
+
285
+ # Create position embeddings to be shared across the decoder layers
286
+ if position_ids is None:
287
+ past_seen_tokens = (
288
+ past_key_values.get_seq_length() if past_key_values is not None else 0
289
+ )
290
+ position_ids = torch.arange(
291
+ past_seen_tokens,
292
+ past_seen_tokens + inputs_embeds.shape[1],
293
+ device=inputs_embeds.device,
294
+ ).unsqueeze(0)
295
+
296
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
297
+ hidden_states = inputs_embeds
298
+
299
+ # Get updated causal mask
300
+ causal_mask = self._update_causal_mask(
301
+ attention_mask,
302
+ inputs_embeds,
303
+ cache_position,
304
+ past_key_values,
305
+ output_attentions,
306
+ )
307
+
308
+ # Initialize outputs
309
+ all_hidden_states = () if output_hidden_states else None
310
+ all_self_attns = () if output_attentions else None
311
+ next_decoder_cache = None
312
+
313
+ # Process through layers
314
+ for decoder_layer in self.layers:
315
+ if output_hidden_states:
316
+ all_hidden_states += (hidden_states,)
317
+
318
+ if self.gradient_checkpointing and self.training:
319
+ layer_outputs = self._gradient_checkpointing_func(
320
+ decoder_layer.__call__,
321
+ hidden_states,
322
+ causal_mask,
323
+ position_ids,
324
+ past_key_values,
325
+ output_attentions,
326
+ use_cache,
327
+ position_embeddings,
328
+ )
329
+ else:
330
+ layer_outputs = decoder_layer(
331
+ hidden_states,
332
+ attention_mask=causal_mask,
333
+ position_ids=position_ids,
334
+ past_key_value=past_key_values,
335
+ output_attentions=output_attentions,
336
+ use_cache=use_cache,
337
+ position_embeddings=position_embeddings,
338
+ **flash_attn_kwargs,
339
+ )
340
+
341
+ hidden_states = layer_outputs[0]
342
+
343
+ if use_cache:
344
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
345
+
346
+ if output_attentions:
347
+ all_self_attns += (layer_outputs[1],)
348
+
349
+ # Final layer norm
350
+ hidden_states = self.norm(hidden_states)
351
+
352
+ # Add last hidden state
353
+ if output_hidden_states:
354
+ all_hidden_states += (hidden_states,)
355
+
356
+ next_cache = next_decoder_cache if use_cache else None
357
+
358
+ if not return_dict:
359
+ return tuple(
360
+ v
361
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
362
+ if v is not None
363
+ )
364
+
365
+ return BaseModelOutputWithPast(
366
+ last_hidden_state=hidden_states,
367
+ past_key_values=next_cache,
368
+ hidden_states=all_hidden_states,
369
+ attentions=all_self_attns,
370
+ )
371
+
372
+ @classmethod
373
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
374
+ if isinstance(
375
+ pretrained_model_name_or_path, str
376
+ ) and pretrained_model_name_or_path.endswith(".pt"):
377
+ print("Loading from local checkpoint")
378
+ # Load from local checkpoint
379
+ config = kwargs.get("config", None)
380
+ if config is None:
381
+ config = AutoConfig.from_pretrained(
382
+ pretrained_model_name_or_path, trust_remote_code=True
383
+ )
384
+
385
+ model = cls(config)
386
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
387
+ state_dict = checkpoint["model_state_dict"]
388
+ logger.info(
389
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
390
+ )
391
+
392
+ missing_keys, unexpected_keys = model.load_state_dict(
393
+ state_dict, strict=False
394
+ )
395
+
396
+ if len(missing_keys) > 0:
397
+ logger.warning(f"Missing keys: {missing_keys}")
398
+ if len(unexpected_keys) > 0:
399
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
400
+
401
+ return model
402
+ else:
403
+ print("Loading from hub")
404
+ # Load from hub using parent's from_pretrained
405
+ return super().from_pretrained(
406
+ pretrained_model_name_or_path, *model_args, **kwargs
407
+ )
408
+
409
+ def get_input_embeddings(self):
410
+ return self.embed_tokens
411
+
412
+ def set_input_embeddings(self, value):
413
+ self.embed_tokens = value
414
+
415
+ def _update_causal_mask(
416
+ self,
417
+ attention_mask: torch.Tensor,
418
+ input_tensor: torch.Tensor,
419
+ cache_position: torch.Tensor,
420
+ past_key_values: Cache,
421
+ output_attentions: bool,
422
+ ):
423
+ if self.config._attn_implementation == "flash_attention_2":
424
+ if attention_mask is not None and 0.0 in attention_mask:
425
+ return attention_mask
426
+ return None
427
+
428
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
429
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
430
+ # to infer the attention mask.
431
+ past_seen_tokens = (
432
+ past_key_values.get_seq_length() if past_key_values is not None else 0
433
+ )
434
+ using_static_cache = isinstance(past_key_values, StaticCache)
435
+
436
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
437
+ if (
438
+ self.config._attn_implementation == "sdpa"
439
+ and not using_static_cache
440
+ and not output_attentions
441
+ ):
442
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
443
+ attention_mask,
444
+ inputs_embeds=input_tensor,
445
+ past_key_values_length=past_seen_tokens,
446
+ is_training=self.training,
447
+ ):
448
+ return None
449
+
450
+ dtype, device = input_tensor.dtype, input_tensor.device
451
+ sequence_length = input_tensor.shape[1]
452
+ if using_static_cache:
453
+ target_length = past_key_values.get_max_cache_shape()
454
+ else:
455
+ target_length = (
456
+ attention_mask.shape[-1]
457
+ if isinstance(attention_mask, torch.Tensor)
458
+ else past_seen_tokens + sequence_length + 1
459
+ )
460
+
461
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
462
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
463
+ attention_mask,
464
+ sequence_length=sequence_length,
465
+ target_length=target_length,
466
+ dtype=dtype,
467
+ device=device,
468
+ cache_position=cache_position,
469
+ batch_size=input_tensor.shape[0],
470
+ )
471
+
472
+ if (
473
+ self.config._attn_implementation == "sdpa"
474
+ and attention_mask is not None
475
+ and attention_mask.device.type == "cuda"
476
+ and not output_attentions
477
+ ):
478
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
479
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
480
+ # Details: https://github.com/pytorch/pytorch/issues/110213
481
+ min_dtype = torch.finfo(dtype).min
482
+ causal_mask = AttentionMaskConverter._unmask_unattended(
483
+ causal_mask, min_dtype
484
+ )
485
+
486
+ return causal_mask
487
+
488
+ @staticmethod
489
+ def _prepare_4d_causal_attention_mask_with_cache_position(
490
+ attention_mask: torch.Tensor,
491
+ sequence_length: int,
492
+ target_length: int,
493
+ dtype: torch.dtype,
494
+ device: torch.device,
495
+ cache_position: torch.Tensor,
496
+ batch_size: int,
497
+ **kwargs,
498
+ ):
499
+ if attention_mask is not None and attention_mask.dim() == 4:
500
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
501
+ causal_mask = attention_mask
502
+ else:
503
+ min_dtype = torch.finfo(dtype).min
504
+ causal_mask = torch.full(
505
+ (sequence_length, target_length),
506
+ fill_value=min_dtype,
507
+ dtype=dtype,
508
+ device=device,
509
+ )
510
+ if sequence_length != 1:
511
+ causal_mask = torch.triu(causal_mask, diagonal=1)
512
+ causal_mask *= torch.arange(
513
+ target_length, device=device
514
+ ) > cache_position.reshape(-1, 1)
515
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
516
+ if attention_mask is not None:
517
+ causal_mask = (
518
+ causal_mask.clone()
519
+ ) # copy to contiguous memory for in-place edit
520
+ mask_length = attention_mask.shape[-1]
521
+ padding_mask = (
522
+ causal_mask[:, :, :, :mask_length]
523
+ + attention_mask[:, None, None, :]
524
+ )
525
+ padding_mask = padding_mask == 0
526
+ causal_mask[:, :, :, :mask_length] = causal_mask[
527
+ :, :, :, :mask_length
528
+ ].masked_fill(padding_mask, min_dtype)
529
+
530
+ return causal_mask
531
+
532
+
533
+ class RECASTMLP_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
534
+ _tied_weights_keys = ["lm_head.weight"]
535
+ _tp_plan = {"lm_head": "colwise_rep"}
536
+ config_class = RECASTMLP_llama
537
+ base_model_prefix = "llama"
538
+ supports_gradient_checkpointing = True
539
+
540
+ def __init__(self, config):
541
+ super().__init__(config)
542
+ self.model = RECASTMLP_llamaModel(config)
543
+ self.vocab_size = config.vocab_size
544
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
545
+
546
+ # Initialize weights and apply final processing
547
+ self.post_init()
548
+
549
+ def get_input_embeddings(self):
550
+ return self.model.embed_tokens
551
+
552
+ def set_input_embeddings(self, value):
553
+ self.model.embed_tokens = value
554
+
555
+ def get_output_embeddings(self):
556
+ return self.lm_head
557
+
558
+ def set_output_embeddings(self, new_embeddings):
559
+ self.lm_head = new_embeddings
560
+
561
+ def set_decoder(self, decoder):
562
+ self.model = decoder
563
+
564
+ def get_decoder(self):
565
+ return self.model
566
+
567
+ def loss_function(
568
+ self,
569
+ logits,
570
+ labels,
571
+ vocab_size: int,
572
+ num_items_in_batch: int = None,
573
+ ignore_index: int = -100,
574
+ **kwargs,
575
+ ):
576
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
577
+ logits = logits.float()
578
+ # Shift so that tokens < n predict n
579
+ shift_logits = logits[..., :-1, :].contiguous()
580
+ shift_labels = labels[..., 1:].contiguous()
581
+ # Flatten the tokens
582
+ shift_logits = shift_logits.view(-1, vocab_size)
583
+ shift_labels = shift_labels.view(-1)
584
+ # Enable model parallelism
585
+ shift_labels = shift_labels.to(shift_logits.device)
586
+ loss = fixed_cross_entropy(
587
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
588
+ )
589
+ return loss
590
+
591
+ def forward(
592
+ self,
593
+ input_ids: torch.LongTensor = None,
594
+ attention_mask: Optional[torch.Tensor] = None,
595
+ position_ids: Optional[torch.LongTensor] = None,
596
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
597
+ inputs_embeds: Optional[torch.FloatTensor] = None,
598
+ labels: Optional[torch.LongTensor] = None,
599
+ use_cache: Optional[bool] = None,
600
+ output_attentions: Optional[bool] = None,
601
+ output_hidden_states: Optional[bool] = None,
602
+ return_dict: Optional[bool] = None,
603
+ cache_position: Optional[torch.LongTensor] = None,
604
+ num_logits_to_keep: int = 0,
605
+ **kwargs,
606
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
607
+ """
608
+ Args:
609
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
610
+ Labels for computing the masked language modeling loss. Indices should be in
611
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
612
+ num_logits_to_keep (`int`, *optional*):
613
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
614
+ """
615
+ output_attentions = (
616
+ output_attentions
617
+ if output_attentions is not None
618
+ else self.config.output_attentions
619
+ )
620
+ output_hidden_states = (
621
+ output_hidden_states
622
+ if output_hidden_states is not None
623
+ else self.config.output_hidden_states
624
+ )
625
+ return_dict = (
626
+ return_dict if return_dict is not None else self.config.use_return_dict
627
+ )
628
+
629
+ outputs = self.model(
630
+ input_ids=input_ids,
631
+ attention_mask=attention_mask,
632
+ position_ids=position_ids,
633
+ past_key_values=past_key_values,
634
+ inputs_embeds=inputs_embeds,
635
+ use_cache=use_cache,
636
+ output_attentions=output_attentions,
637
+ output_hidden_states=output_hidden_states,
638
+ return_dict=return_dict,
639
+ cache_position=cache_position,
640
+ **kwargs,
641
+ )
642
+
643
+ hidden_states = outputs[0]
644
+ # Only compute necessary logits
645
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
646
+
647
+ loss = None
648
+ if labels is not None:
649
+ # Calculate batch size for loss function
650
+ num_items_in_batch = (
651
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
652
+ )
653
+ loss = self.loss_function(
654
+ logits=logits,
655
+ labels=labels,
656
+ vocab_size=self.config.vocab_size,
657
+ num_items_in_batch=num_items_in_batch,
658
+ **kwargs,
659
+ )
660
+
661
+ if not return_dict:
662
+ output = (logits,) + outputs[1:]
663
+ return (loss,) + output if loss is not None else output
664
+
665
+ return CausalLMOutputWithPast(
666
+ loss=loss,
667
+ logits=logits,
668
+ past_key_values=outputs.past_key_values,
669
+ hidden_states=outputs.hidden_states,
670
+ attentions=outputs.attentions,
671
+ )
672
+
673
+ def prepare_inputs_for_generation(
674
+ self,
675
+ input_ids,
676
+ past_key_values=None,
677
+ attention_mask=None,
678
+ inputs_embeds=None,
679
+ **kwargs,
680
+ ):
681
+ if past_key_values:
682
+ input_ids = input_ids[:, -1:]
683
+
684
+ position_ids = kwargs.get("position_ids", None)
685
+ if attention_mask is not None and position_ids is None:
686
+ # create position_ids on the fly for batch generation
687
+ position_ids = attention_mask.long().cumsum(-1) - 1
688
+ position_ids.masked_fill_(attention_mask == 0, 1)
689
+ if past_key_values:
690
+ position_ids = position_ids[:, -1].unsqueeze(-1)
691
+
692
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
693
+ if inputs_embeds is not None and past_key_values is None:
694
+ model_inputs = {"inputs_embeds": inputs_embeds}
695
+ else:
696
+ model_inputs = {"input_ids": input_ids}
697
+
698
+ model_inputs.update(
699
+ {
700
+ "position_ids": position_ids,
701
+ "past_key_values": past_key_values,
702
+ "use_cache": kwargs.get("use_cache"),
703
+ "attention_mask": attention_mask,
704
+ }
705
+ )
706
+ return model_inputs
707
+
708
+ @classmethod
709
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
710
+ if isinstance(
711
+ pretrained_model_name_or_path, str
712
+ ) and pretrained_model_name_or_path.endswith(".pt"):
713
+ print("Loading from local checkpoint")
714
+ config = kwargs.get("config", None)
715
+ if config is None:
716
+ config = AutoConfig.from_pretrained(
717
+ pretrained_model_name_or_path, trust_remote_code=True
718
+ )
719
+
720
+ model = cls(config)
721
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
722
+ state_dict = checkpoint["model_state_dict"]
723
+
724
+ missing_keys, unexpected_keys = model.load_state_dict(
725
+ state_dict, strict=False
726
+ )
727
+
728
+ if len(missing_keys) > 0:
729
+ logger.warning(f"Missing keys: {missing_keys}")
730
+ if len(unexpected_keys) > 0:
731
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
732
+
733
+ return model
734
+ else:
735
+ print("Loading from hub")
736
+ return super().from_pretrained(
737
+ pretrained_model_name_or_path, *model_args, **kwargs
738
+ )