#include "rope.cuh" #include "../util.cuh" #include "../matrix.cuh" const int THREADS_X = 32; const int THREADS_Y = 4; const int MAX_POS_EMBEDDINGS = 32768; // Actual number doesn't matter typedef void (*fp_rope_cuda_kernel) ( half*, const half*, const half*, int, int, int, int ); template __global__ void rope_cuda_kernel ( half* __restrict__ x, const half* __restrict__ sin, const half* __restrict__ cos, int rows_per_batch, int head_dim, int num_heads, int past_len ) { // These heights aren't used so it's okay if they're wrong. MatrixView_half_rw x_(x, rows_per_batch, head_dim); MatrixView_half sin_(sin, MAX_POS_EMBEDDINGS, head_dim); MatrixView_half cos_(cos, MAX_POS_EMBEDDINGS, head_dim); int column = (blockIdx.x * THREADS_X + threadIdx.x); if constexpr (use_half2) column *= 2; int half_dim = head_dim / 2; if (column >= half_dim) return; int row = blockIdx.y * THREADS_Y + threadIdx.y; if (row >= rows_per_batch) return; int batch_offset = blockIdx.z * rows_per_batch; int row_offset = batch_offset + row; // Get sin and cos int sincos_row = past_len + row / num_heads; if constexpr (use_half2) { half2 cos2_l = cos_.item_half2(sincos_row, column); half2 cos2_r = cos_.item_half2(sincos_row, column + half_dim); half2 sin2_l = sin_.item_half2(sincos_row, column); half2 sin2_r = sin_.item_half2(sincos_row, column + half_dim); sin2_l = __hneg2(sin2_l); // Apply embedding to row half2 item2_l = x_.item_half2(row_offset, column); half2 item2_r = x_.item_half2(row_offset, column + half_dim); half2 item2_ls = __hmul2(item2_r, sin2_l); half2 item2_rs = __hmul2(item2_l, sin2_r); item2_l = __hfma2(item2_l, cos2_l, item2_ls); item2_r = __hfma2(item2_r, cos2_r, item2_rs); x_.set_half2(row_offset, column, item2_l); x_.set_half2(row_offset, column + half_dim, item2_r); } else { half cos_l = cos_.item(sincos_row, column); half cos_r = cos_.item(sincos_row, column + half_dim); half sin_l = sin_.item(sincos_row, column); half sin_r = sin_.item(sincos_row, column + half_dim); sin_l = __hneg(sin_l); // Apply embedding to row half item_l = x_.item(row_offset, column); half item_r = x_.item(row_offset, column + half_dim); half item_ls = __hmul(item_r, sin_l); half item_rs = __hmul(item_l, sin_r); item_l = __hfma(item_l, cos_l, item_ls); item_r = __hfma(item_r, cos_r, item_rs); x_.set(row_offset, column, item_l); x_.set(row_offset, column + half_dim, item_r); } } fp_rope_cuda_kernel rope_cuda_kernel_pick(ExLlamaTuning* tuningParams) { // if (tuningParams->matmul_no_half2) { return rope_cuda_kernel; } else { return rope_cuda_kernel; } }; void rope_cuda ( ExLlamaTuning* tuningParams, half* x, const half* sin, const half* cos, const int bsz, const int rows_per_batch, const int head_dim, const int num_heads, const int past_len, cudaStream_t alt_stream ) { dim3 threads(THREADS_X, THREADS_Y, 1); dim3 blocks ( (head_dim + THREADS_X - 1) / THREADS_X / 2 / (tuningParams->rope_no_half2 ? 1 : 2), (rows_per_batch + THREADS_Y - 1) / THREADS_Y, int(bsz) ); fp_rope_cuda_kernel kernel = rope_cuda_kernel_pick(tuningParams); kernel<<>>(x, sin, cos, rows_per_batch, head_dim, num_heads, past_len); }