winglian commited on
Commit
9bf854e
1 Parent(s): 797f3dd

Phi update 202311 (#876)

Browse files

* add phi modeling from hf

* update for packing and use new modeling class for phi

* update e2e tests for phi to use new model name

* update example phi to also use new phi model name

* use AutoModelForCausalLM for phi lora since sample packing isn't supported

examples/phi/phi-ft.yml CHANGED
@@ -1,5 +1,5 @@
1
  base_model: microsoft/phi-1_5
2
- model_type: MixFormerSequentialForCausalLM
3
  tokenizer_type: AutoTokenizer
4
  is_llama_derived_model: false
5
  trust_remote_code: true
 
1
  base_model: microsoft/phi-1_5
2
+ model_type: PhiForCausalLM
3
  tokenizer_type: AutoTokenizer
4
  is_llama_derived_model: false
5
  trust_remote_code: true
src/axolotl/models/phi/__init__.py CHANGED
@@ -3,4 +3,6 @@ MixFormers model architecture used for phi models
3
  """
4
 
5
  from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
 
6
  from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
 
 
3
  """
4
 
5
  from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
6
+ from .configuration_phi import PhiConfig # noqa
7
  from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
8
+ from .modeling_phi import PhiForCausalLM # noqa
src/axolotl/models/phi/configuration_phi.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT license.
4
+
5
+ import math
6
+ from typing import Optional
7
+
8
+ from transformers import PretrainedConfig
9
+
10
+
11
+ class PhiConfig(PretrainedConfig):
12
+ """Phi configuration."""
13
+
14
+ model_type = "phi"
15
+ attribute_map = {
16
+ "max_position_embeddings": "n_positions",
17
+ "hidden_size": "n_embd",
18
+ "num_attention_heads": "n_head",
19
+ "num_hidden_layers": "n_layer",
20
+ }
21
+
22
+ def __init__(
23
+ self,
24
+ vocab_size: int = 50304,
25
+ n_positions: int = 2048,
26
+ n_embd: int = 1024,
27
+ n_layer: int = 20,
28
+ n_inner: Optional[int] = None,
29
+ n_head: int = 16,
30
+ n_head_kv: Optional[int] = None,
31
+ rotary_dim: Optional[int] = 32,
32
+ activation_function: Optional[str] = "gelu_new",
33
+ flash_attn: bool = False,
34
+ flash_rotary: bool = False,
35
+ fused_dense: bool = False,
36
+ attn_pdrop: float = 0.0,
37
+ embd_pdrop: float = 0.0,
38
+ resid_pdrop: float = 0.0,
39
+ layer_norm_epsilon: float = 1e-5,
40
+ initializer_range: float = 0.02,
41
+ tie_word_embeddings: bool = False,
42
+ pad_vocab_size_multiple: int = 64,
43
+ **kwargs
44
+ ) -> None:
45
+ self.vocab_size = int(
46
+ math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
47
+ )
48
+ self.n_positions = n_positions
49
+ self.n_embd = n_embd
50
+ self.n_layer = n_layer
51
+ self.n_inner = n_inner
52
+ self.n_head = n_head
53
+ self.n_head_kv = n_head_kv
54
+ self.rotary_dim = min(rotary_dim, n_embd // n_head)
55
+ self.activation_function = activation_function
56
+ self.flash_attn = flash_attn
57
+ self.flash_rotary = flash_rotary
58
+ self.fused_dense = fused_dense
59
+ self.attn_pdrop = attn_pdrop
60
+ self.embd_pdrop = embd_pdrop
61
+ self.resid_pdrop = resid_pdrop
62
+ self.layer_norm_epsilon = layer_norm_epsilon
63
+ self.initializer_range = initializer_range
64
+
65
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
src/axolotl/models/phi/modeling_phi.py ADDED
@@ -0,0 +1,1063 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT license.
4
+ #
5
+ # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
6
+ # Licensed under the BSD 3-Clause License.
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Dict, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange, repeat
17
+ from transformers import PretrainedConfig, PreTrainedModel
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import CausalLMOutputWithPast
20
+
21
+ from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
22
+ from .configuration_phi import PhiConfig
23
+
24
+ try:
25
+ from flash_attn.bert_padding import pad_input, unpad_input
26
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
27
+ from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
28
+ from flash_attn.ops.fused_dense import FusedDense
29
+ except: # noqa: E722
30
+ pad_input, unpad_input = None, None
31
+ FlashRotaryEmbedding = None
32
+ FlashSelfAttention, FlashCrossAttention = None, None
33
+ FusedDense = None
34
+
35
+
36
+ @dataclass
37
+ class InferenceParams:
38
+ """Inference parameters passed to model to efficiently calculate
39
+ and store context during inference.
40
+
41
+ Reference:
42
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
43
+
44
+ Args:
45
+ max_seqlen: Maximum sequence length.
46
+ max_batch_size: Maximum batch size.
47
+ seqlen_offset: Sequence length offset.
48
+ batch_size_offset: Batch size offset.
49
+ key_value_memory_dict: Key value memory dictionary.
50
+ lengths_per_sample: Lengths per sample.
51
+
52
+ """
53
+
54
+ max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
55
+
56
+ max_batch_size: int = field(metadata={"help": "Maximum batch size."})
57
+
58
+ seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
59
+
60
+ batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
61
+
62
+ key_value_memory_dict: Dict[str, Any] = field(
63
+ default_factory=dict, metadata={"help": "Key value memory dictionary."}
64
+ )
65
+
66
+ lengths_per_sample: torch.Tensor = field(
67
+ default=None, metadata={"help": "Lengths per sample."}
68
+ )
69
+
70
+
71
+ class Embedding(nn.Module):
72
+ """Token embedding with dropout."""
73
+
74
+ def __init__(self, config: PretrainedConfig) -> None:
75
+ super().__init__()
76
+
77
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
78
+ self.drop = nn.Dropout(config.embd_pdrop)
79
+
80
+ def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
81
+ input_shape = input_ids.size()
82
+ input_ids = input_ids.view(-1, input_shape[-1])
83
+
84
+ hidden_states = self.wte(input_ids)
85
+ hidden_states = self.drop(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+
90
+ def _apply_rotary_emb(
91
+ x: torch.FloatTensor,
92
+ cos: torch.FloatTensor,
93
+ sin: torch.FloatTensor,
94
+ ) -> torch.FloatTensor:
95
+ _, seqlen, _, _ = x.shape
96
+ _, rotary_dim = cos.shape
97
+ rotary_dim *= 2
98
+
99
+ x_rot = x[:, :, :, :rotary_dim]
100
+ x_pass = x[:, :, :, rotary_dim:]
101
+
102
+ x1, x2 = x_rot.chunk(2, dim=-1)
103
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
104
+ sin[:seqlen], "s d -> s 1 d"
105
+ )
106
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
107
+
108
+ x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
109
+
110
+ return torch.cat([x_rot, x_pass], axis=-1)
111
+
112
+
113
+ def _apply_rotary_emb_kv(
114
+ kv: torch.FloatTensor,
115
+ cos: torch.FloatTensor,
116
+ sin: torch.FloatTensor,
117
+ cos_k: Optional[torch.FloatTensor] = None,
118
+ sin_k: Optional[torch.FloatTensor] = None,
119
+ ) -> torch.FloatTensor:
120
+ _, seqlen, _, _, _ = kv.shape
121
+ _, rotary_dim = cos.shape
122
+ rotary_dim *= 2
123
+
124
+ k_rot = kv[:, :, 0, :, :rotary_dim]
125
+ k_pass = kv[:, :, 0, :, rotary_dim:]
126
+
127
+ k1, k2 = k_rot.chunk(2, dim=-1)
128
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
129
+ sin[:seqlen], "s d -> s 1 d"
130
+ )
131
+ k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
132
+
133
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
134
+
135
+ return torch.cat(
136
+ [
137
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
138
+ kv[:, :, 1:2, :, :],
139
+ ],
140
+ axis=2,
141
+ )
142
+
143
+
144
+ def _apply_rotary_emb_qkv(
145
+ qkv: torch.FloatTensor,
146
+ cos: torch.FloatTensor,
147
+ sin: torch.FloatTensor,
148
+ cos_k: Optional[torch.FloatTensor] = None,
149
+ sin_k: Optional[torch.FloatTensor] = None,
150
+ ) -> torch.FloatTensor:
151
+ _, seqlen, _, _, _ = qkv.shape
152
+ _, rotary_dim = cos.shape
153
+ rotary_dim *= 2
154
+
155
+ q_rot = qkv[:, :, 0, :, :rotary_dim]
156
+ q_pass = qkv[:, :, 0, :, rotary_dim:]
157
+
158
+ k_rot = qkv[:, :, 1, :, :rotary_dim]
159
+ k_pass = qkv[:, :, 1, :, rotary_dim:]
160
+
161
+ q1, q2 = q_rot.chunk(2, dim=-1)
162
+ k1, k2 = k_rot.chunk(2, dim=-1)
163
+ c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
164
+ sin[:seqlen], "s d -> s 1 d"
165
+ )
166
+ q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
167
+
168
+ q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
169
+ k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
170
+
171
+ return torch.cat(
172
+ [
173
+ torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
174
+ torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
175
+ qkv[:, :, 2:3, :, :],
176
+ ],
177
+ axis=2,
178
+ )
179
+
180
+
181
+ class RotaryEmbedding(nn.Module):
182
+ """Rotary positional embedding (RoPE).
183
+
184
+ Reference:
185
+ RoFormer: Enhanced Transformer with Rotary Position Embedding.
186
+ https://arxiv.org/pdf/2104.09864.pdf.
187
+
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ dim: int,
193
+ base: int = 10000,
194
+ scale_base: Optional[float] = None,
195
+ pos_idx_in_fp32: bool = True,
196
+ max_position_embeddings: int = 2048,
197
+ device: Optional[str] = None,
198
+ **kwargs,
199
+ ) -> None:
200
+ super().__init__()
201
+
202
+ if scale_base is not None:
203
+ raise NotImplementedError
204
+
205
+ self.dim = dim
206
+ self.base = float(base)
207
+ self.scale_base = scale_base
208
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
209
+ self.max_position_embeddings = max_position_embeddings
210
+ self.device = device
211
+
212
+ # Generate and save the inverse frequency buffer (non-trainable)
213
+ inv_freq = self._compute_inv_freq(device)
214
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
215
+
216
+ # Generate and save the scale buffer (non-trainable)
217
+ scale = (
218
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
219
+ / (1.4 * dim)
220
+ if scale_base is not None
221
+ else None
222
+ )
223
+ self.register_buffer("scale", scale, persistent=False)
224
+
225
+ # Initialize cached attributes since ONNX can't rely on dynamic initialization
226
+ self._update_cos_sin_cache(
227
+ max_position_embeddings, device=device, dtype=torch.float32
228
+ )
229
+
230
+ def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
231
+ return 1.0 / (
232
+ self.base
233
+ ** (
234
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
235
+ / self.dim
236
+ )
237
+ )
238
+
239
+ def _update_cos_sin_cache(
240
+ self,
241
+ seqlen: int,
242
+ device: Optional[str] = None,
243
+ dtype: Optional[torch.dtype] = None,
244
+ ) -> None:
245
+ self._seq_len_cached = seqlen
246
+
247
+ # fp32 is preferred since the output of `torch.arange` can be quite large
248
+ # and bf16 would lose a lot of precision
249
+ if self.pos_idx_in_fp32:
250
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
251
+ if self.inv_freq.dtype != torch.float32:
252
+ inv_freq = self._compute_inv_freq(device=device)
253
+ else:
254
+ inv_freq = self.inv_freq
255
+ else:
256
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
257
+ inv_freq = self.inv_freq
258
+
259
+ # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
260
+ freqs = torch.outer(t, inv_freq)
261
+ if self.scale is None:
262
+ self._cos_cached = torch.cos(freqs).to(dtype)
263
+ self._sin_cached = torch.sin(freqs).to(dtype)
264
+ else:
265
+ power = (
266
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
267
+ - seqlen // 2
268
+ ) / self.scale_base
269
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
270
+
271
+ # Force the scale multiplication to happen in fp32
272
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
273
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
274
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
275
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
276
+
277
+ def forward(
278
+ self,
279
+ qkv: torch.Tensor,
280
+ kv: Optional[torch.Tensor] = None,
281
+ seqlen_offset: int = 0,
282
+ **kwargs,
283
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
284
+ seq_start = seqlen_offset
285
+ seq_end = seq_start + qkv.shape[1]
286
+
287
+ if (
288
+ self._cos_cached.device != qkv.device
289
+ or self._cos_cached.dtype != qkv.dtype
290
+ or (self.training and self._cos_cached.is_inference())
291
+ ):
292
+ self._update_cos_sin_cache(
293
+ self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype
294
+ )
295
+
296
+ if kv is None:
297
+ return _apply_rotary_emb_qkv(
298
+ qkv,
299
+ self._cos_cached[seq_start:seq_end],
300
+ self._sin_cached[seq_start:seq_end],
301
+ )
302
+ else:
303
+ q = _apply_rotary_emb(
304
+ qkv,
305
+ self._cos_cached[seq_start:seq_end],
306
+ self._sin_cached[seq_start:seq_end],
307
+ )
308
+ kv = _apply_rotary_emb_kv(
309
+ kv,
310
+ self._cos_cached[seq_start:seq_end],
311
+ self._sin_cached[seq_start:seq_end],
312
+ )
313
+
314
+ return q, kv
315
+
316
+
317
+ class MLP(nn.Module):
318
+ """Multi-Layer Perceptron.
319
+
320
+ Reference:
321
+ Attention Is All You Need.
322
+ https://arxiv.org/pdf/1706.03762.pdf.
323
+
324
+ """
325
+
326
+ def __init__(
327
+ self,
328
+ config: PretrainedConfig,
329
+ n_inner: Optional[int] = None,
330
+ act_fn: Optional[str] = None,
331
+ ) -> None:
332
+ super().__init__()
333
+
334
+ act_fn = config.activation_function if act_fn is None else act_fn
335
+
336
+ n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
337
+ n_inner = n_inner if n_inner is not None else 4 * config.n_embd
338
+
339
+ self.fc1 = nn.Linear(config.n_embd, n_inner)
340
+ self.fc2 = nn.Linear(n_inner, config.n_embd)
341
+ self.act = ACT2FN[act_fn]
342
+
343
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
344
+ hidden_states = self.fc1(hidden_states)
345
+ hidden_states = self.act(hidden_states)
346
+ hidden_states = self.fc2(hidden_states)
347
+
348
+ return hidden_states
349
+
350
+
351
+ class SelfAttention(nn.Module):
352
+ """Self-attention layer (compatible with PyTorch).
353
+
354
+ Reference:
355
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
356
+
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ causal: bool = True,
362
+ softmax_scale: Optional[float] = None,
363
+ attention_dropout: float = 0.0,
364
+ ) -> None:
365
+ super().__init__()
366
+
367
+ self.causal = causal
368
+ self.softmax_scale = softmax_scale
369
+ self.drop = nn.Dropout(attention_dropout)
370
+
371
+ @torch.autocast("cpu", enabled=False)
372
+ @torch.autocast("cuda", enabled=False)
373
+ def forward(
374
+ self,
375
+ qkv: torch.FloatTensor,
376
+ causal: bool = None,
377
+ key_padding_mask: Optional[torch.BoolTensor] = None,
378
+ **kwargs,
379
+ ) -> torch.FloatTensor:
380
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
381
+ q, k, v = qkv.unbind(dim=2)
382
+
383
+ q = q.to(torch.float32)
384
+ k = k.to(torch.float32)
385
+
386
+ causal = self.causal if causal is None else causal
387
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
388
+
389
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
390
+ # using float16, which might lead to overflow
391
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
392
+
393
+ if key_padding_mask is not None:
394
+ padding_mask = torch.full(
395
+ (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
396
+ )
397
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
398
+
399
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
400
+
401
+ if causal:
402
+ causal_mask = torch.triu(
403
+ torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
404
+ )
405
+ scores = scores + causal_mask.to(dtype=scores.dtype)
406
+
407
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
408
+ attention = self.drop(attention)
409
+
410
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
411
+
412
+ return output
413
+
414
+
415
+ class CrossAttention(nn.Module):
416
+ """Cross-attention layer (compatible with PyTorch).
417
+
418
+ Reference:
419
+ https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
420
+
421
+ """
422
+
423
+ def __init__(
424
+ self,
425
+ causal: bool = True,
426
+ softmax_scale: Optional[float] = None,
427
+ attention_dropout: float = 0.0,
428
+ ) -> None:
429
+ super().__init__()
430
+
431
+ self.causal = causal
432
+ self.softmax_scale = softmax_scale
433
+ self.drop = nn.Dropout(attention_dropout)
434
+
435
+ @torch.autocast("cpu", enabled=False)
436
+ @torch.autocast("cuda", enabled=False)
437
+ def forward(
438
+ self,
439
+ q: torch.FloatTensor,
440
+ kv: torch.FloatTensor,
441
+ causal: bool = None,
442
+ key_padding_mask: Optional[torch.BoolTensor] = None,
443
+ **kwargs,
444
+ ) -> torch.FloatTensor:
445
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
446
+ seqlen_k = kv.shape[1]
447
+
448
+ if kv.shape[3] != q.shape[2]:
449
+ kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
450
+ k, v = kv.unbind(dim=2)
451
+
452
+ q = q.to(torch.float32)
453
+ k = k.to(torch.float32)
454
+
455
+ causal = self.causal if causal is None else causal
456
+ softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
457
+
458
+ # Autocast is manually disabled to avoid `torch.einsum` performing the operation
459
+ # using float16, which might lead to overflow
460
+ scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
461
+
462
+ if key_padding_mask is not None:
463
+ padding_mask = torch.full(
464
+ (batch_size, seqlen_k),
465
+ -10000.0,
466
+ dtype=scores.dtype,
467
+ device=scores.device,
468
+ )
469
+ padding_mask.masked_fill_(key_padding_mask, 0.0)
470
+
471
+ scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
472
+
473
+ if causal:
474
+ rows = rearrange(
475
+ torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
476
+ )
477
+ cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
478
+ causal_mask = cols > rows + seqlen_k - seqlen_q
479
+
480
+ scores = scores.masked_fill(causal_mask, -10000.0)
481
+
482
+ attention = torch.softmax(scores, dim=-1).to(v.dtype)
483
+ attention = self.drop(attention)
484
+
485
+ output = torch.einsum("bhts,bshd->bthd", attention, v)
486
+
487
+ return output
488
+
489
+
490
+ def _find_mha_dims(
491
+ config: PretrainedConfig,
492
+ n_head: Optional[int] = None,
493
+ n_head_kv: Optional[int] = None,
494
+ head_dim: Optional[int] = None,
495
+ ) -> Tuple[int, int]:
496
+ if n_head is None and head_dim is None:
497
+ head_dim = config.n_embd // config.n_head
498
+ n_head = config.n_head
499
+ elif n_head is None or head_dim is None:
500
+ raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
501
+
502
+ if n_head_kv is None:
503
+ n_head_kv = getattr(config, "n_head_kv", None) or n_head
504
+
505
+ return n_head, n_head_kv, head_dim
506
+
507
+
508
+ def _update_kv_cache(
509
+ kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int
510
+ ) -> torch.FloatTensor:
511
+ num_heads, head_dim = kv.shape[-2:]
512
+
513
+ if layer_idx not in inference_params.key_value_memory_dict:
514
+ kv_cache = torch.empty(
515
+ inference_params.max_batch_size,
516
+ inference_params.max_seqlen,
517
+ 2,
518
+ num_heads,
519
+ head_dim,
520
+ dtype=kv.dtype,
521
+ device=kv.device,
522
+ )
523
+ inference_params.key_value_memory_dict[layer_idx] = kv_cache
524
+ else:
525
+ kv_cache = inference_params.key_value_memory_dict[layer_idx]
526
+
527
+ batch_start = inference_params.batch_size_offset
528
+ batch_end = batch_start + kv.shape[0]
529
+
530
+ sequence_start = inference_params.seqlen_offset
531
+ sequence_end = sequence_start + kv.shape[1]
532
+
533
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
534
+ kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
535
+
536
+ return kv
537
+
538
+
539
+ class MHA(nn.Module):
540
+ """Multi-head attention layer."""
541
+
542
+ def __init__(
543
+ self,
544
+ config: PretrainedConfig,
545
+ dtype: Optional[torch.dtype] = None,
546
+ device: Optional[str] = None,
547
+ rotary_dim: Optional[int] = None,
548
+ rotary_base: float = 10000.0,
549
+ rotary_scale_base: Optional[float] = None,
550
+ n_head: Optional[int] = None,
551
+ n_head_kv: Optional[int] = None,
552
+ head_dim: Optional[int] = None,
553
+ bias: bool = True,
554
+ causal: bool = True,
555
+ softmax_scale: Optional[float] = None,
556
+ layer_idx: Optional[int] = None,
557
+ return_residual: bool = False,
558
+ checkpointing: bool = False,
559
+ ) -> None:
560
+ super().__init__()
561
+
562
+ # Rotary embedding
563
+ self.rotary_dim = (
564
+ rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
565
+ )
566
+ if self.rotary_dim > 0:
567
+ rotary_cls = (
568
+ FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
569
+ )
570
+ if rotary_cls is None:
571
+ rotary_cls = RotaryEmbedding
572
+
573
+ rotary_kwargs = {}
574
+ if rotary_cls is RotaryEmbedding:
575
+ rotary_kwargs["max_position_embeddings"] = config.n_positions
576
+
577
+ self.rotary_emb = rotary_cls(
578
+ self.rotary_dim,
579
+ base=rotary_base,
580
+ scale_base=rotary_scale_base,
581
+ device=device,
582
+ **rotary_kwargs,
583
+ )
584
+
585
+ # MLP
586
+ self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
587
+ config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim
588
+ )
589
+ op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
590
+ hidden_size = config.n_embd
591
+
592
+ linear_cls = FusedDense if config.fused_dense else nn.Linear
593
+ if linear_cls is None:
594
+ linear_cls = nn.Linear
595
+
596
+ self.Wqkv = linear_cls(
597
+ hidden_size, op_size, bias=bias, device=device, dtype=dtype
598
+ )
599
+ self.out_proj = linear_cls(
600
+ hidden_size, hidden_size, bias=bias, device=device, dtype=dtype
601
+ )
602
+
603
+ # Attention
604
+ attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention
605
+ if attn_cls is None:
606
+ attn_cls = SelfAttention
607
+
608
+ cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention
609
+ if cross_attn_cls is None:
610
+ cross_attn_cls = CrossAttention
611
+
612
+ self.inner_attn = attn_cls(
613
+ causal=causal,
614
+ softmax_scale=softmax_scale,
615
+ attention_dropout=config.attn_pdrop,
616
+ )
617
+ self.inner_cross_attn = cross_attn_cls(
618
+ causal=causal,
619
+ softmax_scale=softmax_scale,
620
+ attention_dropout=config.attn_pdrop,
621
+ )
622
+
623
+ self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
624
+ self.layer_idx = layer_idx
625
+ self.return_residual = return_residual
626
+ self.checkpointing = checkpointing
627
+
628
+ def _forward_self_attn(
629
+ self,
630
+ x: torch.FloatTensor,
631
+ key_padding_mask: Optional[torch.BoolTensor],
632
+ cu_seqlens: Optional[torch.LongTensor] = None,
633
+ max_seqlen: Optional[int] = None,
634
+ ) -> torch.FloatTensor:
635
+ qkv = self.Wqkv(x)
636
+ qkv = rearrange(
637
+ qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
638
+ )
639
+
640
+ if self.rotary_dim > 0:
641
+ qkv = self.rotary_emb(qkv)
642
+
643
+ if self.flash_attn:
644
+ batch_size, seqlen = qkv.shape[0], qkv.shape[1]
645
+
646
+ if (
647
+ key_padding_mask is not None
648
+ and cu_seqlens is None
649
+ and max_seqlen is None
650
+ ):
651
+ # If `key_padding_mask` is supplied, we need to unpad the input and retrieve
652
+ # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
653
+ qkv, indices, cu_seqlens, max_seqlen = unpad_input(
654
+ qkv, key_padding_mask
655
+ )
656
+
657
+ if self.checkpointing:
658
+ attn_output = torch.utils.checkpoint.checkpoint(
659
+ self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
660
+ )
661
+ else:
662
+ attn_output = self.inner_attn(
663
+ qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
664
+ ).to(qkv.device)
665
+
666
+ # If `key_padding_mask` is supplied, we need to pad the output back to the original shape
667
+ return (
668
+ pad_input(attn_output, indices, batch_size, seqlen)
669
+ if key_padding_mask is not None
670
+ else attn_output
671
+ )
672
+
673
+ if self.checkpointing:
674
+ return torch.utils.checkpoint.checkpoint(
675
+ self.inner_attn, qkv, key_padding_mask=key_padding_mask
676
+ )
677
+
678
+ return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
679
+
680
+ def _forward_cross_attn(
681
+ self,
682
+ x: torch.FloatTensor,
683
+ past_key_values: Optional[InferenceParams],
684
+ key_padding_mask: Optional[torch.BoolTensor],
685
+ ) -> torch.FloatTensor:
686
+ batch_size = x.shape[0]
687
+
688
+ qkv = self.Wqkv(x)
689
+
690
+ q = qkv[..., : self.n_head * self.head_dim]
691
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
692
+
693
+ kv = qkv[..., self.n_head * self.head_dim :]
694
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
695
+
696
+ seqlen_offset = (
697
+ past_key_values.seqlen_offset if past_key_values is not None else 0
698
+ )
699
+ causal = None if seqlen_offset == 0 else False
700
+ if self.rotary_dim > 0:
701
+ q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
702
+
703
+ if past_key_values is not None:
704
+ kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
705
+
706
+ if self.flash_attn:
707
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
708
+ seqlen_k = kv.shape[1]
709
+
710
+ cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = (
711
+ None,
712
+ None,
713
+ None,
714
+ None,
715
+ )
716
+ if key_padding_mask is not None:
717
+ kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
718
+
719
+ if seqlen_q == 1:
720
+ key_padding_mask = torch.ones(batch_size, 1, device=q.device)
721
+ elif seqlen_q != seqlen_k:
722
+ key_padding_mask = key_padding_mask[:, -seqlen_q:]
723
+
724
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
725
+ q, key_padding_mask
726
+ )
727
+
728
+ if self.checkpointing:
729
+ attn_output = torch.utils.checkpoint.checkpoint(
730
+ self.inner_cross_attn,
731
+ q,
732
+ kv,
733
+ causal=causal,
734
+ cu_seqlens=cu_seqlens_q,
735
+ max_seqlen=max_seqlen_q,
736
+ cu_seqlens_k=cu_seqlens_k,
737
+ max_seqlen_k=max_seqlen_k,
738
+ )
739
+ else:
740
+ attn_output = self.inner_cross_attn(
741
+ q,
742
+ kv,
743
+ causal=causal,
744
+ cu_seqlens=cu_seqlens_q,
745
+ max_seqlen=max_seqlen_q,
746
+ cu_seqlens_k=cu_seqlens_k,
747
+ max_seqlen_k=max_seqlen_k,
748
+ )
749
+
750
+ return (
751
+ pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
752
+ if key_padding_mask is not None
753
+ else attn_output
754
+ )
755
+
756
+ if self.checkpointing:
757
+ return torch.utils.checkpoint.checkpoint(
758
+ self.inner_cross_attn,
759
+ q,
760
+ kv,
761
+ key_padding_mask=key_padding_mask,
762
+ causal=causal,
763
+ )
764
+
765
+ return self.inner_cross_attn(
766
+ q, kv, key_padding_mask=key_padding_mask, causal=causal
767
+ )
768
+
769
+ def forward(
770
+ self,
771
+ x: torch.FloatTensor,
772
+ past_key_values: Optional[InferenceParams] = None,
773
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
774
+ cu_seqlens: Optional[torch.LongTensor] = None,
775
+ max_seqlen: Optional[int] = None,
776
+ **kwargs,
777
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
778
+ # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
779
+ if attention_mask is not None:
780
+ attention_mask = attention_mask.bool()
781
+ else:
782
+ attention_mask = None
783
+
784
+ # MHA
785
+ if self.n_head == self.n_head_kv:
786
+ if past_key_values is None:
787
+ # If `past_key_values` are not supplied, we run self-attention
788
+ attn_output = self._forward_self_attn(
789
+ x, attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
790
+ )
791
+ else:
792
+ # If `past_key_values` are supplied, it means that we might have cached values and
793
+ # could take advantage of cross-attention
794
+ attn_output = self._forward_cross_attn(
795
+ x,
796
+ past_key_values,
797
+ attention_mask,
798
+ cu_seqlens=cu_seqlens,
799
+ max_seqlen=max_seqlen,
800
+ )
801
+ # MQA / GQA
802
+ else:
803
+ # Regardless of `past_key_values` being supplied or not, it always use cross-attention
804
+ # because `q` and `kv` lengths might be different
805
+ attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
806
+
807
+ output = rearrange(attn_output, "... h d -> ... (h d)")
808
+ output = self.out_proj(output)
809
+
810
+ return output if not self.return_residual else (output, x)
811
+
812
+
813
+ class ParallelBlock(nn.Module):
814
+ """Parallel block.
815
+
816
+ This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
817
+
818
+ """
819
+
820
+ def __init__(
821
+ self,
822
+ config: PretrainedConfig,
823
+ block_idx: Optional[int] = None,
824
+ ) -> None:
825
+ super().__init__()
826
+
827
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
828
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
829
+ self.block_idx = block_idx
830
+
831
+ self.mixer = MHA(config, layer_idx=block_idx)
832
+ self.mlp = MLP(config)
833
+
834
+ def forward(
835
+ self,
836
+ hidden_states: torch.FloatTensor,
837
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
838
+ attention_mask: Optional[torch.BoolTensor] = None,
839
+ **kwargs,
840
+ ) -> torch.FloatTensor:
841
+ residual = hidden_states
842
+ hidden_states = self.ln(hidden_states)
843
+
844
+ attn_outputs = self.mixer(
845
+ hidden_states,
846
+ past_key_values=past_key_values,
847
+ attention_mask=attention_mask,
848
+ )
849
+ if isinstance(attn_outputs, tuple):
850
+ attn_outputs = attn_outputs[0]
851
+
852
+ attn_outputs = self.resid_dropout(attn_outputs)
853
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
854
+
855
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
856
+
857
+ return hidden_states
858
+
859
+
860
+ class CausalLMHead(nn.Module):
861
+ """Causal Language Modeling head.
862
+
863
+ Reference:
864
+ Improving Language Understanding by Generative Pre-Training.
865
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
866
+
867
+ """
868
+
869
+ def __init__(self, config: PretrainedConfig) -> None:
870
+ super().__init__()
871
+
872
+ self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
873
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
874
+
875
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
876
+ hidden_states = self.ln(hidden_states)
877
+ logits = self.linear(hidden_states).to(torch.float32)
878
+
879
+ return logits
880
+
881
+
882
+ class CausalLMLoss(nn.Module):
883
+ """Causal Language Modeling loss.
884
+
885
+ Reference:
886
+ Improving Language Understanding by Generative Pre-Training.
887
+ https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
888
+
889
+ """
890
+
891
+ def __init__(self, shift_labels: bool = True) -> None:
892
+ super().__init__()
893
+
894
+ self.shift_labels = shift_labels
895
+ self.loss_fct = nn.CrossEntropyLoss()
896
+
897
+ def forward(
898
+ self, logits: torch.FloatTensor, labels: torch.LongTensor
899
+ ) -> torch.FloatTensor:
900
+ if self.shift_labels:
901
+ logits = logits[..., :-1, :].contiguous()
902
+ labels = labels[..., 1:].contiguous()
903
+
904
+ loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
905
+
906
+ return loss
907
+
908
+
909
+ class PhiPreTrainedModel(PreTrainedModel):
910
+ """Phi pre-trained model."""
911
+
912
+ config_class = PhiConfig
913
+ base_model_prefix = "transformer"
914
+ supports_gradient_checkpointing = False
915
+ _no_split_modules = ["ParallelBlock"]
916
+
917
+ def __init__(self, *inputs, **kwargs) -> None:
918
+ super().__init__(*inputs, **kwargs)
919
+
920
+ def _init_weights(self, module: nn.Module) -> None:
921
+ if isinstance(module, (nn.Linear,)):
922
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
923
+ if module.bias is not None:
924
+ module.bias.data.zero_()
925
+ elif isinstance(module, nn.Embedding):
926
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
927
+ if module.padding_idx is not None:
928
+ module.weight.data[module.padding_idx].zero_()
929
+ elif isinstance(module, nn.LayerNorm):
930
+ if module.bias is not None:
931
+ module.bias.data.zero_()
932
+ module.weight.data.fill_(1.0)
933
+
934
+ def prepare_inputs_for_generation(
935
+ self,
936
+ input_ids: torch.LongTensor,
937
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
938
+ attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
939
+ **kwargs,
940
+ ) -> Dict[str, Any]:
941
+ if past_key_values is None or not (
942
+ isinstance(past_key_values, InferenceParams)
943
+ ):
944
+ past_key_values = InferenceParams(
945
+ max_seqlen=self.config.n_positions,
946
+ max_batch_size=input_ids.shape[0],
947
+ seqlen_offset=0,
948
+ batch_size_offset=0,
949
+ key_value_memory_dict={},
950
+ lengths_per_sample=None,
951
+ )
952
+ else:
953
+ # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
954
+ past_key_values.seqlen_offset = len(input_ids[0]) - 1
955
+ input_ids = input_ids[:, -1].unsqueeze(-1)
956
+
957
+ return {
958
+ "input_ids": input_ids,
959
+ "past_key_values": past_key_values,
960
+ "attention_mask": attention_mask,
961
+ }
962
+
963
+
964
+ class PhiModel(PhiPreTrainedModel):
965
+ """Phi model."""
966
+
967
+ _keys_to_ignore_on_load_missing = [""]
968
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
969
+
970
+ def __init__(self, config: PhiConfig) -> None:
971
+ super().__init__(config)
972
+
973
+ self.embd = Embedding(config)
974
+ self.h = nn.ModuleList(
975
+ [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
976
+ )
977
+ self.gradient_checkpointing = False
978
+ self.post_init()
979
+
980
+ def get_input_embeddings(self) -> nn.Embedding:
981
+ return self.embd.wte
982
+
983
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
984
+ self.embd.wte = new_embeddings
985
+
986
+ def forward(
987
+ self,
988
+ input_ids: torch.LongTensor,
989
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
990
+ attention_mask: Optional[torch.BoolTensor] = None,
991
+ cu_seqlens: Optional[torch.LongTensor] = None,
992
+ max_seqlen: Optional[int] = None,
993
+ ) -> torch.FloatTensor:
994
+ hidden_states = self.embd(input_ids)
995
+
996
+ for layer in self.h:
997
+ hidden_states = layer(
998
+ hidden_states,
999
+ past_key_values=past_key_values,
1000
+ attention_mask=attention_mask,
1001
+ cu_seqlens=cu_seqlens,
1002
+ max_seqlen=max_seqlen,
1003
+ )
1004
+
1005
+ return hidden_states
1006
+
1007
+
1008
+ class PhiForCausalLM(PhiPreTrainedModel):
1009
+ """Phi for Causal Language Modeling."""
1010
+
1011
+ _keys_to_ignore_on_load_missing = [""]
1012
+ _keys_to_ignore_on_load_unexpected = [
1013
+ r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
1014
+ ]
1015
+
1016
+ def __init__(self, config: PhiConfig) -> None:
1017
+ super().__init__(config)
1018
+
1019
+ self.transformer = PhiModel(config)
1020
+ self.lm_head = CausalLMHead(config)
1021
+ self.loss = CausalLMLoss()
1022
+
1023
+ self.post_init()
1024
+
1025
+ def get_output_embeddings(self) -> nn.Linear:
1026
+ return self.lm_head.linear
1027
+
1028
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
1029
+ self.lm_head.linear = new_embeddings
1030
+
1031
+ def forward(
1032
+ self,
1033
+ input_ids: torch.LongTensor,
1034
+ past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
1035
+ attention_mask: Optional[torch.BoolTensor] = None,
1036
+ labels: Optional[torch.LongTensor] = None,
1037
+ position_ids: Optional[torch.LongTensor] = None,
1038
+ **kwargs,
1039
+ ) -> CausalLMOutputWithPast:
1040
+ cu_seqlens: Optional[torch.LongTensor] = None
1041
+ max_seqlen: Optional[int] = None
1042
+ if position_ids is not None:
1043
+ batch_size, seq_length = input_ids.shape
1044
+ position_ids = position_ids.view(-1, seq_length).long()
1045
+ cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
1046
+ cu_seqlens = cu_seqlens.squeeze()
1047
+
1048
+ hidden_states = self.transformer(
1049
+ input_ids,
1050
+ past_key_values=past_key_values,
1051
+ attention_mask=attention_mask,
1052
+ cu_seqlens=cu_seqlens,
1053
+ max_seqlen=max_seqlen,
1054
+ )
1055
+ lm_logits = self.lm_head(hidden_states)
1056
+
1057
+ loss = None
1058
+ if labels is not None:
1059
+ loss = self.loss(lm_logits, labels)
1060
+
1061
+ return CausalLMOutputWithPast(
1062
+ loss=loss, logits=lm_logits, past_key_values=past_key_values
1063
+ )
src/axolotl/utils/models.py CHANGED
@@ -288,10 +288,10 @@ def load_model(
288
  # device=cfg.device,
289
  # )
290
  # model.train() # sets to train instead of eval mode
291
- elif model_type == "MixFormerSequentialForCausalLM":
292
- from axolotl.models.phi import MixFormerSequentialForCausalLM
293
 
294
- model = MixFormerSequentialForCausalLM.from_pretrained(
295
  base_model,
296
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
297
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
 
288
  # device=cfg.device,
289
  # )
290
  # model.train() # sets to train instead of eval mode
291
+ elif model_type == "PhiForCausalLM":
292
+ from axolotl.models.phi import PhiForCausalLM
293
 
294
+ model = PhiForCausalLM.from_pretrained(
295
  base_model,
296
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
297
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
tests/e2e/test_phi.py CHANGED
@@ -31,7 +31,7 @@ class TestPhi(unittest.TestCase):
31
  {
32
  "base_model": "microsoft/phi-1_5",
33
  "trust_remote_code": True,
34
- "model_type": "MixFormerSequentialForCausalLM",
35
  "tokenizer_type": "AutoTokenizer",
36
  "sequence_len": 512,
37
  "sample_packing": False,
@@ -76,7 +76,7 @@ class TestPhi(unittest.TestCase):
76
  {
77
  "base_model": "microsoft/phi-1_5",
78
  "trust_remote_code": True,
79
- "model_type": "MixFormerSequentialForCausalLM",
80
  "tokenizer_type": "AutoTokenizer",
81
  "sequence_len": 512,
82
  "sample_packing": True,
 
31
  {
32
  "base_model": "microsoft/phi-1_5",
33
  "trust_remote_code": True,
34
+ "model_type": "PhiForCausalLM",
35
  "tokenizer_type": "AutoTokenizer",
36
  "sequence_len": 512,
37
  "sample_packing": False,
 
76
  {
77
  "base_model": "microsoft/phi-1_5",
78
  "trust_remote_code": True,
79
+ "model_type": "PhiForCausalLM",
80
  "tokenizer_type": "AutoTokenizer",
81
  "sequence_len": 512,
82
  "sample_packing": True,