[WebGPU EP] Support GroupQueryAttention (#22658)

### Description
<!-- Describe your changes. -->
Support GroupQueryAttention operator for native webgpu ep.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This is required for inferencing some LLMs.
This commit is contained in:
Satya Kumar Jandhyala 2024-12-02 12:40:03 -08:00 коммит произвёл GitHub
Родитель 6c2ff5fc55
Коммит e8bf46a70e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
9 изменённых файлов: 921 добавлений и 517 удалений

Просмотреть файл

@ -11,18 +11,19 @@ namespace onnxruntime {
namespace contrib {
namespace group_query_attention_helper {
Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap) {
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
@ -265,18 +266,19 @@ Status CheckInputs(const Tensor* query,
return Status::OK();
}
Status CheckInputs(const Tensor* query,
const Tensor* key,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
const Tensor* cos_cache,
const Tensor* sin_cache,
template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const Tensor* seqlens_k,
const Tensor* total_seqlen,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap,
int max_threads_per_block) {

Просмотреть файл

@ -0,0 +1,459 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/webgpu/bert/attention.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/webgpu/bert/multihead_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
using namespace onnxruntime::webgpu;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::contrib::multihead_attention_helper;
namespace onnxruntime {
namespace contrib {
namespace webgpu {
Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("qkv_input", ShaderUsage::UseUniform);
const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices);
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
<< "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n"
<< "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *"
<< " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n";
if (has_bias_) {
shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n";
}
shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]";
if (has_bias_) {
shader.MainFunctionBody() << " + bias[bias_offset_idx];\n";
} else {
shader.MainFunctionBody() << ";\n";
}
return Status::OK();
}
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) {
ORT_ENFORCE(input_tensor->Shape().GetDims().size() == 3);
ORT_ENFORCE(output_tensor->Shape().GetDims().size() == 4);
uint32_t data_size = SafeInt<uint32_t>(output_tensor->Shape().Size());
const int batch_offset = num_heads * sequence_length * head_size;
const int sequence_offset = num_heads * head_size;
const int head_offset = head_size;
bool has_bias = bias != nullptr;
TransferBSDToBNSHProgram program{has_bias};
program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{data_size},
{static_cast<uint32_t>(batch_offset)},
{static_cast<uint32_t>(sequence_offset)},
{static_cast<uint32_t>(head_offset)},
{static_cast<uint32_t>(bias_offset)}});
if (has_bias) {
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank});
}
return context.RunProgram(program);
};
void InitVarStub(std::ostringstream& ss, const Tensor* seqlen_k, bool is_first_prompt) {
if (seqlen_k != nullptr) {
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";
ss << "var past_sequence_length: u32 = " << (is_first_prompt ? "0" : "total_sequence_length - sequence_length") << ";\n";
} else {
ss << "let past_sequence_length = uniforms.past_sequence_length;\n";
}
}
Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (feed_past_key_) {
shader.AddInput("past_key", ShaderUsage::UseUniform);
}
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
if (seqlen_k_ != nullptr) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (has_present_key_) {
shader.AddOutput("present_key", ShaderUsage::UseUniform);
}
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<q_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
<< "let m = workgroup_id.y * TILE_SIZE;\n"
<< "let n = workgroup_id.x * TILE_SIZE;\n"
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.N;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_, is_first_prompt_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n";
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n";
}
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
" }\n"
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
if ((feed_past_key_ && has_present_key_) || past_present_share_buffer_) {
shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n"
<< " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
<< " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length) {\n"
" tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" }\n";
}
if (has_present_key_) {
if (past_present_share_buffer_) {
shader.MainFunctionBody() << " if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
} else {
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
}
shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"
<< " }\n";
}
shader.MainFunctionBody() << " }\n"
<< " workgroupBarrier();\n"
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n"
<< " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< "}\n";
shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n"
<< " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n"
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
if (has_attention_bias_) {
shader.MainFunctionBody() << " + attention_bias[outputIdx]";
}
shader.MainFunctionBody() << ";\n"
<< "}\n";
return Status::OK();
}
Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q,
const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key,
WebgpuAttentionParameters& parameters, int past_sequence_length, int total_sequence_length,
const Tensor* seqlen_k) {
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
: parameters.scale_;
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
const bool has_present_key = output_count > 1 && past_key;
const bool has_attention_bias = attention_bias != nullptr;
constexpr int tile_size = 12;
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components});
}
if (has_attention_bias) {
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
}
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}});
if (has_present_key) {
program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components});
}
const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components;
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
{static_cast<uint32_t>(parameters.num_heads_)},
{static_cast<uint32_t>(parameters.head_size_)},
{static_cast<float>(alpha)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
return context.RunProgram(program);
}
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
if (seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let sequence_length = uniforms.sequence_length;\n"
<< "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_, is_first_prompt_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n"
<< "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
<< "}\n"
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
<< "workgroupBarrier();\n"
<< "var max_value = f32(-3.402823e+38f);\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " max_value = max(thread_max[i], max_value);\n"
<< "}\n"
<< "var sum_vector = f32_val_t(0);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n"
<< "}\n"
<< "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n"
<< "workgroupBarrier();\n"
<< "var sum: f32 = 0;\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " sum += thread_sum[i]\n;"
<< "}\n"
<< "if (sum == 0) {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n"
<< " }\n"
<< "} else {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " var f32input = f32_val_t(x[offset + i]);\n"
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
<< " }\n"
<< "}\n";
if (seqlen_k_) {
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n"
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
<< "}\n";
}
return Status::OK();
}
Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
const Tensor* seqlen_k, bool is_first_prompt) {
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
if (total_sequence_length_comp < work_group_size) {
work_group_size = 32;
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components, is_first_prompt, seqlen_k};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(work_group_size, is_first_prompt)
.SetDispatchGroupSize(1, sequence_length, batch_size * num_heads)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
{static_cast<uint32_t>(num_heads)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(sequence_length)},
{static_cast<uint32_t>(total_sequence_length_comp)},
{static_cast<uint32_t>(elementsPerThread)}});
return context.RunProgram(program);
}
Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (feed_past_value_) {
shader.AddInput("past_value", ShaderUsage::UseUniform);
}
if (seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform);
if (has_present_value_) {
shader.AddOutput("present_value", ShaderUsage::UseUniform);
}
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << ">;\n";
shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n"
<< "let batch_idx = workgroup_id.z / uniforms.num_heads;\n"
<< "let m = global_id.y;\n"
<< "let n = global_id.x;\n"
<< "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n"
<< "let sequence_length = uniforms.M;\n"
<< "var total_sequence_length = uniforms.K;\n";
std::ostringstream oss;
InitVarStub(oss, seqlen_k_, is_first_prompt_);
shader.MainFunctionBody() << oss.str();
shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n";
if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n";
}
shader.MainFunctionBody() << "var value = probs_element_t(0);\n"
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
<< " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n"
<< " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n"
<< " }\n"
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";
if ((feed_past_value_ && has_present_value_) || past_present_share_buffer_) {
shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n"
<< " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n"
<< " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
<< " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length) {\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y) * uniforms.N];\n"
<< " }\n";
}
if (has_present_value_) {
if (past_present_share_buffer_) {
shader.MainFunctionBody() << " if (w + local_id.y >= past_sequence_length && w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
} else {
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {\n";
}
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"
<< " }\n";
}
shader.MainFunctionBody() << " }\n"
<< " workgroupBarrier();\n"
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {\n"
<< " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< "}\n";
shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
<< "if (m < uniforms.M && n < uniforms.N) {\n"
<< " let outputIdx = batch_idx * uniforms.M * uniforms.v_hidden_size + "
<< " m * uniforms.v_hidden_size + head_idx * uniforms.N + n;\n"
<< " output[outputIdx] = value;\n"
<< "}\n";
return Status::OK();
}
Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count,
const Tensor* probs,
const Tensor* V,
const Tensor* past_value,
Tensor* output,
Tensor* present_value,
WebgpuAttentionParameters& parameters,
int past_sequence_length,
int total_sequence_length,
const Tensor* seqlen_k) {
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0 && !parameters.past_present_share_buffer_;
const bool has_present_value = output_count > 1 && past_value != nullptr;
constexpr int tile_size = 12;
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank});
}
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}});
if (has_present_value) {
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank});
}
program.SetDispatchGroupSize((parameters.v_head_size_ + tile_size - 1) / tile_size,
(parameters.sequence_length_ + tile_size - 1) / tile_size,
parameters.batch_size_ * parameters.num_heads_)
.CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length_)},
{static_cast<uint32_t>(total_sequence_length)},
{static_cast<uint32_t>(parameters.v_head_size_)},
{static_cast<uint32_t>(parameters.num_heads_)},
{static_cast<uint32_t>(parameters.head_size_)},
{static_cast<uint32_t>(parameters.v_hidden_size_ * parameters.n_reps)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
{static_cast<uint32_t>(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)},
{static_cast<uint32_t>(parameters.n_reps)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
return context.RunProgram(program);
}
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, total_sequence_length});
const TensorShape probs_shape(probs_dims);
Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape);
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
parameters, past_sequence_length, total_sequence_length, seqlen_k));
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_));
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k));
return Status::OK();
}
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace webgpu {
using namespace onnxruntime::webgpu;
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
public:
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32},
{"batch_offset", ProgramUniformVariableDataType::Uint32},
{"sequence_offset", ProgramUniformVariableDataType::Uint32},
{"head_offset", ProgramUniformVariableDataType::Uint32},
{"bias_offset", ProgramUniformVariableDataType::Uint32});
private:
bool has_bias_;
};
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
private:
bool feed_past_key_;
bool has_present_key_;
bool has_attention_bias_;
int tile_size_;
int components_;
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components), seqlen_k_(seqlen_k), is_first_prompt_(is_first_prompt) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});
private:
int work_group_size_;
int components_;
const Tensor* seqlen_k_;
bool is_first_prompt_;
};
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"head_size", ProgramUniformVariableDataType::Uint32},
{"v_hidden_size", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
private:
bool feed_past_value_;
bool has_present_value_;
int tile_size_;
int n_reps_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,130 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace webgpu {
struct WebgpuAttentionParameters {
explicit WebgpuAttentionParameters(AttentionParameters parameters) : is_gqa_(false),
batch_size_(parameters.batch_size),
sequence_length_(parameters.sequence_length),
kv_sequence_length_(parameters.kv_sequence_length),
past_sequence_length_(parameters.past_sequence_length),
total_sequence_length_(parameters.total_sequence_length),
max_sequence_length_(parameters.max_sequence_length),
input_hidden_size_(parameters.input_hidden_size),
hidden_size_(parameters.hidden_size),
head_size_(parameters.head_size),
v_hidden_size_(parameters.v_hidden_size),
v_head_size_(parameters.v_head_size),
num_heads_(parameters.num_heads),
is_unidirectional_(parameters.is_unidirectional),
past_present_share_buffer_(parameters.past_present_share_buffer),
do_rotary_(parameters.do_rotary),
broadcast_attn_bias_dim_0_(parameters.broadcast_attn_bias_dim_0),
broadcast_attn_bias_dim_1_(parameters.broadcast_attn_bias_dim_1),
mask_filter_value_(parameters.mask_filter_value),
scale_(parameters.scale),
mask_type_(parameters.mask_type),
qkv_format_(parameters.qkv_format) {
}
explicit WebgpuAttentionParameters(onnxruntime::contrib::GroupQueryAttentionParameters parameters) : is_gqa_(true),
batch_size_(parameters.batch_size),
sequence_length_(parameters.sequence_length),
kv_sequence_length_(parameters.sequence_length),
past_sequence_length_(parameters.seqlen_past_kv_cache),
total_sequence_length_(parameters.total_sequence_length),
hidden_size_(parameters.hidden_size),
head_size_(parameters.head_size),
v_hidden_size_(parameters.kv_hidden_size),
v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads),
num_heads_(parameters.num_heads),
do_rotary_(parameters.do_rotary),
scale_(parameters.scale),
seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache),
seqlen_present_kv_cache_(parameters.seqlen_present_kv_cache),
kv_hidden_size_(parameters.kv_hidden_size),
kv_num_heads_(parameters.kv_num_heads),
num_splits_(parameters.num_splits),
rotary_dim_(parameters.rotary_dim),
is_packed_qkv_(parameters.is_packed_qkv),
is_subsequent_prompt_(parameters.is_subsequent_prompt),
is_first_prompt_(parameters.is_first_prompt),
rotary_interleaved_(parameters.rotary_interleaved),
use_smooth_softmax_(parameters.use_smooth_softmax),
softcap_(parameters.softcap),
zeros_count_(parameters.zeros_count),
zero_ptr_(parameters.zero_ptr),
n_reps(parameters.num_heads / parameters.kv_num_heads),
qkv_format_(parameters.qkv_format) {
}
bool is_gqa_;
int batch_size_ = 0;
int sequence_length_ = 0;
int kv_sequence_length_ = 0; // input sequence length of K or V
int past_sequence_length_ = 0; // sequence length in past state of K or V
int total_sequence_length_ = 0; // total sequence length of K or V
int max_sequence_length_ = 0; // max sequence length from 4D mask
int input_hidden_size_ = 0; // first dimension of weights for input projection
int hidden_size_ = 0; // hidden size of Q or K
int head_size_ = 0; // hidden size per head of Q or K
int v_hidden_size_ = 0; // hidden size of V
int v_head_size_ = 0; // hidden size per head of V
int num_heads_ = 0;
int rotary_embedding_ = 0;
bool is_unidirectional_ = false;
bool past_present_share_buffer_ = false;
bool do_rotary_ = false;
bool broadcast_attn_bias_dim_0_ = false;
bool broadcast_attn_bias_dim_1_ = false;
float mask_filter_value_ = -10000.0f;
float scale_ = 0.0f;
bool use_tf32_ = false;
;
// The following members are in onnxruntime::contrib::GroupQueryAttentionParameters
// and not in onnxruntime::contrib::AttentionParameters
int seqlen_past_kv_cache_ = 0; // sequence length of past kv tensor
int seqlen_present_kv_cache_ = 0; // sequence length of present kv tensor
int kv_hidden_size_ = 0;
int kv_num_heads_ = 0;
int num_splits_ = 0; // number of splits for splitkv
int rotary_dim_ = 0; // rotary embedding dimension
int local_window_size_ = 0;
bool kv_share_buffer_ = false;
bool is_packed_qkv_ = false;
bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1
bool is_first_prompt_ = false; // indicates whether this is first decoding step
bool rotary_interleaved_ = false;
bool use_smooth_softmax_ = false;
float softcap_ = 0.0;
int zeros_count_ = 0;
;
int* zero_ptr_ = nullptr;
// Computed values
int n_reps = 1;
AttentionMaskType mask_type_ = MASK_NONE;
AttentionQkvFormat qkv_format_ = UNKNOWN;
};
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor);
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
#include "contrib_ops/webgpu/bert/attention_common.h"
#include "contrib_ops/webgpu/bert/group_query_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
using namespace onnxruntime::webgpu;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::contrib::group_query_attention_helper;
namespace onnxruntime {
namespace contrib {
namespace webgpu {
ONNX_OPERATOR_KERNEL_EX(
GroupQueryAttention,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.MayInplace(3, 1)
.MayInplace(4, 2)
.InputMemoryType(OrtMemTypeCPUInput, 6),
GroupQueryAttention);
Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* query = context.Input<Tensor>(0);
const Tensor* key = context.Input<Tensor>(1);
const Tensor* value = context.Input<Tensor>(2);
const Tensor* past_key = context.Input<Tensor>(3);
const Tensor* past_value = context.Input<Tensor>(4);
const Tensor* seqlen_k = context.Input<Tensor>(5);
const Tensor* total_seqlen_tensor = context.Input<Tensor>(6);
const Tensor* cos_cache = context.Input<Tensor>(7);
const Tensor* sin_cache = context.Input<Tensor>(8);
GroupQueryAttentionParameters params;
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
value,
past_key,
past_value,
cos_cache,
sin_cache,
&params,
num_heads_,
kv_num_heads_,
seqlen_k,
total_seqlen_tensor,
scale_,
softcap_));
WebgpuAttentionParameters parameters(params);
if (parameters.is_packed_qkv_) {
ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu-ep.");
}
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);
output_shape[2] = static_cast<int64_t>(parameters.hidden_size_);
Tensor* output = context.Output(0, output_shape);
std::vector<int64_t> present_dims{
parameters.batch_size_,
kv_num_heads_,
parameters.seqlen_present_kv_cache_,
parameters.head_size_};
std::vector<int64_t> present_kv_shape(present_dims);
Tensor* present_key = context.Output(1, present_kv_shape);
Tensor* present_value = context.Output(2, present_kv_shape);
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();
TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, parameters.head_size_});
TensorShape q_new_shape(q_new_dims);
Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q));
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
return ApplyAttention(&Q, key, value, nullptr, past_key, past_value, output, present_key,
present_value, parameters, context, seqlen_k);
}
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
parameters.kv_sequence_length_, parameters.head_size_});
TensorShape k_new_shape(k_new_dims);
Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
parameters.head_size_, key, nullptr, 0, &K));
TensorShapeVector v_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
parameters.kv_sequence_length_, parameters.v_head_size_});
TensorShape v_new_shape(v_new_dims);
Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
parameters.v_head_size_, value, nullptr, 0, &V));
return ApplyAttention(&Q, &K, &V, nullptr, past_key, past_value, output, present_key,
present_value, parameters, context, seqlen_k);
}
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace webgpu {
using namespace onnxruntime::webgpu;
class GroupQueryAttention final : public WebGpuKernel {
public:
GroupQueryAttention(const OpKernelInfo& info) : WebGpuKernel(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0);
kv_num_heads_ = static_cast<int>(kv_num_heads);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
softcap_ = info.GetAttrOrDefault<float>("softcap", 0.0f);
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_interleaved_ = info.GetAttrOrDefault<int64_t>("rotary_interleaved", 0) == 1;
use_smooth_softmax_ = info.GetAttrOrDefault<int64_t>("smooth_softmax", 0) == 1;
local_window_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("local_window_size", -1));
}
int num_heads_; // number of attention heads of Q
int kv_num_heads_; // number of attention heads of K or V
float scale_; // the scaling factor applied before softmax
float softcap_;
bool do_rotary_; // whether or not to use rotary embeddings
bool rotary_interleaved_;
int local_window_size_;
bool use_smooth_softmax_;
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
};
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/webgpu/bert/attention_common.h"
#include "contrib_ops/webgpu/bert/multihead_attention.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
@ -25,392 +26,8 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
MultiHeadAttention);
Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("qkv_input", ShaderUsage::UseUniform);
const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices);
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size")
<< "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n"
<< "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *"
<< " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n";
if (has_bias_) {
shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n";
}
shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]";
if (has_bias_) {
shader.MainFunctionBody() << " + bias[bias_offset_idx];\n";
} else {
shader.MainFunctionBody() << ";\n";
}
return Status::OK();
}
Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length,
int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) {
assert(input_tensor->Shape().GetDims().size() == 3);
assert(output_tensor->Shape().GetDims().size() == 4);
uint32_t data_size = gsl::narrow<uint32_t>(output_tensor->Shape().Size());
const int batch_offset = num_heads * sequence_length * head_size;
const int sequence_offset = num_heads * head_size;
const int head_offset = head_size;
bool has_bias = bias != nullptr;
TransferBSDToBNSHProgram program{has_bias};
program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{data_size},
{static_cast<uint32_t>(batch_offset)},
{static_cast<uint32_t>(sequence_offset)},
{static_cast<uint32_t>(head_offset)},
{static_cast<uint32_t>(bias_offset)}});
if (has_bias) {
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank});
}
return context.RunProgram(program);
};
Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (feed_past_key_) {
shader.AddInput("past_key", ShaderUsage::UseUniform);
}
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (has_present_key_) {
shader.AddOutput("present_key", ShaderUsage::UseUniform);
}
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<q_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
"let headIdx = workgroup_id.z;\n"
"let m = workgroup_id.y * TILE_SIZE;\n"
"let n = workgroup_id.x * TILE_SIZE;\n"
"let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n";
if (feed_past_key_ && has_present_key_) {
shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n"
<< "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n";
} else {
shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n";
}
if (has_present_key_) {
shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n";
}
shader.MainFunctionBody() << "var value = f32_val_t(0);\n"
"for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
" if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n"
" tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n"
" }\n"
" if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n"
" var idx = TILE_SIZE * local_id.y + local_id.x;\n";
if (feed_past_key_ && has_present_key_) {
shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n"
" tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n"
" } else {\n"
" tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n"
" }\n";
} else {
shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n";
}
if (has_present_key_) {
shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n";
}
shader.MainFunctionBody() << " }\n"
<< " workgroupBarrier();\n"
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n"
<< " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< "}\n";
shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n"
<< "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n"
<< " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n"
<< " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n";
shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)";
if (has_attention_bias_) {
shader.MainFunctionBody() << " + attention_bias[outputIdx]";
}
shader.MainFunctionBody() << ";\n"
<< "}\n";
return Status::OK();
}
Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q,
const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key,
AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) {
const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size))
: parameters.scale;
const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0;
const bool has_present_key = output_count > 1 && past_key;
const bool has_attention_bias = attention_bias != nullptr;
constexpr int tile_size = 12;
const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1);
AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size,
components};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
{K, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (feed_past_key) {
program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components});
}
if (has_attention_bias) {
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}});
if (has_present_key) {
program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components});
}
const uint32_t vectorized_head_size = parameters.head_size / components;
program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size,
(parameters.sequence_length + tile_size - 1) / tile_size,
parameters.batch_size * parameters.num_heads)
.SetWorkgroupSize(tile_size, tile_size)
.CacheHint(std::to_string(tile_size))
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length)},
{static_cast<uint32_t>(vectorized_head_size)},
{static_cast<uint32_t>(total_sequence_length)},
{static_cast<uint32_t>(parameters.num_heads)},
{static_cast<float>(alpha)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
return context.RunProgram(program);
}
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AdditionalImplementation() << "var<workgroup> thread_max: array<f32, " << work_group_size_ << ">;\n"
<< "var<workgroup> thread_sum: array<f32, " << work_group_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";
shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
<< "}\n"
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
<< "workgroupBarrier();\n"
<< "var max_value = f32(-3.402823e+38f);\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " max_value = max(thread_max[i], max_value);\n"
<< "}\n"
<< "var sum_vector = f32_val_t(0);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
<< " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n"
<< "}\n"
<< "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n"
<< "workgroupBarrier();\n"
<< "var sum: f32 = 0;\n"
<< "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n"
<< " sum += thread_sum[i]\n;"
<< "}\n"
<< "if (sum == 0) {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
<< " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n"
<< " }\n"
<< "} else {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n"
<< " var f32input = f32_val_t(x[offset + i]);\n"
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
<< " }\n"
<< "}\n";
return Status::OK();
}
Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) {
const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1);
int work_group_size = 64;
const int d_comp = d / components;
if (d_comp < work_group_size) {
work_group_size = 32;
}
const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size;
InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components};
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.SetDispatchGroupSize(n)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<float>(1.f / static_cast<float>(d))},
{static_cast<uint32_t>(d_comp)},
{static_cast<uint32_t>(elementsPerThread)}});
return context.RunProgram(program);
}
Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
if (feed_past_value_) {
shader.AddInput("past_value", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform);
if (has_present_value_) {
shader.AddOutput("present_value", ShaderUsage::UseUniform);
}
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<probs_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<v_value_t, " << tile_size_ * tile_size_ << ">;\n";
shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n"
<< "let m = global_id.y;\n"
<< "let n = global_id.x;\n"
<< "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n";
if (feed_past_value_ && has_present_value_) {
shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n"
<< "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n";
} else {
shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n";
}
if (has_present_value_) {
shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n";
}
shader.MainFunctionBody() << "var value = probs_element_t(0);\n"
<< "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n"
<< " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n"
<< " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n"
<< " }\n"
<< " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n"
<< " var idx = TILE_SIZE * local_id.y + local_id.x;\n";
if (feed_past_value_ && has_present_value_) {
shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n"
<< " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n"
<< " } else {\n"
<< " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n";
}
if (has_present_value_) {
shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n";
}
shader.MainFunctionBody() << " }\n"
<< " workgroupBarrier();\n"
<< " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n"
<< " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< "}\n";
shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n"
<< "let batchIdx = workgroup_id.z / uniforms.num_heads;\n"
<< "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n"
<< "if (m < uniforms.M && n < uniforms.N) {\n"
<< " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + "
<< " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n"
<< " output[outputIdx] = value;\n"
<< "}\n";
return Status::OK();
}
Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count,
const Tensor* probs,
const Tensor* V,
const Tensor* past_value,
Tensor* output,
Tensor* present_value,
AttentionParameters& parameters,
int past_sequence_length,
int total_sequence_length) {
const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0;
const bool has_present_value = output_count > 1 && past_value != nullptr;
constexpr int tile_size = 12;
VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size};
program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank},
{V, ProgramTensorMetadataDependency::TypeAndRank}});
if (feed_past_value) {
program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank});
}
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}});
if (has_present_value) {
program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank});
}
program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size,
(parameters.sequence_length + tile_size - 1) / tile_size,
parameters.batch_size * parameters.num_heads)
.SetWorkgroupSize(tile_size, tile_size)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length)},
{static_cast<uint32_t>(total_sequence_length)},
{static_cast<uint32_t>(parameters.v_head_size)},
{static_cast<uint32_t>(parameters.num_heads)},
{static_cast<uint32_t>(parameters.v_hidden_size)},
{static_cast<uint32_t>(past_sequence_length)},
{static_cast<uint32_t>(parameters.kv_sequence_length)}})
.SetOverridableConstants({{static_cast<uint32_t>(tile_size)}});
;
return context.RunProgram(program);
}
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length;
const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads,
parameters.sequence_length, total_sequence_length});
const TensorShape probs_shape(probs_dims);
Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape);
ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key,
parameters, past_sequence_length, total_sequence_length));
ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length));
ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length));
return Status::OK();
}
MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
: WebGpuKernel(info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
: WebGpuKernel(info), AttentionBase(info, false) {
ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel");
}
@ -434,54 +51,54 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu");
}
AttentionParameters parameters;
AttentionParameters params;
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query, key, value,
bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, &parameters,
bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, &params,
num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention,
context.DeviceLimits().maxComputeInvocationsPerWorkgroup));
WebgpuAttentionParameters parameters(params);
TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size_);
Tensor* output = context.Output(0, output_shape);
// If optional outputs aren't needed, present_key and present_value will be null
std::vector<int64_t> present_dims{
parameters.batch_size,
parameters.num_heads,
parameters.total_sequence_length,
parameters.head_size,
parameters.batch_size_,
parameters.num_heads_,
parameters.total_sequence_length_,
parameters.head_size_,
};
TensorShape present_shape(present_dims);
Tensor* present_key = context.Output(1, present_shape);
Tensor* present_value = context.Output(2, present_shape);
TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
parameters.sequence_length, parameters.head_size});
TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.sequence_length_, parameters.head_size_});
TensorShape q_new_shape(q_new_dims);
Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q));
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, bias, 0, &Q));
if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context);
}
TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads,
parameters.kv_sequence_length, parameters.head_size});
TensorShapeVector k_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.kv_sequence_length_, parameters.head_size_});
TensorShape k_new_shape(k_new_dims);
Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
parameters.head_size, key, bias, parameters.hidden_size, &K));
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_,
parameters.head_size_, key, bias, parameters.hidden_size_, &K));
TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads,
parameters.kv_sequence_length, parameters.v_head_size});
TensorShapeVector v_new_dims({parameters.batch_size_, parameters.num_heads_,
parameters.kv_sequence_length_, parameters.v_head_size_});
TensorShape v_new_shape(v_new_dims);
Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape);
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V));
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.kv_sequence_length_,
parameters.v_head_size_, value, bias, 2 * parameters.hidden_size_, &V));
// Compute the attention score and apply the score to V
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,

