|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <ATen/ATen.h> |
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <cuda.h> |
|
#include <cuda_fp16.h> |
|
#ifndef __HIP_PLATFORM_HCC__ |
|
#include <cuda_profiler_api.h> |
|
#endif |
|
#include <cuda_runtime.h> |
|
#include <torch/extension.h> |
|
#include "scaled_masked_softmax.h" |
|
#include "type_shim.h" |
|
|
|
namespace multihead_attn { |
|
namespace fused_softmax { |
|
namespace scaled_masked_softmax { |
|
|
|
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads) |
|
{ |
|
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); |
|
} |
|
|
|
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) |
|
{ |
|
|
|
const int batches = input.size(0); |
|
const int pad_batches = mask.size(0); |
|
const int attn_heads = input.size(1); |
|
const int query_seq_len = input.size(2); |
|
const int key_seq_len = input.size(3); |
|
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); |
|
TORCH_INTERNAL_ASSERT(query_seq_len > 1); |
|
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); |
|
TORCH_INTERNAL_ASSERT(mask.size(1) == 1); |
|
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); |
|
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); |
|
|
|
|
|
auto act_options = input.options().requires_grad(false); |
|
torch::Tensor softmax_results = |
|
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); |
|
|
|
|
|
void* input_ptr = static_cast<void*>(input.data_ptr()); |
|
void* mask_ptr = static_cast<void*>(mask.data_ptr()); |
|
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); |
|
|
|
DISPATCH_HALF_AND_BFLOAT(input.scalar_type(), |
|
"dispatch_scaled_masked_softmax_forward", |
|
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>( |
|
reinterpret_cast<scalar_t*>(softmax_results_ptr), |
|
reinterpret_cast<const scalar_t*>(input_ptr), |
|
reinterpret_cast<const uint8_t*>(mask_ptr), |
|
scale_factor, |
|
query_seq_len, |
|
key_seq_len, |
|
batches, |
|
attn_heads, |
|
pad_batches);); |
|
return softmax_results; |
|
} |
|
|
|
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, |
|
torch::Tensor const& softmax_results_, |
|
float scale_factor) |
|
{ |
|
auto output_grads = output_grads_.contiguous(); |
|
auto softmax_results = softmax_results_.contiguous(); |
|
|
|
|
|
const int batches = output_grads.size(0); |
|
const int attn_heads = output_grads.size(1); |
|
const int query_seq_len = output_grads.size(2); |
|
const int key_seq_len = output_grads.size(3); |
|
|
|
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); |
|
|
|
|
|
DISPATCH_HALF_AND_BFLOAT(output_grads_.scalar_type(), |
|
"dispatch_scaled_masked_softmax_backward", |
|
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>( |
|
reinterpret_cast<scalar_t*>(output_grads_ptr), |
|
reinterpret_cast<scalar_t*>(output_grads_ptr), |
|
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()), |
|
scale_factor, |
|
query_seq_len, |
|
key_seq_len, |
|
batches, |
|
attn_heads);); |
|
|
|
|
|
return output_grads; |
|
} |
|
} |
|
} |
|
} |
|
|