[CUDA] Add SparseAttention operator for Phi-3-small (#20216)
### Description Add CUDA implementation for block sparse attention for Phi-3-small. Block sparse attention was proposed in [Sparse Transformers](https://arxiv.org/pdf/1904.10509) by OpenAI, and also adopted in [BigBird](https://arxiv.org/pdf/2007.14062) with different sparse layout. In Phi-3-small, the sparse layout is static, and works with unidirectional (causal) attention. Compared to dense attention, the benefit of block sparse is to speed up both training and inference. It could save memory thus support longer context length. - [x] Add operator spec and shape inference - [x] Symbolic shape inference - [x] Refactor GroupQueryAttention to expose common kernels for kv cache concatenation, q/k/v transpose etc. - [x] Add cuda kernel to convert block mask to CSR format - [x] Add cuda kernel to generate position ids - [x] Add compile script and template files to convert triton kernel to cubin and dispatcher. - [x] Add triton kernel v1 for prompt - [x] Add triton kernel v2 for token generation and support padding - [x] Update IO Binding Helper to allow buffer sharing. - [x] Test relevance - [x] Test performance ### Performance Test in A100-SXM4-80GB with `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Average latency in milliseconds (for fused attention kernel used in prompt prefilling): seq_len | GQA-Dense | GQA-Local | SparseAttention -- | -- | -- | -- 64 | 0.0465 | 0.0722 | 0.0641 128 | 0.0618 | 0.0787 | 0.0672 256 | 0.1086 | 0.1076 | 0.0943 512 | 0.2535 | 0.2487 | 0.1676 1024 | 0.7042 | 0.7050 | 0.3800 2048 | 2.4125 | 1.9316 | 0.8966 4096 | 8.9346 | 4.5699 | 2.1129 8192 | 40.5401 | 10.3508 | 5.1748 Average latency in milliseconds (for fused attention kernel used in token generation: past_seq_len | GQA-Dense | GQA-Local | SparseAttention -- | -- | -- | -- 64 | 0.0186 | 0.0186 | 0.0870 128 | 0.0408 | 0.0466 | 0.1165 256 | 0.0530 | 0.0592 | 0.0988 512 | 0.0445| 0.0447 | 0.1150 1024 | 0.0634 | 0.0640 | 0.1454 2048 | 0.1027 | 0.0637 | 0.1589 4096 | 0.1789 | 0.0631 | 0.1806 8192 | 0.3288 | 0.0655 | 0.2146 We can see that the kernel for token generation still have room to improve. #### Limitations Only support right-side padding and unidirectional attention. The following are not supported in the first version: (1) Packed mode like PackedMultiHeadAttention where input has been removed padding. (2) paged attention. (3) bidirectional attention. (4) GPU compute capacity that is not 8.0, 8.6 and 8.9. (5) Left side padding. Some of these limitations will be removed in the future (may be in a new operator).
This commit is contained in:
Родитель
b2481e3602
Коммит
9f0fae29e8
|
@ -4,10 +4,12 @@
|
|||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
# set all triton kernel ops that need to be compiled
|
||||
set(triton_kernel_scripts
|
||||
"onnxruntime/core/providers/rocm/math/softmax_triton.py"
|
||||
"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
|
||||
)
|
||||
if(onnxruntime_USE_ROCM)
|
||||
set(triton_kernel_scripts
|
||||
"onnxruntime/core/providers/rocm/math/softmax_triton.py"
|
||||
"onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py"
|
||||
)
|
||||
endif()
|
||||
|
||||
function(compile_triton_kernel out_triton_kernel_obj_file out_triton_kernel_header_dir)
|
||||
# compile triton kernel, generate .a and .h files
|
||||
|
|
|
@ -46,6 +46,7 @@ set(contrib_ops_excluded_files
|
|||
"math/gemm_float8.cu"
|
||||
"math/gemm_float8.h"
|
||||
"moe/*"
|
||||
"sparse/*"
|
||||
"quantization/attention_quantization.cc"
|
||||
"quantization/attention_quantization.h"
|
||||
"quantization/attention_quantization_impl.cu"
|
||||
|
|
|
@ -102,6 +102,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
|
||||
* <a href="#com.microsoft.SkipSimplifiedLayerNormalization">com.microsoft.SkipSimplifiedLayerNormalization</a>
|
||||
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
|
||||
* <a href="#com.microsoft.SparseAttention">com.microsoft.SparseAttention</a>
|
||||
* <a href="#com.microsoft.SparseToDenseMatMul">com.microsoft.SparseToDenseMatMul</a>
|
||||
* <a href="#com.microsoft.Tokenizer">com.microsoft.Tokenizer</a>
|
||||
* <a href="#com.microsoft.TorchEmbedding">com.microsoft.TorchEmbedding</a>
|
||||
|
@ -3418,7 +3419,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
Input tensors contains the hidden embedding of real tokens.
|
||||
Token_offset records the offset of token in the unpacked input.
|
||||
cumulated_token_count records cumulated length of each sequnces length.
|
||||
cumulated_token_count records cumulated length of each sequence length.
|
||||
|
||||
The operator only supports BERT like model with padding on right now.
|
||||
|
||||
|
@ -3492,7 +3493,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
|
||||
Token_offset records the offset of token in the unpacked input.
|
||||
cumulative_sequence_length records cumulated length of each sequnces length.
|
||||
cumulative_sequence_length records cumulated length of each sequence length.
|
||||
|
||||
The operator only supports BERT like model with padding on right now.
|
||||
|
||||
|
@ -5541,6 +5542,90 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.SparseAttention"></a><a name="com.microsoft.sparseattention">**com.microsoft.SparseAttention**</a>
|
||||
|
||||
Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219).
|
||||
|
||||
It is inspired by Sparse Transformers (https://arxiv.org/pdf/1904.10509) and BigBird (https://arxiv.org/pdf/2007.14062).
|
||||
|
||||
block_mask can be used to configure sparse layout for different head.
|
||||
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
|
||||
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
|
||||
|
||||
Padding shall be on the right side.
|
||||
|
||||
When do_rotary is True, cos_cache and sin_cache are required.
|
||||
|
||||
Only supports unidirectional attention with cache of past key and value in linear buffers.
|
||||
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>do_rotary</tt> : int</dt>
|
||||
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
|
||||
<dt><tt>kv_num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads for key and value</dd>
|
||||
<dt><tt>num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads for query</dd>
|
||||
<dt><tt>rotary_interleaved</tt> : int</dt>
|
||||
<dd>Rotary use interleaved pattern or not. Default value is 0.</dd>
|
||||
<dt><tt>scale</tt> : float</dt>
|
||||
<dd>Scaling factor applied prior to softmax. The default value is 1/sqrt(head_size)</dd>
|
||||
<dt><tt>sparse_block_size</tt> : int (required)</dt>
|
||||
<dd>Number of tokens per sparse block. Choices: 16, 32, 64, 128</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (8 - 10)
|
||||
|
||||
<dl>
|
||||
<dt><tt>query</tt> : T</dt>
|
||||
<dd>Query with shape (batch_size, sequence_length, num_heads * head_size), or packed QKV with shape is(batch_size, sequence_length, d) where d is (num_heads + 2 * kv_num_heads) * head_size.</dd>
|
||||
<dt><tt>key</tt> (optional) : T</dt>
|
||||
<dd>Key with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
|
||||
<dt><tt>value</tt> (optional) : T</dt>
|
||||
<dd>Value with shape (batch_size, sequence_length, kv_num_heads * head_size)</dd>
|
||||
<dt><tt>past_key</tt> (optional) : T</dt>
|
||||
<dd>Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>past_value</tt> (optional) : T</dt>
|
||||
<dd>Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>block_mask</tt> : M</dt>
|
||||
<dd>block mask. 1 indicates attention and 0 no attention. Its shape is (num_layout, max_blocks, max_blocks), where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.</dd>
|
||||
<dt><tt>total_sequence_length</tt> : M</dt>
|
||||
<dd>Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.</dd>
|
||||
<dt><tt>key_total_sequence_lengths</tt> : M</dt>
|
||||
<dd>1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.</dd>
|
||||
<dt><tt>cos_cache</tt> (optional) : T</dt>
|
||||
<dd>Cos cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
|
||||
<dt><tt>sin_cache</tt> (optional) : T</dt>
|
||||
<dd>Sin cache of rotary with shape (max_sequence_length, head_size / 2).</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt> : T</dt>
|
||||
<dd>3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)</dd>
|
||||
<dt><tt>present_key</tt> : T</dt>
|
||||
<dd>Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
|
||||
<dt><tt>present_value</tt> : T</dt>
|
||||
<dd>Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
|
||||
<dd>Constrain input and output to float tensors.</dd>
|
||||
<dt><tt>M</tt> : tensor(int32)</dt>
|
||||
<dd>Constrain integer type.</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.SparseToDenseMatMul"></a><a name="com.microsoft.sparsetodensematmul">**com.microsoft.SparseToDenseMatMul**</a>
|
||||
|
||||
#### Version
|
||||
|
|
|
@ -906,6 +906,7 @@ Do not modify directly.*
|
|||
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SparseAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* block_mask:**M**<br> *in* total_sequence_length:**M**<br> *in* key_total_sequence_lengths:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|
||||
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
@ -113,6 +113,35 @@ struct GroupQueryAttentionParameters {
|
|||
int* zero_ptr;
|
||||
};
|
||||
|
||||
// Parameters for sparse attention.
|
||||
struct SparseAttentionParameters {
|
||||
int batch_size; // batch size
|
||||
int sequence_length; // sequence length of input query, key, value
|
||||
int hidden_size; // hidden size of query
|
||||
int num_heads; // number of heads of query
|
||||
int head_size; // hidden size per head of query, key or value
|
||||
int kv_hidden_size; // hidden size of key or value
|
||||
int kv_num_heads; // number of heads of key or value
|
||||
bool do_rotary; // whether to use rotary embedding
|
||||
bool rotary_interleaved; // whether to use interleaved rotary embedding
|
||||
int rotary_dim; // rotary embedding dimension
|
||||
int sparse_block_size; // block size for sparse attention
|
||||
int num_sparse_layout; // number of sparse layout, or the first dimension of block_mask
|
||||
float scale; // scaling factor applied prior to softmax
|
||||
bool is_packed_qkv; // whether qkv is packed
|
||||
int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys
|
||||
int max_sequence_length; // max sequence length allowed
|
||||
bool past_present_share_buffer; // whether past_key and present_key share buffer, so is past_value and present_value
|
||||
};
|
||||
|
||||
constexpr bool LAYOUT_BSNH = false;
|
||||
constexpr bool LAYOUT_BNSH = true;
|
||||
|
||||
namespace sparse_attention {
|
||||
// Environment variable to enable or disable sparse attention v1 kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_V1";
|
||||
} // namespace sparse_attention
|
||||
|
||||
namespace attention {
|
||||
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
|
||||
|
|
|
@ -5,30 +5,12 @@
|
|||
#include <string>
|
||||
#include "core/framework/ort_value.h"
|
||||
|
||||
// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc)
|
||||
#ifdef DEBUG_GENERATION
|
||||
#define DUMP_TENSOR_LEVEL 2
|
||||
#else
|
||||
#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
|
||||
#endif
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
#define DUMP_TENSOR_INIT() transformers::CudaTensorConsoleDumper dumper
|
||||
#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__)
|
||||
#else
|
||||
#define DUMP_TENSOR_INIT()
|
||||
#define DUMP_TENSOR(...)
|
||||
#endif
|
||||
#if DUMP_TENSOR_LEVEL > 1
|
||||
#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__)
|
||||
#else
|
||||
#define DUMP_TENSOR_D(...)
|
||||
#endif
|
||||
|
||||
class IConsoleDumper {
|
||||
public:
|
||||
IConsoleDumper() : is_enabled_(true) {}
|
||||
|
|
|
@ -129,6 +129,9 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
|
|||
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output,
|
||||
int total_matrix_count = -1);
|
||||
|
||||
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
|
||||
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block);
|
||||
|
||||
Status LaunchConcatTensorToTensor(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
|
|
|
@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
|
|||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) {
|
||||
AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
|
@ -257,9 +257,9 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
|
|||
batch_size, sequence_length, num_heads, qk_head_size,
|
||||
data.query, data.bias, qkv,
|
||||
true, v_head_size, qkv_add_bias, 3);
|
||||
DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
|
||||
DUMP_TENSOR_D("q(BSNH)", data.q, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
} else {
|
||||
if (!use_fused_kernel) {
|
||||
|
@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
|
|||
AttentionData<T>& data,
|
||||
cudaStream_t stream,
|
||||
int max_threads_per_block,
|
||||
T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) {
|
||||
AttentionQkvFormat& qkv_format) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.kv_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
|
@ -301,10 +301,10 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
|
|||
const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
|
||||
LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
|
||||
batch_size, kv_sequence_length, num_heads, qk_head_size,
|
||||
data.key, kv_bias, k,
|
||||
data.key, kv_bias, data.k,
|
||||
true, v_head_size, qkv_add_bias, 2);
|
||||
DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
DUMP_TENSOR_D("k(BSNH)", data.k, batch_size, kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_TENSOR_D("v(BSNH)", data.v, batch_size, kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
} else {
|
||||
if (data.fused_cross_attention_kernel == nullptr) {
|
||||
|
@ -461,11 +461,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
|
|||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block,
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
} else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block,
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, data.qkv_format));
|
||||
} else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block,
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, data.qkv_format));
|
||||
} else { // multihead attention operator, no past, separated Q/K/V inputs
|
||||
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block,
|
||||
data.q, data.k, data.v, data.qkv_format));
|
||||
|
|
|
@ -298,6 +298,12 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
|
|||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
|
||||
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block) {
|
||||
return LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, false, input, output);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -47,6 +47,9 @@ limitations under the License.
|
|||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
// Macro to help compute index of flatten 4D matrix, note that dim1 is not used so it is excluded.
|
||||
#define INDEX_4D(dim2, dim3, dim4, i, j, k, l) ((i) * (dim2) * (dim3) * (dim4) + (j) * (dim3) * (dim4) + (k) * (dim4) + (l))
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
@ -216,123 +219,162 @@ template <typename T>
|
|||
__global__ void ConcatKVInPlace(const int max_seqlen,
|
||||
T* kv_buff,
|
||||
const T* new_kv,
|
||||
const int* seqlens_k,
|
||||
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
|
||||
const int* past_seqlens_k,
|
||||
const int* total_seqlens_k,
|
||||
const bool is_past_kv_bnsh_format,
|
||||
const bool is_new_kv_bnsh_format) {
|
||||
const int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
|
||||
const int new_seqlen = gridDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
const int kv_num_heads = blockDim.y;
|
||||
const int H = blockDim.x;
|
||||
|
||||
const int present_batch_stride = max_seqlen * num_heads * H;
|
||||
const int present_row_stride = is_bsnh ? num_heads * H : H;
|
||||
const int present_head_stride = is_bsnh ? H : max_seqlen * H;
|
||||
const int past_seq_len = (total_seqlens_k != nullptr)
|
||||
? (total_seqlens_k[b] - new_seqlen)
|
||||
: (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]);
|
||||
|
||||
// kv_buff: BTNH or BNTH with buffered memory for new
|
||||
// new_kv: BLNH
|
||||
int out_offset = is_past_kv_bnsh_format
|
||||
? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h)
|
||||
: INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h);
|
||||
|
||||
const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b];
|
||||
int in_offset = is_new_kv_bnsh_format
|
||||
? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h)
|
||||
: INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h);
|
||||
|
||||
int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h;
|
||||
// Note: new KV always BSNH
|
||||
const int new_batch_stride = new_seqlen * num_heads * H;
|
||||
const int new_row_stride = num_heads * H;
|
||||
const int new_head_stride = H;
|
||||
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
|
||||
kv_buff[out_offset] = new_kv[in_offset];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ConcatKVInPlaceLarge(const int max_seqlen,
|
||||
const int H,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
T* kv_buff,
|
||||
const T* new_kv,
|
||||
const int* seqlens_k,
|
||||
const bool is_bsnh) { // refers to kv buff; otherwise bnsh
|
||||
const int* past_seqlens_k,
|
||||
const int* total_seqlens_k,
|
||||
const bool is_past_kv_bnsh_format,
|
||||
const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh
|
||||
int i = threadIdx.x + (blockDim.x * blockIdx.x);
|
||||
if (i < H * num_heads) {
|
||||
if (i < H * kv_num_heads) {
|
||||
const int h = i % H;
|
||||
const int n = i / H;
|
||||
const int s = blockIdx.y;
|
||||
const int b = blockIdx.z;
|
||||
const int new_seqlen = gridDim.y;
|
||||
const int past_seq_len = (total_seqlens_k != nullptr)
|
||||
? (total_seqlens_k[b] - new_seqlen)
|
||||
: (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]);
|
||||
|
||||
const int present_batch_stride = max_seqlen * num_heads * H;
|
||||
const int present_row_stride = is_bsnh ? num_heads * H : H;
|
||||
const int present_head_stride = is_bsnh ? H : max_seqlen * H;
|
||||
int out_offset = is_past_kv_bnsh_format
|
||||
? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h)
|
||||
: INDEX_4D(max_seqlen, kv_num_heads, H, b, s + past_seq_len, n, h);
|
||||
|
||||
// kv_buff: BTNH or BNTH with buffered memory for new
|
||||
// new_kv: BLNH
|
||||
int in_offset = is_new_kv_bnsh_format
|
||||
? INDEX_4D(kv_num_heads, new_seqlen, H, b, n, s, h)
|
||||
: INDEX_4D(new_seqlen, kv_num_heads, H, b, s, n, h);
|
||||
|
||||
const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b];
|
||||
|
||||
int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h;
|
||||
// Note: new KV always BSNH
|
||||
const int new_batch_stride = new_seqlen * num_heads * H;
|
||||
const int new_row_stride = num_heads * H;
|
||||
const int new_head_stride = H;
|
||||
const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h;
|
||||
kv_buff[out_offset] = new_kv[in_offset];
|
||||
}
|
||||
}
|
||||
|
||||
// Concat new to kv buffer in place
|
||||
template <typename T>
|
||||
Status LaunchConcatKVInPlace(int batch_size,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int max_sequence_length,
|
||||
const int* past_seqlens_k,
|
||||
const int* total_seqlens_k,
|
||||
int new_seq_len,
|
||||
const T* new_key,
|
||||
const T* new_value,
|
||||
T* present_key,
|
||||
T* present_value,
|
||||
bool is_past_kv_bnsh_format,
|
||||
bool is_new_kv_bnsh_format,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
static_assert(sizeof(T) == 2);
|
||||
assert(head_size % 4 == 0);
|
||||
|
||||
const int H = head_size / 4;
|
||||
if (H * kv_num_heads <= max_threads_per_block) {
|
||||
const dim3 grid(new_seq_len, batch_size, 1);
|
||||
const dim3 block(H, kv_num_heads, 1);
|
||||
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(max_sequence_length,
|
||||
reinterpret_cast<float2*>(present_key),
|
||||
reinterpret_cast<const float2*>(new_key),
|
||||
past_seqlens_k,
|
||||
total_seqlens_k,
|
||||
is_past_kv_bnsh_format,
|
||||
is_new_kv_bnsh_format);
|
||||
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(max_sequence_length,
|
||||
reinterpret_cast<float2*>(present_value),
|
||||
reinterpret_cast<const float2*>(new_value),
|
||||
past_seqlens_k,
|
||||
total_seqlens_k,
|
||||
is_past_kv_bnsh_format,
|
||||
is_new_kv_bnsh_format);
|
||||
} else {
|
||||
int steps = int(ceil(float(H * kv_num_heads) / 256.0));
|
||||
const dim3 grid(steps, new_seq_len, batch_size);
|
||||
const dim3 block(256, 1, 1);
|
||||
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(max_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<float2*>(present_key),
|
||||
reinterpret_cast<const float2*>(new_key),
|
||||
past_seqlens_k,
|
||||
total_seqlens_k,
|
||||
is_past_kv_bnsh_format,
|
||||
is_new_kv_bnsh_format);
|
||||
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(max_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<float2*>(present_value),
|
||||
reinterpret_cast<const float2*>(new_value),
|
||||
past_seqlens_k,
|
||||
total_seqlens_k,
|
||||
is_past_kv_bnsh_format,
|
||||
is_new_kv_bnsh_format);
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
// Concat new to kv buffer in place
|
||||
template <typename T>
|
||||
Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data,
|
||||
const void* new_key,
|
||||
const void* new_value,
|
||||
bool is_new_kv_bnsh_format,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int kv_sequence_length = parameters.sequence_length;
|
||||
const int present_sequence_length = parameters.seqlen_present_kv_cache;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
const int max_sequence_length = parameters.seqlen_present_kv_cache;
|
||||
const int* past_seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast<const int*>(data.seqlens_k);
|
||||
|
||||
// Indicates past sequence_length of each sequence
|
||||
const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast<const int*>(data.seqlens_k);
|
||||
assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH ||
|
||||
parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
bool is_past_kv_bnsh_format = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
|
||||
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
|
||||
assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
const int H = head_size / 4;
|
||||
if (H * kv_num_heads <= max_threads_per_block) {
|
||||
const dim3 grid(kv_sequence_length, batch_size, 1);
|
||||
const dim3 block(H, kv_num_heads, 1);
|
||||
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(present_sequence_length,
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
reinterpret_cast<const float2*>(new_key),
|
||||
seqlens_k,
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatKVInPlace<float2><<<grid, block, 0, stream>>>(present_sequence_length,
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
reinterpret_cast<const float2*>(new_value),
|
||||
seqlens_k,
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
} else {
|
||||
int steps = int(ceil(float(H * kv_num_heads) / 256.0));
|
||||
const dim3 grid(steps, kv_sequence_length, batch_size);
|
||||
const dim3 block(256, 1, 1);
|
||||
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(present_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<float2*>(data.present_key),
|
||||
reinterpret_cast<const float2*>(new_key),
|
||||
seqlens_k,
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
ConcatKVInPlaceLarge<float2><<<grid, block, 0, stream>>>(present_sequence_length,
|
||||
H,
|
||||
kv_num_heads,
|
||||
reinterpret_cast<float2*>(data.present_value),
|
||||
reinterpret_cast<const float2*>(new_value),
|
||||
seqlens_k,
|
||||
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
}
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
return LaunchConcatKVInPlace(parameters.batch_size,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
max_sequence_length,
|
||||
past_seqlens_k,
|
||||
nullptr, // total_seqlens_k is not available
|
||||
parameters.sequence_length,
|
||||
reinterpret_cast<const T*>(new_key),
|
||||
reinterpret_cast<const T*>(new_value),
|
||||
data.present_key,
|
||||
data.present_value,
|
||||
is_past_kv_bnsh_format,
|
||||
is_new_kv_bnsh_format,
|
||||
stream,
|
||||
max_threads_per_block);
|
||||
}
|
||||
|
||||
// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh
|
||||
|
@ -474,41 +516,60 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i
|
|||
}
|
||||
|
||||
// Kernel to unpack qkv from packed qkv
|
||||
template <typename T>
|
||||
template <typename T, bool output_bnsh>
|
||||
__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
|
||||
const int kv_num_heads, const int head_size, const int sequence_length,
|
||||
const int batch_size) {
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int d = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int qkv_size = batch_size * sequence_length * d;
|
||||
const int q_size = num_heads * head_size;
|
||||
const int k_size = kv_num_heads * head_size;
|
||||
const int q_hidden = num_heads * head_size;
|
||||
const int k_hidden = kv_num_heads * head_size;
|
||||
if (tid < qkv_size) {
|
||||
int batch = tid / (d * sequence_length);
|
||||
int sequence = (tid % (d * sequence_length)) / d;
|
||||
int b = tid / (d * sequence_length);
|
||||
int s = (tid % (d * sequence_length)) / d;
|
||||
int offset = tid % d;
|
||||
if (offset < q_size) {
|
||||
int unpacked_i = batch * sequence_length * num_heads * head_size + sequence * num_heads * head_size + offset;
|
||||
if (output_bnsh) { // output BNSH
|
||||
int head_count = kv_num_heads;
|
||||
if (offset < q_hidden) {
|
||||
head_count = num_heads;
|
||||
} else if (offset < q_hidden + k_hidden) {
|
||||
offset -= q_hidden;
|
||||
} else {
|
||||
offset -= (q_hidden + k_hidden);
|
||||
}
|
||||
int n = offset / head_size;
|
||||
int h = offset % head_size;
|
||||
|
||||
int unpacked_i = INDEX_4D(head_count, sequence_length, head_size, b, n, s, h);
|
||||
|
||||
unpacked_q[unpacked_i] = packed_qkv[tid];
|
||||
} else if (offset < q_size + k_size) {
|
||||
int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size);
|
||||
unpacked_k[unpacked_i] = packed_qkv[tid];
|
||||
} else {
|
||||
int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size - k_size);
|
||||
unpacked_v[unpacked_i] = packed_qkv[tid];
|
||||
} else { // output BSNH
|
||||
if (offset < q_hidden) {
|
||||
int unpacked_i = b * sequence_length * num_heads * head_size + s * num_heads * head_size + offset;
|
||||
unpacked_q[unpacked_i] = packed_qkv[tid];
|
||||
} else if (offset < q_hidden + k_hidden) {
|
||||
int unpacked_i = b * sequence_length * kv_num_heads * head_size +
|
||||
s * kv_num_heads * head_size + (offset - q_hidden);
|
||||
unpacked_k[unpacked_i] = packed_qkv[tid];
|
||||
} else {
|
||||
int unpacked_i = b * sequence_length * kv_num_heads * head_size +
|
||||
s * kv_num_heads * head_size + (offset - q_hidden - k_hidden);
|
||||
unpacked_v[unpacked_i] = packed_qkv[tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpack packed qkv
|
||||
template <typename T>
|
||||
template <typename T, bool output_bnsh>
|
||||
Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
|
||||
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
|
||||
cudaStream_t stream, const int max_threads_per_block) {
|
||||
const int threads = max_threads_per_block;
|
||||
const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads;
|
||||
UnpackQKV<<<blocks, threads, 0, stream>>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, batch_size);
|
||||
UnpackQKV<T, output_bnsh><<<blocks, threads, 0, stream>>>(
|
||||
packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size);
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
@ -660,8 +721,14 @@ Status EfficientAttention(
|
|||
auto q = reinterpret_cast<T*>(data.unpacked_qkv_buffer);
|
||||
auto k = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size);
|
||||
auto v = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size + k_size);
|
||||
ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast<const T*>(data.query), q, k, v, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, batch_size, stream, max_threads_per_block));
|
||||
|
||||
Status status = LaunchUnpackQKV<T, LAYOUT_BSNH>(
|
||||
reinterpret_cast<const T*>(data.query), q, k, v, num_heads, kv_num_heads,
|
||||
head_size, sequence_length, batch_size, stream, max_threads_per_block);
|
||||
if (status != Status::OK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
query = reinterpret_cast<const void*>(q);
|
||||
key = reinterpret_cast<const void*>(k);
|
||||
value = reinterpret_cast<const void*>(v);
|
||||
|
@ -713,7 +780,9 @@ Status EfficientAttention(
|
|||
"Past and present kv shall share the same tensor when kv_share_buffer is on.");
|
||||
}
|
||||
// Concatenate new kv in place
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, stream, max_threads_per_block));
|
||||
constexpr bool is_new_kv_bnsh_format = false;
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
|
||||
parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block));
|
||||
} else {
|
||||
// Not share buffer case
|
||||
if (data.past_key != nullptr && data.past_key == data.present_key) {
|
||||
|
@ -825,6 +894,51 @@ template Status QkvToContext<BFloat16>(
|
|||
contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<BFloat16>& data);
|
||||
|
||||
template Status LaunchUnpackQKV<half, LAYOUT_BNSH>(
|
||||
const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads,
|
||||
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
|
||||
cudaStream_t stream, const int max_threads_per_block);
|
||||
|
||||
template Status LaunchUnpackQKV<BFloat16, LAYOUT_BNSH>(
|
||||
const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads,
|
||||
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
|
||||
cudaStream_t stream, const int max_threads_per_block);
|
||||
|
||||
template Status LaunchConcatKVInPlace<half>(int batch_size,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int max_sequence_length,
|
||||
const int* past_seqlens_k,
|
||||
const int* total_seqlens_k,
|
||||
int new_seq_len,
|
||||
const half* new_key,
|
||||
const half* new_value,
|
||||
half* present_key,
|
||||
half* present_value,
|
||||
bool is_past_kv_bnsh_format,
|
||||
bool is_new_kv_bnsh_format,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block);
|
||||
|
||||
template Status LaunchConcatKVInPlace<BFloat16>(int batch_size,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int max_sequence_length,
|
||||
const int* past_seqlens_k,
|
||||
const int* total_seqlens_k,
|
||||
int new_seq_len,
|
||||
const BFloat16* new_key,
|
||||
const BFloat16* new_value,
|
||||
BFloat16* present_key,
|
||||
BFloat16* present_value,
|
||||
bool is_past_kv_bnsh_format,
|
||||
bool is_new_kv_bnsh_format,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#undef OFFSET_BNSH
|
||||
#undef OFFSET_BSNH
|
||||
|
|
|
@ -51,6 +51,28 @@ Status QkvToContext(
|
|||
contrib::GroupQueryAttentionParameters& parameters,
|
||||
GroupQueryAttentionData<T>& data);
|
||||
|
||||
template <typename T, bool output_bnsh>
|
||||
Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads,
|
||||
const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size,
|
||||
cudaStream_t stream, const int max_threads_per_block);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchConcatKVInPlace(int batch_size,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int max_sequence_length, // max sequence length of present_key or present_value.
|
||||
const int* past_seqlens_k, // it is not used when total_seqlens_k is available.
|
||||
const int* total_seqlens_k, // optional, nullptr means it is not available.
|
||||
int new_seq_len,
|
||||
const T* new_key,
|
||||
const T* new_value,
|
||||
T* present_key,
|
||||
T* present_value,
|
||||
bool is_past_kv_bnsh_format,
|
||||
bool is_new_kv_bnsh_format,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -18,120 +18,120 @@ namespace contrib {
|
|||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void RotaryEmbeddingBSNH(T *output, // BxSxNxH
|
||||
const T *input, // BxSxNxH
|
||||
const T *cos_cache, // Mx(H/2)
|
||||
const T *sin_cache, // Mx(H/2)
|
||||
const int64_t *position_ids, // (1) or BxS
|
||||
__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
|
||||
const T* input, // BxSxNxH
|
||||
const T* cos_cache, // Mx(H/2)
|
||||
const T* sin_cache, // Mx(H/2)
|
||||
const int64_t* position_ids, // (1) or BxS
|
||||
const int sequence_length, const int num_heads, const int head_size,
|
||||
const int rotary_embedding_dim, const int position_ids_format,
|
||||
const bool interleaved, const int batch_stride, const int seq_stride,
|
||||
const int head_stride) {
|
||||
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
|
||||
// Use .x in innermost loop to access global memory efficiently
|
||||
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
|
||||
// Use .x in innermost loop to access global memory efficiently
|
||||
|
||||
const int b = blockIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int n = blockIdx.z;
|
||||
const int b = blockIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int n = blockIdx.z;
|
||||
|
||||
const int i = threadIdx.x;
|
||||
const int i = threadIdx.x;
|
||||
|
||||
if (i >= head_size) {
|
||||
return;
|
||||
}
|
||||
if (i >= head_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
|
||||
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
|
||||
|
||||
const T *input_data = input + block_offset;
|
||||
T *output_data = output + block_offset;
|
||||
const T* input_data = input + block_offset;
|
||||
T* output_data = output + block_offset;
|
||||
|
||||
if (i >= rotary_embedding_dim) {
|
||||
output_data[i] = input_data[i];
|
||||
return;
|
||||
}
|
||||
if (i >= rotary_embedding_dim) {
|
||||
output_data[i] = input_data[i];
|
||||
return;
|
||||
}
|
||||
|
||||
// Cache is (M, H/2)
|
||||
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
|
||||
const int position_id = (position_ids_format == 0) ? static_cast<int>(position_ids[0]) + s
|
||||
: static_cast<int>(position_ids[b * sequence_length + s]);
|
||||
const int cache_offset = position_id * half_rotary_embedding_dim;
|
||||
const T *cos_data = cos_cache + cache_offset;
|
||||
const T *sin_data = sin_cache + cache_offset;
|
||||
// Cache is (M, H/2)
|
||||
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
|
||||
const int position_id = (position_ids_format == 0) ? static_cast<int>(position_ids[0]) + s
|
||||
: static_cast<int>(position_ids[b * sequence_length + s]);
|
||||
const int cache_offset = position_id * half_rotary_embedding_dim;
|
||||
const T* cos_data = cos_cache + cache_offset;
|
||||
const T* sin_data = sin_cache + cache_offset;
|
||||
|
||||
int cache_idx = 0;
|
||||
T sign = 0;
|
||||
int j = 0;
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_embedding_dim;
|
||||
sign = (i % 2 == 0) ? -1 : 1;
|
||||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_embedding_dim;
|
||||
sign = (i < half_rotary_embedding_dim) ? -1 : 1;
|
||||
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
|
||||
}
|
||||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
|
||||
int cache_idx = 0;
|
||||
T sign = 0;
|
||||
int j = 0;
|
||||
if (interleaved) {
|
||||
cache_idx = (i / 2) % half_rotary_embedding_dim;
|
||||
sign = (i % 2 == 0) ? -1 : 1;
|
||||
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
|
||||
} else {
|
||||
cache_idx = i % half_rotary_embedding_dim;
|
||||
sign = (i < half_rotary_embedding_dim) ? -1 : 1;
|
||||
j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
|
||||
}
|
||||
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T *output, const T *input, const int64_t *position_ids,
|
||||
const T *cos_cache, const T *sin_cache, const int batch_size,
|
||||
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
|
||||
const T* cos_cache, const T* sin_cache, const int batch_size,
|
||||
const int sequence_length, const int num_heads, const int head_size,
|
||||
const int rotary_embedding_dim, const int /*max_sequence_length*/,
|
||||
const int position_ids_format, const bool interleaved,
|
||||
const int max_threads_per_block, const bool transposed) {
|
||||
// Note: Current implementation assumes head_size <= max_threads_per_block
|
||||
// because head_size is currently large for LLaMA-2. For smaller head_size
|
||||
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
|
||||
// instead. This will require kernel changes to support.
|
||||
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
|
||||
const int max_threads_per_block, const bool is_input_bnsh_format) {
|
||||
// Note: Current implementation assumes head_size <= max_threads_per_block
|
||||
// because head_size is currently large for LLaMA-2. For smaller head_size
|
||||
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
|
||||
// instead. This will require kernel changes to support.
|
||||
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
|
||||
|
||||
int tpb = (head_size + 31) / 32 * 32;
|
||||
int tpb = (head_size + 31) / 32 * 32;
|
||||
|
||||
const dim3 block(tpb);
|
||||
const dim3 grid(sequence_length, batch_size, num_heads);
|
||||
const dim3 block(tpb);
|
||||
const dim3 grid(sequence_length, batch_size, num_heads);
|
||||
|
||||
// Default input tensor shape is [batch, seq, hidden_size]
|
||||
int head_stride = head_size;
|
||||
int seq_stride = num_heads * head_stride;
|
||||
int batch_stride = sequence_length * seq_stride;
|
||||
if (transposed) {
|
||||
// When transposed, input tensor shape is [batch, num_heads, seq, head_size]
|
||||
seq_stride = head_size;
|
||||
head_stride = sequence_length * seq_stride;
|
||||
batch_stride = num_heads * head_stride;
|
||||
}
|
||||
// Default input tensor shape is [batch, seq, hidden_size]
|
||||
int head_stride = head_size;
|
||||
int seq_stride = num_heads * head_stride;
|
||||
int batch_stride = sequence_length * seq_stride;
|
||||
if (is_input_bnsh_format) {
|
||||
seq_stride = head_size;
|
||||
head_stride = sequence_length * seq_stride;
|
||||
batch_stride = num_heads * head_stride;
|
||||
}
|
||||
|
||||
assert(head_size <= max_threads_per_block);
|
||||
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
|
||||
num_heads, head_size, rotary_embedding_dim, position_ids_format,
|
||||
interleaved, batch_stride, seq_stride, head_stride);
|
||||
assert(head_size <= max_threads_per_block);
|
||||
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
|
||||
num_heads, head_size, rotary_embedding_dim, position_ids_format,
|
||||
interleaved, batch_stride, seq_stride, head_stride);
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float *output, const float *input,
|
||||
const int64_t *position_ids, const float *cos_cache,
|
||||
const float *sin_cache, const int batch_size,
|
||||
template Status LaunchRotaryEmbeddingKernel<float>(cudaStream_t stream, float* output, const float* input,
|
||||
const int64_t* position_ids, const float* cos_cache,
|
||||
const float* sin_cache, const int batch_size,
|
||||
const int sequence_length, const int num_heads, const int head_size,
|
||||
const int rotary_embedding_dim, const int max_sequence_length,
|
||||
const int position_ids_format, const bool interleaved,
|
||||
const int max_threads_per_block, const bool transposed);
|
||||
const int max_threads_per_block, const bool is_input_bnsh_format);
|
||||
|
||||
template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half *output, const half *input,
|
||||
const int64_t *position_ids, const half *cos_cache,
|
||||
const half *sin_cache, const int batch_size,
|
||||
template Status LaunchRotaryEmbeddingKernel<half>(cudaStream_t stream, half* output, const half* input,
|
||||
const int64_t* position_ids, const half* cos_cache,
|
||||
const half* sin_cache, const int batch_size,
|
||||
const int sequence_length, const int num_heads, const int head_size,
|
||||
const int rotary_embedding_dim, const int max_sequence_length,
|
||||
const int position_ids_format, const bool interleaved,
|
||||
const int max_threads_per_block, const bool transposed);
|
||||
const int max_threads_per_block, const bool is_input_bnsh_format);
|
||||
|
||||
template Status LaunchRotaryEmbeddingKernel<BFloat16>(
|
||||
cudaStream_t stream, BFloat16 *output, const BFloat16 *input, const int64_t *position_ids,
|
||||
const BFloat16 *cos_cache, const BFloat16 *sin_cache, const int batch_size, const int sequence_length,
|
||||
cudaStream_t stream, BFloat16* output, const BFloat16* input, const int64_t* position_ids,
|
||||
const BFloat16* cos_cache, const BFloat16* sin_cache, const int batch_size, const int sequence_length,
|
||||
const int num_heads, const int head_size, const int rotary_embedding_dim, const int max_sequence_length,
|
||||
const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed);
|
||||
const int position_ids_format, const bool interleaved, const int max_threads_per_block,
|
||||
const bool is_input_bnsh_format);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -26,7 +26,7 @@ Status LaunchRotaryEmbeddingKernel(
|
|||
const int position_ids_format,
|
||||
const bool interleaved,
|
||||
const int max_threads_per_block,
|
||||
const bool transposed);
|
||||
const bool is_input_bnsh_format);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
@ -258,8 +258,12 @@ class FusedMultiHeadCrossAttentionKernel
|
|||
<< "\t force_unroll: " << param.force_unroll << "\n";
|
||||
}
|
||||
|
||||
int32_t getSForUnroll(Fused_multihead_attention_params_mhca const& param) const override {
|
||||
return param.s_q;
|
||||
dim3 getGridDim(const FusedMultiHeadCrossAttentionKernelMetaInfoV2& kernelMeta,
|
||||
const Fused_multihead_attention_params_mhca& params) const override {
|
||||
dim3 gridDim(params.h,
|
||||
params.b,
|
||||
params.force_unroll ? ((params.s_q + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep) : 1);
|
||||
return gridDim;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -59,6 +59,8 @@ CUDADriverWrapper::CUDADriverWrapper() {
|
|||
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*reinterpret_cast<void**>(&_cuDeviceGetAttribute) = load_sym(handle, "cuDeviceGetAttribute");
|
||||
*reinterpret_cast<void**>(&_cuFuncSetCacheConfig) = load_sym(handle, "cuFuncSetCacheConfig");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper() {
|
||||
|
@ -73,6 +75,14 @@ CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attr
|
|||
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuDeviceGetAttribute(int* pi, CUdevice_attribute attrib, CUdevice dev) const {
|
||||
return (*_cuDeviceGetAttribute)(pi, attrib, dev);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) const {
|
||||
return (*_cuFuncSetCacheConfig)(hfunc, config);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const {
|
||||
return (*_cuLinkComplete)(state, cubinOut, sizeOut);
|
||||
}
|
||||
|
@ -126,6 +136,13 @@ CUresult CUDADriverWrapper::cuLaunchKernel(
|
|||
f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra);
|
||||
}
|
||||
|
||||
// Initialize the singleton instance
|
||||
CUDADriverWrapper CUDADriverWrapper::instance;
|
||||
|
||||
const CUDADriverWrapper* CUDADriverWrapper::GetInstance() {
|
||||
return &instance;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -45,6 +45,10 @@ class CUDADriverWrapper {
|
|||
|
||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||
|
||||
CUresult cuDeviceGetAttribute(int* pi, CUdevice_attribute attrib, CUdevice dev) const;
|
||||
|
||||
CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) const;
|
||||
|
||||
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
|
||||
|
||||
CUresult cuModuleUnload(CUmodule hmod) const;
|
||||
|
@ -73,10 +77,14 @@ class CUDADriverWrapper {
|
|||
uint32_t blockDimY, uint32_t blockDimZ, uint32_t sharedMemBytes, CUstream hStream, void** kernelParams,
|
||||
void** extra) const;
|
||||
|
||||
static const CUDADriverWrapper* GetInstance();
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUresult (*_cuGetErrorName)(CUresult, const char**);
|
||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||
CUresult (*_cuDeviceGetAttribute)(int*, CUdevice_attribute, CUdevice);
|
||||
CUresult (*_cuFuncSetCacheConfig)(CUfunction, CUfunc_cache);
|
||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||
CUresult (*_cuModuleUnload)(CUmodule);
|
||||
CUresult (*_cuLinkDestroy)(CUlinkState);
|
||||
|
@ -92,6 +100,8 @@ class CUDADriverWrapper {
|
|||
CUfunction f, uint32_t gridDimX, uint32_t gridDimY, uint32_t gridDimZ,
|
||||
uint32_t blockDimX, uint32_t blockDimY, uint32_t blockDimZ, uint32_t sharedMemBytes, CUstream hStream,
|
||||
void** kernelParams, void** extra);
|
||||
|
||||
static CUDADriverWrapper instance;
|
||||
};
|
||||
|
||||
inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line) {
|
||||
|
|
|
@ -311,8 +311,12 @@ class FusedMultiHeadFlashAttentionKernel
|
|||
<< "\t force_unroll: " << param.force_unroll << "\n";
|
||||
}
|
||||
|
||||
int32_t getSForUnroll(Fused_multihead_attention_params_v2 const& param) const override {
|
||||
return param.s;
|
||||
dim3 getGridDim(const FusedMultiHeadFlashAttentionKernelMetaInfoV2& kernelMeta,
|
||||
const Fused_multihead_attention_params_v2& params) const override {
|
||||
dim3 gridDim(params.h,
|
||||
params.b,
|
||||
params.force_unroll ? ((params.s + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep) : 1);
|
||||
return gridDim;
|
||||
}
|
||||
|
||||
uint64_t hashID(KernelMeta const& kernelMeta) const {
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
|
||||
template <typename TKernelMeta, typename TKernelParam>
|
||||
class TSharedCubinKernel {
|
||||
public:
|
||||
|
@ -105,7 +106,7 @@ class TSharedCubinKernel {
|
|||
|
||||
virtual void dumpHashId(TKernelParam const& param, std::ostringstream& message) const = 0;
|
||||
|
||||
virtual int32_t getSForUnroll(TKernelParam const& param) const = 0;
|
||||
virtual dim3 getGridDim(const TKernelMeta& kernelMeta, const TKernelParam& params) const = 0;
|
||||
|
||||
virtual void run(TKernelParam& params, cudaStream_t ss) const {
|
||||
ORT_ENFORCE(!params.interleaved); // interleaved is for int8
|
||||
|
@ -126,16 +127,10 @@ class TSharedCubinKernel {
|
|||
CUfunction const func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
if (!params.force_unroll) {
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, ss, kernelParams, nullptr),
|
||||
mDriver);
|
||||
} else {
|
||||
int32_t unroll = (getSForUnroll(params) + kernelMeta.mUnrollStep - 1) / kernelMeta.mUnrollStep;
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, ss, kernelParams, nullptr),
|
||||
mDriver);
|
||||
}
|
||||
|
||||
dim3 gridDim = getGridDim(kernelMeta, params);
|
||||
cuErrCheck(mDriver.cuLaunchKernel(func, gridDim.x, gridDim.y, gridDim.z, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, ss, kernelParams, nullptr), mDriver);
|
||||
}
|
||||
|
||||
virtual ~TSharedCubinKernel() = default;
|
||||
|
|
|
@ -6,207 +6,221 @@
|
|||
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
// Macros to avoid long line length
|
||||
#define CUDA_MS_OP_CLASS_NAME(ver, name) \
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, name)
|
||||
#define CUDA_MS_OP_TYPED_CLASS_NAME(ver, type, name) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, ver, type, name)
|
||||
#define CUDA_MS_OP_VERSIONED_TYPED_CLASS_NAME(start_ver, end_ver, type, name) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, type, name)
|
||||
#define CUDA_MS_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, start_ver, end_ver, name)
|
||||
|
||||
#define CUDA_ONNX_OP_TYPED_CLASS_NAME(ver, type, name) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, ver, type, name)
|
||||
#define CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(start_ver, end_ver, type, name) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, start_ver, end_ver, type, name)
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GridSample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Irfft);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Irfft);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Irfft);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMulConj);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasSoftmax);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskBiasDropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasAdd);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasAdd);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, QuickGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, QuickGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QuickGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, TransposeMatMul); // backward compatibility
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, TransposeMatMul); // backward compatibility
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, TransposeMatMul); // backward compatibility
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedMatMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FusedMatMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FusedMatMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RelativePositionBias);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RelativePositionBias);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GatedRelativePositionBias);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GatedRelativePositionBias);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RemovePadding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RemovePadding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RestorePadding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RestorePadding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Rfft);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Rfft);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Rfft);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Irfft);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Irfft);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Irfft);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMulConj);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMulConj);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BiasSoftmax);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BiasDropout);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BitmaskDropout);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BitmaskBiasDropout);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, NGramRepeatBlock);
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to maintain backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Affine);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Affine);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Attention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Attention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BeamSearch);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WhisperBeamSearch);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LongformerAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedMatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QuantizeWithOrder);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DequantizeWithOrder);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Affine);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Affine);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Affine);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Attention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Attention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedMultiHeadAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedMultiHeadAttention);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BeamSearch);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, WhisperBeamSearch);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ConvTransposeWithDynamicPads);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Crop);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QMoE);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int64_t, DynamicSlice);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, EmbedLayerNormalization);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, EmbedLayerNormalization);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, GreedySearch);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, GroupNorm);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, NhwcConv);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, NhwcConv);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ImageScaler);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ImageScaler);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ImageScaler);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, LongformerAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, LongformerAttention);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ParametricSoftplus);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ParametricSoftplus);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ParametricSoftplus);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RotaryEmbedding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RotaryEmbedding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, RotaryEmbedding);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GemmaRotaryEmbedding);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, Sampling);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ScaledTanh);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ScaledTanh);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ScaledTanh);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, SkipGroupNorm);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu);
|
||||
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_float, LayerNormalization);
|
||||
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, double_double_double, LayerNormalization);
|
||||
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_MLFloat16, LayerNormalization);
|
||||
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_MLFloat16, LayerNormalization);
|
||||
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_float, LayerNormalization);
|
||||
class CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, BFloat16_float_BFloat16, LayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_float, SimplifiedLayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double_double_double, SimplifiedLayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_MLFloat16, SimplifiedLayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_float, SimplifiedLayerNormalization);
|
||||
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, BFloat16_float_BFloat16, SimplifiedLayerNormalization);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, Inverse);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulNBits);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, Trilu);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, QuantizeLinear);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, QuantizeLinear);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, DequantizeLinear);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QOrderedLayerNormalization);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QOrderedGelu);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QuantizeWithOrder);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, DequantizeWithOrder);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QOrderedAttention);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QOrderedLongformerAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedSelfAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedSelfAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedMultiHeadAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedMultiHeadAttention);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, GemmFloat8);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention);
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once
|
||||
// 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, ShrunkenGather);
|
||||
#endif
|
||||
|
||||
#if defined(ORT_USE_NCCL)
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, AllReduce);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, AllGather);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, AllToAll);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ShardedMoE);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ShardedMoE);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedMatMul);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedMatMul);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSlice);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSlice);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedReshape);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReshape);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReshape);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedExpand);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedExpand);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedExpand);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceSum);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceSum);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMax);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMax);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMean);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMean);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedUnsqueeze);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedUnsqueeze);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedUnsqueeze);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedUnsqueeze);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedUnsqueeze);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedUnsqueeze);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedSqueeze);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSqueeze);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CUDA_NHWC_OPS
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedSqueeze);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSqueeze);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSqueeze);
|
||||
#endif
|
||||
|
||||
template <>
|
||||
|
@ -218,206 +232,205 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasAdd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RelativePositionBias)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RemovePadding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RestorePadding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Irfft)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Irfft)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Irfft)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ComplexMulConj)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ComplexMulConj)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NGramRepeatBlock)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasAdd)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasAdd)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, QuickGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, QuickGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QuickGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RelativePositionBias)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RelativePositionBias)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GatedRelativePositionBias)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GatedRelativePositionBias)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RemovePadding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RemovePadding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RestorePadding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RestorePadding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Rfft)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Rfft)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Rfft)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Irfft)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Irfft)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Irfft)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ComplexMulConj)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ComplexMulConj)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, NGramRepeatBlock)>,
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to maintain backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Affine)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Affine)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BeamSearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WhisperBeamSearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GreedySearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GroupNorm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, NhwcConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, LongformerAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, LongformerAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GemmaRotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, double_double_double, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_MLFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, float_float_MLFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, MLFloat16_float_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 16, BFloat16_float_BFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_float, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_double_double, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskBiasDropout)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Affine)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Affine)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Affine)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Attention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Attention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, PackedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PackedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BeamSearch)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, WhisperBeamSearch)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ConvTransposeWithDynamicPads)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, Crop)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int32_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, int64_t, DynamicSlice)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, EmbedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, EmbedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GreedySearch)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GroupNorm)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, NhwcConv)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, NhwcConv)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ImageScaler)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ImageScaler)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ImageScaler)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, LongformerAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, LongformerAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GemmaRotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Sampling)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, SkipGroupNorm)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, double_double_double, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_MLFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, float_float_MLFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, MLFloat16_float_float, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_VERSIONED_TYPED_CLASS_NAME(1, 16, BFloat16_float_BFloat16, LayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_float, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double_double_double, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float_float_MLFloat16, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16_float_float, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, BFloat16_float_BFloat16, SimplifiedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Inverse)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasSoftmax)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasDropout)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BitmaskDropout)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BitmaskBiasDropout)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int8_t_MLFloat16, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
|
||||
// TransposedMatMul is still here for backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QuantizeWithOrder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DequantizeWithOrder)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QuantizeWithOrder)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DequantizeWithOrder)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QOrderedLongformerAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedSelfAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedSelfAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderMaskedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DecoderMaskedMultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, GemmFloat8)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SparseAttention)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SparseAttention)>,
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once 1). compute optimizer is enabled for inference or
|
||||
// Should remove the shrunken_gather include from ENABLE_TRAINING_OPS once
|
||||
// 1). compute optimizer is enabled for inference or
|
||||
// 2). this is needed by inference for other purpose.
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, ShrunkenGather)>,
|
||||
#endif
|
||||
|
||||
#if defined(ORT_USE_NCCL)
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, AllReduce)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, AllGather)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, AllToAll)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, ShardedMoE)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, ShardedMoE)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedMatMul)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedMatMul)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSlice)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSlice)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedReshape)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReshape)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReshape)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedExpand)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedExpand)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedExpand)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceSum)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceSum)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMax)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMax)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedReduceMean)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedReduceMean)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedUnsqueeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedUnsqueeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedUnsqueeze)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedUnsqueeze)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedUnsqueeze)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedUnsqueeze)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedSqueeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSqueeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze)>,
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CUDA_NHWC_OPS
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, int64_t, DistributedSqueeze)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DistributedSqueeze)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, DistributedSqueeze)>,
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
@ -22,7 +22,10 @@ namespace cuda {
|
|||
onnxruntime::contrib::cuda::GridSample<T, LAYOUT>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
|
||||
|
||||
#ifdef ENABLE_CUDA_NHWC_OPS
|
||||
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
|
||||
#endif
|
||||
|
||||
template <typename T, bool IsNHWC>
|
||||
GridSample<T, IsNHWC>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/sparse/block_mask.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
__global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_indices, int num_rows, int num_cols) {
|
||||
int row = threadIdx.x;
|
||||
if (row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update input and output data pointers to the start of current head
|
||||
int head = blockIdx.x;
|
||||
mask += head * num_rows * num_cols;
|
||||
csr_row_indices += head * (num_rows + 1);
|
||||
csr_col_indices += head * num_rows * num_cols;
|
||||
|
||||
int count = 0;
|
||||
for (int col = 0; col < num_cols; col++) {
|
||||
if (mask[row * num_cols + col] == 1) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
|
||||
extern __shared__ int non_zero_counts[];
|
||||
non_zero_counts[threadIdx.x] = count;
|
||||
__syncthreads();
|
||||
|
||||
// The first thread will calculate the accumulated partial sum of non-zero counts.
|
||||
if (row == 0) {
|
||||
for (int i = 1; i < num_rows; i++) {
|
||||
non_zero_counts[i] += non_zero_counts[i - 1];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The starting index of current row in csr_col_indices
|
||||
int offset = (row == 0) ? 0 : non_zero_counts[row - 1];
|
||||
|
||||
// Output row indices.
|
||||
csr_row_indices[row] = offset;
|
||||
if (row == 0) {
|
||||
// The first thread output the last element.
|
||||
csr_row_indices[num_rows] = non_zero_counts[num_rows - 1];
|
||||
}
|
||||
|
||||
for (int col = 0; col < num_cols; col++) {
|
||||
if (mask[row * num_cols + col] == 1) {
|
||||
csr_col_indices[offset] = col;
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
|
||||
// Note that the remaining buffer in csr_col_indices are not filled with dummy value, but it's fine.
|
||||
// The last element of csr_row_indices is the total number of non-zero elements.
|
||||
}
|
||||
|
||||
void ConvertMaskToCSR(cudaStream_t stream,
|
||||
const int* mask, // input mask with shape (num_layout, num_rows, num_cols)
|
||||
int num_layout, // number of layouts
|
||||
int num_rows, // number of rows
|
||||
int num_cols, // number of columns
|
||||
int* csr_row_indices, // output CSR row indices
|
||||
int* csr_col_indices, // output CSR column indices
|
||||
int max_threads_per_block) {
|
||||
int threads_per_block = (num_rows + 31) / 32 * 32;
|
||||
|
||||
// Each thread handle one row. The kernel assumes that all rows of one head can be handled in one block.
|
||||
if (threads_per_block > max_threads_per_block) {
|
||||
ORT_THROW("num_rows is too large: num_rows=", num_rows, ", max_threads_per_block=", max_threads_per_block);
|
||||
}
|
||||
|
||||
MaskToCSR<<<num_layout, threads_per_block, threads_per_block * sizeof(int), stream>>>(
|
||||
mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
// Convert mask to compressed sparse row (CSR) format ( https://en.wikipedia.org/wiki/Sparse_matrix)
|
||||
// For example, num_layout=1, num_rows=4 and num_cols=4, and the mask is like
|
||||
// 1, 0, 0, 0
|
||||
// 1, 1, 0, 0
|
||||
// 0, 1, 1, 0
|
||||
// 0, 1, 1, 1
|
||||
// The CSR format is like:
|
||||
// csr_col_indices:
|
||||
// 0, 0, 1, 1, 2, 1, 2, 3, 0*, 0*, 0*, 0*, 0*, 0*, 0*, 0* (* is padding)
|
||||
// csr_row_indices:
|
||||
// 0, 1, 3, 5, 8
|
||||
void ConvertMaskToCSR(cudaStream_t stream,
|
||||
const int* mask, // input mask with shape (num_layout, num_rows, num_cols)
|
||||
int num_layout, // number of layout
|
||||
int num_rows, // number of rows of block_mask
|
||||
int num_cols, // number of cols of block_mask
|
||||
int* csr_row_indices, // output CSR row indices with shape (num_layout, num_rows + 1).
|
||||
int* csr_col_indices, // output CSR col indices with shape (num_layout, num_rows * num_cols).
|
||||
int max_threads_per_block);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,338 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_helper.h"
|
||||
#include "contrib_ops/cuda/sparse/block_mask.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
SparseAttention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()) \
|
||||
.MayInplace(3, 1) \
|
||||
.MayInplace(4, 2) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 6), \
|
||||
SparseAttention<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_KERNEL_TYPED(BFloat16)
|
||||
|
||||
static inline int32_t DivUp(int32_t m, int32_t n) {
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SparseAttention<T>::SparseAttention(const OpKernelInfo& info)
|
||||
: CudaKernel(info) {
|
||||
int64_t num_heads = 0;
|
||||
int64_t kv_num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
|
||||
int64_t sparse_block_size = 0;
|
||||
ORT_ENFORCE(info.GetAttr("sparse_block_size", &sparse_block_size).IsOK());
|
||||
ORT_ENFORCE(sparse_block_size == 64 || sparse_block_size == 128);
|
||||
sparse_block_size_ = static_cast<int>(sparse_block_size);
|
||||
|
||||
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
|
||||
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
|
||||
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
kernel_loaded_ = false;
|
||||
|
||||
disable_v1_kernel_ = ParseEnvironmentVariableWithDefault<bool>(sparse_attention::kDisableSparseAttentionV1, false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* query = context->Input<Tensor>(0);
|
||||
const Tensor* key = context->Input<Tensor>(1);
|
||||
const Tensor* value = context->Input<Tensor>(2);
|
||||
const Tensor* past_key = context->Input<Tensor>(3);
|
||||
const Tensor* past_value = context->Input<Tensor>(4);
|
||||
const Tensor* block_mask = context->Input<Tensor>(5);
|
||||
const Tensor* total_seq_len = context->Input<Tensor>(6);
|
||||
const Tensor* seqlens_k_total = context->Input<Tensor>(7);
|
||||
const Tensor* cos_cache = context->Input<Tensor>(8);
|
||||
const Tensor* sin_cache = context->Input<Tensor>(9);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
|
||||
SparseAttentionParameters parameters;
|
||||
|
||||
// Parameters from node attribute
|
||||
parameters.sparse_block_size = sparse_block_size_;
|
||||
parameters.num_heads = num_heads_;
|
||||
parameters.kv_num_heads = kv_num_heads_;
|
||||
parameters.scale = scale_;
|
||||
parameters.do_rotary = do_rotary_;
|
||||
parameters.rotary_interleaved = rotary_interleaved_;
|
||||
|
||||
ORT_RETURN_IF_ERROR(sparse_attention_helper::CheckInputs(¶meters,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
past_key,
|
||||
past_value,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
block_mask,
|
||||
seqlens_k_total,
|
||||
total_seq_len));
|
||||
|
||||
// Some limitations of CUDA kernels
|
||||
if (!sparse_attention_v1::is_supported_sparse_attention(device_prop)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"SparseAttention only support CUDA device with compute capacity 8.*. Got ",
|
||||
device_prop.major);
|
||||
}
|
||||
if (!sparse_attention_v1::is_supported_sparse_attention(parameters.head_size, sparse_block_size_)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"SparseAttention only support head_size=128 and sparse_block_size=64. Got head_size=",
|
||||
parameters.head_size,
|
||||
",sparse_block_size=",
|
||||
sparse_block_size_);
|
||||
}
|
||||
if (device_prop.maxThreadsPerBlock > 0 && num_heads_ > device_prop.maxThreadsPerBlock) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"num_heads should be no larger than ", device_prop.maxThreadsPerBlock);
|
||||
}
|
||||
|
||||
int past_seq_len = parameters.total_sequence_length - parameters.sequence_length;
|
||||
bool is_prompt = (past_seq_len == 0);
|
||||
bool use_v2_kernel = disable_v1_kernel_ || !is_prompt;
|
||||
|
||||
// Async Copy total_k_seq_len from GPU to CPU.
|
||||
IAllocatorUniquePtr<int32_t> pinned_buffer;
|
||||
int32_t* total_k_seq_len_pinned = nullptr;
|
||||
AutoDestoryCudaEvent new_event;
|
||||
cudaEvent_t& isCopyDone = new_event.Get();
|
||||
cudaStream_t cuda_stream = Stream(context);
|
||||
if (use_v2_kernel) {
|
||||
pinned_buffer = AllocateBufferOnCPUPinned<int32_t>(parameters.batch_size);
|
||||
|
||||
total_k_seq_len_pinned = reinterpret_cast<int32_t*>(pinned_buffer.get());
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(total_k_seq_len_pinned,
|
||||
seqlens_k_total->Data<int32_t>(),
|
||||
sizeof(int32_t) * parameters.batch_size,
|
||||
cudaMemcpyDeviceToHost, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaEventCreate(&isCopyDone));
|
||||
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream));
|
||||
}
|
||||
|
||||
if (!kernel_loaded_) {
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
|
||||
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
|
||||
if (use_v2_kernel) {
|
||||
sparse_attention_v2::load_sparse_attention_fp16();
|
||||
} else {
|
||||
sparse_attention_v1::load_sparse_attention_fp16();
|
||||
}
|
||||
} else {
|
||||
if (use_v2_kernel) {
|
||||
sparse_attention_v2::load_sparse_attention_bf16();
|
||||
} else {
|
||||
sparse_attention_v1::load_sparse_attention_bf16();
|
||||
}
|
||||
}
|
||||
kernel_loaded_ = true;
|
||||
}
|
||||
|
||||
// Compute output shape and get output tensors.
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(parameters.hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
std::vector<int64_t> present_dims = {
|
||||
parameters.batch_size, parameters.kv_num_heads, parameters.max_sequence_length, parameters.head_size};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context->Output(1, present_shape);
|
||||
Tensor* present_value = context->Output(2, present_shape);
|
||||
|
||||
// Set input and output data.
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
SparseAttentionData<CudaT> data;
|
||||
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
|
||||
data.key = key == nullptr ? nullptr : reinterpret_cast<const CudaT*>(key->Data<T>());
|
||||
data.value = value == nullptr ? nullptr : reinterpret_cast<const CudaT*>(value->Data<T>());
|
||||
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast<const CudaT*>(past_key->Data<T>());
|
||||
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast<const CudaT*>(past_value->Data<T>());
|
||||
data.block_mask = block_mask->Data<int32_t>();
|
||||
data.seqlens_k_total = (nullptr == seqlens_k_total) ? nullptr : seqlens_k_total->Data<int32_t>();
|
||||
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
|
||||
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
|
||||
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
|
||||
|
||||
// Check past and present share buffer.
|
||||
parameters.past_present_share_buffer = (data.past_key != nullptr && data.past_key == data.present_key);
|
||||
if (parameters.past_present_share_buffer) {
|
||||
ORT_ENFORCE(data.past_value != nullptr && data.past_value == data.present_value);
|
||||
}
|
||||
if (!parameters.past_present_share_buffer) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"CUDA implementation of SparseAttention requires past and present points to same buffer");
|
||||
}
|
||||
|
||||
if (parameters.do_rotary) {
|
||||
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
|
||||
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
|
||||
}
|
||||
|
||||
// Currently, we use same block size in kernel.
|
||||
// TODO: support kernel block size that is smaller than sparse_block_size in tunable (need expand block mask).
|
||||
data.kernel_layout.block_size = parameters.sparse_block_size;
|
||||
data.kernel_layout.mask = data.block_mask;
|
||||
data.kernel_layout.num_layout = parameters.num_sparse_layout;
|
||||
data.kernel_layout.num_cols = parameters.max_sequence_length / data.kernel_layout.block_size;
|
||||
data.kernel_layout.num_rows = parameters.max_sequence_length / data.kernel_layout.block_size;
|
||||
|
||||
// Allocate buffer for CSR col and row indices.
|
||||
onnxruntime::Stream* stream = context->GetComputeStream();
|
||||
int dense_blocks = data.kernel_layout.num_layout * data.kernel_layout.num_cols * data.kernel_layout.num_rows;
|
||||
auto csr_col_indices_buffer = GetScratchBuffer<int>(static_cast<size_t>(dense_blocks), stream);
|
||||
auto csr_row_indices_buffer = GetScratchBuffer<int>(
|
||||
static_cast<size_t>(data.kernel_layout.num_layout * (data.kernel_layout.num_rows + 1)), stream);
|
||||
|
||||
data.kernel_layout.csr_col_indices = reinterpret_cast<const int*>(csr_col_indices_buffer.get());
|
||||
data.kernel_layout.csr_row_indices = reinterpret_cast<const int*>(csr_row_indices_buffer.get());
|
||||
|
||||
ConvertMaskToCSR(cuda_stream,
|
||||
data.kernel_layout.mask,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows,
|
||||
data.kernel_layout.num_cols,
|
||||
csr_row_indices_buffer.get(),
|
||||
csr_col_indices_buffer.get(),
|
||||
device_prop.maxThreadsPerBlock);
|
||||
|
||||
size_t rotary_buffer_bytes = 0;
|
||||
if (do_rotary_) {
|
||||
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads *
|
||||
parameters.sequence_length * parameters.head_size;
|
||||
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
|
||||
}
|
||||
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
|
||||
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
|
||||
|
||||
size_t transposed_q_bytes = 0;
|
||||
if (!parameters.is_packed_qkv) {
|
||||
transposed_q_bytes = parameters.batch_size * parameters.sequence_length *
|
||||
parameters.num_heads * parameters.head_size * sizeof(T);
|
||||
}
|
||||
auto transposed_q_buffer = GetScratchBuffer<void>(transposed_q_bytes, context->GetComputeStream());
|
||||
if (transposed_q_buffer) {
|
||||
data.transposed_q_buffer = reinterpret_cast<CudaT*>(transposed_q_buffer.get());
|
||||
}
|
||||
|
||||
size_t unpacked_qkv_bytes = 0;
|
||||
if (parameters.is_packed_qkv) {
|
||||
unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length *
|
||||
(parameters.num_heads + 2 * parameters.kv_num_heads) *
|
||||
parameters.head_size * sizeof(T));
|
||||
}
|
||||
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
|
||||
if (unpacked_qkv_buffer) {
|
||||
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
|
||||
}
|
||||
|
||||
// Prepare some v2 kernel inputs in CPU then copy to GPU.
|
||||
IAllocatorUniquePtr<int32_t> v2_kernel_inputs_pinned_buffer;
|
||||
IAllocatorUniquePtr<int32_t> v2_kernel_buffer;
|
||||
data.use_v2_kernel = use_v2_kernel;
|
||||
if (use_v2_kernel) {
|
||||
// Compute activate q blocks so that we know the size of buffer to allocate.
|
||||
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone));
|
||||
int active_q_blocks = 0;
|
||||
if (is_prompt) {
|
||||
for (int i = 0; i < parameters.batch_size; i++) {
|
||||
active_q_blocks += DivUp(is_prompt ? total_k_seq_len_pinned[i] : 1, data.kernel_layout.block_size);
|
||||
}
|
||||
} else { // not prompt
|
||||
assert(parameters.sequence_length == 1);
|
||||
active_q_blocks = parameters.batch_size;
|
||||
}
|
||||
|
||||
// Compute buffer size: addresses of 6 buffers for v2 kernel need to be aligned to 16.
|
||||
const size_t aligned_batch_size = DivUp(parameters.batch_size, 16) * 16;
|
||||
const size_t aligned_num_q_blocks = DivUp(active_q_blocks, 16) * 16;
|
||||
size_t v2_kernel_buffer_size = 4 * aligned_batch_size + 2 * aligned_num_q_blocks;
|
||||
|
||||
// Compute those values in CPU, then copy to GPU
|
||||
v2_kernel_inputs_pinned_buffer = AllocateBufferOnCPUPinned<int32_t>(v2_kernel_buffer_size);
|
||||
int32_t* v2_kernel_inputs_pinned = reinterpret_cast<int32_t*>(v2_kernel_inputs_pinned_buffer.get());
|
||||
int32_t* q_batch_starts = v2_kernel_inputs_pinned;
|
||||
int32_t* q_batch_ends = q_batch_starts + aligned_batch_size;
|
||||
int32_t* k_batch_starts = q_batch_ends + aligned_batch_size;
|
||||
int32_t* k_batch_ends = k_batch_starts + aligned_batch_size;
|
||||
int32_t* q_batch_ids = k_batch_ends + aligned_batch_size;
|
||||
int32_t* q_start_sids = q_batch_ids + aligned_num_q_blocks;
|
||||
|
||||
// Here assumes right-side padding
|
||||
if (is_prompt) {
|
||||
for (int i = 0; i < parameters.batch_size; i++) {
|
||||
q_batch_starts[i] = 0;
|
||||
q_batch_ends[i] = total_k_seq_len_pinned[i];
|
||||
k_batch_starts[i] = 0;
|
||||
k_batch_ends[i] = total_k_seq_len_pinned[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < parameters.batch_size; i++) {
|
||||
q_batch_starts[i] = 0;
|
||||
q_batch_ends[i] = 1;
|
||||
k_batch_starts[i] = 0;
|
||||
k_batch_ends[i] = total_k_seq_len_pinned[i];
|
||||
}
|
||||
}
|
||||
|
||||
int current_block = 0;
|
||||
for (int i = 0; i < parameters.batch_size; i++) {
|
||||
int blocks = DivUp(q_batch_ends[i] - q_batch_starts[i], data.kernel_layout.block_size);
|
||||
for (int j = 0; j < blocks; j++) {
|
||||
q_batch_ids[current_block] = i;
|
||||
q_start_sids[current_block] = j * data.kernel_layout.block_size;
|
||||
current_block++;
|
||||
}
|
||||
}
|
||||
|
||||
v2_kernel_buffer = GetScratchBuffer<int>(v2_kernel_buffer_size, context->GetComputeStream());
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(v2_kernel_buffer.get(), v2_kernel_inputs_pinned,
|
||||
sizeof(int32_t) * v2_kernel_buffer_size,
|
||||
cudaMemcpyHostToDevice, cuda_stream));
|
||||
|
||||
data.q_batch_starts = v2_kernel_buffer.get();
|
||||
data.q_batch_ends = data.q_batch_starts + aligned_batch_size;
|
||||
data.k_batch_starts = data.q_batch_ends + aligned_batch_size;
|
||||
data.k_batch_ends = data.k_batch_starts + aligned_batch_size;
|
||||
data.q_batch_ids = data.k_batch_ends + aligned_batch_size;
|
||||
data.q_start_sids = data.q_batch_ids + aligned_num_q_blocks;
|
||||
data.active_q_blocks = active_q_blocks;
|
||||
}
|
||||
|
||||
return QkvToContext<CudaT>(device_prop, stream, parameters, data);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
using namespace ::onnxruntime::cuda;
|
||||
|
||||
template <typename T>
|
||||
class SparseAttention final : public CudaKernel {
|
||||
public:
|
||||
SparseAttention(const OpKernelInfo& op_kernel_info);
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
protected:
|
||||
int num_heads_; // number of attention heads for q
|
||||
int kv_num_heads_; // number of attention heads for k and v
|
||||
float scale_; // Scaling factor applied prior to softmax.
|
||||
bool is_causal_; // unidirectional attention or not
|
||||
int sparse_block_size_; // block size for sparsity
|
||||
bool do_rotary_; // Has rotary positional embedding
|
||||
bool rotary_interleaved_; // Interleaved rotary positional embedding
|
||||
bool disable_v1_kernel_; // Disable V2 kernel
|
||||
mutable bool kernel_loaded_; // Kernel has been loaded
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,244 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace sparse_attention_helper {
|
||||
|
||||
Status CheckInputs(void* params,
|
||||
const Tensor* query,
|
||||
const Tensor* key,
|
||||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
const Tensor* block_mask,
|
||||
const Tensor* seqlens_k_total,
|
||||
const Tensor* total_seq_len) {
|
||||
// No packing for q/k/v:
|
||||
// query (batch_size, sequence_length, num_heads * head_size)
|
||||
// key (batch_size, kv_sequence_length, kv_num_heads * head_size)
|
||||
// value (batch_size, kv_sequence_length, kv_num_heads * head_size)
|
||||
// Packed q/k/v:
|
||||
// query (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
|
||||
// key nullptr
|
||||
// value nullptr
|
||||
// Shape for other inputs:
|
||||
// past_key (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
|
||||
// past_value (batch_size, kv_num_heads, max_sequence_length, head_size) or nullptr
|
||||
// block_mask (num_heads, max_blocks, max_blocks) or (1, max_blocks, max_blocks)
|
||||
// where max_blocks = max_sequence_length / sparse_block_size
|
||||
// seqlens_k_total (batch_size) when do_rotary is True, optional otherwise
|
||||
// total_seq_len (1)
|
||||
// cos_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
|
||||
// sin_cache (max_sequence_length, rotary_dim / 2) when do_rotary is true.
|
||||
|
||||
assert(params != nullptr);
|
||||
SparseAttentionParameters* parameters = reinterpret_cast<SparseAttentionParameters*>(params);
|
||||
|
||||
// The following parameters shall be set by parsing node attributes before calling CheckInputs.
|
||||
const int num_heads = parameters->num_heads;
|
||||
const int kv_num_heads = parameters->kv_num_heads;
|
||||
const bool do_rotary = parameters->do_rotary;
|
||||
|
||||
constexpr bool is_past_bsnh = false;
|
||||
const bool is_packed_qkv = key == nullptr;
|
||||
|
||||
const auto& query_dims = query->Shape().GetDims();
|
||||
if (query_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
|
||||
query_dims.size());
|
||||
}
|
||||
|
||||
int batch_size = static_cast<int>(query_dims[0]);
|
||||
int sequence_length = static_cast<int>(query_dims[1]);
|
||||
int q_hidden_size = static_cast<int>(query_dims[2]);
|
||||
|
||||
int head_size = 0;
|
||||
int kv_hidden_size = 0;
|
||||
if (!is_packed_qkv) {
|
||||
// Check key and value when not packed
|
||||
head_size = static_cast<int>(q_hidden_size) / num_heads;
|
||||
if (head_size % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size must be a multiple of 8. Got head_size = ",
|
||||
head_size);
|
||||
}
|
||||
if (value == nullptr) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
|
||||
}
|
||||
const auto& key_dims = key->Shape().GetDims();
|
||||
if (key_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
|
||||
key_dims.size());
|
||||
}
|
||||
|
||||
if (query_dims[0] != key_dims[0]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 0 (batch size)");
|
||||
}
|
||||
if (query_dims[1] != key_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'key' shall have same dim 1 (sequence length)");
|
||||
}
|
||||
|
||||
kv_hidden_size = static_cast<int>(key_dims[2]);
|
||||
|
||||
if (key->Shape() != value->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'query' and 'value' shall have same shape");
|
||||
}
|
||||
} else {
|
||||
// packed qkv
|
||||
if (static_cast<int>(q_hidden_size) % (num_heads + 2 * kv_num_heads) != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"packed qkv hidden size= ", q_hidden_size, " does not match num_heads and kv_num_heads",
|
||||
num_heads, kv_num_heads);
|
||||
}
|
||||
|
||||
head_size = static_cast<int>(q_hidden_size) / (num_heads + 2 * kv_num_heads);
|
||||
if (head_size % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size must be a multiple of 8. Got head_size = ", head_size);
|
||||
}
|
||||
|
||||
if (value != nullptr) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
|
||||
}
|
||||
|
||||
q_hidden_size = head_size * num_heads;
|
||||
kv_hidden_size = head_size * kv_num_heads;
|
||||
}
|
||||
|
||||
const auto& block_mask_dim = block_mask->Shape().GetDims();
|
||||
if (!(block_mask_dim.size() == 3 && block_mask_dim[1] == block_mask_dim[2] &&
|
||||
(static_cast<int64_t>(num_heads) % block_mask_dim[0] == 0L))) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"block_mask must have shape (num_layout, max_blocks, max_blocks) where num_heads is divisible by num_layout.");
|
||||
}
|
||||
|
||||
int max_blocks = static_cast<int>(block_mask_dim[1]);
|
||||
int max_sequence_length = max_blocks * parameters->sparse_block_size;
|
||||
|
||||
// Check past-present KV
|
||||
if (past_key != nullptr && past_value != nullptr) {
|
||||
if (past_key->Shape() != past_value->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall have same shape");
|
||||
}
|
||||
|
||||
const auto& past_key_dims = past_key->Shape().GetDims();
|
||||
if (past_key_dims.size() != 4) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' is expected to have 4 dimensions, got ",
|
||||
past_key_dims.size());
|
||||
}
|
||||
|
||||
if (past_key_dims[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 0 should be batch_size ", batch_size, ", got ",
|
||||
past_key_dims[0]);
|
||||
}
|
||||
|
||||
if (past_key_dims[is_past_bsnh ? 2 : 1] != kv_num_heads) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' shall have kv_num_heads");
|
||||
}
|
||||
|
||||
int max_cache_sequence_length = static_cast<int>(past_key_dims[is_past_bsnh ? 1 : 2]);
|
||||
if (max_cache_sequence_length != max_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'block_mask' should have the same sequence length:",
|
||||
"max_sequence_length deduced from past_key is ", max_cache_sequence_length,
|
||||
"; max_sequence_length deduced from block_mask is ", max_sequence_length);
|
||||
}
|
||||
|
||||
if (past_key_dims[3] != head_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' dimension 3 should be same as head_size, got ",
|
||||
past_key_dims[3]);
|
||||
}
|
||||
} else if (past_key != nullptr || past_value != nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall be both present or both absent.");
|
||||
}
|
||||
|
||||
// Check the shape of total_key_sequence_lengths. We do not check the values here.
|
||||
const auto& k_len_dim = seqlens_k_total->Shape().GetDims();
|
||||
if (k_len_dim.size() != 1 && k_len_dim[0] != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"key_total_sequence_lengths must have shape (batch_size).");
|
||||
}
|
||||
|
||||
if (!onnxruntime::IsScalarOr1ElementVector(total_seq_len)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"total_sequence_length tensor must be of one element.");
|
||||
}
|
||||
int total_sequence_length = *((*total_seq_len).template Data<int32_t>());
|
||||
|
||||
int rotary_dim = 0;
|
||||
if (do_rotary) {
|
||||
if (cos_cache == nullptr || sin_cache == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache and sin_cache must be passed to SparseAttention when do_rotary = 1");
|
||||
}
|
||||
|
||||
const auto& cos_dims = cos_cache->Shape().GetDims();
|
||||
const auto& sin_dims = sin_cache->Shape().GetDims();
|
||||
|
||||
if (head_size % 16 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"head_size shall be a multiple of 16. Got head_size = ",
|
||||
head_size);
|
||||
}
|
||||
if (cos_dims[0] < max_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 0 should be of max_sequence_length.");
|
||||
}
|
||||
if (sin_dims[0] < max_sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 0 should be of max_sequence_length.");
|
||||
}
|
||||
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
}
|
||||
if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
|
||||
}
|
||||
if (cos_dims[1] != sin_dims[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"cos_cache and sin_cache dimension 1 must be the same.");
|
||||
}
|
||||
rotary_dim = static_cast<int>(cos_dims[1] * 2);
|
||||
}
|
||||
|
||||
parameters->batch_size = batch_size;
|
||||
parameters->sequence_length = sequence_length;
|
||||
parameters->total_sequence_length = total_sequence_length;
|
||||
parameters->max_sequence_length = max_sequence_length;
|
||||
parameters->hidden_size = q_hidden_size;
|
||||
parameters->head_size = head_size;
|
||||
parameters->kv_hidden_size = kv_hidden_size;
|
||||
parameters->rotary_dim = rotary_dim;
|
||||
parameters->is_packed_qkv = is_packed_qkv;
|
||||
parameters->num_sparse_layout = static_cast<int>(block_mask_dim[0]);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_helper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,334 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_impl.h"
|
||||
#include "contrib_ops/cuda/sparse/block_mask.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
|
||||
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
// Convert total_seq_len_k (total key sequence length excluding paddings) to position_ids for Prompt
|
||||
__global__ void PositionIdsPrompt(const int32_t* total_seq_len_k,
|
||||
int64_t* position_ids,
|
||||
int sequence_length,
|
||||
int batch_size) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid < batch_size * sequence_length) {
|
||||
int b = tid / sequence_length;
|
||||
int s = tid % sequence_length;
|
||||
if (s < total_seq_len_k[b]) {
|
||||
position_ids[tid] = s;
|
||||
} else {
|
||||
// padding
|
||||
position_ids[tid] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert total_seq_len_k (total key sequence length excluding paddings) to position_ids for Token Generation
|
||||
__global__ void PositionIdsToken(const int32_t* total_seq_len_k,
|
||||
int64_t* position_ids,
|
||||
int batch_size) {
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid < batch_size) {
|
||||
position_ids[tid] = total_seq_len_k[tid] - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert total_seq_len_k (total key sequence length excluding paddings) to position_ids
|
||||
Status FillPositionIds(contrib::SparseAttentionParameters& parameters,
|
||||
const int32_t* total_seq_len_k,
|
||||
int64_t* position_ids,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int bs = batch_size * sequence_length;
|
||||
|
||||
int threads = max_threads_per_block;
|
||||
if (bs <= 64) {
|
||||
threads = 64;
|
||||
} else if (bs <= 128) {
|
||||
threads = 128;
|
||||
} else if (bs <= 256) {
|
||||
threads = 256;
|
||||
} else if (bs <= 512) {
|
||||
threads = 512;
|
||||
}
|
||||
const int blocks = (bs + threads - 1) / threads;
|
||||
|
||||
if (parameters.sequence_length == parameters.total_sequence_length) { // prompt
|
||||
PositionIdsPrompt<<<blocks, threads, 0, stream>>>(total_seq_len_k, position_ids, sequence_length, batch_size);
|
||||
} else {
|
||||
PositionIdsToken<<<blocks, threads, 0, stream>>>(total_seq_len_k, position_ids, batch_size);
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
// Concat new key and value (BSNH format) to kv buffer (BNSH format) in place.
|
||||
template <typename T>
|
||||
Status LaunchConcatKVInPlace(contrib::SparseAttentionParameters& parameters,
|
||||
SparseAttentionData<T>& data,
|
||||
const void* new_key,
|
||||
const void* new_value,
|
||||
bool is_new_kv_bnsh_format,
|
||||
cudaStream_t stream,
|
||||
const int max_threads_per_block) {
|
||||
constexpr bool is_past_kv_bnsh_format = true;
|
||||
return LaunchConcatKVInPlace(parameters.batch_size,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.max_sequence_length,
|
||||
nullptr,
|
||||
data.seqlens_k_total,
|
||||
parameters.sequence_length,
|
||||
reinterpret_cast<const T*>(new_key),
|
||||
reinterpret_cast<const T*>(new_value),
|
||||
data.present_key,
|
||||
data.present_value,
|
||||
is_past_kv_bnsh_format,
|
||||
is_new_kv_bnsh_format,
|
||||
stream,
|
||||
max_threads_per_block);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
Stream* ort_stream,
|
||||
contrib::SparseAttentionParameters& parameters,
|
||||
SparseAttentionData<T>& data) {
|
||||
cudaStream_t stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
|
||||
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
|
||||
const int batch_size = parameters.batch_size;
|
||||
const int sequence_length = parameters.sequence_length;
|
||||
// const int present_sequence_length = parameters.max_sequence_length;
|
||||
const int num_heads = parameters.num_heads;
|
||||
const int kv_num_heads = parameters.kv_num_heads;
|
||||
const int head_size = parameters.head_size;
|
||||
|
||||
const void* query;
|
||||
const void* key;
|
||||
const void* value;
|
||||
|
||||
DUMP_TENSOR_INIT();
|
||||
|
||||
if (!parameters.is_packed_qkv) {
|
||||
static_assert(sizeof(T) == 2);
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(
|
||||
batch_size, sequence_length, num_heads, head_size,
|
||||
reinterpret_cast<const half*>(data.query), reinterpret_cast<half*>(data.transposed_q_buffer),
|
||||
stream, max_threads_per_block));
|
||||
query = reinterpret_cast<const void*>(data.transposed_q_buffer);
|
||||
key = reinterpret_cast<const void*>(data.key);
|
||||
value = reinterpret_cast<const void*>(data.value);
|
||||
} else {
|
||||
size_t q_size = static_cast<size_t>(batch_size * sequence_length * num_heads * head_size);
|
||||
size_t k_size = static_cast<size_t>(batch_size * sequence_length * kv_num_heads * head_size);
|
||||
auto q = reinterpret_cast<T*>(data.unpacked_qkv_buffer);
|
||||
auto k = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size);
|
||||
auto v = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size + k_size);
|
||||
Status status = LaunchUnpackQKV<T, LAYOUT_BNSH>(data.query, q, k, v, num_heads, kv_num_heads, head_size,
|
||||
sequence_length, batch_size, stream, max_threads_per_block);
|
||||
if (status != Status::OK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
query = reinterpret_cast<const void*>(q);
|
||||
key = reinterpret_cast<const void*>(k);
|
||||
value = reinterpret_cast<const void*>(v);
|
||||
}
|
||||
|
||||
constexpr bool q_layout = LAYOUT_BNSH;
|
||||
bool kv_layout = parameters.is_packed_qkv ? LAYOUT_BNSH : LAYOUT_BSNH;
|
||||
|
||||
DUMP_TENSOR("query", reinterpret_cast<const T*>(query), batch_size, num_heads, sequence_length, head_size);
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
if (LAYOUT_BNSH == kv_layout) {
|
||||
DUMP_TENSOR("key", reinterpret_cast<const T*>(key), batch_size, kv_num_heads, sequence_length, head_size);
|
||||
DUMP_TENSOR("value", reinterpret_cast<const T*>(value), batch_size, kv_num_heads, sequence_length, head_size);
|
||||
} else {
|
||||
DUMP_TENSOR("key", reinterpret_cast<const T*>(key), batch_size, sequence_length, kv_num_heads, head_size);
|
||||
DUMP_TENSOR("value", reinterpret_cast<const T*>(value), batch_size, sequence_length, kv_num_heads, head_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (parameters.do_rotary) {
|
||||
size_t bsh = static_cast<size_t>(parameters.batch_size * parameters.sequence_length * parameters.head_size);
|
||||
size_t q_size = bsh * static_cast<size_t>(parameters.num_heads);
|
||||
size_t k_size = bsh * static_cast<size_t>(parameters.kv_num_heads);
|
||||
auto q_buffer = reinterpret_cast<T*>(data.rotary_buffer);
|
||||
auto k_buffer = q_buffer + q_size;
|
||||
auto position_ids_buff = reinterpret_cast<int64_t*>(k_buffer + k_size);
|
||||
ORT_RETURN_IF_ERROR(FillPositionIds(parameters, data.seqlens_k_total, position_ids_buff, stream,
|
||||
max_threads_per_block));
|
||||
|
||||
DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length);
|
||||
|
||||
// Launch rotary embedding kernel. This requires separated Q, K and V
|
||||
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, q_buffer, reinterpret_cast<const T*>(query),
|
||||
position_ids_buff, data.cos_cache, data.sin_cache,
|
||||
parameters.batch_size, parameters.sequence_length,
|
||||
parameters.num_heads, parameters.head_size,
|
||||
parameters.rotary_dim, parameters.max_sequence_length,
|
||||
/*position_ids_format*/ 1, parameters.rotary_interleaved,
|
||||
max_threads_per_block, q_layout));
|
||||
ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel<T>(stream, k_buffer, reinterpret_cast<const T*>(key),
|
||||
position_ids_buff, data.cos_cache, data.sin_cache,
|
||||
parameters.batch_size, parameters.sequence_length,
|
||||
parameters.kv_num_heads, parameters.head_size,
|
||||
parameters.rotary_dim, parameters.max_sequence_length,
|
||||
/*position_ids_format*/ 1, parameters.rotary_interleaved,
|
||||
max_threads_per_block, kv_layout));
|
||||
query = reinterpret_cast<const void*>(q_buffer);
|
||||
key = reinterpret_cast<const void*>(k_buffer);
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR("query after rotary", reinterpret_cast<const T*>(query),
|
||||
batch_size, num_heads, sequence_length, head_size);
|
||||
if (LAYOUT_BNSH == kv_layout) {
|
||||
DUMP_TENSOR("key after rotary", reinterpret_cast<const T*>(key),
|
||||
batch_size, kv_num_heads, sequence_length, head_size);
|
||||
} else {
|
||||
DUMP_TENSOR("key after rotary", reinterpret_cast<const T*>(key),
|
||||
batch_size, sequence_length, kv_num_heads, head_size);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Concat new key and value to kv buffers (in BNSH format) in place
|
||||
ORT_ENFORCE(parameters.past_present_share_buffer);
|
||||
ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(
|
||||
parameters, data, key, value, kv_layout, stream, max_threads_per_block));
|
||||
|
||||
// TODO: only dump to total sequence length instead of max sequence length.
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR("key cache", data.present_key, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
|
||||
DUMP_TENSOR("value cache", data.present_value, batch_size, kv_num_heads, parameters.max_sequence_length, head_size);
|
||||
|
||||
DUMP_TENSOR("block_mask",
|
||||
data.kernel_layout.mask,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows,
|
||||
data.kernel_layout.num_cols);
|
||||
|
||||
DUMP_TENSOR("csr_col_indices",
|
||||
data.kernel_layout.csr_col_indices,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows,
|
||||
data.kernel_layout.num_cols);
|
||||
|
||||
DUMP_TENSOR("csr_row_indices",
|
||||
data.kernel_layout.csr_row_indices,
|
||||
data.kernel_layout.num_layout,
|
||||
data.kernel_layout.num_rows + 1);
|
||||
|
||||
printf(
|
||||
"batch_size=%d, sequence_length=%d, num_heads=%d, kv_num_heads=%d head_size=%d, "
|
||||
"total_sequence_length=%d, max_sequence_length=%d scale=%f block_size=%d "
|
||||
"row_stride=%d col_stride=%d num_layout=%d\n",
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.total_sequence_length,
|
||||
parameters.max_sequence_length,
|
||||
parameters.scale,
|
||||
data.kernel_layout.block_size,
|
||||
data.kernel_layout.num_rows + 1,
|
||||
data.kernel_layout.num_rows * data.kernel_layout.num_cols,
|
||||
data.kernel_layout.num_layout);
|
||||
#endif
|
||||
|
||||
if (data.use_v2_kernel) {
|
||||
sparse_attention_v2::SparseAttentionParams params(
|
||||
ort_stream,
|
||||
data.output,
|
||||
reinterpret_cast<const void*>(query),
|
||||
reinterpret_cast<const void*>(data.present_key),
|
||||
reinterpret_cast<const void*>(data.present_value),
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.total_sequence_length,
|
||||
parameters.max_sequence_length,
|
||||
parameters.scale,
|
||||
data.kernel_layout.block_size, // kernel_block_size
|
||||
data.kernel_layout.csr_row_indices, // skip past_seq_len in row indices
|
||||
data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
|
||||
data.kernel_layout.num_rows + 1, // stride per head in row indices
|
||||
data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
|
||||
data.kernel_layout.num_layout,
|
||||
data.active_q_blocks,
|
||||
data.q_batch_starts,
|
||||
data.q_batch_ends,
|
||||
data.k_batch_starts,
|
||||
data.k_batch_ends,
|
||||
data.q_batch_ids,
|
||||
data.q_start_sids);
|
||||
|
||||
if constexpr (std::is_same<T, BFloat16>::value) {
|
||||
ORT_RETURN_IF_ERROR(sparse_attention_v2::run_sparse_attention_bf16(params));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(sparse_attention_v2::run_sparse_attention_fp16(params));
|
||||
}
|
||||
} else {
|
||||
sparse_attention_v1::SparseAttentionParams params(
|
||||
ort_stream,
|
||||
data.output,
|
||||
reinterpret_cast<const void*>(query),
|
||||
reinterpret_cast<const void*>(data.present_key),
|
||||
reinterpret_cast<const void*>(data.present_value),
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.kv_num_heads,
|
||||
parameters.head_size,
|
||||
parameters.total_sequence_length,
|
||||
parameters.max_sequence_length,
|
||||
parameters.scale,
|
||||
data.kernel_layout.block_size, // kernel_block_size
|
||||
data.kernel_layout.csr_row_indices, // (num_layout, num_rows + 1)
|
||||
data.kernel_layout.csr_col_indices, // (num_layout, num_rows, num_cols)
|
||||
data.kernel_layout.num_rows + 1, // stride per head in row indices
|
||||
data.kernel_layout.num_rows * data.kernel_layout.num_cols, // stride per head in col indices
|
||||
data.kernel_layout.num_layout);
|
||||
|
||||
if constexpr (std::is_same<T, BFloat16>::value) {
|
||||
ORT_RETURN_IF_ERROR(sparse_attention_v1::run_sparse_attention_bf16(params));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(sparse_attention_v1::run_sparse_attention_fp16(params));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status QkvToContext<half>(
|
||||
const cudaDeviceProp& device_prop,
|
||||
Stream* ort_stream,
|
||||
contrib::SparseAttentionParameters& parameters,
|
||||
SparseAttentionData<half>& data);
|
||||
|
||||
template Status QkvToContext<BFloat16>(
|
||||
const cudaDeviceProp& device_prop,
|
||||
Stream* ort_stream,
|
||||
contrib::SparseAttentionParameters& parameters,
|
||||
SparseAttentionData<BFloat16>& data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <cuda_fp16.h>
|
||||
#include <cublas_v2.h>
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/providers/cuda/tunable/cuda_tunable.h"
|
||||
|
||||
using onnxruntime::cuda::tunable::CudaTuningContext;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
struct BlockLayout {
|
||||
const int32_t* mask; // shape (num_layout, num_rows, num_cols), where num_rows = num_cols = max_seq_len / block_size.
|
||||
int num_layout;
|
||||
int block_size; // kernel block size, which is <= sparse_block_size
|
||||
|
||||
const int* csr_col_indices;
|
||||
const int* csr_row_indices;
|
||||
int num_rows;
|
||||
int num_cols;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct SparseAttentionData {
|
||||
// Input Tensors
|
||||
const T* query = nullptr;
|
||||
const T* key = nullptr;
|
||||
const T* value = nullptr;
|
||||
const T* past_key = nullptr;
|
||||
const T* past_value = nullptr;
|
||||
const T* cos_cache = nullptr;
|
||||
const T* sin_cache = nullptr;
|
||||
|
||||
const int32_t* block_mask = nullptr;
|
||||
const int32_t* seqlens_k_total = nullptr;
|
||||
|
||||
// Temporary buffers
|
||||
T* transposed_q_buffer = nullptr;
|
||||
T* rotary_buffer = nullptr;
|
||||
T* unpacked_qkv_buffer = nullptr;
|
||||
|
||||
// This is sparse layout used in kernel.
|
||||
BlockLayout kernel_layout;
|
||||
|
||||
// Output Tensors
|
||||
T* output = nullptr;
|
||||
T* present_key = nullptr;
|
||||
T* present_value = nullptr;
|
||||
|
||||
// Data for sparse attention v2 kernel.
|
||||
bool use_v2_kernel = false;
|
||||
int* q_batch_starts = nullptr; // shape (batch_size)
|
||||
int* q_batch_ends = nullptr; // shape (batch_size)
|
||||
int* k_batch_starts = nullptr; // shape (batch_size)
|
||||
int* k_batch_ends = nullptr; // shape (batch_size)
|
||||
int* q_batch_ids = nullptr; // shape (G)
|
||||
int* q_start_sids = nullptr; // shape (G)
|
||||
int active_q_blocks = 0; // G: number of blocks in q that are not masked out
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const cudaDeviceProp& device_prop,
|
||||
Stream* ort_stream,
|
||||
contrib::SparseAttentionParameters& parameters,
|
||||
SparseAttentionData<T>& data);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,126 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
|
||||
# Example to use this script (Tested with CUDA 12.3 in Ubuntu 20.04):
|
||||
# python3 -m pip install triton==2.3.0
|
||||
# python3 compile_sparse_attention.py | sh
|
||||
#
|
||||
# Note that sparse_attention_v1_*.cc and sparse_attention_dispatcher_*.h are the generated files.
|
||||
|
||||
import math
|
||||
from itertools import product
|
||||
|
||||
|
||||
def generate_triton_compile_shell_script(dtype="fp16"):
|
||||
assert dtype in ["fp16", "bf16"]
|
||||
print("export TRITON_ROOT=$(pip show triton | grep Location | cut -d' ' -f2)")
|
||||
print('export ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader|head -n 1)"')
|
||||
print("export SM=$(echo $ARCH | sed -e 's/\\.//g')")
|
||||
|
||||
# Modify the compile.py to use custom template file template_h.txt and template_c.txt in current directory.
|
||||
# Also pass block_m to the template.
|
||||
print(
|
||||
'python -c "'
|
||||
"import sys;lines=sys.stdin.read();"
|
||||
"lines=lines.replace('template_path = Path(__file__).parent / f\\\"compile.{ext}\\\"','template_path = f\\\"compile_template_kernel_{ext}.txt\\\"');"
|
||||
'lines=lines.replace(\'\\"_placeholder\\": \\"\\",\', \'\\"_placeholder\\": \\"\\",\\n \\"block_m\\": list(constants.values())[0],\');'
|
||||
'print(lines)"'
|
||||
"< ${TRITON_ROOT}/triton/tools/compile.py > compile.py"
|
||||
)
|
||||
|
||||
out_dir = f"trition_cubin_{dtype}"
|
||||
print(f"rm -rf {out_dir}")
|
||||
print(f"mkdir -p {out_dir}")
|
||||
|
||||
# Note that block_n * num_block_d is the head_size. We support head_size = 128 for now.
|
||||
block_n_values = [64]
|
||||
block_d_values = [64]
|
||||
num_block_d_values = [2]
|
||||
even_m_values = [True, False]
|
||||
even_n_values = [True, False]
|
||||
|
||||
# Use triton compiler to compile the kernel of different combinations of constant parameters.
|
||||
for block_n, block_d, num_blocks_d, even_m, even_n in product(
|
||||
block_n_values, block_d_values, num_block_d_values, even_m_values, even_n_values
|
||||
):
|
||||
block_m_values = [16, block_n] if block_n != 16 else [block_n]
|
||||
for block_m in block_m_values:
|
||||
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
|
||||
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
|
||||
prefix = "python compile.py sparse_attention_triton.py"
|
||||
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm${{SM}}"
|
||||
name = f"sparse_attention_{dtype}_sm${{SM}}"
|
||||
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))
|
||||
num_stages = 2
|
||||
# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
|
||||
print(
|
||||
f"{prefix} -n block_sparse_attention_kernel -o {out_dir}/{filename} --out-name {name} "
|
||||
f'-w {num_warps} -ns {num_stages} -s "{sig}" -g "(total_seq_len - past_seq_len + {block_m} - 1) / {block_m}, batch_size * num_heads, 1"'
|
||||
)
|
||||
|
||||
# Generate the dispatcher.
|
||||
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm${{SM}}"
|
||||
print(f"cd {out_dir}")
|
||||
print(f"python ${{TRITON_ROOT}}/triton/tools/link.py sparse_attention_v1_*.h -o {dispatcher}")
|
||||
print("rm *.h")
|
||||
|
||||
# Remove signature hash in code.
|
||||
suffix = "0d1d2d3d4d5d678910d11d12d13d14d15d16d17d18d19d20d21d222324"
|
||||
print(f"for file in *.c; do sed -i 's/_{suffix}//g' \"$file\"; done")
|
||||
|
||||
# Recover signature hash in kernel name that is removed in previous step. Kernel name shall not be changed.
|
||||
print(
|
||||
f"for file in *.c; do sed -i 's/block_sparse_attention_kernel/block_sparse_attention_kernel_{suffix}/g' \"$file\"; done"
|
||||
)
|
||||
|
||||
# Remove signature hash from filename since we use same signature for all kernels except constants.
|
||||
# and we have constants in filename so that we can distinguish files without the hash.
|
||||
print('for file in sparse_attention_v1_*.c; do mv -- "$file" "$(echo $file | cut -f 1 -d \'.\').c"; done')
|
||||
|
||||
# Change function parameters and return type. If you change the kernel interface, you will need to modify this part.
|
||||
source1 = "CUstream stream, CUdeviceptr out, CUdeviceptr Q, CUdeviceptr K, CUdeviceptr V, CUdeviceptr layout_csr_row_indices, CUdeviceptr layout_csr_col_indices, int32_t layout_csr_row_stride_h, int32_t layout_csr_col_stride_h, int32_t num_layout, float softmax_scale, int32_t stride_qb, int32_t stride_qh, int32_t stride_qm, int32_t stride_kb, int32_t stride_kh, int32_t stride_kn, int32_t stride_vb, int32_t stride_vh, int32_t stride_vn, int32_t stride_ob, int32_t stride_oh, int32_t stride_om, int32_t num_heads, int32_t num_kv_heads, int32_t total_seq_len"
|
||||
target1 = "SparseAttentionParams& params"
|
||||
source2 = "stream, out, Q, K, V, layout_csr_row_indices, layout_csr_col_indices, layout_csr_row_stride_h, layout_csr_col_stride_h, num_layout, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om, num_heads, num_kv_heads, total_seq_len"
|
||||
target2 = "params"
|
||||
print(
|
||||
f"python -c \"import sys;lines=sys.stdin.read();lines=lines.replace('{source1}', '{target1}');"
|
||||
f'lines=lines.replace(\'{source2}\', \'{target2}\');print(lines)" < "{dispatcher}.c" > "{dispatcher}.h"'
|
||||
)
|
||||
print(f"sed -i 's/CUresult/Status/g' \"{dispatcher}.h\"")
|
||||
|
||||
# Remove parameter checking since we moved the validation logic to SparseAttentionParams
|
||||
print(f"sed -i '/if /d' \"{dispatcher}.h\"")
|
||||
print(f"sed -i '/CUDA_ERROR_INVALID_VALUE/d' \"{dispatcher}.h\"")
|
||||
print(f"sed -i '/#include/d' \"{dispatcher}.h\"")
|
||||
|
||||
print(f"rm {dispatcher}.c")
|
||||
|
||||
# Use a template file to add namespace and includes to the dispatcher file.
|
||||
print(
|
||||
'python -c "'
|
||||
"from pathlib import Path;"
|
||||
"template=Path('../compile_template_dispatcher_h.txt').read_text();"
|
||||
f"code=Path('{dispatcher}.h').read_text();"
|
||||
"text=template.replace('PLACEHOLDER', code); print(text)\" "
|
||||
f"> ../{dispatcher}.h"
|
||||
)
|
||||
# rename *.c to *.cc
|
||||
print('for file in *.c; do mv -- "$file" "${file%.c}.cc"; done')
|
||||
|
||||
# Move kernel files to parent directory. This might overwrite existing files in repository.
|
||||
print("mv sparse_attention_v1_* ../")
|
||||
|
||||
# Clean up
|
||||
print("cd ..")
|
||||
print("rm compile.py")
|
||||
print(f"rm -rf {out_dir}")
|
||||
|
||||
print("echo Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
generate_triton_compile_shell_script(dtype)
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention.py using triton AoT compiler
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v1 {
|
||||
|
||||
PLACEHOLDER
|
||||
|
||||
} // namespace sparse_attention_v1
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
namespace onnxruntime {{
|
||||
namespace contrib {{
|
||||
namespace cuda {{
|
||||
namespace sparse_attention_v1 {{
|
||||
|
||||
// This file is generated by compile_sparse_attention.py using triton AoT compiler
|
||||
// {kernel_docstring}
|
||||
// cubin_size = {bin_size}
|
||||
// shared_mem_bytes = {shared}
|
||||
// threads_per_cta = {num_warps} * 32
|
||||
// kernel_name = {triton_kernel_name}
|
||||
|
||||
unsigned char {kernel_name}_cubin[] = {{ {bin_data} }};
|
||||
|
||||
CUmodule {kernel_name}_mod = NULL;
|
||||
CUfunction {kernel_name}_func = NULL;
|
||||
|
||||
void unload_{kernel_name}(void) {{
|
||||
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
|
||||
CU_CHECK(driver->cuModuleUnload({kernel_name}_mod), driver);
|
||||
}}
|
||||
|
||||
void load_{kernel_name}(void) {{
|
||||
void *bin = (void *)&{kernel_name}_cubin;
|
||||
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
|
||||
CU_CHECK(driver->cuModuleLoadData(&{kernel_name}_mod, bin), driver);
|
||||
CU_CHECK(driver->cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"), driver);
|
||||
constexpr int shared = {shared};
|
||||
if constexpr (shared > 49152) {{
|
||||
SetKernelSharedMemory(driver, {kernel_name}_func);
|
||||
}}
|
||||
}}
|
||||
|
||||
Status {kernel_name}(SparseAttentionParams& params) {{
|
||||
return params.LaunchKernel({kernel_name}_func, {block_m}, {num_warps} * 32, {shared});
|
||||
}}
|
||||
|
||||
}} // namespace sparse_attention_v1
|
||||
}} // namespace cuda
|
||||
}} // namespace contrib
|
||||
}} // namespace onnxruntime
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention.py using triton AoT compiler
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
namespace onnxruntime {{
|
||||
namespace contrib {{
|
||||
namespace cuda {{
|
||||
namespace sparse_attention_v1 {{
|
||||
|
||||
void unload_{kernel_name}(void);
|
||||
|
||||
void load_{kernel_name}(void);
|
||||
|
||||
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
|
||||
Status{_placeholder} {kernel_name}(SparseAttentionParams& params);
|
||||
|
||||
}} // namespace sparse_attention_v1
|
||||
}} // namespace cuda
|
||||
}} // namespace contrib
|
||||
}} // namespace onnxruntime
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cudaDriverWrapper.h"
|
||||
|
||||
#define CU_CHECK(expr, driver) cuErrCheck(expr, *driver)
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v1 {
|
||||
|
||||
struct SparseAttentionParams {
|
||||
onnxruntime::Stream* ort_stream;
|
||||
void* output;
|
||||
const void* q;
|
||||
const void* k;
|
||||
const void* v;
|
||||
|
||||
int batch_size;
|
||||
int num_heads;
|
||||
int kv_num_heads;
|
||||
int head_size;
|
||||
|
||||
int sequence_length;
|
||||
int past_sequence_length;
|
||||
int total_sequence_length;
|
||||
int max_sequence_length;
|
||||
|
||||
float scale;
|
||||
|
||||
int kernel_block_size;
|
||||
|
||||
// CSR format of block mask
|
||||
const int* layout_csr_row_indices;
|
||||
const int* layout_csr_col_indices;
|
||||
int layout_row_stride_h;
|
||||
int layout_col_stride_h;
|
||||
int num_layout;
|
||||
|
||||
// strides
|
||||
int stride_qb;
|
||||
int stride_qh;
|
||||
int stride_qm;
|
||||
int stride_kb;
|
||||
int stride_kh;
|
||||
int stride_kn;
|
||||
int stride_vb;
|
||||
int stride_vh;
|
||||
int stride_vn;
|
||||
int stride_ob;
|
||||
int stride_oh;
|
||||
int stride_om;
|
||||
|
||||
SparseAttentionParams(
|
||||
onnxruntime::Stream* ort_stream,
|
||||
void* output,
|
||||
const void* q,
|
||||
const void* k,
|
||||
const void* v,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int total_sequence_length,
|
||||
int max_sequence_length,
|
||||
float scale,
|
||||
int kernel_block_size,
|
||||
const int* layout_csr_row_indices,
|
||||
const int* layout_csr_col_indices,
|
||||
int layout_row_stride_h,
|
||||
int layout_col_stride_h,
|
||||
int num_layout) {
|
||||
this->ort_stream = ort_stream;
|
||||
this->output = output;
|
||||
this->q = q;
|
||||
this->k = k;
|
||||
this->v = v;
|
||||
this->batch_size = batch_size;
|
||||
this->sequence_length = sequence_length;
|
||||
this->num_heads = num_heads;
|
||||
this->kv_num_heads = kv_num_heads;
|
||||
this->head_size = head_size;
|
||||
this->past_sequence_length = total_sequence_length - sequence_length;
|
||||
this->total_sequence_length = total_sequence_length;
|
||||
this->max_sequence_length = max_sequence_length;
|
||||
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast<float>(head_size)) : scale;
|
||||
this->kernel_block_size = kernel_block_size;
|
||||
this->layout_csr_row_indices = layout_csr_row_indices;
|
||||
this->layout_csr_col_indices = layout_csr_col_indices;
|
||||
this->layout_row_stride_h = layout_row_stride_h;
|
||||
this->layout_col_stride_h = layout_col_stride_h;
|
||||
this->num_layout = num_layout;
|
||||
|
||||
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
|
||||
this->stride_qh = this->sequence_length * this->head_size;
|
||||
this->stride_qm = this->head_size;
|
||||
|
||||
// When kv buffer has max sequence length, stride should match max sequence length.
|
||||
int kv_buffer_sequence_length = max_sequence_length;
|
||||
|
||||
// KV cache is in BNSH format
|
||||
this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kn = this->head_size;
|
||||
this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vn = this->head_size;
|
||||
|
||||
// Output is BSNH format
|
||||
this->stride_ob = this->sequence_length * this->num_heads * this->head_size;
|
||||
this->stride_oh = this->head_size;
|
||||
this->stride_om = this->num_heads * this->head_size;
|
||||
}
|
||||
|
||||
Status LaunchKernel(CUfunction f, int block_m, int threads_per_block, unsigned int sharedMemBytes) {
|
||||
ORT_ENFORCE(f != nullptr, "Kernel shall be loaded before calling LaunchKernel.");
|
||||
|
||||
if (!Valididate()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseAttentionParams is not valid.");
|
||||
}
|
||||
|
||||
void* args[26] = {
|
||||
&output, &q, &k, &v,
|
||||
&layout_csr_row_indices, &layout_csr_col_indices, &layout_row_stride_h, &layout_col_stride_h, &num_layout, &scale,
|
||||
&stride_qb, &stride_qh, &stride_qm, &stride_kb, &stride_kh, &stride_kn,
|
||||
&stride_vb, &stride_vh, &stride_vn, &stride_ob, &stride_oh, &stride_om,
|
||||
&num_heads, &kv_num_heads, &total_sequence_length, &past_sequence_length};
|
||||
|
||||
unsigned int gridDimX = (sequence_length + block_m - 1) / block_m;
|
||||
unsigned int gridDimY = batch_size * num_heads;
|
||||
constexpr unsigned int gridDimZ = 1;
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("q", reinterpret_cast<const half*>(q), batch_size, num_heads, sequence_length, head_size);
|
||||
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("csr_col_indices",
|
||||
layout_csr_col_indices,
|
||||
num_layout,
|
||||
layout_col_stride_h);
|
||||
|
||||
DUMP_TENSOR("csr_row_indices",
|
||||
layout_csr_row_indices,
|
||||
num_layout,
|
||||
layout_row_stride_h);
|
||||
printf(
|
||||
"layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n"
|
||||
"stride_qb=%d, stride_qh=%d, stride_qm=%d, stride_kb=%d, stride_kh=%d, stride_kn=%d,\n"
|
||||
"stride_vb=%d, stride_vh=%d, stride_vn=%d, stride_ob=%d, stride_oh=%d, stride_om=%d,\n"
|
||||
"num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n"
|
||||
"output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n",
|
||||
layout_row_stride_h, layout_col_stride_h, num_layout, scale,
|
||||
stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn,
|
||||
stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om,
|
||||
num_heads, kv_num_heads, total_sequence_length, past_sequence_length,
|
||||
output, q, k, v, layout_csr_row_indices, layout_csr_col_indices);
|
||||
|
||||
printf("block_m=%d gridDimX=%d gridDimY=%d threads_per_block=%d sharedMemBytes=%d\n",
|
||||
block_m, gridDimX, gridDimY, threads_per_block, sharedMemBytes);
|
||||
#endif
|
||||
|
||||
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
|
||||
CU_CHECK(driver->cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, threads_per_block, 1, 1, sharedMemBytes,
|
||||
static_cast<CUstream>(this->ort_stream->GetHandle()),
|
||||
args, NULL),
|
||||
driver);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Valididate() {
|
||||
return (reinterpret_cast<size_t>(output) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(q) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(k) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(v) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(layout_csr_col_indices) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(layout_csr_row_indices) % 16 == 0 &&
|
||||
this->head_size % 16 == 0 &&
|
||||
this->past_sequence_length == 0); // This kernel is for prompt only.
|
||||
}
|
||||
};
|
||||
} // namespace sparse_attention_v1
|
||||
|
||||
inline void SetKernelSharedMemory(const CUDADriverWrapper* driver, CUfunction func) {
|
||||
int device = 0;
|
||||
CUDA_CALL_THROW(cudaGetDevice(&device));
|
||||
|
||||
int shared_optin = 0;
|
||||
CU_CHECK(driver->cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device), driver);
|
||||
if (shared_optin > 49152) {
|
||||
CU_CHECK(driver->cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED), driver);
|
||||
CU_CHECK(driver->cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin), driver);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,216 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention.py using triton AoT compiler
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v1 {
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
|
||||
Status sparse_attention_bf16_sm80_ba65ff9c(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_ba65ff9c(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
|
||||
void load_sparse_attention_bf16_sm80_ba65ff9c();
|
||||
void load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
|
||||
load_sparse_attention_bf16_sm80_ba65ff9c();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
|
||||
void unload_sparse_attention_bf16_sm80_ba65ff9c();
|
||||
void unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_ba65ff9c();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
|
||||
Status sparse_attention_bf16_sm80_f951a16d(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_f951a16d(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
|
||||
void load_sparse_attention_bf16_sm80_f951a16d();
|
||||
void load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
|
||||
load_sparse_attention_bf16_sm80_f951a16d();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
|
||||
void unload_sparse_attention_bf16_sm80_f951a16d();
|
||||
void unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_f951a16d();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
|
||||
Status sparse_attention_bf16_sm80_646fefc8(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_646fefc8(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
|
||||
void load_sparse_attention_bf16_sm80_646fefc8();
|
||||
void load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
|
||||
load_sparse_attention_bf16_sm80_646fefc8();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
|
||||
void unload_sparse_attention_bf16_sm80_646fefc8();
|
||||
void unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_646fefc8();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
|
||||
Status sparse_attention_bf16_sm80_21cac990(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_21cac990(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
|
||||
void load_sparse_attention_bf16_sm80_21cac990();
|
||||
void load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
|
||||
load_sparse_attention_bf16_sm80_21cac990();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
|
||||
void unload_sparse_attention_bf16_sm80_21cac990();
|
||||
void unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_21cac990();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
|
||||
Status sparse_attention_bf16_sm80_31acb592(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_31acb592(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
|
||||
void load_sparse_attention_bf16_sm80_31acb592();
|
||||
void load_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2() {
|
||||
load_sparse_attention_bf16_sm80_31acb592();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
|
||||
void unload_sparse_attention_bf16_sm80_31acb592();
|
||||
void unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_31acb592();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
|
||||
Status sparse_attention_bf16_sm80_d55ab166(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_d55ab166(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
|
||||
void load_sparse_attention_bf16_sm80_d55ab166();
|
||||
void load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
|
||||
load_sparse_attention_bf16_sm80_d55ab166();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
|
||||
void unload_sparse_attention_bf16_sm80_d55ab166();
|
||||
void unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_d55ab166();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
|
||||
Status sparse_attention_bf16_sm80_b0560d11(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_b0560d11(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
|
||||
void load_sparse_attention_bf16_sm80_b0560d11();
|
||||
void load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
|
||||
load_sparse_attention_bf16_sm80_b0560d11();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
|
||||
void unload_sparse_attention_bf16_sm80_b0560d11();
|
||||
void unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_b0560d11();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
|
||||
Status sparse_attention_bf16_sm80_c777f3f5(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80_c777f3f5(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
|
||||
void load_sparse_attention_bf16_sm80_c777f3f5();
|
||||
void load_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2() {
|
||||
load_sparse_attention_bf16_sm80_c777f3f5();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
|
||||
void unload_sparse_attention_bf16_sm80_c777f3f5();
|
||||
void unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_bf16_sm80_c777f3f5();
|
||||
}
|
||||
|
||||
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
|
||||
kernel_func_t sparse_attention_bf16_sm80_kernels[] = {
|
||||
sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2,
|
||||
sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2,
|
||||
sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2,
|
||||
sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2,
|
||||
sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2,
|
||||
sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2,
|
||||
sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2,
|
||||
sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2,
|
||||
};
|
||||
|
||||
int sparse_attention_bf16_sm80_get_num_algos(void) {
|
||||
return (int)sizeof(sparse_attention_bf16_sm80_kernels);
|
||||
}
|
||||
|
||||
Status sparse_attention_bf16_sm80(SparseAttentionParams& params, int algo_id) {
|
||||
assert(algo_id < (int)sizeof(sparse_attention_bf16_sm80_kernels));
|
||||
return sparse_attention_bf16_sm80_kernels[algo_id](params);
|
||||
}
|
||||
|
||||
void load_sparse_attention_bf16_sm80(void) {
|
||||
load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
|
||||
load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
|
||||
load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
|
||||
load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
|
||||
load_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
|
||||
load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
|
||||
load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
|
||||
load_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
|
||||
}
|
||||
|
||||
void unload_sparse_attention_bf16_sm80(void) {
|
||||
unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
|
||||
unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
|
||||
unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
|
||||
unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
|
||||
unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
|
||||
unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
|
||||
unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
|
||||
unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
|
||||
}
|
||||
|
||||
Status sparse_attention_bf16_sm80_default(SparseAttentionParams& params) {
|
||||
return sparse_attention_bf16_sm80(params, 0);
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_v1
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,216 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention.py using triton AoT compiler
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v1 {
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2
|
||||
Status sparse_attention_fp16_sm80_3d26f9b3(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_3d26f9b3(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2
|
||||
void load_sparse_attention_fp16_sm80_3d26f9b3();
|
||||
void load_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2() {
|
||||
load_sparse_attention_fp16_sm80_3d26f9b3();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2
|
||||
void unload_sparse_attention_fp16_sm80_3d26f9b3();
|
||||
void unload_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_3d26f9b3();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2
|
||||
Status sparse_attention_fp16_sm80_bfb8dd1f(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_bfb8dd1f(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2
|
||||
void load_sparse_attention_fp16_sm80_bfb8dd1f();
|
||||
void load_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2() {
|
||||
load_sparse_attention_fp16_sm80_bfb8dd1f();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2
|
||||
void unload_sparse_attention_fp16_sm80_bfb8dd1f();
|
||||
void unload_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_bfb8dd1f();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2
|
||||
Status sparse_attention_fp16_sm80_5fdf5cf7(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_5fdf5cf7(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2
|
||||
void load_sparse_attention_fp16_sm80_5fdf5cf7();
|
||||
void load_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2() {
|
||||
load_sparse_attention_fp16_sm80_5fdf5cf7();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2
|
||||
void unload_sparse_attention_fp16_sm80_5fdf5cf7();
|
||||
void unload_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_5fdf5cf7();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2
|
||||
Status sparse_attention_fp16_sm80_35b9b6eb(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_35b9b6eb(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2
|
||||
void load_sparse_attention_fp16_sm80_35b9b6eb();
|
||||
void load_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2() {
|
||||
load_sparse_attention_fp16_sm80_35b9b6eb();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2
|
||||
void unload_sparse_attention_fp16_sm80_35b9b6eb();
|
||||
void unload_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_35b9b6eb();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2
|
||||
Status sparse_attention_fp16_sm80_bef12fb0(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_bef12fb0(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2
|
||||
void load_sparse_attention_fp16_sm80_bef12fb0();
|
||||
void load_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2() {
|
||||
load_sparse_attention_fp16_sm80_bef12fb0();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2
|
||||
void unload_sparse_attention_fp16_sm80_bef12fb0();
|
||||
void unload_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_bef12fb0();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2
|
||||
Status sparse_attention_fp16_sm80_30cd91a1(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_30cd91a1(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2
|
||||
void load_sparse_attention_fp16_sm80_30cd91a1();
|
||||
void load_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2() {
|
||||
load_sparse_attention_fp16_sm80_30cd91a1();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2
|
||||
void unload_sparse_attention_fp16_sm80_30cd91a1();
|
||||
void unload_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_30cd91a1();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2
|
||||
Status sparse_attention_fp16_sm80_72b7bd79(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_72b7bd79(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2
|
||||
void load_sparse_attention_fp16_sm80_72b7bd79();
|
||||
void load_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2() {
|
||||
load_sparse_attention_fp16_sm80_72b7bd79();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2
|
||||
void unload_sparse_attention_fp16_sm80_72b7bd79();
|
||||
void unload_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_72b7bd79();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2
|
||||
Status sparse_attention_fp16_sm80_d7f3a63f(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80_d7f3a63f(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2
|
||||
void load_sparse_attention_fp16_sm80_d7f3a63f();
|
||||
void load_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2() {
|
||||
load_sparse_attention_fp16_sm80_d7f3a63f();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2
|
||||
void unload_sparse_attention_fp16_sm80_d7f3a63f();
|
||||
void unload_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2() {
|
||||
unload_sparse_attention_fp16_sm80_d7f3a63f();
|
||||
}
|
||||
|
||||
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
|
||||
kernel_func_t sparse_attention_fp16_sm80_kernels[] = {
|
||||
sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2,
|
||||
sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2,
|
||||
sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2,
|
||||
sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2,
|
||||
sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2,
|
||||
sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2,
|
||||
sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2,
|
||||
sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2,
|
||||
};
|
||||
|
||||
int sparse_attention_fp16_sm80_get_num_algos(void) {
|
||||
return (int)sizeof(sparse_attention_fp16_sm80_kernels);
|
||||
}
|
||||
|
||||
Status sparse_attention_fp16_sm80(SparseAttentionParams& params, int algo_id) {
|
||||
assert(algo_id < (int)sizeof(sparse_attention_fp16_sm80_kernels));
|
||||
return sparse_attention_fp16_sm80_kernels[algo_id](params);
|
||||
}
|
||||
|
||||
void load_sparse_attention_fp16_sm80(void) {
|
||||
load_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2();
|
||||
load_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2();
|
||||
load_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2();
|
||||
load_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2();
|
||||
load_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2();
|
||||
load_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2();
|
||||
load_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2();
|
||||
load_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2();
|
||||
}
|
||||
|
||||
void unload_sparse_attention_fp16_sm80(void) {
|
||||
unload_sparse_attention_fp16_sm80_16x0x64x0x64x2_warps1xstages2();
|
||||
unload_sparse_attention_fp16_sm80_16x0x64x1x64x2_warps1xstages2();
|
||||
unload_sparse_attention_fp16_sm80_16x1x64x0x64x2_warps1xstages2();
|
||||
unload_sparse_attention_fp16_sm80_16x1x64x1x64x2_warps1xstages2();
|
||||
unload_sparse_attention_fp16_sm80_64x0x64x0x64x2_warps4xstages2();
|
||||
unload_sparse_attention_fp16_sm80_64x0x64x1x64x2_warps4xstages2();
|
||||
unload_sparse_attention_fp16_sm80_64x1x64x0x64x2_warps4xstages2();
|
||||
unload_sparse_attention_fp16_sm80_64x1x64x1x64x2_warps4xstages2();
|
||||
}
|
||||
|
||||
Status sparse_attention_fp16_sm80_default(SparseAttentionParams& params) {
|
||||
return sparse_attention_fp16_sm80(params, 0);
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_v1
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,166 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# This kernel is for prompt only and assume that past sequence length is 0. It only supports right padding.
|
||||
@triton.jit
|
||||
def block_sparse_attention_kernel(
|
||||
out, # output [B, H, M, D]. Note that B is batch_size, H is num_heads, M is q_seq_len, and D is head_size
|
||||
Q, # query [B, H, M, D]
|
||||
K, # key [B, H_kv, N, D]. Note that N is max_seq_len for kv cache, H_kv is num_kv_heads
|
||||
V, # value [B, H_kv, N, D]
|
||||
layout_csr_row_indices, # block mask CSR format. Shape is [L, num_rows + 1] where num_rows = max_seq_len / BLOCK_M
|
||||
layout_csr_col_indices, # block mask CSR format. Shape is [L, num_rows * num_cols] where num_cols = max_seq_len / BLOCK_N
|
||||
layout_csr_row_stride_h, # stride per head for csr_row_indices, i.e. num_rows + 1
|
||||
layout_csr_col_stride_h, # stride per head for csr_col_indices, i.e. num_rows * num_cols
|
||||
num_layout, # number of sparse layout (L)
|
||||
softmax_scale,
|
||||
stride_qb,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_kb,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vb,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_ob,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
total_seq_len, # Total sequence length including past sequence length and query sequence length.
|
||||
BLOCK_M: tl.constexpr, # block size for q_seq_len
|
||||
EVEN_M: tl.constexpr, # whether q_seq_len % BLOCK_M == 0
|
||||
BLOCK_N: tl.constexpr, # block size for k_seq_len
|
||||
EVEN_N: tl.constexpr, # whether k_seq_len % BLOCK_N == 0
|
||||
BLOCK_D: tl.constexpr, # block size for D
|
||||
NUM_D_BLOCKS: tl.constexpr, # number of data blocks = D / BLOCK_D
|
||||
):
|
||||
tl.static_print(f"{BLOCK_M=} {BLOCK_N=} {BLOCK_D=} {EVEN_M=} {EVEN_N=} {NUM_D_BLOCKS=}")
|
||||
|
||||
# Past sequence length is 0 since this kernel is for prompt only.
|
||||
q_seq_len = total_seq_len
|
||||
|
||||
# Grid is [CDiv(q_seq_len, BLOCK_M), batch_size * num_heads]
|
||||
start_m = tl.program_id(0)
|
||||
off_bh = tl.program_id(1)
|
||||
|
||||
off_h = off_bh % num_heads
|
||||
off_b = off_bh // num_heads
|
||||
|
||||
# For group query attention, map the query head index to the corresponding one for key and value.
|
||||
head_groups = num_heads // num_kv_heads
|
||||
off_h_kv = off_h // head_groups
|
||||
|
||||
Q += off_b * stride_qb + off_h * stride_qh
|
||||
K += off_b * stride_kb + off_h_kv * stride_kh
|
||||
V += off_b * stride_vb + off_h_kv * stride_vh
|
||||
|
||||
# Initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] # [BLOCK_M, BLOCK_D]
|
||||
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] # [BLOCK_D, BLOCK_N]
|
||||
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] # [BLOCK_N, BLOCK_D]
|
||||
|
||||
# Initialize pointers to query, key, value
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
|
||||
# Initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
acc2 = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
# Load q: it will stay in SRAM throughout
|
||||
if EVEN_M:
|
||||
q = tl.load(q_ptrs)
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
q2 = tl.load(q_ptrs + BLOCK_D)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < q_seq_len)
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
q2 = tl.load(q_ptrs + BLOCK_D, mask=offs_m[:, None] < q_seq_len)
|
||||
|
||||
layout_h = off_h % num_layout
|
||||
|
||||
# This assumes that past sequence length is 0, otherwise need + (past_seq_len + 1) // BLOCK_M.
|
||||
layout_ptr = layout_csr_row_indices + layout_h * layout_csr_row_stride_h + start_m
|
||||
start_l = tl.load(layout_ptr).to(tl.int32)
|
||||
end_l = tl.load(layout_ptr + 1).to(tl.int32)
|
||||
|
||||
# Loop over k, v and update accumulator
|
||||
for col_idx_idx in range(start_l, end_l):
|
||||
col_idx = tl.load(layout_csr_col_indices + layout_h * layout_csr_col_stride_h + col_idx_idx).to(tl.int32)
|
||||
start_n = col_idx * BLOCK_N
|
||||
# -- compute qk ----
|
||||
if EVEN_N:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < total_seq_len)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
if EVEN_N:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_D)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_D, mask=offs_n[None, :] + start_n < total_seq_len)
|
||||
qk += tl.dot(q2, k)
|
||||
|
||||
qk *= softmax_scale
|
||||
|
||||
# This assumes that past sequence length is 0, otherwise need offs_m[:, None] + past_seq_len >= ...
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
acc2 = acc2 * acc_scale[:, None]
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if EVEN_N:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < total_seq_len)
|
||||
acc += tl.dot(p, v)
|
||||
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
if EVEN_N:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_D)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_D, mask=offs_n[:, None] + start_n < total_seq_len)
|
||||
acc2 += tl.dot(p, v)
|
||||
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
off_o = off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]
|
||||
out_ptrs = out + off_o
|
||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < q_seq_len)
|
||||
if NUM_D_BLOCKS >= 2:
|
||||
tl.store(out_ptrs + BLOCK_D, acc2, mask=offs_m[:, None] < q_seq_len)
|
|
@ -0,0 +1,99 @@
|
|||
#include <cuda.h>
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_v1_api.h"
|
||||
|
||||
// Dispatcher files are generated.
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_fp16_sm80.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm80.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v1 {
|
||||
|
||||
int get_algo_id(SparseAttentionParams& params) {
|
||||
int block_n = params.kernel_block_size;
|
||||
int block_m = block_n;
|
||||
bool even_m = (params.sequence_length % block_m == 0);
|
||||
bool even_n = (params.total_sequence_length % block_n == 0);
|
||||
|
||||
if (params.head_size == 128) {
|
||||
if (block_m == 16) {
|
||||
if (!even_m) {
|
||||
return even_n ? 1 : 0;
|
||||
} else {
|
||||
return even_n ? 3 : 2;
|
||||
}
|
||||
} else if (block_m == 64) {
|
||||
if (!even_m) {
|
||||
return even_n ? 5 : 4;
|
||||
} else {
|
||||
return even_n ? 7 : 6;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool is_supported_sparse_attention(const cudaDeviceProp& dprops) {
|
||||
return dprops.major == 8;
|
||||
}
|
||||
|
||||
bool is_supported_sparse_attention(int head_size, int sparse_block_size) {
|
||||
return head_size == 128 && sparse_block_size == 64;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// FP16
|
||||
|
||||
Status run_sparse_attention_fp16(SparseAttentionParams& params) {
|
||||
int algo_id = get_algo_id(params);
|
||||
if (algo_id < 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
|
||||
}
|
||||
|
||||
// Right now we only support sm_8x.
|
||||
// If we want to support more architectures, we need to dispatch according to SM.
|
||||
return sparse_attention_fp16_sm80(params, algo_id);
|
||||
}
|
||||
|
||||
static std::once_flag load_sparse_attention_fp16_flag;
|
||||
|
||||
void load_sparse_attention_fp16(void) {
|
||||
// Right now we only support sm_8x.
|
||||
// If we want to support more architectures, we need to dispatch according to SM.
|
||||
std::call_once(load_sparse_attention_fp16_flag, load_sparse_attention_fp16_sm80);
|
||||
}
|
||||
|
||||
void unload_sparse_attention_fp16(void) {
|
||||
unload_sparse_attention_fp16_sm80();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// BF16
|
||||
|
||||
Status run_sparse_attention_bf16(SparseAttentionParams& params) {
|
||||
int algo_id = get_algo_id(params);
|
||||
if (algo_id < 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
|
||||
}
|
||||
|
||||
return sparse_attention_bf16_sm80(params, algo_id);
|
||||
}
|
||||
|
||||
static std::once_flag load_sparse_attention_bf16_flag;
|
||||
|
||||
void load_sparse_attention_bf16(void) {
|
||||
std::call_once(load_sparse_attention_bf16_flag, load_sparse_attention_fp16_sm80);
|
||||
}
|
||||
|
||||
void unload_sparse_attention_bf16(void) {
|
||||
unload_sparse_attention_bf16_sm80();
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_v1
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,25 @@
|
|||
#include <cuda.h>
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
using onnxruntime::Status;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v1 {
|
||||
|
||||
bool is_supported_sparse_attention(const cudaDeviceProp& dprops);
|
||||
bool is_supported_sparse_attention(int head_size, int sparse_block_size);
|
||||
|
||||
Status run_sparse_attention_fp16(SparseAttentionParams& params);
|
||||
void load_sparse_attention_fp16();
|
||||
void unload_sparse_attention_fp16();
|
||||
|
||||
Status run_sparse_attention_bf16(SparseAttentionParams& params);
|
||||
void load_sparse_attention_bf16();
|
||||
void unload_sparse_attention_bf16();
|
||||
|
||||
} // namespace sparse_attention_v1
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,140 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
|
||||
# Example to use this script (Tested with CUDA 12.3 in Ubuntu 20.04):
|
||||
# python3 -m pip install triton==2.3.0
|
||||
# python3 compile_sparse_attention_v2.py | sh
|
||||
#
|
||||
# Note that sparse_attention_v2_*.cc and sparse_attention_v2_dispatcher_*.h are the generated files.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
def generate_triton_compile_shell_script(dtype="fp16"):
|
||||
assert dtype in ["fp16", "bf16"]
|
||||
print("export TRITON_ROOT=$(pip show triton | grep Location | cut -d' ' -f2)")
|
||||
print('export ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader|head -n 1)"')
|
||||
print("export SM=$(echo $ARCH | sed -e 's/\\.//g')")
|
||||
|
||||
# Modify the compile.py to use custom template files compile_template_kernel_v2_c/h.txt in current directory.
|
||||
print(
|
||||
'python -c "'
|
||||
"import sys;lines=sys.stdin.read();"
|
||||
"lines=lines.replace('template_path = Path(__file__).parent / f\\\"compile.{ext}\\\"','template_path = f\\\"compile_template_kernel_v2_{ext}.txt\\\"');"
|
||||
'print(lines)"'
|
||||
"< ${TRITON_ROOT}/triton/tools/compile.py > compile.py"
|
||||
)
|
||||
|
||||
out_dir = f"trition_cubin_{dtype}"
|
||||
print(f"rm -rf {out_dir}")
|
||||
print(f"mkdir -p {out_dir}")
|
||||
|
||||
# All combinations of parameters for kernel.
|
||||
has_batch_dim_values = [True]
|
||||
head_size_values = [128]
|
||||
block_size_values = [64]
|
||||
is_prompt_values = [True, False]
|
||||
|
||||
# Use triton compiler to compile the kernel of different combinations of constant parameters.
|
||||
for has_batch_dim, head_size, block_size, is_prompt in product(
|
||||
has_batch_dim_values, head_size_values, block_size_values, is_prompt_values
|
||||
):
|
||||
# Constant parameters for triton kernel.
|
||||
block_d = triton.next_power_of_2(head_size)
|
||||
block_m = block_size
|
||||
block_n = block_size
|
||||
block_m_loading = 16 if not is_prompt else block_size
|
||||
even_d = block_d == head_size
|
||||
m_lt_n = block_m < block_n
|
||||
num_warps = 1 if not is_prompt else 4
|
||||
num_stages = 3
|
||||
|
||||
# There are 4 float and 8 int32 buffer pointers, and they are assumed to be aligned to 16 bytes.
|
||||
tensor_params = f"*{dtype}:16," * 4 + "*i32:16," * 8
|
||||
|
||||
# The strides for Q, K, V, and Out are multiples of 16 since head_size is 128.
|
||||
scalar_params = ("i32," * 2) + ("i32:16," * 12) + "i32,i32,fp32,"
|
||||
|
||||
constant_params = f"{int(has_batch_dim)},{head_size},{block_m},{block_n},{block_d},{block_m_loading},{int(even_d)},{int(m_lt_n)}"
|
||||
signature = f"{tensor_params}{scalar_params}{constant_params}"
|
||||
prefix = "python compile.py sparse_attention_v2_triton.py"
|
||||
|
||||
# output filename
|
||||
filename = f"sparse_attention_v2_{dtype}_d{head_size}_m{block_m}_{block_m_loading}_n{block_n}_b{int(has_batch_dim)}_sm${{SM}}"
|
||||
|
||||
# function name
|
||||
name = f"sparse_attention_v2_{dtype}_sm${{SM}}"
|
||||
|
||||
print(
|
||||
f"{prefix} -n block_sparse_attention -o {out_dir}/{filename} --out-name {name} "
|
||||
f'-w {num_warps} -ns {num_stages} -s "{signature}" -g "query_blocks, num_heads, 1"'
|
||||
)
|
||||
|
||||
# Generate the dispatcher.
|
||||
dispatcher = f"sparse_attention_v2_dispatcher_{dtype}_sm${{SM}}"
|
||||
print(f"cd {out_dir}")
|
||||
print(f"python ${{TRITON_ROOT}}/triton/tools/link.py sparse_attention_v2_*.h -o {dispatcher}")
|
||||
print("rm *.h")
|
||||
|
||||
# Remove signature in code.
|
||||
suffix = "0d1d2d3d4d5d6d7d8d9d10d11d121314d15d16d17d18d19d20d21d22d23d24d25d262728"
|
||||
print(f"for file in *.c; do sed -i 's/_{suffix}//g' \"$file\"; done")
|
||||
|
||||
# Recover signature in kernel name that is removed in previous step. Kernel name shall not be changed.
|
||||
print(f"for file in *.c; do sed -i 's/block_sparse_attention/block_sparse_attention_{suffix}/g' \"$file\"; done")
|
||||
|
||||
# Remove signature from filename since we use same signature for all kernels except constants.
|
||||
# and we have constants in filename so that we can distinguish files without the hash.
|
||||
print(f'for file in sparse_attention_v2_{dtype}_*.c; do mv "$file" "$(echo $file | cut -f 1 -d \'.\').c"; done')
|
||||
|
||||
# Change function parameters and return type. If you change the kernel interface, you will need to modify this part.
|
||||
source1 = "CUstream stream, CUdeviceptr Out, CUdeviceptr Q, CUdeviceptr K, CUdeviceptr V, CUdeviceptr q_batch_starts, CUdeviceptr q_batch_ends, CUdeviceptr k_batch_starts, CUdeviceptr k_batch_ends, CUdeviceptr q_batch_ids, CUdeviceptr q_start_sids, CUdeviceptr layout_crow_ptr, CUdeviceptr layout_col_ptr, int32_t layout_crow_stride_h, int32_t layout_col_stride_h, int32_t stride_qb, int32_t stride_qt, int32_t stride_qh, int32_t stride_kb, int32_t stride_kt, int32_t stride_kh, int32_t stride_vb, int32_t stride_vt, int32_t stride_vh, int32_t stride_ob, int32_t stride_ot, int32_t stride_oh, int32_t q_k_ratio, int32_t num_layout, float softmax_scale"
|
||||
target1 = "SparseAttentionParams& params"
|
||||
source2 = "stream, Out, Q, K, V, q_batch_starts, q_batch_ends, k_batch_starts, k_batch_ends, q_batch_ids, q_start_sids, layout_crow_ptr, layout_col_ptr, layout_crow_stride_h, layout_col_stride_h, stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh, stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh, q_k_ratio, num_layout, softmax_scale"
|
||||
target2 = "params"
|
||||
print(
|
||||
f"python -c \"import sys;lines=sys.stdin.read();lines=lines.replace('{source1}', '{target1}');"
|
||||
f'lines=lines.replace(\'{source2}\', \'{target2}\');print(lines)" < "{dispatcher}.c" > "{dispatcher}.h"'
|
||||
)
|
||||
print(f"sed -i 's/CUresult/Status/g' \"{dispatcher}.h\"")
|
||||
|
||||
# Remove parameter checking since we moved the validation logic to SparseAttentionParams
|
||||
print(f"sed -i '/if /d' \"{dispatcher}.h\"")
|
||||
print(f"sed -i '/CUDA_ERROR_INVALID_VALUE/d' \"{dispatcher}.h\"")
|
||||
print(f"sed -i '/#include/d' \"{dispatcher}.h\"")
|
||||
|
||||
print(f"rm {dispatcher}.c")
|
||||
|
||||
# Use a template file to add namespace and includes to the dispatcher file.
|
||||
print(
|
||||
'python -c "'
|
||||
"from pathlib import Path;"
|
||||
"template=Path('../compile_template_dispatcher_v2_h.txt').read_text();"
|
||||
f"code=Path('{dispatcher}.h').read_text();"
|
||||
"text=template.replace('PLACEHOLDER', code); print(text)\" "
|
||||
f"> ../{dispatcher}.h"
|
||||
)
|
||||
# rename *.c to *.cc
|
||||
print('for file in *.c; do mv -- "$file" "${file%.c}.cc"; done')
|
||||
|
||||
# Move kernel files to parent directory. This might overwrite existing files in repository.
|
||||
print("echo Generated files:")
|
||||
print("ls sparse_attention_v2_*")
|
||||
print(f"mv -f sparse_attention_v2_{dtype}_* ../")
|
||||
|
||||
# Clean up
|
||||
print("cd ..")
|
||||
print("rm compile.py")
|
||||
print(f"rm -rf {out_dir}")
|
||||
|
||||
print(f"echo compiling {dtype} is done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
generate_triton_compile_shell_script(dtype)
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention_v2.py
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v2 {
|
||||
|
||||
PLACEHOLDER
|
||||
|
||||
} // namespace sparse_attention_v2
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
|
||||
|
||||
namespace onnxruntime {{
|
||||
namespace contrib {{
|
||||
namespace cuda {{
|
||||
namespace sparse_attention_v2 {{
|
||||
|
||||
// This file is generated by compile_sparse_attention_v2.py
|
||||
// {kernel_docstring}
|
||||
// cubin_size = {bin_size}
|
||||
// shared_mem_bytes = {shared}
|
||||
// threads_per_cta = {num_warps} * 32
|
||||
// kernel_name = {triton_kernel_name}
|
||||
|
||||
unsigned char {kernel_name}_cubin[] = {{ {bin_data} }};
|
||||
|
||||
CUmodule {kernel_name}_mod = NULL;
|
||||
CUfunction {kernel_name}_func = NULL;
|
||||
|
||||
void unload_{kernel_name}(void) {{
|
||||
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
|
||||
CU_CHECK(driver->cuModuleUnload({kernel_name}_mod), driver);
|
||||
}}
|
||||
|
||||
void load_{kernel_name}(void) {{
|
||||
void *bin = (void *)&{kernel_name}_cubin;
|
||||
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
|
||||
CU_CHECK(driver->cuModuleLoadData(&{kernel_name}_mod, bin), driver);
|
||||
CU_CHECK(driver->cuModuleGetFunction(&{kernel_name}_func, {kernel_name}_mod, "{triton_kernel_name}"), driver);
|
||||
constexpr int shared = {shared};
|
||||
if constexpr (shared > 49152) {{
|
||||
SetKernelSharedMemory(driver, {kernel_name}_func);
|
||||
}}
|
||||
}}
|
||||
|
||||
Status {kernel_name}(SparseAttentionParams& params) {{
|
||||
return params.LaunchKernel({kernel_name}_func, {num_warps} * 32, {shared});
|
||||
}}
|
||||
|
||||
}} // namespace sparse_attention_v2
|
||||
}} // namespace cuda
|
||||
}} // namespace contrib
|
||||
}} // namespace onnxruntime
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention_v2.py
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
|
||||
|
||||
namespace onnxruntime {{
|
||||
namespace contrib {{
|
||||
namespace cuda {{
|
||||
namespace sparse_attention_v2 {{
|
||||
|
||||
void unload_{kernel_name}(void);
|
||||
|
||||
void load_{kernel_name}(void);
|
||||
|
||||
// tt-linker: {kernel_name}:{full_signature}:{algo_info}
|
||||
Status{_placeholder} {kernel_name}(SparseAttentionParams& params);
|
||||
|
||||
}} // namespace sparse_attention_v2
|
||||
}} // namespace cuda
|
||||
}} // namespace contrib
|
||||
}} // namespace onnxruntime
|
|
@ -0,0 +1,78 @@
|
|||
#include <cuda.h>
|
||||
#include <stdint.h>
|
||||
#include <assert.h>
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_api.h"
|
||||
|
||||
// Dispatcher files are generated.
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_dispatcher_fp16_sm80.h"
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_dispatcher_bf16_sm80.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v2 {
|
||||
|
||||
int get_algo_id(SparseAttentionParams& params) {
|
||||
return (params.past_sequence_length > 0 && params.sequence_length <= 16) ? 0 : 1;
|
||||
}
|
||||
|
||||
bool is_supported_sparse_attention(const cudaDeviceProp& dprops) {
|
||||
return dprops.major == 8;
|
||||
}
|
||||
|
||||
bool is_supported_sparse_attention(int head_size, int sparse_block_size) {
|
||||
return head_size == 128 && sparse_block_size == 64;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// FP16
|
||||
|
||||
Status run_sparse_attention_fp16(SparseAttentionParams& params) {
|
||||
int algo_id = get_algo_id(params);
|
||||
if (algo_id < 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
|
||||
}
|
||||
|
||||
// Right now we only support sm_8x.
|
||||
// If we want to support more architectures, we need to dispatch according to SM.
|
||||
return sparse_attention_v2_fp16_sm80(params, algo_id);
|
||||
}
|
||||
|
||||
static std::once_flag load_sparse_attention_v2_fp16_flag;
|
||||
|
||||
void load_sparse_attention_fp16(void) {
|
||||
// Right now we only support sm_8x.
|
||||
// If we want to support more architectures, we need to dispatch according to SM.
|
||||
std::call_once(load_sparse_attention_v2_fp16_flag, load_sparse_attention_v2_fp16_sm80);
|
||||
}
|
||||
|
||||
void unload_sparse_attention_fp16(void) {
|
||||
unload_sparse_attention_v2_fp16_sm80();
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// BF16
|
||||
|
||||
Status run_sparse_attention_bf16(SparseAttentionParams& params) {
|
||||
int algo_id = get_algo_id(params);
|
||||
if (algo_id < 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "no algo found for the parameters");
|
||||
}
|
||||
|
||||
return sparse_attention_v2_bf16_sm80(params, algo_id);
|
||||
}
|
||||
|
||||
static std::once_flag load_sparse_attention_v2_bf16_flag;
|
||||
|
||||
void load_sparse_attention_bf16(void) {
|
||||
std::call_once(load_sparse_attention_v2_bf16_flag, load_sparse_attention_v2_bf16_sm80);
|
||||
}
|
||||
|
||||
void unload_sparse_attention_bf16(void) {
|
||||
unload_sparse_attention_v2_bf16_sm80();
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_v2
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,25 @@
|
|||
#include <cuda.h>
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
|
||||
|
||||
using onnxruntime::Status;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v2 {
|
||||
|
||||
bool is_supported_sparse_attention(const cudaDeviceProp& dprops);
|
||||
bool is_supported_sparse_attention(int head_size, int sparse_block_size);
|
||||
|
||||
Status run_sparse_attention_fp16(SparseAttentionParams& params);
|
||||
void load_sparse_attention_fp16();
|
||||
void unload_sparse_attention_fp16();
|
||||
|
||||
Status run_sparse_attention_bf16(SparseAttentionParams& params);
|
||||
void load_sparse_attention_bf16();
|
||||
void unload_sparse_attention_bf16();
|
||||
|
||||
} // namespace sparse_attention_v2
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,232 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v2 {
|
||||
|
||||
struct SparseAttentionParams {
|
||||
onnxruntime::Stream* ort_stream;
|
||||
void* output;
|
||||
const void* q;
|
||||
const void* k;
|
||||
const void* v;
|
||||
|
||||
int batch_size;
|
||||
int num_heads;
|
||||
int kv_num_heads;
|
||||
int head_size;
|
||||
|
||||
int sequence_length;
|
||||
int past_sequence_length;
|
||||
int total_sequence_length;
|
||||
int max_sequence_length;
|
||||
|
||||
float scale;
|
||||
|
||||
int kernel_block_size;
|
||||
|
||||
// CSR format of block mask
|
||||
const int* layout_csr_row_indices;
|
||||
const int* layout_csr_col_indices;
|
||||
int layout_row_stride_h;
|
||||
int layout_col_stride_h;
|
||||
int num_layout;
|
||||
|
||||
// strides
|
||||
int stride_qb;
|
||||
int stride_qt;
|
||||
int stride_qh;
|
||||
int stride_kb;
|
||||
int stride_kt;
|
||||
int stride_kh;
|
||||
int stride_vb;
|
||||
int stride_vt;
|
||||
int stride_vh;
|
||||
int stride_ob;
|
||||
int stride_ot;
|
||||
int stride_oh;
|
||||
|
||||
int q_k_ratio;
|
||||
|
||||
int active_q_blocks;
|
||||
const int* q_batch_starts;
|
||||
const int* q_batch_ends;
|
||||
const int* k_batch_starts;
|
||||
const int* k_batch_ends;
|
||||
const int* q_batch_ids;
|
||||
const int* q_start_sids;
|
||||
|
||||
SparseAttentionParams(
|
||||
onnxruntime::Stream* ort_stream,
|
||||
void* output,
|
||||
const void* q,
|
||||
const void* k,
|
||||
const void* v,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
int head_size,
|
||||
int total_sequence_length,
|
||||
int max_sequence_length,
|
||||
float scale,
|
||||
int kernel_block_size,
|
||||
const int* layout_csr_row_indices,
|
||||
const int* layout_csr_col_indices,
|
||||
int layout_row_stride_h,
|
||||
int layout_col_stride_h,
|
||||
int num_layout,
|
||||
int active_q_blocks,
|
||||
const int* q_batch_starts,
|
||||
const int* q_batch_ends,
|
||||
const int* k_batch_starts,
|
||||
const int* k_batch_ends,
|
||||
const int* q_batch_ids,
|
||||
const int* q_start_sids) {
|
||||
this->ort_stream = ort_stream;
|
||||
this->output = output;
|
||||
this->q = q;
|
||||
this->k = k;
|
||||
this->v = v;
|
||||
this->batch_size = batch_size;
|
||||
this->sequence_length = sequence_length;
|
||||
this->num_heads = num_heads;
|
||||
this->kv_num_heads = kv_num_heads;
|
||||
this->head_size = head_size;
|
||||
this->past_sequence_length = total_sequence_length - sequence_length;
|
||||
this->total_sequence_length = total_sequence_length;
|
||||
this->max_sequence_length = max_sequence_length;
|
||||
this->scale = scale == 0.0f ? 1.0f / sqrtf(static_cast<float>(head_size)) : scale;
|
||||
this->kernel_block_size = kernel_block_size;
|
||||
this->layout_csr_row_indices = layout_csr_row_indices;
|
||||
this->layout_csr_col_indices = layout_csr_col_indices;
|
||||
this->layout_row_stride_h = layout_row_stride_h;
|
||||
this->layout_col_stride_h = layout_col_stride_h;
|
||||
this->num_layout = num_layout;
|
||||
|
||||
// Q is in BNSH format
|
||||
this->stride_qb = this->num_heads * this->sequence_length * this->head_size;
|
||||
this->stride_qh = this->sequence_length * this->head_size;
|
||||
this->stride_qt = this->head_size;
|
||||
|
||||
// When kv buffer has max sequence length, stride should match max sequence length.
|
||||
int kv_buffer_sequence_length = max_sequence_length;
|
||||
|
||||
// KV cache is in BNSH format
|
||||
this->stride_kb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_kt = this->head_size;
|
||||
this->stride_vb = this->kv_num_heads * kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vh = kv_buffer_sequence_length * this->head_size;
|
||||
this->stride_vt = this->head_size;
|
||||
|
||||
// Output is BSNH format
|
||||
this->stride_ob = this->sequence_length * this->num_heads * this->head_size;
|
||||
this->stride_oh = this->head_size;
|
||||
this->stride_ot = this->num_heads * this->head_size;
|
||||
|
||||
this->q_k_ratio = this->num_heads / this->kv_num_heads;
|
||||
|
||||
this->active_q_blocks = active_q_blocks;
|
||||
this->q_batch_starts = q_batch_starts;
|
||||
this->q_batch_ends = q_batch_ends;
|
||||
this->k_batch_starts = k_batch_starts;
|
||||
this->k_batch_ends = k_batch_ends;
|
||||
this->q_batch_ids = q_batch_ids;
|
||||
this->q_start_sids = q_start_sids;
|
||||
}
|
||||
|
||||
Status LaunchKernel(CUfunction f, int threads_per_block, unsigned int sharedMemBytes) {
|
||||
ORT_ENFORCE(f != nullptr, "Kernel shall be loaded before calling LaunchKernel.");
|
||||
|
||||
if (!Valididate()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SparseAttentionParams is not valid.");
|
||||
}
|
||||
|
||||
void* args[29] = {
|
||||
&output, &q, &k, &v,
|
||||
&q_batch_starts, &q_batch_ends, &k_batch_starts, &k_batch_ends, &q_batch_ids, &q_start_sids,
|
||||
&layout_csr_row_indices, &layout_csr_col_indices, &layout_row_stride_h, &layout_col_stride_h,
|
||||
&stride_qb, &stride_qt, &stride_qh, &stride_kb, &stride_kt, &stride_kh,
|
||||
&stride_vb, &stride_vt, &stride_vh, &stride_ob, &stride_ot, &stride_oh,
|
||||
&q_k_ratio, &num_layout, &scale};
|
||||
|
||||
unsigned int gridDimX = active_q_blocks;
|
||||
unsigned int gridDimY = num_heads;
|
||||
constexpr unsigned int gridDimZ = 1;
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
DUMP_TENSOR_INIT();
|
||||
DUMP_TENSOR("q", reinterpret_cast<const half*>(q), batch_size, num_heads, sequence_length, head_size);
|
||||
DUMP_TENSOR("k", reinterpret_cast<const half*>(k), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("v", reinterpret_cast<const half*>(v), batch_size, kv_num_heads, max_sequence_length, head_size);
|
||||
DUMP_TENSOR("csr_col_indices",
|
||||
layout_csr_col_indices,
|
||||
num_layout,
|
||||
layout_col_stride_h);
|
||||
|
||||
DUMP_TENSOR("csr_row_indices",
|
||||
layout_csr_row_indices,
|
||||
num_layout,
|
||||
layout_row_stride_h);
|
||||
|
||||
DUMP_TENSOR("q_batch_starts", q_batch_starts, 1, batch_size);
|
||||
DUMP_TENSOR("q_batch_ends", q_batch_ends, 1, batch_size);
|
||||
DUMP_TENSOR("k_batch_starts", k_batch_starts, 1, batch_size);
|
||||
DUMP_TENSOR("k_batch_ends", k_batch_ends, 1, batch_size);
|
||||
DUMP_TENSOR("q_batch_ids", q_batch_ids, 1, active_q_blocks);
|
||||
DUMP_TENSOR("q_start_sids", q_start_sids, 1, active_q_blocks);
|
||||
|
||||
printf(
|
||||
"layout_row_stride_h=%d, layout_col_stride_h=%d, num_layout=%d, scale=%f,\n"
|
||||
"stride_qb=%d, stride_qt=%d, stride_qh=%d, stride_kb=%d, stride_kt=%d, stride_kh=%d,\n"
|
||||
"stride_vb=%d, stride_vt=%d, stride_vh=%d, stride_ob=%d, stride_ot=%d, stride_oh=%d,\n"
|
||||
"num_heads=%d, kv_num_heads=%d, total_sequence_length=%d, past_sequence_length=%d\n"
|
||||
"output=%p, q=%p, k=%p, v=%p, layout_csr_row_indices=%p, layout_csr_col_indices=%p\n"
|
||||
"q_batch_starts=%p, q_batch_ends=%p, k_batch_starts=%p, k_batch_ends=%p, q_batch_ids=%p, q_start_sids=%p active_q_blocks=%d\n",
|
||||
layout_row_stride_h, layout_col_stride_h, num_layout, scale,
|
||||
stride_qb, stride_qt, stride_qh, stride_kb, stride_kt, stride_kh,
|
||||
stride_vb, stride_vt, stride_vh, stride_ob, stride_ot, stride_oh,
|
||||
num_heads, kv_num_heads, total_sequence_length, past_sequence_length,
|
||||
output, q, k, v, layout_csr_row_indices, layout_csr_col_indices,
|
||||
q_batch_starts, q_batch_ends, k_batch_starts, k_batch_ends, q_batch_ids, q_start_sids, active_q_blocks);
|
||||
|
||||
printf("gridDimX=%d gridDimY=%d threads_per_block=%d sharedMemBytes=%d\n",
|
||||
gridDimX, gridDimY, threads_per_block, sharedMemBytes);
|
||||
#endif
|
||||
|
||||
const CUDADriverWrapper* driver = CUDADriverWrapper::GetInstance();
|
||||
CU_CHECK(driver->cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, threads_per_block, 1, 1, sharedMemBytes,
|
||||
static_cast<CUstream>(this->ort_stream->GetHandle()),
|
||||
args, NULL),
|
||||
driver);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool Valididate() {
|
||||
// Check pointers are aligned to 16 bytes (we used that to hint the compiler to generate aligned loads/stores)
|
||||
return (reinterpret_cast<size_t>(output) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(q) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(k) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(v) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(layout_csr_col_indices) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(layout_csr_row_indices) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(q_batch_starts) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(q_batch_ends) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(k_batch_starts) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(k_batch_ends) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(q_batch_ids) % 16 == 0 &&
|
||||
reinterpret_cast<size_t>(q_start_sids) % 16 == 0 &&
|
||||
this->head_size % 16 == 0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sparse_attention_v2
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention_v2.py
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v2 {
|
||||
|
||||
// launcher for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
|
||||
Status sparse_attention_v2_bf16_sm80_0aafaf4a(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3(SparseAttentionParams& params) {
|
||||
return sparse_attention_v2_bf16_sm80_0aafaf4a(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
|
||||
void load_sparse_attention_v2_bf16_sm80_0aafaf4a();
|
||||
void load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
|
||||
load_sparse_attention_v2_bf16_sm80_0aafaf4a();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
|
||||
void unload_sparse_attention_v2_bf16_sm80_0aafaf4a();
|
||||
void unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
|
||||
unload_sparse_attention_v2_bf16_sm80_0aafaf4a();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
|
||||
Status sparse_attention_v2_bf16_sm80_8b0ce70d(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3(SparseAttentionParams& params) {
|
||||
return sparse_attention_v2_bf16_sm80_8b0ce70d(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
|
||||
void load_sparse_attention_v2_bf16_sm80_8b0ce70d();
|
||||
void load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
|
||||
load_sparse_attention_v2_bf16_sm80_8b0ce70d();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
|
||||
void unload_sparse_attention_v2_bf16_sm80_8b0ce70d();
|
||||
void unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
|
||||
unload_sparse_attention_v2_bf16_sm80_8b0ce70d();
|
||||
}
|
||||
|
||||
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
|
||||
kernel_func_t sparse_attention_v2_bf16_sm80_kernels[] = {
|
||||
sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3,
|
||||
sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3,
|
||||
};
|
||||
|
||||
int sparse_attention_v2_bf16_sm80_get_num_algos(void) {
|
||||
return (int)sizeof(sparse_attention_v2_bf16_sm80_kernels);
|
||||
}
|
||||
|
||||
Status sparse_attention_v2_bf16_sm80(SparseAttentionParams& params, int algo_id) {
|
||||
assert(algo_id < (int)sizeof(sparse_attention_v2_bf16_sm80_kernels));
|
||||
return sparse_attention_v2_bf16_sm80_kernels[algo_id](params);
|
||||
}
|
||||
|
||||
void load_sparse_attention_v2_bf16_sm80(void) {
|
||||
load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
|
||||
load_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
|
||||
}
|
||||
|
||||
void unload_sparse_attention_v2_bf16_sm80(void) {
|
||||
unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
|
||||
unload_sparse_attention_v2_bf16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
|
||||
}
|
||||
|
||||
Status sparse_attention_v2_bf16_sm80_default(SparseAttentionParams& params) {
|
||||
return sparse_attention_v2_bf16_sm80(params, 0);
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_v2
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file is generated by compile_sparse_attention_v2.py
|
||||
|
||||
#pragma once
|
||||
#include "contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
namespace sparse_attention_v2 {
|
||||
|
||||
// launcher for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
|
||||
Status sparse_attention_v2_fp16_sm80_a6bdc951(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3(SparseAttentionParams& params) {
|
||||
return sparse_attention_v2_fp16_sm80_a6bdc951(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
|
||||
void load_sparse_attention_v2_fp16_sm80_a6bdc951();
|
||||
void load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
|
||||
load_sparse_attention_v2_fp16_sm80_a6bdc951();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3
|
||||
void unload_sparse_attention_v2_fp16_sm80_a6bdc951();
|
||||
void unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3() {
|
||||
unload_sparse_attention_v2_fp16_sm80_a6bdc951();
|
||||
}
|
||||
|
||||
// launcher for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
|
||||
Status sparse_attention_v2_fp16_sm80_ca298032(SparseAttentionParams& params);
|
||||
|
||||
Status sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3(SparseAttentionParams& params) {
|
||||
return sparse_attention_v2_fp16_sm80_ca298032(params);
|
||||
}
|
||||
|
||||
// load for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
|
||||
void load_sparse_attention_v2_fp16_sm80_ca298032();
|
||||
void load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
|
||||
load_sparse_attention_v2_fp16_sm80_ca298032();
|
||||
}
|
||||
|
||||
// unload for: sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3
|
||||
void unload_sparse_attention_v2_fp16_sm80_ca298032();
|
||||
void unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3() {
|
||||
unload_sparse_attention_v2_fp16_sm80_ca298032();
|
||||
}
|
||||
|
||||
typedef Status (*kernel_func_t)(SparseAttentionParams& params);
|
||||
kernel_func_t sparse_attention_v2_fp16_sm80_kernels[] = {
|
||||
sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3,
|
||||
sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3,
|
||||
};
|
||||
|
||||
int sparse_attention_v2_fp16_sm80_get_num_algos(void) {
|
||||
return (int)sizeof(sparse_attention_v2_fp16_sm80_kernels);
|
||||
}
|
||||
|
||||
Status sparse_attention_v2_fp16_sm80(SparseAttentionParams& params, int algo_id) {
|
||||
assert(algo_id < (int)sizeof(sparse_attention_v2_fp16_sm80_kernels));
|
||||
return sparse_attention_v2_fp16_sm80_kernels[algo_id](params);
|
||||
}
|
||||
|
||||
void load_sparse_attention_v2_fp16_sm80(void) {
|
||||
load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
|
||||
load_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
|
||||
}
|
||||
|
||||
void unload_sparse_attention_v2_fp16_sm80(void) {
|
||||
unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x16x1x0_warps1xstages3();
|
||||
unload_sparse_attention_v2_fp16_sm80_1x128x64x64x128x64x1x0_warps4xstages3();
|
||||
}
|
||||
|
||||
Status sparse_attention_v2_fp16_sm80_default(SparseAttentionParams& params) {
|
||||
return sparse_attention_v2_fp16_sm80(params, 0);
|
||||
}
|
||||
|
||||
} // namespace sparse_attention_v2
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,215 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def block_sparse_attention(
|
||||
Out, # output [B, M, H, D]. Note that B is batch_size, M is q_seq_len, H is num_heads, and D is head_size
|
||||
Q, # query [B, M, H, D]
|
||||
K, # key [B, N, H_kv, D]. Note that N is max_seq_len for kv cache, H_kv is num_kv_heads
|
||||
V, # value [B, N, H_kv, D]
|
||||
q_batch_starts, # [B], start position (excluding the past) of query in the sequence for each batch
|
||||
q_batch_ends, # [B], end position (excluding the past) of query in the sequence for each batch
|
||||
k_batch_starts, # [B], start position (including the past) of key in the sequence for each batch
|
||||
k_batch_ends, # [B], end position (including the past) of key in the sequence for each batch
|
||||
q_batch_ids, # [G], batch id for each query block; G is the total number of query blocks
|
||||
q_start_sids, # [G], start position (excluding the past) of each query block
|
||||
layout_crow_ptr, # block mask CSR format. Shape is [H, num_rows + 1] where num_rows = max_seq_len / BLOCK_M
|
||||
layout_col_ptr, # block mask CSR format. Shape is [H, num_rows * num_cols] where num_cols = max_seq_len / BLOCK_N
|
||||
layout_crow_stride_h, # stride per head for csr_row_indices, i.e. num_rows + 1
|
||||
layout_col_stride_h, # stride per head for csr_col_indices, i.e. num_rows * num_cols
|
||||
stride_qb,
|
||||
stride_qt,
|
||||
stride_qh, # strides for query (excluding the stride for last hidden dim, which is always 1)
|
||||
stride_kb,
|
||||
stride_kt,
|
||||
stride_kh, # strides for key (excluding the stride for last hidden dim, which is always 1)
|
||||
stride_vb,
|
||||
stride_vt,
|
||||
stride_vh, # strides for value (excluding the stride for last hidden dim, which is always 1)
|
||||
stride_ob,
|
||||
stride_ot,
|
||||
stride_oh, # strides for output (excluding the stride for last hidden dim, which is always 1)
|
||||
q_k_ratio, # num_heads / num_kv_heads
|
||||
num_layout, # number of sparse layout (H)
|
||||
softmax_scale, # scaling factor applied prior to softmax
|
||||
HAS_BATCH_DIM: tl.constexpr, # whether batch dim is present
|
||||
D_HEAD: tl.constexpr, # head size
|
||||
BLOCK_M: tl.constexpr, # block size for q_seq_len
|
||||
BLOCK_N: tl.constexpr, # block size for k_seq_len
|
||||
BLOCK_D: tl.constexpr, # block size for D
|
||||
BLOCK_M_LOADING: tl.constexpr, # block size for loading q
|
||||
EVEN_D: tl.constexpr, # whether D is divisible by BLOCK_D
|
||||
M_LT_N: tl.constexpr, # whether BLOCK_M < BLOCK_N
|
||||
):
|
||||
tl.static_print(
|
||||
f"{HAS_BATCH_DIM=} {D_HEAD=} {BLOCK_M=} {BLOCK_N=} {BLOCK_D=} {BLOCK_M_LOADING=} {EVEN_D=} {M_LT_N=}"
|
||||
)
|
||||
# The grid is [G, num_heads] where G is number of query blocks.
|
||||
off_g = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
|
||||
off_h_for_kv = off_h // q_k_ratio
|
||||
off_b = tl.load(q_batch_ids + off_g).to(tl.int32)
|
||||
q_start_sid = tl.load(q_start_sids + off_g)
|
||||
start_m = q_start_sid // BLOCK_M
|
||||
|
||||
if HAS_BATCH_DIM:
|
||||
Q += off_b * stride_qb
|
||||
K += off_b * stride_kb
|
||||
V += off_b * stride_vb
|
||||
Out += off_b * stride_ob
|
||||
|
||||
# offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_D)
|
||||
|
||||
q_cu_start = tl.load(q_batch_starts + off_b).to(tl.int32)
|
||||
q_seqlen = tl.load(q_batch_ends + off_b).to(tl.int32) - q_cu_start
|
||||
|
||||
k_cu_start = tl.load(k_batch_starts + off_b).to(tl.int32)
|
||||
k_seqlen = tl.load(k_batch_ends + off_b).to(tl.int32) - k_cu_start
|
||||
|
||||
past_len = k_seqlen - q_seqlen
|
||||
|
||||
Q += q_cu_start * stride_qt + off_h * stride_qh
|
||||
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
|
||||
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
|
||||
Out += q_cu_start * stride_ot + off_h * stride_oh
|
||||
|
||||
if EVEN_D:
|
||||
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :], mask=offs_m[:, None] < q_seqlen)
|
||||
else:
|
||||
q = tl.load(
|
||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :],
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
other=0,
|
||||
)
|
||||
|
||||
q_row = (past_len + q_start_sid) // BLOCK_M
|
||||
|
||||
layout_h = off_h % num_layout
|
||||
sparse_crow_ptr = layout_crow_ptr + layout_h * layout_crow_stride_h + q_row
|
||||
|
||||
# TODO: load at once, supported in new Triton
|
||||
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
|
||||
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
|
||||
|
||||
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None]
|
||||
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :]
|
||||
|
||||
for k_block_col_idx in range(k_block_start, k_block_end - 1):
|
||||
k_block_id = tl.load(layout_col_ptr + layout_h * layout_col_stride_h + k_block_col_idx).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
|
||||
# -- compute qk ----
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_d[:, None] < D_HEAD)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= softmax_scale
|
||||
if M_LT_N:
|
||||
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
|
||||
# update acc
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_d[None, :] < D_HEAD)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
|
||||
# Process the last k block
|
||||
k_block_col_idx = k_block_end - 1
|
||||
k_block_id = tl.load(layout_col_ptr + layout_h * layout_col_stride_h + k_block_col_idx).to(tl.int32)
|
||||
start_n = k_block_id * BLOCK_N
|
||||
# -- compute qk ----
|
||||
if EVEN_D:
|
||||
k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < k_seqlen)
|
||||
else:
|
||||
# mask = mask & (offs_d[:, ])
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kt, mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD)
|
||||
)
|
||||
|
||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= softmax_scale
|
||||
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
# update acc
|
||||
if EVEN_D:
|
||||
v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < k_seqlen)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vt, mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD)
|
||||
)
|
||||
|
||||
acc += tl.dot(p, v)
|
||||
# l_i = l_i_new
|
||||
# m_i = m_i_new
|
||||
|
||||
# write output
|
||||
if EVEN_D:
|
||||
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :], acc, mask=offs_m[:, None] < q_seqlen)
|
||||
else:
|
||||
tl.store(
|
||||
Out + offs_m[:, None] * stride_ot + offs_d[None, :],
|
||||
acc,
|
||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||
)
|
|
@ -7,6 +7,22 @@
|
|||
#include "core/framework/ort_value.h"
|
||||
#include "contrib_ops/cpu/utils/console_dumper.h"
|
||||
|
||||
#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation.
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 0
|
||||
#define DUMP_TENSOR_INIT() onnxruntime::contrib::cuda::transformers::CudaTensorConsoleDumper dumper
|
||||
#define DUMP_TENSOR(...) dumper.Print(__VA_ARGS__)
|
||||
#else
|
||||
#define DUMP_TENSOR_INIT()
|
||||
#define DUMP_TENSOR(...)
|
||||
#endif
|
||||
|
||||
#if DUMP_TENSOR_LEVEL > 1
|
||||
#define DUMP_TENSOR_D(...) dumper.Print(__VA_ARGS__)
|
||||
#else
|
||||
#define DUMP_TENSOR_D(...)
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
|
|
@ -228,18 +228,13 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
|
|||
}
|
||||
}
|
||||
|
||||
void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
|
||||
// Output 0 has shape (batch_size, sequence_length, hidden_size)
|
||||
|
||||
// Q, K and V:
|
||||
// Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 (key) has shape (batch_size, kv_sequence_length, kv_hidden_size)
|
||||
// Input 2 (value) has shape (batch_size, kv_sequence_length, kv_hidden_size)
|
||||
|
||||
// Type inference
|
||||
// Type and shape inference for group query attention and sparse attention.
|
||||
void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
|
||||
int past_key_index = -1,
|
||||
int use_max_past_present_buffer = -1) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
// Shape inference
|
||||
int64_t kv_sequence_length = -1;
|
||||
if (hasInputShape(ctx, 0)) {
|
||||
auto& query_shape = getInputShape(ctx, 0);
|
||||
auto& query_dims = query_shape.dim();
|
||||
|
@ -249,18 +244,22 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
|
|||
}
|
||||
|
||||
if (hasInputShape(ctx, 2)) {
|
||||
// Input 0 (query) has shape (batch_size, sequence_length, num_heads * head_size)
|
||||
// Input 1 (key) has shape (batch_size, kv_sequence_length, kv_num_heads * head_size)
|
||||
// Input 2 (value) has shape (batch_size, kv_sequence_length, kv_num_heads * head_size)
|
||||
// Output 0 has shape (batch_size, sequence_length, num_heads * head_size)
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
auto& value_shape = getInputShape(ctx, 2);
|
||||
auto& value_dims = value_shape.dim();
|
||||
if (value_dims.size() != 3) {
|
||||
fail_shape_inference("Inputs 2 (value) shall be 3 dimensions");
|
||||
if (value_dims.size() == 3 && value_dims[1].has_dim_value()) {
|
||||
kv_sequence_length = value_dims[1].dim_value();
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
||||
*output_shape.add_dim() = query_dims[0];
|
||||
*output_shape.add_dim() = query_dims[1];
|
||||
*output_shape.add_dim() = query_dims[2];
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
} else {
|
||||
// Packed QKV:
|
||||
// Input 0 (query) has shape (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
|
||||
// Input 1 (key) is not present
|
||||
// Input 2 (value) is not present
|
||||
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
||||
int64_t num_heads = getAttribute(ctx, "num_heads", 0);
|
||||
int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);
|
||||
|
@ -270,25 +269,64 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
|
|||
*output_shape.add_dim() = query_dims[1];
|
||||
output_shape.add_dim()->set_dim_value(head_size * num_heads);
|
||||
updateOutputShape(ctx, 0, output_shape);
|
||||
|
||||
if (query_dims[1].has_dim_value()) {
|
||||
kv_sequence_length = query_dims[1].dim_value();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx.getNumOutputs() > 1) { // has present output
|
||||
if (hasInputShape(ctx, past_key_index)) {
|
||||
// auto& query_shape = getInputShape(ctx, 0);
|
||||
// auto& query_dims = query_shape.dim();
|
||||
// copy the type from query to present key
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1);
|
||||
|
||||
// copy the type from query to present value
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2);
|
||||
|
||||
if (past_key_index >= 0 && hasInputShape(ctx, past_key_index)) {
|
||||
auto& past_shape = getInputShape(ctx, past_key_index);
|
||||
auto& past_dims = past_shape.dim();
|
||||
|
||||
// past key has shape (batch_size, kv_num_heads, max_sequence_length, head_size)
|
||||
if (past_dims.size() != 4) {
|
||||
fail_shape_inference("The past_key input shall be 4 dimensions");
|
||||
}
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1);
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
|
||||
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
|
||||
|
||||
if (use_max_past_present_buffer == 1) {
|
||||
// When past and present use max buffer, they have the same shape
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1);
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast<size_t>(past_key_index) + 1, 2);
|
||||
} else if (use_max_past_present_buffer == 0) {
|
||||
if (kv_sequence_length > 0 && past_dims[2].has_dim_value()) {
|
||||
int64_t total_sequence_length = kv_sequence_length + past_dims[2].dim_value();
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto present_shape;
|
||||
for (auto& dim : past_dims) {
|
||||
*present_shape.add_dim() = dim;
|
||||
}
|
||||
|
||||
// shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
|
||||
present_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
||||
|
||||
updateOutputShape(ctx, 1, present_shape);
|
||||
updateOutputShape(ctx, 2, present_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
|
||||
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
|
||||
constexpr int use_max_past_present_buffer = -1;
|
||||
BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer);
|
||||
}
|
||||
|
||||
void SparseAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
|
||||
constexpr int use_max_past_present_buffer = 1;
|
||||
BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer);
|
||||
}
|
||||
|
||||
constexpr const char* Attention_ver1_doc = R"DOC(
|
||||
Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
|
||||
|
||||
|
@ -432,7 +470,7 @@ An input as above will be packed into 3 tensors like below:
|
|||
|
||||
Input tensors contains the hidden embedding of real tokens.
|
||||
Token_offset records the offset of token in the unpacked input.
|
||||
cumulated_token_count records cumulated length of each sequnces length.
|
||||
cumulated_token_count records cumulated length of each sequence length.
|
||||
|
||||
The operator only supports BERT like model with padding on right now.
|
||||
|
||||
|
@ -561,7 +599,7 @@ An input as above will be packed into 3 tensors like below:
|
|||
|
||||
The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
|
||||
Token_offset records the offset of token in the unpacked input.
|
||||
cumulative_sequence_length records cumulated length of each sequnces length.
|
||||
cumulative_sequence_length records cumulated length of each sequence length.
|
||||
|
||||
The operator only supports BERT like model with padding on right now.
|
||||
)DOC";
|
||||
|
@ -1103,6 +1141,103 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
GroupQueryAttentionTypeAndShapeInference(ctx, 3);
|
||||
}));
|
||||
|
||||
constexpr const char* SparseAttention_ver1_doc = R"DOC(
|
||||
Block Sparse Attention used in Phi-3-small (https://arxiv.org/pdf/2404.14219).
|
||||
|
||||
It is inspired by Sparse Transformers (https://arxiv.org/pdf/1904.10509) and BigBird (https://arxiv.org/pdf/2007.14062).
|
||||
|
||||
block_mask can be used to configure sparse layout for different head.
|
||||
When number of sparse layout is 1, all heads have same sparse layout. Otherwise, different layouts are used cyclically.
|
||||
For example, given 4 layouts (S0, S1, S2, S3), 8 heads will have layouts like (S0, S1, S2, S3, S0, S1, S2, S3).
|
||||
|
||||
Padding shall be on the right side.
|
||||
|
||||
When do_rotary is True, cos_cache and sin_cache are required.
|
||||
|
||||
Only supports unidirectional attention with cache of past key and value in linear buffers.
|
||||
For performance, past_key and present_key share same memory buffer, and past_value and present_value too.
|
||||
)DOC";
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
SparseAttention, 1,
|
||||
OpSchema()
|
||||
.SetDoc(SparseAttention_ver1_doc)
|
||||
.Attr("num_heads", "Number of attention heads for query", AttributeProto::INT)
|
||||
.Attr("kv_num_heads", "Number of attention heads for key and value", AttributeProto::INT)
|
||||
.Attr("scale", "Scaling factor applied prior to softmax. The default value is 1/sqrt(head_size)", AttributeProto::FLOAT,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("sparse_block_size", "Number of tokens per sparse block. Choices: 16, 32, 64, 128", AttributeProto::INT)
|
||||
.Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("rotary_interleaved", "Rotary use interleaved pattern or not. Default value is 0.", AttributeProto::INT,
|
||||
OPTIONAL_VALUE)
|
||||
.Input(0,
|
||||
"query",
|
||||
"Query with shape (batch_size, sequence_length, num_heads * head_size), or packed QKV with shape is"
|
||||
"(batch_size, sequence_length, d) where d is (num_heads + 2 * kv_num_heads) * head_size.",
|
||||
"T")
|
||||
.Input(1,
|
||||
"key",
|
||||
"Key with shape (batch_size, sequence_length, kv_num_heads * head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(2,
|
||||
"value",
|
||||
"Value with shape (batch_size, sequence_length, kv_num_heads * head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(3,
|
||||
"past_key",
|
||||
"Key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(4,
|
||||
"past_value",
|
||||
"Value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(5,
|
||||
"block_mask",
|
||||
"block mask. 1 indicates attention and 0 no attention. "
|
||||
"Its shape is (num_layout, max_blocks, max_blocks), "
|
||||
"where num_heads is divisible by num_layout, and max_blocks is max_sequence_length / sparse_block_size.",
|
||||
"M")
|
||||
.Input(6,
|
||||
"total_sequence_length",
|
||||
"Scalar tensor of maximum total sequence length (past_sequence_length + sequence_length) among keys.",
|
||||
"M")
|
||||
.Input(7,
|
||||
"key_total_sequence_lengths",
|
||||
"1D tensor with shape (batch_size) where each value is total sequence length of key excluding paddings.",
|
||||
"M")
|
||||
.Input(8,
|
||||
"cos_cache",
|
||||
"Cos cache of rotary with shape (max_sequence_length, head_size / 2).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Input(9,
|
||||
"sin_cache",
|
||||
"Sin cache of rotary with shape (max_sequence_length, head_size / 2).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
.Output(0,
|
||||
"output",
|
||||
"3D output tensor with shape (batch_size, sequence_length, num_heads * head_size)",
|
||||
"T")
|
||||
.Output(1,
|
||||
"present_key",
|
||||
"Updated key cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
|
||||
"T")
|
||||
.Output(2,
|
||||
"present_value",
|
||||
"Updated value cache with shape (batch_size, kv_num_heads, max_sequence_length, head_size).",
|
||||
"T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain integer type.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
SparseAttentionTypeAndShapeInference(ctx, 3);
|
||||
}));
|
||||
|
||||
constexpr const char* Longformer_Attention_doc = R"DOC(
|
||||
Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token
|
||||
attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens
|
||||
|
|
|
@ -104,6 +104,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseAttention);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding);
|
||||
|
@ -216,6 +217,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseAttention)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TransposeMatMul)>());
|
||||
|
|
|
@ -2320,6 +2320,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
#endif
|
||||
|
||||
#ifdef ENABLE_CUDA_NHWC_OPS
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaNhwcContribKernels(kernel_registry));
|
||||
#endif
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaNhwcKernels(kernel_registry));
|
||||
#endif
|
||||
|
||||
|
|
|
@ -11,210 +11,132 @@
|
|||
|
||||
#include "core/providers/cuda/cuda_nhwc_kernels.h"
|
||||
|
||||
// Macros to avoid long line length
|
||||
#define CUDA_NHWC_OP_CLASS_NAME(ver, name) \
|
||||
ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, ver, name)
|
||||
#define CUDA_NHWC_OP_TYPED_CLASS_NAME(ver, type, name) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, ver, type, name)
|
||||
#define CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(start_ver, end_ver, type, name) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, \
|
||||
start_ver, end_ver, type, name)
|
||||
#define CUDA_NHWC_OP_VERSIONED_CLASS_NAME(start_ver, end_ver, name) \
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, start_ver, end_ver, name)
|
||||
|
||||
namespace onnxruntime::cuda {
|
||||
|
||||
// When adding new supported NHWC operations make sure to also integrate them into: ConvertNodeLayout
|
||||
// in onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
|
||||
Conv);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16,
|
||||
Conv);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float,
|
||||
ConvTranspose);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16,
|
||||
ConvTranspose);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, float,
|
||||
AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, MLFloat16,
|
||||
AveragePool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16,
|
||||
GlobalAveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, float,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, MLFloat16,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, float,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, MLFloat16,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, MLFloat16, GlobalMaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float,
|
||||
AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16,
|
||||
AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, Conv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16,
|
||||
ConvTranspose);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, MLFloat16, AveragePool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, float,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, MLFloat16,
|
||||
MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, int8_t, MaxPool);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, uint8_t, MaxPool);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16,
|
||||
BatchNormalization);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, DepthToSpace);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, SpaceToDepth);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, SpaceToDepth);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, MLFloat16, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, float, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, double, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, MLFloat16, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, Conv);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, Conv);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, ConvTranspose);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, ConvTranspose);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, float, AveragePool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, MLFloat16, AveragePool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalAveragePool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalAveragePool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, float, MaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, MLFloat16, MaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, float, MaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, MLFloat16, MaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalMaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalMaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, AveragePool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, AveragePool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, MaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, MaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, Conv);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, Conv);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, ConvTranspose);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, ConvTranspose);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, AveragePool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, AveragePool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, float, MaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, MLFloat16, MaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, float, MaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, MLFloat16, MaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, int8_t, MaxPool);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(12, uint8_t, MaxPool);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, float, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, double, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, MLFloat16, BatchNormalization);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(15, float, BatchNormalization);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(15, double, BatchNormalization);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(15, MLFloat16, BatchNormalization);
|
||||
class CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 10, DepthToSpace);
|
||||
class CUDA_NHWC_OP_VERSIONED_CLASS_NAME(11, 12, DepthToSpace);
|
||||
class CUDA_NHWC_OP_CLASS_NAME(13, DepthToSpace);
|
||||
class CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 12, SpaceToDepth);
|
||||
class CUDA_NHWC_OP_CLASS_NAME(13, SpaceToDepth);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, float, LRN);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, double, LRN);
|
||||
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, MLFloat16, LRN);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(13, float, LRN);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(13, double, LRN);
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(13, MLFloat16, LRN);
|
||||
|
||||
Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn nhwc_function_table[] = {
|
||||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
|
||||
MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
|
||||
float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15,
|
||||
double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider,
|
||||
kMSInternalNHWCDomain, 1, 10, float, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
|
||||
float, Conv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
|
||||
MLFloat16, Conv)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 9, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
|
||||
float, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
|
||||
MLFloat16, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 7, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 8, 9, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
|
||||
float, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1,
|
||||
MLFloat16, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 10, 10, MLFloat16, MaxPool)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
|
||||
float, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
|
||||
MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 11, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
|
||||
float, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
|
||||
MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
|
||||
int8_t, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12,
|
||||
uint8_t, MaxPool)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
|
||||
float, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11,
|
||||
MLFloat16, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, MLFloat16, ConvTranspose)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
|
||||
1, 10, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
|
||||
11, 12, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
|
||||
13, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
|
||||
1, 12, SpaceToDepth)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain,
|
||||
13, SpaceToDepth)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, float, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, double, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, MLFloat16, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 13, float, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 13, double, LRN)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
|
||||
kCudaExecutionProvider, kMSInternalNHWCDomain, 13, MLFloat16, LRN)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 8, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(9, 13, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(14, 14, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(15, MLFloat16, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(15, float, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(15, double, BatchNormalization)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, Conv)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, Conv)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, Conv)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, Conv)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(7, 9, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalAveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 7, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(8, 9, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, float, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(1, MLFloat16, GlobalMaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(10, 10, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, AveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, AveragePool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(11, 11, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, float, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, MLFloat16, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, int8_t, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(12, uint8_t, MaxPool)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, float, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(11, MLFloat16, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, float, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 10, MLFloat16, ConvTranspose)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 10, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_CLASS_NAME(11, 12, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_CLASS_NAME(13, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_CLASS_NAME(1, 12, SpaceToDepth)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_CLASS_NAME(13, SpaceToDepth)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, float, LRN)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, double, LRN)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(1, 12, MLFloat16, LRN)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(13, float, LRN)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(13, double, LRN)>,
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(13, MLFloat16, LRN)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : nhwc_function_table) {
|
||||
|
@ -226,4 +148,28 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
|
|||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime::cuda
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
namespace onnxruntime::contrib::cuda {
|
||||
|
||||
class CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample);
|
||||
|
||||
onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn nhwc_function_table[] = {
|
||||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : nhwc_function_table) {
|
||||
KernelCreateInfo info = function_table_entry();
|
||||
if (info.kernel_def != nullptr) { // filter disabled entries where type is void
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime::contrib::cuda
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
|
@ -11,3 +11,11 @@ namespace onnxruntime::cuda {
|
|||
onnxruntime::common::Status RegisterCudaNhwcKernels(onnxruntime::KernelRegistry& kernel_registry);
|
||||
|
||||
} // namespace onnxruntime::cuda
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
namespace onnxruntime::contrib::cuda {
|
||||
|
||||
onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry);
|
||||
|
||||
} // namespace onnxruntime::contrib::cuda
|
||||
#endif
|
||||
|
|
|
@ -83,7 +83,7 @@ void TryToLoadKernel() {
|
|||
// get all kernel symbols from curret lib.so
|
||||
size_t size = sizeof(kernel_infos) / sizeof(kernel_infos[0]);
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto k_i = kernel_infos[i];
|
||||
|
||||
void* buff;
|
||||
|
|
|
@ -206,6 +206,7 @@ class SymbolicShapeInference:
|
|||
"GemmFloat8": self._infer_GemmFloat8,
|
||||
"GroupNorm": self._infer_GroupNorm,
|
||||
"GroupQueryAttention": self._infer_GroupQueryAttention,
|
||||
"SparseAttention": self._infer_SparseAttention,
|
||||
"SkipGroupNorm": self._infer_SkipGroupNorm,
|
||||
"LayerNormalization": self._infer_LayerNormalization,
|
||||
"LongformerAttention": self._infer_LongformerAttention,
|
||||
|
@ -473,6 +474,7 @@ class SymbolicShapeInference:
|
|||
"MultiHeadAttention",
|
||||
"GroupNorm",
|
||||
"GroupQueryAttention",
|
||||
"SparseAttention",
|
||||
"SkipGroupNorm",
|
||||
"BiasSplitGelu",
|
||||
"BiasAdd",
|
||||
|
@ -2449,6 +2451,8 @@ class SymbolicShapeInference:
|
|||
|
||||
past_shape = self._try_get_shape(node, 3)
|
||||
if past_shape is not None:
|
||||
# When past and present has the maximum sequence length, we can propagate the shape from past to present.
|
||||
# Note that GQA also supports different sequence lengths for past and present, but it is rarely used.
|
||||
vi = self.known_vi_[node.output[1]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
||||
vi = self.known_vi_[node.output[2]]
|
||||
|
@ -2470,6 +2474,9 @@ class SymbolicShapeInference:
|
|||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape))
|
||||
|
||||
def _infer_SparseAttention(self, node): # noqa: N802
|
||||
self._infer_GroupQueryAttention(node)
|
||||
|
||||
def _infer_SkipGroupNorm(self, node): # noqa: N802
|
||||
self._propagate_shape_and_type(node, 0, 0)
|
||||
if len(node.output) > 1:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
@ -224,11 +224,44 @@ class CudaSession:
|
|||
self.output_tensors = OrderedDict()
|
||||
self.device = device
|
||||
|
||||
# Pairs of input and output names that share the same buffer.
|
||||
self.buffer_sharing: Dict[str, str] = {}
|
||||
|
||||
def set_buffer_sharing(self, input_name: str, output_name: str):
|
||||
assert input_name in self.input_names
|
||||
assert output_name in self.output_names
|
||||
self.buffer_sharing[input_name] = output_name
|
||||
self.buffer_sharing[output_name] = input_name
|
||||
|
||||
def __del__(self):
|
||||
del self.input_tensors
|
||||
del self.output_tensors
|
||||
del self.io_binding
|
||||
|
||||
def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
|
||||
device_id = tensor.device.index if tensor.device.index is not None else 0
|
||||
tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
|
||||
|
||||
self.io_binding.bind_input(
|
||||
name,
|
||||
tensor.device.type,
|
||||
device_id,
|
||||
self.io_name_to_numpy_type[name],
|
||||
tensor_shape,
|
||||
tensor.data_ptr(),
|
||||
)
|
||||
|
||||
if name in self.buffer_sharing:
|
||||
self.io_binding.bind_output(
|
||||
self.buffer_sharing[name],
|
||||
tensor.device.type,
|
||||
device_id,
|
||||
self.io_name_to_numpy_type[name],
|
||||
tensor_shape,
|
||||
tensor.data_ptr(),
|
||||
)
|
||||
self.output_tensors[self.buffer_sharing[name]] = tensor
|
||||
|
||||
def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
|
||||
"""Allocate tensors for I/O Binding"""
|
||||
if self.enable_cuda_graph:
|
||||
|
@ -245,15 +278,7 @@ class CudaSession:
|
|||
device=self.device
|
||||
)
|
||||
self.input_tensors[name] = tensor
|
||||
|
||||
self.io_binding.bind_input(
|
||||
name,
|
||||
tensor.device.type,
|
||||
tensor.device.index,
|
||||
numpy_dtype,
|
||||
list(tensor.size()),
|
||||
tensor.data_ptr(),
|
||||
)
|
||||
self.bind_input_and_buffer_sharing(name, tensor)
|
||||
|
||||
for name, shape in shape_dict.items():
|
||||
if name in self.output_names:
|
||||
|
@ -261,6 +286,9 @@ class CudaSession:
|
|||
if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
|
||||
continue
|
||||
|
||||
if name in self.buffer_sharing:
|
||||
continue
|
||||
|
||||
numpy_dtype = self.io_name_to_numpy_type[name]
|
||||
tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to(
|
||||
device=self.device
|
||||
|
@ -270,7 +298,7 @@ class CudaSession:
|
|||
self.io_binding.bind_output(
|
||||
name,
|
||||
tensor.device.type,
|
||||
tensor.device.index,
|
||||
tensor.device.index if tensor.device.index is not None else 0,
|
||||
numpy_dtype,
|
||||
list(tensor.size()),
|
||||
tensor.data_ptr(),
|
||||
|
@ -287,14 +315,7 @@ class CudaSession:
|
|||
assert tensor.device.type == "cuda"
|
||||
self.input_tensors[name].copy_(tensor)
|
||||
else:
|
||||
self.io_binding.bind_input(
|
||||
name,
|
||||
tensor.device.type,
|
||||
tensor.device.index,
|
||||
TypeHelper.torch_type_to_numpy_type(tensor.dtype),
|
||||
[1] if len(tensor.shape) == 0 else list(tensor.shape),
|
||||
tensor.data_ptr(),
|
||||
)
|
||||
self.bind_input_and_buffer_sharing(name, tensor)
|
||||
|
||||
# Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU.
|
||||
if synchronize:
|
||||
|
@ -330,8 +351,13 @@ class GpuBinding(CudaSession):
|
|||
enable_gpu_graph: bool = False,
|
||||
gpu_graph_id: int = -1,
|
||||
stream: int = 0,
|
||||
buffer_sharing: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
super().__init__(ort_session, device, enable_gpu_graph)
|
||||
if buffer_sharing:
|
||||
for input_name, output_name in buffer_sharing.items():
|
||||
self.set_buffer_sharing(input_name, output_name)
|
||||
|
||||
self.allocate_buffers(shape_dict)
|
||||
self.gpu_graph_id = gpu_graph_id
|
||||
# For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
|
||||
|
@ -383,6 +409,7 @@ class GpuBindingManager:
|
|||
self,
|
||||
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
|
||||
use_cuda_graph: bool = False,
|
||||
buffer_sharing: Optional[Dict[str, str]] = None,
|
||||
) -> GpuBinding:
|
||||
for gpu_graph_binding in self.graph_bindings:
|
||||
# Found a cuda graph that captured with the same shape
|
||||
|
@ -392,7 +419,9 @@ class GpuBindingManager:
|
|||
# Reached the maximum number of cuda graphs. Return a binding without cuda graph.
|
||||
if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
|
||||
if self.no_graph_binding is None:
|
||||
self.no_graph_binding = GpuBinding(self.ort_session, self.device, shape_dict, stream=self.stream)
|
||||
self.no_graph_binding = GpuBinding(
|
||||
self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
|
||||
)
|
||||
else:
|
||||
self.no_graph_binding.allocate_buffers(shape_dict)
|
||||
return self.no_graph_binding
|
||||
|
@ -405,6 +434,7 @@ class GpuBindingManager:
|
|||
enable_gpu_graph=True,
|
||||
gpu_graph_id=len(self.graph_bindings),
|
||||
stream=self.stream,
|
||||
buffer_sharing=buffer_sharing,
|
||||
)
|
||||
self.graph_bindings.append(gpu_graph_binding)
|
||||
return gpu_graph_binding
|
||||
|
|
|
@ -0,0 +1,680 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 8.x.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from onnx import TensorProto, helper
|
||||
from torch import Tensor
|
||||
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
|
||||
|
||||
ENABLE_DEBUG = False
|
||||
|
||||
|
||||
class Config:
|
||||
batch_size = 0
|
||||
sequence_length = 0
|
||||
max_sequence_length = 0
|
||||
past_sequence_length = 0
|
||||
num_heads = 0
|
||||
kv_num_heads = 0
|
||||
head_size = 0
|
||||
sparse_block_size = 0
|
||||
num_layout = 0
|
||||
|
||||
local_blocks = 0
|
||||
vert_stride = 0
|
||||
softmax_scale = None
|
||||
share_buffer = True
|
||||
|
||||
do_rotary = False
|
||||
rotary_interleaved = False
|
||||
|
||||
# TODO: test packed qkv
|
||||
is_packed_qkv = False
|
||||
|
||||
# TODO: test bfloat16.
|
||||
is_fp16 = True # True for float16; False for bfloat16.
|
||||
|
||||
use_sparse = True # True for GroupQueryAttention; False for SparseAttention
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
max_sequence_length: int,
|
||||
past_sequence_length: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
sparse_block_size: int,
|
||||
num_layout: int,
|
||||
local_blocks: int,
|
||||
vert_stride: int,
|
||||
softmax_scale=None,
|
||||
do_rotary: bool = False,
|
||||
rotary_interleaved: bool = False,
|
||||
device="cuda",
|
||||
operator="SparseAttention",
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.past_sequence_length = past_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.head_size = head_size
|
||||
self.sparse_block_size = sparse_block_size
|
||||
self.num_layout = num_layout
|
||||
self.local_blocks = local_blocks
|
||||
self.vert_stride = vert_stride
|
||||
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
|
||||
|
||||
# Derived values
|
||||
self.total_sequence_length = sequence_length + past_sequence_length
|
||||
self.past_buffer_length = max_sequence_length if self.share_buffer else past_sequence_length
|
||||
self.present_buffer_length = (
|
||||
max_sequence_length if self.share_buffer else (past_sequence_length + sequence_length)
|
||||
)
|
||||
self.max_blocks = max_sequence_length // sparse_block_size
|
||||
|
||||
self.do_rotary = do_rotary
|
||||
self.rotary_interleaved = rotary_interleaved
|
||||
self.device = device
|
||||
self.operator = operator
|
||||
|
||||
def block_mask(self):
|
||||
return get_block_mask(self.num_layout, self.max_blocks, self.local_blocks, self.vert_stride).to(self.device)
|
||||
|
||||
def dense_mask(self):
|
||||
expand_block_mask = self.block_mask()
|
||||
dense_mask = get_dense_mask(
|
||||
expand_block_mask, self.total_sequence_length, self.sequence_length, self.sparse_block_size
|
||||
)
|
||||
return dense_mask.repeat(self.batch_size, self.num_heads // self.num_layout, 1, 1).to(self.device)
|
||||
|
||||
def shape_dict(self):
|
||||
shape_dict = {
|
||||
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
|
||||
"value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
|
||||
"past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
|
||||
"past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
|
||||
"block_mask": (self.num_layout, self.max_blocks, self.max_blocks),
|
||||
"total_sequence_length": (1,),
|
||||
"key_total_sequence_lengths": (self.batch_size,),
|
||||
"seqlens_k": (self.batch_size,),
|
||||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
}
|
||||
|
||||
if self.operator == "SparseAttention":
|
||||
del shape_dict["seqlens_k"]
|
||||
else:
|
||||
assert self.operator == "GroupQueryAttention"
|
||||
del shape_dict["key_total_sequence_lengths"]
|
||||
del shape_dict["block_mask"]
|
||||
return shape_dict
|
||||
|
||||
def get_cos_sin_cache(self, dtype=torch.float32):
|
||||
rotary_fraction = 1.0
|
||||
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
|
||||
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
return cos.to(device=self.device), sin.to(device=self.device)
|
||||
|
||||
def random_inputs(self, dtype=torch.float16):
|
||||
device = self.device
|
||||
shape_dict = self.shape_dict()
|
||||
k_seqlens = torch.ones((self.batch_size,), device=device, dtype=torch.int32) * self.total_sequence_length
|
||||
|
||||
torch.manual_seed(123)
|
||||
feeds = {
|
||||
"query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"block_mask": self.block_mask(),
|
||||
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
|
||||
"key_total_sequence_lengths": k_seqlens,
|
||||
"seqlens_k": k_seqlens - 1,
|
||||
}
|
||||
|
||||
if self.do_rotary:
|
||||
cos_cache, sin_cache = self.get_cos_sin_cache(dtype)
|
||||
feeds["cos_cache"] = cos_cache
|
||||
feeds["sin_cache"] = sin_cache
|
||||
|
||||
if "seqlens_k" not in shape_dict:
|
||||
del feeds["seqlens_k"]
|
||||
else:
|
||||
del feeds["key_total_sequence_lengths"]
|
||||
del feeds["block_mask"]
|
||||
return feeds
|
||||
|
||||
|
||||
def get_block_mask(num_layout, max_blocks, local_blocks, vert_stride):
|
||||
q_pos = torch.arange(max_blocks)[None, :, None]
|
||||
k_pos = torch.arange(max_blocks)[None, None]
|
||||
head_sliding_step = max(1, int(vert_stride / num_layout))
|
||||
mask_vert_strided = [
|
||||
(torch.arange(max_blocks) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(num_layout)
|
||||
]
|
||||
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
||||
block_mask = (q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)
|
||||
block_mask = block_mask.to(torch.int32)
|
||||
|
||||
if ENABLE_DEBUG:
|
||||
torch.set_printoptions(profile="full")
|
||||
torch.set_printoptions(edgeitems=100)
|
||||
torch.set_printoptions(linewidth=200)
|
||||
print(block_mask)
|
||||
|
||||
return block_mask
|
||||
|
||||
|
||||
def get_dense_mask(block_mask, total_seq_len, query_seq_len, block_size):
|
||||
dense_mask = torch.kron(block_mask, block_mask.new_ones((block_size, block_size)))[
|
||||
:, :total_seq_len, :total_seq_len
|
||||
]
|
||||
causal_mask = torch.tril(torch.ones(total_seq_len, total_seq_len)).type_as(dense_mask)
|
||||
dense_mask = dense_mask * causal_mask[None]
|
||||
return dense_mask[..., -query_seq_len:, :total_seq_len]
|
||||
|
||||
|
||||
def create_sparse_attention_onnx_model(config):
|
||||
assert config.is_fp16 # python does not support bfloat16 for I/O binding.
|
||||
|
||||
float_type = TensorProto.FLOAT16
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"SparseAttention",
|
||||
[
|
||||
"query",
|
||||
"key" if not config.is_packed_qkv else "",
|
||||
"value" if not config.is_packed_qkv else "",
|
||||
"past_key",
|
||||
"past_value",
|
||||
"block_mask",
|
||||
"total_sequence_length" if config.share_buffer else "",
|
||||
"key_total_sequence_lengths",
|
||||
"cos_cache" if config.do_rotary else "",
|
||||
"sin_cache" if config.do_rotary else "",
|
||||
],
|
||||
["output", "present_key", "present_value"],
|
||||
"SparseAttention_0",
|
||||
num_heads=config.num_heads,
|
||||
kv_num_heads=config.kv_num_heads,
|
||||
scale=config.softmax_scale,
|
||||
sparse_block_size=config.sparse_block_size,
|
||||
do_rotary=1 if config.do_rotary else 0,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
shape_dict = config.shape_dict()
|
||||
graph_input = [
|
||||
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
|
||||
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
|
||||
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
|
||||
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
|
||||
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
|
||||
helper.make_tensor_value_info("block_mask", TensorProto.INT32, list(shape_dict["block_mask"])),
|
||||
helper.make_tensor_value_info(
|
||||
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
|
||||
),
|
||||
helper.make_tensor_value_info(
|
||||
"key_total_sequence_lengths", TensorProto.INT32, list(shape_dict["key_total_sequence_lengths"])
|
||||
),
|
||||
]
|
||||
|
||||
if config.do_rotary:
|
||||
graph_input += [
|
||||
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
|
||||
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
|
||||
]
|
||||
|
||||
graph_output = [
|
||||
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
|
||||
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
|
||||
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"SparseAttention_Graph",
|
||||
graph_input,
|
||||
graph_output,
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
return model.SerializeToString()
|
||||
|
||||
|
||||
def create_group_query_attention_onnx_model(config):
|
||||
assert config.is_fp16 # python does not support bfloat16 for I/O binding.
|
||||
|
||||
float_type = TensorProto.FLOAT16
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"GroupQueryAttention",
|
||||
[
|
||||
"query",
|
||||
"key" if not config.is_packed_qkv else "",
|
||||
"value" if not config.is_packed_qkv else "",
|
||||
"past_key",
|
||||
"past_value",
|
||||
"seqlens_k",
|
||||
"total_sequence_length" if config.share_buffer else "",
|
||||
"cos_cache" if config.do_rotary else "",
|
||||
"sin_cache" if config.do_rotary else "",
|
||||
],
|
||||
["output", "present_key", "present_value"],
|
||||
"GroupQueryAttention_0",
|
||||
num_heads=config.num_heads,
|
||||
kv_num_heads=config.kv_num_heads,
|
||||
local_window_size=config.local_blocks * config.sparse_block_size if config.use_sparse else -1,
|
||||
do_rotary=1 if config.do_rotary else 0,
|
||||
rotary_interleaved=config.rotary_interleaved,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
shape_dict = config.shape_dict()
|
||||
graph_input = [
|
||||
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
|
||||
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
|
||||
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
|
||||
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
|
||||
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
|
||||
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])),
|
||||
helper.make_tensor_value_info(
|
||||
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
|
||||
),
|
||||
]
|
||||
|
||||
if config.do_rotary:
|
||||
graph_input += [
|
||||
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
|
||||
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
|
||||
]
|
||||
|
||||
graph_output = [
|
||||
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
|
||||
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
|
||||
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"GroupQueryAttention_Graph",
|
||||
graph_input,
|
||||
graph_output,
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
return model.SerializeToString()
|
||||
|
||||
|
||||
def create_session(onnx_model_str, config: Config, cuda_provider_options=None) -> InferenceSession:
|
||||
session_options = SessionOptions()
|
||||
ort_session = InferenceSession(
|
||||
onnx_model_str,
|
||||
session_options,
|
||||
providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"],
|
||||
)
|
||||
return ort_session
|
||||
|
||||
|
||||
def group_query_attention_reference(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
config: Config,
|
||||
scale: Optional[float] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
):
|
||||
if scale is None:
|
||||
scale = 1.0 / (config.head_size**0.5)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Expand key and value to have same number of heads as query
|
||||
num_key_value_groups = config.num_heads // config.kv_num_heads
|
||||
key = torch.repeat_interleave(key, dim=1, repeats=num_key_value_groups)
|
||||
value = torch.repeat_interleave(value, dim=1, repeats=num_key_value_groups)
|
||||
|
||||
# Apply multi-head attention.
|
||||
attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill((1 - mask).bool(), float("-inf"))
|
||||
attn = attn.softmax(-1)
|
||||
attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value)
|
||||
|
||||
return attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class TorchGroupQueryAttention:
|
||||
"""A wrapper of Torch GroupQueryAttention to test relevance and performance."""
|
||||
|
||||
def __init__(self, device, config: Config, feed_dict):
|
||||
self.device = device
|
||||
self.config = config
|
||||
self.query = feed_dict["query"].view(
|
||||
config.batch_size, config.sequence_length, config.num_heads, config.head_size
|
||||
)
|
||||
self.key = feed_dict["key"].view(
|
||||
config.batch_size, config.sequence_length, config.kv_num_heads, config.head_size
|
||||
)
|
||||
self.value = feed_dict["value"].view(
|
||||
config.batch_size, config.sequence_length, config.kv_num_heads, config.head_size
|
||||
)
|
||||
self.dense_mask = config.dense_mask()
|
||||
if ENABLE_DEBUG:
|
||||
torch.set_printoptions(precision=6, edgeitems=3, linewidth=1000, profile="full", sci_mode=False)
|
||||
print("query(BNSH)", self.query.clone().transpose(1, 2))
|
||||
print("key(BNSH)", self.key.clone().transpose(1, 2))
|
||||
print("value(BNSH)", self.value.clone().transpose(1, 2))
|
||||
print("dense_mask", self.dense_mask)
|
||||
|
||||
def infer(self):
|
||||
return group_query_attention_reference(
|
||||
self.query, self.key, self.value, self.config, scale=self.config.softmax_scale, mask=self.dense_mask
|
||||
)
|
||||
|
||||
|
||||
class OrtGroupQueryAttention:
|
||||
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
|
||||
|
||||
def __init__(self, device, config: Config, feed_dict):
|
||||
cuda_provider_options = CudaSession.get_cuda_provider_options(
|
||||
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
|
||||
)
|
||||
onnx_model_str = create_group_query_attention_onnx_model(config)
|
||||
self.ort_session = create_session(onnx_model_str, config, cuda_provider_options=cuda_provider_options)
|
||||
self.gpu_binding_manager = GpuBindingManager(
|
||||
ort_session=self.ort_session,
|
||||
device=device,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
max_cuda_graphs=2,
|
||||
)
|
||||
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
|
||||
self.gpu_binding = self.gpu_binding_manager.get_binding(
|
||||
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
|
||||
)
|
||||
self.feed_dict = feed_dict
|
||||
|
||||
def infer(self):
|
||||
return self.gpu_binding.infer(self.feed_dict)
|
||||
|
||||
|
||||
class OrtSparseAttention:
|
||||
"""A wrapper of ORT SparseAttention to test relevance and performance."""
|
||||
|
||||
def __init__(self, device, config: Config, feed_dict):
|
||||
cuda_provider_options = CudaSession.get_cuda_provider_options(
|
||||
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
|
||||
)
|
||||
onnx_model_str = create_sparse_attention_onnx_model(config)
|
||||
self.ort_session = create_session(onnx_model_str, config, cuda_provider_options=cuda_provider_options)
|
||||
self.gpu_binding_manager = GpuBindingManager(
|
||||
ort_session=self.ort_session,
|
||||
device=device,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
max_cuda_graphs=2,
|
||||
)
|
||||
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
|
||||
self.gpu_binding = self.gpu_binding_manager.get_binding(
|
||||
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
|
||||
)
|
||||
self.feed_dict = feed_dict
|
||||
|
||||
def infer(self):
|
||||
return self.gpu_binding.infer(self.feed_dict)
|
||||
|
||||
|
||||
def run_one_relevance_test(device, config: Config):
|
||||
dtype = torch.float16
|
||||
|
||||
# Run QGA ort
|
||||
config.use_sparse = False
|
||||
config.operator = "GroupQueryAttention"
|
||||
feed_dict = config.random_inputs(dtype=dtype)
|
||||
if config.past_sequence_length == 0:
|
||||
obj = TorchGroupQueryAttention(device, config, feed_dict)
|
||||
expected_out = obj.infer()
|
||||
else:
|
||||
obj = OrtGroupQueryAttention(device, config, feed_dict)
|
||||
ort_qga_outputs = obj.infer()
|
||||
expected_out = ort_qga_outputs["output"].view(
|
||||
config.batch_size, config.sequence_length, config.num_heads, config.head_size
|
||||
)
|
||||
|
||||
# Run SparseAttention by ORT
|
||||
config.use_sparse = True
|
||||
config.operator = "SparseAttention"
|
||||
if config.past_sequence_length != 0:
|
||||
config.local_blocks = config.max_blocks # Use dense to compare with GQA
|
||||
feed_dict = config.random_inputs(dtype=dtype)
|
||||
if ENABLE_DEBUG:
|
||||
print("block_mask", feed_dict["block_mask"])
|
||||
print("total_sequence_length", feed_dict["total_sequence_length"])
|
||||
print("key_total_sequence_lengths", feed_dict["key_total_sequence_lengths"])
|
||||
obj = OrtSparseAttention(device, config, feed_dict)
|
||||
ort_outputs = obj.infer()
|
||||
ort_output = ort_outputs["output"]
|
||||
actual_out = ort_output.view(config.batch_size, config.sequence_length, config.num_heads, config.head_size)
|
||||
|
||||
if torch.allclose(expected_out, actual_out, atol=1e-2, rtol=0):
|
||||
print(f"Relevance test passed: {vars(config)}")
|
||||
else:
|
||||
print(f"Relevance test not passed: {vars(config)}")
|
||||
print("ort_output", actual_out)
|
||||
print("expected_out", expected_out)
|
||||
print("diff", expected_out - actual_out)
|
||||
exit(1)
|
||||
|
||||
|
||||
def run_relevance_no_past(device):
|
||||
"""Test prompt prefilling without past kv cache."""
|
||||
for seq_len in [1, 64, 127, 128, 192, 256]:
|
||||
config = Config(
|
||||
batch_size=1,
|
||||
sequence_length=seq_len,
|
||||
max_sequence_length=256,
|
||||
past_sequence_length=0,
|
||||
num_heads=8,
|
||||
kv_num_heads=4,
|
||||
head_size=128,
|
||||
sparse_block_size=64,
|
||||
num_layout=2,
|
||||
local_blocks=2,
|
||||
vert_stride=2,
|
||||
softmax_scale=1.8 / (128**0.5),
|
||||
)
|
||||
run_one_relevance_test(device, config)
|
||||
|
||||
|
||||
def run_relevance_past(device):
|
||||
"""Test token generation with past kv cache."""
|
||||
for past_seq_len in [1, 63, 64, 127, 128, 511]:
|
||||
config = Config(
|
||||
batch_size=2,
|
||||
sequence_length=1,
|
||||
max_sequence_length=512,
|
||||
past_sequence_length=past_seq_len,
|
||||
num_heads=8,
|
||||
kv_num_heads=4,
|
||||
head_size=128,
|
||||
sparse_block_size=64,
|
||||
num_layout=4,
|
||||
local_blocks=2,
|
||||
vert_stride=4,
|
||||
do_rotary=True,
|
||||
rotary_interleaved=(past_seq_len % 2 == 1),
|
||||
)
|
||||
run_one_relevance_test(device, config)
|
||||
|
||||
|
||||
def run_relevance_test():
|
||||
device_id = torch.cuda.current_device()
|
||||
device = torch.device("cuda", device_id)
|
||||
with torch.no_grad():
|
||||
run_relevance_no_past(device)
|
||||
run_relevance_past(device)
|
||||
|
||||
|
||||
def plot_prompt_performance(
|
||||
batch_size=4,
|
||||
num_heads=32,
|
||||
max_seq_len=8192,
|
||||
head_size=128,
|
||||
sparse_block_size=64,
|
||||
local_blocks=16,
|
||||
vert_stride=8,
|
||||
num_layout=8,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
import triton
|
||||
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["sequence_length"],
|
||||
x_vals=[2**i for i in range(4, 14)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_gqa", "ort_gqa", "ort_gqa_local", "ort_sparse_att"],
|
||||
line_names=["TORCH-GQA", "ORT-GQA-Dense", "ORT-GQA-Local", "ORT-SparseAtt"],
|
||||
styles=[("red", "-"), ("yellow", "-"), ("blue", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"prompt-batch{batch_size}-head{num_heads}-d{head_size}-local{local_blocks}-vert{vert_stride}-{dtype}",
|
||||
args={"num_heads": num_heads, "batch_size": batch_size, "head_size": head_size, "dtype": dtype},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(batch_size, num_heads, sequence_length, head_size, provider, dtype=torch.float16, device="cuda"):
|
||||
warmup = 15
|
||||
repeat = 100
|
||||
|
||||
config = Config(
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
max_sequence_length=max_seq_len,
|
||||
past_sequence_length=0,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=8,
|
||||
head_size=head_size,
|
||||
sparse_block_size=sparse_block_size,
|
||||
num_layout=num_layout,
|
||||
local_blocks=local_blocks,
|
||||
vert_stride=vert_stride,
|
||||
)
|
||||
|
||||
config.use_sparse = provider in ["ort_sparse_att", "ort_gqa_local"]
|
||||
config.operator = "SparseAttention" if provider in ["ort_sparse_att"] else "GroupQueryAttention"
|
||||
feed_dict = config.random_inputs(dtype=dtype)
|
||||
|
||||
if provider in ["ort_gqa", "ort_gqa_local"]:
|
||||
obj = OrtGroupQueryAttention(device, config, feed_dict)
|
||||
elif provider == "ort_sparse_att":
|
||||
obj = OrtSparseAttention(device, config, feed_dict)
|
||||
else:
|
||||
assert provider == "torch_gqa"
|
||||
if sequence_length > 2048: # out of memory
|
||||
return 0
|
||||
obj = TorchGroupQueryAttention(device, config, feed_dict)
|
||||
|
||||
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
|
||||
return ms
|
||||
|
||||
benchmark.run(save_path=".", print_data=True)
|
||||
|
||||
|
||||
def plot_token_performance(
|
||||
batch_size=4,
|
||||
num_heads=32,
|
||||
max_seq_len=8192,
|
||||
head_size=128,
|
||||
sparse_block_size=64,
|
||||
local_blocks=16,
|
||||
vert_stride=8,
|
||||
num_layout=8,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
import triton
|
||||
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["past_sequence_length"],
|
||||
x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_gqa", "ort_gqa", "ort_gqa_local", "ort_sparse_att"],
|
||||
line_names=["TORCH-GQA", "ORT-GQA-Dense", "ORT-GQA-Local", "ORT-SparseAtt"],
|
||||
styles=[("red", "-"), ("yellow", "-"), ("blue", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"token-batch{batch_size}-head{num_heads}-d{head_size}-local{local_blocks}-vert{vert_stride}-{dtype}",
|
||||
args={"num_heads": num_heads, "batch_size": batch_size, "head_size": head_size, "dtype": dtype},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(batch_size, num_heads, past_sequence_length, head_size, provider, dtype=torch.float16, device="cuda"):
|
||||
warmup = 15
|
||||
repeat = 100
|
||||
|
||||
config = Config(
|
||||
batch_size=batch_size,
|
||||
sequence_length=1,
|
||||
max_sequence_length=max_seq_len,
|
||||
past_sequence_length=past_sequence_length,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=8,
|
||||
head_size=head_size,
|
||||
sparse_block_size=sparse_block_size,
|
||||
num_layout=num_layout,
|
||||
local_blocks=local_blocks,
|
||||
vert_stride=vert_stride,
|
||||
)
|
||||
|
||||
config.use_sparse = provider in ["ort_sparse_att", "ort_gqa_local"]
|
||||
config.operator = "SparseAttention" if provider in ["ort_sparse_att"] else "GroupQueryAttention"
|
||||
feed_dict = config.random_inputs(dtype=dtype)
|
||||
|
||||
if provider in ["ort_gqa", "ort_gqa_local"]:
|
||||
obj = OrtGroupQueryAttention(device, config, feed_dict)
|
||||
elif provider == "ort_sparse_att":
|
||||
obj = OrtSparseAttention(device, config, feed_dict)
|
||||
else:
|
||||
assert provider == "torch_gqa"
|
||||
if past_sequence_length > 2048: # out of memory
|
||||
return 0
|
||||
obj = TorchGroupQueryAttention(device, config, feed_dict)
|
||||
|
||||
ms = triton.testing.do_bench(obj.infer, warmup=warmup, rep=repeat)
|
||||
return ms
|
||||
|
||||
benchmark.run(save_path=".", print_data=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
run_relevance_test()
|
||||
plot_prompt_performance()
|
||||
plot_token_performance()
|
|
@ -100,6 +100,8 @@ unfixable = [
|
|||
# Eventually this list should become empty.
|
||||
"orttraining/orttraining/test/**" = ["N802"] # Function casing
|
||||
"tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix
|
||||
"onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_triton.py" = ["N806"] # use of Q, K and V in triton script
|
||||
"onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_triton.py" = ["N806"] # use of Q, K and V in triton script
|
||||
"onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix
|
||||
"onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix
|
||||
"orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo.
|
||||
|
|
Загрузка…
Ссылка в новой задаче