[CUDA] cuDNN Flash Attention (#21629)
### Description - [x] Add cuDNN flash attention using cudnn frontend, and enable it in MultiHeadAttention operator. - [x] Support attention mask. - [x] Support attention bias. - [x] Update tests and benchmark script. The cuDNN SDPA is disabled by default. To enable it, need the following: (1) Requires cuDNN 9.3 or newer version installed. (2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or set `sdpa_kernel=8` cuda provider option to enable it. (3) Only works on devices with compute capability >= 8.0. Note that some combinations of parameters might be rejected due to limited support of head dimension or sequence lengths. Future Works: (1) FP8 and BF16 APIs. Currently, only API for FP16 are exposed. (2) Add API to support ragged batching (padding removed in inputs). (3) Support other input formats (like QKV_BS3NH). (4) Currently, q are converted to BSNH, k/v are converted to either BSNH or BNSH format. May do some experiment to see whether converting q to BNSH could be better in some case. ### Example Benchmark Results on H100 The following tests are on FP16 MultiHeadAttention operator without attention mask and attention bias. #### Test Setting 1 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 256 | 0 | 32 | 128 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math Q,KV | 0.000129 | 133.0 | ort:cudnn Q,KV | 0.000151 | 114.1 | ort:flash Q,KV | 0.000194 | 88.5 | ort:efficient QKV | 0.000154 | 111.8 | ort:cudnn QKV | 0.000175 | 98.0 | ort:flash QKV | 0.000217 | 79.0 | ort:efficient #### Test Setting 2 batch_size | sequence_length | past_sequence_length | num_heads | head_size -- | -- | -- | -- | -- 16 | 512 | 0 | 16 | 64 format | average_latency | tflops | kernel -- | -- | -- | -- Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn Q,K,V (BSNH) | 0.000087 | 196.6 | ort:flash Q,K,V (BSNH) | 0.000163 | 105.6 | ort:efficient Q,K,V (BSNH) | 0.000651 | 26.4 | ort:math Q,KV | 0.000103 | 167.1 | ort:cudnn Q,KV | 0.000117 | 146.3 | ort:flash Q,KV | 0.000192 | 89.6 | ort:efficient QKV | 0.000113 | 151.5 | ort:cudnn QKV | 0.000128 | 134.7 | ort:flash QKV | 0.000201 | 85.3 | ort:efficient
This commit is contained in:
Родитель
9f7e19cedd
Коммит
fbc3927231
|
@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
|
|||
CUDNN::cudnn_heuristic
|
||||
)
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CUDNN_INCLUDE_DIR)
|
||||
|
|
|
@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
|||
|
||||
# GLOB pattern of file to be excluded
|
||||
set(contrib_ops_excluded_files
|
||||
"bert/cudnn_fmha/*"
|
||||
"bert/cutlass_fmha/*"
|
||||
"bert/fastertransformer_decoder_attention/*"
|
||||
"bert/flash_attention/*"
|
||||
|
|
|
@ -47,6 +47,7 @@ enum AttentionKernelType {
|
|||
AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernel_CutlassMemoryEfficientAttention,
|
||||
AttentionKernel_FlashAttention,
|
||||
AttentionKernel_CudnnFlashAttention,
|
||||
AttentionKernel_Default
|
||||
};
|
||||
|
||||
|
|
|
@ -246,6 +246,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
constexpr bool use_fused_cross_attention = false;
|
||||
constexpr bool use_cudnn_flash_attention = false;
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
parameters.batch_size,
|
||||
parameters.num_heads,
|
||||
|
@ -258,6 +259,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
use_flash_attention,
|
||||
use_fused_cross_attention,
|
||||
use_memory_efficient_attention,
|
||||
use_cudnn_flash_attention,
|
||||
false);
|
||||
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
|
||||
|
||||
|
@ -294,7 +296,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
|
||||
}
|
||||
|
||||
return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
|
||||
cudnnHandle_t cudnn = GetCudnnHandle(context);
|
||||
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
@ -37,6 +37,7 @@ limitations under the License.
|
|||
#include "contrib_ops/cuda/bert/bert_padding.h"
|
||||
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
|
||||
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
|
||||
|
@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
|
|||
bool use_flash_attention,
|
||||
bool use_fused_cross_attention,
|
||||
bool use_memory_efficient_attention,
|
||||
bool use_cudnn_flash_attention,
|
||||
bool no_qkv_workspace) {
|
||||
// Note that q, k and v might need alignment for fused attention kernels.
|
||||
const size_t qkv_size = element_size * batch_size * num_heads *
|
||||
|
@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
|
|||
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
|
||||
}
|
||||
|
||||
if (use_cudnn_flash_attention) {
|
||||
return qkv_bytes;
|
||||
}
|
||||
|
||||
return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
|
||||
total_sequence_length);
|
||||
}
|
||||
|
@ -320,6 +326,68 @@ Status FlashAttention(
|
|||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
Status CudnnFlashAttention(
|
||||
cudnnHandle_t cudnn_handle,
|
||||
Stream* ort_stream,
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data,
|
||||
float scale) {
|
||||
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
|
||||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
|
||||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
|
||||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
|
||||
constexpr bool is_bf16 = false;
|
||||
|
||||
T* attention_bias = const_cast<T*>(data.attention_bias);
|
||||
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);
|
||||
|
||||
cudnn_sdpa::run(
|
||||
data.output,
|
||||
data.q,
|
||||
data.k,
|
||||
data.v,
|
||||
attention_bias,
|
||||
nullptr, // (optional) mask_sequence_lengths_q
|
||||
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
|
||||
parameters.batch_size,
|
||||
parameters.num_heads, // num_heads_q,
|
||||
parameters.num_heads, // num_heads_kv,
|
||||
parameters.head_size, // head_size_qk
|
||||
parameters.v_head_size, // head_size_v
|
||||
parameters.sequence_length, // sequence_length_q
|
||||
parameters.total_sequence_length, // sequence_length_kv
|
||||
scale, // scaling factor applied prior softmax
|
||||
parameters.is_unidirectional, // causal
|
||||
is_bf16, // True if bfloat16, otherwise float16
|
||||
parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not
|
||||
parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not
|
||||
0, // sliding window length. 0 means no sliding window.
|
||||
data.qkv_format,
|
||||
cudnn_handle,
|
||||
ort_stream,
|
||||
data.allocator);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status CudnnFlashAttention(
|
||||
cudnnHandle_t cudnn_handle,
|
||||
Stream* ort_stream,
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<float>& data,
|
||||
float scale) {
|
||||
ORT_UNUSED_PARAMETER(cudnn_handle);
|
||||
ORT_UNUSED_PARAMETER(ort_stream);
|
||||
ORT_UNUSED_PARAMETER(parameters);
|
||||
ORT_UNUSED_PARAMETER(data);
|
||||
ORT_UNUSED_PARAMETER(scale);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
|
||||
"cudnn flash attention does not support float tensor");
|
||||
}
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
template <typename T>
|
||||
Status EfficientAttention(
|
||||
|
@ -498,6 +566,7 @@ template <typename T>
|
|||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cublasHandle_t& cublas,
|
||||
cudnnHandle_t& cudnn,
|
||||
Stream* ort_stream,
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data) {
|
||||
|
@ -512,10 +581,11 @@ Status QkvToContext(
|
|||
void* fused_runner = data.fused_runner;
|
||||
|
||||
// At most one fused kernel is enabled.
|
||||
assert((int(data.use_flash_attention) +
|
||||
int(data.use_memory_efficient_attention) +
|
||||
int(fused_runner != nullptr) +
|
||||
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
|
||||
assert((static_cast<int>(data.use_flash_attention) +
|
||||
static_cast<int>(data.use_memory_efficient_attention) +
|
||||
static_cast<int>(fused_runner != nullptr) +
|
||||
static_cast<int>(data.fused_cross_attention_kernel != nullptr) +
|
||||
static_cast<int>(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);
|
||||
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));
|
||||
|
||||
|
@ -577,6 +647,10 @@ Status QkvToContext(
|
|||
}
|
||||
#endif
|
||||
|
||||
if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
|
||||
}
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
if (data.use_memory_efficient_attention) {
|
||||
return EfficientAttention(device_prop, stream, parameters, data, scale);
|
||||
|
@ -594,6 +668,7 @@ template struct AttentionData<half>;
|
|||
template Status QkvToContext<float>(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cublasHandle_t& cublas,
|
||||
cudnnHandle_t& cudnn,
|
||||
Stream* ort_stream,
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<float>& data);
|
||||
|
@ -601,6 +676,7 @@ template Status QkvToContext<float>(
|
|||
template Status QkvToContext<half>(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cublasHandle_t& cublas,
|
||||
cudnnHandle_t& cudnn,
|
||||
Stream* ort_stream,
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<half>& data);
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <iostream>
|
||||
#include <mutex>
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
|
|||
bool use_flash_attention,
|
||||
bool use_fused_cross_attention,
|
||||
bool use_memory_efficient_attention,
|
||||
bool use_cudnn_flash_attention,
|
||||
bool no_qkv_workspace);
|
||||
|
||||
template <typename T>
|
||||
|
@ -104,9 +106,11 @@ struct AttentionData {
|
|||
size_t workspace_bytes = 0;
|
||||
bool allow_debug_info = false;
|
||||
|
||||
// For MultiHeadAttention only.
|
||||
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
|
||||
AllocatorPtr allocator = nullptr;
|
||||
bool IsUnfused() const {
|
||||
return !use_flash_attention && !use_memory_efficient_attention &&
|
||||
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
|
||||
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
|
||||
}
|
||||
|
||||
void PrintDebugInfo() const {
|
||||
|
@ -139,6 +143,7 @@ template <typename T>
|
|||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
cublasHandle_t& cublas,
|
||||
cudnnHandle_t& cudnn,
|
||||
Stream* stream,
|
||||
contrib::AttentionParameters& parameters,
|
||||
AttentionData<T>& data);
|
||||
|
|
|
@ -9,11 +9,12 @@
|
|||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
|
||||
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
|
||||
|
||||
using namespace onnxruntime::contrib::attention;
|
||||
|
||||
namespace onnxruntime {
|
||||
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
|
||||
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
|
||||
if (value > 0) {
|
||||
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
|
||||
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
|
||||
|
@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
|
|||
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
|
||||
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
|
||||
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
|
||||
|
||||
use_unfused_ = true;
|
||||
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
|
||||
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
|
||||
|
@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
|
|||
kMinSeqLenForEfficientAttentionFp32,
|
||||
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);
|
||||
|
||||
// Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0).
|
||||
if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) {
|
||||
use_cudnn_flash_attention_ = false;
|
||||
if (enable_kernel_debug_info_) {
|
||||
std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_build_flag) {
|
||||
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
|
||||
#ifndef USE_FLASH_ATTENTION
|
||||
|
@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
|
|||
}
|
||||
|
||||
void AttentionKernelOptions::InitializeOnce(
|
||||
int sdpa_kernel, bool use_build_flag) {
|
||||
int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) {
|
||||
std::call_once(this->initialize_once_flag_, [&]() {
|
||||
this->Initialize(sdpa_kernel, use_build_flag);
|
||||
this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version);
|
||||
if (this->enable_kernel_debug_info_) {
|
||||
this->Print();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo {
|
|||
|
||||
class AttentionKernelOptions {
|
||||
public:
|
||||
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
|
||||
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);
|
||||
|
||||
bool UseFlashAttention() const { return use_flash_attention_; }
|
||||
bool UseEfficientAttention() const { return use_efficient_attention_; }
|
||||
|
@ -40,7 +40,7 @@ class AttentionKernelOptions {
|
|||
protected:
|
||||
void Print() const;
|
||||
|
||||
void Initialize(int value, bool use_build_flag);
|
||||
void Initialize(int value, bool use_build_flag, bool check_cudnn_version);
|
||||
|
||||
private:
|
||||
bool use_flash_attention_{true};
|
||||
|
|
|
@ -169,7 +169,10 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
|
|||
template <typename T>
|
||||
bool NoQkvWorkspace_MHA_Cross(AttentionData<T>& data) {
|
||||
// query, key and value are passed as Q, K and V for the following conditions.
|
||||
return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr);
|
||||
return (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) &&
|
||||
data.bias == nullptr;
|
||||
}
|
||||
|
||||
// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format)
|
||||
|
@ -190,8 +193,9 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
|
|||
const int num_heads = parameters.num_heads;
|
||||
const int qk_head_size = parameters.head_size;
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
if (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
// Add bias for Q
|
||||
if (data.bias != nullptr) {
|
||||
LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
|
||||
|
@ -204,9 +208,7 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
|
|||
data.k = const_cast<T*>(data.key);
|
||||
data.v = const_cast<T*>(data.value);
|
||||
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
|
||||
} else
|
||||
#endif
|
||||
{ // unfused kernel
|
||||
} else { // unfused kernel
|
||||
assert(data.IsUnfused());
|
||||
if (data.bias == nullptr) {
|
||||
// Transpose query from BSNH to BNSH
|
||||
|
@ -233,7 +235,10 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
|
|||
template <typename T>
|
||||
bool NoQkvWorkspace_MHA_NoPast(AttentionData<T>& data) {
|
||||
// query, key and value are passed as Q, K and V for the following conditions.
|
||||
return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr;
|
||||
return (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) &&
|
||||
data.bias == nullptr;
|
||||
}
|
||||
|
||||
// For MultiHeadAttention without past state, with Q, K and V inputs
|
||||
|
@ -275,9 +280,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
|
|||
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
|
||||
data.v = nullptr;
|
||||
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
|
||||
}
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
} else if (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
if (data.bias != nullptr) {
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
|
@ -290,9 +295,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
|
|||
}
|
||||
|
||||
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
}
|
||||
#endif
|
||||
else if (data.fused_runner != nullptr) {
|
||||
} else if (data.fused_runner != nullptr) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
assert(data.attention_bias == nullptr);
|
||||
|
||||
|
@ -338,7 +341,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
|
|||
|
||||
template <typename T>
|
||||
bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData<T>& data) {
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
if (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
// Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV.
|
||||
return data.past_key == nullptr && data.present_key != nullptr;
|
||||
}
|
||||
|
@ -377,8 +382,9 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
|
|||
data.v = data.present_value;
|
||||
}
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
if (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
// Use oiginal Query (BSNH) since there is no bias.
|
||||
data.q = const_cast<T*>(data.query);
|
||||
|
||||
|
@ -389,9 +395,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
|
|||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, data.value, data.v));
|
||||
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
|
||||
} else
|
||||
#endif
|
||||
{ // unfused kernel
|
||||
} else { // unfused kernel
|
||||
assert(data.IsUnfused());
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, data.query, data.q));
|
||||
|
@ -440,8 +444,9 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
|
|||
data.v = data.present_value;
|
||||
}
|
||||
|
||||
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
if (data.use_memory_efficient_attention ||
|
||||
data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
// Query(BxSxNxH) + Bias_Q => Q (BxSxNxH)
|
||||
LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.bias, data.query, data.q);
|
||||
|
@ -460,9 +465,7 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
|
|||
data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1);
|
||||
|
||||
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
|
||||
} else
|
||||
#endif
|
||||
{ // unfused kernel
|
||||
} else { // unfused kernel
|
||||
assert(data.IsUnfused());
|
||||
|
||||
constexpr int format = 0;
|
||||
|
@ -518,7 +521,8 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
|
|||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
// unpack qkv to BSNH.
|
||||
constexpr int format = 4;
|
||||
T* qkv_add_bias = nullptr;
|
||||
|
@ -590,7 +594,8 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
|
|||
const int qk_head_size = parameters.head_size;
|
||||
const int v_head_size = parameters.v_head_size;
|
||||
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention) {
|
||||
if (data.use_memory_efficient_attention || data.use_flash_attention ||
|
||||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
// Note that there is no bias so we need not output query to q.
|
||||
data.q = const_cast<T*>(data.query);
|
||||
// Unpack kv to BSNH.
|
||||
|
|
|
@ -0,0 +1,405 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <cudnn.h>
|
||||
|
||||
#if CUDNN_MAJOR < 9
|
||||
namespace onnxruntime::cudnn_sdpa {
|
||||
|
||||
bool is_stable() {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_supported(const cudaDeviceProp& /*dprops*/,
|
||||
int /*num_heads_q*/,
|
||||
int /*num_heads_kv*/,
|
||||
int /*head_size_qk*/,
|
||||
int /*head_size_v*/,
|
||||
int /*sequence_length_q*/,
|
||||
int /*sequence_length_kv*/,
|
||||
bool /*is_causal*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
void run(
|
||||
void* /*output*/,
|
||||
void* /*q*/,
|
||||
void* /*k*/,
|
||||
void* /*v*/,
|
||||
void* /*bias*/,
|
||||
int* /*mask_sequence_lengths_q*/,
|
||||
int* /*mask_sequence_lengths_kv*/,
|
||||
int /*batch_size*/,
|
||||
int /*num_heads_q*/,
|
||||
int /*num_heads_kv*/,
|
||||
int /*head_size_qk*/,
|
||||
int /*head_size_v*/,
|
||||
int /*sequence_length_q*/,
|
||||
int /*sequence_length_kv*/,
|
||||
float /*scale*/,
|
||||
bool /*is_causal*/,
|
||||
bool /*is_bf16*/,
|
||||
bool /*broadcast_attn_bias_dim_0*/,
|
||||
bool /*broadcast_attn_bias_dim_1*/,
|
||||
int /*sliding_window*/,
|
||||
AttentionQkvFormat /*qkv_format*/,
|
||||
cudnnHandle_t /*handle*/,
|
||||
Stream* /*stream*/,
|
||||
AllocatorPtr /*allocator*/) {
|
||||
ORT_THROW("OnnxRuntime was not compiled with cuDNN Flash Attention.");
|
||||
}
|
||||
|
||||
} // namespace onnxruntime::cudnn_sdpa
|
||||
|
||||
#else // CUDNN_MAJOR >= 9
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
|
||||
#include "core/providers/cuda/cuda_stream_handle.h"
|
||||
|
||||
namespace onnxruntime::cudnn_sdpa {
|
||||
|
||||
bool is_stable() {
|
||||
// FP16/BF16 Flash Attention support in CUDNN backend:
|
||||
// version 8903 (8.9.3):
|
||||
// Padding mask and causal mask
|
||||
// Additive bias
|
||||
// Multi-query attention (h_kv=1)
|
||||
// Both self attention and cross attention
|
||||
// (padded) variable sequence length
|
||||
// Head dimensions 64 or 128
|
||||
// version 8903 (8.9.4):
|
||||
// Alibi mask;
|
||||
// version 8907 (8.9.7):
|
||||
// Grouped Query Attention
|
||||
// version 90100 (9.1.0):
|
||||
// Head dimensions 256
|
||||
// version 90101 (9.1.1)
|
||||
// Sliding window attention
|
||||
// version 90300 (9.3.0)
|
||||
// Bug fixes; Variable sequence length supports zero-sequence-length values
|
||||
// For more information, please refer to cuDNN release notes, and the following links:
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop
|
||||
// https://github.com/NVIDIA/cudnn-frontend/blob/v1.5.2/docs/operations/Attention.md
|
||||
|
||||
// For cuDNN version < 9.3, we will disable it by default.
|
||||
return cudnnGetVersion() >= 90300;
|
||||
}
|
||||
|
||||
namespace fe = cudnn_frontend;
|
||||
|
||||
bool is_supported(const cudaDeviceProp& dprops,
|
||||
int num_heads_q,
|
||||
int num_heads_kv,
|
||||
int head_size_qk,
|
||||
int head_size_v,
|
||||
int sequence_length_q,
|
||||
int sequence_length_kv,
|
||||
bool is_causal) {
|
||||
bool is_sm8x = dprops.major == 8 && dprops.minor >= 0;
|
||||
bool is_sm90 = dprops.major == 9 && dprops.minor == 0;
|
||||
return (is_sm8x || is_sm90) &&
|
||||
(head_size_qk % 8 == 0) && (head_size_qk <= 256) &&
|
||||
(head_size_v % 8 == 0) && (head_size_v <= 256) &&
|
||||
(num_heads_q % num_heads_kv == 0) &&
|
||||
// Bottom right causal mask is only supported with s_q multiple of 64 and s_kv multiple of 64.
|
||||
(!is_causal || (sequence_length_q != sequence_length_kv &&
|
||||
sequence_length_q % 64 == 0 &&
|
||||
sequence_length_kv % 64 == 0));
|
||||
}
|
||||
|
||||
// A helper function to set stride for q, k, v or output tensor.
|
||||
// Strides are calculated based on logical tensor layout BNSH (batch_size, num_heads, sequence_length, head_size).
|
||||
// The physical tensor layout could be either BSNH (is_bsnh=True) or BNSH (is_bsnh=False).
|
||||
inline void set_stride(std::vector<int64_t>& stride,
|
||||
int64_t num_heads,
|
||||
int64_t sequence_length,
|
||||
int64_t head_size,
|
||||
bool is_bsnh) {
|
||||
stride = {num_heads * sequence_length * head_size, // stride for batch.
|
||||
is_bsnh ? head_size : (head_size * sequence_length), // stride for head.
|
||||
is_bsnh ? (num_heads * head_size) : head_size, // stride for sequence.
|
||||
1}; // stride for hidden dim of head, shall always be 1.
|
||||
}
|
||||
|
||||
// It is used as a key for hash table to store cached graphs.
|
||||
// It contains all parameters used in builing graph. Do not include data pointers that only needed in graph execution.
|
||||
struct GraphParams {
|
||||
int batch_size;
|
||||
int num_heads_q;
|
||||
int num_heads_kv;
|
||||
int head_size_qk;
|
||||
int head_size_v;
|
||||
int sequence_length_q;
|
||||
int sequence_length_kv;
|
||||
float scale;
|
||||
bool is_causal;
|
||||
bool is_bf16; // True if bfloat16, otherwise float16
|
||||
AttentionQkvFormat qkv_format;
|
||||
cudnnHandle_t handle;
|
||||
bool has_bias;
|
||||
bool broadcast_bias_dim_0;
|
||||
bool broadcast_bias_dim_1;
|
||||
bool has_padding_mask_q;
|
||||
bool has_padding_mask_kv;
|
||||
int sliding_window;
|
||||
|
||||
bool operator==(const GraphParams& rhs) const {
|
||||
return batch_size == rhs.batch_size &&
|
||||
num_heads_q == rhs.num_heads_q &&
|
||||
num_heads_kv == rhs.num_heads_kv &&
|
||||
head_size_qk == rhs.head_size_qk &&
|
||||
head_size_v == rhs.head_size_v &&
|
||||
sequence_length_q == rhs.sequence_length_q &&
|
||||
sequence_length_kv == rhs.sequence_length_kv &&
|
||||
scale == rhs.scale &&
|
||||
is_causal == rhs.is_causal &&
|
||||
is_bf16 == rhs.is_bf16 &&
|
||||
qkv_format == rhs.qkv_format &&
|
||||
handle == rhs.handle &&
|
||||
has_bias == rhs.has_bias &&
|
||||
broadcast_bias_dim_0 == rhs.broadcast_bias_dim_0 &&
|
||||
broadcast_bias_dim_1 == rhs.broadcast_bias_dim_1 &&
|
||||
has_padding_mask_q == rhs.has_padding_mask_q &&
|
||||
has_padding_mask_kv == rhs.has_padding_mask_kv &&
|
||||
sliding_window == rhs.sliding_window;
|
||||
}
|
||||
};
|
||||
|
||||
#define Q_UID 1
|
||||
#define K_UID 2
|
||||
#define V_UID 3
|
||||
#define O_UID 4
|
||||
#define BIAS_UID 5
|
||||
#define SEQ_LEN_Q_UID 6
|
||||
#define SEQ_LEN_KV_UID 7
|
||||
|
||||
std::shared_ptr<fe::graph::Graph> build_graph(GraphParams& params) {
|
||||
int batch_size = params.batch_size;
|
||||
int num_heads_q = params.num_heads_q;
|
||||
int num_heads_kv = params.num_heads_kv;
|
||||
int head_size_qk = params.head_size_qk;
|
||||
int head_size_v = params.head_size_v;
|
||||
int sequence_length_q = params.sequence_length_q;
|
||||
int sequence_length_kv = params.sequence_length_kv;
|
||||
float scale = params.scale;
|
||||
bool is_causal = params.is_causal;
|
||||
bool is_bf16 = params.is_bf16;
|
||||
AttentionQkvFormat qkv_format = params.qkv_format;
|
||||
cudnnHandle_t handle = params.handle;
|
||||
|
||||
assert(qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH ||
|
||||
qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
|
||||
qkv_format == contrib::AttentionQkvFormat::Q_K_V_BNSH);
|
||||
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
mha_graph->set_io_data_type(is_bf16 ? fe::DataType_t::BFLOAT16 : fe::DataType_t::HALF)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
bool is_q_bsnh = (qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH ||
|
||||
qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
|
||||
bool is_kv_bsnh = qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH;
|
||||
|
||||
std::vector<int64_t> stride;
|
||||
set_stride(stride, num_heads_q, sequence_length_q, head_size_qk, is_q_bsnh);
|
||||
|
||||
auto Q = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_uid(Q_UID)
|
||||
.set_dim({batch_size, num_heads_q, sequence_length_q, head_size_qk}) // logical layout
|
||||
.set_stride(stride));
|
||||
|
||||
set_stride(stride, num_heads_kv, sequence_length_kv, head_size_qk, is_kv_bsnh);
|
||||
auto K = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_uid(K_UID)
|
||||
.set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_qk})
|
||||
.set_stride(stride));
|
||||
|
||||
set_stride(stride, num_heads_kv, sequence_length_kv, head_size_v, is_kv_bsnh);
|
||||
auto V = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_uid(V_UID)
|
||||
.set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_v})
|
||||
.set_stride(stride));
|
||||
|
||||
auto attributes = fe::graph::SDPA_attributes()
|
||||
.set_name("SDPA")
|
||||
.set_is_inference(true)
|
||||
.set_causal_mask(is_causal)
|
||||
.set_causal_mask_bottom_right(is_causal && sequence_length_q != sequence_length_kv)
|
||||
.set_attn_scale(scale);
|
||||
|
||||
if (params.sliding_window > 0) {
|
||||
attributes.set_sliding_window_length(params.sliding_window);
|
||||
}
|
||||
|
||||
if (params.has_bias) {
|
||||
std::vector<int64_t> bias_shape = {params.broadcast_bias_dim_0 ? 1 : batch_size,
|
||||
params.broadcast_bias_dim_1 ? 1 : num_heads_q,
|
||||
sequence_length_q,
|
||||
sequence_length_kv};
|
||||
stride = {bias_shape[1] * bias_shape[2] * bias_shape[3], bias_shape[2] * bias_shape[3], bias_shape[3], 1};
|
||||
auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("bias")
|
||||
.set_uid(BIAS_UID)
|
||||
.set_dim(bias_shape)
|
||||
.set_stride(stride));
|
||||
attributes.set_bias(bias);
|
||||
}
|
||||
|
||||
if (params.has_padding_mask_q || params.has_padding_mask_kv) {
|
||||
attributes.set_padding_mask(true);
|
||||
|
||||
if (params.has_padding_mask_q) {
|
||||
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("seq_q")
|
||||
.set_uid(SEQ_LEN_Q_UID)
|
||||
.set_dim({batch_size, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
attributes.set_seq_len_q(seq_q);
|
||||
}
|
||||
|
||||
if (params.has_padding_mask_kv) {
|
||||
auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("seq_kv")
|
||||
.set_uid(SEQ_LEN_KV_UID)
|
||||
.set_dim({batch_size, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
attributes.set_seq_len_kv(seq_kv);
|
||||
}
|
||||
}
|
||||
|
||||
auto [O, Stats] = mha_graph->sdpa(Q, K, V, attributes);
|
||||
|
||||
constexpr bool is_output_bsnh = true;
|
||||
set_stride(stride, num_heads_q, sequence_length_q, head_size_v, is_output_bsnh);
|
||||
|
||||
O->set_output(true)
|
||||
.set_dim({batch_size, num_heads_q, sequence_length_q, head_size_v})
|
||||
.set_stride(stride)
|
||||
.set_uid(O_UID);
|
||||
|
||||
if (!mha_graph->build(handle, {fe::HeurMode_t::A}).is_good()) {
|
||||
ORT_THROW("Failed to build cuDNN graph for Flash Attention:", *mha_graph, "cudnn version:", cudnnGetVersion());
|
||||
}
|
||||
|
||||
return mha_graph;
|
||||
}
|
||||
|
||||
// Compute hash based on content in memory byte by byte. This can be moved to a common header file if needed.
|
||||
template <typename T>
|
||||
struct BytesHash {
|
||||
// Verify that Params is good to hash byte by byte.
|
||||
static_assert(std::is_standard_layout_v<T>, "Params is not standard layout");
|
||||
|
||||
size_t operator()(const T& params) const {
|
||||
auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
|
||||
// Fowler–Noll–Vo hash function
|
||||
uint32_t value = 0x811C9DC5;
|
||||
constexpr size_t bytes = sizeof(T);
|
||||
for (size_t i = 0; i < bytes; ++i) {
|
||||
value ^= ptr[i];
|
||||
value *= 0x01000193;
|
||||
}
|
||||
return static_cast<size_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
|
||||
// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size.
|
||||
thread_local
|
||||
std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;
|
||||
|
||||
void run(
|
||||
void* output,
|
||||
void* q,
|
||||
void* k,
|
||||
void* v,
|
||||
void* attn_bias,
|
||||
int* mask_sequence_lengths_q,
|
||||
int* mask_sequence_lengths_kv,
|
||||
int batch_size,
|
||||
int num_heads_q,
|
||||
int num_heads_kv,
|
||||
int head_size_qk,
|
||||
int head_size_v,
|
||||
int sequence_length_q,
|
||||
int sequence_length_kv,
|
||||
float scale,
|
||||
bool is_causal,
|
||||
bool is_bf16,
|
||||
bool broadcast_attn_bias_dim_0,
|
||||
bool broadcast_attn_bias_dim_1,
|
||||
int sliding_window,
|
||||
AttentionQkvFormat qkv_format,
|
||||
cudnnHandle_t handle,
|
||||
Stream* stream,
|
||||
AllocatorPtr allocator) {
|
||||
|
||||
GraphParams params;
|
||||
params.batch_size = batch_size;
|
||||
params.num_heads_q = num_heads_q;
|
||||
params.num_heads_kv = num_heads_kv;
|
||||
params.head_size_qk = head_size_qk;
|
||||
params.head_size_v = head_size_v;
|
||||
params.sequence_length_q = sequence_length_q;
|
||||
params.sequence_length_kv = sequence_length_kv;
|
||||
params.scale = scale;
|
||||
params.is_causal = is_causal;
|
||||
params.is_bf16 = is_bf16;
|
||||
params.qkv_format = qkv_format;
|
||||
params.handle = handle;
|
||||
params.has_bias = attn_bias != nullptr;
|
||||
params.broadcast_bias_dim_0 = broadcast_attn_bias_dim_0;
|
||||
params.broadcast_bias_dim_1 = broadcast_attn_bias_dim_1;
|
||||
params.has_padding_mask_q = (mask_sequence_lengths_q != nullptr);
|
||||
params.has_padding_mask_kv = (mask_sequence_lengths_kv != nullptr);
|
||||
params.sliding_window = sliding_window;
|
||||
|
||||
std::shared_ptr<fe::graph::Graph> mha_graph;
|
||||
auto it = mha_graph_cache.find(params);
|
||||
if (it != mha_graph_cache.end()) {
|
||||
mha_graph = it->second;
|
||||
} else {
|
||||
mha_graph = build_graph(params);
|
||||
mha_graph_cache[params] = mha_graph;
|
||||
}
|
||||
|
||||
std::unordered_map<fe::graph::Tensor_attributes::uid_t, void*> variant_pack = {
|
||||
{Q_UID, q},
|
||||
{K_UID, k},
|
||||
{V_UID, v},
|
||||
{O_UID, output},
|
||||
};
|
||||
|
||||
if (attn_bias != nullptr) {
|
||||
variant_pack[BIAS_UID] = attn_bias;
|
||||
}
|
||||
|
||||
if (mask_sequence_lengths_q != nullptr) {
|
||||
variant_pack[SEQ_LEN_Q_UID] = mask_sequence_lengths_q;
|
||||
}
|
||||
|
||||
if (mask_sequence_lengths_kv != nullptr) {
|
||||
variant_pack[SEQ_LEN_KV_UID] = mask_sequence_lengths_kv;
|
||||
}
|
||||
|
||||
// Allocate workspace.
|
||||
auto bytes = mha_graph->get_workspace_size();
|
||||
|
||||
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(
|
||||
allocator, bytes, false, stream, WaitCudaNotificationOnDevice);
|
||||
|
||||
CUDNN_FE_CALL_THROW(mha_graph->execute(handle, variant_pack, buffer.get()));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime::cudnn_sdpa
|
||||
#endif
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
using onnxruntime::Stream;
|
||||
using onnxruntime::contrib::AttentionQkvFormat;
|
||||
|
||||
namespace onnxruntime::cudnn_sdpa {
|
||||
|
||||
bool is_stable();
|
||||
|
||||
bool is_supported(const cudaDeviceProp& dprops,
|
||||
int num_heads_q,
|
||||
int num_heads_kv,
|
||||
int head_size_qk,
|
||||
int head_size_v,
|
||||
int sequence_length_q,
|
||||
int sequence_length_kv,
|
||||
bool is_causal);
|
||||
|
||||
void run(
|
||||
void* output,
|
||||
void* q,
|
||||
void* k,
|
||||
void* v,
|
||||
void* bias, // (optional) attention bias with shape [b or 1, h_q or 1, s_q, s_kv].
|
||||
int* mask_sequence_lengths_q, // (optional) sequence lengths of q for padding mask. Shape: [batch_size]
|
||||
int* mask_sequence_lengths_kv, // (optional) sequence lengths of k or v for padding mask. Shape: [batch_size]
|
||||
int batch_size,
|
||||
int num_heads_q,
|
||||
int num_heads_kv,
|
||||
int head_size_qk,
|
||||
int head_size_v,
|
||||
int sequence_length_q,
|
||||
int sequence_length_kv,
|
||||
float scale,
|
||||
bool is_causal,
|
||||
bool is_bf16, // True if bfloat16, otherwise float16
|
||||
bool broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0
|
||||
bool broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1
|
||||
int sliding_window, // sliding window length. 0 means no sliding window.
|
||||
AttentionQkvFormat qkv_format, // Q_K_V_BNSH, Q_K_V_BSNH, Q_K_V_BSNH_BNSH_BNSH are supported
|
||||
cudnnHandle_t handle,
|
||||
Stream* stream,
|
||||
AllocatorPtr allocator);
|
||||
|
||||
} // namespace onnxruntime::cudnn_sdpa
|
|
@ -6,6 +6,7 @@
|
|||
#include "contrib_ops/cuda/bert/multihead_attention.h"
|
||||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
|
||||
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
|
||||
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
|
||||
|
||||
|
@ -59,6 +60,8 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
|
|||
|
||||
disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();
|
||||
|
||||
enable_cudnn_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseCudnnFlashAttention();
|
||||
|
||||
// Allocate cache buffers
|
||||
constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast<size_t>(kCumulatedSequenceLengthCacheMaxBatchSize) + 1);
|
||||
cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer<void>(cache_bytes);
|
||||
|
@ -148,6 +151,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Check whether we can use fused kernel
|
||||
int sm = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
bool use_flash_attention = !disable_flash_attention_ &&
|
||||
nullptr == attention_bias &&
|
||||
|
@ -173,6 +178,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.num_splits = static_cast<int>(num_splits);
|
||||
softmax_lse_accum_bytes = slse_accum_bytes;
|
||||
out_accum_bytes = o_accum_bytes;
|
||||
kernel_type = AttentionKernelType::AttentionKernel_FlashAttention;
|
||||
}
|
||||
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
|
||||
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
|
||||
|
@ -184,8 +190,23 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE ||
|
||||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
bool use_cudnn_sdpa = kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
enable_cudnn_flash_attention_ &&
|
||||
is_mask_none_or_1d_k_len &&
|
||||
onnxruntime::cudnn_sdpa::is_supported(device_prop,
|
||||
parameters.num_heads, // num_heads_q
|
||||
parameters.num_heads, // num_heads_kv
|
||||
parameters.head_size, // head_size_qk
|
||||
parameters.v_head_size, // head_size_v
|
||||
parameters.sequence_length, // seq_len_q
|
||||
parameters.total_sequence_length, // seq_len_kv
|
||||
is_unidirectional_);
|
||||
if (use_cudnn_sdpa) {
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention;
|
||||
}
|
||||
|
||||
bool use_fused_cross_attention =
|
||||
!use_flash_attention &&
|
||||
kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
!disable_fused_cross_attention_ &&
|
||||
nullptr == key_padding_mask &&
|
||||
nullptr == attention_bias &&
|
||||
|
@ -205,11 +226,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// The kernel has no limit on sequence length, and this checks whether the kernel has been loaded.
|
||||
if (fused_fp16_cross_attention_kernel_->isValid(sequence_length)) {
|
||||
fused_cross_attention_kernel = fused_fp16_cross_attention_kernel_;
|
||||
kernel_type = AttentionKernelType::AttentionKernel_TrtFusedCrossAttention;
|
||||
}
|
||||
}
|
||||
|
||||
bool use_fused_runner =
|
||||
!use_flash_attention &&
|
||||
kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
!disable_fused_self_attention_ &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
nullptr == attention_bias &&
|
||||
|
@ -234,6 +256,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
|
||||
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
|
||||
fused_runner = fused_fp16_runner_.get();
|
||||
// could also be AttentionKernel_TrtFlashAttention, but we don't classify it here.
|
||||
kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -244,9 +268,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.kv_sequence_length >= length_threshold;
|
||||
|
||||
bool use_memory_efficient_attention =
|
||||
!use_flash_attention &&
|
||||
fused_runner == nullptr &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
is_long_sequence &&
|
||||
// Check whether the attention bias alignment is good for memory efficient attention.
|
||||
|
@ -254,10 +276,17 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
|
||||
has_memory_efficient_attention(sm, std::is_same<T, MLFloat16>::value,
|
||||
parameters.head_size, parameters.v_head_size);
|
||||
if (use_memory_efficient_attention) {
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
|
||||
}
|
||||
#else
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
#endif
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_Default) {
|
||||
kernel_type = AttentionKernelType::AttentionKernel_Unfused;
|
||||
}
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
AttentionData<CudaT> data;
|
||||
data.bias = (nullptr == bias) ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
|
||||
|
@ -278,6 +307,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
|
||||
data.use_flash_attention = use_flash_attention;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
data.kernel_type = kernel_type;
|
||||
data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault);
|
||||
|
||||
// Cache of cumulated sequence length that could help when sequence length does not change (for example, image model).
|
||||
// The cache will be initialized only once, and become readonly after that.
|
||||
|
@ -305,6 +336,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
use_flash_attention,
|
||||
use_fused_cross_attention,
|
||||
use_memory_efficient_attention,
|
||||
use_cudnn_sdpa,
|
||||
no_qkv_workspace);
|
||||
auto work_space = GetScratchBuffer<void>(workspace_bytes, context->GetComputeStream());
|
||||
|
||||
|
@ -323,6 +355,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
if (data.allow_debug_info) {
|
||||
AttentionKernelDebugInfo debug_info;
|
||||
debug_info.use_flash_attention = use_flash_attention;
|
||||
debug_info.use_cudnn_flash_attention = use_cudnn_sdpa;
|
||||
debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr;
|
||||
debug_info.use_efficient_attention = use_memory_efficient_attention;
|
||||
if (fused_fp16_runner_ != nullptr) {
|
||||
|
@ -337,8 +370,9 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
cublasHandle_t cublas = GetCublasHandle(context);
|
||||
cudnnHandle_t cudnn = GetCudnnHandle(context);
|
||||
return QkvToContext<CudaT>(
|
||||
device_prop, cublas, context->GetComputeStream(), parameters, data);
|
||||
device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
@ -33,6 +33,7 @@ class MultiHeadAttention final : public CudaKernel {
|
|||
bool disable_fused_cross_attention_;
|
||||
bool disable_flash_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
bool enable_cudnn_flash_attention_;
|
||||
|
||||
// These mutable members are readonly after they are initialized so that they can be shared among multiple threads.
|
||||
// Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource.
|
||||
|
|
|
@ -179,6 +179,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
constexpr bool use_fused_cross_attention = false;
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
constexpr bool use_flash_attention = false;
|
||||
constexpr bool use_cudnn_flash_attention = false;
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
batch_size,
|
||||
parameters.num_heads,
|
||||
|
@ -191,6 +192,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
use_flash_attention,
|
||||
use_fused_cross_attention,
|
||||
use_memory_efficient_attention,
|
||||
use_cudnn_flash_attention,
|
||||
true);
|
||||
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
|
||||
|
@ -215,7 +217,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
}
|
||||
|
||||
return QkvToContext<CudaT>(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data);
|
||||
cudnnHandle_t cudnn = GetCudnnHandle(context);
|
||||
return QkvToContext<CudaT>(GetDeviceProp(), cublas, cudnn, context->GetComputeStream(), parameters, data);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
@ -88,7 +88,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
#ifndef DISABLE_CONTRIB_OPS
|
||||
// Attention kernel options parsed from sdpa_kernel cuda provider option.
|
||||
const AttentionKernelOptions* GetAttentionKernelOptions() const {
|
||||
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true);
|
||||
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true, true);
|
||||
return &attention_kernel_options_;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -104,7 +104,8 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData&
|
|||
|
||||
data.skip_kernel_types = {AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedAttention,
|
||||
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention};
|
||||
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention,
|
||||
AttentionKernelType::AttentionKernel_CudnnFlashAttention};
|
||||
|
||||
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.query_data", data.query_data);
|
||||
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.key_data", data.key_data);
|
||||
|
|
|
@ -367,6 +367,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
|
@ -377,6 +378,22 @@ static void RunMultiHeadAttentionKernel(
|
|||
mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length,
|
||||
hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml);
|
||||
}
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data,
|
||||
past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data,
|
||||
mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length,
|
||||
hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml);
|
||||
}
|
||||
}
|
||||
|
||||
static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) {
|
||||
|
@ -451,6 +468,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu
|
|||
}
|
||||
#endif
|
||||
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention;
|
||||
if (!SkipAttentionKernel(data, kernel_type)) {
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data,
|
||||
data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data,
|
||||
data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data,
|
||||
data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size,
|
||||
data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda);
|
||||
}
|
||||
|
||||
kernel_type = AttentionKernelType::AttentionKernel_Default;
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data,
|
||||
|
|
|
@ -791,7 +791,13 @@ def run_tflops_test(
|
|||
# flash attention is available for sm >= 80
|
||||
sm = get_compute_capability()
|
||||
if sm >= 80:
|
||||
backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH]
|
||||
backends = [
|
||||
SdpaKernel.DEFAULT,
|
||||
SdpaKernel.FLASH_ATTENTION,
|
||||
SdpaKernel.EFFICIENT_ATTENTION,
|
||||
SdpaKernel.CUDNN_FLASH_ATTENTION,
|
||||
SdpaKernel.MATH,
|
||||
]
|
||||
else:
|
||||
backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH]
|
||||
else:
|
||||
|
|
|
@ -804,6 +804,10 @@ class TestMultiHeadAttention(unittest.TestCase):
|
|||
if get_compute_capability() >= 60:
|
||||
self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT)
|
||||
|
||||
def run_mha_cuda_multi_threading_cudnn(self):
|
||||
if get_compute_capability() in [80, 86, 89, 90]:
|
||||
self.run_mha_cuda_multi_threading(SdpaKernel.CUDNN_FLASH_ATTENTION)
|
||||
|
||||
def run_mha_cuda_multi_threading_efficient(self):
|
||||
if comprehensive_mode and get_compute_capability() >= 60:
|
||||
self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION)
|
||||
|
@ -826,6 +830,7 @@ class TestMultiHeadAttention(unittest.TestCase):
|
|||
self.run_mha_cpu()
|
||||
self.run_mha_cuda()
|
||||
self.run_mha_cuda_multi_threading_default()
|
||||
self.run_mha_cuda_multi_threading_cudnn()
|
||||
self.run_mha_cuda_multi_threading_efficient()
|
||||
self.run_mha_cuda_multi_threading_math()
|
||||
self.run_mha_cuda_multi_threading_trt()
|
||||
|
|
Загрузка…
Ссылка в новой задаче