""" Core mathematics for Additive Quantization (AQ): initialization, reconstruction and beam search""" import functools import os from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl class FinalizedQuantizedLinear(nn.Module): def __init__( self, in_features: int, out_features: int, in_group_size: int, out_group_size: int, num_codebooks: int, nbits_per_codebook: int, bias=True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.in_features = in_features self.out_features = out_features assert self.in_features % in_group_size == 0 assert self.out_features % out_group_size == 0 num_out_groups = out_features // out_group_size num_in_groups = in_features // in_group_size self.out_group_size, self.in_group_size = out_group_size, in_group_size self.num_codebooks = num_codebooks self.nbits_per_codebook = nbits_per_codebook self.codebook_size = 2**nbits_per_codebook # CODES & CODEBOOKS self.codebooks = nn.Parameter( torch.empty( (num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs, ), requires_grad=True, ) # [num_codebooks, codebook_size, out_group_size, in_group_size] self.codes = nn.Parameter( torch.empty( (num_out_groups, num_in_groups, num_codebooks), device=device, dtype=get_int_dtype(nbits_per_codebook), ), requires_grad=False, ) # [num_out_groups, num_in_groups, num_codebooks] # SCALES self.scales = nn.Parameter( torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True ) # [num_out_groups, num_in_groups, 1, 1] if scale_nbits > 0 else [num_out_groups, 1, 1, 1] # BIAS if bias: self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) else: self.register_parameter("bias", None) def forward(self, input: torch.Tensor) -> torch.Tensor: return forward_pass_quantized_linear( input, self.codes, self.codebooks, self.scales, self.bias ) def get_int_dtype(nbits: int) -> torch.dtype: if nbits <= 8: return torch.int8 if nbits <= 16: return torch.int16 if nbits <= 32: return torch.int32 if nbits <= 64: return torch.int64 raise ValueError(f"No dtype available for {nbits}-bit codebooks") @torch.inference_mode() def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: return data.to(torch.int64) % (2**nbits) @functools.lru_cache() def maybe_script(fn: callable) -> callable: """Apply torch.jit.script to function unless one is using TPU. TPU does not support torch.jit.script.""" using_tpu = bool(os.environ.get("TPU_NAME")) # this is a reserved variable that must be set to TPU address (e.g. grpc://11.22.33.44:1337) for TPU to function should_script = int(os.environ.get("AQ_USE_JIT", not using_tpu)) return torch.jit.script(fn) if should_script else fn @maybe_script def _dequantize_weight( codes: torch.Tensor, codebooks: torch.Tensor, scales: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Decode float weights from quantization codes. Differentiable. :param codes: tensor of integer quantization codes, shape [*dims, num_out_groups, num_in_groups, num_codebooks] :param codebooks: tensor of vectors for each quantization code, [num_codebooks, codebook_size, out_group_size, in_group_size] :param scales: weight will be multiplied by this factor, must be broadcastble with [*dims, out_groups, num_in_groups, out_group_size, in_group_size] :return: reconstructed weight tensor of shape [*dims, num_in_groups*group_size] """ num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape out_features = num_out_groups * out_group_size in_features = num_in_groups * in_group_size codebook_offsets = torch.arange( 0, num_codebooks * codebook_size, codebook_size, device=codes.device ) # shape: [num_codebooks] reconstructed_weight_flat = F.embedding_bag( codes.flatten(0, -2) + codebook_offsets, codebooks.flatten(0, 1).flatten(-2, -1), mode="sum", ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size * in_group_size] reconstructed_weight_groupwise = reconstructed_weight_flat.view( list(codes.shape[:-3]) + [num_out_groups, num_in_groups, out_group_size, in_group_size] ) if scales is not None: reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(scales) return reconstructed_weight_groupwise.swapaxes(-3, -2).reshape( list(codes.shape[:-3]) + [out_features, in_features] ) def forward_pass_quantized_linear( input: torch.Tensor, codes: torch.IntTensor, codebooks: torch.Tensor, scales: torch.Tensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: if input.is_cuda: matmul_result = aqlm_gemm_stupid(input, codes, codebooks, scales) if bias is not None: matmul_result += bias return matmul_result else: dequantized_weight = _dequantize_weight( unpack_int_data(codes, codebooks.shape[0].bit_length() - 1), codebooks, scales, ) return F.linear(input, dequantized_weight, bias) @triton.autotune( configs=[ triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps) for num_stages in (1, 2, 3, 4, 5) for num_warps in (1, 2, 4, 8) ], key=[ "in_features", "out_features", "num_codebooks", "codebook_size", "out_group_size", "in_group_size", "num_input_groups", "num_input_groups_next_power_of_2", "compute_in_fp32", ], ) @triton.jit def _aqlm_gemv_simple( input_vec_ptr, output_vec_ptr, codes_i16_ptr, codebooks_ptr, scales_ptr, in_features: tl.constexpr, out_features: tl.constexpr, num_codebooks: tl.constexpr, codebook_size: tl.constexpr, out_group_size: tl.constexpr, in_group_size: tl.constexpr, num_input_groups: tl.constexpr, num_input_groups_next_power_of_2: tl.constexpr, compute_in_fp32: tl.constexpr, UNUSED: tl.constexpr, ): # variables ending with "_i" mean "for i-th output unit" pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}] # Stage 1: load input data input_vec = tl.load( input_vec_ptr + tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size + tl.arange(0, in_group_size)[None, None, :], mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups, ) # [in_features//in_group_size, 1, group_size] # Note: we could simply load input_vec then reshape # input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features] # input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size]) # , but this does not work because tl.view may reorder elements arbitrarily; see its docstring # Stage 2: load integer codes for the active row # [in_features // in_group_size, num_codebooks] codes_i_ptrs = ( codes_i16_ptr + pid * num_input_groups * num_codebooks + tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks + tl.arange(0, num_codebooks)[None, :] ) codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups codes_i = tl.load( codes_i_ptrs, mask=codes_i_mask_1d[:, None] ) # [in_features//in_group_size, num_codebooks] if codes_i.dtype == tl.int16: codes_i = codes_i.to(tl.int32) codes_i = (codes_i) + ( codes_i < 0 ) * codebook_size # aka 2 ** nbits_per_codebook # ^-- (because codes are int16 tensors that contain uint data) # The following alternative does not work: # codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook else: codes_i = codes_i.to(tl.int32) # shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr codes_i += ( tl.arange(0, num_codebooks)[None, :] * codebook_size ) # aka 2 ** nbits_per_codebook # ^-- [in_group_size, num_codebooks] # Stage 3: convert codes to pointers to every individual (activated) weight in codebooks # [in_features // in_group_size, num_codebooks, out_group_size, in_group_size] out_group_ix = tl.arange(0, out_group_size)[None, None, :, None] in_group_ix = tl.arange(0, in_group_size)[None, None, None, :] weight_i_ptrs = ( codebooks_ptr + codes_i[:, :, None, None] * out_group_size * in_group_size + out_group_ix * in_group_size + in_group_ix ) # Stage 4: reconstruct weights, multiply by inputs and write out weights_i = tl.load( weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0 ) if compute_in_fp32: weights_i = weights_i.to(tl.float32) input_vec = input_vec.to(tl.float32) # ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size] weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization # ^-- [in_features // in_group_size, out_group_size, in_group_size] if out_group_size == 1: scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar output_i = tl.sum(weights_i * input_vec) * scale tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype)) else: output_i = tl.sum( tl.sum(weights_i * input_vec, axis=2), axis=0 ) # [out_group_size] output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype) tl.store( output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype), ) def next_power_of_2(x): return 1 if x == 0 else 2 ** (x - 1).bit_length() def aqlm_gemv_simple( input_vec: torch.Tensor, codes_i16: torch.ShortTensor, codebooks: torch.Tensor, scales: torch.Tensor, compute_in_fp32: bool = True, ): device, dtype = codebooks.device, codebooks.dtype num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape in_features = input_vec.shape[1] out_features = codes_i16.shape[0] * out_group_size num_input_groups = codes_i16.shape[1] assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!" assert scales.shape == (out_features // out_group_size, 1, 1, 1) assert in_features % in_group_size == 0 assert codebooks.shape[1] == 2**16 output_vec = torch.empty(1, out_features, device=device, dtype=dtype) # 1D launch kernel where each block computes output unit grid = lambda META: (out_features // out_group_size,) _aqlm_gemv_simple[grid]( input_vec, output_vec, codes_i16, codebooks, scales, in_features, out_features, num_codebooks, codebook_size, out_group_size, in_group_size, num_input_groups, next_power_of_2(num_input_groups), compute_in_fp32, ) return output_vec def aqlm_gemm_stupid( input: torch.Tensor, codes_i16: torch.ShortTensor, codebooks: torch.Tensor, scales: torch.Tensor, compute_in_fp32: bool = True, ): original_shape = input.shape input = input.reshape(-1, original_shape[-1]) return torch.cat( [ aqlm_gemv_simple( input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32 ) for input_vec in input ] ).reshape(original_shape[:-1] + (-1,))