### Description
- [x] Add cuDNN flash attention using cudnn frontend, and enable it in
MultiHeadAttention operator.
- [x] Support attention mask.
- [x] Support attention bias.
- [x] Update tests and benchmark script.

The cuDNN SDPA is disabled by default. To enable it, need the following:
(1) Requires cuDNN 9.3 or newer version installed.
(2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or
set `sdpa_kernel=8` cuda provider option to enable it.
(3) Only works on devices with compute capability >= 8.0.

Note that some combinations of parameters might be rejected due to
limited support of head dimension or sequence lengths.

Future Works:
(1) FP8 and BF16 APIs.  Currently, only API for FP16 are exposed.
(2) Add API to support ragged batching (padding removed in inputs).
(3) Support other input formats (like QKV_BS3NH).
(4) Currently, q are converted to BSNH, k/v are converted to either BSNH
or BNSH format. May do some experiment to see whether converting q to
BNSH could be better in some case.

### Example Benchmark Results on H100

The following tests are on FP16 MultiHeadAttention operator without
attention mask and attention bias.

#### Test Setting 1
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 256 | 0 | 32 | 128

format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash
Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient
Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math
Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn
Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash
Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient
Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math
Q,KV | 0.000129 | 133.0 | ort:cudnn
Q,KV | 0.000151 | 114.1 | ort:flash
Q,KV | 0.000194 | 88.5 | ort:efficient
QKV | 0.000154 | 111.8 | ort:cudnn
QKV | 0.000175 | 98.0 | ort:flash
QKV | 0.000217 | 79.0 | ort:efficient

#### Test Setting 2

batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 512 | 0 | 16 | 64

format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash
Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient
Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math
Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn
Q,K,V (BSNH)  | 0.000087 | 196.6 | ort:flash
Q,K,V (BSNH)  | 0.000163 | 105.6 | ort:efficient
Q,K,V (BSNH)  | 0.000651 | 26.4 | ort:math
Q,KV | 0.000103 | 167.1 | ort:cudnn
Q,KV | 0.000117 | 146.3 | ort:flash
Q,KV | 0.000192 | 89.6 | ort:efficient
QKV | 0.000113 | 151.5 | ort:cudnn
QKV | 0.000128 | 134.7 | ort:flash
QKV | 0.000201 | 85.3 | ort:efficient
This commit is contained in:
Tianlei Wu 2024-08-20 08:50:22 -07:00 коммит произвёл GitHub
Родитель 9f7e19cedd
Коммит fbc3927231
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
19 изменённых файлов: 681 добавлений и 50 удалений

2
cmake/external/cuDNN.cmake поставляемый
Просмотреть файл

@ -107,5 +107,3 @@ elseif(CUDNN_MAJOR_VERSION EQUAL 9)
CUDNN::cudnn_heuristic
)
endif()
mark_as_advanced(CUDNN_INCLUDE_DIR)

Просмотреть файл

@ -5,6 +5,7 @@ find_package(Python3 COMPONENTS Interpreter REQUIRED)
# GLOB pattern of file to be excluded
set(contrib_ops_excluded_files
"bert/cudnn_fmha/*"
"bert/cutlass_fmha/*"
"bert/fastertransformer_decoder_attention/*"
"bert/flash_attention/*"

Просмотреть файл

@ -47,6 +47,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_Default
};

Просмотреть файл

@ -246,6 +246,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
@ -258,6 +259,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());
@ -294,7 +296,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
}
return QkvToContext<CudaT>(device_prop, cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}
} // namespace cuda

Просмотреть файл

