File size: 12,262 Bytes
5c0d7ef 5edaefc 5c0d7ef 5edaefc 5c0d7ef 5edaefc 5c0d7ef 5edaefc 5c0d7ef 5edaefc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
""" Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search"""
import functools
import os
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
class FinalizedQuantizedLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
in_group_size: int,
out_group_size: int,
num_codebooks: int,
nbits_per_codebook: int,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
assert self.in_features % in_group_size == 0
assert self.out_features % out_group_size == 0
num_out_groups = out_features // out_group_size
num_in_groups = in_features // in_group_size
self.out_group_size, self.in_group_size = out_group_size, in_group_size
self.num_codebooks = num_codebooks
self.nbits_per_codebook = nbits_per_codebook
self.codebook_size = 2**nbits_per_codebook
# CODES & CODEBOOKS
self.codebooks = nn.Parameter(
torch.empty(
(num_codebooks, self.codebook_size, out_group_size, in_group_size),
**factory_kwargs,
),
requires_grad=True,
) # [num_codebooks, codebook_size, out_group_size, in_group_size]
self.codes = nn.Parameter(
torch.empty(
(num_out_groups, num_in_groups, num_codebooks),
device=device,
dtype=get_int_dtype(nbits_per_codebook),
),
requires_grad=False,
) # [num_out_groups, num_in_groups, num_codebooks]
# SCALES
self.scales = nn.Parameter(
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True
) # [num_out_groups, num_in_groups, 1, 1] if scale_nbits > 0 else [num_out_groups, 1, 1, 1]
# BIAS
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return forward_pass_quantized_linear(
input, self.codes, self.codebooks, self.scales, self.bias
)
def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)
@functools.lru_cache()
def maybe_script(fn: callable) -> callable:
"""Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script."""
using_tpu = bool(os.environ.get("TPU_NAME"))
# this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function
should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu))
return torch.jit.script(fn) if should_script else fn
@maybe_script
def _dequantize_weight(
codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size, device=codes.device
) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets,
codebooks.flatten(0, 1).flatten(-2, -1),
mode="sum",
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size]
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3])
+ [num_out_groups, num_in_groups, out_group_size, in_group_size]
)
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales)
return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape(
list(codes.shape[:-3]) + [out_features, in_features]
)
def forward_pass_quantized_linear(
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
if input.is_cuda:
matmul_result = aqlm_gemm_stupid(input, codes, codebooks, scales)
if bias is not None:
matmul_result += bias
return matmul_result
else:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
@triton.autotune(
configs=[
triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
for num_stages in (1, 2, 3, 4, 5)
for num_warps in (1, 2, 4, 8)
],
key=[
"in_features",
"out_features",
"num_codebooks",
"codebook_size",
"out_group_size",
"in_group_size",
"num_input_groups",
"num_input_groups_next_power_of_2",
"compute_in_fp32",
],
)
@triton.jit
def _aqlm_gemv_simple(
input_vec_ptr,
output_vec_ptr,
codes_i16_ptr,
codebooks_ptr,
scales_ptr,
in_features: tl.constexpr,
out_features: tl.constexpr,
num_codebooks: tl.constexpr,
codebook_size: tl.constexpr,
out_group_size: tl.constexpr,
in_group_size: tl.constexpr,
num_input_groups: tl.constexpr,
num_input_groups_next_power_of_2: tl.constexpr,
compute_in_fp32: tl.constexpr,
UNUSED: tl.constexpr,
):
# variables ending with "_i" mean "for i-th output unit"
pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
# Stage 1: load input data
input_vec = tl.load(
input_vec_ptr
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
+ tl.arange(0, in_group_size)[None, None, :],
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None]
< num_input_groups,
)
# [in_features//in_group_size, 1, group_size]
# Note: we could simply load input_vec then reshape
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
# , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
# Stage 2: load integer codes for the active row
# [in_features // in_group_size, num_codebooks]
codes_i_ptrs = (
codes_i16_ptr
+ pid * num_input_groups * num_codebooks
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
+ tl.arange(0, num_codebooks)[None, :]
)
codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
codes_i = tl.load(
codes_i_ptrs, mask=codes_i_mask_1d[:, None]
) # [in_features//in_group_size, num_codebooks]
if codes_i.dtype == tl.int16:
codes_i = codes_i.to(tl.int32)
codes_i = (codes_i) + (
codes_i < 0
) * codebook_size # aka 2 ** nbits_per_codebook
# ^-- (because codes are int16 tensors that contain uint data)
# The following alternative does not work:
# codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook
else:
codes_i = codes_i.to(tl.int32)
# shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
codes_i += (
tl.arange(0, num_codebooks)[None, :] * codebook_size
) # aka 2 ** nbits_per_codebook
# ^-- [in_group_size, num_codebooks]
# Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
# [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
weight_i_ptrs = (
codebooks_ptr
+ codes_i[:, :, None, None] * out_group_size * in_group_size
+ out_group_ix * in_group_size
+ in_group_ix
)
# Stage 4: reconstruct weights, multiply by inputs and write out
weights_i = tl.load(
weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0
)
if compute_in_fp32:
weights_i = weights_i.to(tl.float32)
input_vec = input_vec.to(tl.float32)
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
# ^-- [in_features // in_group_size, out_group_size, in_group_size]
if out_group_size == 1:
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
output_i = tl.sum(weights_i * input_vec) * scale
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
else:
output_i = tl.sum(
tl.sum(weights_i * input_vec, axis=2), axis=0
) # [out_group_size]
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
tl.store(
output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size),
output_i.to(input_vec.dtype),
)
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def aqlm_gemv_simple(
input_vec: torch.Tensor,
codes_i16: torch.ShortTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
compute_in_fp32: bool = True,
):
device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input_vec.shape[1]
out_features = codes_i16.shape[0] * out_group_size
num_input_groups = codes_i16.shape[1]
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebooks.shape[1] == 2**16
output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
# 1D launch kernel where each block computes output unit
grid = lambda META: (out_features // out_group_size,)
_aqlm_gemv_simple[grid](
input_vec,
output_vec,
codes_i16,
codebooks,
scales,
in_features,
out_features,
num_codebooks,
codebook_size,
out_group_size,
in_group_size,
num_input_groups,
next_power_of_2(num_input_groups),
compute_in_fp32,
)
return output_vec
def aqlm_gemm_stupid(
input: torch.Tensor,
codes_i16: torch.ShortTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
compute_in_fp32: bool = True,
):
original_shape = input.shape
input = input.reshape(-1, original_shape[-1])
return torch.cat(
[
aqlm_gemv_simple(
input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32
)
for input_vec in input
]
).reshape(original_shape[:-1] + (-1,))
|