File size: 10,010 Bytes
b664585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
#!/usr/bin/env python3

# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization

# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.

from __future__ import annotations

import argparse
from math import prod
import os
import sys
from pathlib import Path
import ctypes
import logging
import numpy as np

# Necessary to load the local gguf package
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
    sys.path.insert(0, str(Path(__file__).parent.parent))

import gguf
from gguf.constants import GGMLQuantizationType


logger = logging.getLogger("test-quants")


c_float_p = ctypes.POINTER(ctypes.c_float)


class ggml_init_params(ctypes.Structure):
    _fields_ = [
        ("mem_size", ctypes.c_size_t),
        ("mem_buffer", ctypes.c_void_p),
        ("no_alloc", ctypes.c_bool),
    ]


class GGMLQuants:
    libggml: ctypes.CDLL

    def __init__(self, libggml: Path):
        self.libggml = ctypes.CDLL(str(libggml))
        self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
        # enum ggml_type   type,
        #    const float * src,
        #           void * dst,
        #        int64_t   start,
        #        int64_t   nrows,
        #        int64_t   n_per_row,
        #    const float * imatrix) {
        self.libggml.ggml_quantize_chunk.argtypes = (
            ctypes.c_int,
            ctypes.POINTER(ctypes.c_float),
            ctypes.c_void_p,
            ctypes.c_int64,
            ctypes.c_int64,
            ctypes.c_int64,
            ctypes.POINTER(ctypes.c_float),
        )

        self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
        self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)

        for t in (
            "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
            "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
            "tq1_0", "tq2_0",
            "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
            "iq4_nl", "iq4_xs",
        ):
            dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
            dequant_func.restype = None
            dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)

        self.libggml.ggml_fp16_to_fp32_row.restype = None
        self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
        self.libggml.ggml_bf16_to_fp32_row.restype = None
        self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)

        self.libggml.ggml_init.argtypes = (ggml_init_params,)

        self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))

    def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
        result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
        if qtype == GGMLQuantizationType.F32:
            # no-op
            result = tensor.view(np.float32)
        elif qtype == GGMLQuantizationType.F16:
            self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
        elif qtype == GGMLQuantizationType.BF16:
            self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
        else:
            lw_qname = qtype.name.lower()
            if lw_qname[-1] == "k":
                lw_qname = lw_qname[:-1] + "K"
            dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
            dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
        return result

    def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
        result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
        if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
            # TODO: is a column-wise sum of squares appropriate?
            qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
        else:
            qw = ctypes.cast(0, c_float_p)
        result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
        assert result.size == result_size
        return result


def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
    same = np.array_equal(t1, t2)
    if same:
        return True
    else:
        block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
        if t1.dtype == np.float32:
            t1 = t1.reshape((-1, block_size))
            t2 = t2.reshape((-1, block_size))
        else:
            t1 = t1.reshape((-1, type_size))
            t2 = t2.reshape((-1, type_size))
        x = t1.view(np.uint8) ^ t2.view(np.uint8)
        diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
        num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
        if num_bad_blocks == 0 and t1.shape == t2.shape:
            logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
            return True
        logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
        bad_block_id = np.argmax(diff_bits, axis=0)
        logger.debug(f"Worst block id: {bad_block_id}")
        logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")

        sum_diff_bits = np.sum(diff_bits)
        logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)")
        return False


def do_test(libggml_path: Path, quick: bool = False):
    ggml_quants = GGMLQuants(libggml_path)

    np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})

    r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)

    for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
        has_dequantize = False
        has_quantize = False

        try:
            gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
            has_dequantize = True
        except (NotImplementedError, AssertionError) as e:
            if isinstance(e, AssertionError):
                logger.error(f"Error with {qtype.name}: {e}")
                raise e
        try:
            gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
            has_quantize = True
        except (NotImplementedError, AssertionError) as e:
            if isinstance(e, AssertionError):
                logger.error(f"Error with {qtype.name}: {e}")
                raise e

        if not has_dequantize and not has_quantize:
            continue

        logger.info(f"Testing {qtype.name}")

        rc = r.copy(order="C")

        pyq = None
        ggq = None

        if has_quantize:
            logger.debug(f"Quantizing to {qtype.name} with Python")
            pyq = gguf.quants.quantize(rc, qtype)

            logger.debug(f"Quantizing to {qtype.name} with C")
            ggq = ggml_quants.quantize(rc, qtype)

            if qtype == GGMLQuantizationType.F16:
                pyq = pyq.view(np.uint8)
            quant_equal = compare_tensors(pyq, ggq, qtype)

            if not quant_equal:
                logger.error(f"Quantization to {qtype.name} does not match ❌")
            else:
                logger.info(f"Quantization to {qtype.name} matches exactly βœ…")

        if has_dequantize:
            if ggq is None and not quick:
                logger.debug(f"Quantizing to {qtype.name} with C")
                ggq = ggml_quants.quantize(rc, qtype)

            if ggq is not None:
                logger.debug(f"Dequantizing from {qtype.name} with Python")
                pydq = gguf.quants.dequantize(ggq, qtype)
                logger.debug(f"Dequantizing from {qtype.name} with C")
                ggdq = ggml_quants.dequantize(ggq, qtype)

                dequant_equal = compare_tensors(pydq, ggdq, qtype)

                if not dequant_equal:
                    logger.error(f"Dequantization from {qtype.name} does not match ❌")
                else:
                    logger.info(f"Dequantization from {qtype.name} matches exactly βœ…")

            rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
            rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)

            logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
            pydq = gguf.quants.dequantize(rq, qtype)
            logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
            ggdq = ggml_quants.dequantize(rq, qtype)

            dequant_equal = compare_tensors(pydq, ggdq, qtype)

            if not dequant_equal:
                logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
            else:
                logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly βœ…")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
    parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
    parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")

    args = parser.parse_args()

    logging.basicConfig(level=logging.DEBUG)

    do_test(args.libggml, args.quick)