Spaces:
Sleeping
Sleeping
// Copyright 2021 AlQuraishi Laboratory | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu | |
__inline__ __device__ float WarpAllReduceMax(float val) { | |
for (int mask = 1; mask < 32; mask *= 2) { | |
val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); | |
} | |
return val; | |
} | |
__inline__ __device__ float WarpAllReduceSum(float val) { | |
for (int mask = 1; mask < 32; mask *= 2) { | |
val += __shfl_xor_sync(0xffffffff, val, mask); | |
} | |
return val; | |
} | |
template<typename T> | |
__global__ void attn_softmax_inplace_( | |
T *input, | |
long long rows, int cols | |
) { | |
int threadidx_x = threadIdx.x / 32; | |
int threadidx_y = threadIdx.x % 32; | |
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); | |
int cols_per_thread = (cols + 31) / 32; | |
int cols_this_thread = cols_per_thread; | |
int last_y = (cols / cols_per_thread); | |
if (threadidx_y == last_y) { | |
cols_this_thread = cols - cols_per_thread * last_y; | |
} | |
else if (threadidx_y > last_y) { | |
cols_this_thread = 0; | |
} | |
float buf[32]; | |
int lane_id = threadidx_y; | |
if (row_offset < rows) { | |
T *row_input = input + row_offset * cols; | |
T *row_output = row_input; | |
for (int i = 0; i < cols_this_thread; i++) { | |
int idx = lane_id * cols_per_thread + i; | |
buf[i] = static_cast<float>(row_input[idx]); | |
} | |
float thread_max = -1 * CUDART_INF_F; | |
for (int i = 0; i < cols_this_thread; i++) { | |
thread_max = max(thread_max, buf[i]); | |
} | |
float warp_max = WarpAllReduceMax(thread_max); | |
float thread_sum = 0.f; | |
for (int i = 0; i < cols_this_thread; i++) { | |
buf[i] = __expf(buf[i] - warp_max); | |
thread_sum += buf[i]; | |
} | |
float warp_sum = WarpAllReduceSum(thread_sum); | |
for (int i = 0; i < cols_this_thread; i++) { | |
row_output[lane_id * cols_per_thread + i] = | |
static_cast<T>(__fdividef(buf[i], warp_sum)); | |
} | |
} | |
} | |
void attn_softmax_inplace_forward_( | |
at::Tensor input, | |
long long rows, int cols | |
) { | |
CHECK_INPUT(input); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | |
int grid = (rows + 3) / 4; | |
dim3 block(128); | |
if (input.dtype() == torch::kFloat32) { | |
attn_softmax_inplace_<float><<<grid, block>>>( | |
(float *)input.data_ptr(), | |
rows, cols | |
); | |
} | |
else { | |
attn_softmax_inplace_<at::BFloat16><<<grid, block>>>( | |
(at::BFloat16 *)input.data_ptr(), | |
rows, cols | |
); | |
} | |
} | |
template<typename T> | |
__global__ void attn_softmax_inplace_grad_( | |
T *output, | |
T *d_ov, | |
T *values, | |
long long rows, | |
int cols_output, | |
int cols_values | |
) { | |
int threadidx_x = threadIdx.x / 32; | |
int threadidx_y = threadIdx.x % 32; | |
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); | |
int cols_per_thread = (cols_output + 31) / 32; | |
int cols_this_thread = cols_per_thread; | |
int rows_values = cols_output; | |
// values are set to the beginning of the current | |
// rows_values x cols_values leaf matrix | |
long long value_row_offset = row_offset - row_offset % rows_values; | |
int last_y = (cols_output / cols_per_thread); | |
if (threadidx_y == last_y) { | |
cols_this_thread = cols_output - cols_per_thread * last_y; | |
} | |
else if (threadidx_y > last_y) { | |
cols_this_thread = 0; | |
} | |
float y_buf[32]; | |
float dy_buf[32]; | |
int lane_id = threadidx_y; | |
if (row_offset < rows) { | |
T *row_output = output + row_offset * cols_output; | |
T *row_d_ov = d_ov + row_offset * cols_values; | |
T *row_values = values + value_row_offset * cols_values; | |
float thread_max = -1 * CUDART_INF_F; | |
// Compute a chunk of the output gradient on the fly | |
int value_row_idx = 0; | |
int value_idx = 0; | |
for (int i = 0; i < cols_this_thread; i++) { | |
T sum = 0.; | |
for (int j = 0; j < cols_values; j++) { | |
value_row_idx = ((lane_id * cols_per_thread) + i); | |
value_idx = value_row_idx * cols_values + j; | |
sum += row_d_ov[j] * row_values[value_idx]; | |
} | |
dy_buf[i] = static_cast<float>(sum); | |
} | |
for (int i = 0; i < cols_this_thread; i++) { | |
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]); | |
} | |
float thread_sum = 0.; | |
for (int i = 0; i < cols_this_thread; i++) { | |
thread_sum += y_buf[i] * dy_buf[i]; | |
} | |
float warp_sum = WarpAllReduceSum(thread_sum); | |
for (int i = 0; i < cols_this_thread; i++) { | |
row_output[lane_id * cols_per_thread + i] = static_cast<T>( | |
(dy_buf[i] - warp_sum) * y_buf[i] | |
); | |
} | |
} | |
} | |
void attn_softmax_inplace_backward_( | |
at::Tensor output, | |
at::Tensor d_ov, | |
at::Tensor values, | |
long long rows, | |
int cols_output, | |
int cols_values | |
) { | |
CHECK_INPUT(output); | |
CHECK_INPUT(d_ov); | |
CHECK_INPUT(values); | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); | |
int grid = (rows + 3) / 4; | |
dim3 block(128); | |
if (output.dtype() == torch::kFloat32) { | |
attn_softmax_inplace_grad_<float><<<grid, block>>>( | |
(float *)output.data_ptr(), | |
(float *)d_ov.data_ptr(), | |
(float *)values.data_ptr(), | |
rows, cols_output, cols_values | |
); | |
} else { | |
attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>( | |
(at::BFloat16 *)output.data_ptr(), | |
(at::BFloat16 *)d_ov.data_ptr(), | |
(at::BFloat16 *)values.data_ptr(), | |
rows, cols_output, cols_values | |
); | |
} | |
} | |