Andrei Panferov
flat
7e4a8ff
raw
history blame
858 Bytes
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.inference_kernels.triton_kernel import aqlm_gemm_stupid as triton_gemm
from src.utils import _dequantize_weight, unpack_int_data
def forward_pass_quantized_linear(
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
if input.is_cuda:
matmul_result = triton_gemm(input, codes, codebooks, scales)
if bias is not None:
matmul_result += bias
return matmul_result
else:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)