Andrei Panferov
flat
7e4a8ff
raw
history blame
6.36 kB
import torch
import triton
import triton.language as tl
from torch.autograd import Function
@triton.autotune(
configs=[
triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps)
for num_stages in (1, 2, 3, 4, 5)
for num_warps in (1, 2, 4, 8)
],
key=[
"in_features",
"out_features",
"num_codebooks",
"codebook_size",
"out_group_size",
"in_group_size",
"num_input_groups",
"num_input_groups_next_power_of_2",
"compute_in_fp32",
],
)
@triton.jit
def _aqlm_gemv_simple(
input_vec_ptr,
output_vec_ptr,
codes_i16_ptr,
codebooks_ptr,
scales_ptr,
in_features: tl.constexpr,
out_features: tl.constexpr,
num_codebooks: tl.constexpr,
codebook_size: tl.constexpr,
out_group_size: tl.constexpr,
in_group_size: tl.constexpr,
num_input_groups: tl.constexpr,
num_input_groups_next_power_of_2: tl.constexpr,
compute_in_fp32: tl.constexpr,
UNUSED: tl.constexpr,
):
# variables ending with "_i" mean "for i-th output unit"
pid = tl.program_id(axis=0) # [0, 1, ... {out_features-1}]
# Stage 1: load input data
input_vec = tl.load(
input_vec_ptr
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size
+ tl.arange(0, in_group_size)[None, None, :],
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups,
)
# [in_features//in_group_size, 1, group_size]
# Note: we could simply load input_vec then reshape
# input_vec = tl.load(input_vec_ptr + tl.arange(0, in_features)) # [in_features]
# input_vec = tl.view(input_vec, [num_input_groups, 1, in_group_size])
# , but this does not work because tl.view may reorder elements arbitrarily; see its docstring
# Stage 2: load integer codes for the active row
# [in_features // in_group_size, num_codebooks]
codes_i_ptrs = (
codes_i16_ptr
+ pid * num_input_groups * num_codebooks
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks
+ tl.arange(0, num_codebooks)[None, :]
)
codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups
codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) # [in_features//in_group_size, num_codebooks]
if codes_i.dtype == tl.int16:
codes_i = codes_i.to(tl.int32)
codes_i = (codes_i) + (codes_i < 0) * codebook_size # aka 2 ** nbits_per_codebook
# ^-- (because codes are int16 tensors that contain uint data)
# The following alternative does not work:
# codes_i = codes_i.to(tl.int32) % codebook_size # aka 2 ** nbits_per_codebook
else:
codes_i = codes_i.to(tl.int32)
# shift codes_i so that codebooks after 0th point to correct indices in codebooks_ptr
codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size # aka 2 ** nbits_per_codebook
# ^-- [in_group_size, num_codebooks]
# Stage 3: convert codes to pointers to every individual (activated) weight in codebooks
# [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
out_group_ix = tl.arange(0, out_group_size)[None, None, :, None]
in_group_ix = tl.arange(0, in_group_size)[None, None, None, :]
weight_i_ptrs = (
codebooks_ptr
+ codes_i[:, :, None, None] * out_group_size * in_group_size
+ out_group_ix * in_group_size
+ in_group_ix
)
# Stage 4: reconstruct weights, multiply by inputs and write out
weights_i = tl.load(weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0)
if compute_in_fp32:
weights_i = weights_i.to(tl.float32)
input_vec = input_vec.to(tl.float32)
# ^-- [in_features // in_group_size, num_codebooks, out_group_size, in_group_size]
weights_i = tl.sum(weights_i, axis=1) # sum codebooks as per additive quantization
# ^-- [in_features // in_group_size, out_group_size, in_group_size]
if out_group_size == 1:
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) # scalar
output_i = tl.sum(weights_i * input_vec) * scale
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype))
else:
output_i = tl.sum(tl.sum(weights_i * input_vec, axis=2), axis=0) # [out_group_size]
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype)
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype))
def next_power_of_2(x):
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def aqlm_gemv_simple(
input_vec: torch.Tensor,
codes_i16: torch.ShortTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
compute_in_fp32: bool = True,
):
device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input_vec.shape[1]
out_features = codes_i16.shape[0] * out_group_size
num_input_groups = codes_i16.shape[1]
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!"
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebooks.shape[1] == 2**16
output_vec = torch.empty(1, out_features, device=device, dtype=dtype)
# 1D launch kernel where each block computes output unit
grid = lambda META: (out_features // out_group_size,)
_aqlm_gemv_simple[grid](
input_vec,
output_vec,
codes_i16,
codebooks,
scales,
in_features,
out_features,
num_codebooks,
codebook_size,
out_group_size,
in_group_size,
num_input_groups,
next_power_of_2(num_input_groups),
compute_in_fp32,
)
return output_vec
def aqlm_gemm_stupid(
input: torch.Tensor,
codes_i16: torch.ShortTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
compute_in_fp32: bool = True,
):
original_shape = input.shape
input = input.reshape(-1, original_shape[-1])
return torch.cat(
[aqlm_gemv_simple(input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32) for input_vec in input]
).reshape(original_shape[:-1] + (-1,))