DockFormerPP / dockformerpp /utils /kernel /csrc /softmax_cuda_kernel.cu
bshor's picture
add code
0fdcb79
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__ __device__ float WarpAllReduceMax(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
__inline__ __device__ float WarpAllReduceSum(float val) {
for (int mask = 1; mask < 32; mask *= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
template<typename T>
__global__ void attn_softmax_inplace_(
T *input,
long long rows, int cols
) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_input = input + row_offset * cols;
T *row_output = row_input;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
int idx = lane_id * cols_per_thread + i;
buf[i] = static_cast<float>(row_input[idx]);
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
row_output[lane_id * cols_per_thread + i] =
static_cast<T>(__fdividef(buf[i], warp_sum));
}
}
}
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
attn_softmax_inplace_<float><<<grid, block>>>(
(float *)input.data_ptr(),
rows, cols
);
}
else {
attn_softmax_inplace_<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)input.data_ptr(),
rows, cols
);
}
}
template<typename T>
__global__ void attn_softmax_inplace_grad_(
T *output,
T *d_ov,
T *values,
long long rows,
int cols_output,
int cols_values
) {
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
long long row_offset = (long long)(blockIdx.x * 4 + threadidx_x);
int cols_per_thread = (cols_output + 31) / 32;
int cols_this_thread = cols_per_thread;
int rows_values = cols_output;
// values are set to the beginning of the current
// rows_values x cols_values leaf matrix
long long value_row_offset = row_offset - row_offset % rows_values;
int last_y = (cols_output / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols_output - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
T *row_output = output + row_offset * cols_output;
T *row_d_ov = d_ov + row_offset * cols_values;
T *row_values = values + value_row_offset * cols_values;
float thread_max = -1 * CUDART_INF_F;
// Compute a chunk of the output gradient on the fly
int value_row_idx = 0;
int value_idx = 0;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
T sum = 0.;
#pragma unroll
for (int j = 0; j < cols_values; j++) {
value_row_idx = ((lane_id * cols_per_thread) + i);
value_idx = value_row_idx * cols_values + j;
sum += row_d_ov[j] * row_values[value_idx];
}
dy_buf[i] = static_cast<float>(sum);
}
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.;
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
row_output[lane_id * cols_per_thread + i] = static_cast<T>(
(dy_buf[i] - warp_sum) * y_buf[i]
);
}
}
}
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
) {
CHECK_INPUT(output);
CHECK_INPUT(d_ov);
CHECK_INPUT(values);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
attn_softmax_inplace_grad_<float><<<grid, block>>>(
(float *)output.data_ptr(),
(float *)d_ov.data_ptr(),
(float *)values.data_ptr(),
rows, cols_output, cols_values
);
} else {
attn_softmax_inplace_grad_<at::BFloat16><<<grid, block>>>(
(at::BFloat16 *)output.data_ptr(),
(at::BFloat16 *)d_ov.data_ptr(),
(at::BFloat16 *)values.data_ptr(),
rows, cols_output, cols_values
);
}
}