Просмотреть файл

@ -7,6 +7,9 @@
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "contrib_ops/webgpu/bert/attention.h"
#include "contrib_ops/cpu/bert/attention_base.h"
namespace onnxruntime {
namespace contrib {
@ -14,100 +17,10 @@ namespace webgpu {
using namespace onnxruntime::webgpu;
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> {
public:
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32},
{"batch_offset", ProgramUniformVariableDataType::Uint32},
{"sequence_offset", ProgramUniformVariableDataType::Uint32},
{"head_offset", ProgramUniformVariableDataType::Uint32},
{"bias_offset", ProgramUniformVariableDataType::Uint32});
private:
bool has_bias_;
};
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"alpha", ProgramUniformVariableDataType::Float32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
private:
bool feed_past_key_;
bool has_present_key_;
bool has_attention_bias_;
int tile_size_;
int components_;
};
class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components)
: Program{kernel_name}, work_group_size_(work_group_size), components_(components) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32},
{"d_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32});
private:
int work_group_size_;
int components_;
};
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"v_hidden_size", ProgramUniformVariableDataType::Uint32},
{"past_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32});
private:
bool feed_past_value_;
bool has_present_value_;
int tile_size_;
};
class MultiHeadAttention final : public WebGpuKernel {
class MultiHeadAttention final : public WebGpuKernel, public AttentionBase {
public:
MultiHeadAttention(const OpKernelInfo& info);
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
protected:
int num_heads_;
float mask_filter_value_;
float scale_;
bool is_unidirectional_{false};
};
} // namespace webgpu

Просмотреть файл

@ -42,7 +42,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,