[WebGPU EP] Support GroupQueryAttention (#22658)
### Description <!-- Describe your changes. --> Support GroupQueryAttention operator for native webgpu ep. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is required for inferencing some LLMs.
This commit is contained in:
Родитель
6c2ff5fc55
Коммит
e8bf46a70e
|
@ -11,18 +11,19 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace group_query_attention_helper {
|
||||
|
||||
Status CheckInputs(const Tensor* query,
|
||||
const Tensor* key,
|
||||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
template <typename T = Tensor>
|
||||
Status CheckInputs(const T* query,
|
||||
const T* key,
|
||||
const T* value,
|
||||
const T* past_key,
|
||||
const T* past_value,
|
||||
const T* cos_cache,
|
||||
const T* sin_cache,
|
||||
void* parameters,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
const Tensor* seqlens_k,
|
||||
const Tensor* total_seqlen,
|
||||
const T* seqlens_k,
|
||||
const T* total_seqlen,
|
||||
float scale,
|
||||
float softcap) {
|
||||
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
|
||||
|
@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckInputs(const Tensor* query,
|
||||
const Tensor* key,
|
||||
const Tensor* value,
|
||||
const Tensor* past_key,
|
||||
const Tensor* past_value,
|
||||
const Tensor* cos_cache,
|
||||
const Tensor* sin_cache,
|
||||
template <typename T = Tensor>
|
||||
Status CheckInputs(const T* query,
|
||||
const T* key,
|
||||
const T* value,
|
||||
const T* past_key,
|
||||
const T* past_value,
|
||||
const T* cos_cache,
|
||||
const T* sin_cache,
|
||||
void* parameters,
|
||||
int num_heads,
|
||||
int kv_num_heads,
|
||||
const Tensor* seqlens_k,
|
||||
const Tensor* total_seqlen,
|
||||
const T* seqlens_k,
|
||||
const T* total_seqlen,
|
||||
float scale,
|
||||
float softcap,
|
||||
int max_threads_per_block) {
|
||||
|
|
|
@ -0,0 +1,459 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/webgpu/bert/attention.h"
|
||||
|
||||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
|
||||
#include "contrib_ops/webgpu/bert/multihead_attention.h"
|
||||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
using namespace onnxruntime::webgpu;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::contrib::multihead_attention_helper;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("qkv_input", ShaderUsage::UseUniform);
|
||||
const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices);
|
||||
|
||||
if (has_bias_) {
|
||||
shader.AddInput("bias", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
|
||||
<< "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n"
|
||||
<< "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *"
|
||||
<< " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n";
|
||||
if (has_bias_) {
|
||||
shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n";
|
||||
}
|
||||
shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]";
|
||||
if (has_bias_) {
|
||||
shader.MainFunctionBody() << " + bias[bias_offset_idx];\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << ";\n";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
|
||||
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) {
|
||||
ORT_ENFORCE(input_tensor->Shape().GetDims().size() == 3);
|
||||
ORT_ENFORCE(output_tensor->Shape().GetDims().size() == 4);
|
||||
|
||||
uint32_t data_size = SafeInt<uint32_t>(output_tensor->Shape().Size());
|
||||
const int batch_offset = num_heads * sequence_length * head_size;
|
||||
const int sequence_offset = num_heads * head_size;
|
||||
const int head_offset = head_size;
|
||||
bool has_bias = bias != nullptr;
|
||||
|
||||
TransferBSDToBNSHProgram program{has_bias};
|
||||
program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
|
||||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
|
||||
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
|
||||
.AddUniformVariables({{data_size},
|
||||
{static_cast<uint32_t>(batch_offset)},
|
||||
{static_cast<uint32_t>(sequence_offset)},
|
||||
{static_cast<uint32_t>(head_offset)},
|
||||
{static_cast<uint32_t>(bias_offset)}});
|
||||
|
||||
if (has_bias) {
|
||||
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
|
||||
return context.RunProgram(program);
|
||||
};
|
||||
|
||||
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) {
|
||||
if (seqlen_k != nullptr) {
|
||||
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
|
||||
ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n";
|
||||
} else {
|
||||
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
if (feed_past_key_) {
|
||||
shader.AddInput("past_key", ShaderUsage::UseUniform);
|
||||
}
|
||||
if (has_attention_bias_) {
|
||||
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
|
||||
}
|
||||
if (seqlen_k_ != nullptr) {
|
||||
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
|
||||
}
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
if (has_present_key_) {
|
||||
shader.AddOutput("present_key", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<q_value_t, " << tile_size_ * tile_size_ << ">;\n"
|
||||
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
|
||||
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
|
||||
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
|
||||
<< "let m = workgroup_id.y * TILE_SIZE;\n"
|
||||
<< "let n = workgroup_id.x * TILE_SIZE;\n"
|
||||
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
|
||||
<< "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
|
||||
<< "let sequence_length = uniforms.M;\n"
|
||||
<< "var total_sequence_length = uniforms.N;\n";
|
||||
std::ostringstream oss;
|
||||
InitVarStub(oss, seqlen_k_, is_first_prompt_);
|
||||
shader.MainFunctionBody() << oss.str();
|
||||
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
|
||||
if (has_present_key_) {
|
||||
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
|
||||
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
|
||||
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
|
||||
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
|
||||
" }\n"
|
||||
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
|
||||
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
|
||||
|
||||
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
|
||||
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
|
||||
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
|
||||
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
|
||||
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
|
||||
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
|
||||
<< " }\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n"
|
||||
" tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
|
||||
" }\n";
|
||||
}
|
||||
|
||||
if (has_present_key_) {
|
||||
if (past_present_share_buffer_) {
|
||||
shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
|
||||
}
|
||||
shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"
|
||||
<< " }\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n"
|
||||
<< " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< "}\n";
|
||||
|
||||
shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
|
||||
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
|
||||
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
|
||||
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
|
||||
|
||||
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
|
||||
if (has_attention_bias_) {
|
||||
shader.MainFunctionBody() << " + attention_bias[outputIdx]";
|
||||
}
|
||||
shader.MainFunctionBody() << ";\n"
|
||||
<< "}\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q,
|
||||
const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key,
|
||||
WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length,
|
||||
const Tensor* seqlen_k) {
|
||||
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
|
||||
: parameters.scale_;
|
||||
|
||||
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
|
||||
const bool has_present_key = output_count > 1 && past_key;
|
||||
const bool has_attention_bias = attention_bias != nullptr;
|
||||
constexpr int tile_size = 12;
|
||||
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
|
||||
|
||||
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
|
||||
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
|
||||
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
|
||||
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
|
||||
if (feed_past_key) {
|
||||
program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components});
|
||||
}
|
||||
if (has_attention_bias) {
|
||||
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
if (seqlen_k != nullptr) {
|
||||
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}});
|
||||
if (has_present_key) {
|
||||
program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components});
|
||||
}
|
||||
|
||||
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
|
||||
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
|
||||
(parameters.sequence_length_ + tile_size - 1) / tile_size,
|
||||
parameters.batch_size_ * parameters.num_heads_)
|
||||
.SetWorkgroupSize(tile_size, tile_size)
|
||||
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
|
||||
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
|
||||
{static_cast<uint32_t>(vectorized_head_size)},
|
||||
{static_cast<uint32_t>(total_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.num_heads_)},
|
||||
{static_cast<uint32_t>(parameters.head_size_)},
|
||||
{static_cast<float>(alpha)},
|
||||
{static_cast<uint32_t>(past_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
|
||||
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
|
||||
{static_cast<uint32_t>(parameters.n_reps)}})
|
||||
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
if (seqlen_k_) {
|
||||
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
|
||||
}
|
||||
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
|
||||
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
|
||||
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
|
||||
shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
|
||||
<< "let sequence_length = uniforms.sequence_length;\n"
|
||||
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
|
||||
std::ostringstream oss;
|
||||
InitVarStub(oss, seqlen_k_, is_first_prompt_);
|
||||
shader.MainFunctionBody() << oss.str()
|
||||
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
|
||||
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
|
||||
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n"
|
||||
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
|
||||
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
|
||||
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
|
||||
<< "}\n"
|
||||
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
|
||||
<< "workgroupBarrier();\n"
|
||||
<< "var max_value = f32(-3.402823e+38f);\n"
|
||||
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
|
||||
<< " max_value = max(thread_max[i], max_value);\n"
|
||||
<< "}\n"
|
||||
<< "var sum_vector = f32_val_t(0);\n"
|
||||
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
|
||||
<< " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n"
|
||||
<< "}\n"
|
||||
<< "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n"
|
||||
<< "workgroupBarrier();\n"
|
||||
<< "var sum: f32 = 0;\n"
|
||||
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
|
||||
<< " sum += thread_sum[i]\n;"
|
||||
<< "}\n"
|
||||
<< "if (sum == 0) {\n"
|
||||
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
|
||||
<< " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n"
|
||||
<< " }\n"
|
||||
<< "} else {\n"
|
||||
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
|
||||
<< " var f32input = f32_val_t(x[offset + i]);\n"
|
||||
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
|
||||
<< " }\n"
|
||||
<< "}\n";
|
||||
if (seqlen_k_) {
|
||||
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n"
|
||||
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
|
||||
<< "}\n";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
|
||||
const Tensor* seqlen_k, bool is_first_prompt) {
|
||||
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
|
||||
int work_group_size = 64;
|
||||
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
|
||||
if (total_sequence_length_comp < work_group_size) {
|
||||
work_group_size = 32;
|
||||
}
|
||||
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;
|
||||
|
||||
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k};
|
||||
if (seqlen_k != nullptr) {
|
||||
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
|
||||
.CacheHint(work_group_size, is_first_prompt)
|
||||
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
|
||||
.SetWorkgroupSize(work_group_size)
|
||||
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
|
||||
{static_cast<uint32_t>(num_heads)},
|
||||
{static_cast<uint32_t>(past_sequence_length)},
|
||||
{static_cast<uint32_t>(sequence_length)},
|
||||
{static_cast<uint32_t>(total_sequence_length_comp)},
|
||||
{static_cast<uint32_t>(elementsPerThread)}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
if (feed_past_value_) {
|
||||
shader.AddInput("past_value", ShaderUsage::UseUniform);
|
||||
}
|
||||
if (seqlen_k_) {
|
||||
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
|
||||
}
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform);
|
||||
if (has_present_value_) {
|
||||
shader.AddOutput("present_value", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << ">;\n"
|
||||
<< "var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << ">;\n";
|
||||
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
|
||||
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
|
||||
<< "let m = global_id.y;\n"
|
||||
<< "let n = global_id.x;\n"
|
||||
<< "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
|
||||
<< "let sequence_length = uniforms.M;\n"
|
||||
<< "var total_sequence_length = uniforms.K;\n";
|
||||
std::ostringstream oss;
|
||||
InitVarStub(oss, seqlen_k_, is_first_prompt_);
|
||||
shader.MainFunctionBody() << oss.str();
|
||||
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
|
||||
if (has_present_value_) {
|
||||
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << "var value = probs_element_t(0);\n"
|
||||
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
|
||||
<< " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n"
|
||||
<< " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n"
|
||||
<< " }\n"
|
||||
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
|
||||
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";
|
||||
|
||||
if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
|
||||
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
|
||||
<< " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n"
|
||||
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
|
||||
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
|
||||
<< " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n"
|
||||
<< " }\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n"
|
||||
<< " tileK[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n"
|
||||
<< " }\n";
|
||||
}
|
||||
|
||||
if (has_present_value_) {
|
||||
if (past_present_share_buffer_) {
|
||||
shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
|
||||
}
|
||||
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"
|
||||
<< " }\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {\n"
|
||||
<< " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< "}\n";
|
||||
|
||||
shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
|
||||
<< "if (m < uniforms.M && n < uniforms.N) {\n"
|
||||
<< " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + "
|
||||
<< " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n"
|
||||
<< " output[outputIdx] = value;\n"
|
||||
<< "}\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count,
|
||||
const Tensor* probs,
|
||||
const Tensor* V,
|
||||
const Tensor* past_value,
|
||||
Tensor* output,
|
||||
Tensor* present_value,
|
||||
WebgpuAttentionParameters& parameters,
|
||||
int past_sequence_length,
|
||||
int total_sequence_length,
|
||||
const Tensor* seqlen_k) {
|
||||
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
|
||||
const bool has_present_value = output_count > 1 && past_value != nullptr;
|
||||
constexpr int tile_size = 12;
|
||||
|
||||
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
|
||||
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
|
||||
{V, ProgramTensorMetadataDependency::TypeAndRank}});
|
||||
if (feed_past_value) {
|
||||
program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
if (seqlen_k != nullptr) {
|
||||
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}});
|
||||
if (has_present_value) {
|
||||
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
|
||||
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size,
|
||||
(parameters.sequence_length_ + tile_size - 1) / tile_size,
|
||||
parameters.batch_size_ * parameters.num_heads_)
|
||||
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
|
||||
.SetWorkgroupSize(tile_size, tile_size)
|
||||
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
|
||||
{static_cast<uint32_t>(total_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.v_head_size_)},
|
||||
{static_cast<uint32_t>(parameters.num_heads_)},
|
||||
{static_cast<uint32_t>(parameters.head_size_)},
|
||||
{static_cast<uint32_t>(parameters.v_hidden_size_ * parameters.n_reps)},
|
||||
{static_cast<uint32_t>(past_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
|
||||
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
|
||||
{static_cast<uint32_t>(parameters.n_reps)}})
|
||||
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
|
||||
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
|
||||
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
|
||||
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
|
||||
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
|
||||
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
|
||||
|
||||
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.sequence_length_, total_sequence_length});
|
||||
const TensorShape probs_shape(probs_dims);
|
||||
Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape);
|
||||
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
|
||||
parameters, past_sequence_length, total_sequence_length, seqlen_k));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
|
||||
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
|
||||
parameters, past_sequence_length, total_sequence_length, seqlen_k));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,123 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/webgpu/compute_context.h"
|
||||
#include "core/providers/webgpu/program.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
#include "contrib_ops/webgpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
using namespace onnxruntime::webgpu;
|
||||
|
||||
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
|
||||
public:
|
||||
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"batch_offset", ProgramUniformVariableDataType::Uint32},
|
||||
{"sequence_offset", ProgramUniformVariableDataType::Uint32},
|
||||
{"head_offset", ProgramUniformVariableDataType::Uint32},
|
||||
{"bias_offset", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool has_bias_;
|
||||
};
|
||||
|
||||
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
|
||||
public:
|
||||
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
|
||||
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
|
||||
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
|
||||
{"K", ProgramUniformVariableDataType::Uint32},
|
||||
{"N", ProgramUniformVariableDataType::Uint32},
|
||||
{"num_heads", ProgramUniformVariableDataType::Uint32},
|
||||
{"head_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"alpha", ProgramUniformVariableDataType::Float32},
|
||||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"n_reps", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool feed_past_key_;
|
||||
bool has_present_key_;
|
||||
bool has_attention_bias_;
|
||||
int tile_size_;
|
||||
int components_;
|
||||
int n_reps_;
|
||||
const Tensor* seqlen_k_;
|
||||
bool past_present_share_buffer_;
|
||||
bool is_first_prompt_;
|
||||
};
|
||||
|
||||
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
|
||||
public:
|
||||
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr)
|
||||
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"num_heads", ProgramUniformVariableDataType::Uint32},
|
||||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
|
||||
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
private:
|
||||
int work_group_size_;
|
||||
int components_;
|
||||
const Tensor* seqlen_k_;
|
||||
bool is_first_prompt_;
|
||||
};
|
||||
|
||||
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
|
||||
public:
|
||||
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
|
||||
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
|
||||
{"K", ProgramUniformVariableDataType::Uint32},
|
||||
{"N", ProgramUniformVariableDataType::Uint32},
|
||||
{"num_heads", ProgramUniformVariableDataType::Uint32},
|
||||
{"head_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"v_hidden_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"n_reps", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool feed_past_value_;
|
||||
bool has_present_value_;
|
||||
int tile_size_;
|
||||
int n_reps_;
|
||||
const Tensor* seqlen_k_;
|
||||
bool past_present_share_buffer_;
|
||||
bool is_first_prompt_;
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/webgpu/compute_context.h"
|
||||
#include "core/providers/webgpu/program.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
#include "contrib_ops/webgpu/bert/attention_common.h"
|
||||
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
struct WebgpuAttentionParameters {
|
||||
explicit WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false),
|
||||
batch_size_(parameters.batch_size),
|
||||
sequence_length_(parameters.sequence_length),
|
||||
kv_sequence_length_(parameters.kv_sequence_length),
|
||||
past_sequence_length_(parameters.past_sequence_length),
|
||||
total_sequence_length_(parameters.total_sequence_length),
|
||||
max_sequence_length_(parameters.max_sequence_length),
|
||||
input_hidden_size_(parameters.input_hidden_size),
|
||||
hidden_size_(parameters.hidden_size),
|
||||
head_size_(parameters.head_size),
|
||||
v_hidden_size_(parameters.v_hidden_size),
|
||||
v_head_size_(parameters.v_head_size),
|
||||
num_heads_(parameters.num_heads),
|
||||
is_unidirectional_(parameters.is_unidirectional),
|
||||
past_present_share_buffer_(parameters.past_present_share_buffer),
|
||||
do_rotary_(parameters.do_rotary),
|
||||
broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0),
|
||||
broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1),
|
||||
mask_filter_value_(parameters.mask_filter_value),
|
||||
scale_(parameters.scale),
|
||||
mask_type_(parameters.mask_type),
|
||||
qkv_format_(parameters.qkv_format) {
|
||||
}
|
||||
|
||||
explicit WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true),
|
||||
batch_size_(parameters.batch_size),
|
||||
sequence_length_(parameters.sequence_length),
|
||||
kv_sequence_length_(parameters.sequence_length),
|
||||
past_sequence_length_(parameters.seqlen_past_kv_cache),
|
||||
total_sequence_length_(parameters.total_sequence_length),
|
||||
hidden_size_(parameters.hidden_size),
|
||||
head_size_(parameters.head_size),
|
||||
v_hidden_size_(parameters.kv_hidden_size),
|
||||
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads),
|
||||
num_heads_(parameters.num_heads),
|
||||
do_rotary_(parameters.do_rotary),
|
||||
scale_(parameters.scale),
|
||||
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache),
|
||||
seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache),
|
||||
kv_hidden_size_(parameters.kv_hidden_size),
|
||||
kv_num_heads_(parameters.kv_num_heads),
|
||||
num_splits_(parameters.num_splits),
|
||||
rotary_dim_(parameters.rotary_dim),
|
||||
is_packed_qkv_(parameters.is_packed_qkv),
|
||||
is_subsequent_prompt_(parameters.is_subsequent_prompt),
|
||||
is_first_prompt_(parameters.is_first_prompt),
|
||||
rotary_interleaved_(parameters.rotary_interleaved),
|
||||
use_smooth_softmax_(parameters.use_smooth_softmax),
|
||||
softcap_(parameters.softcap),
|
||||
zeros_count_(parameters.zeros_count),
|
||||
zero_ptr_(parameters.zero_ptr),
|
||||
n_reps(parameters.num_heads / parameters.kv_num_heads),
|
||||
qkv_format_(parameters.qkv_format) {
|
||||
}
|
||||
|
||||
bool is_gqa_;
|
||||
int batch_size_ = 0;
|
||||
int sequence_length_ = 0;
|
||||
int kv_sequence_length_ = 0; // input sequence length of K or V
|
||||
int past_sequence_length_ = 0; // sequence length in past state of K or V
|
||||
int total_sequence_length_ = 0; // total sequence length of K or V
|
||||
int max_sequence_length_ = 0; // max sequence length from 4D mask
|
||||
int input_hidden_size_ = 0; // first dimension of weights for input projection
|
||||
int hidden_size_ = 0; // hidden size of Q or K
|
||||
int head_size_ = 0; // hidden size per head of Q or K
|
||||
int v_hidden_size_ = 0; // hidden size of V
|
||||
int v_head_size_ = 0; // hidden size per head of V
|
||||
int num_heads_ = 0;
|
||||
int rotary_embedding_ = 0;
|
||||
bool is_unidirectional_ = false;
|
||||
bool past_present_share_buffer_ = false;
|
||||
bool do_rotary_ = false;
|
||||
bool broadcast_attn_bias_dim_0_ = false;
|
||||
bool broadcast_attn_bias_dim_1_ = false;
|
||||
float mask_filter_value_ = -10000.0f;
|
||||
float scale_ = 0.0f;
|
||||
bool use_tf32_ = false;
|
||||
;
|
||||
// The following members are in onnxruntime::contrib::GroupQueryAttentionParameters
|
||||
// and not in onnxruntime::contrib::AttentionParameters
|
||||
int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor
|
||||
int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor
|
||||
int kv_hidden_size_ = 0;
|
||||
int kv_num_heads_ = 0;
|
||||
int num_splits_ = 0; // number of splits for splitkv
|
||||
int rotary_dim_ = 0; // rotary embedding dimension
|
||||
int local_window_size_ = 0;
|
||||
bool kv_share_buffer_ = false;
|
||||
bool is_packed_qkv_ = false;
|
||||
bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1
|
||||
bool is_first_prompt_ = false; // indicates whether this is first decoding step
|
||||
bool rotary_interleaved_ = false;
|
||||
bool use_smooth_softmax_ = false;
|
||||
float softcap_ = 0.0;
|
||||
int zeros_count_ = 0;
|
||||
;
|
||||
int* zero_ptr_ = nullptr;
|
||||
// Computed values
|
||||
int n_reps = 1;
|
||||
AttentionMaskType mask_type_ = MASK_NONE;
|
||||
AttentionQkvFormat qkv_format_ = UNKNOWN;
|
||||
};
|
||||
|
||||
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
|
||||
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor);
|
||||
|
||||
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
|
||||
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
|
||||
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
|
||||
#include "contrib_ops/webgpu/bert/attention_common.h"
|
||||
#include "contrib_ops/webgpu/bert/group_query_attention.h"
|
||||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
|
||||
|
||||
#include "core/providers/webgpu/webgpu_supported_types.h"
|
||||
|
||||
using namespace onnxruntime::webgpu;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::contrib::group_query_attention_helper;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GroupQueryAttention,
|
||||
kMSDomain,
|
||||
1,
|
||||
kWebGpuExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", WebGpuSupportedFloatTypes())
|
||||
.MayInplace(3, 1)
|
||||
.MayInplace(4, 2)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 6),
|
||||
GroupQueryAttention);
|
||||
|
||||
Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
|
||||
const Tensor* query = context.Input<Tensor>(0);
|
||||
const Tensor* key = context.Input<Tensor>(1);
|
||||
const Tensor* value = context.Input<Tensor>(2);
|
||||
const Tensor* past_key = context.Input<Tensor>(3);
|
||||
const Tensor* past_value = context.Input<Tensor>(4);
|
||||
const Tensor* seqlen_k = context.Input<Tensor>(5);
|
||||
const Tensor* total_seqlen_tensor = context.Input<Tensor>(6);
|
||||
const Tensor* cos_cache = context.Input<Tensor>(7);
|
||||
const Tensor* sin_cache = context.Input<Tensor>(8);
|
||||
|
||||
GroupQueryAttentionParameters params;
|
||||
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
|
||||
key,
|
||||
value,
|
||||
past_key,
|
||||
past_value,
|
||||
cos_cache,
|
||||
sin_cache,
|
||||
¶ms,
|
||||
num_heads_,
|
||||
kv_num_heads_,
|
||||
seqlen_k,
|
||||
total_seqlen_tensor,
|
||||
scale_,
|
||||
softcap_));
|
||||
WebgpuAttentionParameters parameters(params);
|
||||
if (parameters.is_packed_qkv_) {
|
||||
ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep.");
|
||||
}
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
|
||||
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);
|
||||
output_shape[2] = static_cast<int64_t>(parameters.hidden_size_);
|
||||
Tensor* output = context.Output(0, output_shape);
|
||||
std::vector<int64_t> present_dims{
|
||||
parameters.batch_size_,
|
||||
kv_num_heads_,
|
||||
parameters.seqlen_present_kv_cache_,
|
||||
parameters.head_size_};
|
||||
std::vector<int64_t> present_kv_shape(present_dims);
|
||||
Tensor* present_key = context.Output(1, present_kv_shape);
|
||||
Tensor* present_value = context.Output(2, present_kv_shape);
|
||||
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();
|
||||
|
||||
TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.sequence_length_, parameters.head_size_});
|
||||
TensorShape q_new_shape(q_new_dims);
|
||||
Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape);
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
|
||||
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q));
|
||||
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
|
||||
return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key,
|
||||
present_value, parameters, context, seqlen_k);
|
||||
}
|
||||
|
||||
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
|
||||
parameters.kv_sequence_length_, parameters.head_size_});
|
||||
TensorShape k_new_shape(k_new_dims);
|
||||
Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape);
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
|
||||
parameters.head_size_, key, nullptr, 0, &K));
|
||||
|
||||
TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
|
||||
parameters.kv_sequence_length_, parameters.v_head_size_});
|
||||
TensorShape v_new_shape(v_new_dims);
|
||||
Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape);
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
|
||||
parameters.v_head_size_, value, nullptr, 0, &V));
|
||||
return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key,
|
||||
present_value, parameters, context, seqlen_k);
|
||||
}
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/webgpu/compute_context.h"
|
||||
#include "core/providers/webgpu/program.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace webgpu {
|
||||
|
||||
using namespace onnxruntime::webgpu;
|
||||
|
||||
class GroupQueryAttention final : public WebGpuKernel {
|
||||
public:
|
||||
GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) {
|
||||
int64_t num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
|
||||
int64_t kv_num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
|
||||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
softcap_ = info.GetAttrOrDefault<float>("softcap", 0.0f);
|
||||
|
||||
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
|
||||
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
|
||||
|
||||
use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;
|
||||
|
||||
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
|
||||
}
|
||||
|
||||
int num_heads_; // number of attention heads of Q
|
||||
int kv_num_heads_; // number of attention heads of K or V
|
||||
float scale_; // the scaling factor applied before softmax
|
||||
float softcap_;
|
||||
bool do_rotary_; // whether or not to use rotary embeddings
|
||||
bool rotary_interleaved_;
|
||||
int local_window_size_;
|
||||
|
||||
bool use_smooth_softmax_;
|
||||
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
|
||||
#include "contrib_ops/webgpu/bert/attention_common.h"
|
||||
#include "contrib_ops/webgpu/bert/multihead_attention.h"
|
||||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
|
||||
|
||||
|
@ -25,392 +26,8 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
|
||||
MultiHeadAttention);
|
||||
|
||||
Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("qkv_input", ShaderUsage::UseUniform);
|
||||
const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices);
|
||||
|
||||
if (has_bias_) {
|
||||
shader.AddInput("bias", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
|
||||
<< "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n"
|
||||
<< "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *"
|
||||
<< " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n";
|
||||
if (has_bias_) {
|
||||
shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n";
|
||||
}
|
||||
shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]";
|
||||
if (has_bias_) {
|
||||
shader.MainFunctionBody() << " + bias[bias_offset_idx];\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << ";\n";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
|
||||
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) {
|
||||
assert(input_tensor->Shape().GetDims().size() == 3);
|
||||
assert(output_tensor->Shape().GetDims().size() == 4);
|
||||
|
||||
uint32_t data_size = gsl::narrow<uint32_t>(output_tensor->Shape().Size());
|
||||
const int batch_offset = num_heads * sequence_length * head_size;
|
||||
const int sequence_offset = num_heads * head_size;
|
||||
const int head_offset = head_size;
|
||||
bool has_bias = bias != nullptr;
|
||||
|
||||
TransferBSDToBNSHProgram program{has_bias};
|
||||
program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
|
||||
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
|
||||
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
|
||||
.AddUniformVariables({{data_size},
|
||||
{static_cast<uint32_t>(batch_offset)},
|
||||
{static_cast<uint32_t>(sequence_offset)},
|
||||
{static_cast<uint32_t>(head_offset)},
|
||||
{static_cast<uint32_t>(bias_offset)}});
|
||||
|
||||
if (has_bias) {
|
||||
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
|
||||
return context.RunProgram(program);
|
||||
};
|
||||
|
||||
Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
if (feed_past_key_) {
|
||||
shader.AddInput("past_key", ShaderUsage::UseUniform);
|
||||
}
|
||||
if (has_attention_bias_) {
|
||||
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
if (has_present_key_) {
|
||||
shader.AddOutput("present_key", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<q_value_t, " << tile_size_ * tile_size_ << ">;\n"
|
||||
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
|
||||
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
|
||||
|
||||
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
|
||||
"let headIdx = workgroup_id.z;\n"
|
||||
"let m = workgroup_id.y * TILE_SIZE;\n"
|
||||
"let n = workgroup_id.x * TILE_SIZE;\n"
|
||||
"let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n";
|
||||
|
||||
if (feed_past_key_ && has_present_key_) {
|
||||
shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n"
|
||||
<< "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n";
|
||||
}
|
||||
|
||||
if (has_present_key_) {
|
||||
shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
|
||||
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
|
||||
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
|
||||
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
|
||||
" }\n"
|
||||
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
|
||||
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
|
||||
|
||||
if (feed_past_key_ && has_present_key_) {
|
||||
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n"
|
||||
" tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
|
||||
" } else {\n"
|
||||
" tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n"
|
||||
" }\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n";
|
||||
}
|
||||
|
||||
if (has_present_key_) {
|
||||
shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n"
|
||||
<< " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< "}\n";
|
||||
|
||||
shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n"
|
||||
<< "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n"
|
||||
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
|
||||
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
|
||||
|
||||
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
|
||||
if (has_attention_bias_) {
|
||||
shader.MainFunctionBody() << " + attention_bias[outputIdx]";
|
||||
}
|
||||
shader.MainFunctionBody() << ";\n"
|
||||
<< "}\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q,
|
||||
const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key,
|
||||
AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) {
|
||||
const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size))
|
||||
: parameters.scale;
|
||||
|
||||
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0;
|
||||
const bool has_present_key = output_count > 1 && past_key;
|
||||
const bool has_attention_bias = attention_bias != nullptr;
|
||||
constexpr int tile_size = 12;
|
||||
const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1);
|
||||
|
||||
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
|
||||
components};
|
||||
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
|
||||
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
|
||||
if (feed_past_key) {
|
||||
program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components});
|
||||
}
|
||||
if (has_attention_bias) {
|
||||
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}});
|
||||
if (has_present_key) {
|
||||
program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components});
|
||||
}
|
||||
|
||||
const uint32_t vectorized_head_size = parameters.head_size / components;
|
||||
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
|
||||
(parameters.sequence_length + tile_size - 1) / tile_size,
|
||||
parameters.batch_size * parameters.num_heads)
|
||||
.SetWorkgroupSize(tile_size, tile_size)
|
||||
.CacheHint(std::to_string(tile_size))
|
||||
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length)},
|
||||
{static_cast<uint32_t>(vectorized_head_size)},
|
||||
{static_cast<uint32_t>(total_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.num_heads)},
|
||||
{static_cast<float>(alpha)},
|
||||
{static_cast<uint32_t>(past_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.kv_sequence_length)}})
|
||||
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
|
||||
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
|
||||
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
|
||||
|
||||
shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n"
|
||||
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n"
|
||||
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
|
||||
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
|
||||
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
|
||||
<< "}\n"
|
||||
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
|
||||
<< "workgroupBarrier();\n"
|
||||
<< "var max_value = f32(-3.402823e+38f);\n"
|
||||
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
|
||||
<< " max_value = max(thread_max[i], max_value);\n"
|
||||
<< "}\n"
|
||||
<< "var sum_vector = f32_val_t(0);\n"
|
||||
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
|
||||
<< " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n"
|
||||
<< "}\n"
|
||||
<< "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n"
|
||||
<< "workgroupBarrier();\n"
|
||||
<< "var sum: f32 = 0;\n"
|
||||
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
|
||||
<< " sum += thread_sum[i]\n;"
|
||||
<< "}\n"
|
||||
<< "if (sum == 0) {\n"
|
||||
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
|
||||
<< " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n"
|
||||
<< " }\n"
|
||||
<< "} else {\n"
|
||||
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
|
||||
<< " var f32input = f32_val_t(x[offset + i]);\n"
|
||||
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
|
||||
<< " }\n"
|
||||
<< "}\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) {
|
||||
const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1);
|
||||
int work_group_size = 64;
|
||||
const int d_comp = d / components;
|
||||
if (d_comp < work_group_size) {
|
||||
work_group_size = 32;
|
||||
}
|
||||
const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size;
|
||||
|
||||
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components};
|
||||
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
|
||||
.SetDispatchGroupSize(n)
|
||||
.SetWorkgroupSize(work_group_size)
|
||||
.AddUniformVariables({{static_cast<float>(1.f / static_cast<float>(d))},
|
||||
{static_cast<uint32_t>(d_comp)},
|
||||
{static_cast<uint32_t>(elementsPerThread)}});
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
|
||||
shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
|
||||
if (feed_past_value_) {
|
||||
shader.AddInput("past_value", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform);
|
||||
if (has_present_value_) {
|
||||
shader.AddOutput("present_value", ShaderUsage::UseUniform);
|
||||
}
|
||||
|
||||
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << ">;\n"
|
||||
<< "var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << ">;\n";
|
||||
|
||||
shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n"
|
||||
<< "let m = global_id.y;\n"
|
||||
<< "let n = global_id.x;\n"
|
||||
<< "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n";
|
||||
|
||||
if (feed_past_value_ && has_present_value_) {
|
||||
shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n"
|
||||
<< "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n";
|
||||
}
|
||||
|
||||
if (has_present_value_) {
|
||||
shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << "var value = probs_element_t(0);\n"
|
||||
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
|
||||
<< " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n"
|
||||
<< " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n"
|
||||
<< " }\n"
|
||||
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
|
||||
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";
|
||||
|
||||
if (feed_past_value_ && has_present_value_) {
|
||||
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n"
|
||||
<< " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
|
||||
<< " } else {\n"
|
||||
<< " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n"
|
||||
<< " }\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n";
|
||||
}
|
||||
|
||||
if (has_present_value_) {
|
||||
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n";
|
||||
}
|
||||
|
||||
shader.MainFunctionBody() << " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n"
|
||||
<< " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n"
|
||||
<< " }\n"
|
||||
<< " workgroupBarrier();\n"
|
||||
<< "}\n";
|
||||
|
||||
shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
|
||||
<< "let batchIdx = workgroup_id.z / uniforms.num_heads;\n"
|
||||
<< "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n"
|
||||
<< "if (m < uniforms.M && n < uniforms.N) {\n"
|
||||
<< " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + "
|
||||
<< " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n"
|
||||
<< " output[outputIdx] = value;\n"
|
||||
<< "}\n";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count,
|
||||
const Tensor* probs,
|
||||
const Tensor* V,
|
||||
const Tensor* past_value,
|
||||
Tensor* output,
|
||||
Tensor* present_value,
|
||||
AttentionParameters& parameters,
|
||||
int past_sequence_length,
|
||||
int total_sequence_length) {
|
||||
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0;
|
||||
const bool has_present_value = output_count > 1 && past_value != nullptr;
|
||||
constexpr int tile_size = 12;
|
||||
|
||||
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size};
|
||||
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
|
||||
{V, ProgramTensorMetadataDependency::TypeAndRank}});
|
||||
if (feed_past_value) {
|
||||
program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}});
|
||||
if (has_present_value) {
|
||||
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
}
|
||||
|
||||
program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size,
|
||||
(parameters.sequence_length + tile_size - 1) / tile_size,
|
||||
parameters.batch_size * parameters.num_heads)
|
||||
.SetWorkgroupSize(tile_size, tile_size)
|
||||
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length)},
|
||||
{static_cast<uint32_t>(total_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.v_head_size)},
|
||||
{static_cast<uint32_t>(parameters.num_heads)},
|
||||
{static_cast<uint32_t>(parameters.v_hidden_size)},
|
||||
{static_cast<uint32_t>(past_sequence_length)},
|
||||
{static_cast<uint32_t>(parameters.kv_sequence_length)}})
|
||||
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
|
||||
;
|
||||
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
|
||||
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
|
||||
AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
|
||||
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
|
||||
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0;
|
||||
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length;
|
||||
|
||||
const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads,
|
||||
parameters.sequence_length, total_sequence_length});
|
||||
const TensorShape probs_shape(probs_dims);
|
||||
Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape);
|
||||
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
|
||||
parameters, past_sequence_length, total_sequence_length));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
|
||||
parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length));
|
||||
|
||||
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
|
||||
parameters, past_sequence_length, total_sequence_length));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
|
||||
: WebGpuKernel(info) {
|
||||
int64_t num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
: WebGpuKernel(info), AttentionBase(info, false) {
|
||||
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel");
|
||||
}
|
||||
|
||||
|
@ -434,54 +51,54 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
|
|||
ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu");
|
||||
}
|
||||
|
||||
AttentionParameters parameters;
|
||||
AttentionParameters params;
|
||||
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query, key, value,
|
||||
bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters,
|
||||
bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶ms,
|
||||
num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention,
|
||||
context.DeviceLimits().maxComputeInvocationsPerWorkgroup));
|
||||
|
||||
WebgpuAttentionParameters parameters(params);
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
|
||||
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);
|
||||
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size_);
|
||||
Tensor* output = context.Output(0, output_shape);
|
||||
|
||||
// If optional outputs aren't needed, present_key and present_value will be null
|
||||
std::vector<int64_t> present_dims{
|
||||
parameters.batch_size,
|
||||
parameters.num_heads,
|
||||
parameters.total_sequence_length,
|
||||
parameters.head_size,
|
||||
parameters.batch_size_,
|
||||
parameters.num_heads_,
|
||||
parameters.total_sequence_length_,
|
||||
parameters.head_size_,
|
||||
};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context.Output(1, present_shape);
|
||||
Tensor* present_value = context.Output(2, present_shape);
|
||||
|
||||
TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
|
||||
parameters.sequence_length, parameters.head_size});
|
||||
TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.sequence_length_, parameters.head_size_});
|
||||
TensorShape q_new_shape(q_new_dims);
|
||||
Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape);
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
|
||||
context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q));
|
||||
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q));
|
||||
|
||||
if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
|
||||
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
|
||||
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
|
||||
present_value, parameters, context);
|
||||
}
|
||||
|
||||
TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads,
|
||||
parameters.kv_sequence_length, parameters.head_size});
|
||||
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.kv_sequence_length_, parameters.head_size_});
|
||||
TensorShape k_new_shape(k_new_dims);
|
||||
Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape);
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
|
||||
parameters.head_size, key, bias, parameters.hidden_size, &K));
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_,
|
||||
parameters.head_size_, key, bias, parameters.hidden_size_, &K));
|
||||
|
||||
TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads,
|
||||
parameters.kv_sequence_length, parameters.v_head_size});
|
||||
TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_,
|
||||
parameters.kv_sequence_length_, parameters.v_head_size_});
|
||||
TensorShape v_new_shape(v_new_dims);
|
||||
Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape);
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
|
||||
parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V));
|
||||
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_,
|
||||
parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V));
|
||||
|
||||
// Compute the attention score and apply the score to V
|
||||
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
#include "core/providers/webgpu/program.h"
|
||||
#include "core/providers/webgpu/shader_helper.h"
|
||||
#include "core/providers/webgpu/webgpu_kernel.h"
|
||||
#include "contrib_ops/webgpu/bert/attention.h"
|
||||
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
@ -14,100 +17,10 @@ namespace webgpu {
|
|||
|
||||
using namespace onnxruntime::webgpu;
|
||||
|
||||
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
|
||||
public:
|
||||
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"batch_offset", ProgramUniformVariableDataType::Uint32},
|
||||
{"sequence_offset", ProgramUniformVariableDataType::Uint32},
|
||||
{"head_offset", ProgramUniformVariableDataType::Uint32},
|
||||
{"bias_offset", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool has_bias_;
|
||||
};
|
||||
|
||||
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
|
||||
public:
|
||||
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
|
||||
bool has_attention_bias, int tile_size, int components)
|
||||
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
|
||||
{"K", ProgramUniformVariableDataType::Uint32},
|
||||
{"N", ProgramUniformVariableDataType::Uint32},
|
||||
{"num_heads", ProgramUniformVariableDataType::Uint32},
|
||||
{"alpha", ProgramUniformVariableDataType::Float32},
|
||||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool feed_past_key_;
|
||||
bool has_present_key_;
|
||||
bool has_attention_bias_;
|
||||
int tile_size_;
|
||||
int components_;
|
||||
};
|
||||
|
||||
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
|
||||
public:
|
||||
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components)
|
||||
: Program{kernel_name}, work_group_size_(work_group_size), components_(components) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32},
|
||||
{"d_comp", ProgramUniformVariableDataType::Uint32},
|
||||
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
private:
|
||||
int work_group_size_;
|
||||
int components_;
|
||||
};
|
||||
|
||||
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
|
||||
public:
|
||||
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size)
|
||||
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
|
||||
{"K", ProgramUniformVariableDataType::Uint32},
|
||||
{"N", ProgramUniformVariableDataType::Uint32},
|
||||
{"num_heads", ProgramUniformVariableDataType::Uint32},
|
||||
{"v_hidden_size", ProgramUniformVariableDataType::Uint32},
|
||||
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
|
||||
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32});
|
||||
|
||||
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
|
||||
|
||||
private:
|
||||
bool feed_past_value_;
|
||||
bool has_present_value_;
|
||||
int tile_size_;
|
||||
};
|
||||
|
||||
class MultiHeadAttention final : public WebGpuKernel {
|
||||
class MultiHeadAttention final : public WebGpuKernel, public AttentionBase {
|
||||
public:
|
||||
MultiHeadAttention(const OpKernelInfo& info);
|
||||
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
|
||||
|
||||
protected:
|
||||
int num_heads_;
|
||||
float mask_filter_value_;
|
||||
float scale_;
|
||||
bool is_unidirectional_{false};
|
||||
};
|
||||
|
||||
} // namespace webgpu
|
||||
|
|
|
@ -42,7 +42,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
|
||||
|
|
Загрузка…
Ссылка в новой задаче