|
import torch |
|
import triton |
|
import triton.language as tl |
|
from torch.autograd import Function |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({"UNUSED": 1}, num_stages=num_stages, num_warps=num_warps) |
|
for num_stages in (1, 2, 3, 4, 5) |
|
for num_warps in (1, 2, 4, 8) |
|
], |
|
key=[ |
|
"in_features", |
|
"out_features", |
|
"num_codebooks", |
|
"codebook_size", |
|
"out_group_size", |
|
"in_group_size", |
|
"num_input_groups", |
|
"num_input_groups_next_power_of_2", |
|
"compute_in_fp32", |
|
], |
|
) |
|
@triton.jit |
|
def _aqlm_gemv_simple( |
|
input_vec_ptr, |
|
output_vec_ptr, |
|
codes_i16_ptr, |
|
codebooks_ptr, |
|
scales_ptr, |
|
in_features: tl.constexpr, |
|
out_features: tl.constexpr, |
|
num_codebooks: tl.constexpr, |
|
codebook_size: tl.constexpr, |
|
out_group_size: tl.constexpr, |
|
in_group_size: tl.constexpr, |
|
num_input_groups: tl.constexpr, |
|
num_input_groups_next_power_of_2: tl.constexpr, |
|
compute_in_fp32: tl.constexpr, |
|
UNUSED: tl.constexpr, |
|
): |
|
|
|
pid = tl.program_id(axis=0) |
|
|
|
|
|
input_vec = tl.load( |
|
input_vec_ptr |
|
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] * in_group_size |
|
+ tl.arange(0, in_group_size)[None, None, :], |
|
mask=tl.arange(0, num_input_groups_next_power_of_2)[:, None, None] < num_input_groups, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codes_i_ptrs = ( |
|
codes_i16_ptr |
|
+ pid * num_input_groups * num_codebooks |
|
+ tl.arange(0, num_input_groups_next_power_of_2)[:, None] * num_codebooks |
|
+ tl.arange(0, num_codebooks)[None, :] |
|
) |
|
codes_i_mask_1d = tl.arange(0, num_input_groups_next_power_of_2) < num_input_groups |
|
|
|
codes_i = tl.load(codes_i_ptrs, mask=codes_i_mask_1d[:, None]) |
|
if codes_i.dtype == tl.int16: |
|
codes_i = codes_i.to(tl.int32) |
|
codes_i = (codes_i) + (codes_i < 0) * codebook_size |
|
|
|
|
|
|
|
|
|
else: |
|
codes_i = codes_i.to(tl.int32) |
|
|
|
|
|
codes_i += tl.arange(0, num_codebooks)[None, :] * codebook_size |
|
|
|
|
|
|
|
|
|
out_group_ix = tl.arange(0, out_group_size)[None, None, :, None] |
|
in_group_ix = tl.arange(0, in_group_size)[None, None, None, :] |
|
weight_i_ptrs = ( |
|
codebooks_ptr |
|
+ codes_i[:, :, None, None] * out_group_size * in_group_size |
|
+ out_group_ix * in_group_size |
|
+ in_group_ix |
|
) |
|
|
|
|
|
weights_i = tl.load(weight_i_ptrs, mask=codes_i_mask_1d[:, None, None, None], other=0) |
|
if compute_in_fp32: |
|
weights_i = weights_i.to(tl.float32) |
|
input_vec = input_vec.to(tl.float32) |
|
|
|
weights_i = tl.sum(weights_i, axis=1) |
|
|
|
|
|
if out_group_size == 1: |
|
scale = tl.load(scales_ptr + pid).to(weights_i.dtype) |
|
output_i = tl.sum(weights_i * input_vec) * scale |
|
tl.store(output_vec_ptr + pid, output_i.to(input_vec.dtype)) |
|
else: |
|
output_i = tl.sum(tl.sum(weights_i * input_vec, axis=2), axis=0) |
|
output_i *= tl.load(scales_ptr + pid).to(weights_i.dtype) |
|
tl.store(output_vec_ptr + pid * out_group_size + tl.arange(0, out_group_size), output_i.to(input_vec.dtype)) |
|
|
|
|
|
def next_power_of_2(x): |
|
return 1 if x == 0 else 2 ** (x - 1).bit_length() |
|
|
|
|
|
def aqlm_gemv_simple( |
|
input_vec: torch.Tensor, |
|
codes_i16: torch.ShortTensor, |
|
codebooks: torch.Tensor, |
|
scales: torch.Tensor, |
|
compute_in_fp32: bool = True, |
|
): |
|
|
|
device, dtype = codebooks.device, codebooks.dtype |
|
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape |
|
in_features = input_vec.shape[1] |
|
out_features = codes_i16.shape[0] * out_group_size |
|
num_input_groups = codes_i16.shape[1] |
|
assert input_vec.ndim == 2 and input_vec.shape[0] == 1, "do reshape; now!" |
|
assert scales.shape == (out_features // out_group_size, 1, 1, 1) |
|
assert in_features % in_group_size == 0 |
|
assert codebooks.shape[1] == 2**16 |
|
|
|
output_vec = torch.empty(1, out_features, device=device, dtype=dtype) |
|
|
|
grid = lambda META: (out_features // out_group_size,) |
|
_aqlm_gemv_simple[grid]( |
|
input_vec, |
|
output_vec, |
|
codes_i16, |
|
codebooks, |
|
scales, |
|
in_features, |
|
out_features, |
|
num_codebooks, |
|
codebook_size, |
|
out_group_size, |
|
in_group_size, |
|
num_input_groups, |
|
next_power_of_2(num_input_groups), |
|
compute_in_fp32, |
|
) |
|
|
|
return output_vec |
|
|
|
|
|
def aqlm_gemm_stupid( |
|
input: torch.Tensor, |
|
codes_i16: torch.ShortTensor, |
|
codebooks: torch.Tensor, |
|
scales: torch.Tensor, |
|
compute_in_fp32: bool = True, |
|
): |
|
original_shape = input.shape |
|
input = input.reshape(-1, original_shape[-1]) |
|
return torch.cat( |
|
[aqlm_gemv_simple(input_vec.unsqueeze(0), codes_i16, codebooks, scales, compute_in_fp32) for input_vec in input] |
|
).reshape(original_shape[:-1] + (-1,)) |
|
|