@ -37,6 +37,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
@ -109,6 +110,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_size = element_size * batch_size * num_heads *
@ -144,6 +146,10 @@ size_t GetAttentionWorkspaceSize(
return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast<int>(batch_size), true);
}
if (use_cudnn_flash_attention) {
return qkv_bytes;
}
return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
total_sequence_length);
}
@ -320,6 +326,68 @@ Status FlashAttention(
}
#endif
template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN);
constexpr bool is_bf16 = false;
T* attention_bias = const_cast<T*>(data.attention_bias);
int* mask_sequence_lengths_kv = const_cast<int*>(data.mask_index);
cudnn_sdpa::run(
data.output,
data.q,
data.k,
data.v,
attention_bias,
nullptr, // (optional) mask_sequence_lengths_q
mask_sequence_lengths_kv, // (optional) mask_sequence_lengths_kv
parameters.batch_size,
parameters.num_heads, // num_heads_q,
parameters.num_heads, // num_heads_kv,
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // sequence_length_q
parameters.total_sequence_length, // sequence_length_kv
scale, // scaling factor applied prior softmax
parameters.is_unidirectional, // causal
is_bf16, // True if bfloat16, otherwise float16
parameters.broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0 or not
parameters.broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1 or not
0, // sliding window length. 0 means no sliding window.
data.qkv_format,
cudnn_handle,
ort_stream,
data.allocator);
return Status::OK();
}
template <>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(cudnn_handle);
ORT_UNUSED_PARAMETER(ort_stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"cudnn flash attention does not support float tensor");
}
#if USE_MEMORY_EFFICIENT_ATTENTION
template <typename T>
Status EfficientAttention(
@ -498,6 +566,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data) {
@ -512,10 +581,11 @@ Status QkvToContext(
void* fused_runner = data.fused_runner;
// At most one fused kernel is enabled.
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr)) <= 1);
assert((static_cast<int>(data.use_flash_attention) +
static_cast<int>(data.use_memory_efficient_attention) +
static_cast<int>(fused_runner != nullptr) +
static_cast<int>(data.fused_cross_attention_kernel != nullptr) +
static_cast<int>(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);
ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));
@ -577,6 +647,10 @@ Status QkvToContext(
}
#endif
if (data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
return CudnnFlashAttention(cudnn, ort_stream, parameters, data, scale);
}
#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
return EfficientAttention(device_prop, stream, parameters, data, scale);
@ -594,6 +668,7 @@ template struct AttentionData<half>;
template Status QkvToContext<float>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data);
@ -601,6 +676,7 @@ template Status QkvToContext<float>(
template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* ort_stream,
contrib::AttentionParameters& parameters,
AttentionData<half>& data);

Просмотреть файл

