|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <ATen/ATen.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <c10/macros/Macros.h> |
|
#include <cuda_runtime.h> |
|
#include <torch/extension.h> |
|
|
|
namespace { |
|
|
|
template <typename scalar_t> |
|
__global__ void fused_rope_forward(const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t* src, |
|
const float* freqs, |
|
scalar_t* dst) |
|
{ |
|
int s_id = blockIdx.x, b_id = blockIdx.y; |
|
int offset_block = s_id * stride_s + b_id * stride_b; |
|
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; |
|
#pragma unroll |
|
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { |
|
float v_cos, v_sin; |
|
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; |
|
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; |
|
scalar_t v_src = src[offset_src]; |
|
scalar_t v_src_rotate = (d_id + d2 / 2 < d2) |
|
? -src[offset_src + (d2 / 2) * stride_d] |
|
: src[offset_src + (d2 / 2 - d2) * stride_d]; |
|
dst[offset_dst] = v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin; |
|
} |
|
} |
|
|
|
|
|
if (d > d2) { |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_head = offset_block + h_id * stride_h; |
|
int offset_head_dst = offset_block_dst + h_id * o_stride_h; |
|
#pragma unroll |
|
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { |
|
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void fused_rope_backward(const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t* src, |
|
const float* freqs, |
|
scalar_t* dst) |
|
{ |
|
int s_id = blockIdx.x, b_id = blockIdx.y; |
|
int offset_block = s_id * stride_s + b_id * stride_b; |
|
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; |
|
#pragma unroll |
|
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { |
|
scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]); |
|
scalar_t v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) |
|
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; |
|
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; |
|
scalar_t v_src = src[offset_src]; |
|
scalar_t v_src_rotate = (d_id + d2 / 2 < d2) |
|
? src[offset_src + (d2 / 2) * stride_d] |
|
: src[offset_src + (d2 / 2 - d2) * stride_d]; |
|
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; |
|
} |
|
} |
|
|
|
|
|
if (d > d2) { |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_head = offset_block + h_id * stride_h; |
|
int offset_head_dst = offset_block_dst + h_id * o_stride_h; |
|
#pragma unroll |
|
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { |
|
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <typename scalar_t_0, typename scalar_t_1> |
|
__global__ void fused_rope_cached_forward(const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t_0* src, |
|
const scalar_t_1* cos, |
|
const scalar_t_1* sin, |
|
scalar_t_0* dst) |
|
{ |
|
int s_id = blockIdx.x, b_id = blockIdx.y; |
|
int offset_block = s_id * stride_s + b_id * stride_b; |
|
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; |
|
#pragma unroll |
|
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { |
|
scalar_t_0 v_cos = cos[s_id * d2 + d_id]; |
|
scalar_t_0 v_sin = sin[s_id * d2 + d_id]; |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; |
|
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; |
|
scalar_t_0 v_src = src[offset_src]; |
|
scalar_t_0 v_src_rotate = (d_id + d2 / 2 < d2) |
|
? -src[offset_src + (d2 / 2) * stride_d] |
|
: src[offset_src + (d2 / 2 - d2) * stride_d]; |
|
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; |
|
} |
|
} |
|
|
|
|
|
if (d > d2) { |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_head = offset_block + h_id * stride_h; |
|
int offset_head_dst = offset_block_dst + h_id * o_stride_h; |
|
#pragma unroll |
|
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { |
|
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
template <typename scalar_t_0, typename scalar_t_1> |
|
__global__ void fused_rope_cached_backward(const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t_0* src, |
|
const scalar_t_1* cos, |
|
const scalar_t_1* sin, |
|
scalar_t_0* dst) |
|
{ |
|
int s_id = blockIdx.x, b_id = blockIdx.y; |
|
int offset_block = s_id * stride_s + b_id * stride_b; |
|
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; |
|
#pragma unroll |
|
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { |
|
scalar_t_0 v_cos = cos[s_id * d2 + d_id]; |
|
scalar_t_0 v_sin = (d_id + d2 / 2 < d2) ? sin[s_id * d2 + d_id + d2 / 2] |
|
: -sin[s_id * d2 + d_id + d2 / 2 - d2]; |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; |
|
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; |
|
scalar_t_0 v_src = src[offset_src]; |
|
scalar_t_0 v_src_rotate = (d_id + d2 / 2 < d2) |
|
? src[offset_src + (d2 / 2) * stride_d] |
|
: src[offset_src + (d2 / 2 - d2) * stride_d]; |
|
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; |
|
} |
|
} |
|
|
|
|
|
if (d > d2) { |
|
#pragma unroll |
|
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { |
|
int offset_head = offset_block + h_id * stride_h; |
|
int offset_head_dst = offset_block_dst + h_id * o_stride_h; |
|
#pragma unroll |
|
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { |
|
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
} |
|
|
|
template <typename scalar_t> |
|
void dispatch_fused_rope_forward(const int s, |
|
const int b, |
|
const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t* input, |
|
const float* freqs, |
|
scalar_t* output) |
|
{ |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
int warps_per_block = h < 16 ? 4 : 8; |
|
dim3 blocks(s, b); |
|
dim3 threads(C10_WARP_SIZE, warps_per_block); |
|
|
|
fused_rope_forward<<<blocks, threads, 0, stream>>>(h, |
|
d, |
|
d2, |
|
stride_s, |
|
stride_b, |
|
stride_h, |
|
stride_d, |
|
o_stride_s, |
|
o_stride_b, |
|
o_stride_h, |
|
o_stride_d, |
|
input, |
|
freqs, |
|
output); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
template <typename scalar_t> |
|
void dispatch_fused_rope_backward(const int s, |
|
const int b, |
|
const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t* output_grads, |
|
const float* freqs, |
|
scalar_t* input_grads) |
|
{ |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
int warps_per_block = h < 16 ? 4 : 8; |
|
dim3 blocks(s, b); |
|
dim3 threads(C10_WARP_SIZE, warps_per_block); |
|
|
|
fused_rope_backward<<<blocks, threads, 0, stream>>>(h, |
|
d, |
|
d2, |
|
stride_s, |
|
stride_b, |
|
stride_h, |
|
stride_d, |
|
o_stride_s, |
|
o_stride_b, |
|
o_stride_h, |
|
o_stride_d, |
|
output_grads, |
|
freqs, |
|
input_grads); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
template <typename scalar_t_0, typename scalar_t_1> |
|
void dispatch_fused_rope_cached_forward(const int s, |
|
const int b, |
|
const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t_0* input, |
|
const scalar_t_1* cos, |
|
const scalar_t_1* sin, |
|
scalar_t_0* output) |
|
{ |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
int warps_per_block = h < 16 ? 4 : 8; |
|
dim3 blocks(s, b); |
|
dim3 threads(C10_WARP_SIZE, warps_per_block); |
|
|
|
fused_rope_cached_forward<<<blocks, threads, 0, stream>>>(h, |
|
d, |
|
d2, |
|
stride_s, |
|
stride_b, |
|
stride_h, |
|
stride_d, |
|
o_stride_s, |
|
o_stride_b, |
|
o_stride_h, |
|
o_stride_d, |
|
input, |
|
cos, |
|
sin, |
|
output); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
template <typename scalar_t_0, typename scalar_t_1> |
|
void dispatch_fused_rope_cached_backward(const int s, |
|
const int b, |
|
const int h, |
|
const int d, |
|
const int d2, |
|
const int stride_s, |
|
const int stride_b, |
|
const int stride_h, |
|
const int stride_d, |
|
const int o_stride_s, |
|
const int o_stride_b, |
|
const int o_stride_h, |
|
const int o_stride_d, |
|
const scalar_t_0* output_grads, |
|
const scalar_t_1* cos, |
|
const scalar_t_1* sin, |
|
scalar_t_0* input_grads) |
|
{ |
|
auto stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
int warps_per_block = h < 16 ? 4 : 8; |
|
dim3 blocks(s, b); |
|
dim3 threads(C10_WARP_SIZE, warps_per_block); |
|
|
|
fused_rope_cached_backward<<<blocks, threads, 0, stream>>>(h, |
|
d, |
|
d2, |
|
stride_s, |
|
stride_b, |
|
stride_h, |
|
stride_d, |
|
o_stride_s, |
|
o_stride_b, |
|
o_stride_h, |
|
o_stride_d, |
|
output_grads, |
|
cos, |
|
sin, |
|
input_grads); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|