File size: 858 Bytes
5c0d7ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.inference_kernels.triton_kernel import aqlm_gemm_stupid as triton_gemm
from src.utils import _dequantize_weight, unpack_int_data


def forward_pass_quantized_linear(
    input: torch.Tensor,
    codes: torch.IntTensor,
    codebooks: torch.Tensor,
    scales: torch.Tensor,
    bias: Optional[torch.Tensor],
) -> torch.Tensor:
    if input.is_cuda:
        matmul_result = triton_gemm(input, codes, codebooks, scales)
        if bias is not None:
            matmul_result += bias
        return matmul_result
    else:
        dequantized_weight = _dequantize_weight(
            unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
            codebooks,
            scales,
        )
        return F.linear(input, dequantized_weight, bias)