@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include "core/framework/allocator.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
@ -54,6 +55,7 @@ size_t GetAttentionWorkspaceSize(
bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
bool no_qkv_workspace);
template <typename T>
@ -104,9 +106,11 @@ struct AttentionData {
size_t workspace_bytes = 0;
bool allow_debug_info = false;
// For MultiHeadAttention only.
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
AllocatorPtr allocator = nullptr;
bool IsUnfused() const {
return !use_flash_attention && !use_memory_efficient_attention &&
(fused_runner == nullptr) && (fused_cross_attention_kernel == nullptr);
return kernel_type == AttentionKernelType::AttentionKernel_Unfused;
}
void PrintDebugInfo() const {
@ -139,6 +143,7 @@ template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cudnnHandle_t& cudnn,
Stream* stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data);

Просмотреть файл

@ -9,11 +9,12 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
using namespace onnxruntime::contrib::attention;
namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
@ -28,6 +29,7 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
use_unfused_ = true;
use_trt_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableTrtFlashAttention, false);
use_trt_cross_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedCrossAttention, false);
@ -45,6 +47,14 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
kMinSeqLenForEfficientAttentionFp32,
value > 0 ? 0 : kDefaultMinSeqLenForEfficientAttentionFp32);
// Enable cuDNN flash attention only when it is stable (requires cuDNN version >= 9.3.0).
if (use_cudnn_flash_attention_ && check_cudnn_version && !::onnxruntime::cudnn_sdpa::is_stable()) {
use_cudnn_flash_attention_ = false;
if (enable_kernel_debug_info_) {
std::cout << "cuDNN Flash Attention is disabled. Requires cuDNN 9.3 or later." << std::endl;
}
}
if (use_build_flag) {
// Some kernels can be disabled at build time. If they are disabled, we should not use them.
#ifndef USE_FLASH_ATTENTION
@ -58,9 +68,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag) {
}
void AttentionKernelOptions::InitializeOnce(
int sdpa_kernel, bool use_build_flag) {
int sdpa_kernel, bool use_build_flag, bool check_cudnn_version) {
std::call_once(this->initialize_once_flag_, [&]() {
this->Initialize(sdpa_kernel, use_build_flag);
this->Initialize(sdpa_kernel, use_build_flag, check_cudnn_version);
if (this->enable_kernel_debug_info_) {
this->Print();
}

Просмотреть файл

@ -21,7 +21,7 @@ struct AttentionKernelDebugInfo {
class AttentionKernelOptions {
public:
void InitializeOnce(int sdpa_kernel, bool use_build_flag);
void InitializeOnce(int sdpa_kernel, bool use_build_flag, bool check_cudnn_version = false);
bool UseFlashAttention() const { return use_flash_attention_; }
bool UseEfficientAttention() const { return use_efficient_attention_; }
@ -40,7 +40,7 @@ class AttentionKernelOptions {
protected:
void Print() const;
void Initialize(int value, bool use_build_flag);
void Initialize(int value, bool use_build_flag, bool check_cudnn_version);
private:
bool use_flash_attention_{true};

Просмотреть файл

@ -169,7 +169,10 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
template <typename T>
bool NoQkvWorkspace_MHA_Cross(AttentionData<T>& data) {
// query, key and value are passed as Q, K and V for the following conditions.
return (data.use_memory_efficient_attention || data.use_flash_attention) && (data.bias == nullptr);
return (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) &&
data.bias == nullptr;
}
// For MultiHeadAttention with cross attention (Q_K_V_BSNH_BNSH_BNSH format)
@ -190,8 +193,9 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Add bias for Q
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
@ -204,9 +208,7 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
data.k = const_cast<T*>(data.key);
data.v = const_cast<T*>(data.value);
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else
#endif
{ // unfused kernel
} else { // unfused kernel
assert(data.IsUnfused());
if (data.bias == nullptr) {
// Transpose query from BSNH to BNSH
@ -233,7 +235,10 @@ Status PrepareQkv_MHA_Cross(contrib::AttentionParameters& parameters,
template <typename T>
bool NoQkvWorkspace_MHA_NoPast(AttentionData<T>& data) {
// query, key and value are passed as Q, K and V for the following conditions.
return (data.use_memory_efficient_attention || data.use_flash_attention) && data.bias == nullptr;
return (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) &&
data.bias == nullptr;
}
// For MultiHeadAttention without past state, with Q, K and V inputs
@ -275,9 +280,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
data.v = nullptr;
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
else if (data.use_memory_efficient_attention || data.use_flash_attention) {
} else if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
if (data.bias != nullptr) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
@ -290,9 +295,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
}
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else if (data.fused_runner != nullptr) {
} else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
assert(data.attention_bias == nullptr);
@ -338,7 +341,9 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
template <typename T>
bool NoQkvWorkspace_MHA_WithPast_NoBias(AttentionData<T>& data) {
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Q, K and V redirects to query, present_k and present_v, so we do not need extra workspace for QKV.
return data.past_key == nullptr && data.present_key != nullptr;
}
@ -377,8 +382,9 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
data.v = data.present_value;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Use oiginal Query (BSNH) since there is no bias.
data.q = const_cast<T*>(data.query);
@ -389,9 +395,7 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.v));
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else
#endif
{ // unfused kernel
} else { // unfused kernel
assert(data.IsUnfused());
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, data.q));
@ -440,8 +444,9 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
data.v = data.present_value;
}
#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention ||
data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Query(BxSxNxH) + Bias_Q => Q (BxSxNxH)
LaunchAddBias(stream, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size,
data.bias, data.query, data.q);
@ -460,9 +465,7 @@ Status PrepareQkv_MHA_WithPast_Bias(contrib::AttentionParameters& parameters,
data.value, data.bias + 2 * num_heads * qk_head_size, data.v, true, -1);
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
} else
#endif
{ // unfused kernel
} else { // unfused kernel
assert(data.IsUnfused());
constexpr int format = 0;
@ -518,7 +521,8 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention || data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// unpack qkv to BSNH.
constexpr int format = 4;
T* qkv_add_bias = nullptr;
@ -590,7 +594,8 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
if (data.use_memory_efficient_attention || data.use_flash_attention) {
if (data.use_memory_efficient_attention || data.use_flash_attention ||
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
// Note that there is no bias so we need not output query to q.
data.q = const_cast<T*>(data.query);
// Unpack kv to BSNH.

Просмотреть файл

@ -0,0 +1,405 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <cudnn.h>
#if CUDNN_MAJOR < 9
namespace onnxruntime::cudnn_sdpa {
bool is_stable() {
return false;
}
bool is_supported(const cudaDeviceProp& /*dprops*/,
int /*num_heads_q*/,
int /*num_heads_kv*/,
int /*head_size_qk*/,
int /*head_size_v*/,
int /*sequence_length_q*/,
int /*sequence_length_kv*/,
bool /*is_causal*/) {
return false;
}
void run(
void* /*output*/,
void* /*q*/,
void* /*k*/,
void* /*v*/,
void* /*bias*/,
int* /*mask_sequence_lengths_q*/,
int* /*mask_sequence_lengths_kv*/,
int /*batch_size*/,
int /*num_heads_q*/,
int /*num_heads_kv*/,
int /*head_size_qk*/,
int /*head_size_v*/,
int /*sequence_length_q*/,
int /*sequence_length_kv*/,
float /*scale*/,
bool /*is_causal*/,
bool /*is_bf16*/,
bool /*broadcast_attn_bias_dim_0*/,
bool /*broadcast_attn_bias_dim_1*/,
int /*sliding_window*/,
AttentionQkvFormat /*qkv_format*/,
cudnnHandle_t /*handle*/,
Stream* /*stream*/,
AllocatorPtr /*allocator*/) {
ORT_THROW("OnnxRuntime was not compiled with cuDNN Flash Attention.");
}
} // namespace onnxruntime::cudnn_sdpa
#else // CUDNN_MAJOR >= 9
#include <cudnn_frontend.h>
#include "core/providers/cuda/shared_inc/cudnn_fe_call.h"
#include "core/providers/cuda/cuda_stream_handle.h"
namespace onnxruntime::cudnn_sdpa {
bool is_stable() {
// FP16/BF16 Flash Attention support in CUDNN backend:
// version 8903 (8.9.3):
// Padding mask and causal mask
// Additive bias
// Multi-query attention (h_kv=1)
// Both self attention and cross attention
// (padded) variable sequence length
// Head dimensions 64 or 128
// version 8903 (8.9.4):
// Alibi mask;
// version 8907 (8.9.7):
// Grouped Query Attention
// version 90100 (9.1.0):
// Head dimensions 256
// version 90101 (9.1.1)
// Sliding window attention
// version 90300 (9.3.0)
// Bug fixes; Variable sequence length supports zero-sequence-length values
// For more information, please refer to cuDNN release notes, and the following links:
// https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-flash-attention-fprop
// https://github.com/NVIDIA/cudnn-frontend/blob/v1.5.2/docs/operations/Attention.md
// For cuDNN version < 9.3, we will disable it by default.
return cudnnGetVersion() >= 90300;
}
namespace fe = cudnn_frontend;
bool is_supported(const cudaDeviceProp& dprops,
int num_heads_q,
int num_heads_kv,
int head_size_qk,
int head_size_v,
int sequence_length_q,
int sequence_length_kv,
bool is_causal) {
bool is_sm8x = dprops.major == 8 && dprops.minor >= 0;
bool is_sm90 = dprops.major == 9 && dprops.minor == 0;
return (is_sm8x || is_sm90) &&
(head_size_qk % 8 == 0) && (head_size_qk <= 256) &&
(head_size_v % 8 == 0) && (head_size_v <= 256) &&
(num_heads_q % num_heads_kv == 0) &&
// Bottom right causal mask is only supported with s_q multiple of 64 and s_kv multiple of 64.
(!is_causal || (sequence_length_q != sequence_length_kv &&
sequence_length_q % 64 == 0 &&
sequence_length_kv % 64 == 0));
}
// A helper function to set stride for q, k, v or output tensor.
// Strides are calculated based on logical tensor layout BNSH (batch_size, num_heads, sequence_length, head_size).
// The physical tensor layout could be either BSNH (is_bsnh=True) or BNSH (is_bsnh=False).
inline void set_stride(std::vector<int64_t>& stride,
int64_t num_heads,
int64_t sequence_length,
int64_t head_size,
bool is_bsnh) {
stride = {num_heads * sequence_length * head_size, // stride for batch.
is_bsnh ? head_size : (head_size * sequence_length), // stride for head.
is_bsnh ? (num_heads * head_size) : head_size, // stride for sequence.
1}; // stride for hidden dim of head, shall always be 1.
}
// It is used as a key for hash table to store cached graphs.
// It contains all parameters used in builing graph. Do not include data pointers that only needed in graph execution.
struct GraphParams {
int batch_size;
int num_heads_q;
int num_heads_kv;
int head_size_qk;
int head_size_v;
int sequence_length_q;
int sequence_length_kv;
float scale;
bool is_causal;
bool is_bf16; // True if bfloat16, otherwise float16
AttentionQkvFormat qkv_format;
cudnnHandle_t handle;
bool has_bias;
bool broadcast_bias_dim_0;
bool broadcast_bias_dim_1;
bool has_padding_mask_q;
bool has_padding_mask_kv;
int sliding_window;
bool operator==(const GraphParams& rhs) const {
return batch_size == rhs.batch_size &&
num_heads_q == rhs.num_heads_q &&
num_heads_kv == rhs.num_heads_kv &&
head_size_qk == rhs.head_size_qk &&
head_size_v == rhs.head_size_v &&
sequence_length_q == rhs.sequence_length_q &&
sequence_length_kv == rhs.sequence_length_kv &&
scale == rhs.scale &&
is_causal == rhs.is_causal &&
is_bf16 == rhs.is_bf16 &&
qkv_format == rhs.qkv_format &&
handle == rhs.handle &&
has_bias == rhs.has_bias &&
broadcast_bias_dim_0 == rhs.broadcast_bias_dim_0 &&
broadcast_bias_dim_1 == rhs.broadcast_bias_dim_1 &&
has_padding_mask_q == rhs.has_padding_mask_q &&
has_padding_mask_kv == rhs.has_padding_mask_kv &&
sliding_window == rhs.sliding_window;
}
};
#define Q_UID 1
#define K_UID 2
#define V_UID 3
#define O_UID 4
#define BIAS_UID 5
#define SEQ_LEN_Q_UID 6
#define SEQ_LEN_KV_UID 7
std::shared_ptr<fe::graph::Graph> build_graph(GraphParams& params) {
int batch_size = params.batch_size;
int num_heads_q = params.num_heads_q;
int num_heads_kv = params.num_heads_kv;
int head_size_qk = params.head_size_qk;
int head_size_v = params.head_size_v;
int sequence_length_q = params.sequence_length_q;
int sequence_length_kv = params.sequence_length_kv;
float scale = params.scale;
bool is_causal = params.is_causal;
bool is_bf16 = params.is_bf16;
AttentionQkvFormat qkv_format = params.qkv_format;
cudnnHandle_t handle = params.handle;
assert(qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH ||
qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH ||
qkv_format == contrib::AttentionQkvFormat::Q_K_V_BNSH);
auto mha_graph = std::make_shared<fe::graph::Graph>();
mha_graph->set_io_data_type(is_bf16 ? fe::DataType_t::BFLOAT16 : fe::DataType_t::HALF)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
bool is_q_bsnh = (qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH ||
qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
bool is_kv_bsnh = qkv_format == contrib::AttentionQkvFormat::Q_K_V_BSNH;
std::vector<int64_t> stride;
set_stride(stride, num_heads_q, sequence_length_q, head_size_qk, is_q_bsnh);
auto Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_uid(Q_UID)
.set_dim({batch_size, num_heads_q, sequence_length_q, head_size_qk}) // logical layout
.set_stride(stride));
set_stride(stride, num_heads_kv, sequence_length_kv, head_size_qk, is_kv_bsnh);
auto K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_uid(K_UID)
.set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_qk})
.set_stride(stride));
set_stride(stride, num_heads_kv, sequence_length_kv, head_size_v, is_kv_bsnh);
auto V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_uid(V_UID)
.set_dim({batch_size, num_heads_kv, sequence_length_kv, head_size_v})
.set_stride(stride));
auto attributes = fe::graph::SDPA_attributes()
.set_name("SDPA")
.set_is_inference(true)
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_causal && sequence_length_q != sequence_length_kv)
.set_attn_scale(scale);
if (params.sliding_window > 0) {
attributes.set_sliding_window_length(params.sliding_window);
}
if (params.has_bias) {
std::vector<int64_t> bias_shape = {params.broadcast_bias_dim_0 ? 1 : batch_size,
params.broadcast_bias_dim_1 ? 1 : num_heads_q,
sequence_length_q,
sequence_length_kv};
stride = {bias_shape[1] * bias_shape[2] * bias_shape[3], bias_shape[2] * bias_shape[3], bias_shape[3], 1};
auto bias = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("bias")
.set_uid(BIAS_UID)
.set_dim(bias_shape)
.set_stride(stride));
attributes.set_bias(bias);
}
if (params.has_padding_mask_q || params.has_padding_mask_kv) {
attributes.set_padding_mask(true);
if (params.has_padding_mask_q) {
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_uid(SEQ_LEN_Q_UID)
.set_dim({batch_size, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
attributes.set_seq_len_q(seq_q);
}
if (params.has_padding_mask_kv) {
auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_uid(SEQ_LEN_KV_UID)
.set_dim({batch_size, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
attributes.set_seq_len_kv(seq_kv);
}
}
auto [O, Stats] = mha_graph->sdpa(Q, K, V, attributes);
constexpr bool is_output_bsnh = true;
set_stride(stride, num_heads_q, sequence_length_q, head_size_v, is_output_bsnh);
O->set_output(true)
.set_dim({batch_size, num_heads_q, sequence_length_q, head_size_v})
.set_stride(stride)
.set_uid(O_UID);
if (!mha_graph->build(handle, {fe::HeurMode_t::A}).is_good()) {
ORT_THROW("Failed to build cuDNN graph for Flash Attention:", *mha_graph, "cudnn version:", cudnnGetVersion());
}
return mha_graph;
}
// Compute hash based on content in memory byte by byte. This can be moved to a common header file if needed.
template <typename T>
struct BytesHash {
// Verify that Params is good to hash byte by byte.
static_assert(std::is_standard_layout_v<T>, "Params is not standard layout");
size_t operator()(const T& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&params);
// Fowler–Noll–Vo hash function
uint32_t value = 0x811C9DC5;
constexpr size_t bytes = sizeof(T);
for (size_t i = 0; i < bytes; ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return static_cast<size_t>(value);
}
};
// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size.
thread_local
std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;
void run(
void* output,
void* q,
void* k,
void* v,
void* attn_bias,
int* mask_sequence_lengths_q,
int* mask_sequence_lengths_kv,
int batch_size,
int num_heads_q,
int num_heads_kv,
int head_size_qk,
int head_size_v,
int sequence_length_q,
int sequence_length_kv,
float scale,
bool is_causal,
bool is_bf16,
bool broadcast_attn_bias_dim_0,
bool broadcast_attn_bias_dim_1,
int sliding_window,
AttentionQkvFormat qkv_format,
cudnnHandle_t handle,
Stream* stream,
AllocatorPtr allocator) {
GraphParams params;
params.batch_size = batch_size;
params.num_heads_q = num_heads_q;
params.num_heads_kv = num_heads_kv;
params.head_size_qk = head_size_qk;
params.head_size_v = head_size_v;
params.sequence_length_q = sequence_length_q;
params.sequence_length_kv = sequence_length_kv;
params.scale = scale;
params.is_causal = is_causal;
params.is_bf16 = is_bf16;
params.qkv_format = qkv_format;
params.handle = handle;
params.has_bias = attn_bias != nullptr;
params.broadcast_bias_dim_0 = broadcast_attn_bias_dim_0;
params.broadcast_bias_dim_1 = broadcast_attn_bias_dim_1;
params.has_padding_mask_q = (mask_sequence_lengths_q != nullptr);
params.has_padding_mask_kv = (mask_sequence_lengths_kv != nullptr);
params.sliding_window = sliding_window;
std::shared_ptr<fe::graph::Graph> mha_graph;
auto it = mha_graph_cache.find(params);
if (it != mha_graph_cache.end()) {
mha_graph = it->second;
} else {
mha_graph = build_graph(params);
mha_graph_cache[params] = mha_graph;
}
std::unordered_map<fe::graph::Tensor_attributes::uid_t, void*> variant_pack = {
{Q_UID, q},
{K_UID, k},
{V_UID, v},
{O_UID, output},
};
if (attn_bias != nullptr) {
variant_pack[BIAS_UID] = attn_bias;
}
if (mask_sequence_lengths_q != nullptr) {
variant_pack[SEQ_LEN_Q_UID] = mask_sequence_lengths_q;
}
if (mask_sequence_lengths_kv != nullptr) {
variant_pack[SEQ_LEN_KV_UID] = mask_sequence_lengths_kv;
}
// Allocate workspace.
auto bytes = mha_graph->get_workspace_size();
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(
allocator, bytes, false, stream, WaitCudaNotificationOnDevice);
CUDNN_FE_CALL_THROW(mha_graph->execute(handle, variant_pack, buffer.get()));
}
} // namespace onnxruntime::cudnn_sdpa
#endif

Просмотреть файл

@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
using onnxruntime::Stream;
using onnxruntime::contrib::AttentionQkvFormat;
namespace onnxruntime::cudnn_sdpa {
bool is_stable();
bool is_supported(const cudaDeviceProp& dprops,
int num_heads_q,
int num_heads_kv,
int head_size_qk,
int head_size_v,
int sequence_length_q,
int sequence_length_kv,
bool is_causal);
void run(
void* output,
void* q,
void* k,
void* v,
void* bias, // (optional) attention bias with shape [b or 1, h_q or 1, s_q, s_kv].
int* mask_sequence_lengths_q, // (optional) sequence lengths of q for padding mask. Shape: [batch_size]
int* mask_sequence_lengths_kv, // (optional) sequence lengths of k or v for padding mask. Shape: [batch_size]
int batch_size,
int num_heads_q,
int num_heads_kv,
int head_size_qk,
int head_size_v,
int sequence_length_q,
int sequence_length_kv,
float scale,
bool is_causal,
bool is_bf16, // True if bfloat16, otherwise float16
bool broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0
bool broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1
int sliding_window, // sliding window length. 0 means no sliding window.
AttentionQkvFormat qkv_format, // Q_K_V_BNSH, Q_K_V_BSNH, Q_K_V_BSNH_BNSH_BNSH are supported
cudnnHandle_t handle,
Stream* stream,
AllocatorPtr allocator);
} // namespace onnxruntime::cudnn_sdpa

Просмотреть файл

@ -6,6 +6,7 @@
#include "contrib_ops/cuda/bert/multihead_attention.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
@ -59,6 +60,8 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
disable_fused_cross_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtCrossAttention();
enable_cudnn_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseCudnnFlashAttention();
// Allocate cache buffers
constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast<size_t>(kCumulatedSequenceLengthCacheMaxBatchSize) + 1);
cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer<void>(cache_bytes);
@ -148,6 +151,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Default;
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
nullptr == attention_bias &&
@ -173,6 +178,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.num_splits = static_cast<int>(num_splits);
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
kernel_type = AttentionKernelType::AttentionKernel_FlashAttention;
}
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
@ -184,8 +190,23 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool is_mask_none_or_1d_k_len = parameters.mask_type == AttentionMaskType::MASK_NONE ||
parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
bool use_cudnn_sdpa = kernel_type == AttentionKernelType::AttentionKernel_Default &&
enable_cudnn_flash_attention_ &&
is_mask_none_or_1d_k_len &&
onnxruntime::cudnn_sdpa::is_supported(device_prop,
parameters.num_heads, // num_heads_q
parameters.num_heads, // num_heads_kv
parameters.head_size, // head_size_qk
parameters.v_head_size, // head_size_v
parameters.sequence_length, // seq_len_q
parameters.total_sequence_length, // seq_len_kv
is_unidirectional_);
if (use_cudnn_sdpa) {
kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention;
}
bool use_fused_cross_attention =
!use_flash_attention &&
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
nullptr == attention_bias &&
@ -205,11 +226,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
// The kernel has no limit on sequence length, and this checks whether the kernel has been loaded.
if (fused_fp16_cross_attention_kernel_->isValid(sequence_length)) {
fused_cross_attention_kernel = fused_fp16_cross_attention_kernel_;
kernel_type = AttentionKernelType::AttentionKernel_TrtFusedCrossAttention;
}
}
bool use_fused_runner =
!use_flash_attention &&
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_fused_self_attention_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == attention_bias &&
@ -234,6 +256,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
const int normalized_seq_len = fused_fp16_runner_->NormalizeSequenceLength(sequence_length);
if (fused_fp16_runner_->IsValid(normalized_seq_len)) {
fused_runner = fused_fp16_runner_.get();
// could also be AttentionKernel_TrtFlashAttention, but we don't classify it here.
kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention;
}
}
@ -244,9 +268,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.kv_sequence_length >= length_threshold;
bool use_memory_efficient_attention =
!use_flash_attention &&
fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
// Check whether the attention bias alignment is good for memory efficient attention.
@ -254,10 +276,17 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
has_memory_efficient_attention(sm, std::is_same<T, MLFloat16>::value,
parameters.head_size, parameters.v_head_size);
if (use_memory_efficient_attention) {
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
}
#else
constexpr bool use_memory_efficient_attention = false;
#endif
if (kernel_type == AttentionKernelType::AttentionKernel_Default) {
kernel_type = AttentionKernelType::AttentionKernel_Unfused;
}
typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
data.bias = (nullptr == bias) ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>());
@ -278,6 +307,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.kernel_type = kernel_type;
data.allocator = Info().GetAllocator(OrtMemType::OrtMemTypeDefault);
// Cache of cumulated sequence length that could help when sequence length does not change (for example, image model).
// The cache will be initialized only once, and become readonly after that.
@ -305,6 +336,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_sdpa,
no_qkv_workspace);
auto work_space = GetScratchBuffer<void>(workspace_bytes, context->GetComputeStream());
@ -323,6 +355,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (data.allow_debug_info) {
AttentionKernelDebugInfo debug_info;
debug_info.use_flash_attention = use_flash_attention;
debug_info.use_cudnn_flash_attention = use_cudnn_sdpa;
debug_info.use_trt_cross_attention = fused_cross_attention_kernel != nullptr;
debug_info.use_efficient_attention = use_memory_efficient_attention;
if (fused_fp16_runner_ != nullptr) {
@ -337,8 +370,9 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
cublasHandle_t cublas = GetCublasHandle(context);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(
device_prop, cublas, context->GetComputeStream(), parameters, data);
device_prop, cublas, cudnn, context->GetComputeStream(), parameters, data);
}
} // namespace cuda

