[CUDA] Add SparseAttention operator for Phi-3-small (#20216)

### Description
Add CUDA implementation for block sparse attention for Phi-3-small.

Block sparse attention was proposed in [Sparse
Transformers](https://arxiv.org/pdf/1904.10509) by OpenAI, and also
adopted in [BigBird](https://arxiv.org/pdf/2007.14062) with different
sparse layout.

In Phi-3-small, the sparse layout is static, and works with
unidirectional (causal) attention.

Compared to dense attention, the benefit of block sparse is to speed up
both training and inference. It could save memory thus support longer
context length.

- [x] Add operator spec and shape inference
- [x] Symbolic shape inference
- [x] Refactor GroupQueryAttention to expose common kernels for kv cache
concatenation, q/k/v transpose etc.
- [x] Add cuda kernel to convert block mask to CSR format
- [x] Add cuda kernel to generate position ids
- [x] Add compile script and template files to convert triton kernel to
cubin and dispatcher.
- [x] Add triton kernel v1 for prompt
- [x] Add triton kernel v2 for token generation and support padding
- [x] Update IO Binding Helper to allow buffer sharing.
- [x] Test relevance
- [x] Test performance

### Performance
Test in A100-SXM4-80GB with `batch_size=4, num_heads=32,
max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16,
vert_stride=8, num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal.

Average latency in milliseconds (for fused attention kernel used in
prompt prefilling):

seq_len | GQA-Dense | GQA-Local | SparseAttention
-- | -- | -- | --
64 | 0.0465 | 0.0722 | 0.0641
128 | 0.0618 | 0.0787 | 0.0672
256 | 0.1086 | 0.1076 | 0.0943
512 | 0.2535 | 0.2487 | 0.1676
1024 | 0.7042 | 0.7050 | 0.3800
2048 | 2.4125 | 1.9316 | 0.8966
4096 | 8.9346 | 4.5699 | 2.1129
8192 | 40.5401 | 10.3508 | 5.1748

Average latency in milliseconds (for fused attention kernel used in
token generation:

past_seq_len | GQA-Dense | GQA-Local | SparseAttention
-- | -- | -- | --
64 | 0.0186 | 0.0186 | 0.0870
128 | 0.0408 | 0.0466 | 0.1165
256 | 0.0530  | 0.0592 | 0.0988
512 | 0.0445| 0.0447 | 0.1150
1024 | 0.0634  | 0.0640 | 0.1454
2048 | 0.1027 | 0.0637 | 0.1589
4096 | 0.1789 | 0.0631 | 0.1806
8192 | 0.3288 | 0.0655 | 0.2146

We can see that the kernel for token generation still have room to
improve.

#### Limitations
Only support right-side padding and unidirectional attention.

The following are not supported in the first version:
(1) Packed mode like PackedMultiHeadAttention where input has been
removed padding.
(2) paged attention.
(3) bidirectional attention.
(4) GPU compute capacity that is not 8.0, 8.6 and 8.9.
(5) Left side padding.

Some of these limitations will be removed in the future (may be in a new
operator).
This commit is contained in:
Tianlei Wu 2024-04-30 09:06:29 -07:00 коммит произвёл GitHub
Родитель b2481e3602
Коммит 9f0fae29e8
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
78 изменённых файлов: 6088 добавлений и 821 удалений

Просмотреть файл

@ -4,10 +4,12 @@
find_package(Python3 COMPONENTS Interpreter REQUIRED)
# set all triton kernel ops that need to be compiled
set(triton_kernel_scripts
"onnxruntime/core/providers/rocm/math/softmax_triton.py"
"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
)
if(onnxruntime_USE_ROCM)
set(triton_kernel_scripts
"onnxruntime/core/providers/rocm/math/softmax_triton.py"
"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
)
endif()
function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir)
# compile triton kernel, generate .a and .h files

Просмотреть файл

@ -46,6 +46,7 @@ set(contrib_ops_excluded_files
"math/gemm_float8.cu"
"math/gemm_float8.h"
"moe/*"
"sparse/*"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"

Просмотреть файл

@ -102,6 +102,7 @@ Do not modify directly.*
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
* <a href="#com.microsoft.SkipSimplifiedLayerNormalization">com.microsoft.SkipSimplifiedLayerNormalization</a>
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
* <a href="#com.microsoft.SparseAttention">com.microsoft.SparseAttention</a>
* <a href="#com.microsoft.SparseToDenseMatMul">com.microsoft.SparseToDenseMatMul</a>
* <a href="#com.microsoft.Tokenizer">com.microsoft.Tokenizer</a>
* <a href="#com.microsoft.TorchEmbedding">com.microsoft.TorchEmbedding</a>
@ -3418,7 +3419,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Input tensors contains the hidden embedding of real tokens.
Token_offset records the offset of token in the unpacked input.
cumulated_token_count records cumulated length of each sequnces length.
cumulated_token_count records cumulated length of each sequence length.
The operator only supports BERT like model with padding on right now.
@ -3492,7 +3493,7 @@ This version of the operator has been available since version 1 of the 'com.micr
The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
Token_offset records the offset of token in the unpacked input.
cumulative_sequence_length records cumulated length of each sequnces length.
cumulative_sequence_length records cumulated length of each sequence length.
The operator only supports BERT like model with padding on right now.
@ -5541,6 +5542,90 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.SparseAttention"></a><a name="com.microsoft.sparseattention">**com.microsoft.SparseAttention**</a>
Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219).
It is inspired by Sparse Transformers (https://arxiv.org/pdf/1904.10509) and BigBird (https://arxiv.org/pdf/2007.14062).
block_mask can be used to configure sparse layout for different head.
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
Padding shall be on the right side.
When do_rotary is True, cos_cache and sin_cache are required.
Only supports unidirectional attention with cache of past key and value in linear buffers.
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for key and value</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for query</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
<dd>Rotary use interleaved pattern or not. Default value is 0.</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Scaling factor applied prior to softmax. The default value is 1/sqrt(head_size)</dd>
<dt><tt>sparse_block_size</tt> : int (required)</dt>
<dd>Number of tokens per sparse block. Choices: 16, 32, 64, 128</dd>
</dl>
#### Inputs (8 - 10)
<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, num_heads * head_size), or packed QKV with shape is(batch_size, sequence_length, d) where d is (num_heads + 2 * kv_num_heads) * head_size.</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
<dt><tt>block_mask</tt> : M</dt>
<dd>block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.</dd>
<dt><tt>key_total_sequence_lengths</tt> : M</dt>
<dd>1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>Cos cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
<dd>Sin cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
</dl>
#### Outputs
<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)</dd>
<dt><tt>present_key</tt> : T</dt>
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
<dt><tt>present_value</tt> : T</dt>
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain integer type.</dd>
</dl>
### <a name="com.microsoft.SparseToDenseMatMul"></a><a name="com.microsoft.sparsetodensematmul">**com.microsoft.SparseToDenseMatMul**</a>
#### Version

Просмотреть файл

@ -906,6 +906,7 @@ Do not modify directly.*
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_mask:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

Просмотреть файл

@ -113,6 +113,35 @@ struct GroupQueryAttentionParameters {
int* zero_ptr;
};
// Parameters for sparse attention.
struct SparseAttentionParameters {
int batch_size; // batch size
int sequence_length; // sequence length of input query, key, value
int hidden_size; // hidden size of query
int num_heads; // number of heads of query
int head_size; // hidden size per head of query, key or value
int kv_hidden_size; // hidden size of key or value
int kv_num_heads; // number of heads of key or value
bool do_rotary; // whether to use rotary embedding
bool rotary_interleaved; // whether to use interleaved rotary embedding
int rotary_dim; // rotary embedding dimension
int sparse_block_size; // block size for sparse attention
int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask
float scale; // scaling factor applied prior to softmax
bool is_packed_qkv; // whether qkv is packed
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
int max_sequence_length; // max sequence length allowed
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
};
constexpr bool LAYOUT_BSNH = false;
constexpr bool LAYOUT_BNSH = true;
namespace sparse_attention {
// Environment variable to enable or disable sparse attention v1 kernel. Default is 0 (enabled).
constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_V1";
} // namespace sparse_attention
namespace attention {
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";

Просмотреть файл

@ -5,30 +5,12 @@
#include <string>
#include "core/framework/ort_value.h"
// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
namespace onnxruntime {
namespace contrib {
namespace transformers {
// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
#ifdef DEBUG_GENERATION
#define DUMP_TENSOR_LEVEL 2
#else
#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
#endif
#if DUMP_TENSOR_LEVEL > 0
#define DUMP_TENSOR_INIT() transformers::CudaTensorConsoleDumper dumper
#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__)
#else
#define DUMP_TENSOR_INIT()
#define DUMP_TENSOR(...)
#endif
#if DUMP_TENSOR_LEVEL > 1
#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__)
#else
#define DUMP_TENSOR_D(...)
#endif
class IConsoleDumper {
public:
IConsoleDumper() : is_enabled_(true) {}

Просмотреть файл

@ -129,6 +129,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output,
int total_matrix_count = -1);
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block);
Status LaunchConcatTensorToTensor(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,

Просмотреть файл

@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) {
AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
@ -257,9 +257,9 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
batch_size, sequence_length, num_heads, qk_head_size,
data.query, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (!use_fused_kernel) {
@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) {
AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
@ -301,10 +301,10 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
batch_size, kv_sequence_length, num_heads, qk_head_size,
data.key, kv_bias, k,
data.key, kv_bias, data.k,
true, v_head_size, qkv_add_bias, 2);
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else {
if (data.fused_cross_attention_kernel == nullptr) {
@ -461,11 +461,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block,
data.q, data.k, data.v, data.qkv_format));
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block,
data.q, data.k, data.v, data.qkv_format));
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, data.qkv_format));
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block,
data.q, data.k, data.v, data.qkv_format));
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, data.qkv_format));
} else { // multihead attention operator, no past, separated Q/K/V inputs
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block,
data.q, data.k, data.v, data.qkv_format));

Просмотреть файл

@ -298,6 +298,12 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
return CUDA_CALL(cudaGetLastError());
}
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block) {
return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads,
max_threads_per_block, false, input, output);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -47,6 +47,9 @@ limitations under the License.
using namespace onnxruntime::cuda;
// Macro to help compute index of flatten 4D matrix, note that dim1 is not used so it is excluded.
#define INDEX_4D(dim2, dim3, dim4, i, j, k, l) ((i) * (dim2) * (dim3) * (dim4) + (j) * (dim3) * (dim4) + (k) * (dim4) + (l))
namespace onnxruntime {
namespace contrib {
namespace cuda {
@ -216,123 +219,162 @@ template <typename T>
__global__ void ConcatKVInPlace(const int max_seqlen,
T* kv_buff,
const T* new_kv,
const int* seqlens_k,
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
const int* past_seqlens_k,
const int* total_seqlens_k,
const bool is_past_kv_bnsh_format,
const bool is_new_kv_bnsh_format) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int new_seqlen = gridDim.x;
const int num_heads = blockDim.y;
const int kv_num_heads = blockDim.y;
const int H = blockDim.x;
const int present_batch_stride = max_seqlen * num_heads * H;
const int present_row_stride = is_bsnh ? num_heads * H : H;
const int present_head_stride = is_bsnh ? H : max_seqlen * H;
const int past_seq_len = (total_seqlens_k != nullptr)
? (total_seqlens_k[b] - new_seqlen)
: (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]);
// kv_buff: BTNH or BNTH with buffered memory for new
// new_kv: BLNH
int out_offset = is_past_kv_bnsh_format
? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h)
: INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h);
const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b];
int in_offset = is_new_kv_bnsh_format
? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h)
: INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h);
int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h;
// Note: new KV always BSNH
const int new_batch_stride = new_seqlen * num_heads * H;
const int new_row_stride = num_heads * H;
const int new_head_stride = H;
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
kv_buff[out_offset] = new_kv[in_offset];
}
template <typename T>
__global__ void ConcatKVInPlaceLarge(const int max_seqlen,
const int H,
const int num_heads,
const int kv_num_heads,
T* kv_buff,
const T* new_kv,
const int* seqlens_k,
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
const int* past_seqlens_k,
const int* total_seqlens_k,
const bool is_past_kv_bnsh_format,
const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh
int i = threadIdx.x + (blockDim.x * blockIdx.x);
if (i < H * num_heads) {
if (i < H * kv_num_heads) {
const int h = i % H;
const int n = i / H;
const int s = blockIdx.y;
const int b = blockIdx.z;
const int new_seqlen = gridDim.y;
const int past_seq_len = (total_seqlens_k != nullptr)
? (total_seqlens_k[b] - new_seqlen)
: (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]);
const int present_batch_stride = max_seqlen * num_heads * H;
const int present_row_stride = is_bsnh ? num_heads * H : H;
const int present_head_stride = is_bsnh ? H : max_seqlen * H;
int out_offset = is_past_kv_bnsh_format
? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h)
: INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h);
// kv_buff: BTNH or BNTH with buffered memory for new
// new_kv: BLNH
int in_offset = is_new_kv_bnsh_format
? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h)
: INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h);
const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b];
int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h;
// Note: new KV always BSNH
const int new_batch_stride = new_seqlen * num_heads * H;
const int new_row_stride = num_heads * H;
const int new_head_stride = H;
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
kv_buff[out_offset] = new_kv[in_offset];
}
}
// Concat new to kv buffer in place
template <typename T>
Status LaunchConcatKVInPlace(int batch_size,
int kv_num_heads,
int head_size,
int max_sequence_length,
const int* past_seqlens_k,
const int* total_seqlens_k,
int new_seq_len,
const T* new_key,
const T* new_value,
T* present_key,
T* present_value,
bool is_past_kv_bnsh_format,
bool is_new_kv_bnsh_format,
cudaStream_t stream,
const int max_threads_per_block) {
static_assert(sizeof(T) == 2);
assert(head_size % 4 == 0);
const int H = head_size / 4;
if (H * kv_num_heads <= max_threads_per_block) {
const dim3 grid(new_seq_len, batch_size, 1);
const dim3 block(H, kv_num_heads, 1);
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(max_sequence_length,
reinterpret_cast<float2*>(present_key),
reinterpret_cast<const float2*>(new_key),
past_seqlens_k,
total_seqlens_k,
is_past_kv_bnsh_format,
is_new_kv_bnsh_format);
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(max_sequence_length,
reinterpret_cast<float2*>(present_value),
reinterpret_cast<const float2*>(new_value),
past_seqlens_k,
total_seqlens_k,
is_past_kv_bnsh_format,
is_new_kv_bnsh_format);
} else {
int steps = int(ceil(float(H * kv_num_heads) / 256.0));
const dim3 grid(steps, new_seq_len, batch_size);
const dim3 block(256, 1, 1);
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(max_sequence_length,
H,
kv_num_heads,
reinterpret_cast<float2*>(present_key),
reinterpret_cast<const float2*>(new_key),
past_seqlens_k,
total_seqlens_k,
is_past_kv_bnsh_format,
is_new_kv_bnsh_format);
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(max_sequence_length,
H,
kv_num_heads,
reinterpret_cast<float2*>(present_value),
reinterpret_cast<const float2*>(new_value),
past_seqlens_k,
total_seqlens_k,
is_past_kv_bnsh_format,
is_new_kv_bnsh_format);
}
return CUDA_CALL(cudaGetLastError());
}
// Concat new to kv buffer in place
template <typename T>
Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
const void* new_key,
const void* new_value,
bool is_new_kv_bnsh_format,
cudaStream_t stream,
const int max_threads_per_block) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.sequence_length;
const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
const int max_sequence_length = parameters.seqlen_present_kv_cache;
const int* past_seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast<const int*>(data.seqlens_k);
// Indicates past sequence_length of each sequence
const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast<const int*>(data.seqlens_k);
assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH ||
parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
bool is_past_kv_bnsh_format = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
const int H = head_size / 4;
if (H * kv_num_heads <= max_threads_per_block) {
const dim3 grid(kv_sequence_length, batch_size, 1);
const dim3 block(H, kv_num_heads, 1);
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(present_sequence_length,
reinterpret_cast<float2*>(data.present_key),
reinterpret_cast<const float2*>(new_key),
seqlens_k,
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(present_sequence_length,
reinterpret_cast<float2*>(data.present_value),
reinterpret_cast<const float2*>(new_value),
seqlens_k,
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
} else {
int steps = int(ceil(float(H * kv_num_heads) / 256.0));
const dim3 grid(steps, kv_sequence_length, batch_size);
const dim3 block(256, 1, 1);
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(present_sequence_length,
H,
kv_num_heads,
reinterpret_cast<float2*>(data.present_key),
reinterpret_cast<const float2*>(new_key),
seqlens_k,
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(present_sequence_length,
H,
kv_num_heads,
reinterpret_cast<float2*>(data.present_value),
reinterpret_cast<const float2*>(new_value),
seqlens_k,
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
}
return CUDA_CALL(cudaGetLastError());
return LaunchConcatKVInPlace(parameters.batch_size,
parameters.kv_num_heads,
parameters.head_size,
max_sequence_length,
past_seqlens_k,
nullptr, // total_seqlens_k is not available
parameters.sequence_length,
reinterpret_cast<const T*>(new_key),
reinterpret_cast<const T*>(new_value),
data.present_key,
data.present_value,
is_past_kv_bnsh_format,
is_new_kv_bnsh_format,
stream,
max_threads_per_block);
}
// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh
@ -474,41 +516,60 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i
}
// Kernel to unpack qkv from packed qkv
template <typename T>
template <typename T, bool output_bnsh>
__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length,
const int batch_size) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
int d = (num_heads + 2 * kv_num_heads) * head_size;
const int qkv_size = batch_size * sequence_length * d;
const int q_size = num_heads * head_size;
const int k_size = kv_num_heads * head_size;
const int q_hidden = num_heads * head_size;
const int k_hidden = kv_num_heads * head_size;
if (tid < qkv_size) {
int batch = tid / (d * sequence_length);
int sequence = (tid % (d * sequence_length)) / d;
int b = tid / (d * sequence_length);
int s = (tid % (d * sequence_length)) / d;
int offset = tid % d;
if (offset < q_size) {
int unpacked_i = batch * sequence_length * num_heads * head_size + sequence * num_heads * head_size + offset;
if (output_bnsh) { // output BNSH
int head_count = kv_num_heads;
if (offset < q_hidden) {
head_count = num_heads;
} else if (offset < q_hidden + k_hidden) {
offset -= q_hidden;
} else {
offset -= (q_hidden + k_hidden);
}
int n = offset / head_size;
int h = offset % head_size;
int unpacked_i = INDEX_4D(head_count, sequence_length, head_size, b, n, s, h);
unpacked_q[unpacked_i] = packed_qkv[tid];
} else if (offset < q_size + k_size) {
int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size);
unpacked_k[unpacked_i] = packed_qkv[tid];
} else {
int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size - k_size);
unpacked_v[unpacked_i] = packed_qkv[tid];
} else { // output BSNH
if (offset < q_hidden) {
int unpacked_i = b * sequence_length * num_heads * head_size + s * num_heads * head_size + offset;
unpacked_q[unpacked_i] = packed_qkv[tid];
} else if (offset < q_hidden + k_hidden) {
int unpacked_i = b * sequence_length * kv_num_heads * head_size +
s * kv_num_heads * head_size + (offset - q_hidden);
unpacked_k[unpacked_i] = packed_qkv[tid];
} else {
int unpacked_i = b * sequence_length * kv_num_heads * head_size +
s * kv_num_heads * head_size + (offset - q_hidden - k_hidden);
unpacked_v[unpacked_i] = packed_qkv[tid];
}
}
}
}
// Unpack packed qkv
template <typename T>
template <typename T, bool output_bnsh>
Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
cudaStream_t stream, const int max_threads_per_block) {
const int threads = max_threads_per_block;
const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads;
UnpackQKV<<<blocks, threads, 0, stream>>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads,
head_size, sequence_length, batch_size);
UnpackQKV<T, output_bnsh><<<blocks, threads, 0, stream>>>(
packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size);
return CUDA_CALL(cudaGetLastError());
}
@ -660,8 +721,14 @@ Status EfficientAttention(
auto q = reinterpret_cast<T*>(data.unpacked_qkv_buffer);
auto k = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size);
auto v = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size + k_size);
ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast<const T*>(data.query), q, k, v, num_heads, kv_num_heads,
head_size, sequence_length, batch_size, stream, max_threads_per_block));
Status status = LaunchUnpackQKV<T, LAYOUT_BSNH>(
reinterpret_cast<const T*>(data.query), q, k, v, num_heads, kv_num_heads,
head_size, sequence_length, batch_size, stream, max_threads_per_block);
if (status != Status::OK()) {
return status;
}
query = reinterpret_cast<const void*>(q);
key = reinterpret_cast<const void*>(k);
value = reinterpret_cast<const void*>(v);
@ -713,7 +780,9 @@ Status EfficientAttention(
"Past and present kv shall share the same tensor when kv_share_buffer is on.");
}
// Concatenate new kv in place
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, stream, max_threads_per_block));
constexpr bool is_new_kv_bnsh_format = false;
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block));
} else {
// Not share buffer case
if (data.past_key != nullptr && data.past_key == data.present_key) {
@ -825,6 +894,51 @@ template Status QkvToContext<BFloat16>(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<BFloat16>& data);
template Status LaunchUnpackQKV<half, LAYOUT_BNSH>(
const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
cudaStream_t stream, const int max_threads_per_block);
template Status LaunchUnpackQKV<BFloat16, LAYOUT_BNSH>(
const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
cudaStream_t stream, const int max_threads_per_block);
template Status LaunchConcatKVInPlace<half>(int batch_size,
int kv_num_heads,
int head_size,
int max_sequence_length,
const int* past_seqlens_k,
const int* total_seqlens_k,
int new_seq_len,
const half* new_key,
const half* new_value,
half* present_key,
half* present_value,
bool is_past_kv_bnsh_format,
bool is_new_kv_bnsh_format,
cudaStream_t stream,
const int max_threads_per_block);
template Status LaunchConcatKVInPlace<BFloat16>(int batch_size,
int kv_num_heads,
int head_size,
int max_sequence_length,
const int* past_seqlens_k,
const int* total_seqlens_k,
int new_seq_len,
const BFloat16* new_key,
const BFloat16* new_value,
BFloat16* present_key,
BFloat16* present_value,
bool is_past_kv_bnsh_format,
bool is_new_kv_bnsh_format,
cudaStream_t stream,
const int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#undef OFFSET_BNSH
#undef OFFSET_BSNH

Просмотреть файл

@ -51,6 +51,28 @@ Status QkvToContext(
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data);
template <typename T, bool output_bnsh>
Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
cudaStream_t stream, const int max_threads_per_block);
template <typename T>
Status LaunchConcatKVInPlace(int batch_size,
int kv_num_heads,
int head_size,
int max_sequence_length, // max sequence length of present_key or present_value.
const int* past_seqlens_k, // it is not used when total_seqlens_k is available.
const int* total_seqlens_k, // optional, nullptr means it is not available.
int new_seq_len,
const T* new_key,
const T* new_value,
T* present_key,
T* present_value,
bool is_past_kv_bnsh_format,
bool is_new_kv_bnsh_format,
cudaStream_t stream,
const int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -18,120 +18,120 @@ namespace contrib {
namespace cuda {
template <typename T>
__global__ void RotaryEmbeddingBSNH(T *output, // BxSxNxH
const T *input, // BxSxNxH
const T *cos_cache, // Mx(H/2)
const T *sin_cache, // Mx(H/2)
const int64_t *position_ids, // (1) or BxS
__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const T* input, // BxSxNxH
const T* cos_cache, // Mx(H/2)
const T* sin_cache, // Mx(H/2)
const int64_t* position_ids, // (1) or BxS
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int position_ids_format,
const bool interleaved, const int batch_stride, const int seq_stride,
const int head_stride) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently
const int b = blockIdx.y;
const int s = blockIdx.x;
const int n = blockIdx.z;
const int b = blockIdx.y;
const int s = blockIdx.x;
const int n = blockIdx.z;
const int i = threadIdx.x;
const int i = threadIdx.x;
if (i >= head_size) {
return;
}
if (i >= head_size) {
return;
}
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T *input_data = input + block_offset;
T *output_data = output + block_offset;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
if (i >= rotary_embedding_dim) {
output_data[i] = input_data[i];
return;
}
if (i >= rotary_embedding_dim) {
output_data[i] = input_data[i];
return;
}
// Cache is (M, H/2)
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
const int position_id = (position_ids_format == 0) ? static_cast<int>(position_ids[0]) + s
: static_cast<int>(position_ids[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_embedding_dim;
const T *cos_data = cos_cache + cache_offset;
const T *sin_data = sin_cache + cache_offset;
// Cache is (M, H/2)
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
const int position_id = (position_ids_format == 0) ? static_cast<int>(position_ids[0]) + s
: static_cast<int>(position_ids[b * sequence_length + s]);
const int cache_offset = position_id * half_rotary_embedding_dim;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
if (interleaved) {
cache_idx = (i / 2) % half_rotary_embedding_dim;
sign = (i % 2 == 0) ? -1 : 1;
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_rotary_embedding_dim;
sign = (i < half_rotary_embedding_dim) ? -1 : 1;
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
int cache_idx = 0;
T sign = 0;
int j = 0;
if (interleaved) {
cache_idx = (i / 2) % half_rotary_embedding_dim;
sign = (i % 2 == 0) ? -1 : 1;
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_rotary_embedding_dim;
sign = (i < half_rotary_embedding_dim) ? -1 : 1;
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T *output, const T *input, const int64_t *position_ids,
const T *cos_cache, const T *sin_cache, const int batch_size,
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool transposed) {
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
const int max_threads_per_block, const bool is_input_bnsh_format) {
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
int tpb = (head_size + 31) / 32 * 32;
int tpb = (head_size + 31) / 32 * 32;
const dim3 block(tpb);
const dim3 grid(sequence_length, batch_size, num_heads);
const dim3 block(tpb);
const dim3 grid(sequence_length, batch_size, num_heads);
// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (transposed) {
// When transposed, input tensor shape is [batch, num_heads, seq, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}
// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (is_input_bnsh_format) {
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}
assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
interleaved, batch_stride, seq_stride, head_stride);
assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
interleaved, batch_stride, seq_stride, head_stride);
return CUDA_CALL(cudaGetLastError());
return CUDA_CALL(cudaGetLastError());
}
template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float *output, const float *input,
const int64_t *position_ids, const float *cos_cache,
const float *sin_cache, const int batch_size,
template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float* output, const float* input,
const int64_t* position_ids, const float* cos_cache,
const float* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool transposed);
const int max_threads_per_block, const bool is_input_bnsh_format);
template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half *output, const half *input,
const int64_t *position_ids, const half *cos_cache,
const half *sin_cache, const int batch_size,
template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half* output, const half* input,
const int64_t* position_ids, const half* cos_cache,
const half* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool transposed);
const int max_threads_per_block, const bool is_input_bnsh_format);
template Status LaunchRotaryEmbeddingKernel<BFloat16>(
cudaStream_t stream, BFloat16 *output, const BFloat16 *input, const int64_t *position_ids,
const BFloat16 *cos_cache, const BFloat16 *sin_cache, const int batch_size, const int sequence_length,
cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids,
const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length,
const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed);
const int position_ids_format, const bool interleaved, const int max_threads_per_block,
const bool is_input_bnsh_format);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -26,7 +26,7 @@ Status LaunchRotaryEmbeddingKernel(
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
const bool transposed);
const bool is_input_bnsh_format);
} // namespace cuda
} // namespace contrib

Просмотреть файл

@ -258,8 +258,12 @@ class FusedMultiHeadCrossAttentionKernel
<< "\t force_unroll: " << param.force_unroll << "\n";
}
int32_t getSForUnroll(Fused_multihead_attention_params_mhca const& param) const override {
return param.s_q;
dim3 getGridDim(const FusedMultiHeadCrossAttentionKernelMetaInfoV2& kernelMeta,
const Fused_multihead_attention_params_mhca& params) const override {
dim3 gridDim(params.h,
params.b,
params.force_unroll ? ((params.s_q + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep) : 1);
return gridDim;
}
};

Просмотреть файл

@ -59,6 +59,8 @@ CUDADriverWrapper::CUDADriverWrapper() {
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuDeviceGetAttribute) = load_sym(handle, "cuDeviceGetAttribute");
*reinterpret_cast<void**>(&_cuFuncSetCacheConfig) = load_sym(handle, "cuFuncSetCacheConfig");
}
CUDADriverWrapper::~CUDADriverWrapper() {
@ -73,6 +75,14 @@ CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attr
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
}
CUresult CUDADriverWrapper::cuDeviceGetAttribute(int* pi, CUdevice_attribute attrib, CUdevice dev) const {
return (*_cuDeviceGetAttribute)(pi, attrib, dev);
}
CUresult CUDADriverWrapper::cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) const {
return (*_cuFuncSetCacheConfig)(hfunc, config);
}
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const {
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
}
@ -126,6 +136,13 @@ CUresult CUDADriverWrapper::cuLaunchKernel(
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
}
// Initialize the singleton instance
CUDADriverWrapper CUDADriverWrapper::instance;
const CUDADriverWrapper* CUDADriverWrapper::GetInstance() {
return &instance;
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -45,6 +45,10 @@ class CUDADriverWrapper {
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuDeviceGetAttribute(int* pi, CUdevice_attribute attrib, CUdevice dev) const;
CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
CUresult cuModuleUnload(CUmodule hmod) const;
@ -73,10 +77,14 @@ class CUDADriverWrapper {
uint32_t blockDimY, uint32_t blockDimZ, uint32_t sharedMemBytes, CUstream hStream, void** kernelParams,
void** extra) const;
static const CUDADriverWrapper* GetInstance();
private:
void* handle;
CUresult (*_cuGetErrorName)(CUresult, const char**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuDeviceGetAttribute)(int*, CUdevice_attribute, CUdevice);
CUresult (*_cuFuncSetCacheConfig)(CUfunction, CUfunc_cache);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState);
@ -92,6 +100,8 @@ class CUDADriverWrapper {
CUfunction f, uint32_t gridDimX, uint32_t gridDimY, uint32_t gridDimZ,
uint32_t blockDimX, uint32_t blockDimY, uint32_t blockDimZ, uint32_t sharedMemBytes, CUstream hStream,
void** kernelParams, void** extra);
static CUDADriverWrapper instance;
};
inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line) {

Просмотреть файл

@ -311,8 +311,12 @@ class FusedMultiHeadFlashAttentionKernel
<< "\t force_unroll: " << param.force_unroll << "\n";
}
int32_t getSForUnroll(Fused_multihead_attention_params_v2 const& param) const override {
return param.s;
dim3 getGridDim(const FusedMultiHeadFlashAttentionKernelMetaInfoV2& kernelMeta,
const Fused_multihead_attention_params_v2& params) const override {
dim3 gridDim(params.h,
params.b,
params.force_unroll ? ((params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep) : 1);
return gridDim;
}
uint64_t hashID(KernelMeta const& kernelMeta) const {

Просмотреть файл

@ -29,6 +29,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename TKernelMeta, typename TKernelParam>
class TSharedCubinKernel {
public:
@ -105,7 +106,7 @@ class TSharedCubinKernel {
virtual void dumpHashId(TKernelParam const& param, std::ostringstream& message) const = 0;
virtual int32_t getSForUnroll(TKernelParam const& param) const = 0;
virtual dim3 getGridDim(const TKernelMeta& kernelMeta, const TKernelParam& params) const = 0;
virtual void run(TKernelParam& params, cudaStream_t ss) const {
ORT_ENFORCE(!params.interleaved); // interleaved is for int8
@ -126,16 +127,10 @@ class TSharedCubinKernel {
CUfunction const func = findIter->second.mDeviceFunction;
void* kernelParams[] = {&params, nullptr};
if (!params.force_unroll) {
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, ss, kernelParams, nullptr),
mDriver);
} else {
int32_t unroll = (getSForUnroll(params) + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep;
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, ss, kernelParams, nullptr),
mDriver);
}
dim3 gridDim = getGridDim(kernelMeta, params);
cuErrCheck(mDriver.cuLaunchKernel(func, gridDim.x, gridDim.y, gridDim.z, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, ss, kernelParams, nullptr), mDriver);
}
virtual ~TSharedCubinKernel() = default;

Просмотреть файл

@ -6,207 +6,221 @@
using namespace onnxruntime::common;
// Macros to avoid long line length
#define CUDA_MS_OP_CLASS_NAME(ver, name) \
ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, name)
#define CUDA_MS_OP_TYPED_CLASS_NAME(ver, type, name) \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type, name)
#define CUDA_MS_OP_VERSIONED_TYPED_CLASS_NAME(start_ver, end_ver, type, name) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name)
#define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name)
#define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name)
#define CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(start_ver, end_ver, type, name) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, start_ver, end_ver, type, name)
namespace onnxruntime {
namespace contrib {
namespace cuda {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Irfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Irfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Irfft);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMulConj);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasSoftmax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskBiasDropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasAdd);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasAdd);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, QuickGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, QuickGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QuickGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedMatMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FusedMatMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FusedMatMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RelativePositionBias);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RelativePositionBias);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GatedRelativePositionBias);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GatedRelativePositionBias);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RemovePadding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RemovePadding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RestorePadding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RestorePadding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Rfft);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Rfft);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Rfft);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Irfft);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Irfft);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Irfft);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMulConj);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMulConj);
class CUDA_MS_OP_CLASS_NAME(1, BiasSoftmax);
class CUDA_MS_OP_CLASS_NAME(1, BiasDropout);
class CUDA_MS_OP_CLASS_NAME(1, BitmaskDropout);
class CUDA_MS_OP_CLASS_NAME(1, BitmaskBiasDropout);
class CUDA_MS_OP_CLASS_NAME(1, NGramRepeatBlock);
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to maintain backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Affine);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Affine);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BeamSearch);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WhisperBeamSearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LongformerAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QuantizeWithOrder);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DequantizeWithOrder);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Affine);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Affine);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Affine);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Attention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Attention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedMultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedMultiHeadAttention);
class CUDA_MS_OP_CLASS_NAME(1, BeamSearch);
class CUDA_MS_OP_CLASS_NAME(1, WhisperBeamSearch);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ConvTransposeWithDynamicPads);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Crop);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
class CUDA_MS_OP_CLASS_NAME(1, QMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int64_t, DynamicSlice);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, EmbedLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, EmbedLayerNormalization);
class CUDA_MS_OP_CLASS_NAME(1, GreedySearch);
class CUDA_MS_OP_CLASS_NAME(1, GroupNorm);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, NhwcConv);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, NhwcConv);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ImageScaler);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ImageScaler);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ImageScaler);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, LongformerAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, LongformerAttention);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ParametricSoftplus);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ParametricSoftplus);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ParametricSoftplus);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RotaryEmbedding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RotaryEmbedding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, RotaryEmbedding);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GemmaRotaryEmbedding);
class CUDA_MS_OP_CLASS_NAME(1, Sampling);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ScaledTanh);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ScaledTanh);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ScaledTanh);
class CUDA_MS_OP_CLASS_NAME(1, SkipGroupNorm);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu);
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_float, LayerNormalization);
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, double_double_double, LayerNormalization);
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_MLFloat16, LayerNormalization);
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_MLFloat16, LayerNormalization);
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_float, LayerNormalization);
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, BFloat16_float_BFloat16, LayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_float, SimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double_double_double, SimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_MLFloat16, SimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_float, SimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, BFloat16_float_BFloat16, SimplifiedLayerNormalization);
class CUDA_MS_OP_CLASS_NAME(1, Inverse);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulNBits);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4);
class CUDA_MS_OP_CLASS_NAME(1, Trilu);
class CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor);
class CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, QuantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, QuantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, DequantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedLayerNormalization);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedGelu);
class CUDA_MS_OP_CLASS_NAME(1, QuantizeWithOrder);
class CUDA_MS_OP_CLASS_NAME(1, DequantizeWithOrder);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedAttention);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedLongformerAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedSelfAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedSelfAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedMultiHeadAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedMultiHeadAttention);
class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention);
#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once
// 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather);
class CUDA_MS_OP_CLASS_NAME(1, ShrunkenGather);
#endif
#if defined(ORT_USE_NCCL)
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);
class CUDA_MS_OP_CLASS_NAME(1, AllReduce);
class CUDA_MS_OP_CLASS_NAME(1, AllGather);
class CUDA_MS_OP_CLASS_NAME(1, AllToAll);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ShardedMoE);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ShardedMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedMatMul);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSlice);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedReshape);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReshape);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedExpand);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedExpand);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceSum);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMax);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMean);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedUnsqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedUnsqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedUnsqueeze);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedUnsqueeze);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedUnsqueeze);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedUnsqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedSqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSqueeze);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
#endif
#ifdef ENABLE_CUDA_NHWC_OPS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedSqueeze);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSqueeze);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSqueeze);
#endif
template <>
@ -218,206 +232,205 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Irfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Irfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Irfft)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMulConj)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NGramRepeatBlock)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasAdd)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasAdd)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, QuickGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, QuickGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QuickGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FusedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FusedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RelativePositionBias)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RelativePositionBias)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GatedRelativePositionBias)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GatedRelativePositionBias)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RemovePadding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RemovePadding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RestorePadding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RestorePadding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Rfft)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Rfft)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Rfft)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Irfft)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Irfft)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Irfft)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMulConj)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMulConj)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, NGramRepeatBlock)>,
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to maintain backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Affine)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Affine)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BeamSearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WhisperBeamSearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LongformerAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskBiasDropout)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Affine)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Affine)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Affine)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Attention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Attention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedMultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedMultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BeamSearch)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, WhisperBeamSearch)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ConvTransposeWithDynamicPads)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Crop)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int64_t, DynamicSlice)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, EmbedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GreedySearch)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GroupNorm)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, NhwcConv)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, NhwcConv)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ImageScaler)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ImageScaler)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ImageScaler)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, LongformerAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, LongformerAttention)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ParametricSoftplus)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ParametricSoftplus)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ParametricSoftplus)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RotaryEmbedding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, RotaryEmbedding)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GemmaRotaryEmbedding)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Sampling)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ScaledTanh)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ScaledTanh)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, SkipGroupNorm)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_float, LayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, double_double_double, LayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_MLFloat16, LayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_MLFloat16, LayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_float, LayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, BFloat16_float_BFloat16, LayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double_double_double, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, BFloat16_float_BFloat16, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Inverse)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasSoftmax)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasDropout)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BitmaskDropout)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BitmaskBiasDropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, DequantizeLinear)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QuantizeWithOrder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DequantizeWithOrder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QuantizeWithOrder)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DequantizeWithOrder)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedLongformerAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedSelfAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedSelfAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedMultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedMultiHeadAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GemmFloat8)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention)>,
#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
#endif
#ifdef ENABLE_TRAINING_OPS
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once
// 1). compute optimizer is enabled for inference or
// 2). this is needed by inference for other purpose.
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, ShrunkenGather)>,
#endif
#if defined(ORT_USE_NCCL)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, AllReduce)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, AllGather)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, AllToAll)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ShardedMoE)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedMatMul)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSlice)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedReshape)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReshape)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedExpand)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedExpand)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceSum)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMax)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMean)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedUnsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedUnsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedUnsqueeze)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedUnsqueeze)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedUnsqueeze)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedUnsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedSqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze)>,
#endif
#ifdef ENABLE_CUDA_NHWC_OPS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedSqueeze)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSqueeze)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSqueeze)>,
#endif
};

Просмотреть файл

@ -22,7 +22,10 @@ namespace cuda {
onnxruntime::contrib::cuda::GridSample<T, LAYOUT>);
REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
#ifdef ENABLE_CUDA_NHWC_OPS
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
#endif
template <typename T, bool IsNHWC>
GridSample<T, IsNHWC>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {

Просмотреть файл

@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/sparse/block_mask.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
__global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_indices, int num_rows, int num_cols) {
int row = threadIdx.x;
if (row >= num_rows) {
return;
}
// Update input and output data pointers to the start of current head
int head = blockIdx.x;
mask += head * num_rows * num_cols;
csr_row_indices += head * (num_rows + 1);
csr_col_indices += head * num_rows * num_cols;
int count = 0;
for (int col = 0; col < num_cols; col++) {
if (mask[row * num_cols + col] == 1) {
count++;
}
}
extern __shared__ int non_zero_counts[];
non_zero_counts[threadIdx.x] = count;
__syncthreads();
// The first thread will calculate the accumulated partial sum of non-zero counts.
if (row == 0) {
for (int i = 1; i < num_rows; i++) {
non_zero_counts[i] += non_zero_counts[i - 1];
}
}
__syncthreads();
// The starting index of current row in csr_col_indices
int offset = (row == 0) ? 0 : non_zero_counts[row - 1];
// Output row indices.
csr_row_indices[row] = offset;
if (row == 0) {
// The first thread output the last element.
csr_row_indices[num_rows] = non_zero_counts[num_rows - 1];
}
for (int col = 0; col < num_cols; col++) {
if (mask[row * num_cols + col] == 1) {
csr_col_indices[offset] = col;
offset++;
}
}
// Note that the remaining buffer in csr_col_indices are not filled with dummy value, but it's fine.
// The last element of csr_row_indices is the total number of non-zero elements.
}
void ConvertMaskToCSR(cudaStream_t stream,
const int* mask, // input mask with shape (num_layout, num_rows, num_cols)
int num_layout, // number of layouts
int num_rows, // number of rows
int num_cols, // number of columns
int* csr_row_indices, // output CSR row indices
int* csr_col_indices, // output CSR column indices
int max_threads_per_block) {
int threads_per_block = (num_rows + 31) / 32 * 32;
// Each thread handle one row. The kernel assumes that all rows of one head can be handled in one block.
if (threads_per_block > max_threads_per_block) {
ORT_THROW("num_rows is too large: num_rows=", num_rows, ", max_threads_per_block=", max_threads_per_block);
}
MaskToCSR<<<num_layout, threads_per_block, threads_per_block * sizeof(int), stream>>>(
mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace cuda {
// Convert mask to compressed sparse row (CSR) format ( https://en.wikipedia.org/wiki/Sparse_matrix)
// For example, num_layout=1, num_rows=4 and num_cols=4, and the mask is like
// 1, 0, 0, 0
// 1, 1, 0, 0
// 0, 1, 1, 0
// 0, 1, 1, 1
// The CSR format is like:
// csr_col_indices:
// 0, 0, 1, 1, 2, 1, 2, 3, 0*, 0*, 0*, 0*, 0*, 0*, 0*, 0* (* is padding)
// csr_row_indices:
// 0, 1, 3, 5, 8
void ConvertMaskToCSR(cudaStream_t stream,
const int* mask, // input mask with shape (num_layout, num_rows, num_cols)
int num_layout, // number of layout
int num_rows, // number of rows of block_mask
int num_cols, // number of cols of block_mask
int* csr_row_indices, // output CSR row indices with shape (num_layout, num_rows + 1).
int* csr_col_indices, // output CSR col indices with shape (num_layout, num_rows * num_cols).
int max_threads_per_block);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,338 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
#include "contrib_ops/cuda/sparse/sparse_attention.h"
#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
#include "contrib_ops/cuda/sparse/block_mask.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
#include "core/platform/env_var_utils.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SparseAttention, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \
.MayInplace(3, 1) \
.MayInplace(4, 2) \
.InputMemoryType(OrtMemTypeCPUInput, 6), \
SparseAttention<T>);
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
static inline int32_t DivUp(int32_t m, int32_t n) {
return (m + n - 1) / n;
}
template <typename T>
SparseAttention<T>::SparseAttention(const OpKernelInfo& info)
: CudaKernel(info) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);
int64_t sparse_block_size = 0;
ORT_ENFORCE(info.GetAttr("sparse_block_size", &sparse_block_size).IsOK());
ORT_ENFORCE(sparse_block_size == 64 || sparse_block_size == 128);
sparse_block_size_ = static_cast<int>(sparse_block_size);
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
kernel_loaded_ = false;
disable_v1_kernel_ = ParseEnvironmentVariableWithDefault<bool>(sparse_attention::kDisableSparseAttentionV1, false);
}
template <typename T>
Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
const Tensor* past_key = context->Input<Tensor>(3);
const Tensor* past_value = context->Input<Tensor>(4);
const Tensor* block_mask = context->Input<Tensor>(5);
const Tensor* total_seq_len = context->Input<Tensor>(6);
const Tensor* seqlens_k_total = context->Input<Tensor>(7);
const Tensor* cos_cache = context->Input<Tensor>(8);
const Tensor* sin_cache = context->Input<Tensor>(9);
auto& device_prop = GetDeviceProp();
SparseAttentionParameters parameters;
// Parameters from node attribute
parameters.sparse_block_size = sparse_block_size_;
parameters.num_heads = num_heads_;
parameters.kv_num_heads = kv_num_heads_;
parameters.scale = scale_;
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;
ORT_RETURN_IF_ERROR(sparse_attention_helper::CheckInputs(&parameters,
query,
key,
value,
past_key,
past_value,
cos_cache,
sin_cache,
block_mask,
seqlens_k_total,
total_seq_len));
// Some limitations of CUDA kernels
if (!sparse_attention_v1::is_supported_sparse_attention(device_prop)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support CUDA device with compute capacity 8.*. Got ",
device_prop.major);
}
if (!sparse_attention_v1::is_supported_sparse_attention(parameters.head_size, sparse_block_size_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support head_size=128 and sparse_block_size=64. Got head_size=",
parameters.head_size,
",sparse_block_size=",
sparse_block_size_);
}
if (device_prop.maxThreadsPerBlock > 0 && num_heads_ > device_prop.maxThreadsPerBlock) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"num_heads should be no larger than ", device_prop.maxThreadsPerBlock);
}
int past_seq_len = parameters.total_sequence_length - parameters.sequence_length;
bool is_prompt = (past_seq_len == 0);
bool use_v2_kernel = disable_v1_kernel_ || !is_prompt;
// Async Copy total_k_seq_len from GPU to CPU.
IAllocatorUniquePtr<int32_t> pinned_buffer;
int32_t* total_k_seq_len_pinned = nullptr;
AutoDestoryCudaEvent new_event;
cudaEvent_t& isCopyDone = new_event.Get();
cudaStream_t cuda_stream = Stream(context);
if (use_v2_kernel) {
pinned_buffer = AllocateBufferOnCPUPinned<int32_t>(parameters.batch_size);
total_k_seq_len_pinned = reinterpret_cast<int32_t*>(pinned_buffer.get());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(total_k_seq_len_pinned,
seqlens_k_total->Data<int32_t>(),
sizeof(int32_t) * parameters.batch_size,
cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaEventCreate(&isCopyDone));
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream));
}
if (!kernel_loaded_) {
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16();
} else {
sparse_attention_v1::load_sparse_attention_fp16();
}
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16();
} else {
sparse_attention_v1::load_sparse_attention_bf16();
}
}
kernel_loaded_ = true;
}
// Compute output shape and get output tensors.
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
Tensor* output = context->Output(0, output_shape);
assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
std::vector<int64_t> present_dims = {
parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(1, present_shape);
Tensor* present_value = context->Output(2, present_shape);
// Set input and output data.
typedef typename ToCudaType<T>::MappedType CudaT;
SparseAttentionData<CudaT> data;
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
data.block_mask = block_mask->Data<int32_t>();
data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data<int32_t>();
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
// Check past and present share buffer.
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
if (parameters.past_present_share_buffer) {
ORT_ENFORCE(data.past_value != nullptr && data.past_value == data.present_value);
}
if (!parameters.past_present_share_buffer) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"CUDA implementation of SparseAttention requires past and present points to same buffer");
}
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
}
// Currently, we use same block size in kernel.
// TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask).
data.kernel_layout.block_size = parameters.sparse_block_size;
data.kernel_layout.mask = data.block_mask;
data.kernel_layout.num_layout = parameters.num_sparse_layout;
data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size;
data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size;
// Allocate buffer for CSR col and row indices.
onnxruntime::Stream* stream = context->GetComputeStream();
int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows;
auto csr_col_indices_buffer = GetScratchBuffer<int>(static_cast<size_t>(dense_blocks), stream);
auto csr_row_indices_buffer = GetScratchBuffer<int>(
static_cast<size_t>(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream);
data.kernel_layout.csr_col_indices = reinterpret_cast<const int*>(csr_col_indices_buffer.get());
data.kernel_layout.csr_row_indices = reinterpret_cast<const int*>(csr_row_indices_buffer.get());
ConvertMaskToCSR(cuda_stream,
data.kernel_layout.mask,
data.kernel_layout.num_layout,
data.kernel_layout.num_rows,
data.kernel_layout.num_cols,
csr_row_indices_buffer.get(),
csr_col_indices_buffer.get(),
device_prop.maxThreadsPerBlock);
size_t rotary_buffer_bytes = 0;
if (do_rotary_) {
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads *
parameters.sequence_length * parameters.head_size;
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
}
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
size_t transposed_q_bytes = 0;
if (!parameters.is_packed_qkv) {
transposed_q_bytes = parameters.batch_size * parameters.sequence_length *
parameters.num_heads * parameters.head_size * sizeof(T);
}
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, context->GetComputeStream());
if (transposed_q_buffer) {
data.transposed_q_buffer = reinterpret_cast<CudaT*>(transposed_q_buffer.get());
}
size_t unpacked_qkv_bytes = 0;
if (parameters.is_packed_qkv) {
unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length *
(parameters.num_heads + 2 * parameters.kv_num_heads) *
parameters.head_size * sizeof(T));
}
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
if (unpacked_qkv_buffer) {
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
// Prepare some v2 kernel inputs in CPU then copy to GPU.
IAllocatorUniquePtr<int32_t> v2_kernel_inputs_pinned_buffer;
IAllocatorUniquePtr<int32_t> v2_kernel_buffer;
data.use_v2_kernel = use_v2_kernel;
if (use_v2_kernel) {
// Compute activate q blocks so that we know the size of buffer to allocate.
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone));
int active_q_blocks = 0;
if (is_prompt) {
for (int i = 0; i < parameters.batch_size; i++) {
active_q_blocks += DivUp(is_prompt ? total_k_seq_len_pinned[i] : 1, data.kernel_layout.block_size);
}
} else { // not prompt
assert(parameters.sequence_length == 1);
active_q_blocks = parameters.batch_size;
}
// Compute buffer size: addresses of 6 buffers for v2 kernel need to be aligned to 16.
const size_t aligned_batch_size = DivUp(parameters.batch_size, 16) * 16;
const size_t aligned_num_q_blocks = DivUp(active_q_blocks, 16) * 16;
size_t v2_kernel_buffer_size = 4 * aligned_batch_size + 2 * aligned_num_q_blocks;
// Compute those values in CPU, then copy to GPU
v2_kernel_inputs_pinned_buffer = AllocateBufferOnCPUPinned<int32_t>(v2_kernel_buffer_size);
int32_t* v2_kernel_inputs_pinned = reinterpret_cast<int32_t*>(v2_kernel_inputs_pinned_buffer.get());
int32_t* q_batch_starts = v2_kernel_inputs_pinned;
int32_t* q_batch_ends = q_batch_starts + aligned_batch_size;
int32_t* k_batch_starts = q_batch_ends + aligned_batch_size;
int32_t* k_batch_ends = k_batch_starts + aligned_batch_size;
int32_t* q_batch_ids = k_batch_ends + aligned_batch_size;
int32_t* q_start_sids = q_batch_ids + aligned_num_q_blocks;
// Here assumes right-side padding
if (is_prompt) {
for (int i = 0; i < parameters.batch_size; i++) {
q_batch_starts[i] = 0;
q_batch_ends[i] = total_k_seq_len_pinned[i];
k_batch_starts[i] = 0;
k_batch_ends[i] = total_k_seq_len_pinned[i];
}
} else {
for (int i = 0; i < parameters.batch_size; i++) {
q_batch_starts[i] = 0;
q_batch_ends[i] = 1;
k_batch_starts[i] = 0;
k_batch_ends[i] = total_k_seq_len_pinned[i];
}
}
int current_block = 0;
for (int i = 0; i < parameters.batch_size; i++) {
int blocks = DivUp(q_batch_ends[i] - q_batch_starts[i], data.kernel_layout.block_size);
for (int j = 0; j < blocks; j++) {
q_batch_ids[current_block] = i;
q_start_sids[current_block] = j * data.kernel_layout.block_size;
current_block++;
}
}
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, context->GetComputeStream());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned,
sizeof(int32_t) * v2_kernel_buffer_size,
cudaMemcpyHostToDevice, cuda_stream));
data.q_batch_starts = v2_kernel_buffer.get();
data.q_batch_ends = data.q_batch_starts + aligned_batch_size;
data.k_batch_starts = data.q_batch_ends + aligned_batch_size;
data.k_batch_ends = data.k_batch_starts + aligned_batch_size;
data.q_batch_ids = data.k_batch_ends + aligned_batch_size;
data.q_start_sids = data.q_batch_ids + aligned_num_q_blocks;
data.active_q_blocks = active_q_blocks;
}
return QkvToContext<CudaT>(device_prop, stream, parameters, data);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/cuda/cuda_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace ::onnxruntime::cuda;
template <typename T>
class SparseAttention final : public CudaKernel {
public:
SparseAttention(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;
protected:
int num_heads_; // number of attention heads for q
int kv_num_heads_; // number of attention heads for k and v
float scale_; // Scaling factor applied prior to softmax.
bool is_causal_; // unidirectional attention or not
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Disable V2 kernel
mutable bool kernel_loaded_; // Kernel has been loaded
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,244 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace sparse_attention_helper {
Status CheckInputs(void* params,
const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
const Tensor* block_mask,
const Tensor* seqlens_k_total,
const Tensor* total_seq_len) {
// No packing for q/k/v:
// query (batch_size, sequence_length, num_heads * head_size)
// key (batch_size, kv_sequence_length, kv_num_heads * head_size)
// value (batch_size, kv_sequence_length, kv_num_heads * head_size)
// Packed q/k/v:
// query (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
// key nullptr
// value nullptr
// Shape for other inputs:
// past_key (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
// past_value (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
// block_mask (num_heads, max_blocks, max_blocks) or (1, max_blocks, max_blocks)
// where max_blocks = max_sequence_length / sparse_block_size
// seqlens_k_total (batch_size) when do_rotary is True, optional otherwise
// total_seq_len (1)
// cos_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
// sin_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
assert(params != nullptr);
SparseAttentionParameters* parameters = reinterpret_cast<SparseAttentionParameters*>(params);
// The following parameters shall be set by parsing node attributes before calling CheckInputs.
const int num_heads = parameters->num_heads;
const int kv_num_heads = parameters->kv_num_heads;
const bool do_rotary = parameters->do_rotary;
constexpr bool is_past_bsnh = false;
const bool is_packed_qkv = key == nullptr;
const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
query_dims.size());
}
int batch_size = static_cast<int>(query_dims[0]);
int sequence_length = static_cast<int>(query_dims[1]);
int q_hidden_size = static_cast<int>(query_dims[2]);
int head_size = 0;
int kv_hidden_size = 0;
if (!is_packed_qkv) {
// Check key and value when not packed
head_size = static_cast<int>(q_hidden_size) / num_heads;
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size = ",
head_size);
}
if (value == nullptr) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
}
const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 0 (batch size)");
}
if (query_dims[1] != key_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
}
kv_hidden_size = static_cast<int>(key_dims[2]);
if (key->Shape() != value->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'value' shall have same shape");
}
} else {
// packed qkv
if (static_cast<int>(q_hidden_size) % (num_heads + 2 * kv_num_heads) != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"packed qkv hidden size= ", q_hidden_size, " does not match num_heads and kv_num_heads",
num_heads, kv_num_heads);
}
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
if (head_size % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size must be a multiple of 8. Got head_size = ", head_size);
}
if (value != nullptr) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
}
q_hidden_size = head_size * num_heads;
kv_hidden_size = head_size * kv_num_heads;
}
const auto& block_mask_dim = block_mask->Shape().GetDims();
if (!(block_mask_dim.size() == 3 && block_mask_dim[1] == block_mask_dim[2] &&
(static_cast<int64_t>(num_heads) % block_mask_dim[0] == 0L))) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"block_mask must have shape (num_layout, max_blocks, max_blocks) where num_heads is divisible by num_layout.");
}
int max_blocks = static_cast<int>(block_mask_dim[1]);
int max_sequence_length = max_blocks * parameters->sparse_block_size;
// Check past-present KV
if (past_key != nullptr && past_value != nullptr) {
if (past_key->Shape() != past_value->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall have same shape");
}
const auto& past_key_dims = past_key->Shape().GetDims();
if (past_key_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' is expected to have 4 dimensions, got ",
past_key_dims.size());
}
if (past_key_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ",
past_key_dims[0]);
}
if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads");
}
int max_cache_sequence_length = static_cast<int>(past_key_dims[is_past_bsnh ? 1 : 2]);
if (max_cache_sequence_length != max_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'block_mask' should have the same sequence length:",
"max_sequence_length deduced from past_key is ", max_cache_sequence_length,
"; max_sequence_length deduced from block_mask is ", max_sequence_length);
}
if (past_key_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3]);
}
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}
// Check the shape of total_key_sequence_lengths. We do not check the values here.
const auto& k_len_dim = seqlens_k_total->Shape().GetDims();
if (k_len_dim.size() != 1 && k_len_dim[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"key_total_sequence_lengths must have shape (batch_size).");
}
if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"total_sequence_length tensor must be of one element.");
}
int total_sequence_length = *((*total_seq_len).template Data<int32_t>());
int rotary_dim = 0;
if (do_rotary) {
if (cos_cache == nullptr || sin_cache == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache must be passed to SparseAttention when do_rotary = 1");
}
const auto& cos_dims = cos_cache->Shape().GetDims();
const auto& sin_dims = sin_cache->Shape().GetDims();
if (head_size % 16 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"head_size shall be a multiple of 16. Got head_size = ",
head_size);
}
if (cos_dims[0] < max_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 0 should be of max_sequence_length.");
}
if (sin_dims[0] < max_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 0 should be of max_sequence_length.");
}
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
}
if (cos_dims[1] != sin_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache dimension 1 must be the same.");
}
rotary_dim = static_cast<int>(cos_dims[1] * 2);
}
parameters->batch_size = batch_size;
parameters->sequence_length = sequence_length;
parameters->total_sequence_length = total_sequence_length;
parameters->max_sequence_length = max_sequence_length;
parameters->hidden_size = q_hidden_size;
parameters->head_size = head_size;
parameters->kv_hidden_size = kv_hidden_size;
parameters->rotary_dim = rotary_dim;
parameters->is_packed_qkv = is_packed_qkv;
parameters->num_sparse_layout = static_cast<int>(block_mask_dim[0]);
return Status::OK();
}
} // namespace sparse_attention_helper
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,334 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
#include "contrib_ops/cuda/sparse/block_mask.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
// Convert total_seq_len_k (total key sequence length excluding paddings) to position_ids for Prompt
__global__ void PositionIdsPrompt(const int32_t* total_seq_len_k,
int64_t* position_ids,
int sequence_length,
int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < batch_size * sequence_length) {
int b = tid / sequence_length;
int s = tid % sequence_length;
if (s < total_seq_len_k[b]) {
position_ids[tid] = s;
} else {
// padding
position_ids[tid] = 1;
}
}
}
// Convert total_seq_len_k (total key sequence length excluding paddings) to position_ids for Token Generation
__global__ void PositionIdsToken(const int32_t* total_seq_len_k,
int64_t* position_ids,
int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < batch_size) {
position_ids[tid] = total_seq_len_k[tid] - 1;
}
}
// Convert total_seq_len_k (total key sequence length excluding paddings) to position_ids
Status FillPositionIds(contrib::SparseAttentionParameters& parameters,
const int32_t* total_seq_len_k,
int64_t* position_ids,
cudaStream_t stream,
const int max_threads_per_block) {
const int sequence_length = parameters.sequence_length;
const int batch_size = parameters.batch_size;
const int bs = batch_size * sequence_length;
int threads = max_threads_per_block;
if (bs <= 64) {
threads = 64;
} else if (bs <= 128) {
threads = 128;
} else if (bs <= 256) {
threads = 256;
} else if (bs <= 512) {
threads = 512;
}
const int blocks = (bs + threads - 1) / threads;
if (parameters.sequence_length == parameters.total_sequence_length) { // prompt
PositionIdsPrompt<<<blocks, threads, 0, stream>>>(total_seq_len_k, position_ids, sequence_length, batch_size);
} else {
PositionIdsToken<<<blocks, threads, 0, stream>>>(total_seq_len_k, position_ids, batch_size);
}
return CUDA_CALL(cudaGetLastError());
}
// Concat new key and value (BSNH format) to kv buffer (BNSH format) in place.
template <typename T>
Status LaunchConcatKVInPlace(contrib::SparseAttentionParameters& parameters,
SparseAttentionData<T>& data,
const void* new_key,
const void* new_value,
bool is_new_kv_bnsh_format,
cudaStream_t stream,
const int max_threads_per_block) {
constexpr bool is_past_kv_bnsh_format = true;
return LaunchConcatKVInPlace(parameters.batch_size,
parameters.kv_num_heads,
parameters.head_size,
parameters.max_sequence_length,
nullptr,
data.seqlens_k_total,
parameters.sequence_length,
reinterpret_cast<const T*>(new_key),
reinterpret_cast<const T*>(new_value),
data.present_key,
data.present_value,
is_past_kv_bnsh_format,
is_new_kv_bnsh_format,
stream,
max_threads_per_block);
}
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
Stream* ort_stream,
contrib::SparseAttentionParameters& parameters,
SparseAttentionData<T>& data) {
cudaStream_t stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
// const int present_sequence_length = parameters.max_sequence_length;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
const void* query;
const void* key;
const void* value;
DUMP_TENSOR_INIT();
if (!parameters.is_packed_qkv) {
static_assert(sizeof(T) == 2);
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(
batch_size, sequence_length, num_heads, head_size,
reinterpret_cast<const half*>(data.query), reinterpret_cast<half*>(data.transposed_q_buffer),
stream, max_threads_per_block));
query = reinterpret_cast<const void*>(data.transposed_q_buffer);
key = reinterpret_cast<const void*>(data.key);
value = reinterpret_cast<const void*>(data.value);
} else {
size_t q_size = static_cast<size_t>(batch_size * sequence_length * num_heads * head_size);
size_t k_size = static_cast<size_t>(batch_size * sequence_length * kv_num_heads * head_size);
auto q = reinterpret_cast<T*>(data.unpacked_qkv_buffer);
auto k = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size);
auto v = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size + k_size);
Status status = LaunchUnpackQKV<T, LAYOUT_BNSH>(data.query, q, k, v, num_heads, kv_num_heads, head_size,
sequence_length, batch_size, stream, max_threads_per_block);
if (status != Status::OK()) {
return status;
}
query = reinterpret_cast<const void*>(q);
key = reinterpret_cast<const void*>(k);
value = reinterpret_cast<const void*>(v);
}
constexpr bool q_layout = LAYOUT_BNSH;
bool kv_layout = parameters.is_packed_qkv ? LAYOUT_BNSH : LAYOUT_BSNH;
DUMP_TENSOR("query", reinterpret_cast<const T*>(query), batch_size, num_heads, sequence_length, head_size);
#if DUMP_TENSOR_LEVEL > 0
if (LAYOUT_BNSH == kv_layout) {
DUMP_TENSOR("key", reinterpret_cast<const T*>(key), batch_size, kv_num_heads, sequence_length, head_size);
DUMP_TENSOR("value", reinterpret_cast<const T*>(value), batch_size, kv_num_heads, sequence_length, head_size);
} else {
DUMP_TENSOR("key", reinterpret_cast<const T*>(key), batch_size, sequence_length, kv_num_heads, head_size);
DUMP_TENSOR("value", reinterpret_cast<const T*>(value), batch_size, sequence_length, kv_num_heads, head_size);
}
#endif
if (parameters.do_rotary) {
size_t bsh = static_cast<size_t>(parameters.batch_size * parameters.sequence_length * parameters.head_size);
size_t q_size = bsh * static_cast<size_t>(parameters.num_heads);
size_t k_size = bsh * static_cast<size_t>(parameters.kv_num_heads);
auto q_buffer = reinterpret_cast<T*>(data.rotary_buffer);
auto k_buffer = q_buffer + q_size;
auto position_ids_buff = reinterpret_cast<int64_t*>(k_buffer + k_size);
ORT_RETURN_IF_ERROR(FillPositionIds(parameters, data.seqlens_k_total, position_ids_buff, stream,
max_threads_per_block));
DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length);
// Launch rotary embedding kernel. This requires separated Q, K and V
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, q_buffer, reinterpret_cast<const T*>(query),
position_ids_buff, data.cos_cache, data.sin_cache,
parameters.batch_size, parameters.sequence_length,
parameters.num_heads, parameters.head_size,
parameters.rotary_dim, parameters.max_sequence_length,
/*position_ids_format*/ 1, parameters.rotary_interleaved,
max_threads_per_block, q_layout));
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, k_buffer, reinterpret_cast<const T*>(key),
position_ids_buff, data.cos_cache, data.sin_cache,
parameters.batch_size, parameters.sequence_length,
parameters.kv_num_heads, parameters.head_size,
parameters.rotary_dim, parameters.max_sequence_length,
/*position_ids_format*/ 1, parameters.rotary_interleaved,
max_threads_per_block, kv_layout));
query = reinterpret_cast<const void*>(q_buffer);
key = reinterpret_cast<const void*>(k_buffer);
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR("query after rotary", reinterpret_cast<const T*>(query),
batch_size, num_heads, sequence_length, head_size);
if (LAYOUT_BNSH == kv_layout) {
DUMP_TENSOR("key after rotary", reinterpret_cast<const T*>(key),
batch_size, kv_num_heads, sequence_length, head_size);
} else {
DUMP_TENSOR("key after rotary", reinterpret_cast<const T*>(key),
batch_size, sequence_length, kv_num_heads, head_size);
}
#endif
}
// Concat new key and value to kv buffers (in BNSH format) in place
ORT_ENFORCE(parameters.past_present_share_buffer);
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
parameters, data, key, value, kv_layout, stream, max_threads_per_block));
// TODO: only dump to total sequence length instead of max sequence length.
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
DUMP_TENSOR("block_mask",
data.kernel_layout.mask,
data.kernel_layout.num_layout,
data.kernel_layout.num_rows,
data.kernel_layout.num_cols);
DUMP_TENSOR("csr_col_indices",
data.kernel_layout.csr_col_indices,
data.kernel_layout.num_layout,
data.kernel_layout.num_rows,
data.kernel_layout.num_cols);
DUMP_TENSOR("csr_row_indices",
data.kernel_layout.csr_row_indices,
data.kernel_layout.num_layout,
data.kernel_layout.num_rows + 1);
printf(
"batch_size=%d, sequence_length=%d, num_heads=%d, kv_num_heads=%d head_size=%d, "
"total_sequence_length=%d, max_sequence_length=%d scale=%f block_size=%d "
"row_stride=%d col_stride=%d num_layout=%d\n",
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
parameters.total_sequence_length,
parameters.max_sequence_length,
parameters.scale,
data.kernel_layout.block_size,
data.kernel_layout.num_rows + 1,
data.kernel_layout.num_rows * data.kernel_layout.num_cols,
data.kernel_layout.num_layout);
#endif
if (data.use_v2_kernel) {
sparse_attention_v2::SparseAttentionParams params(
ort_stream,
data.output,
reinterpret_cast<const void*>(query),
reinterpret_cast<const void*>(data.present_key),
reinterpret_cast<const void*>(data.present_value),
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
parameters.total_sequence_length,
parameters.max_sequence_length,
parameters.scale,
data.kernel_layout.block_size, // kernel_block_size
data.kernel_layout.csr_row_indices, // skip past_seq_len in row indices
data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
data.kernel_layout.num_rows + 1, // stride per head in row indices
data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
data.kernel_layout.num_layout,
data.active_q_blocks,
data.q_batch_starts,
data.q_batch_ends,
data.k_batch_starts,
data.k_batch_ends,
data.q_batch_ids,
data.q_start_sids);
if constexpr (std::is_same<T, BFloat16>::value) {
ORT_RETURN_IF_ERROR(sparse_attention_v2::run_sparse_attention_bf16(params));
} else {
ORT_RETURN_IF_ERROR(sparse_attention_v2::run_sparse_attention_fp16(params));
}
} else {
sparse_attention_v1::SparseAttentionParams params(
ort_stream,
data.output,
reinterpret_cast<const void*>(query),
reinterpret_cast<const void*>(data.present_key),
reinterpret_cast<const void*>(data.present_value),
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.kv_num_heads,
parameters.head_size,
parameters.total_sequence_length,
parameters.max_sequence_length,
parameters.scale,
data.kernel_layout.block_size, // kernel_block_size
data.kernel_layout.csr_row_indices, // (num_layout, num_rows + 1)
data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
data.kernel_layout.num_rows + 1, // stride per head in row indices
data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
data.kernel_layout.num_layout);
if constexpr (std::is_same<T, BFloat16>::value) {
ORT_RETURN_IF_ERROR(sparse_attention_v1::run_sparse_attention_bf16(params));
} else {
ORT_RETURN_IF_ERROR(sparse_attention_v1::run_sparse_attention_fp16(params));
}
}
return Status::OK();
}
template Status QkvToContext<half>(
const cudaDeviceProp& device_prop,
Stream* ort_stream,
contrib::SparseAttentionParameters& parameters,
SparseAttentionData<half>& data);
template Status QkvToContext<BFloat16>(
const cudaDeviceProp& device_prop,
Stream* ort_stream,
contrib::SparseAttentionParameters& parameters,
SparseAttentionData<BFloat16>& data);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/framework/allocator.h"
#include "core/providers/cuda/tunable/cuda_tunable.h"
using onnxruntime::cuda::tunable::CudaTuningContext;
namespace onnxruntime {
namespace contrib {
namespace cuda {
struct BlockLayout {
const int32_t* mask; // shape (num_layout, num_rows, num_cols), where num_rows = num_cols = max_seq_len / block_size.
int num_layout;
int block_size; // kernel block size, which is <= sparse_block_size
const int* csr_col_indices;
const int* csr_row_indices;
int num_rows;
int num_cols;
};
template <typename T>
struct SparseAttentionData {
// Input Tensors
const T* query = nullptr;
const T* key = nullptr;
const T* value = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
const T* cos_cache = nullptr;
const T* sin_cache = nullptr;
const int32_t* block_mask = nullptr;
const int32_t* seqlens_k_total = nullptr;
// Temporary buffers
T* transposed_q_buffer = nullptr;
T* rotary_buffer = nullptr;
T* unpacked_qkv_buffer = nullptr;
// This is sparse layout used in kernel.
BlockLayout kernel_layout;
// Output Tensors
T* output = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
// Data for sparse attention v2 kernel.
bool use_v2_kernel = false;
int* q_batch_starts = nullptr; // shape (batch_size)
int* q_batch_ends = nullptr; // shape (batch_size)
int* k_batch_starts = nullptr; // shape (batch_size)
int* k_batch_ends = nullptr; // shape (batch_size)
int* q_batch_ids = nullptr; // shape (G)
int* q_start_sids = nullptr; // shape (G)
int active_q_blocks = 0; // G: number of blocks in q that are not masked out
};
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
Stream* ort_stream,
contrib::SparseAttentionParameters& parameters,
SparseAttentionData<T>& data);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,126 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
# Example to use this script (Tested with CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install triton==2.3.0
# python3 compile_sparse_attention.py | sh
#
# Note that sparse_attention_v1_*.cc and sparse_attention_dispatcher_*.h are the generated files.
import math
from itertools import product
def generate_triton_compile_shell_script(dtype="fp16"):
assert dtype in ["fp16", "bf16"]
print("export TRITON_ROOT=$(pip show triton | grep Location | cut -d' ' -f2)")
print('export ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader|head -n 1)"')
print("export SM=$(echo $ARCH | sed -e 's/\\.//g')")
# Modify the compile.py to use custom template file template_h.txt and template_c.txt in current directory.
# Also pass block_m to the template.
print(
'python -c "'
"import sys;lines=sys.stdin.read();"
"lines=lines.replace('template_path = Path(__file__).parent / f\\\"compile.{ext}\\\"','template_path = f\\\"compile_template_kernel_{ext}.txt\\\"');"
'lines=lines.replace(\'\\"_placeholder\\": \\"\\",\', \'\\"_placeholder\\": \\"\\",\\n \\"block_m\\": list(constants.values())[0],\');'
'print(lines)"'
"< ${TRITON_ROOT}/triton/tools/compile.py > compile.py"
)
out_dir = f"trition_cubin_{dtype}"
print(f"rm -rf {out_dir}")
print(f"mkdir -p {out_dir}")
# Note that block_n * num_block_d is the head_size. We support head_size = 128 for now.
block_n_values = [64]
block_d_values = [64]
num_block_d_values = [2]
even_m_values = [True, False]
even_n_values = [True, False]
# Use triton compiler to compile the kernel of different combinations of constant parameters.
for block_n, block_d, num_blocks_d, even_m, even_n in product(
block_n_values, block_d_values, num_block_d_values, even_m_values, even_n_values
):
block_m_values = [16, block_n] if block_n != 16 else [block_n]
for block_m in block_m_values:
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm${{SM}}"
name = f"sparse_attention_{dtype}_sm${{SM}}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))
num_stages = 2
# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
print(
f"{prefix} -n block_sparse_attention_kernel -o {out_dir}/{filename} --out-name {name} "
f'-w {num_warps} -ns {num_stages} -s "{sig}" -g "(total_seq_len - past_seq_len + {block_m} - 1) / {block_m}, batch_size * num_heads, 1"'
)
# Generate the dispatcher.
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm${{SM}}"
print(f"cd {out_dir}")
print(f"python ${{TRITON_ROOT}}/triton/tools/link.py sparse_attention_v1_*.h -o {dispatcher}")
print("rm *.h")
# Remove signature hash in code.
suffix = "0d1d2d3d4d5d678910d11d12d13d14d15d16d17d18d19d20d21d222324"
print(f"for file in *.c; do sed -i 's/_{suffix}//g' \"$file\"; done")
# Recover signature hash in kernel name that is removed in previous step. Kernel name shall not be changed.
print(
f"for file in *.c; do sed -i 's/block_sparse_attention_kernel/block_sparse_attention_kernel_{suffix}/g' \"$file\"; done"
)
# Remove signature hash from filename since we use same signature for all kernels except constants.
# and we have constants in filename so that we can distinguish files without the hash.
print('for file in sparse_attention_v1_*.c; do mv -- "$file" "$(echo $file | cut -f 1 -d \'.\').c"; done')
# Change function parameters and return type. If you change the kernel interface, you will need to modify this part.
source1 = "CUstream stream, CUdeviceptr out, CUdeviceptr Q, CUdeviceptr K, CUdeviceptr V, CUdeviceptr layout_csr_row_indices, CUdeviceptr layout_csr_col_indices, int32_t layout_csr_row_stride_h, int32_t layout_csr_col_stride_h, int32_t num_layout, float softmax_scale, int32_t stride_qb, int32_t stride_qh, int32_t stride_qm, int32_t stride_kb, int32_t stride_kh, int32_t stride_kn, int32_t stride_vb, int32_t stride_vh, int32_t stride_vn, int32_t stride_ob, int32_t stride_oh, int32_t stride_om, int32_t num_heads, int32_t num_kv_heads, int32_t total_seq_len"
target1 = "SparseAttentionParams& params"
source2 = "stream, out, Q, K, V, layout_csr_row_indices, layout_csr_col_indices, layout_csr_row_stride_h, layout_csr_col_stride_h, num_layout, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om, num_heads, num_kv_heads, total_seq_len"
target2 = "params"
print(
f"python -c \"import sys;lines=sys.stdin.read();lines=lines.replace('{source1}', '{target1}');"
f'lines=lines.replace(\'{source2}\', \'{target2}\');print(lines)" < "{dispatcher}.c" > "{dispatcher}.h"'
)
print(f"sed -i 's/CUresult/Status/g' \"{dispatcher}.h\"")
# Remove parameter checking since we moved the validation logic to SparseAttentionParams
print(f"sed -i '/if /d' \"{dispatcher}.h\"")
print(f"sed -i '/CUDA_ERROR_INVALID_VALUE/d' \"{dispatcher}.h\"")
print(f"sed -i '/#include/d' \"{dispatcher}.h\"")
print(f"rm {dispatcher}.c")
# Use a template file to add namespace and includes to the dispatcher file.
print(
'python -c "'
"from pathlib import Path;"
"template=Path('../compile_template_dispatcher_h.txt').read_text();"
f"code=Path('{dispatcher}.h').read_text();"
"text=template.replace('PLACEHOLDER', code); print(text)\" "
f"> ../{dispatcher}.h"
)
# rename *.c to *.cc
print('for file in *.c; do mv -- "$file" "${file%.c}.cc"; done')
# Move kernel files to parent directory. This might overwrite existing files in repository.
print("mv sparse_attention_v1_* ../")
# Clean up
print("cd ..")
print("rm compile.py")
print(f"rm -rf {out_dir}")
print("echo Done")
if __name__ == "__main__":
for dtype in ["fp16", "bf16"]:
generate_triton_compile_shell_script(dtype)

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention.py using triton AoT compiler
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {
PLACEHOLDER
} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
namespace onnxruntime {{
namespace contrib {{
namespace cuda {{
namespace sparse_attention_v1 {{
// This file is generated by compile_sparse_attention.py using triton AoT compiler
// {kernel_docstring}
// cubin_size = {bin_size}
// shared_mem_bytes = {shared}
// threads_per_cta = {num_warps} * 32
// kernel_name = {triton_kernel_name}
unsigned char {kernel_name}_cubin[] = {{ {bin_data} }};
CUmodule {kernel_name}_mod = NULL;
CUfunction {kernel_name}_func = NULL;
void unload_{kernel_name}(void) {{
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
CU_CHECK(driver->cuModuleUnload({kernel_name}_mod), driver);
}}
void load_{kernel_name}(void) {{
void *bin = (void *)&{kernel_name}_cubin;
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
CU_CHECK(driver->cuModuleLoadData(&{kernel_name}_mod, bin), driver);
CU_CHECK(driver->cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"), driver);
constexpr int shared = {shared};
if constexpr (shared > 49152) {{
SetKernelSharedMemory(driver, {kernel_name}_func);
}}
}}
Status {kernel_name}(SparseAttentionParams& params) {{
return params.LaunchKernel({kernel_name}_func, {block_m}, {num_warps} * 32, {shared});
}}
}} // namespace sparse_attention_v1
}} // namespace cuda
}} // namespace contrib
}} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention.py using triton AoT compiler
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
namespace onnxruntime {{
namespace contrib {{
namespace cuda {{
namespace sparse_attention_v1 {{
void unload_{kernel_name}(void);
void load_{kernel_name}(void);
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
Status{_placeholder} {kernel_name}(SparseAttentionParams& params);
}} // namespace sparse_attention_v1
}} // namespace cuda
}} // namespace contrib
}} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,203 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.h"
#define CU_CHECK(expr, driver) cuErrCheck(expr, *driver)
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {
struct SparseAttentionParams {
onnxruntime::Stream* ort_stream;
void* output;
const void* q;
const void* k;
const void* v;
int batch_size;
int num_heads;
int kv_num_heads;
int head_size;
int sequence_length;
int past_sequence_length;
int total_sequence_length;
int max_sequence_length;
float scale;
int kernel_block_size;
// CSR format of block mask
const int* layout_csr_row_indices;
const int* layout_csr_col_indices;
int layout_row_stride_h;
int layout_col_stride_h;
int num_layout;
// strides
int stride_qb;
int stride_qh;
int stride_qm;
int stride_kb;
int stride_kh;
int stride_kn;
int stride_vb;
int stride_vh;
int stride_vn;
int stride_ob;
int stride_oh;
int stride_om;
SparseAttentionParams(
onnxruntime::Stream* ort_stream,
void* output,
const void* q,
const void* k,
const void* v,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
int total_sequence_length,
int max_sequence_length,
float scale,
int kernel_block_size,
const int* layout_csr_row_indices,
const int* layout_csr_col_indices,
int layout_row_stride_h,
int layout_col_stride_h,
int num_layout) {
this->ort_stream = ort_stream;
this->output = output;
this->q = q;
this->k = k;
this->v = v;
this->batch_size = batch_size;
this->sequence_length = sequence_length;
this->num_heads = num_heads;
this->kv_num_heads = kv_num_heads;
this->head_size = head_size;
this->past_sequence_length = total_sequence_length - sequence_length;
this->total_sequence_length = total_sequence_length;
this->max_sequence_length = max_sequence_length;
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast<float>(head_size)) : scale;
this->kernel_block_size = kernel_block_size;
this->layout_csr_row_indices = layout_csr_row_indices;
this->layout_csr_col_indices = layout_csr_col_indices;
this->layout_row_stride_h = layout_row_stride_h;
this->layout_col_stride_h = layout_col_stride_h;
this->num_layout = num_layout;
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
this->stride_qh = this->sequence_length * this->head_size;
this->stride_qm = this->head_size;
// When kv buffer has max sequence length, stride should match max sequence length.
int kv_buffer_sequence_length = max_sequence_length;
// KV cache is in BNSH format
this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
this->stride_kh = kv_buffer_sequence_length * this->head_size;
this->stride_kn = this->head_size;
this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
this->stride_vh = kv_buffer_sequence_length * this->head_size;
this->stride_vn = this->head_size;
// Output is BSNH format
this->stride_ob = this->sequence_length * this->num_heads * this->head_size;
this->stride_oh = this->head_size;
this->stride_om = this->num_heads * this->head_size;
}
Status LaunchKernel(CUfunction f, int block_m, int threads_per_block, unsigned int sharedMemBytes) {
ORT_ENFORCE(f != nullptr, "Kernel shall be loaded before calling LaunchKernel.");
if (!Valididate()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseAttentionParams is not valid.");
}
void* args[26] = {
&output, &q, &k, &v,
&layout_csr_row_indices, &layout_csr_col_indices, &layout_row_stride_h, &layout_col_stride_h, &num_layout, &scale,
&stride_qb, &stride_qh, &stride_qm, &stride_kb, &stride_kh, &stride_kn,
&stride_vb, &stride_vh, &stride_vn, &stride_ob, &stride_oh, &stride_om,
&num_heads, &kv_num_heads, &total_sequence_length, &past_sequence_length};
unsigned int gridDimX = (sequence_length + block_m - 1) / block_m;
unsigned int gridDimY = batch_size * num_heads;
constexpr unsigned int gridDimZ = 1;
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR_INIT();
DUMP_TENSOR("q", reinterpret_cast<const half*>(q), batch_size, num_heads, sequence_length, head_size);
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_sequence_length, head_size);
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_sequence_length, head_size);
DUMP_TENSOR("csr_col_indices",
layout_csr_col_indices,
num_layout,
layout_col_stride_h);
DUMP_TENSOR("csr_row_indices",
layout_csr_row_indices,
num_layout,
layout_row_stride_h);
printf(
"layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n"
"stride_qb=%d, stride_qh=%d, stride_qm=%d, stride_kb=%d, stride_kh=%d, stride_kn=%d,\n"
"stride_vb=%d, stride_vh=%d, stride_vn=%d, stride_ob=%d, stride_oh=%d, stride_om=%d,\n"
"num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n"
"output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n",
layout_row_stride_h, layout_col_stride_h, num_layout, scale,
stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om,
num_heads, kv_num_heads, total_sequence_length, past_sequence_length,
output, q, k, v, layout_csr_row_indices, layout_csr_col_indices);
printf("block_m=%d gridDimX=%d gridDimY=%d threads_per_block=%d sharedMemBytes=%d\n",
block_m, gridDimX, gridDimY, threads_per_block, sharedMemBytes);
#endif
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
CU_CHECK(driver->cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, threads_per_block, 1, 1, sharedMemBytes,
static_cast<CUstream>(this->ort_stream->GetHandle()),
args, NULL),
driver);
return Status::OK();
}
bool Valididate() {
return (reinterpret_cast<size_t>(output) % 16 == 0 &&
reinterpret_cast<size_t>(q) % 16 == 0 &&
reinterpret_cast<size_t>(k) % 16 == 0 &&
reinterpret_cast<size_t>(v) % 16 == 0 &&
reinterpret_cast<size_t>(layout_csr_col_indices) % 16 == 0 &&
reinterpret_cast<size_t>(layout_csr_row_indices) % 16 == 0 &&
this->head_size % 16 == 0 &&
this->past_sequence_length == 0); // This kernel is for prompt only.
}
};
} // namespace sparse_attention_v1
inline void SetKernelSharedMemory(const CUDADriverWrapper* driver, CUfunction func) {
int device = 0;
CUDA_CALL_THROW(cudaGetDevice(&device));
int shared_optin = 0;
CU_CHECK(driver->cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device), driver);
if (shared_optin > 49152) {
CU_CHECK(driver->cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED), driver);
CU_CHECK(driver->cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin), driver);
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,216 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention.py using triton AoT compiler
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {
// launcher for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_ba65ff9c(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_ba65ff9c(params);
}
// load for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_ba65ff9c();
void load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_ba65ff9c();
}
// unload for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_ba65ff9c();
void unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_ba65ff9c();
}
// launcher for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_f951a16d(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_f951a16d(params);
}
// load for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_f951a16d();
void load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_f951a16d();
}
// unload for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_f951a16d();
void unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_f951a16d();
}
// launcher for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_646fefc8(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_646fefc8(params);
}
// load for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_646fefc8();
void load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_646fefc8();
}
// unload for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_646fefc8();
void unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_646fefc8();
}
// launcher for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_21cac990(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_21cac990(params);
}
// load for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_21cac990();
void load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_21cac990();
}
// unload for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_21cac990();
void unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_21cac990();
}
// launcher for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_31acb592(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_31acb592(params);
}
// load for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_31acb592();
void load_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_31acb592();
}
// unload for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_31acb592();
void unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_31acb592();
}
// launcher for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_d55ab166(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_d55ab166(params);
}
// load for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_d55ab166();
void load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_d55ab166();
}
// unload for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_d55ab166();
void unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_d55ab166();
}
// launcher for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_b0560d11(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_b0560d11(params);
}
// load for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_b0560d11();
void load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_b0560d11();
}
// unload for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_b0560d11();
void unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_b0560d11();
}
// launcher for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_c777f3f5(SparseAttentionParams& params);
Status sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_c777f3f5(params);
}
// load for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_c777f3f5();
void load_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_c777f3f5();
}
// unload for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_c777f3f5();
void unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_c777f3f5();
}
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_bf16_sm80_kernels[] = {
sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2,
sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2,
};
int sparse_attention_bf16_sm80_get_num_algos(void) {
return (int)sizeof(sparse_attention_bf16_sm80_kernels);
}
Status sparse_attention_bf16_sm80(SparseAttentionParams& params, int algo_id) {
assert(algo_id < (int)sizeof(sparse_attention_bf16_sm80_kernels));
return sparse_attention_bf16_sm80_kernels[algo_id](params);
}
void load_sparse_attention_bf16_sm80(void) {
load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
}
void unload_sparse_attention_bf16_sm80(void) {
unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
}
Status sparse_attention_bf16_sm80_default(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80(params, 0);
}
} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,216 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention.py using triton AoT compiler
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {
// launcher for: sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2
Status sparse_attention_fp16_sm80_3d26f9b3(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_3d26f9b3(params);
}
// load for: sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2
void load_sparse_attention_fp16_sm80_3d26f9b3();
void load_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2() {
load_sparse_attention_fp16_sm80_3d26f9b3();
}
// unload for: sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2
void unload_sparse_attention_fp16_sm80_3d26f9b3();
void unload_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2() {
unload_sparse_attention_fp16_sm80_3d26f9b3();
}
// launcher for: sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2
Status sparse_attention_fp16_sm80_bfb8dd1f(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_bfb8dd1f(params);
}
// load for: sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2
void load_sparse_attention_fp16_sm80_bfb8dd1f();
void load_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2() {
load_sparse_attention_fp16_sm80_bfb8dd1f();
}
// unload for: sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2
void unload_sparse_attention_fp16_sm80_bfb8dd1f();
void unload_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2() {
unload_sparse_attention_fp16_sm80_bfb8dd1f();
}
// launcher for: sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2
Status sparse_attention_fp16_sm80_5fdf5cf7(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_5fdf5cf7(params);
}
// load for: sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2
void load_sparse_attention_fp16_sm80_5fdf5cf7();
void load_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2() {
load_sparse_attention_fp16_sm80_5fdf5cf7();
}
// unload for: sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2
void unload_sparse_attention_fp16_sm80_5fdf5cf7();
void unload_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2() {
unload_sparse_attention_fp16_sm80_5fdf5cf7();
}
// launcher for: sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2
Status sparse_attention_fp16_sm80_35b9b6eb(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_35b9b6eb(params);
}
// load for: sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2
void load_sparse_attention_fp16_sm80_35b9b6eb();
void load_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2() {
load_sparse_attention_fp16_sm80_35b9b6eb();
}
// unload for: sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2
void unload_sparse_attention_fp16_sm80_35b9b6eb();
void unload_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2() {
unload_sparse_attention_fp16_sm80_35b9b6eb();
}
// launcher for: sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2
Status sparse_attention_fp16_sm80_bef12fb0(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_bef12fb0(params);
}
// load for: sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2
void load_sparse_attention_fp16_sm80_bef12fb0();
void load_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2() {
load_sparse_attention_fp16_sm80_bef12fb0();
}
// unload for: sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2
void unload_sparse_attention_fp16_sm80_bef12fb0();
void unload_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2() {
unload_sparse_attention_fp16_sm80_bef12fb0();
}
// launcher for: sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2
Status sparse_attention_fp16_sm80_30cd91a1(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_30cd91a1(params);
}
// load for: sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2
void load_sparse_attention_fp16_sm80_30cd91a1();
void load_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2() {
load_sparse_attention_fp16_sm80_30cd91a1();
}
// unload for: sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2
void unload_sparse_attention_fp16_sm80_30cd91a1();
void unload_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2() {
unload_sparse_attention_fp16_sm80_30cd91a1();
}
// launcher for: sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2
Status sparse_attention_fp16_sm80_72b7bd79(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_72b7bd79(params);
}
// load for: sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2
void load_sparse_attention_fp16_sm80_72b7bd79();
void load_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2() {
load_sparse_attention_fp16_sm80_72b7bd79();
}
// unload for: sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2
void unload_sparse_attention_fp16_sm80_72b7bd79();
void unload_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2() {
unload_sparse_attention_fp16_sm80_72b7bd79();
}
// launcher for: sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2
Status sparse_attention_fp16_sm80_d7f3a63f(SparseAttentionParams& params);
Status sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80_d7f3a63f(params);
}
// load for: sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2
void load_sparse_attention_fp16_sm80_d7f3a63f();
void load_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2() {
load_sparse_attention_fp16_sm80_d7f3a63f();
}
// unload for: sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2
void unload_sparse_attention_fp16_sm80_d7f3a63f();
void unload_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2() {
unload_sparse_attention_fp16_sm80_d7f3a63f();
}
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_fp16_sm80_kernels[] = {
sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2,
sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2,
sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2,
sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2,
sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2,
sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2,
sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2,
sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2,
};
int sparse_attention_fp16_sm80_get_num_algos(void) {
return (int)sizeof(sparse_attention_fp16_sm80_kernels);
}
Status sparse_attention_fp16_sm80(SparseAttentionParams& params, int algo_id) {
assert(algo_id < (int)sizeof(sparse_attention_fp16_sm80_kernels));
return sparse_attention_fp16_sm80_kernels[algo_id](params);
}
void load_sparse_attention_fp16_sm80(void) {
load_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2();
load_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2();
load_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2();
load_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2();
load_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2();
load_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2();
load_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2();
load_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2();
}
void unload_sparse_attention_fp16_sm80(void) {
unload_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2();
unload_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2();
unload_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2();
unload_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2();
unload_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2();
unload_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2();
unload_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2();
unload_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2();
}
Status sparse_attention_fp16_sm80_default(SparseAttentionParams& params) {
return sparse_attention_fp16_sm80(params, 0);
}
} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,166 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import triton
import triton.language as tl
# This kernel is for prompt only and assume that past sequence length is 0. It only supports right padding.
@triton.jit
def block_sparse_attention_kernel(
out, # output [B, H, M, D]. Note that B is batch_size, H is num_heads, M is q_seq_len, and D is head_size
Q, # query [B, H, M, D]
K, # key [B, H_kv, N, D]. Note that N is max_seq_len for kv cache, H_kv is num_kv_heads
V, # value [B, H_kv, N, D]
layout_csr_row_indices, # block mask CSR format. Shape is [L, num_rows + 1] where num_rows = max_seq_len / BLOCK_M
layout_csr_col_indices, # block mask CSR format. Shape is [L, num_rows * num_cols] where num_cols = max_seq_len / BLOCK_N
layout_csr_row_stride_h, # stride per head for csr_row_indices, i.e. num_rows + 1
layout_csr_col_stride_h, # stride per head for csr_col_indices, i.e. num_rows * num_cols
num_layout, # number of sparse layout (L)
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
num_heads,
num_kv_heads,
total_seq_len, # Total sequence length including past sequence length and query sequence length.
BLOCK_M: tl.constexpr, # block size for q_seq_len
EVEN_M: tl.constexpr, # whether q_seq_len % BLOCK_M == 0
BLOCK_N: tl.constexpr, # block size for k_seq_len
EVEN_N: tl.constexpr, # whether k_seq_len % BLOCK_N == 0
BLOCK_D: tl.constexpr, # block size for D
NUM_D_BLOCKS: tl.constexpr, # number of data blocks = D / BLOCK_D
):
tl.static_print(f"{BLOCK_M=} {BLOCK_N=} {BLOCK_D=} {EVEN_M=} {EVEN_N=} {NUM_D_BLOCKS=}")
# Past sequence length is 0 since this kernel is for prompt only.
q_seq_len = total_seq_len
# Grid is [CDiv(q_seq_len, BLOCK_M), batch_size * num_heads]
start_m = tl.program_id(0)
off_bh = tl.program_id(1)
off_h = off_bh % num_heads
off_b = off_bh // num_heads
# For group query attention, map the query head index to the corresponding one for key and value.
head_groups = num_heads // num_kv_heads
off_h_kv = off_h // head_groups
Q += off_b * stride_qb + off_h * stride_qh
K += off_b * stride_kb + off_h_kv * stride_kh
V += off_b * stride_vb + off_h_kv * stride_vh
# Initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] # [BLOCK_M, BLOCK_D]
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] # [BLOCK_D, BLOCK_N]
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] # [BLOCK_N, BLOCK_D]
# Initialize pointers to query, key, value
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# Initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
if NUM_D_BLOCKS >= 2:
acc2 = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
# Load q: it will stay in SRAM throughout
if EVEN_M:
q = tl.load(q_ptrs)
if NUM_D_BLOCKS >= 2:
q2 = tl.load(q_ptrs + BLOCK_D)
else:
q = tl.load(q_ptrs, mask=offs_m[:, None] < q_seq_len)
if NUM_D_BLOCKS >= 2:
q2 = tl.load(q_ptrs + BLOCK_D, mask=offs_m[:, None] < q_seq_len)
layout_h = off_h % num_layout
# This assumes that past sequence length is 0, otherwise need + (past_seq_len + 1) // BLOCK_M.
layout_ptr = layout_csr_row_indices + layout_h * layout_csr_row_stride_h + start_m
start_l = tl.load(layout_ptr).to(tl.int32)
end_l = tl.load(layout_ptr + 1).to(tl.int32)
# Loop over k, v and update accumulator
for col_idx_idx in range(start_l, end_l):
col_idx = tl.load(layout_csr_col_indices + layout_h * layout_csr_col_stride_h + col_idx_idx).to(tl.int32)
start_n = col_idx * BLOCK_N
# -- compute qk ----
if EVEN_N:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < total_seq_len)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
if NUM_D_BLOCKS >= 2:
if EVEN_N:
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_D)
else:
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_D, mask=offs_n[None, :] + start_n < total_seq_len)
qk += tl.dot(q2, k)
qk *= softmax_scale
# This assumes that past sequence length is 0, otherwise need offs_m[:, None] + past_seq_len >= ...
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
if NUM_D_BLOCKS >= 2:
acc2 = acc2 * acc_scale[:, None]
p = p.to(Q.dtype.element_ty)
# update acc
if EVEN_N:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < total_seq_len)
acc += tl.dot(p, v)
if NUM_D_BLOCKS >= 2:
if EVEN_N:
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_D)
else:
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_D, mask=offs_n[:, None] + start_n < total_seq_len)
acc2 += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_o = off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]
out_ptrs = out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < q_seq_len)
if NUM_D_BLOCKS >= 2:
tl.store(out_ptrs + BLOCK_D, acc2, mask=offs_m[:, None] < q_seq_len)

Просмотреть файл

@ -0,0 +1,99 @@
#include <cuda.h>
#include <stdint.h>
#include <assert.h>
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
// Dispatcher files are generated.
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_fp16_sm80.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm80.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {
int get_algo_id(SparseAttentionParams& params) {
int block_n = params.kernel_block_size;
int block_m = block_n;
bool even_m = (params.sequence_length % block_m == 0);
bool even_n = (params.total_sequence_length % block_n == 0);
if (params.head_size == 128) {
if (block_m == 16) {
if (!even_m) {
return even_n ? 1 : 0;
} else {
return even_n ? 3 : 2;
}
} else if (block_m == 64) {
if (!even_m) {
return even_n ? 5 : 4;
} else {
return even_n ? 7 : 6;
}
}
}
return -1;
}
bool is_supported_sparse_attention(const cudaDeviceProp& dprops) {
return dprops.major == 8;
}
bool is_supported_sparse_attention(int head_size, int sparse_block_size) {
return head_size == 128 && sparse_block_size == 64;
}
// -----------------------------------------------------------------------
// FP16
Status run_sparse_attention_fp16(SparseAttentionParams& params) {
int algo_id = get_algo_id(params);
if (algo_id < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
}
// Right now we only support sm_8x.
// If we want to support more architectures, we need to dispatch according to SM.
return sparse_attention_fp16_sm80(params, algo_id);
}
static std::once_flag load_sparse_attention_fp16_flag;
void load_sparse_attention_fp16(void) {
// Right now we only support sm_8x.
// If we want to support more architectures, we need to dispatch according to SM.
std::call_once(load_sparse_attention_fp16_flag, load_sparse_attention_fp16_sm80);
}
void unload_sparse_attention_fp16(void) {
unload_sparse_attention_fp16_sm80();
}
// -----------------------------------------------------------------------
// BF16
Status run_sparse_attention_bf16(SparseAttentionParams& params) {
int algo_id = get_algo_id(params);
if (algo_id < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
}
return sparse_attention_bf16_sm80(params, algo_id);
}
static std::once_flag load_sparse_attention_bf16_flag;
void load_sparse_attention_bf16(void) {
std::call_once(load_sparse_attention_bf16_flag, load_sparse_attention_fp16_sm80);
}
void unload_sparse_attention_bf16(void) {
unload_sparse_attention_bf16_sm80();
}
} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,25 @@
#include <cuda.h>
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
using onnxruntime::Status;
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {
bool is_supported_sparse_attention(const cudaDeviceProp& dprops);
bool is_supported_sparse_attention(int head_size, int sparse_block_size);
Status run_sparse_attention_fp16(SparseAttentionParams& params);
void load_sparse_attention_fp16();
void unload_sparse_attention_fp16();
Status run_sparse_attention_bf16(SparseAttentionParams& params);
void load_sparse_attention_bf16();
void unload_sparse_attention_bf16();
} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,140 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
# Example to use this script (Tested with CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install triton==2.3.0
# python3 compile_sparse_attention_v2.py | sh
#
# Note that sparse_attention_v2_*.cc and sparse_attention_v2_dispatcher_*.h are the generated files.
from itertools import product
import triton
def generate_triton_compile_shell_script(dtype="fp16"):
assert dtype in ["fp16", "bf16"]
print("export TRITON_ROOT=$(pip show triton | grep Location | cut -d' ' -f2)")
print('export ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader|head -n 1)"')
print("export SM=$(echo $ARCH | sed -e 's/\\.//g')")
# Modify the compile.py to use custom template files compile_template_kernel_v2_c/h.txt in current directory.
print(
'python -c "'
"import sys;lines=sys.stdin.read();"
"lines=lines.replace('template_path = Path(__file__).parent / f\\\"compile.{ext}\\\"','template_path = f\\\"compile_template_kernel_v2_{ext}.txt\\\"');"
'print(lines)"'
"< ${TRITON_ROOT}/triton/tools/compile.py > compile.py"
)
out_dir = f"trition_cubin_{dtype}"
print(f"rm -rf {out_dir}")
print(f"mkdir -p {out_dir}")
# All combinations of parameters for kernel.
has_batch_dim_values = [True]
head_size_values = [128]
block_size_values = [64]
is_prompt_values = [True, False]
# Use triton compiler to compile the kernel of different combinations of constant parameters.
for has_batch_dim, head_size, block_size, is_prompt in product(
has_batch_dim_values, head_size_values, block_size_values, is_prompt_values
):
# Constant parameters for triton kernel.
block_d = triton.next_power_of_2(head_size)
block_m = block_size
block_n = block_size
block_m_loading = 16 if not is_prompt else block_size
even_d = block_d == head_size
m_lt_n = block_m < block_n
num_warps = 1 if not is_prompt else 4
num_stages = 3
# There are 4 float and 8 int32 buffer pointers, and they are assumed to be aligned to 16 bytes.
tensor_params = f"*{dtype}:16," * 4 + "*i32:16," * 8
# The strides for Q, K, V, and Out are multiples of 16 since head_size is 128.
scalar_params = ("i32," * 2) + ("i32:16," * 12) + "i32,i32,fp32,"
constant_params = f"{int(has_batch_dim)},{head_size},{block_m},{block_n},{block_d},{block_m_loading},{int(even_d)},{int(m_lt_n)}"
signature = f"{tensor_params}{scalar_params}{constant_params}"
prefix = "python compile.py sparse_attention_v2_triton.py"
# output filename
filename = f"sparse_attention_v2_{dtype}_d{head_size}_m{block_m}_{block_m_loading}_n{block_n}_b{int(has_batch_dim)}_sm${{SM}}"
# function name
name = f"sparse_attention_v2_{dtype}_sm${{SM}}"
print(
f"{prefix} -n block_sparse_attention -o {out_dir}/{filename} --out-name {name} "
f'-w {num_warps} -ns {num_stages} -s "{signature}" -g "query_blocks, num_heads, 1"'
)
# Generate the dispatcher.
dispatcher = f"sparse_attention_v2_dispatcher_{dtype}_sm${{SM}}"
print(f"cd {out_dir}")
print(f"python ${{TRITON_ROOT}}/triton/tools/link.py sparse_attention_v2_*.h -o {dispatcher}")
print("rm *.h")
# Remove signature in code.
suffix = "0d1d2d3d4d5d6d7d8d9d10d11d121314d15d16d17d18d19d20d21d22d23d24d25d262728"
print(f"for file in *.c; do sed -i 's/_{suffix}//g' \"$file\"; done")
# Recover signature in kernel name that is removed in previous step. Kernel name shall not be changed.
print(f"for file in *.c; do sed -i 's/block_sparse_attention/block_sparse_attention_{suffix}/g' \"$file\"; done")
# Remove signature from filename since we use same signature for all kernels except constants.
# and we have constants in filename so that we can distinguish files without the hash.
print(f'for file in sparse_attention_v2_{dtype}_*.c; do mv "$file" "$(echo $file | cut -f 1 -d \'.\').c"; done')
# Change function parameters and return type. If you change the kernel interface, you will need to modify this part.
source1 = "CUstream stream, CUdeviceptr Out, CUdeviceptr Q, CUdeviceptr K, CUdeviceptr V, CUdeviceptr q_batch_starts, CUdeviceptr q_batch_ends, CUdeviceptr k_batch_starts, CUdeviceptr k_batch_ends, CUdeviceptr q_batch_ids, CUdeviceptr q_start_sids, CUdeviceptr layout_crow_ptr, CUdeviceptr layout_col_ptr, int32_t layout_crow_stride_h, int32_t layout_col_stride_h, int32_t stride_qb, int32_t stride_qt, int32_t stride_qh, int32_t stride_kb, int32_t stride_kt, int32_t stride_kh, int32_t stride_vb, int32_t stride_vt, int32_t stride_vh, int32_t stride_ob, int32_t stride_ot, int32_t stride_oh, int32_t q_k_ratio, int32_t num_layout, float softmax_scale"
target1 = "SparseAttentionParams& params"
source2 = "stream, Out, Q, K, V, q_batch_starts, q_batch_ends, k_batch_starts, k_batch_ends, q_batch_ids, q_start_sids, layout_crow_ptr, layout_col_ptr, layout_crow_stride_h, layout_col_stride_h, stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh, stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh, q_k_ratio, num_layout, softmax_scale"
target2 = "params"
print(
f"python -c \"import sys;lines=sys.stdin.read();lines=lines.replace('{source1}', '{target1}');"
f'lines=lines.replace(\'{source2}\', \'{target2}\');print(lines)" < "{dispatcher}.c" > "{dispatcher}.h"'
)
print(f"sed -i 's/CUresult/Status/g' \"{dispatcher}.h\"")
# Remove parameter checking since we moved the validation logic to SparseAttentionParams
print(f"sed -i '/if /d' \"{dispatcher}.h\"")
print(f"sed -i '/CUDA_ERROR_INVALID_VALUE/d' \"{dispatcher}.h\"")
print(f"sed -i '/#include/d' \"{dispatcher}.h\"")
print(f"rm {dispatcher}.c")
# Use a template file to add namespace and includes to the dispatcher file.
print(
'python -c "'
"from pathlib import Path;"
"template=Path('../compile_template_dispatcher_v2_h.txt').read_text();"
f"code=Path('{dispatcher}.h').read_text();"
"text=template.replace('PLACEHOLDER', code); print(text)\" "
f"> ../{dispatcher}.h"
)
# rename *.c to *.cc
print('for file in *.c; do mv -- "$file" "${file%.c}.cc"; done')
# Move kernel files to parent directory. This might overwrite existing files in repository.
print("echo Generated files:")
print("ls sparse_attention_v2_*")
print(f"mv -f sparse_attention_v2_{dtype}_* ../")
# Clean up
print("cd ..")
print("rm compile.py")
print(f"rm -rf {out_dir}")
print(f"echo compiling {dtype} is done")
if __name__ == "__main__":
for dtype in ["fp16", "bf16"]:
generate_triton_compile_shell_script(dtype)

Просмотреть файл

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention_v2.py
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v2 {
PLACEHOLDER
} // namespace sparse_attention_v2
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
namespace onnxruntime {{
namespace contrib {{
namespace cuda {{
namespace sparse_attention_v2 {{
// This file is generated by compile_sparse_attention_v2.py
// {kernel_docstring}
// cubin_size = {bin_size}
// shared_mem_bytes = {shared}
// threads_per_cta = {num_warps} * 32
// kernel_name = {triton_kernel_name}
unsigned char {kernel_name}_cubin[] = {{ {bin_data} }};
CUmodule {kernel_name}_mod = NULL;
CUfunction {kernel_name}_func = NULL;
void unload_{kernel_name}(void) {{
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
CU_CHECK(driver->cuModuleUnload({kernel_name}_mod), driver);
}}
void load_{kernel_name}(void) {{
void *bin = (void *)&{kernel_name}_cubin;
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
CU_CHECK(driver->cuModuleLoadData(&{kernel_name}_mod, bin), driver);
CU_CHECK(driver->cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"), driver);
constexpr int shared = {shared};
if constexpr (shared > 49152) {{
SetKernelSharedMemory(driver, {kernel_name}_func);
}}
}}
Status {kernel_name}(SparseAttentionParams& params) {{
return params.LaunchKernel({kernel_name}_func, {num_warps} * 32, {shared});
}}
}} // namespace sparse_attention_v2
}} // namespace cuda
}} // namespace contrib
}} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention_v2.py
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
namespace onnxruntime {{
namespace contrib {{
namespace cuda {{
namespace sparse_attention_v2 {{
void unload_{kernel_name}(void);
void load_{kernel_name}(void);
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
Status{_placeholder} {kernel_name}(SparseAttentionParams& params);
}} // namespace sparse_attention_v2
}} // namespace cuda
}} // namespace contrib
}} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,78 @@
#include <cuda.h>
#include <stdint.h>
#include <assert.h>
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
// Dispatcher files are generated.
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_dispatcher_fp16_sm80.h"
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_dispatcher_bf16_sm80.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v2 {
int get_algo_id(SparseAttentionParams& params) {
return (params.past_sequence_length > 0 && params.sequence_length <= 16) ? 0 : 1;
}
bool is_supported_sparse_attention(const cudaDeviceProp& dprops) {
return dprops.major == 8;
}
bool is_supported_sparse_attention(int head_size, int sparse_block_size) {
return head_size == 128 && sparse_block_size == 64;
}
// -----------------------------------------------------------------------
// FP16
Status run_sparse_attention_fp16(SparseAttentionParams& params) {
int algo_id = get_algo_id(params);
if (algo_id < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
}
// Right now we only support sm_8x.
// If we want to support more architectures, we need to dispatch according to SM.
return sparse_attention_v2_fp16_sm80(params, algo_id);
}
static std::once_flag load_sparse_attention_v2_fp16_flag;
void load_sparse_attention_fp16(void) {
// Right now we only support sm_8x.
// If we want to support more architectures, we need to dispatch according to SM.
std::call_once(load_sparse_attention_v2_fp16_flag, load_sparse_attention_v2_fp16_sm80);
}
void unload_sparse_attention_fp16(void) {
unload_sparse_attention_v2_fp16_sm80();
}
// -----------------------------------------------------------------------
// BF16
Status run_sparse_attention_bf16(SparseAttentionParams& params) {
int algo_id = get_algo_id(params);
if (algo_id < 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
}
return sparse_attention_v2_bf16_sm80(params, algo_id);
}
static std::once_flag load_sparse_attention_v2_bf16_flag;
void load_sparse_attention_bf16(void) {
std::call_once(load_sparse_attention_v2_bf16_flag, load_sparse_attention_v2_bf16_sm80);
}
void unload_sparse_attention_bf16(void) {
unload_sparse_attention_v2_bf16_sm80();
}
} // namespace sparse_attention_v2
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,25 @@
#include <cuda.h>
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
using onnxruntime::Status;
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v2 {
bool is_supported_sparse_attention(const cudaDeviceProp& dprops);
bool is_supported_sparse_attention(int head_size, int sparse_block_size);
Status run_sparse_attention_fp16(SparseAttentionParams& params);
void load_sparse_attention_fp16();
void unload_sparse_attention_fp16();
Status run_sparse_attention_bf16(SparseAttentionParams& params);
void load_sparse_attention_bf16();
void unload_sparse_attention_bf16();
} // namespace sparse_attention_v2
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,232 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v2 {
struct SparseAttentionParams {
onnxruntime::Stream* ort_stream;
void* output;
const void* q;
const void* k;
const void* v;
int batch_size;
int num_heads;
int kv_num_heads;
int head_size;
int sequence_length;
int past_sequence_length;
int total_sequence_length;
int max_sequence_length;
float scale;
int kernel_block_size;
// CSR format of block mask
const int* layout_csr_row_indices;
const int* layout_csr_col_indices;
int layout_row_stride_h;
int layout_col_stride_h;
int num_layout;
// strides
int stride_qb;
int stride_qt;
int stride_qh;
int stride_kb;
int stride_kt;
int stride_kh;
int stride_vb;
int stride_vt;
int stride_vh;
int stride_ob;
int stride_ot;
int stride_oh;
int q_k_ratio;
int active_q_blocks;
const int* q_batch_starts;
const int* q_batch_ends;
const int* k_batch_starts;
const int* k_batch_ends;
const int* q_batch_ids;
const int* q_start_sids;
SparseAttentionParams(
onnxruntime::Stream* ort_stream,
void* output,
const void* q,
const void* k,
const void* v,
int batch_size,
int sequence_length,
int num_heads,
int kv_num_heads,
int head_size,
int total_sequence_length,
int max_sequence_length,
float scale,
int kernel_block_size,
const int* layout_csr_row_indices,
const int* layout_csr_col_indices,
int layout_row_stride_h,
int layout_col_stride_h,
int num_layout,
int active_q_blocks,
const int* q_batch_starts,
const int* q_batch_ends,
const int* k_batch_starts,
const int* k_batch_ends,
const int* q_batch_ids,
const int* q_start_sids) {
this->ort_stream = ort_stream;
this->output = output;
this->q = q;
this->k = k;
this->v = v;
this->batch_size = batch_size;
this->sequence_length = sequence_length;
this->num_heads = num_heads;
this->kv_num_heads = kv_num_heads;
this->head_size = head_size;
this->past_sequence_length = total_sequence_length - sequence_length;
this->total_sequence_length = total_sequence_length;
this->max_sequence_length = max_sequence_length;
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast<float>(head_size)) : scale;
this->kernel_block_size = kernel_block_size;
this->layout_csr_row_indices = layout_csr_row_indices;
this->layout_csr_col_indices = layout_csr_col_indices;
this->layout_row_stride_h = layout_row_stride_h;
this->layout_col_stride_h = layout_col_stride_h;
this->num_layout = num_layout;
// Q is in BNSH format
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
this->stride_qh = this->sequence_length * this->head_size;
this->stride_qt = this->head_size;
// When kv buffer has max sequence length, stride should match max sequence length.
int kv_buffer_sequence_length = max_sequence_length;
// KV cache is in BNSH format
this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
this->stride_kh = kv_buffer_sequence_length * this->head_size;
this->stride_kt = this->head_size;
this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
this->stride_vh = kv_buffer_sequence_length * this->head_size;
this->stride_vt = this->head_size;
// Output is BSNH format
this->stride_ob = this->sequence_length * this->num_heads * this->head_size;
this->stride_oh = this->head_size;
this->stride_ot = this->num_heads * this->head_size;
this->q_k_ratio = this->num_heads / this->kv_num_heads;
this->active_q_blocks = active_q_blocks;
this->q_batch_starts = q_batch_starts;
this->q_batch_ends = q_batch_ends;
this->k_batch_starts = k_batch_starts;
this->k_batch_ends = k_batch_ends;
this->q_batch_ids = q_batch_ids;
this->q_start_sids = q_start_sids;
}
Status LaunchKernel(CUfunction f, int threads_per_block, unsigned int sharedMemBytes) {
ORT_ENFORCE(f != nullptr, "Kernel shall be loaded before calling LaunchKernel.");
if (!Valididate()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseAttentionParams is not valid.");
}
void* args[29] = {
&output, &q, &k, &v,
&q_batch_starts, &q_batch_ends, &k_batch_starts, &k_batch_ends, &q_batch_ids, &q_start_sids,
&layout_csr_row_indices, &layout_csr_col_indices, &layout_row_stride_h, &layout_col_stride_h,
&stride_qb, &stride_qt, &stride_qh, &stride_kb, &stride_kt, &stride_kh,
&stride_vb, &stride_vt, &stride_vh, &stride_ob, &stride_ot, &stride_oh,
&q_k_ratio, &num_layout, &scale};
unsigned int gridDimX = active_q_blocks;
unsigned int gridDimY = num_heads;
constexpr unsigned int gridDimZ = 1;
#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR_INIT();
DUMP_TENSOR("q", reinterpret_cast<const half*>(q), batch_size, num_heads, sequence_length, head_size);
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_sequence_length, head_size);
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_sequence_length, head_size);
DUMP_TENSOR("csr_col_indices",
layout_csr_col_indices,
num_layout,
layout_col_stride_h);
DUMP_TENSOR("csr_row_indices",
layout_csr_row_indices,
num_layout,
layout_row_stride_h);
DUMP_TENSOR("q_batch_starts", q_batch_starts, 1, batch_size);
DUMP_TENSOR("q_batch_ends", q_batch_ends, 1, batch_size);
DUMP_TENSOR("k_batch_starts", k_batch_starts, 1, batch_size);
DUMP_TENSOR("k_batch_ends", k_batch_ends, 1, batch_size);
DUMP_TENSOR("q_batch_ids", q_batch_ids, 1, active_q_blocks);
DUMP_TENSOR("q_start_sids", q_start_sids, 1, active_q_blocks);
printf(
"layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n"
"stride_qb=%d, stride_qt=%d, stride_qh=%d, stride_kb=%d, stride_kt=%d, stride_kh=%d,\n"
"stride_vb=%d, stride_vt=%d, stride_vh=%d, stride_ob=%d, stride_ot=%d, stride_oh=%d,\n"
"num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n"
"output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n"
"q_batch_starts=%p, q_batch_ends=%p, k_batch_starts=%p, k_batch_ends=%p, q_batch_ids=%p, q_start_sids=%p active_q_blocks=%d\n",
layout_row_stride_h, layout_col_stride_h, num_layout, scale,
stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh,
stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh,
num_heads, kv_num_heads, total_sequence_length, past_sequence_length,
output, q, k, v, layout_csr_row_indices, layout_csr_col_indices,
q_batch_starts, q_batch_ends, k_batch_starts, k_batch_ends, q_batch_ids, q_start_sids, active_q_blocks);
printf("gridDimX=%d gridDimY=%d threads_per_block=%d sharedMemBytes=%d\n",
gridDimX, gridDimY, threads_per_block, sharedMemBytes);
#endif
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
CU_CHECK(driver->cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, threads_per_block, 1, 1, sharedMemBytes,
static_cast<CUstream>(this->ort_stream->GetHandle()),
args, NULL),
driver);
return Status::OK();
}
bool Valididate() {
// Check pointers are aligned to 16 bytes (we used that to hint the compiler to generate aligned loads/stores)
return (reinterpret_cast<size_t>(output) % 16 == 0 &&
reinterpret_cast<size_t>(q) % 16 == 0 &&
reinterpret_cast<size_t>(k) % 16 == 0 &&
reinterpret_cast<size_t>(v) % 16 == 0 &&
reinterpret_cast<size_t>(layout_csr_col_indices) % 16 == 0 &&
reinterpret_cast<size_t>(layout_csr_row_indices) % 16 == 0 &&
reinterpret_cast<size_t>(q_batch_starts) % 16 == 0 &&
reinterpret_cast<size_t>(q_batch_ends) % 16 == 0 &&
reinterpret_cast<size_t>(k_batch_starts) % 16 == 0 &&
reinterpret_cast<size_t>(k_batch_ends) % 16 == 0 &&
reinterpret_cast<size_t>(q_batch_ids) % 16 == 0 &&
reinterpret_cast<size_t>(q_start_sids) % 16 == 0 &&
this->head_size % 16 == 0);
}
};
} // namespace sparse_attention_v2
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention_v2.py
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v2 {
// launcher for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
Status sparse_attention_v2_bf16_sm80_0aafaf4a(SparseAttentionParams& params);
Status sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3(SparseAttentionParams& params) {
return sparse_attention_v2_bf16_sm80_0aafaf4a(params);
}
// load for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
void load_sparse_attention_v2_bf16_sm80_0aafaf4a();
void load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
load_sparse_attention_v2_bf16_sm80_0aafaf4a();
}
// unload for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
void unload_sparse_attention_v2_bf16_sm80_0aafaf4a();
void unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
unload_sparse_attention_v2_bf16_sm80_0aafaf4a();
}
// launcher for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
Status sparse_attention_v2_bf16_sm80_8b0ce70d(SparseAttentionParams& params);
Status sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3(SparseAttentionParams& params) {
return sparse_attention_v2_bf16_sm80_8b0ce70d(params);
}
// load for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
void load_sparse_attention_v2_bf16_sm80_8b0ce70d();
void load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
load_sparse_attention_v2_bf16_sm80_8b0ce70d();
}
// unload for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
void unload_sparse_attention_v2_bf16_sm80_8b0ce70d();
void unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
unload_sparse_attention_v2_bf16_sm80_8b0ce70d();
}
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_v2_bf16_sm80_kernels[] = {
sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3,
sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3,
};
int sparse_attention_v2_bf16_sm80_get_num_algos(void) {
return (int)sizeof(sparse_attention_v2_bf16_sm80_kernels);
}
Status sparse_attention_v2_bf16_sm80(SparseAttentionParams& params, int algo_id) {
assert(algo_id < (int)sizeof(sparse_attention_v2_bf16_sm80_kernels));
return sparse_attention_v2_bf16_sm80_kernels[algo_id](params);
}
void load_sparse_attention_v2_bf16_sm80(void) {
load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
}
void unload_sparse_attention_v2_bf16_sm80(void) {
unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
}
Status sparse_attention_v2_bf16_sm80_default(SparseAttentionParams& params) {
return sparse_attention_v2_bf16_sm80(params, 0);
}
} // namespace sparse_attention_v2
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is generated by compile_sparse_attention_v2.py
#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v2 {
// launcher for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
Status sparse_attention_v2_fp16_sm80_a6bdc951(SparseAttentionParams& params);
Status sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3(SparseAttentionParams& params) {
return sparse_attention_v2_fp16_sm80_a6bdc951(params);
}
// load for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
void load_sparse_attention_v2_fp16_sm80_a6bdc951();
void load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
load_sparse_attention_v2_fp16_sm80_a6bdc951();
}
// unload for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
void unload_sparse_attention_v2_fp16_sm80_a6bdc951();
void unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
unload_sparse_attention_v2_fp16_sm80_a6bdc951();
}
// launcher for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
Status sparse_attention_v2_fp16_sm80_ca298032(SparseAttentionParams& params);
Status sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3(SparseAttentionParams& params) {
return sparse_attention_v2_fp16_sm80_ca298032(params);
}
// load for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
void load_sparse_attention_v2_fp16_sm80_ca298032();
void load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
load_sparse_attention_v2_fp16_sm80_ca298032();
}
// unload for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
void unload_sparse_attention_v2_fp16_sm80_ca298032();
void unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
unload_sparse_attention_v2_fp16_sm80_ca298032();
}
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_v2_fp16_sm80_kernels[] = {
sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3,
sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3,
};
int sparse_attention_v2_fp16_sm80_get_num_algos(void) {
return (int)sizeof(sparse_attention_v2_fp16_sm80_kernels);
}
Status sparse_attention_v2_fp16_sm80(SparseAttentionParams& params, int algo_id) {
assert(algo_id < (int)sizeof(sparse_attention_v2_fp16_sm80_kernels));
return sparse_attention_v2_fp16_sm80_kernels[algo_id](params);
}
void load_sparse_attention_v2_fp16_sm80(void) {
load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
}
void unload_sparse_attention_v2_fp16_sm80(void) {
unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
}
Status sparse_attention_v2_fp16_sm80_default(SparseAttentionParams& params) {
return sparse_attention_v2_fp16_sm80(params, 0);
}
} // namespace sparse_attention_v2
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -0,0 +1,215 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import triton
import triton.language as tl
@triton.jit
def block_sparse_attention(
Out, # output [B, M, H, D]. Note that B is batch_size, M is q_seq_len, H is num_heads, and D is head_size
Q, # query [B, M, H, D]
K, # key [B, N, H_kv, D]. Note that N is max_seq_len for kv cache, H_kv is num_kv_heads
V, # value [B, N, H_kv, D]
q_batch_starts, # [B], start position (excluding the past) of query in the sequence for each batch
q_batch_ends, # [B], end position (excluding the past) of query in the sequence for each batch
k_batch_starts, # [B], start position (including the past) of key in the sequence for each batch
k_batch_ends, # [B], end position (including the past) of key in the sequence for each batch
q_batch_ids, # [G], batch id for each query block; G is the total number of query blocks
q_start_sids, # [G], start position (excluding the past) of each query block
layout_crow_ptr, # block mask CSR format. Shape is [H, num_rows + 1] where num_rows = max_seq_len / BLOCK_M
layout_col_ptr, # block mask CSR format. Shape is [H, num_rows * num_cols] where num_cols = max_seq_len / BLOCK_N
layout_crow_stride_h, # stride per head for csr_row_indices, i.e. num_rows + 1
layout_col_stride_h, # stride per head for csr_col_indices, i.e. num_rows * num_cols
stride_qb,
stride_qt,
stride_qh, # strides for query (excluding the stride for last hidden dim, which is always 1)
stride_kb,
stride_kt,
stride_kh, # strides for key (excluding the stride for last hidden dim, which is always 1)
stride_vb,
stride_vt,
stride_vh, # strides for value (excluding the stride for last hidden dim, which is always 1)
stride_ob,
stride_ot,
stride_oh, # strides for output (excluding the stride for last hidden dim, which is always 1)
q_k_ratio, # num_heads / num_kv_heads
num_layout, # number of sparse layout (H)
softmax_scale, # scaling factor applied prior to softmax
HAS_BATCH_DIM: tl.constexpr, # whether batch dim is present
D_HEAD: tl.constexpr, # head size
BLOCK_M: tl.constexpr, # block size for q_seq_len
BLOCK_N: tl.constexpr, # block size for k_seq_len
BLOCK_D: tl.constexpr, # block size for D
BLOCK_M_LOADING: tl.constexpr, # block size for loading q
EVEN_D: tl.constexpr, # whether D is divisible by BLOCK_D
M_LT_N: tl.constexpr, # whether BLOCK_M < BLOCK_N
):
tl.static_print(
f"{HAS_BATCH_DIM=} {D_HEAD=} {BLOCK_M=} {BLOCK_N=} {BLOCK_D=} {BLOCK_M_LOADING=} {EVEN_D=} {M_LT_N=}"
)
# The grid is [G, num_heads] where G is number of query blocks.
off_g = tl.program_id(0)
off_h = tl.program_id(1)
off_h_for_kv = off_h // q_k_ratio
off_b = tl.load(q_batch_ids + off_g).to(tl.int32)
q_start_sid = tl.load(q_start_sids + off_g)
start_m = q_start_sid // BLOCK_M
if HAS_BATCH_DIM:
Q += off_b * stride_qb
K += off_b * stride_kb
V += off_b * stride_vb
Out += off_b * stride_ob
# offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_D)
q_cu_start = tl.load(q_batch_starts + off_b).to(tl.int32)
q_seqlen = tl.load(q_batch_ends + off_b).to(tl.int32) - q_cu_start
k_cu_start = tl.load(k_batch_starts + off_b).to(tl.int32)
k_seqlen = tl.load(k_batch_ends + off_b).to(tl.int32) - k_cu_start
past_len = k_seqlen - q_seqlen
Q += q_cu_start * stride_qt + off_h * stride_qh
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
Out += q_cu_start * stride_ot + off_h * stride_oh
if EVEN_D:
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :], mask=offs_m[:, None] < q_seqlen)
else:
q = tl.load(
Q + offs_m[:, None] * stride_qt + offs_d[None, :],
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
other=0,
)
q_row = (past_len + q_start_sid) // BLOCK_M
layout_h = off_h % num_layout
sparse_crow_ptr = layout_crow_ptr + layout_h * layout_crow_stride_h + q_row
# TODO: load at once, supported in new Triton
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None]
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :]
for k_block_col_idx in range(k_block_start, k_block_end - 1):
k_block_id = tl.load(layout_col_ptr + layout_h * layout_col_stride_h + k_block_col_idx).to(tl.int32)
start_n = k_block_id * BLOCK_N
# -- compute qk ----
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt)
else:
k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_d[:, None] < D_HEAD)
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= softmax_scale
if M_LT_N:
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
p = p.to(Q.dtype.element_ty)
# update acc
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt)
else:
v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_d[None, :] < D_HEAD)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# Process the last k block
k_block_col_idx = k_block_end - 1
k_block_id = tl.load(layout_col_ptr + layout_h * layout_col_stride_h + k_block_col_idx).to(tl.int32)
start_n = k_block_id * BLOCK_N
# -- compute qk ----
if EVEN_D:
k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < k_seqlen)
else:
# mask = mask & (offs_d[:, ])
k = tl.load(
k_ptrs + start_n * stride_kt, mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD)
)
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= softmax_scale
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
p = p.to(Q.dtype.element_ty)
# update acc
if EVEN_D:
v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < k_seqlen)
else:
v = tl.load(
v_ptrs + start_n * stride_vt, mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD)
)
acc += tl.dot(p, v)
# l_i = l_i_new
# m_i = m_i_new
# write output
if EVEN_D:
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :], acc, mask=offs_m[:, None] < q_seqlen)
else:
tl.store(
Out + offs_m[:, None] * stride_ot + offs_d[None, :],
acc,
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
)

Просмотреть файл

@ -7,6 +7,22 @@
#include "core/framework/ort_value.h"
#include "contrib_ops/cpu/utils/console_dumper.h"
#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
#if DUMP_TENSOR_LEVEL > 0
#define DUMP_TENSOR_INIT() onnxruntime::contrib::cuda::transformers::CudaTensorConsoleDumper dumper
#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__)
#else
#define DUMP_TENSOR_INIT()
#define DUMP_TENSOR(...)
#endif
#if DUMP_TENSOR_LEVEL > 1
#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__)
#else
#define DUMP_TENSOR_D(...)
#endif
namespace onnxruntime {
namespace contrib {
namespace cuda {

Просмотреть файл

@ -228,18 +228,13 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
}
}
void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
// Output 0 has shape (batch_size, sequence_length, hidden_size)
// Q, K and V:
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
// Input 1 (key) has shape (batch_size, kv_sequence_length, kv_hidden_size)
// Input 2 (value) has shape (batch_size, kv_sequence_length, kv_hidden_size)
// Type inference
// Type and shape inference for group query attention and sparse attention.
void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
int past_key_index = -1,
int use_max_past_present_buffer = -1) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
int64_t kv_sequence_length = -1;
if (hasInputShape(ctx, 0)) {
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();
@ -249,18 +244,22 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
}
if (hasInputShape(ctx, 2)) {
// Input 0 (query) has shape (batch_size, sequence_length, num_heads * head_size)
// Input 1 (key) has shape (batch_size, kv_sequence_length, kv_num_heads * head_size)
// Input 2 (value) has shape (batch_size, kv_sequence_length, kv_num_heads * head_size)
// Output 0 has shape (batch_size, sequence_length, num_heads * head_size)
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
auto& value_shape = getInputShape(ctx, 2);
auto& value_dims = value_shape.dim();
if (value_dims.size() != 3) {
fail_shape_inference("Inputs 2 (value) shall be 3 dimensions");
if (value_dims.size() == 3 && value_dims[1].has_dim_value()) {
kv_sequence_length = value_dims[1].dim_value();
}
ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = query_dims[0];
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = query_dims[2];
updateOutputShape(ctx, 0, output_shape);
} else {
// Packed QKV:
// Input 0 (query) has shape (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
// Input 1 (key) is not present
// Input 2 (value) is not present
ONNX_NAMESPACE::TensorShapeProto output_shape;
int64_t num_heads = getAttribute(ctx, "num_heads", 0);
int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);
@ -270,25 +269,64 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
*output_shape.add_dim() = query_dims[1];
output_shape.add_dim()->set_dim_value(head_size * num_heads);
updateOutputShape(ctx, 0, output_shape);
if (query_dims[1].has_dim_value()) {
kv_sequence_length = query_dims[1].dim_value();
}
}
}
if (ctx.getNumOutputs() > 1) { // has present output
if (hasInputShape(ctx, past_key_index)) {
// auto& query_shape = getInputShape(ctx, 0);
// auto& query_dims = query_shape.dim();
// copy the type from query to present key
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1);
// copy the type from query to present value
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2);
if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) {
auto& past_shape = getInputShape(ctx, past_key_index);
auto& past_dims = past_shape.dim();
// past key has shape (batch_size, kv_num_heads, max_sequence_length, head_size)
if (past_dims.size() != 4) {
fail_shape_inference("The past_key input shall be 4 dimensions");
}
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1);
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
if (use_max_past_present_buffer == 1) {
// When past and present use max buffer, they have the same shape
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1);
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
} else if (use_max_past_present_buffer == 0) {
if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) {
int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value();
ONNX_NAMESPACE::TensorShapeProto present_shape;
for (auto& dim : past_dims) {
*present_shape.add_dim() = dim;
}
// shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
present_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
updateOutputShape(ctx, 1, present_shape);
updateOutputShape(ctx, 2, present_shape);
}
}
}
}
}
void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
constexpr int use_max_past_present_buffer = -1;
BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer);
}
void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
constexpr int use_max_past_present_buffer = 1;
BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer);
}
constexpr const char* Attention_ver1_doc = R"DOC(
Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
@ -432,7 +470,7 @@ An input as above will be packed into 3 tensors like below:
Input tensors contains the hidden embedding of real tokens.
Token_offset records the offset of token in the unpacked input.
cumulated_token_count records cumulated length of each sequnces length.
cumulated_token_count records cumulated length of each sequence length.
The operator only supports BERT like model with padding on right now.
@ -561,7 +599,7 @@ An input as above will be packed into 3 tensors like below:
The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
Token_offset records the offset of token in the unpacked input.
cumulative_sequence_length records cumulated length of each sequnces length.
cumulative_sequence_length records cumulated length of each sequence length.
The operator only supports BERT like model with padding on right now.
)DOC";
@ -1103,6 +1141,103 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
GroupQueryAttentionTypeAndShapeInference(ctx, 3);
}));
constexpr const char* SparseAttention_ver1_doc = R"DOC(
Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219).
It is inspired by Sparse Transformers (https://arxiv.org/pdf/1904.10509) and BigBird (https://arxiv.org/pdf/2007.14062).
block_mask can be used to configure sparse layout for different head.
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
Padding shall be on the right side.
When do_rotary is True, cos_cache and sin_cache are required.
Only supports unidirectional attention with cache of past key and value in linear buffers.
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
SparseAttention, 1,
OpSchema()
.SetDoc(SparseAttention_ver1_doc)
.Attr("num_heads", "Number of attention heads for query", AttributeProto::INT)
.Attr("kv_num_heads", "Number of attention heads for key and value", AttributeProto::INT)
.Attr("scale", "Scaling factor applied prior to softmax. The default value is 1/sqrt(head_size)", AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("sparse_block_size", "Number of tokens per sparse block. Choices: 16, 32, 64, 128", AttributeProto::INT)
.Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_interleaved", "Rotary use interleaved pattern or not. Default value is 0.", AttributeProto::INT,
OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, num_heads * head_size), or packed QKV with shape is"
"(batch_size, sequence_length, d) where d is (num_heads + 2 * kv_num_heads) * head_size.",
"T")
.Input(1,
"key",
"Key with shape (batch_size, sequence_length, kv_num_heads * head_size)",
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, sequence_length, kv_num_heads * head_size)",
"T",
OpSchema::Optional)
.Input(3,
"past_key",
"Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(4,
"past_value",
"Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(5,
"block_mask",
"block mask. 1 indicates attention and 0 no attention. "
"Its shape is (num_layout, max_blocks, max_blocks), "
"where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.",
"M")
.Input(6,
"total_sequence_length",
"Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.",
"M")
.Input(7,
"key_total_sequence_lengths",
"1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.",
"M")
.Input(8,
"cos_cache",
"Cos cache of rotary with shape (max_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
.Input(9,
"sin_cache",
"Sin cache of rotary with shape (max_sequence_length, head_size / 2).",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)",
"T")
.Output(1,
"present_key",
"Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
"T")
.Output(2,
"present_value",
"Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
"T")
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
SparseAttentionTypeAndShapeInference(ctx, 3);
}));
constexpr const char* Longformer_Attention_doc = R"DOC(
Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token
attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens

Просмотреть файл

@ -104,6 +104,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding);
@ -216,6 +217,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TransposeMatMul)>());

Просмотреть файл

@ -2320,6 +2320,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
#endif
#ifdef ENABLE_CUDA_NHWC_OPS
#ifndef DISABLE_CONTRIB_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaNhwcContribKernels(kernel_registry));
#endif
ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaNhwcKernels(kernel_registry));
#endif

Просмотреть файл

@ -11,210 +11,132 @@
#include "core/providers/cuda/cuda_nhwc_kernels.h"
// Macros to avoid long line length
#define CUDA_NHWC_OP_CLASS_NAME(ver, name) \
ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, ver, name)
#define CUDA_NHWC_OP_TYPED_CLASS_NAME(ver, type, name) \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, ver, type, name)
#define CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(start_ver, end_ver, type, name) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, \
start_ver, end_ver, type, name)
#define CUDA_NHWC_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, start_ver, end_ver, name)
namespace onnxruntime::cuda {
// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout
// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
Conv);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16,
Conv);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16,
ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, float,
AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, MLFloat16,
AveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16,
GlobalAveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, float,
MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, MLFloat16,
MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, float,
MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, MLFloat16,
MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalMaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float,
AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16,
AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float,
MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16,
MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, Conv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16,
ConvTranspose);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, float,
MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, MLFloat16,
MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double,
BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16,
BatchNormalization);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, DepthToSpace);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, SpaceToDepth);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, SpaceToDepth);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, MLFloat16, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, float, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, double, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, MLFloat16, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, Conv);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, Conv);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, ConvTranspose);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, ConvTranspose);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, float, AveragePool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, MLFloat16, AveragePool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalAveragePool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalAveragePool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, float, MaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, MLFloat16, MaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, float, MaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, MLFloat16, MaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalMaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalMaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, AveragePool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, AveragePool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, MaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, MaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, Conv);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, Conv);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, ConvTranspose);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, ConvTranspose);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, AveragePool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, AveragePool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, float, MaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, MLFloat16, MaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, float, MaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, MLFloat16, MaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, int8_t, MaxPool);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, uint8_t, MaxPool);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, float, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, double, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, MLFloat16, BatchNormalization);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(15, float, BatchNormalization);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(15, double, BatchNormalization);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(15, MLFloat16, BatchNormalization);
class CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 10, DepthToSpace);
class CUDA_NHWC_OP_VERSIONED_CLASS_NAME(11, 12, DepthToSpace);
class CUDA_NHWC_OP_CLASS_NAME(13, DepthToSpace);
class CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 12, SpaceToDepth);
class CUDA_NHWC_OP_CLASS_NAME(13, SpaceToDepth);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, float, LRN);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, double, LRN);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, MLFloat16, LRN);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(13, float, LRN);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(13, double, LRN);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(13, MLFloat16, LRN);
Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn nhwc_function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
float, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider,
kMSInternalNHWCDomain, 1, 10, float, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
float, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
MLFloat16, Conv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
float, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
MLFloat16, GlobalAveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
float, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
MLFloat16, GlobalMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
float, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
MLFloat16, AveragePool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
float, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
MLFloat16, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
int8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
uint8_t, MaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
MLFloat16, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
1, 10, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
11, 12, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
13, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
1, 12, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
13, SpaceToDepth)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, float, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, double, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, float, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, double, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(15, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(15, float, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(15, double, BatchNormalization)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, Conv)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, Conv)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, Conv)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, Conv)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, float, AveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalAveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalAveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, float, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, float, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalMaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalMaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, AveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, AveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, AveragePool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, float, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, float, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, MLFloat16, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, int8_t, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, uint8_t, MaxPool)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, ConvTranspose)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, ConvTranspose)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, ConvTranspose)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, ConvTranspose)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 10, DepthToSpace)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_CLASS_NAME(11, 12, DepthToSpace)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_CLASS_NAME(13, DepthToSpace)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 12, SpaceToDepth)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_CLASS_NAME(13, SpaceToDepth)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, float, LRN)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, double, LRN)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, MLFloat16, LRN)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(13, float, LRN)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(13, double, LRN)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(13, MLFloat16, LRN)>,
};
for (auto& function_table_entry : nhwc_function_table) {
@ -226,4 +148,28 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
return Status::OK();
}
} // namespace onnxruntime::cuda
#ifndef DISABLE_CONTRIB_OPS
namespace onnxruntime::contrib::cuda {
class CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample);
onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn nhwc_function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample)>,
};
for (auto& function_table_entry : nhwc_function_table) {
KernelCreateInfo info = function_table_entry();
if (info.kernel_def != nullptr) { // filter disabled entries where type is void
ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
}
}
return Status::OK();
}
} // namespace onnxruntime::contrib::cuda
#endif
#endif

Просмотреть файл

@ -11,3 +11,11 @@ namespace onnxruntime::cuda {
onnxruntime::common::Status RegisterCudaNhwcKernels(onnxruntime::KernelRegistry& kernel_registry);
} // namespace onnxruntime::cuda
#ifndef DISABLE_CONTRIB_OPS
namespace onnxruntime::contrib::cuda {
onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry);
} // namespace onnxruntime::contrib::cuda
#endif

Просмотреть файл

@ -83,7 +83,7 @@ void TryToLoadKernel() {
// get all kernel symbols from curret lib.so
size_t size = sizeof(kernel_infos) / sizeof(kernel_infos[0]);
for (int i = 0; i < size; ++i) {
for (size_t i = 0; i < size; ++i) {
auto k_i = kernel_infos[i];
void* buff;

Просмотреть файл

@ -206,6 +206,7 @@ class SymbolicShapeInference:
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"SparseAttention": self._infer_SparseAttention,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
@ -473,6 +474,7 @@ class SymbolicShapeInference:
"MultiHeadAttention",
"GroupNorm",
"GroupQueryAttention",
"SparseAttention",
"SkipGroupNorm",
"BiasSplitGelu",
"BiasAdd",
@ -2449,6 +2451,8 @@ class SymbolicShapeInference:
past_shape = self._try_get_shape(node, 3)
if past_shape is not None:
# When past and present has the maximum sequence length, we can propagate the shape from past to present.
# Note that GQA also supports different sequence lengths for past and present, but it is rarely used.
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
vi = self.known_vi_[node.output[2]]
@ -2470,6 +2474,9 @@ class SymbolicShapeInference:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape))
def _infer_SparseAttention(self, node): # noqa: N802
self._infer_GroupQueryAttention(node)
def _infer_SkipGroupNorm(self, node): # noqa: N802
self._propagate_shape_and_type(node, 0, 0)
if len(node.output) > 1:

Просмотреть файл

@ -1,7 +1,7 @@
import copy
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy
import torch
@ -224,11 +224,44 @@ class CudaSession:
self.output_tensors = OrderedDict()
self.device = device
# Pairs of input and output names that share the same buffer.
self.buffer_sharing: Dict[str, str] = {}
def set_buffer_sharing(self, input_name: str, output_name: str):
assert input_name in self.input_names
assert output_name in self.output_names
self.buffer_sharing[input_name] = output_name
self.buffer_sharing[output_name] = input_name
def __del__(self):
del self.input_tensors
del self.output_tensors
del self.io_binding
def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
device_id = tensor.device.index if tensor.device.index is not None else 0
tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
self.io_binding.bind_input(
name,
tensor.device.type,
device_id,
self.io_name_to_numpy_type[name],
tensor_shape,
tensor.data_ptr(),
)
if name in self.buffer_sharing:
self.io_binding.bind_output(
self.buffer_sharing[name],
tensor.device.type,
device_id,
self.io_name_to_numpy_type[name],
tensor_shape,
tensor.data_ptr(),
)
self.output_tensors[self.buffer_sharing[name]] = tensor
def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
"""Allocate tensors for I/O Binding"""
if self.enable_cuda_graph:
@ -245,15 +278,7 @@ class CudaSession:
device=self.device
)
self.input_tensors[name] = tensor
self.io_binding.bind_input(
name,
tensor.device.type,
tensor.device.index,
numpy_dtype,
list(tensor.size()),
tensor.data_ptr(),
)
self.bind_input_and_buffer_sharing(name, tensor)
for name, shape in shape_dict.items():
if name in self.output_names:
@ -261,6 +286,9 @@ class CudaSession:
if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
continue
if name in self.buffer_sharing:
continue
numpy_dtype = self.io_name_to_numpy_type[name]
tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to(
device=self.device
@ -270,7 +298,7 @@ class CudaSession:
self.io_binding.bind_output(
name,
tensor.device.type,
tensor.device.index,
tensor.device.index if tensor.device.index is not None else 0,
numpy_dtype,
list(tensor.size()),
tensor.data_ptr(),
@ -287,14 +315,7 @@ class CudaSession:
assert tensor.device.type == "cuda"
self.input_tensors[name].copy_(tensor)
else:
self.io_binding.bind_input(
name,
tensor.device.type,
tensor.device.index,
TypeHelper.torch_type_to_numpy_type(tensor.dtype),
[1] if len(tensor.shape) == 0 else list(tensor.shape),
tensor.data_ptr(),
)
self.bind_input_and_buffer_sharing(name, tensor)
# Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU.
if synchronize:
@ -330,8 +351,13 @@ class GpuBinding(CudaSession):
enable_gpu_graph: bool = False,
gpu_graph_id: int = -1,
stream: int = 0,
buffer_sharing: Optional[Dict[str, str]] = None,
):
super().__init__(ort_session, device, enable_gpu_graph)
if buffer_sharing:
for input_name, output_name in buffer_sharing.items():
self.set_buffer_sharing(input_name, output_name)
self.allocate_buffers(shape_dict)
self.gpu_graph_id = gpu_graph_id
# For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
@ -383,6 +409,7 @@ class GpuBindingManager:
self,
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
use_cuda_graph: bool = False,
buffer_sharing: Optional[Dict[str, str]] = None,
) -> GpuBinding:
for gpu_graph_binding in self.graph_bindings:
# Found a cuda graph that captured with the same shape
@ -392,7 +419,9 @@ class GpuBindingManager:
# Reached the maximum number of cuda graphs. Return a binding without cuda graph.
if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
if self.no_graph_binding is None:
self.no_graph_binding = GpuBinding(self.ort_session, self.device, shape_dict, stream=self.stream)
self.no_graph_binding = GpuBinding(
self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
)
else:
self.no_graph_binding.allocate_buffers(shape_dict)
return self.no_graph_binding
@ -405,6 +434,7 @@ class GpuBindingManager:
enable_gpu_graph=True,
gpu_graph_id=len(self.graph_bindings),
stream=self.stream,
buffer_sharing=buffer_sharing,
)
self.graph_bindings.append(gpu_graph_binding)
return gpu_graph_binding

Просмотреть файл

@ -0,0 +1,680 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 8.x.
"""
import math
from typing import Optional
import torch
from onnx import TensorProto, helper
from torch import Tensor
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
ENABLE_DEBUG = False
class Config:
batch_size = 0
sequence_length = 0
max_sequence_length = 0
past_sequence_length = 0
num_heads = 0
kv_num_heads = 0
head_size = 0
sparse_block_size = 0
num_layout = 0
local_blocks = 0
vert_stride = 0
softmax_scale = None
share_buffer = True
do_rotary = False
rotary_interleaved = False
# TODO: test packed qkv
is_packed_qkv = False
# TODO: test bfloat16.
is_fp16 = True # True for float16; False for bfloat16.
use_sparse = True # True for GroupQueryAttention; False for SparseAttention
def __init__(
self,
batch_size: int,
sequence_length: int,
max_sequence_length: int,
past_sequence_length: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
sparse_block_size: int,
num_layout: int,
local_blocks: int,
vert_stride: int,
softmax_scale=None,
do_rotary: bool = False,
rotary_interleaved: bool = False,
device="cuda",
operator="SparseAttention",
):
self.batch_size = batch_size
self.sequence_length = sequence_length
self.max_sequence_length = max_sequence_length
self.past_sequence_length = past_sequence_length
self.num_heads = num_heads
self.kv_num_heads = kv_num_heads
self.head_size = head_size
self.sparse_block_size = sparse_block_size
self.num_layout = num_layout
self.local_blocks = local_blocks
self.vert_stride = vert_stride
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
# Derived values
self.total_sequence_length = sequence_length + past_sequence_length
self.past_buffer_length = max_sequence_length if self.share_buffer else past_sequence_length
self.present_buffer_length = (
max_sequence_length if self.share_buffer else (past_sequence_length + sequence_length)
)
self.max_blocks = max_sequence_length // sparse_block_size
self.do_rotary = do_rotary
self.rotary_interleaved = rotary_interleaved
self.device = device
self.operator = operator
def block_mask(self):
return get_block_mask(self.num_layout, self.max_blocks, self.local_blocks, self.vert_stride).to(self.device)
def dense_mask(self):
expand_block_mask = self.block_mask()
dense_mask = get_dense_mask(
expand_block_mask, self.total_sequence_length, self.sequence_length, self.sparse_block_size
)
return dense_mask.repeat(self.batch_size, self.num_heads // self.num_layout, 1, 1).to(self.device)
def shape_dict(self):
shape_dict = {
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
"value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
"past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
"past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
"block_mask": (self.num_layout, self.max_blocks, self.max_blocks),
"total_sequence_length": (1,),
"key_total_sequence_lengths": (self.batch_size,),
"seqlens_k": (self.batch_size,),
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
}
if self.operator == "SparseAttention":
del shape_dict["seqlens_k"]
else:
assert self.operator == "GroupQueryAttention"
del shape_dict["key_total_sequence_lengths"]
del shape_dict["block_mask"]
return shape_dict
def get_cos_sin_cache(self, dtype=torch.float32):
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos.to(device=self.device), sin.to(device=self.device)
def random_inputs(self, dtype=torch.float16):
device = self.device
shape_dict = self.shape_dict()
k_seqlens = torch.ones((self.batch_size,), device=device, dtype=torch.int32) * self.total_sequence_length
torch.manual_seed(123)
feeds = {
"query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"block_mask": self.block_mask(),
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
"key_total_sequence_lengths": k_seqlens,
"seqlens_k": k_seqlens - 1,
}
if self.do_rotary:
cos_cache, sin_cache = self.get_cos_sin_cache(dtype)
feeds["cos_cache"] = cos_cache
feeds["sin_cache"] = sin_cache
if "seqlens_k" not in shape_dict:
del feeds["seqlens_k"]
else:
del feeds["key_total_sequence_lengths"]
del feeds["block_mask"]
return feeds
def get_block_mask(num_layout, max_blocks, local_blocks, vert_stride):
q_pos = torch.arange(max_blocks)[None, :, None]
k_pos = torch.arange(max_blocks)[None, None]
head_sliding_step = max(1, int(vert_stride / num_layout))
mask_vert_strided = [
(torch.arange(max_blocks) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(num_layout)
]
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
block_mask = (q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)
block_mask = block_mask.to(torch.int32)
if ENABLE_DEBUG:
torch.set_printoptions(profile="full")
torch.set_printoptions(edgeitems=100)
torch.set_printoptions(linewidth=200)
print(block_mask)
return block_mask
def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size):
dense_mask = torch.kron(block_mask, block_mask.new_ones((block_size, block_size)))[
:, :total_seq_len, :total_seq_len
]
causal_mask = torch.tril(torch.ones(total_seq_len, total_seq_len)).type_as(dense_mask)
dense_mask = dense_mask * causal_mask[None]
return dense_mask[..., -query_seq_len:, :total_seq_len]
def create_sparse_attention_onnx_model(config):
assert config.is_fp16 # python does not support bfloat16 for I/O binding.
float_type = TensorProto.FLOAT16
nodes = [
helper.make_node(
"SparseAttention",
[
"query",
"key" if not config.is_packed_qkv else "",
"value" if not config.is_packed_qkv else "",
"past_key",
"past_value",
"block_mask",
"total_sequence_length" if config.share_buffer else "",
"key_total_sequence_lengths",
"cos_cache" if config.do_rotary else "",
"sin_cache" if config.do_rotary else "",
],
["output", "present_key", "present_value"],
"SparseAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
scale=config.softmax_scale,
sparse_block_size=config.sparse_block_size,
do_rotary=1 if config.do_rotary else 0,
domain="com.microsoft",
),
]
shape_dict = config.shape_dict()
graph_input = [
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
helper.make_tensor_value_info("block_mask", TensorProto.INT32, list(shape_dict["block_mask"])),
helper.make_tensor_value_info(
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
),
helper.make_tensor_value_info(
"key_total_sequence_lengths", TensorProto.INT32, list(shape_dict["key_total_sequence_lengths"])
),
]
if config.do_rotary:
graph_input += [
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
]
graph_output = [
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
]
graph = helper.make_graph(
nodes,
"SparseAttention_Graph",
graph_input,
graph_output,
)
model = helper.make_model(graph)
return model.SerializeToString()
def create_group_query_attention_onnx_model(config):
assert config.is_fp16 # python does not support bfloat16 for I/O binding.
float_type = TensorProto.FLOAT16
nodes = [
helper.make_node(
"GroupQueryAttention",
[
"query",
"key" if not config.is_packed_qkv else "",
"value" if not config.is_packed_qkv else "",
"past_key",
"past_value",
"seqlens_k",
"total_sequence_length" if config.share_buffer else "",
"cos_cache" if config.do_rotary else "",
"sin_cache" if config.do_rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
local_window_size=config.local_blocks * config.sparse_block_size if config.use_sparse else -1,
do_rotary=1 if config.do_rotary else 0,
rotary_interleaved=config.rotary_interleaved,
domain="com.microsoft",
),
]
shape_dict = config.shape_dict()
graph_input = [
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])),
helper.make_tensor_value_info(
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
),
]
if config.do_rotary:
graph_input += [
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
]
graph_output = [
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
]
graph = helper.make_graph(
nodes,
"GroupQueryAttention_Graph",
graph_input,
graph_output,
)
model = helper.make_model(graph)
return model.SerializeToString()
def create_session(onnx_model_str, config: Config, cuda_provider_options=None) -> InferenceSession:
session_options = SessionOptions()
ort_session = InferenceSession(
onnx_model_str,
session_options,
providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"],
)
return ort_session
def group_query_attention_reference(
query: Tensor,
key: Tensor,
value: Tensor,
config: Config,
scale: Optional[float] = None,
mask: Optional[Tensor] = None,
):
if scale is None:
scale = 1.0 / (config.head_size**0.5)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Expand key and value to have same number of heads as query
num_key_value_groups = config.num_heads // config.kv_num_heads
key = torch.repeat_interleave(key, dim=1, repeats=num_key_value_groups)
value = torch.repeat_interleave(value, dim=1, repeats=num_key_value_groups)
# Apply multi-head attention.
attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale
if mask is not None:
attn = attn.masked_fill((1 - mask).bool(), float("-inf"))
attn = attn.softmax(-1)
attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value)
return attn_output.transpose(1, 2).contiguous()
class TorchGroupQueryAttention:
"""A wrapper of Torch GroupQueryAttention to test relevance and performance."""
def __init__(self, device, config: Config, feed_dict):
self.device = device
self.config = config
self.query = feed_dict["query"].view(
config.batch_size, config.sequence_length, config.num_heads, config.head_size
)
self.key = feed_dict["key"].view(
config.batch_size, config.sequence_length, config.kv_num_heads, config.head_size
)
self.value = feed_dict["value"].view(
config.batch_size, config.sequence_length, config.kv_num_heads, config.head_size
)
self.dense_mask = config.dense_mask()
if ENABLE_DEBUG:
torch.set_printoptions(precision=6, edgeitems=3, linewidth=1000, profile="full", sci_mode=False)
print("query(BNSH)", self.query.clone().transpose(1, 2))
print("key(BNSH)", self.key.clone().transpose(1, 2))
print("value(BNSH)", self.value.clone().transpose(1, 2))
print("dense_mask", self.dense_mask)
def infer(self):
return group_query_attention_reference(
self.query, self.key, self.value, self.config, scale=self.config.softmax_scale, mask=self.dense_mask
)
class OrtGroupQueryAttention:
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
def __init__(self, device, config: Config, feed_dict):
cuda_provider_options = CudaSession.get_cuda_provider_options(
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
)
onnx_model_str = create_group_query_attention_onnx_model(config)
self.ort_session = create_session(onnx_model_str, config, cuda_provider_options=cuda_provider_options)
self.gpu_binding_manager = GpuBindingManager(
ort_session=self.ort_session,
device=device,
stream=torch.cuda.current_stream().cuda_stream,
max_cuda_graphs=2,
)
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
self.gpu_binding = self.gpu_binding_manager.get_binding(
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
)
self.feed_dict = feed_dict
def infer(self):
return self.gpu_binding.infer(self.feed_dict)
class OrtSparseAttention:
"""A wrapper of ORT SparseAttention to test relevance and performance."""
def __init__(self, device, config: Config, feed_dict):
cuda_provider_options = CudaSession.get_cuda_provider_options(
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
)
onnx_model_str = create_sparse_attention_onnx_model(config)
self.ort_session = create_session(onnx_model_str, config, cuda_provider_options=cuda_provider_options)
self.gpu_binding_manager = GpuBindingManager(
ort_session=self.ort_session,
device=device,
stream=torch.cuda.current_stream().cuda_stream,
max_cuda_graphs=2,
)
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
self.gpu_binding = self.gpu_binding_manager.get_binding(
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
)
self.feed_dict = feed_dict
def infer(self):
return self.gpu_binding.infer(self.feed_dict)
def run_one_relevance_test(device, config: Config):
dtype = torch.float16
# Run QGA ort
config.use_sparse = False
config.operator = "GroupQueryAttention"
feed_dict = config.random_inputs(dtype=dtype)
if config.past_sequence_length == 0:
obj = TorchGroupQueryAttention(device, config, feed_dict)
expected_out = obj.infer()
else:
obj = OrtGroupQueryAttention(device, config, feed_dict)
ort_qga_outputs = obj.infer()
expected_out = ort_qga_outputs["output"].view(
config.batch_size, config.sequence_length, config.num_heads, config.head_size
)
# Run SparseAttention by ORT
config.use_sparse = True
config.operator = "SparseAttention"
if config.past_sequence_length != 0:
config.local_blocks = config.max_blocks # Use dense to compare with GQA
feed_dict = config.random_inputs(dtype=dtype)
if ENABLE_DEBUG:
print("block_mask", feed_dict["block_mask"])
print("total_sequence_length", feed_dict["total_sequence_length"])
print("key_total_sequence_lengths", feed_dict["key_total_sequence_lengths"])
obj = OrtSparseAttention(device, config, feed_dict)
ort_outputs = obj.infer()
ort_output = ort_outputs["output"]
actual_out = ort_output.view(config.batch_size, config.sequence_length, config.num_heads, config.head_size)
if torch.allclose(expected_out, actual_out, atol=1e-2, rtol=0):
print(f"Relevance test passed: {vars(config)}")
else:
print(f"Relevance test not passed: {vars(config)}")
print("ort_output", actual_out)
print("expected_out", expected_out)
print("diff", expected_out - actual_out)
exit(1)
def run_relevance_no_past(device):
"""Test prompt prefilling without past kv cache."""
for seq_len in [1, 64, 127, 128, 192, 256]:
config = Config(
batch_size=1,
sequence_length=seq_len,
max_sequence_length=256,
past_sequence_length=0,
num_heads=8,
kv_num_heads=4,
head_size=128,
sparse_block_size=64,
num_layout=2,
local_blocks=2,
vert_stride=2,
softmax_scale=1.8 / (128**0.5),
)
run_one_relevance_test(device, config)
def run_relevance_past(device):
"""Test token generation with past kv cache."""
for past_seq_len in [1, 63, 64, 127, 128, 511]:
config = Config(
batch_size=2,
sequence_length=1,
max_sequence_length=512,
past_sequence_length=past_seq_len,
num_heads=8,
kv_num_heads=4,
head_size=128,
sparse_block_size=64,
num_layout=4,
local_blocks=2,
vert_stride=4,
do_rotary=True,
rotary_interleaved=(past_seq_len % 2 == 1),
)
run_one_relevance_test(device, config)
def run_relevance_test():
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
with torch.no_grad():
run_relevance_no_past(device)
run_relevance_past(device)
def plot_prompt_performance(
batch_size=4,
num_heads=32,
max_seq_len=8192,
head_size=128,
sparse_block_size=64,
local_blocks=16,
vert_stride=8,
num_layout=8,
dtype=torch.float16,
):
import triton
configs = [
triton.testing.Benchmark(
x_names=["sequence_length"],
x_vals=[2**i for i in range(4, 14)],
line_arg="provider",
line_vals=["torch_gqa", "ort_gqa", "ort_gqa_local", "ort_sparse_att"],
line_names=["TORCH-GQA", "ORT-GQA-Dense", "ORT-GQA-Local", "ORT-SparseAtt"],
styles=[("red", "-"), ("yellow", "-"), ("blue", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"prompt-batch{batch_size}-head{num_heads}-d{head_size}-local{local_blocks}-vert{vert_stride}-{dtype}",
args={"num_heads": num_heads, "batch_size": batch_size, "head_size": head_size, "dtype": dtype},
)
]
@triton.testing.perf_report(configs)
def benchmark(batch_size, num_heads, sequence_length, head_size, provider, dtype=torch.float16, device="cuda"):
warmup = 15
repeat = 100
config = Config(
batch_size=batch_size,
sequence_length=sequence_length,
max_sequence_length=max_seq_len,
past_sequence_length=0,
num_heads=num_heads,
kv_num_heads=8,
head_size=head_size,
sparse_block_size=sparse_block_size,
num_layout=num_layout,
local_blocks=local_blocks,
vert_stride=vert_stride,
)
config.use_sparse = provider in ["ort_sparse_att", "ort_gqa_local"]
config.operator = "SparseAttention" if provider in ["ort_sparse_att"] else "GroupQueryAttention"
feed_dict = config.random_inputs(dtype=dtype)
if provider in ["ort_gqa", "ort_gqa_local"]:
obj = OrtGroupQueryAttention(device, config, feed_dict)
elif provider == "ort_sparse_att":
obj = OrtSparseAttention(device, config, feed_dict)
else:
assert provider == "torch_gqa"
if sequence_length > 2048: # out of memory
return 0
obj = TorchGroupQueryAttention(device, config, feed_dict)
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
return ms
benchmark.run(save_path=".", print_data=True)
def plot_token_performance(
batch_size=4,
num_heads=32,
max_seq_len=8192,
head_size=128,
sparse_block_size=64,
local_blocks=16,
vert_stride=8,
num_layout=8,
dtype=torch.float16,
):
import triton
configs = [
triton.testing.Benchmark(
x_names=["past_sequence_length"],
x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1],
line_arg="provider",
line_vals=["torch_gqa", "ort_gqa", "ort_gqa_local", "ort_sparse_att"],
line_names=["TORCH-GQA", "ORT-GQA-Dense", "ORT-GQA-Local", "ORT-SparseAtt"],
styles=[("red", "-"), ("yellow", "-"), ("blue", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"token-batch{batch_size}-head{num_heads}-d{head_size}-local{local_blocks}-vert{vert_stride}-{dtype}",
args={"num_heads": num_heads, "batch_size": batch_size, "head_size": head_size, "dtype": dtype},
)
]
@triton.testing.perf_report(configs)
def benchmark(batch_size, num_heads, past_sequence_length, head_size, provider, dtype=torch.float16, device="cuda"):
warmup = 15
repeat = 100
config = Config(
batch_size=batch_size,
sequence_length=1,
max_sequence_length=max_seq_len,
past_sequence_length=past_sequence_length,
num_heads=num_heads,
kv_num_heads=8,
head_size=head_size,
sparse_block_size=sparse_block_size,
num_layout=num_layout,
local_blocks=local_blocks,
vert_stride=vert_stride,
)
config.use_sparse = provider in ["ort_sparse_att", "ort_gqa_local"]
config.operator = "SparseAttention" if provider in ["ort_sparse_att"] else "GroupQueryAttention"
feed_dict = config.random_inputs(dtype=dtype)
if provider in ["ort_gqa", "ort_gqa_local"]:
obj = OrtGroupQueryAttention(device, config, feed_dict)
elif provider == "ort_sparse_att":
obj = OrtSparseAttention(device, config, feed_dict)
else:
assert provider == "torch_gqa"
if past_sequence_length > 2048: # out of memory
return 0
obj = TorchGroupQueryAttention(device, config, feed_dict)
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
return ms
benchmark.run(save_path=".", print_data=True)
if __name__ == "__main__":
s = torch.cuda.Stream()
with torch.cuda.stream(s):
run_relevance_test()
plot_prompt_performance()
plot_token_performance()

Просмотреть файл

@ -100,6 +100,8 @@ unfixable = [
# Eventually this list should become empty.
"orttraining/orttraining/test/**" = ["N802"] # Function casing
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
"onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_triton.py" = ["N806"] # use of Q, K and V in triton script
"onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_triton.py" = ["N806"] # use of Q, K and V in triton script
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.