// Copyright 2021 AlQuraishi Laboratory // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu #include <math_constants.h> #include <torch/extension.h> #include <c10/cuda/CUDAGuard.h> #include <iostream> #include "ATen/ATen.h" #include "ATen/cuda/CUDAContext.h" #include "compat.h" #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) __inline__ __device__ float WarpAllReduceMax(float val) { for (int mask = 1; mask < 32; mask *= 2) { val = max(val, __shfl_xor_sync(0xffffffff, val, mask)); } return val; } __inline__ __device__ float WarpAllReduceSum(float val) { for (int mask = 1; mask < 32; mask *= 2) { val += __shfl_xor_sync(0xffffffff, val, mask); } return val; } template<typename T> __global__ void attn_softmax_inplace_( T *input, long long rows, int cols ) { int threadidx_x = threadIdx.x / 32; int threadidx_y = threadIdx.x % 32; long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); int cols_per_thread = (cols + 31) / 32; int cols_this_thread = cols_per_thread; int last_y = (cols / cols_per_thread); if (threadidx_y == last_y) { cols_this_thread = cols - cols_per_thread * last_y; } else if (threadidx_y > last_y) { cols_this_thread = 0; } float buf[32]; int lane_id = threadidx_y; if (row_offset < rows) { T *row_input = input + row_offset * cols; T *row_output = row_input; #pragma unroll for (int i = 0; i < cols_this_thread; i++) { int idx = lane_id * cols_per_thread + i; buf[i] = static_cast<float>(row_input[idx]); } float thread_max = -1 * CUDART_INF_F; #pragma unroll for (int i = 0; i < cols_this_thread; i++) { thread_max = max(thread_max, buf[i]); } float warp_max = WarpAllReduceMax(thread_max); float thread_sum = 0.f; #pragma unroll for (int i = 0; i < cols_this_thread; i++) { buf[i] = __expf(buf[i] - warp_max); thread_sum += buf[i]; } float warp_sum = WarpAllReduceSum(thread_sum); #pragma unroll for (int i = 0; i < cols_this_thread; i++) { row_output[lane_id * cols_per_thread + i] = static_cast<T>(__fdividef(buf[i], warp_sum)); } } } void attn_softmax_inplace_forward_( at::Tensor input, long long rows, int cols ) { CHECK_INPUT(input); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); int grid = (rows + 3) / 4; dim3 block(128); if (input.dtype() == torch::kFloat32) { attn_softmax_inplace_<float><<<grid, block>>>( (float *)input.data_ptr(), rows, cols ); } else { attn_softmax_inplace_<at::BFloat16><<<grid, block>>>( (at::BFloat16 *)input.data_ptr(), rows, cols ); } } template<typename T> __global__ void attn_softmax_inplace_grad_( T *output, T *d_ov, T *values, long long rows, int cols_output, int cols_values ) { int threadidx_x = threadIdx.x / 32; int threadidx_y = threadIdx.x % 32; long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x); int cols_per_thread = (cols_output + 31) / 32; int cols_this_thread = cols_per_thread; int rows_values = cols_output; // values are set to the beginning of the current // rows_values x cols_values leaf matrix long long value_row_offset = row_offset - row_offset % rows_values; int last_y = (cols_output / cols_per_thread); if (threadidx_y == last_y) { cols_this_thread = cols_output - cols_per_thread * last_y; } else if (threadidx_y > last_y) { cols_this_thread = 0; } float y_buf[32]; float dy_buf[32]; int lane_id = threadidx_y; if (row_offset < rows) { T *row_output = output + row_offset * cols_output; T *row_d_ov = d_ov + row_offset * cols_values; T *row_values = values + value_row_offset * cols_values; float thread_max = -1 * CUDART_INF_F; // Compute a chunk of the output gradient on the fly int value_row_idx = 0; int value_idx = 0; #pragma unroll for (int i = 0; i < cols_this_thread; i++) { T sum = 0.; #pragma unroll for (int j = 0; j < cols_values; j++) { value_row_idx = ((lane_id * cols_per_thread) + i); value_idx = value_row_idx * cols_values + j; sum += row_d_ov[j] * row_values[value_idx]; } dy_buf[i] = static_cast<float>(sum); } #pragma unroll for (int i = 0; i < cols_this_thread; i++) { y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]); } float thread_sum = 0.; #pragma unroll for (int i = 0; i < cols_this_thread; i++) { thread_sum += y_buf[i] * dy_buf[i]; } float warp_sum = WarpAllReduceSum(thread_sum); #pragma unroll for (int i = 0; i < cols_this_thread; i++) { row_output[lane_id * cols_per_thread + i] = static_cast<T>( (dy_buf[i] - warp_sum) * y_buf[i] ); } } } void attn_softmax_inplace_backward_( at::Tensor output, at::Tensor d_ov, at::Tensor values, long long rows, int cols_output, int cols_values ) { CHECK_INPUT(output); CHECK_INPUT(d_ov); CHECK_INPUT(values); const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); int grid = (rows + 3) / 4; dim3 block(128); if (output.dtype() == torch::kFloat32) { attn_softmax_inplace_grad_<float><<<grid, block>>>( (float *)output.data_ptr(), (float *)d_ov.data_ptr(), (float *)values.data_ptr(), rows, cols_output, cols_values ); } else { attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>( (at::BFloat16 *)output.data_ptr(), (at::BFloat16 *)d_ov.data_ptr(), (at::BFloat16 *)values.data_ptr(), rows, cols_output, cols_values ); } }