Просмотреть файл

@ -33,6 +33,7 @@ class MultiHeadAttention final : public CudaKernel {
bool disable_fused_cross_attention_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
bool enable_cudnn_flash_attention_;
// These mutable members are readonly after they are initialized so that they can be shared among multiple threads.
// Initialization are done only once by the first thread using the resource, so use once_flag to guard each resource.

Просмотреть файл

@ -179,6 +179,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_fused_cross_attention = false;
constexpr bool use_memory_efficient_attention = false;
constexpr bool use_flash_attention = false;
constexpr bool use_cudnn_flash_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
batch_size,
parameters.num_heads,
@ -191,6 +192,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
true);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
@ -215,7 +217,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
data.present = reinterpret_cast<CudaT*>(present->MutableData<T>());
}
return QkvToContext<CudaT>(GetDeviceProp(), cublas, context->GetComputeStream(), parameters, data);
cudnnHandle_t cudnn = GetCudnnHandle(context);
return QkvToContext<CudaT>(GetDeviceProp(), cublas, cudnn, context->GetComputeStream(), parameters, data);
}
} // namespace cuda

Просмотреть файл

@ -88,7 +88,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
#ifndef DISABLE_CONTRIB_OPS
// Attention kernel options parsed from sdpa_kernel cuda provider option.
const AttentionKernelOptions* GetAttentionKernelOptions() const {
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true);
attention_kernel_options_.InitializeOnce(info_.sdpa_kernel, true, true);
return &attention_kernel_options_;
}
#endif

