|
#include "half_matmul.cuh" |
|
#include "../util.cuh" |
|
#include "../matrix.cuh" |
|
#include "../cuda_compat.cuh" |
|
#if defined(USE_ROCM) |
|
#include "../hip_compat.cuh" |
|
#endif |
|
|
|
|
|
|
|
const int THREADS_X = 32; |
|
const int THREADS_Y = 8; |
|
const int BLOCKSIZE = 256; |
|
|
|
__global__ void half_matmul_kernel |
|
( |
|
const half* __restrict__ x, |
|
const half* __restrict__ w, |
|
half* __restrict__ out, |
|
const int height, |
|
const int dim, |
|
const int width |
|
) |
|
{ |
|
const int column = (blockIdx.x * THREADS_X + threadIdx.x) * 2; |
|
const int row = blockIdx.y * THREADS_Y + threadIdx.y; |
|
const int k0 = blockIdx.z * BLOCKSIZE; |
|
|
|
if (row >= height) return; |
|
if (column >= width) return; |
|
|
|
MatrixView_half x_(x, height, dim); |
|
MatrixView_half w_(w, dim, width); |
|
MatrixView_half_rw out_(out, height, width); |
|
|
|
half2* x_ptr = (half2*) x_.item_ptr(row, k0); |
|
half2* w_ptr = (half2*) w_.item_ptr(k0, column); |
|
half2 acc = {}; |
|
|
|
#pragma unroll |
|
for (int k = k0; k < k0 + BLOCKSIZE / 2; k++) |
|
{ |
|
half2 x_item = *x_ptr++; |
|
half2 x_item_0 = __half2half2(x_item.x); |
|
half2 x_item_1 = __half2half2(x_item.y); |
|
half2 w_item_0 = *w_ptr; w_ptr += w_.width / 2; |
|
half2 w_item_1 = *w_ptr; w_ptr += w_.width / 2; |
|
acc = __hfma2(x_item_0, w_item_0, acc); |
|
acc = __hfma2(x_item_1, w_item_1, acc); |
|
} |
|
|
|
|
|
atomicAdd((half2*)out_.item_ptr(row, column), acc); |
|
} |
|
|
|
void half_matmul_cuda |
|
( |
|
const half* x, |
|
const half* w, |
|
half* out, |
|
const int height, |
|
const int dim, |
|
const int width, |
|
cudaStream_t alt_stream |
|
) |
|
{ |
|
dim3 threads(THREADS_X, THREADS_Y, 1); |
|
|
|
dim3 blocks |
|
( |
|
(width + THREADS_X - 1) / THREADS_X / 2, |
|
(height + THREADS_Y - 1) / THREADS_Y, |
|
(dim + BLOCKSIZE - 1) / BLOCKSIZE |
|
); |
|
|
|
half_matmul_kernel<<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width); |
|
} |
|
|
|
|
|
|
|
const int MAX_DIM_SMALL = 8192; |
|
|
|
void half_matmul_cublas_cuda |
|
( |
|
ExLlamaTuning* tuningParams, |
|
const half* x, |
|
const half* w, |
|
half* out, |
|
const int height, |
|
const int dim, |
|
const int width, |
|
cublasHandle_t handle, |
|
bool no_zero, |
|
cudaStream_t alt_stream |
|
) |
|
{ |
|
|
|
|
|
if (height < 4 && dim <= MAX_DIM_SMALL) |
|
{ |
|
half_matmul_small_cuda(tuningParams, x, w, out, height, dim, width, no_zero, alt_stream); |
|
return; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
const half alpha = __float2half(1.0f); |
|
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); |
|
|
|
cudaStream_t default_stream; |
|
if (alt_stream) |
|
{ |
|
cublasGetStream(handle, &default_stream); |
|
cublasSetStream(handle, alt_stream); |
|
} |
|
|
|
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, w, width, x, dim, &beta, out, width); |
|
|
|
if (alt_stream) |
|
{ |
|
cublasSetStream(handle, default_stream); |
|
} |
|
} |
|
|
|
|
|
|
|
const int S_THREADS_X = 8; |
|
const int S_THREADS_Z = 1; |
|
const int S_BLOCKSIZE = MAX_DIM_SMALL / 1024 * S_THREADS_X; |
|
|
|
template<bool use_half2, bool odd_rank> |
|
__global__ void half_matmul_small_kernel |
|
( |
|
const half* __restrict__ x, |
|
const half* __restrict__ w, |
|
half* __restrict__ out, |
|
const int height, |
|
const int dim, |
|
const int width, |
|
bool no_zero |
|
) |
|
{ |
|
int column = blockIdx.x * S_THREADS_X + threadIdx.x; |
|
int row = blockIdx.z * S_THREADS_Z + threadIdx.z; |
|
int k = threadIdx.y * S_BLOCKSIZE; |
|
|
|
if (row >= height) return; |
|
if (column >= width) return; |
|
|
|
|
|
|
|
MatrixView_half x_(x, height, dim); |
|
MatrixView_half w_(w, dim, width); |
|
MatrixView_half_rw out_(out, height, width); |
|
|
|
int k_end = k + S_BLOCKSIZE; |
|
if (k_end > dim) k_end = dim; |
|
|
|
const half* x_ptr = x_.item_ptr(row, k); |
|
const half* x_ptr_end = x_.item_ptr(row, k_end); |
|
const half* w_ptr = w_.item_ptr(k, column); |
|
half* out_ptr = out_.item_ptr(row, column); |
|
|
|
if constexpr (use_half2 && !odd_rank) |
|
{ |
|
half2* x_ptr2 = (half2*) x_ptr; |
|
half2* x_ptr2_end = (half2*) x_ptr_end; |
|
|
|
half2 r = {}; |
|
|
|
while(x_ptr2 < x_ptr2_end) |
|
{ |
|
half2 x_01 = *x_ptr2++; |
|
half2 x_23 = *x_ptr2++; |
|
half w_0 = *w_ptr; w_ptr += width; |
|
half w_1 = *w_ptr; w_ptr += width; |
|
half w_2 = *w_ptr; w_ptr += width; |
|
half w_3 = *w_ptr; w_ptr += width; |
|
half2 w_01 = __halves2half2(w_0, w_1); |
|
half2 w_23 = __halves2half2(w_2, w_3); |
|
r = __hfma2(x_01, w_01, r); |
|
r = __hfma2(x_23, w_23, r); |
|
} |
|
|
|
half rh = __hadd(r.x, r.y); |
|
|
|
__shared__ half accum[MAX_DIM_SMALL / S_BLOCKSIZE][S_THREADS_X]; |
|
accum[threadIdx.y][threadIdx.x] = rh; |
|
__syncthreads(); |
|
|
|
if (threadIdx.y == 0) |
|
{ |
|
half acc = rh; |
|
for (int i = 1; i < blockDim.y; ++i) acc = __hadd(accum[i][threadIdx.x], acc); |
|
if (no_zero) acc = __hadd(acc, *out_ptr); |
|
*out_ptr = acc; |
|
} |
|
} |
|
else |
|
{ |
|
half r = {}; |
|
|
|
while(x_ptr < x_ptr_end) |
|
{ |
|
if constexpr (odd_rank) |
|
{ |
|
half x_item = *x_ptr++; |
|
half w_item = *w_ptr; w_ptr += width; |
|
r = __hfma(x_item, w_item, r); |
|
} |
|
else |
|
{ |
|
#pragma unroll |
|
for (int i = 0; i < 4; ++i) |
|
{ |
|
half x_item = *x_ptr++; |
|
half w_item = *w_ptr; w_ptr += width; |
|
r = __hfma(x_item, w_item, r); |
|
} |
|
} |
|
} |
|
|
|
__shared__ half accum[MAX_DIM_SMALL / S_BLOCKSIZE][S_THREADS_X]; |
|
accum[threadIdx.y][threadIdx.x] = r; |
|
__syncthreads(); |
|
|
|
if (threadIdx.y == 0) |
|
{ |
|
half acc = accum[0][threadIdx.x]; |
|
for (int i = 1; i < blockDim.y; ++i) acc = __hadd(accum[i][threadIdx.x], acc); |
|
if (no_zero) acc = __hadd(acc, *out_ptr); |
|
*out_ptr = acc; |
|
} |
|
} |
|
} |
|
|
|
void half_matmul_small_cuda |
|
( |
|
ExLlamaTuning* tuningParams, |
|
const half* x, |
|
const half* w, |
|
half* out, |
|
const int height, |
|
const int dim, |
|
const int width, |
|
bool no_zero, |
|
cudaStream_t alt_stream |
|
) |
|
{ |
|
bool use_half2 = !tuningParams->matmul_no_half2; |
|
|
|
|
|
|
|
dim3 threads |
|
( |
|
S_THREADS_X, |
|
(dim + S_BLOCKSIZE - 1) / S_BLOCKSIZE, |
|
1 |
|
); |
|
|
|
dim3 blocks |
|
( |
|
(width + S_THREADS_X - 1) / S_THREADS_X, |
|
1, |
|
height |
|
); |
|
|
|
|
|
|
|
|
|
|
|
if (dim & 0x03) |
|
{ |
|
half_matmul_small_kernel<false, true> <<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width, no_zero); |
|
} |
|
else |
|
{ |
|
if (use_half2) half_matmul_small_kernel<true, false> <<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width, no_zero); |
|
else half_matmul_small_kernel<false, false> <<<blocks, threads, 0, alt_stream>>>(x, w, out, height, dim, width, no_zero); |
|
} |
|
} |
|
|
|
|