File size: 858 Bytes
5c0d7ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.inference_kernels.triton_kernel import aqlm_gemm_stupid as triton_gemm
from src.utils import _dequantize_weight, unpack_int_data
def forward_pass_quantized_linear(
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
if input.is_cuda:
matmul_result = triton_gemm(input, codes, codebooks, scales)
if bias is not None:
matmul_result += bias
return matmul_result
else:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
|