Просмотреть файл

@ -104,7 +104,8 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData&
data.skip_kernel_types = {AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention,
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention};
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernelType::AttentionKernel_CudnnFlashAttention};
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.query_data", data.query_data);
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.key_data", data.key_data);

Просмотреть файл

@ -367,6 +367,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
@ -377,6 +378,22 @@ static void RunMultiHeadAttentionKernel(
mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length,
hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml);
}
if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kEnableCudnnFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data,
past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data,
mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length,
hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml);
}
}
static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) {
@ -451,6 +468,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu
}
#endif
kernel_type = AttentionKernelType::AttentionKernel_CudnnFlashAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data,
data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data,
data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data,
data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size,
data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda);
}
kernel_type = AttentionKernelType::AttentionKernel_Default;
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data,

Просмотреть файл

@ -791,7 +791,13 @@ def run_tflops_test(
# flash attention is available for sm >= 80
sm = get_compute_capability()
if sm >= 80:
backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH]
backends = [
SdpaKernel.DEFAULT,
SdpaKernel.FLASH_ATTENTION,
SdpaKernel.EFFICIENT_ATTENTION,
SdpaKernel.CUDNN_FLASH_ATTENTION,
SdpaKernel.MATH,
]
else:
backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION, SdpaKernel.MATH]
else:

Просмотреть файл

@ -804,6 +804,10 @@ class TestMultiHeadAttention(unittest.TestCase):
if get_compute_capability() >= 60:
self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT)
def run_mha_cuda_multi_threading_cudnn(self):
if get_compute_capability() in [80, 86, 89, 90]:
self.run_mha_cuda_multi_threading(SdpaKernel.CUDNN_FLASH_ATTENTION)
def run_mha_cuda_multi_threading_efficient(self):
if comprehensive_mode and get_compute_capability() >= 60:
self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION)
@ -826,6 +830,7 @@ class TestMultiHeadAttention(unittest.TestCase):
self.run_mha_cpu()
self.run_mha_cuda()
self.run_mha_cuda_multi_threading_default()
self.run_mha_cuda_multi_threading_cudnn()
self.run_mha_cuda_multi_threading_efficient()
self.run_mha_cuda_multi_threading_math()
self.run_mha_cuda_multi_threading_trt()