File size: 5,548 Bytes
462dacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#include "q4_matrix.cuh"
#include <vector>
#include "../util.cuh"
#include "../matrix.cuh"

using namespace std;

const int UNSHUF_BLOCKSIZE_X = 64;

const int RECONS_THREADS_X = 64;      // Block size and thread count along columns in out, each thread converts 1 column
const int RECONS_THREADS_Y = 1;       // Block size and thread count along rows in x and out, each thread converts 8 rows

vector<Q4Matrix*> g_q4_matrices;

void g_q4_keep_matrix(Q4Matrix* m)
{
    g_q4_matrices.push_back(m);
}

void g_q4_free_matrices()
{
    for (const auto& m : g_q4_matrices) delete m;
    g_q4_matrices.clear();
}

Q4Matrix::Q4Matrix
(
    const int _height,
    const int _width,
    const int _groups,

    uint32_t* _qweight,
    uint32_t* _qzeros,
    half* _scales,
    uint32_t* _g_idx,

    const int _device
) :
    height(_height),
    width(_width),
    groups(_groups),
    device(_device)
{
    cudaSetDevice(device);

    cuda_qweight = _qweight;
    cuda_qzeros = _qzeros;
    cuda_scales = _scales;

    groupsize = height / groups;

    if (_g_idx) make_sequential(_g_idx);
}

Q4Matrix::~Q4Matrix()
{
}

// Make sequential

__global__ void make_sequential_kernel
(
    const uint32_t* __restrict__ w,
    uint32_t* __restrict__ w_new,
    const uint32_t* __restrict__ x_map,
    const int w_height,
    const int w_width
)
{
    const uint64_t* w2 = (uint64_t*) w;
    uint64_t* w_new2 = (uint64_t*) w_new;
    int w2_stride = w_width >> 1;

    int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
    if (w2_column >= w2_stride) return;

    int w_new2_row = blockIdx.y;

    int x_map_idx = w_new2_row << 3;

    uint64_t dst = 0;


    #pragma unroll
    for (int i = 0; i < 8; i++)
    {
        int source_row = x_map[x_map_idx++];

        int w2_row = source_row >> 3;
        int w2_subrow = source_row & 0x07;
        int w2_row_shift = w2_subrow << 2;
        int wnew2_row_shift = i << 2;

        uint64_t src = w2[w2_row * w2_stride + w2_column];
        src >>= w2_row_shift;
        src &= 0x0000000f0000000f;
        src <<= wnew2_row_shift;
        dst |= src;
    }

    w_new2[w_new2_row * w2_stride + w2_column] = dst;
}

void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)
{
    uint32_t* cuda_new_qweight = NULL;
    cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
    cudaMalloc(&cuda_x_map, height * sizeof(uint32_t));  // TODO: Should probably be allocated in PyTorch

    uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
    uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
    uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));

    // Group histogram

    for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;

    // Group map

    for (int i = 0, acc = 0; i < groups; i++)
    {
        short tmp = cpu_g_idx_map[i];
        cpu_g_idx_map[i] = acc;
        acc += tmp;
    }

    // X map (inverse)

    for (int row = 0; row < height; row++)
    {
        uint32_t target_group = cpu_g_idx[row];
        uint32_t target_row = cpu_g_idx_map[target_group];
        cpu_g_idx_map[target_group]++;
        cpu_x_map_inv[row] = target_row;
    }

    // X map

    for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;

    // Move to CUDA

    cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);

    // Rearrange rows in w

    dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);
    dim3 blocks
    (
        (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2),
        height / 8,
        1
    );

    make_sequential_kernel<<<blocks, threads>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);

    // Replace qweights

    cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);

    // Cleanup

    cudaDeviceSynchronize();
    cudaFree(cuda_new_qweight);
    free(cpu_g_idx_map);
    free(cpu_x_map);
    free(cpu_x_map_inv);
}

__global__ void reconstruct_kernel
(
    const uint32_t* __restrict__ w,
    half* __restrict__ out,  // (y)
    const half* __restrict__ w_scales,
    const uint32_t* __restrict__ w_zeros,
    const int height,
    const int width,
    const int groupsize
)
{
    // Start of block

    int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;
    int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;
    if (column >= width) return;

    // Views

    MatrixView_q4_column w_(w, height, width);
    MatrixView_half_rw out_(out, height, width);
    MatrixView_half w_scales_(w_scales, height / groupsize, width);
    MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);

    // Groupsize version

    int group = row / groupsize;

    half w_scale = w_scales_.item(group, column);
    uint32_t w_zero = w_zeros_.item(group, column) + 1;

    uint32_t w_read = w_.item_uint32_t(row, column);
    half* out_ptr = out_.item_ptr(row, column);

    #pragma unroll
    for (int s = 0; s < 32; s += 4)
    {
        half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
        *out_ptr = w_item; out_ptr += out_.width;
    }
}

void Q4Matrix::reconstruct(half* out)
{
    dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);

    dim3 blocks
    (
        (width + threads.x - 1) / threads.x,
        (height / 8 + threads.y - 1) / threads.y,
        1
    );

    reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
}