#include "q4_mlp.cuh" #include "q4_matmul.cuh" #include "rope.cuh" #include "rms_norm.cuh" #include "half_matmul.cuh" #include "../cuda_buffers.cuh" #include "../util.cuh" #include "../matrix.cuh" #if defined(USE_ROCM) #include "../hip_compat.cuh" #endif const int THREADS_X = 32; const int THREADS_Y = 1; const int THREADS_Z = 4; const int BLOCKSIZE_X = 2; // 2*half == 1*uint32_t const int BLOCKSIZE_Z = 4; // num_heads must be divisible by BLOCKSIZE_Z TODO: Check that this is the case when Llama2-34b releases __global__ void update_cache_kernel ( const half* __restrict__ key_states, const half* __restrict__ value_states, half* __restrict__ key_cache, half* __restrict__ value_cache, const int head_dim, const int num_kv_heads, const int q_len, const int max_seq_len, const int past_len ) { //int state_shape[] = { num_kv_heads, q_len, head_dim }; int state_stride[] = { head_dim, head_dim * num_kv_heads, 1 }; int state_pos[] = { 0, 0, 0 }; //int cache_shape[] = { num_kv_heads, max_seq_len, head_dim }; int cache_stride[] = { max_seq_len * head_dim, head_dim, 1 }; int cache_pos[] = { 0, past_len, 0 }; int size[] = { num_kv_heads, q_len, head_dim }; int x = (blockIdx.x * THREADS_X + threadIdx.x) * BLOCKSIZE_X; int y = blockIdx.y * THREADS_Y + threadIdx.y; int z = (blockIdx.z * THREADS_Z + threadIdx.z) * BLOCKSIZE_Z; if (x >= size[2]) return; if (y >= size[1]) return; if (z >= size[0]) return; int state_offset = (z + state_pos[0]) * state_stride[0] + (y + state_pos[1]) * state_stride[1] + (x + state_pos[2]) * state_stride[2]; int cache_offset = (z + cache_pos[0]) * cache_stride[0] + (y + cache_pos[1]) * cache_stride[1] + (x + cache_pos[2]) * cache_stride[2]; const uint32_t* key_ptr = (uint32_t*) (key_states + state_offset); const uint32_t* value_ptr = (uint32_t*) (value_states + state_offset); uint32_t* key_cache_ptr = (uint32_t*) (key_cache + cache_offset); uint32_t* value_cache_ptr = (uint32_t*) (value_cache + cache_offset); #pragma unroll for (int k = 0; k < BLOCKSIZE_Z; k++) { *key_cache_ptr = *key_ptr; key_ptr += state_stride[0] / BLOCKSIZE_X; key_cache_ptr += cache_stride[0] / BLOCKSIZE_X; } #pragma unroll for (int k = 0; k < BLOCKSIZE_Z; k++) { *value_cache_ptr = *value_ptr; value_ptr += state_stride[0] / BLOCKSIZE_X; value_cache_ptr += cache_stride[0] / BLOCKSIZE_X; } } void q4_attn_cuda ( ExLlamaTuning* tuningParams, cudaStream_t stream, cublasHandle_t handle, half* x, const half* rms_norm_weight, // shape == (x.shape[1],) == (dim,) float epsilon, half* query_states, half* key_states, half* value_states, Q4Matrix* q_proj, Q4Matrix* k_proj, Q4Matrix* v_proj, half* sin, half* cos, const int bsz, const int q_len, const int dim, const int head_dim, const int num_heads, const int num_kv_heads, const int past_len, half* key_cache, half* value_cache, const half* q_a, const half* q_b, const int q_rank, const half* k_a, const half* k_b, const int k_rank, const half* v_a, const half* v_b, const int v_rank, half* lora_temp, const int max_seq_len, const int device_index ) { // Cache update grid dim3 threads(THREADS_X, THREADS_Y, THREADS_Z); dim3 blocks ( ((head_dim + THREADS_X - 1) / THREADS_X + BLOCKSIZE_X - 1) / BLOCKSIZE_X, q_len, ((num_kv_heads + THREADS_Z - 1) / THREADS_Z + BLOCKSIZE_Z - 1) / BLOCKSIZE_Z ); int _rows_per_batch = q_len * num_heads; int _rows_per_batch_kv = q_len * num_kv_heads; CudaBuffers* buffers = get_buffers(device_index); // Layernorm half* temp_x = buffers->temp_state + q_len * dim; rms_norm_cuda(tuningParams, x, rms_norm_weight, temp_x, epsilon, q_len, dim, device_index); // Adapters if (q_a) { half_matmul_cublas_cuda(tuningParams, temp_x, q_a, lora_temp, q_len, dim, q_rank, handle); half_matmul_cublas_cuda(tuningParams, lora_temp, q_b, query_states, q_len, q_rank, dim, handle); } if (k_a) { half_matmul_cublas_cuda(tuningParams, temp_x, k_a, lora_temp, q_len, dim, k_rank, handle); half_matmul_cublas_cuda(tuningParams, lora_temp, k_b, key_states, q_len, k_rank, dim, handle); } if (v_a) { half_matmul_cublas_cuda(tuningParams, temp_x, v_a, lora_temp, q_len, dim, v_rank, handle); half_matmul_cublas_cuda(tuningParams, lora_temp, v_b, value_states, q_len, v_rank, dim, handle); } if (!tuningParams->concurrent_streams) { // Project q, k, v q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states, q_a ? true : false); q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states, k_a ? true : false); q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states, v_a ? true : false); // Positional embeddings q, k rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_heads, past_len); rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch_kv, head_dim, num_kv_heads, past_len); // Update cache tensors with projected k, v update_cache_kernel<<>>(key_states, value_states, key_cache, value_cache, head_dim, num_kv_heads, q_len, max_seq_len, past_len); } else { // Project q, k, v, add positional embeddings to q, k, update cache tensors with projected k, v cudaStream_t str_1 = buffers->alt_stream_1; cudaStream_t str_2 = buffers->alt_stream_2; cudaStream_t str_3 = buffers->alt_stream_3; cudaEvent_t sync_1 = buffers->alt_stream_1_done; cudaEvent_t sync_2 = buffers->alt_stream_2_done; cudaEvent_t sync_3 = buffers->alt_stream_3_done; // str_1: project q, positions q, sync q4_matmul_cuda(tuningParams, temp_x, q_len, q_proj, query_states, q_a ? true : false, str_1); rope_cuda(tuningParams, query_states, sin, cos, bsz, _rows_per_batch, head_dim, num_kv_heads, past_len, str_1); cudaEventRecord(sync_1, str_1); // str_2: project k, positions k, sync q4_matmul_cuda(tuningParams, temp_x, q_len, k_proj, key_states, k_a ? true : false, str_2); rope_cuda(tuningParams, key_states, sin, cos, bsz, _rows_per_batch_kv, head_dim, num_kv_heads, past_len, str_2); cudaEventRecord(sync_2, str_2); // str_3: project v, wait for str_2, copy (k,v) to cache, sync q4_matmul_cuda(tuningParams, temp_x, q_len, v_proj, value_states, v_a ? true : false, buffers->alt_stream_3); cudaStreamWaitEvent(str_3, sync_2, 0); update_cache_kernel<<>>(key_states, value_states, key_cache, value_cache, head_dim, num_kv_heads, q_len, max_seq_len, past_len); cudaEventRecord(sync_3, str_3); // default: wait for str_1 and str_3 cudaStreamWaitEvent(NULL, sync_1, 0); cudaStreamWaitEvent(NULL, sync_3, 0); } } void q4_attn_2_cuda ( ExLlamaTuning* tuningParams, cublasHandle_t handle, half* x, half* attn_output, Q4Matrix* o_proj, const int height, const half* o_a, const half* o_b, const int o_rank, half* lora_temp ) { if (o_a) { int dim = o_proj->height; half_matmul_cublas_cuda(tuningParams, attn_output, o_a, lora_temp, height, dim, o_rank, handle); half_matmul_cublas_cuda(tuningParams, lora_temp, o_b, x, height, o_rank, dim, handle, true); } q4_matmul_cuda(tuningParams, attn_output, height, o_proj, x, true); }