Andrei Panferov
commited on
Commit
•
5edaefc
1
Parent(s):
7e4a8ff
depth 1
Browse files- configuration_llama.py +1 -1
- inference.py +282 -8
- modeling_llama.py +348 -109
configuration_llama.py
CHANGED
@@ -3,7 +3,7 @@ from transformers import LlamaConfig as OrigLlamaConfig
|
|
3 |
|
4 |
class LlamaConfig(OrigLlamaConfig):
|
5 |
model_type = "llama_aqlm"
|
6 |
-
|
7 |
def __init__(
|
8 |
self,
|
9 |
nbits_per_codebook: int = 16,
|
|
|
3 |
|
4 |
class LlamaConfig(OrigLlamaConfig):
|
5 |
model_type = "llama_aqlm"
|
6 |
+
|
7 |
def __init__(
|
8 |
self,
|
9 |
nbits_per_codebook: int = 16,
|
inference.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
""" Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
|
2 |
-
import
|
3 |
-
|
|
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
-
|
9 |
-
|
10 |
-
from src.utils import _dequantize_weight, ellipsis, get_int_dtype, unpack_int_data
|
11 |
|
12 |
|
13 |
class FinalizedQuantizedLinear(nn.Module):
|
@@ -39,12 +39,17 @@ class FinalizedQuantizedLinear(nn.Module):
|
|
39 |
|
40 |
# CODES & CODEBOOKS
|
41 |
self.codebooks = nn.Parameter(
|
42 |
-
torch.empty(
|
|
|
|
|
|
|
43 |
requires_grad=True,
|
44 |
) # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
45 |
self.codes = nn.Parameter(
|
46 |
torch.empty(
|
47 |
-
(num_out_groups, num_in_groups, num_codebooks),
|
|
|
|
|
48 |
),
|
49 |
requires_grad=False,
|
50 |
) # [num_out_groups, num_in_groups, num_codebooks]
|
@@ -61,4 +66,273 @@ class FinalizedQuantizedLinear(nn.Module):
|
|
61 |
self.register_parameter("bias", None)
|
62 |
|
63 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
64 |
-
return forward_pass_quantized_linear(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
""" Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
9 |
+
import triton
|
10 |
+
import triton.language as tl
|
|
|
11 |
|
12 |
|
13 |
class FinalizedQuantizedLinear(nn.Module):
|
|
|
39 |
|
40 |
# CODES & CODEBOOKS
|
41 |
self.codebooks = nn.Parameter(
|
42 |
+
torch.empty(
|
43 |
+
(num_codebooks, self.codebook_size, out_group_size, in_group_size),
|
44 |
+
**factory_kwargs,
|
45 |
+
),
|
46 |
requires_grad=True,
|
47 |
) # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
48 |
self.codes = nn.Parameter(
|
49 |
torch.empty(
|
50 |
+
(num_out_groups, num_in_groups, num_codebooks),
|
51 |
+
device=device,
|
52 |
+
dtype=get_int_dtype(nbits_per_codebook),
|
53 |
),
|
54 |
requires_grad=False,
|
55 |
) # [num_out_groups, num_in_groups, num_codebooks]
|
|
|
66 |
self.register_parameter("bias", None)
|
67 |
|
68 |
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
69 |
+
return forward_pass_quantized_linear(
|
70 |
+
input, self.codes, self.codebooks, self.scales, self.bias
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def get_int_dtype(nbits: int) -> torch.dtype:
|
75 |
+
if nbits <= 8:
|
76 |
+
return torch.int8
|
77 |
+
if nbits <= 16:
|
78 |
+
return torch.int16
|
79 |
+
if nbits <= 32:
|
80 |
+
return torch.int32
|
81 |
+
if nbits <= 64:
|
82 |
+
return torch.int64
|
83 |
+
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
|
84 |
+
|
85 |
+
|
86 |
+
@torch.inference_mode()
|
87 |
+
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
|
88 |
+
return data.to(torch.int64) % (2**nbits)
|
89 |
+
|
90 |
+
|
91 |
+
@functools.lru_cache()
|
92 |
+
def maybe_script(fn: callable) -> callable:
|
93 |
+
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
|
94 |
+
using_tpu = bool(os.environ.get("TPU_NAME"))
|
95 |
+
# this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
|
96 |
+
should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
|
97 |
+
return torch.jit.script(fn) if should_script else fn
|
98 |
+
|
99 |
+
|
100 |
+
@maybe_script
|
101 |
+
def _dequantize_weight(
|
102 |
+
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Decode float weights from quantization codes. Differentiable.
|
106 |
+
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
|
107 |
+
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
|
108 |
+
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
|
109 |
+
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
|
110 |
+
"""
|
111 |
+
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
|
112 |
+
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
|
113 |
+
out_features = num_out_groups * out_group_size
|
114 |
+
in_features = num_in_groups * in_group_size
|
115 |
+
codebook_offsets = torch.arange(
|
116 |
+
0, num_codebooks * codebook_size, codebook_size, device=codes.device
|
117 |
+
) # shape: [num_codebooks]
|
118 |
+
reconstructed_weight_flat = F.embedding_bag(
|
119 |
+
codes.flatten(0, -2) + codebook_offsets,
|
120 |
+
codebooks.flatten(0, 1).flatten(-2, -1),
|
121 |
+
mode="sum",
|
122 |
+
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
|
123 |
+
|
124 |
+
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
|
125 |
+
list(codes.shape[:-3])
|
126 |
+
+ [num_out_groups, num_in_groups, out_group_size, in_group_size]
|
127 |
+
)
|
128 |
+
if scales is not None:
|
129 |
+
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
|
130 |
+
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(
|
131 |
+
list(codes.shape[:-3]) + [out_features, in_features]
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
def forward_pass_quantized_linear(
|
136 |
+
input: torch.Tensor,
|
137 |
+
codes: torch.IntTensor,
|
138 |
+
codebooks: torch.Tensor,
|
139 |
+
scales: torch.Tensor,
|
140 |
+
bias: Optional[torch.Tensor],
|
141 |
+
) -> torch.Tensor:
|
142 |
+
if input.is_cuda:
|
143 |
+
matmul_result = aqlm_gemm_stupid(input, codes, codebooks, scales)
|
144 |
+
if bias is not None:
|
145 |
+
matmul_result += bias
|
146 |
+
return matmul_result
|
147 |
+
else:
|
148 |
+
dequantized_weight = _dequantize_weight(
|
149 |
+
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
|
150 |
+
codebooks,
|
151 |
+
scales,
|
152 |
+
)
|
153 |
+
return F.linear(input, dequantized_weight, bias)
|
154 |
+
|
155 |
+
|
156 |
+
@triton.autotune(
|
157 |
+
configs=[
|
158 |
+
triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
|
159 |
+
for num_stages in (1, 2, 3, 4, 5)
|
160 |
+
for num_warps in (1, 2, 4, 8)
|
161 |
+
],
|
162 |
+
key=[
|
163 |
+
"in_features",
|
164 |
+
"out_features",
|
165 |
+
"num_codebooks",
|
166 |
+
"codebook_size",
|
167 |
+
"out_group_size",
|
168 |
+
"in_group_size",
|
169 |
+
"num_input_groups",
|
170 |
+
"num_input_groups_next_power_of_2",
|
171 |
+
"compute_in_fp32",
|
172 |
+
],
|
173 |
+
)
|
174 |
+
@triton.jit
|
175 |
+
def _aqlm_gemv_simple(
|
176 |
+
input_vec_ptr,
|
177 |
+
output_vec_ptr,
|
178 |
+
codes_i16_ptr,
|
179 |
+
codebooks_ptr,
|
180 |
+
scales_ptr,
|
181 |
+
in_features: tl.constexpr,
|
182 |
+
out_features: tl.constexpr,
|
183 |
+
num_codebooks: tl.constexpr,
|
184 |
+
codebook_size: tl.constexpr,
|
185 |
+
out_group_size: tl.constexpr,
|
186 |
+
in_group_size: tl.constexpr,
|
187 |
+
num_input_groups: tl.constexpr,
|
188 |
+
num_input_groups_next_power_of_2: tl.constexpr,
|
189 |
+
compute_in_fp32: tl.constexpr,
|
190 |
+
UNUSED: tl.constexpr,
|
191 |
+
):
|
192 |
+
# variables ending with "_i" mean "for i-th output unit"
|
193 |
+
pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
|
194 |
+
|
195 |
+
# Stage 1: load input data
|
196 |
+
input_vec = tl.load(
|
197 |
+
input_vec_ptr
|
198 |
+
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
|
199 |
+
+ tl.arange(0, in_group_size)[None, None, :],
|
200 |
+
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None]
|
201 |
+
< num_input_groups,
|
202 |
+
)
|
203 |
+
# [in_features//in_group_size, 1, group_size]
|
204 |
+
# Note: we could simply load input_vec then reshape
|
205 |
+
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
|
206 |
+
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
|
207 |
+
# , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
|
208 |
+
|
209 |
+
# Stage 2: load integer codes for the active row
|
210 |
+
# [in_features // in_group_size, num_codebooks]
|
211 |
+
codes_i_ptrs = (
|
212 |
+
codes_i16_ptr
|
213 |
+
+ pid * num_input_groups * num_codebooks
|
214 |
+
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
|
215 |
+
+ tl.arange(0, num_codebooks)[None, :]
|
216 |
+
)
|
217 |
+
codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
|
218 |
+
|
219 |
+
codes_i = tl.load(
|
220 |
+
codes_i_ptrs, mask=codes_i_mask_1d[:, None]
|
221 |
+
) # [in_features//in_group_size, num_codebooks]
|
222 |
+
if codes_i.dtype == tl.int16:
|
223 |
+
codes_i = codes_i.to(tl.int32)
|
224 |
+
codes_i = (codes_i) + (
|
225 |
+
codes_i < 0
|
226 |
+
) * codebook_size # aka 2 ** nbits_per_codebook
|
227 |
+
# ^-- (because codes are int16 tensors that contain uint data)
|
228 |
+
|
229 |
+
# The following alternative does not work:
|
230 |
+
# codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook
|
231 |
+
else:
|
232 |
+
codes_i = codes_i.to(tl.int32)
|
233 |
+
|
234 |
+
# shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
|
235 |
+
codes_i += (
|
236 |
+
tl.arange(0, num_codebooks)[None, :] * codebook_size
|
237 |
+
) # aka 2 ** nbits_per_codebook
|
238 |
+
# ^-- [in_group_size, num_codebooks]
|
239 |
+
|
240 |
+
# Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
|
241 |
+
# [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
|
242 |
+
out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
|
243 |
+
in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
|
244 |
+
weight_i_ptrs = (
|
245 |
+
codebooks_ptr
|
246 |
+
+ codes_i[:, :, None, None] * out_group_size * in_group_size
|
247 |
+
+ out_group_ix * in_group_size
|
248 |
+
+ in_group_ix
|
249 |
+
)
|
250 |
+
|
251 |
+
# Stage 4: reconstruct weights, multiply by inputs and write out
|
252 |
+
weights_i = tl.load(
|
253 |
+
weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0
|
254 |
+
)
|
255 |
+
if compute_in_fp32:
|
256 |
+
weights_i = weights_i.to(tl.float32)
|
257 |
+
input_vec = input_vec.to(tl.float32)
|
258 |
+
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
|
259 |
+
weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
|
260 |
+
# ^-- [in_features // in_group_size, out_group_size, in_group_size]
|
261 |
+
|
262 |
+
if out_group_size == 1:
|
263 |
+
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
|
264 |
+
output_i = tl.sum(weights_i * input_vec) * scale
|
265 |
+
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
|
266 |
+
else:
|
267 |
+
output_i = tl.sum(
|
268 |
+
tl.sum(weights_i * input_vec, axis=2), axis=0
|
269 |
+
) # [out_group_size]
|
270 |
+
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
|
271 |
+
tl.store(
|
272 |
+
output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size),
|
273 |
+
output_i.to(input_vec.dtype),
|
274 |
+
)
|
275 |
+
|
276 |
+
|
277 |
+
def next_power_of_2(x):
|
278 |
+
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
279 |
+
|
280 |
+
|
281 |
+
def aqlm_gemv_simple(
|
282 |
+
input_vec: torch.Tensor,
|
283 |
+
codes_i16: torch.ShortTensor,
|
284 |
+
codebooks: torch.Tensor,
|
285 |
+
scales: torch.Tensor,
|
286 |
+
compute_in_fp32: bool = True,
|
287 |
+
):
|
288 |
+
|
289 |
+
device, dtype = codebooks.device, codebooks.dtype
|
290 |
+
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
|
291 |
+
in_features = input_vec.shape[1]
|
292 |
+
out_features = codes_i16.shape[0] * out_group_size
|
293 |
+
num_input_groups = codes_i16.shape[1]
|
294 |
+
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
|
295 |
+
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
|
296 |
+
assert in_features % in_group_size == 0
|
297 |
+
assert codebooks.shape[1] == 2**16
|
298 |
+
|
299 |
+
output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
|
300 |
+
# 1D launch kernel where each block computes output unit
|
301 |
+
grid = lambda META: (out_features // out_group_size,)
|
302 |
+
_aqlm_gemv_simple[grid](
|
303 |
+
input_vec,
|
304 |
+
output_vec,
|
305 |
+
codes_i16,
|
306 |
+
codebooks,
|
307 |
+
scales,
|
308 |
+
in_features,
|
309 |
+
out_features,
|
310 |
+
num_codebooks,
|
311 |
+
codebook_size,
|
312 |
+
out_group_size,
|
313 |
+
in_group_size,
|
314 |
+
num_input_groups,
|
315 |
+
next_power_of_2(num_input_groups),
|
316 |
+
compute_in_fp32,
|
317 |
+
)
|
318 |
+
|
319 |
+
return output_vec
|
320 |
+
|
321 |
+
|
322 |
+
def aqlm_gemm_stupid(
|
323 |
+
input: torch.Tensor,
|
324 |
+
codes_i16: torch.ShortTensor,
|
325 |
+
codebooks: torch.Tensor,
|
326 |
+
scales: torch.Tensor,
|
327 |
+
compute_in_fp32: bool = True,
|
328 |
+
):
|
329 |
+
original_shape = input.shape
|
330 |
+
input = input.reshape(-1, original_shape[-1])
|
331 |
+
return torch.cat(
|
332 |
+
[
|
333 |
+
aqlm_gemv_simple(
|
334 |
+
input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32
|
335 |
+
)
|
336 |
+
for input_vec in input
|
337 |
+
]
|
338 |
+
).reshape(original_shape[:-1] + (-1,))
|
modeling_llama.py
CHANGED
@@ -27,27 +27,23 @@ import torch.utils.checkpoint
|
|
27 |
from torch import nn
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
from transformers.activations import ACT2FN
|
30 |
-
from transformers.modeling_outputs import (
|
31 |
-
|
32 |
-
|
33 |
-
SequenceClassifierOutputWithPast,
|
34 |
-
)
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
37 |
-
from transformers.utils import (
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
logging,
|
42 |
-
replace_return_docstrings,
|
43 |
-
)
|
44 |
|
45 |
from .configuration_llama import LlamaConfig
|
46 |
-
from
|
47 |
|
48 |
if is_flash_attn_available():
|
49 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
50 |
-
from flash_attn.bert_padding import index_first_axis, pad_input,
|
|
|
51 |
|
52 |
|
53 |
logger = logging.get_logger(__name__)
|
@@ -59,7 +55,9 @@ def _get_unpad_data(padding_mask):
|
|
59 |
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
60 |
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
61 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
62 |
-
cu_seqlens = F.pad(
|
|
|
|
|
63 |
return (
|
64 |
indices,
|
65 |
cu_seqlens,
|
@@ -69,7 +67,10 @@ def _get_unpad_data(padding_mask):
|
|
69 |
|
70 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
71 |
def _make_causal_mask(
|
72 |
-
input_ids_shape: torch.Size,
|
|
|
|
|
|
|
73 |
):
|
74 |
"""
|
75 |
Make causal mask used for bi-directional self-attention.
|
@@ -81,8 +82,18 @@ def _make_causal_mask(
|
|
81 |
mask = mask.to(dtype)
|
82 |
|
83 |
if past_key_values_length > 0:
|
84 |
-
mask = torch.cat(
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
@@ -97,7 +108,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
97 |
|
98 |
inverted_mask = 1.0 - expanded_mask
|
99 |
|
100 |
-
return inverted_mask.masked_fill(
|
|
|
|
|
101 |
|
102 |
|
103 |
class LlamaRMSNorm(nn.Module):
|
@@ -127,23 +140,33 @@ class LlamaRotaryEmbedding(nn.Module):
|
|
127 |
self.dim = dim
|
128 |
self.max_position_embeddings = max_position_embeddings
|
129 |
self.base = base
|
130 |
-
inv_freq = 1.0 / (
|
|
|
|
|
131 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
132 |
|
133 |
# Build here to make `torch.jit.trace` work.
|
134 |
self._set_cos_sin_cache(
|
135 |
-
seq_len=max_position_embeddings,
|
|
|
|
|
136 |
)
|
137 |
|
138 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
139 |
self.max_seq_len_cached = seq_len
|
140 |
-
t = torch.arange(
|
|
|
|
|
141 |
|
142 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
143 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
144 |
emb = torch.cat((freqs, freqs), dim=-1)
|
145 |
-
self.register_buffer(
|
146 |
-
|
|
|
|
|
|
|
|
|
147 |
|
148 |
def forward(self, x, seq_len=None):
|
149 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
@@ -159,26 +182,46 @@ class LlamaRotaryEmbedding(nn.Module):
|
|
159 |
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
160 |
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
161 |
|
162 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
self.scaling_factor = scaling_factor
|
164 |
super().__init__(dim, max_position_embeddings, base, device)
|
165 |
|
166 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
167 |
self.max_seq_len_cached = seq_len
|
168 |
-
t = torch.arange(
|
|
|
|
|
169 |
t = t / self.scaling_factor
|
170 |
|
171 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
172 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
173 |
emb = torch.cat((freqs, freqs), dim=-1)
|
174 |
-
self.register_buffer(
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
|
177 |
|
178 |
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
179 |
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
180 |
|
181 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
self.scaling_factor = scaling_factor
|
183 |
super().__init__(dim, max_position_embeddings, base, device)
|
184 |
|
@@ -187,18 +230,27 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
|
187 |
|
188 |
if seq_len > self.max_position_embeddings:
|
189 |
base = self.base * (
|
190 |
-
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
|
|
191 |
) ** (self.dim / (self.dim - 2))
|
192 |
-
inv_freq = 1.0 / (
|
|
|
|
|
193 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
194 |
|
195 |
-
t = torch.arange(
|
|
|
|
|
196 |
|
197 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
198 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
199 |
emb = torch.cat((freqs, freqs), dim=-1)
|
200 |
-
self.register_buffer(
|
201 |
-
|
|
|
|
|
|
|
|
|
202 |
|
203 |
|
204 |
def rotate_half(x):
|
@@ -225,9 +277,15 @@ class LlamaMLP(nn.Module):
|
|
225 |
self.config = config
|
226 |
self.hidden_size = config.hidden_size
|
227 |
self.intermediate_size = config.intermediate_size
|
228 |
-
self.gate_proj = FinalizedQuantizedLinear(
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
self.act_fn = ACT2FN[config.hidden_act]
|
232 |
|
233 |
def forward(self, x):
|
@@ -237,12 +295,25 @@ class LlamaMLP(nn.Module):
|
|
237 |
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
238 |
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
239 |
|
240 |
-
gate_proj = torch.cat(
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
244 |
down_proj = [
|
245 |
-
F.linear(intermediate_states[i], down_proj_slices[i])
|
|
|
246 |
]
|
247 |
down_proj = sum(down_proj)
|
248 |
else:
|
@@ -259,7 +330,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
259 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
260 |
if n_rep == 1:
|
261 |
return hidden_states
|
262 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
|
|
|
263 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
264 |
|
265 |
|
@@ -283,16 +356,28 @@ class LlamaAttention(nn.Module):
|
|
283 |
f" and `num_heads`: {self.num_heads})."
|
284 |
)
|
285 |
self.q_proj = FinalizedQuantizedLinear(
|
286 |
-
self.hidden_size,
|
|
|
|
|
|
|
287 |
)
|
288 |
self.k_proj = FinalizedQuantizedLinear(
|
289 |
-
self.hidden_size,
|
|
|
|
|
|
|
290 |
)
|
291 |
self.v_proj = FinalizedQuantizedLinear(
|
292 |
-
self.hidden_size,
|
|
|
|
|
|
|
293 |
)
|
294 |
self.o_proj = FinalizedQuantizedLinear(
|
295 |
-
self.num_heads * self.head_dim,
|
|
|
|
|
|
|
296 |
)
|
297 |
self._init_rope()
|
298 |
|
@@ -324,7 +409,11 @@ class LlamaAttention(nn.Module):
|
|
324 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
325 |
|
326 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
327 |
-
return
|
|
|
|
|
|
|
|
|
328 |
|
329 |
def forward(
|
330 |
self,
|
@@ -339,20 +428,31 @@ class LlamaAttention(nn.Module):
|
|
339 |
bsz, q_len, _ = hidden_states.size()
|
340 |
|
341 |
if self.config.pretraining_tp > 1:
|
342 |
-
key_value_slicing = (
|
|
|
|
|
343 |
query_slices = self.q_proj.weight.split(
|
344 |
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
345 |
)
|
346 |
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
347 |
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
348 |
|
349 |
-
query_states = [
|
|
|
|
|
|
|
350 |
query_states = torch.cat(query_states, dim=-1)
|
351 |
|
352 |
-
key_states = [
|
|
|
|
|
|
|
353 |
key_states = torch.cat(key_states, dim=-1)
|
354 |
|
355 |
-
value_states = [
|
|
|
|
|
|
|
356 |
value_states = torch.cat(value_states, dim=-1)
|
357 |
|
358 |
else:
|
@@ -360,15 +460,23 @@ class LlamaAttention(nn.Module):
|
|
360 |
key_states = self.k_proj(hidden_states)
|
361 |
value_states = self.v_proj(hidden_states)
|
362 |
|
363 |
-
query_states = query_states.view(
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
kv_seq_len = key_states.shape[-2]
|
368 |
if past_key_value is not None:
|
369 |
kv_seq_len += past_key_value[0].shape[-2]
|
370 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
371 |
-
query_states, key_states = apply_rotary_pos_emb(
|
|
|
|
|
372 |
|
373 |
if past_key_value is not None:
|
374 |
# reuse k, v, self_attention
|
@@ -380,7 +488,9 @@ class LlamaAttention(nn.Module):
|
|
380 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
381 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
382 |
|
383 |
-
attn_weights = torch.matmul(
|
|
|
|
|
384 |
|
385 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
386 |
raise ValueError(
|
@@ -396,7 +506,9 @@ class LlamaAttention(nn.Module):
|
|
396 |
attn_weights = attn_weights + attention_mask
|
397 |
|
398 |
# upcast attention to fp32
|
399 |
-
attn_weights = nn.functional.softmax(
|
|
|
|
|
400 |
attn_output = torch.matmul(attn_weights, value_states)
|
401 |
|
402 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
@@ -410,9 +522,18 @@ class LlamaAttention(nn.Module):
|
|
410 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
411 |
|
412 |
if self.config.pretraining_tp > 1:
|
413 |
-
attn_output = attn_output.split(
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
else:
|
417 |
attn_output = self.o_proj(attn_output)
|
418 |
|
@@ -451,9 +572,15 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
451 |
# Flash attention requires the input to have the shape
|
452 |
# batch_size x seq_length x head_dime x hidden_dim
|
453 |
# therefore we just need to keep the original shape
|
454 |
-
query_states = query_states.view(
|
455 |
-
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
|
458 |
kv_seq_len = key_states.shape[-2]
|
459 |
if past_key_value is not None:
|
@@ -461,7 +588,9 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
461 |
|
462 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
463 |
|
464 |
-
query_states, key_states = apply_rotary_pos_emb(
|
|
|
|
|
465 |
|
466 |
if past_key_value is not None:
|
467 |
# reuse k, v, self_attention
|
@@ -497,7 +626,12 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
497 |
value_states = value_states.to(torch.float16)
|
498 |
|
499 |
attn_output = self._flash_attention_forward(
|
500 |
-
query_states,
|
|
|
|
|
|
|
|
|
|
|
501 |
)
|
502 |
|
503 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
@@ -509,7 +643,14 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
509 |
return attn_output, attn_weights, past_key_value
|
510 |
|
511 |
def _flash_attention_forward(
|
512 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
):
|
514 |
"""
|
515 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
@@ -533,7 +674,14 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
533 |
# Contains at least one padding token in the sequence
|
534 |
if padding_mask is not None:
|
535 |
batch_size = query_states.shape[0]
|
536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
query_states, key_states, value_states, padding_mask, query_length
|
538 |
)
|
539 |
|
@@ -553,27 +701,39 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
553 |
causal=True,
|
554 |
)
|
555 |
|
556 |
-
attn_output = pad_input(
|
|
|
|
|
557 |
else:
|
558 |
attn_output = flash_attn_func(
|
559 |
-
query_states,
|
|
|
|
|
|
|
|
|
|
|
560 |
)
|
561 |
|
562 |
return attn_output
|
563 |
|
564 |
-
def _upad_input(
|
|
|
|
|
565 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
566 |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
567 |
|
568 |
key_layer = index_first_axis(
|
569 |
-
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
|
|
570 |
)
|
571 |
value_layer = index_first_axis(
|
572 |
-
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
|
|
573 |
)
|
574 |
if query_length == kv_seq_len:
|
575 |
query_layer = index_first_axis(
|
576 |
-
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
|
|
577 |
)
|
578 |
cu_seqlens_q = cu_seqlens_k
|
579 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
@@ -588,7 +748,9 @@ class LlamaFlashAttention2(LlamaAttention):
|
|
588 |
else:
|
589 |
# The -q_len: slice assumes left padding.
|
590 |
padding_mask = padding_mask[:, -query_length:]
|
591 |
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
|
|
|
|
592 |
|
593 |
return (
|
594 |
query_layer,
|
@@ -611,7 +773,9 @@ class LlamaDecoderLayer(nn.Module):
|
|
611 |
)
|
612 |
self.mlp = LlamaMLP(config)
|
613 |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
614 |
-
self.post_attention_layernorm = LlamaRMSNorm(
|
|
|
|
|
615 |
|
616 |
def forward(
|
617 |
self,
|
@@ -622,7 +786,9 @@ class LlamaDecoderLayer(nn.Module):
|
|
622 |
output_attentions: Optional[bool] = False,
|
623 |
use_cache: Optional[bool] = False,
|
624 |
padding_mask: Optional[torch.LongTensor] = None,
|
625 |
-
) -> Tuple[
|
|
|
|
|
626 |
"""
|
627 |
Args:
|
628 |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
@@ -796,8 +962,12 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
796 |
self.padding_idx = config.pad_token_id
|
797 |
self.vocab_size = config.vocab_size
|
798 |
|
799 |
-
self.embed_tokens = nn.Embedding(
|
800 |
-
|
|
|
|
|
|
|
|
|
801 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
802 |
|
803 |
self.gradient_checkpointing = False
|
@@ -811,7 +981,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
811 |
self.embed_tokens = value
|
812 |
|
813 |
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
814 |
-
def _prepare_decoder_attention_mask(
|
|
|
|
|
815 |
# create causal mask
|
816 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
817 |
combined_attention_mask = None
|
@@ -825,11 +997,13 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
825 |
|
826 |
if attention_mask is not None:
|
827 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
828 |
-
expanded_attn_mask = _expand_mask(
|
829 |
-
inputs_embeds.
|
830 |
-
)
|
831 |
combined_attention_mask = (
|
832 |
-
expanded_attn_mask
|
|
|
|
|
833 |
)
|
834 |
|
835 |
return combined_attention_mask
|
@@ -847,17 +1021,27 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
847 |
output_hidden_states: Optional[bool] = None,
|
848 |
return_dict: Optional[bool] = None,
|
849 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
850 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
851 |
output_hidden_states = (
|
852 |
-
output_hidden_states
|
|
|
|
|
853 |
)
|
854 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
855 |
|
856 |
-
return_dict =
|
|
|
|
|
857 |
|
858 |
# retrieve input_ids and inputs_embeds
|
859 |
if input_ids is not None and inputs_embeds is not None:
|
860 |
-
raise ValueError(
|
|
|
|
|
861 |
elif input_ids is not None:
|
862 |
batch_size, seq_length = input_ids.shape
|
863 |
elif inputs_embeds is not None:
|
@@ -875,7 +1059,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
875 |
if position_ids is None:
|
876 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
877 |
position_ids = torch.arange(
|
878 |
-
past_key_values_length,
|
|
|
|
|
|
|
879 |
)
|
880 |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
881 |
else:
|
@@ -886,7 +1073,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
886 |
# embed positions
|
887 |
if attention_mask is None:
|
888 |
attention_mask = torch.ones(
|
889 |
-
(batch_size, seq_length_with_past),
|
|
|
|
|
890 |
)
|
891 |
padding_mask = None
|
892 |
else:
|
@@ -896,7 +1085,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
896 |
padding_mask = None
|
897 |
|
898 |
attention_mask = self._prepare_decoder_attention_mask(
|
899 |
-
attention_mask,
|
|
|
|
|
|
|
900 |
)
|
901 |
|
902 |
hidden_states = inputs_embeds
|
@@ -917,19 +1109,29 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
917 |
if output_hidden_states:
|
918 |
all_hidden_states += (hidden_states,)
|
919 |
|
920 |
-
past_key_value =
|
|
|
|
|
921 |
|
922 |
if self.gradient_checkpointing and self.training:
|
923 |
|
924 |
def create_custom_forward(module):
|
925 |
def custom_forward(*inputs):
|
926 |
# None for past_key_value
|
927 |
-
return module(
|
|
|
|
|
|
|
|
|
|
|
928 |
|
929 |
return custom_forward
|
930 |
|
931 |
layer_outputs = torch.utils.checkpoint.checkpoint(
|
932 |
-
create_custom_forward(decoder_layer),
|
|
|
|
|
|
|
933 |
)
|
934 |
else:
|
935 |
layer_outputs = decoder_layer(
|
@@ -958,7 +1160,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
958 |
|
959 |
next_cache = next_decoder_cache if use_cache else None
|
960 |
if not return_dict:
|
961 |
-
return tuple(
|
|
|
|
|
|
|
|
|
962 |
return BaseModelOutputWithPast(
|
963 |
last_hidden_state=hidden_states,
|
964 |
past_key_values=next_cache,
|
@@ -998,7 +1204,9 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
998 |
return self.model
|
999 |
|
1000 |
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1001 |
-
@replace_return_docstrings(
|
|
|
|
|
1002 |
def forward(
|
1003 |
self,
|
1004 |
input_ids: torch.LongTensor = None,
|
@@ -1038,11 +1246,19 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1038 |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1039 |
```"""
|
1040 |
|
1041 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
1042 |
output_hidden_states = (
|
1043 |
-
output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
1044 |
)
|
1045 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1046 |
|
1047 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1048 |
outputs = self.model(
|
@@ -1059,8 +1275,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1059 |
|
1060 |
hidden_states = outputs[0]
|
1061 |
if self.config.pretraining_tp > 1:
|
1062 |
-
lm_head_slices = self.lm_head.weight.split(
|
1063 |
-
|
|
|
|
|
|
|
|
|
|
|
1064 |
logits = torch.cat(logits, dim=-1)
|
1065 |
else:
|
1066 |
logits = self.lm_head(hidden_states)
|
@@ -1092,7 +1313,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1092 |
)
|
1093 |
|
1094 |
def prepare_inputs_for_generation(
|
1095 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
1096 |
):
|
1097 |
if past_key_values:
|
1098 |
input_ids = input_ids[:, -1:]
|
@@ -1126,7 +1352,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
1126 |
reordered_past = ()
|
1127 |
for layer_past in past_key_values:
|
1128 |
reordered_past += (
|
1129 |
-
tuple(
|
|
|
|
|
|
|
1130 |
)
|
1131 |
return reordered_past
|
1132 |
|
@@ -1182,7 +1411,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|
1182 |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1183 |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1184 |
"""
|
1185 |
-
return_dict =
|
|
|
|
|
1186 |
|
1187 |
transformer_outputs = self.model(
|
1188 |
input_ids,
|
@@ -1204,18 +1435,22 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|
1204 |
batch_size = inputs_embeds.shape[0]
|
1205 |
|
1206 |
if self.config.pad_token_id is None and batch_size != 1:
|
1207 |
-
raise ValueError(
|
|
|
|
|
1208 |
if self.config.pad_token_id is None:
|
1209 |
sequence_lengths = -1
|
1210 |
else:
|
1211 |
if input_ids is not None:
|
1212 |
-
sequence_lengths = (
|
1213 |
-
|
1214 |
-
)
|
1215 |
else:
|
1216 |
sequence_lengths = -1
|
1217 |
|
1218 |
-
pooled_logits = logits[
|
|
|
|
|
1219 |
|
1220 |
loss = None
|
1221 |
if labels is not None:
|
@@ -1223,7 +1458,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|
1223 |
if self.config.problem_type is None:
|
1224 |
if self.num_labels == 1:
|
1225 |
self.config.problem_type = "regression"
|
1226 |
-
elif self.num_labels > 1 and (
|
|
|
|
|
1227 |
self.config.problem_type = "single_label_classification"
|
1228 |
else:
|
1229 |
self.config.problem_type = "multi_label_classification"
|
@@ -1236,7 +1473,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|
1236 |
loss = loss_fct(pooled_logits, labels)
|
1237 |
elif self.config.problem_type == "single_label_classification":
|
1238 |
loss_fct = CrossEntropyLoss()
|
1239 |
-
loss = loss_fct(
|
|
|
|
|
1240 |
elif self.config.problem_type == "multi_label_classification":
|
1241 |
loss_fct = BCEWithLogitsLoss()
|
1242 |
loss = loss_fct(pooled_logits, labels)
|
|
|
27 |
from torch import nn
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
from transformers.activations import ACT2FN
|
30 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
31 |
+
CausalLMOutputWithPast,
|
32 |
+
SequenceClassifierOutputWithPast)
|
|
|
|
|
33 |
from transformers.modeling_utils import PreTrainedModel
|
34 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
35 |
+
from transformers.utils import (add_start_docstrings,
|
36 |
+
add_start_docstrings_to_model_forward,
|
37 |
+
is_flash_attn_available, logging,
|
38 |
+
replace_return_docstrings)
|
|
|
|
|
|
|
39 |
|
40 |
from .configuration_llama import LlamaConfig
|
41 |
+
from .inference import FinalizedQuantizedLinear
|
42 |
|
43 |
if is_flash_attn_available():
|
44 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
45 |
+
from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
|
46 |
+
unpad_input)
|
47 |
|
48 |
|
49 |
logger = logging.get_logger(__name__)
|
|
|
55 |
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
|
56 |
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
|
57 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
58 |
+
cu_seqlens = F.pad(
|
59 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
60 |
+
)
|
61 |
return (
|
62 |
indices,
|
63 |
cu_seqlens,
|
|
|
67 |
|
68 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
69 |
def _make_causal_mask(
|
70 |
+
input_ids_shape: torch.Size,
|
71 |
+
dtype: torch.dtype,
|
72 |
+
device: torch.device,
|
73 |
+
past_key_values_length: int = 0,
|
74 |
):
|
75 |
"""
|
76 |
Make causal mask used for bi-directional self-attention.
|
|
|
82 |
mask = mask.to(dtype)
|
83 |
|
84 |
if past_key_values_length > 0:
|
85 |
+
mask = torch.cat(
|
86 |
+
[
|
87 |
+
torch.zeros(
|
88 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
89 |
+
),
|
90 |
+
mask,
|
91 |
+
],
|
92 |
+
dim=-1,
|
93 |
+
)
|
94 |
+
return mask[None, None, :, :].expand(
|
95 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
96 |
+
)
|
97 |
|
98 |
|
99 |
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
|
|
108 |
|
109 |
inverted_mask = 1.0 - expanded_mask
|
110 |
|
111 |
+
return inverted_mask.masked_fill(
|
112 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
113 |
+
)
|
114 |
|
115 |
|
116 |
class LlamaRMSNorm(nn.Module):
|
|
|
140 |
self.dim = dim
|
141 |
self.max_position_embeddings = max_position_embeddings
|
142 |
self.base = base
|
143 |
+
inv_freq = 1.0 / (
|
144 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
145 |
+
)
|
146 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
147 |
|
148 |
# Build here to make `torch.jit.trace` work.
|
149 |
self._set_cos_sin_cache(
|
150 |
+
seq_len=max_position_embeddings,
|
151 |
+
device=self.inv_freq.device,
|
152 |
+
dtype=torch.get_default_dtype(),
|
153 |
)
|
154 |
|
155 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
156 |
self.max_seq_len_cached = seq_len
|
157 |
+
t = torch.arange(
|
158 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
159 |
+
)
|
160 |
|
161 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
162 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
163 |
emb = torch.cat((freqs, freqs), dim=-1)
|
164 |
+
self.register_buffer(
|
165 |
+
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
|
166 |
+
)
|
167 |
+
self.register_buffer(
|
168 |
+
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
|
169 |
+
)
|
170 |
|
171 |
def forward(self, x, seq_len=None):
|
172 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
|
182 |
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
183 |
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
184 |
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
dim,
|
188 |
+
max_position_embeddings=2048,
|
189 |
+
base=10000,
|
190 |
+
device=None,
|
191 |
+
scaling_factor=1.0,
|
192 |
+
):
|
193 |
self.scaling_factor = scaling_factor
|
194 |
super().__init__(dim, max_position_embeddings, base, device)
|
195 |
|
196 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
197 |
self.max_seq_len_cached = seq_len
|
198 |
+
t = torch.arange(
|
199 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
200 |
+
)
|
201 |
t = t / self.scaling_factor
|
202 |
|
203 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
204 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
205 |
emb = torch.cat((freqs, freqs), dim=-1)
|
206 |
+
self.register_buffer(
|
207 |
+
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
|
208 |
+
)
|
209 |
+
self.register_buffer(
|
210 |
+
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
|
211 |
+
)
|
212 |
|
213 |
|
214 |
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
215 |
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
216 |
|
217 |
+
def __init__(
|
218 |
+
self,
|
219 |
+
dim,
|
220 |
+
max_position_embeddings=2048,
|
221 |
+
base=10000,
|
222 |
+
device=None,
|
223 |
+
scaling_factor=1.0,
|
224 |
+
):
|
225 |
self.scaling_factor = scaling_factor
|
226 |
super().__init__(dim, max_position_embeddings, base, device)
|
227 |
|
|
|
230 |
|
231 |
if seq_len > self.max_position_embeddings:
|
232 |
base = self.base * (
|
233 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
234 |
+
- (self.scaling_factor - 1)
|
235 |
) ** (self.dim / (self.dim - 2))
|
236 |
+
inv_freq = 1.0 / (
|
237 |
+
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
238 |
+
)
|
239 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
240 |
|
241 |
+
t = torch.arange(
|
242 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
243 |
+
)
|
244 |
|
245 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
246 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
247 |
emb = torch.cat((freqs, freqs), dim=-1)
|
248 |
+
self.register_buffer(
|
249 |
+
"cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
|
250 |
+
)
|
251 |
+
self.register_buffer(
|
252 |
+
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
|
253 |
+
)
|
254 |
|
255 |
|
256 |
def rotate_half(x):
|
|
|
277 |
self.config = config
|
278 |
self.hidden_size = config.hidden_size
|
279 |
self.intermediate_size = config.intermediate_size
|
280 |
+
self.gate_proj = FinalizedQuantizedLinear(
|
281 |
+
self.hidden_size, self.intermediate_size, bias=False, **config.aqlm
|
282 |
+
)
|
283 |
+
self.up_proj = FinalizedQuantizedLinear(
|
284 |
+
self.hidden_size, self.intermediate_size, bias=False, **config.aqlm
|
285 |
+
)
|
286 |
+
self.down_proj = FinalizedQuantizedLinear(
|
287 |
+
self.intermediate_size, self.hidden_size, bias=False, **config.aqlm
|
288 |
+
)
|
289 |
self.act_fn = ACT2FN[config.hidden_act]
|
290 |
|
291 |
def forward(self, x):
|
|
|
295 |
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
296 |
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
297 |
|
298 |
+
gate_proj = torch.cat(
|
299 |
+
[
|
300 |
+
F.linear(x, gate_proj_slices[i])
|
301 |
+
for i in range(self.config.pretraining_tp)
|
302 |
+
],
|
303 |
+
dim=-1,
|
304 |
+
)
|
305 |
+
up_proj = torch.cat(
|
306 |
+
[
|
307 |
+
F.linear(x, up_proj_slices[i])
|
308 |
+
for i in range(self.config.pretraining_tp)
|
309 |
+
],
|
310 |
+
dim=-1,
|
311 |
+
)
|
312 |
|
313 |
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
314 |
down_proj = [
|
315 |
+
F.linear(intermediate_states[i], down_proj_slices[i])
|
316 |
+
for i in range(self.config.pretraining_tp)
|
317 |
]
|
318 |
down_proj = sum(down_proj)
|
319 |
else:
|
|
|
330 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
331 |
if n_rep == 1:
|
332 |
return hidden_states
|
333 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
334 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
335 |
+
)
|
336 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
337 |
|
338 |
|
|
|
356 |
f" and `num_heads`: {self.num_heads})."
|
357 |
)
|
358 |
self.q_proj = FinalizedQuantizedLinear(
|
359 |
+
self.hidden_size,
|
360 |
+
self.num_heads * self.head_dim,
|
361 |
+
bias=config.attention_bias,
|
362 |
+
**config.aqlm,
|
363 |
)
|
364 |
self.k_proj = FinalizedQuantizedLinear(
|
365 |
+
self.hidden_size,
|
366 |
+
self.num_key_value_heads * self.head_dim,
|
367 |
+
bias=config.attention_bias,
|
368 |
+
**config.aqlm,
|
369 |
)
|
370 |
self.v_proj = FinalizedQuantizedLinear(
|
371 |
+
self.hidden_size,
|
372 |
+
self.num_key_value_heads * self.head_dim,
|
373 |
+
bias=config.attention_bias,
|
374 |
+
**config.aqlm,
|
375 |
)
|
376 |
self.o_proj = FinalizedQuantizedLinear(
|
377 |
+
self.num_heads * self.head_dim,
|
378 |
+
self.hidden_size,
|
379 |
+
bias=config.attention_bias,
|
380 |
+
**config.aqlm,
|
381 |
)
|
382 |
self._init_rope()
|
383 |
|
|
|
409 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
410 |
|
411 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
412 |
+
return (
|
413 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
414 |
+
.transpose(1, 2)
|
415 |
+
.contiguous()
|
416 |
+
)
|
417 |
|
418 |
def forward(
|
419 |
self,
|
|
|
428 |
bsz, q_len, _ = hidden_states.size()
|
429 |
|
430 |
if self.config.pretraining_tp > 1:
|
431 |
+
key_value_slicing = (
|
432 |
+
self.num_key_value_heads * self.head_dim
|
433 |
+
) // self.config.pretraining_tp
|
434 |
query_slices = self.q_proj.weight.split(
|
435 |
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
436 |
)
|
437 |
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
438 |
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
439 |
|
440 |
+
query_states = [
|
441 |
+
F.linear(hidden_states, query_slices[i])
|
442 |
+
for i in range(self.config.pretraining_tp)
|
443 |
+
]
|
444 |
query_states = torch.cat(query_states, dim=-1)
|
445 |
|
446 |
+
key_states = [
|
447 |
+
F.linear(hidden_states, key_slices[i])
|
448 |
+
for i in range(self.config.pretraining_tp)
|
449 |
+
]
|
450 |
key_states = torch.cat(key_states, dim=-1)
|
451 |
|
452 |
+
value_states = [
|
453 |
+
F.linear(hidden_states, value_slices[i])
|
454 |
+
for i in range(self.config.pretraining_tp)
|
455 |
+
]
|
456 |
value_states = torch.cat(value_states, dim=-1)
|
457 |
|
458 |
else:
|
|
|
460 |
key_states = self.k_proj(hidden_states)
|
461 |
value_states = self.v_proj(hidden_states)
|
462 |
|
463 |
+
query_states = query_states.view(
|
464 |
+
bsz, q_len, self.num_heads, self.head_dim
|
465 |
+
).transpose(1, 2)
|
466 |
+
key_states = key_states.view(
|
467 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
468 |
+
).transpose(1, 2)
|
469 |
+
value_states = value_states.view(
|
470 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
471 |
+
).transpose(1, 2)
|
472 |
|
473 |
kv_seq_len = key_states.shape[-2]
|
474 |
if past_key_value is not None:
|
475 |
kv_seq_len += past_key_value[0].shape[-2]
|
476 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
477 |
+
query_states, key_states = apply_rotary_pos_emb(
|
478 |
+
query_states, key_states, cos, sin, position_ids
|
479 |
+
)
|
480 |
|
481 |
if past_key_value is not None:
|
482 |
# reuse k, v, self_attention
|
|
|
488 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
489 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
490 |
|
491 |
+
attn_weights = torch.matmul(
|
492 |
+
query_states, key_states.transpose(2, 3)
|
493 |
+
) / math.sqrt(self.head_dim)
|
494 |
|
495 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
496 |
raise ValueError(
|
|
|
506 |
attn_weights = attn_weights + attention_mask
|
507 |
|
508 |
# upcast attention to fp32
|
509 |
+
attn_weights = nn.functional.softmax(
|
510 |
+
attn_weights, dim=-1, dtype=torch.float32
|
511 |
+
).to(query_states.dtype)
|
512 |
attn_output = torch.matmul(attn_weights, value_states)
|
513 |
|
514 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
522 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
523 |
|
524 |
if self.config.pretraining_tp > 1:
|
525 |
+
attn_output = attn_output.split(
|
526 |
+
self.hidden_size // self.config.pretraining_tp, dim=2
|
527 |
+
)
|
528 |
+
o_proj_slices = self.o_proj.weight.split(
|
529 |
+
self.hidden_size // self.config.pretraining_tp, dim=1
|
530 |
+
)
|
531 |
+
attn_output = sum(
|
532 |
+
[
|
533 |
+
F.linear(attn_output[i], o_proj_slices[i])
|
534 |
+
for i in range(self.config.pretraining_tp)
|
535 |
+
]
|
536 |
+
)
|
537 |
else:
|
538 |
attn_output = self.o_proj(attn_output)
|
539 |
|
|
|
572 |
# Flash attention requires the input to have the shape
|
573 |
# batch_size x seq_length x head_dime x hidden_dim
|
574 |
# therefore we just need to keep the original shape
|
575 |
+
query_states = query_states.view(
|
576 |
+
bsz, q_len, self.num_heads, self.head_dim
|
577 |
+
).transpose(1, 2)
|
578 |
+
key_states = key_states.view(
|
579 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
580 |
+
).transpose(1, 2)
|
581 |
+
value_states = value_states.view(
|
582 |
+
bsz, q_len, self.num_key_value_heads, self.head_dim
|
583 |
+
).transpose(1, 2)
|
584 |
|
585 |
kv_seq_len = key_states.shape[-2]
|
586 |
if past_key_value is not None:
|
|
|
588 |
|
589 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
590 |
|
591 |
+
query_states, key_states = apply_rotary_pos_emb(
|
592 |
+
query_states, key_states, cos, sin, position_ids
|
593 |
+
)
|
594 |
|
595 |
if past_key_value is not None:
|
596 |
# reuse k, v, self_attention
|
|
|
626 |
value_states = value_states.to(torch.float16)
|
627 |
|
628 |
attn_output = self._flash_attention_forward(
|
629 |
+
query_states,
|
630 |
+
key_states,
|
631 |
+
value_states,
|
632 |
+
padding_mask,
|
633 |
+
q_len,
|
634 |
+
dropout=dropout_rate,
|
635 |
)
|
636 |
|
637 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
|
643 |
return attn_output, attn_weights, past_key_value
|
644 |
|
645 |
def _flash_attention_forward(
|
646 |
+
self,
|
647 |
+
query_states,
|
648 |
+
key_states,
|
649 |
+
value_states,
|
650 |
+
padding_mask,
|
651 |
+
query_length,
|
652 |
+
dropout=0.0,
|
653 |
+
softmax_scale=None,
|
654 |
):
|
655 |
"""
|
656 |
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
674 |
# Contains at least one padding token in the sequence
|
675 |
if padding_mask is not None:
|
676 |
batch_size = query_states.shape[0]
|
677 |
+
(
|
678 |
+
query_states,
|
679 |
+
key_states,
|
680 |
+
value_states,
|
681 |
+
indices_q,
|
682 |
+
cu_seq_lens,
|
683 |
+
max_seq_lens,
|
684 |
+
) = self._upad_input(
|
685 |
query_states, key_states, value_states, padding_mask, query_length
|
686 |
)
|
687 |
|
|
|
701 |
causal=True,
|
702 |
)
|
703 |
|
704 |
+
attn_output = pad_input(
|
705 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
706 |
+
)
|
707 |
else:
|
708 |
attn_output = flash_attn_func(
|
709 |
+
query_states,
|
710 |
+
key_states,
|
711 |
+
value_states,
|
712 |
+
dropout,
|
713 |
+
softmax_scale=softmax_scale,
|
714 |
+
causal=True,
|
715 |
)
|
716 |
|
717 |
return attn_output
|
718 |
|
719 |
+
def _upad_input(
|
720 |
+
self, query_layer, key_layer, value_layer, padding_mask, query_length
|
721 |
+
):
|
722 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
|
723 |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
724 |
|
725 |
key_layer = index_first_axis(
|
726 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
727 |
+
indices_k,
|
728 |
)
|
729 |
value_layer = index_first_axis(
|
730 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
731 |
+
indices_k,
|
732 |
)
|
733 |
if query_length == kv_seq_len:
|
734 |
query_layer = index_first_axis(
|
735 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
736 |
+
indices_k,
|
737 |
)
|
738 |
cu_seqlens_q = cu_seqlens_k
|
739 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
748 |
else:
|
749 |
# The -q_len: slice assumes left padding.
|
750 |
padding_mask = padding_mask[:, -query_length:]
|
751 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
752 |
+
query_layer, padding_mask
|
753 |
+
)
|
754 |
|
755 |
return (
|
756 |
query_layer,
|
|
|
773 |
)
|
774 |
self.mlp = LlamaMLP(config)
|
775 |
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
776 |
+
self.post_attention_layernorm = LlamaRMSNorm(
|
777 |
+
config.hidden_size, eps=config.rms_norm_eps
|
778 |
+
)
|
779 |
|
780 |
def forward(
|
781 |
self,
|
|
|
786 |
output_attentions: Optional[bool] = False,
|
787 |
use_cache: Optional[bool] = False,
|
788 |
padding_mask: Optional[torch.LongTensor] = None,
|
789 |
+
) -> Tuple[
|
790 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
791 |
+
]:
|
792 |
"""
|
793 |
Args:
|
794 |
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
|
962 |
self.padding_idx = config.pad_token_id
|
963 |
self.vocab_size = config.vocab_size
|
964 |
|
965 |
+
self.embed_tokens = nn.Embedding(
|
966 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
967 |
+
)
|
968 |
+
self.layers = nn.ModuleList(
|
969 |
+
[LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
970 |
+
)
|
971 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
972 |
|
973 |
self.gradient_checkpointing = False
|
|
|
981 |
self.embed_tokens = value
|
982 |
|
983 |
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
984 |
+
def _prepare_decoder_attention_mask(
|
985 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
986 |
+
):
|
987 |
# create causal mask
|
988 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
989 |
combined_attention_mask = None
|
|
|
997 |
|
998 |
if attention_mask is not None:
|
999 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1000 |
+
expanded_attn_mask = _expand_mask(
|
1001 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
1002 |
+
).to(inputs_embeds.device)
|
1003 |
combined_attention_mask = (
|
1004 |
+
expanded_attn_mask
|
1005 |
+
if combined_attention_mask is None
|
1006 |
+
else expanded_attn_mask + combined_attention_mask
|
1007 |
)
|
1008 |
|
1009 |
return combined_attention_mask
|
|
|
1021 |
output_hidden_states: Optional[bool] = None,
|
1022 |
return_dict: Optional[bool] = None,
|
1023 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1024 |
+
output_attentions = (
|
1025 |
+
output_attentions
|
1026 |
+
if output_attentions is not None
|
1027 |
+
else self.config.output_attentions
|
1028 |
+
)
|
1029 |
output_hidden_states = (
|
1030 |
+
output_hidden_states
|
1031 |
+
if output_hidden_states is not None
|
1032 |
+
else self.config.output_hidden_states
|
1033 |
)
|
1034 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1035 |
|
1036 |
+
return_dict = (
|
1037 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1038 |
+
)
|
1039 |
|
1040 |
# retrieve input_ids and inputs_embeds
|
1041 |
if input_ids is not None and inputs_embeds is not None:
|
1042 |
+
raise ValueError(
|
1043 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
1044 |
+
)
|
1045 |
elif input_ids is not None:
|
1046 |
batch_size, seq_length = input_ids.shape
|
1047 |
elif inputs_embeds is not None:
|
|
|
1059 |
if position_ids is None:
|
1060 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
1061 |
position_ids = torch.arange(
|
1062 |
+
past_key_values_length,
|
1063 |
+
seq_length + past_key_values_length,
|
1064 |
+
dtype=torch.long,
|
1065 |
+
device=device,
|
1066 |
)
|
1067 |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
1068 |
else:
|
|
|
1073 |
# embed positions
|
1074 |
if attention_mask is None:
|
1075 |
attention_mask = torch.ones(
|
1076 |
+
(batch_size, seq_length_with_past),
|
1077 |
+
dtype=torch.bool,
|
1078 |
+
device=inputs_embeds.device,
|
1079 |
)
|
1080 |
padding_mask = None
|
1081 |
else:
|
|
|
1085 |
padding_mask = None
|
1086 |
|
1087 |
attention_mask = self._prepare_decoder_attention_mask(
|
1088 |
+
attention_mask,
|
1089 |
+
(batch_size, seq_length),
|
1090 |
+
inputs_embeds,
|
1091 |
+
past_key_values_length,
|
1092 |
)
|
1093 |
|
1094 |
hidden_states = inputs_embeds
|
|
|
1109 |
if output_hidden_states:
|
1110 |
all_hidden_states += (hidden_states,)
|
1111 |
|
1112 |
+
past_key_value = (
|
1113 |
+
past_key_values[idx] if past_key_values is not None else None
|
1114 |
+
)
|
1115 |
|
1116 |
if self.gradient_checkpointing and self.training:
|
1117 |
|
1118 |
def create_custom_forward(module):
|
1119 |
def custom_forward(*inputs):
|
1120 |
# None for past_key_value
|
1121 |
+
return module(
|
1122 |
+
*inputs,
|
1123 |
+
past_key_value,
|
1124 |
+
output_attentions,
|
1125 |
+
padding_mask=padding_mask,
|
1126 |
+
)
|
1127 |
|
1128 |
return custom_forward
|
1129 |
|
1130 |
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1131 |
+
create_custom_forward(decoder_layer),
|
1132 |
+
hidden_states,
|
1133 |
+
attention_mask,
|
1134 |
+
position_ids,
|
1135 |
)
|
1136 |
else:
|
1137 |
layer_outputs = decoder_layer(
|
|
|
1160 |
|
1161 |
next_cache = next_decoder_cache if use_cache else None
|
1162 |
if not return_dict:
|
1163 |
+
return tuple(
|
1164 |
+
v
|
1165 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
1166 |
+
if v is not None
|
1167 |
+
)
|
1168 |
return BaseModelOutputWithPast(
|
1169 |
last_hidden_state=hidden_states,
|
1170 |
past_key_values=next_cache,
|
|
|
1204 |
return self.model
|
1205 |
|
1206 |
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
1207 |
+
@replace_return_docstrings(
|
1208 |
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
1209 |
+
)
|
1210 |
def forward(
|
1211 |
self,
|
1212 |
input_ids: torch.LongTensor = None,
|
|
|
1246 |
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
1247 |
```"""
|
1248 |
|
1249 |
+
output_attentions = (
|
1250 |
+
output_attentions
|
1251 |
+
if output_attentions is not None
|
1252 |
+
else self.config.output_attentions
|
1253 |
+
)
|
1254 |
output_hidden_states = (
|
1255 |
+
output_hidden_states
|
1256 |
+
if output_hidden_states is not None
|
1257 |
+
else self.config.output_hidden_states
|
1258 |
+
)
|
1259 |
+
return_dict = (
|
1260 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1261 |
)
|
|
|
1262 |
|
1263 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1264 |
outputs = self.model(
|
|
|
1275 |
|
1276 |
hidden_states = outputs[0]
|
1277 |
if self.config.pretraining_tp > 1:
|
1278 |
+
lm_head_slices = self.lm_head.weight.split(
|
1279 |
+
self.vocab_size // self.config.pretraining_tp, dim=0
|
1280 |
+
)
|
1281 |
+
logits = [
|
1282 |
+
F.linear(hidden_states, lm_head_slices[i])
|
1283 |
+
for i in range(self.config.pretraining_tp)
|
1284 |
+
]
|
1285 |
logits = torch.cat(logits, dim=-1)
|
1286 |
else:
|
1287 |
logits = self.lm_head(hidden_states)
|
|
|
1313 |
)
|
1314 |
|
1315 |
def prepare_inputs_for_generation(
|
1316 |
+
self,
|
1317 |
+
input_ids,
|
1318 |
+
past_key_values=None,
|
1319 |
+
attention_mask=None,
|
1320 |
+
inputs_embeds=None,
|
1321 |
+
**kwargs,
|
1322 |
):
|
1323 |
if past_key_values:
|
1324 |
input_ids = input_ids[:, -1:]
|
|
|
1352 |
reordered_past = ()
|
1353 |
for layer_past in past_key_values:
|
1354 |
reordered_past += (
|
1355 |
+
tuple(
|
1356 |
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
1357 |
+
for past_state in layer_past
|
1358 |
+
),
|
1359 |
)
|
1360 |
return reordered_past
|
1361 |
|
|
|
1411 |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1412 |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1413 |
"""
|
1414 |
+
return_dict = (
|
1415 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1416 |
+
)
|
1417 |
|
1418 |
transformer_outputs = self.model(
|
1419 |
input_ids,
|
|
|
1435 |
batch_size = inputs_embeds.shape[0]
|
1436 |
|
1437 |
if self.config.pad_token_id is None and batch_size != 1:
|
1438 |
+
raise ValueError(
|
1439 |
+
"Cannot handle batch sizes > 1 if no padding token is defined."
|
1440 |
+
)
|
1441 |
if self.config.pad_token_id is None:
|
1442 |
sequence_lengths = -1
|
1443 |
else:
|
1444 |
if input_ids is not None:
|
1445 |
+
sequence_lengths = (
|
1446 |
+
torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
|
1447 |
+
).to(logits.device)
|
1448 |
else:
|
1449 |
sequence_lengths = -1
|
1450 |
|
1451 |
+
pooled_logits = logits[
|
1452 |
+
torch.arange(batch_size, device=logits.device), sequence_lengths
|
1453 |
+
]
|
1454 |
|
1455 |
loss = None
|
1456 |
if labels is not None:
|
|
|
1458 |
if self.config.problem_type is None:
|
1459 |
if self.num_labels == 1:
|
1460 |
self.config.problem_type = "regression"
|
1461 |
+
elif self.num_labels > 1 and (
|
1462 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
1463 |
+
):
|
1464 |
self.config.problem_type = "single_label_classification"
|
1465 |
else:
|
1466 |
self.config.problem_type = "multi_label_classification"
|
|
|
1473 |
loss = loss_fct(pooled_logits, labels)
|
1474 |
elif self.config.problem_type == "single_label_classification":
|
1475 |
loss_fct = CrossEntropyLoss()
|
1476 |
+
loss = loss_fct(
|
1477 |
+
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
1478 |
+
)
|
1479 |
elif self.config.problem_type == "multi_label_classification":
|
1480 |
loss_fct = BCEWithLogitsLoss()
|
1481 |
loss = loss_fct(pooled_logits, labels)
|