|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from src.inference_kernels.triton_kernel import aqlm_gemm_stupid as triton_gemm |
|
from src.utils import _dequantize_weight, unpack_int_data |
|
|
|
|
|
def forward_pass_quantized_linear( |
|
input: torch.Tensor, |
|
codes: torch.IntTensor, |
|
codebooks: torch.Tensor, |
|
scales: torch.Tensor, |
|
bias: Optional[torch.Tensor], |
|
) -> torch.Tensor: |
|
if input.is_cuda: |
|
matmul_result = triton_gemm(input, codes, codebooks, scales) |
|
if bias is not None: |
|
matmul_result += bias |
|
return matmul_result |
|
else: |
|
dequantized_weight = _dequantize_weight( |
|
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1), |
|
codebooks, |
|
scales, |
|
) |
|
return F.linear(input, dequantized_weight, bias) |
|
|