diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index fe824a5c455..09c786daa3f 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; -import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention'; +import { groupQueryAttention } from './ops/group-query-attention'; import { instanceNorm } from './ops/instance-norm'; import { layerNorm } from './ops/layer-norm'; import { matMul } from './ops/matmul'; @@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], - ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]], + ['GroupQueryAttention', [groupQueryAttention]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 832f6e13290..6a78c8ae3b1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU import { getMaxComponents, + IndicesHelper, inputVariable, outputVariable, ShaderHelper, @@ -65,14 +66,17 @@ export interface AttentionParameters { broadcastResPosBias: boolean; passPastInKv: boolean; qkvFormat: AttentionQkvFormat; - isPastkvBSNH?: boolean; + softcap?: number; + doRotary?: number; + rotaryInterLeaved?: number; + sommoothSoftmax?: number; + localWindowsSize?: number; } export interface AttentionAttrs { numHeads: number; - kvNumHeads?: number; - isUnidirectional?: number; - maskFilterValue?: number; + isUnidirectional: number; + maskFilterValue: number; scale: number; doRotary: number; qkvHiddenSizes: number[]; @@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => { - const components = getMaxComponents(d); +const initVarStub = ( + seqLensInput: IndicesHelper | undefined, + totalSequenceLengthInput: IndicesHelper | undefined, + initPastSequenceLength: boolean, +) => { + // In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input + if (totalSequenceLengthInput && seqLensInput) { + return ` + let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')}); + let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length); + let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input; + let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input; + total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1; + var past_sequence_length: u32 = 0; + if (is_first_prompt == false) { + past_sequence_length = total_sequence_length - sequence_length; + } + `; + } else { + return ` + ${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''}; + let present_sequence_length = total_sequence_length; + `; + } +}; + +const createInPlaceSoftmaxProgramInfo = ( + input: TensorView, + batchSize: number, + numHeads: number, + pastSequenceLength: number, + sequenceLength: number, + totalSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, +) => { + // Set components to 1 if seqLens is specified, i.e. GroupQueryAttention. + const components = getMaxComponents(seqLens ? 1 : totalSequenceLength); let WG = 64; - const dComp = d / components; - if (dComp < WG) { + const totalSequenceLengthComp = totalSequenceLength / components; + if (totalSequenceLengthComp < WG) { WG = 32; } - const elementsPerThread = Math.ceil(d / components / WG); + const elementsPerThread = Math.ceil(totalSequenceLength / components / WG); const programUniforms: ProgramUniform[] = [ - { type: DataType.float, data: 1 / d }, - { type: DataType.uint32, data: dComp }, + { type: DataType.uint32, data: batchSize }, + { type: DataType.uint32, data: numHeads }, + { type: DataType.uint32, data: pastSequenceLength }, + { type: DataType.uint32, data: sequenceLength }, + { type: DataType.uint32, data: totalSequenceLengthComp }, { type: DataType.uint32, data: elementsPerThread }, ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const f32Type = tensorTypeToWsglValueType(DataType.float, components); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); + const inputHelpers = [inputHelper]; + const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputHelper) { + inputHelpers.push(seqLensInputHelper); + } + + const totalSequenceLengthInputHelper = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputHelper) { + inputHelpers.push(totalSequenceLengthInputHelper); + } const elemValueType = tensorTypeToWsglValueType(input.dataType); const uniforms: UniformsArrayType = [ - { name: 'd_inv', type: 'f32' }, - { name: 'd_comp', type: 'u32' }, + { name: 'batch_size', type: 'u32' }, + { name: 'num_heads', type: 'u32' }, + { name: 'past_sequence_length', type: 'u32' }, + { name: 'sequence_length', type: 'u32' }, + { name: 'total_sequence_length', type: 'u32' }, { name: 'elements_per_thread', type: 'u32' }, ]; return ` var thread_max: array; var thread_sum: array; - ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)} ${shaderHelper.mainStart([WG, 1, 1])} + let batchIdx = workgroup_id.z / uniforms.num_heads; + let headIdx = workgroup_id.z % uniforms.num_heads; + let sequence_length = uniforms.sequence_length; + var total_sequence_length = uniforms.total_sequence_length; + ${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)} let local_offset = local_idx * uniforms.elements_per_thread; - let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; - + let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset; + let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'}; var thread_max_vector = ${f32Type}(-3.402823e+38f); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } thread_max[local_idx] = ${(() => { @@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } var sum_vector = ${f32Type}(0); - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { sum_vector += exp(${f32Type}(x[offset + i]) - max_value); } thread_sum[local_idx] = ${(() => { @@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number } if (sum == 0) { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { - x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv)); + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { + x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length)); } } else { - for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { + for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) { var f32input = ${f32Type}(x[offset + i]); x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum); } } + ${ + seqLens + ? ` + for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) { + x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0)); + }` + : '' + }; }`; }; @@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number name: 'AttentionProbsSoftmax', shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies }, getShaderSource, - getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }), + getRunData: () => ({ + outputs: [], + dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads }, + programUniforms, + }), }; }; @@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = ( pastKey: TensorView | undefined, attentionBias: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, pastSequenceLength: number, + seqLens: TensorView | undefined, + totalSequenceLengthInput: TensorView | undefined, ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; - const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey; + const presentKey = outputCount > 1 && pastKey; + const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads; const presentKeyShape = presentKey - ? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] + ? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize] : undefined; - + const nReps = parameters.nReps ? parameters.nReps : 1; // TODO: handle mask - const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale; const components = getMaxComponents(parameters.headSize); const vectorizedHeadSize = parameters.headSize / components; const TILE_SIZE = 12; @@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = ( { type: DataType.uint32, data: vectorizedHeadSize }, { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: parameters.numHeads }, + { type: DataType.uint32, data: parameters.headSize }, { type: DataType.float, data: alpha }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: parameters.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0; @@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }]; if (presentKey) { outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default }); @@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = ( if (attentionBias) { inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLensInputVariable) { + inputVars.push(seqLensInputVariable); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInputVariable) { + inputVars.push(totalSequenceLengthInputVariable); + } const output = outputVariable('output', q.dataType, probsShape); const outputVars = [output]; if (presentKey) { @@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'alpha', type: 'f32' as UniformDataElementType }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; @@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = ( ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} // x holds the N and y holds the M - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; + let batchIdx = workgroup_id.z / uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; - let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; - ${(() => { - if (feedPastKey && presentKey) { - return ` - let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; - let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; - } else { - return ` - let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`; - } - })()} - ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.N; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; + let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + ${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''}; + let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K; + ${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { @@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = ( ${(() => { if (feedPastKey && presentKey) { return ` - if (n + local_id.y < uniforms.past_sequence_length) { + if (n + local_id.y < past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; - } else { - tileK[idx] = - key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x]; + } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; }`; } else { - return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; + return ` + if (n + local_id.y < uniforms.kv_sequence_length) { + tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; + }`; } })()} ${ - presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : '' + presentKey + ? `if (n + local_id.y < present_sequence_length) { + present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx]; + }` + : '' } } workgroupBarrier(); for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); + value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]); } workgroupBarrier(); } - let headOffset = headIdx * uniforms.M * uniforms.N; - if (global_id.y < uniforms.M && global_id.x < uniforms.N) { + if (global_id.y < uniforms.M && global_id.x < total_sequence_length) { + let headOffset = workgroup_id.z * uniforms.M * uniforms.N; let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x; var sum: f32 = ${(() => { switch (components) { @@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = ( pastValue: TensorView | undefined, params: AttentionParameters, pastSequenceLength: number, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; - const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue; + const presentValue = outputCount > 1 && pastValue; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; const presentValueShape = presentValue - ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] + ? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize] : undefined; const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; const TILE_SIZE = 12; @@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = ( { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: params.vHeadSize }, { type: DataType.uint32, data: params.numHeads }, + { type: DataType.uint32, data: params.headSize }, { type: DataType.uint32, data: repeatedVHiddenSize }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: params.kvSequenceLength }, + { type: DataType.uint32, data: nReps }, ]; // Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0; @@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputDependencies.push('type'); } + if (seqLens) { + inputDependencies.push('type'); + } + if (totalSequenceLengthInput) { + inputDependencies.push('type'); + } const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }]; if (presentValue) { outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default }); @@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = ( if (feedPastValue) { inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); } + const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined; + if (seqLens) { + inputVars.push(seqLensInputVariable!); + } + const totalSequenceLengthInputVariable = totalSequenceLengthInput + ? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims) + : undefined; + if (totalSequenceLengthInput) { + inputVars.push(totalSequenceLengthInputVariable!); + } const output = outputVariable('output', probs.dataType, outputShape); const outputVars = [output]; if (presentValue) { @@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, + { name: 'head_size', type: 'u32' }, { name: 'v_hidden_size', type: 'u32' }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, + { name: 'n_reps', type: 'u32' }, ]; return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])} - let headIdx = workgroup_id.z; + let headIdx = workgroup_id.z % uniforms.num_heads; + let batchIdx = workgroup_id.z / uniforms.num_heads; + let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'}; + let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'}; let m = global_id.y; let n = global_id.x; - - let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; - ${(() => { - if (feedPastValue && presentValue) { - return ` - let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; - let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; - `; - } else { - return ` - let offsetB = headIdx * uniforms.N * uniforms.K + n; - `; - } - })()} - ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''} + let sequence_length = uniforms.M; + var total_sequence_length = uniforms.K; + ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch + ${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''}; + let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n; + ${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (m < uniforms.M && w + local_id.x < uniforms.K) { @@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = ( ${(() => { if (feedPastValue && presentValue) { return ` - if (w + local_id.y < uniforms.past_sequence_length) { - tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; - } else { - tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N]; + if (w + local_id.y < past_sequence_length) { + tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; + } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N]; } `; } else { return ` - tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N]; - `; + if (w + local_id.y < uniforms.kv_sequence_length) { + tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N]; + }`; } })()} - ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''} + ${ + presentValue + ? ` + if (w + local_id.y < present_sequence_length) { + present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx]; + }` + : '' + } } workgroupBarrier(); - for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { - value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x]; + for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) { + value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x]; } workgroupBarrier(); } // we need to transpose output from BNSH_v to BSND_v - let batchIdx = workgroup_id.z / uniforms.num_heads; - let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads; if (m < uniforms.M && n < uniforms.N) { let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size - + currentBatchHeadNumber * uniforms.N + n; + + headIdx * uniforms.N + n; output[outputIdx] = value; } }`; @@ -671,23 +804,29 @@ export const applyAttention = ( pastValue: TensorView | undefined, attentionBiasInput: TensorView | undefined, parameters: AttentionParameters, - attributes: AttentionAttrs, + seqLens: TensorView | undefined = undefined, + totalSequenceLengthInput: TensorView | undefined = undefined, ) => { - // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. + // Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists. const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0)); - const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; + const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const attentionBias = attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined; const inputsK = [q, k]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { + if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { inputsK.push(pastKey); } if (attentionBias) { inputsK.push(attentionBias); } - + if (seqLens) { + inputsK.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsK.push(totalSequenceLengthInput); + } // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( @@ -697,31 +836,55 @@ export const applyAttention = ( pastKey, attentionBias, parameters, - attributes, pastSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] }, + { inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] }, )[0]; // Run Softmax context.compute( createInPlaceSoftmaxProgramInfo( probs, - parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.batchSize, + parameters.numHeads, + pastSequenceLength, + parameters.sequenceLength, totalSequenceLength, + seqLens, + totalSequenceLengthInput, ), - { inputs: [probs], outputs: [] }, + { inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] }, ); - // Run AttrionScore + // Run AttentionScore const inputsV = [probs, v]; - if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { + if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { inputsV.push(pastValue); } - context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), { - inputs: inputsV, - outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0], - }); + if (seqLens) { + inputsV.push(seqLens); + } + if (totalSequenceLengthInput) { + inputsV.push(totalSequenceLengthInput); + } + context.compute( + createVxAttentionScoreProgramInfo( + outputCount, + probs, + v, + pastValue, + parameters, + pastSequenceLength, + seqLens, + totalSequenceLengthInput, + ), + { + inputs: inputsV, + outputs: outputCount > 1 ? [0, 2] : [0], + }, + ); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { @@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs): undefined, context.inputs[5], params, - attributes, ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 56291c037b7..bbe25460d6f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -1,31 +1,49 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; -import { ShapeUtil } from '../../util'; import { createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types'; +import { ComputeContext } from '../types'; -import { - applyAttention, - AttentionAttrs, - AttentionMaskType, - AttentionParameters, - AttentionQkvFormat, -} from './attention'; -import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common'; +import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; -import { createTileProgramInfo } from './tile'; +import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; +export interface GroupQueryAttentionAttributes { + numHeads: number; + kvNumHeads: number; + scale: number; + softcap: number; + doRotary: number; + rotaryInterleaved: number; + smoothSoftmax: boolean; + localWindowSize: number; +} -export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { +export const validateInputs = ( + inputs: readonly TensorView[], + attributes: GroupQueryAttentionAttributes, +): AttentionParameters => { + if (attributes.doRotary && inputs.length <= 7) { + throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified'); + } const query = inputs[0]; const key = inputs[1]; const value = inputs[2]; const pastKey = inputs[3]; const pastValue = inputs[4]; - + if (attributes.localWindowSize !== -1) { + throw new Error('Local attention is not supported'); + } + if (attributes.softcap !== 0) { + throw new Error('Softcap is not supported'); + } + if (attributes.rotaryInterleaved !== 0) { + throw new Error('Rotary interleaved is not supported'); + } + if (attributes.smoothSoftmax) { + throw new Error('Smooth softmax is not supported'); + } // Abbreviation and Meanings: // B: batch_size // S: sequence_length (input sequence length of query) @@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const dmmhaPacking = false; const batchSize = query.dims[0]; const sequenceLength = query.dims[1]; - const hiddenSize = + let hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4]; let kvSequenceLength = sequenceLength; let pastSequenceLength = 0; - let maxSequenceLength = 0; - const headSize = Math.floor(hiddenSize / attributes.numHeads); + const packedQKV = !key || key.dims.length === 0; + const headSize = !packedQKV + ? Math.floor(hiddenSize / attributes.numHeads) + : Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads)); + if (packedQKV) { + hiddenSize = headSize * attributes.numHeads; + } const hasPastKey = pastKey && pastKey.dims.length !== 0; const hasPastValue = pastValue && pastValue.dims.length !== 0; - // TODO : this should be from attributes. - const isPastkvBSNH = true; + // Currenly the onnxruntime GQA specification only support key/value BNSH format. + const isPastkvBSNH = + hasPastKey && + pastKey.dims.length === 4 && + pastKey.dims[0] === batchSize && + pastKey.dims[1] !== attributes.kvNumHeads && + pastKey.dims[2] === attributes.kvNumHeads && + pastKey.dims[3] === headSize; + + if (isPastkvBSNH) { + throw new Error('BSNH pastKey/pastValue is not supported'); + } if (hasPastKey && hasPastValue) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); @@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (pastValue.dims.length !== 4) { throw new Error('Input "past_value" is expected to have 4 dimensions'); } - if (isPastkvBSNH) { - // For BSNH - pastSequenceLength = pastKey.dims[1]; - maxSequenceLength = pastKey.dims[1]; - } else { - // For BNSH - pastSequenceLength = pastKey.dims[2]; - maxSequenceLength = pastKey.dims[2]; - } + pastSequenceLength = pastKey.dims[2]; } else if (hasPastKey || hasPastValue) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } - let qkvFormat: AttentionQkvFormat; - if (key) { + let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH; + if (key && key.dims.length > 0) { if (query.dims.length !== 3) { throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } @@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (query.dims[2] % key.dims[2] !== 0) { throw new Error('Dimension 2 of "query" should be a multiple of "key"'); } - qkvFormat = AttentionQkvFormat.qkvBSNH; kvSequenceLength = key.dims[1]; } else if (key.dims.length === 5) { if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { @@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent if (value) { throw new Error('Expect "value" be none when "key" has packed kv format.'); } - qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; kvSequenceLength = key.dims[1]; } else { // key_dims.size() == 4 (cross-attention with past_key) if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); } - - qkvFormat = AttentionQkvFormat.unknown; kvSequenceLength = key.dims[2]; } } else { @@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent const maskType: AttentionMaskType = AttentionMaskType.none; let passPastInKv = false; - let vHiddenSize = hiddenSize; - if (value) { + let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize; + if (value && value.dims.length > 0) { if (value.dims.length !== 3 && value.dims.length !== 4) { throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } @@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent passPastInKv = true; } } - const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const seqlLens = inputs.length > 4 ? inputs[5] : undefined; + if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) { + throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size'); + } + const totalSequenceLength = -1; + const maxSequenceLength = -1; const broadcastResPosBias = false; return { @@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent hiddenSize, vHiddenSize, headSize, - vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!), + vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads), numHeads: attributes.numHeads, kvNumHeads: attributes.kvNumHeads, - nReps: attributes.numHeads / attributes.kvNumHeads!, + nReps: attributes.numHeads / attributes.kvNumHeads, pastPresentShareBuffer: false, maskType, scale: attributes.scale, broadcastResPosBias, passPastInKv, qkvFormat, - isPastkvBSNH, }; }; -const createConcatProgramInfo = ( - a: TensorView, - b: TensorView | undefined, - dataType: DataType, - params: AttentionParameters, -): ProgramInfo => { - const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize]; - const component = 4; - const outputSize = ShapeUtil.size(outputShape) / component; - const presentSequenceLength = params.totalSequenceLength; - const output = outputVariable('present_kv', dataType, outputShape.length, component); - const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component); - const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined; - - const H = Math.ceil(params.headSize / component); - const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 }; - - const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank']; - - const programUniforms: ProgramUniform[] = [ - { type: DataType.uint32, data: outputSize }, - { type: DataType.uint32, data: params.pastSequenceLength }, - { type: DataType.uint32, data: params.kvSequenceLength }, - { type: DataType.uint32, data: params.totalSequenceLength }, - ]; - - const inputs = [inputA]; - if (inputB) { - programUniforms.push( - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b!.dims), - ...createTensorShapeVariables(outputShape), - ); - inputs.push(inputB); - } else { - programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape)); - } - const uniforms: UniformsArrayType = [ - { name: 'output_size', type: 'u32' }, - { name: 'past_seqlen', type: 'u32' }, - { name: 'new_seqlen', type: 'u32' }, - { name: 'present_seqlen', type: 'u32' }, - ]; - - const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H; - var past_head_stride = uniforms.past_seqlen * H; - if (is_bsnh) { - past_head_stride = H; - } - let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; - present_kv[out_offset] = past_kv[in_offset];`; - const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H; - let new_row_stride = num_heads * H; - let new_head_stride = H; - let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; - present_kv[out_offset] = new_kv[in_offset];`; - const concatStr = b - ? `if (s < past_seqlen) { - ${pastStr} - } else if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }` - : `if (s < past_seqlen + uniforms.new_seqlen) { - ${newStr} - }`; - - // TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit. - const getShaderSource = (shaderHelper: ShaderHelper) => ` - - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)} - ${shaderHelper.mainStart([H, params.kvNumHeads!, 1])} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var indices = ${output.offsetToIndices('global_idx')}; - let h = local_id.x; - let n = local_id.y; - let s = workgroup_id.x; - let b = workgroup_id.y; - let num_heads = ${params.kvNumHeads!}u; - let H = ${H}u; - - let present_seqlen = uniforms.present_seqlen; - let present_batch_stride = present_seqlen * num_heads * H; - var row_stride = H; - let is_bsnh = ${params.isPastkvBSNH}; - - if (is_bsnh) { - row_stride = num_heads * H; - } - var present_head_stride = present_seqlen * H; - if (is_bsnh) { - present_head_stride = H; - } - - let past_seqlen = uniforms.past_seqlen; - - let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; - ${concatStr} - }`; - - return { - name: 'ConcatPastNew', - shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies }, - getRunData: () => ({ - outputs: [{ dims: outputShape, dataType }], - dispatchGroup: dispatch, - programUniforms, - }), - getShaderSource, - }; -}; - -export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({ ...attributes }); - const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] }); -const maybeExpandAndTransposeToBNSH = ( - context: ComputeContext, - input: TensorView, - pastKV: TensorView | undefined, - params: AttentionParameters, - outputIndex: number, -) => { +const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => { let reshapedInput = input; const numHeads = params.kvNumHeads!; - const nReps = params.nReps!; if (input.dims.length === 3 && params.kvSequenceLength !== 0) { reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]); - } - - if (pastKV) { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), { - inputs: [reshapedInput, pastKV], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } else { - reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), { - inputs: [reshapedInput], - outputs: [params.isPastkvBSNH ? outputIndex : -1], - })[0]; - } - if (nReps !== 1) { - reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), { + reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], })[0]; - reshapedInput = reshapedInput.reshape([ - params.batchSize, - params.totalSequenceLength, - numHeads * nReps, - params.headSize, - ]); } - return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { - inputs: [reshapedInput], - outputs: [-1], - })[0]; + return reshapedInput; }; -export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { +export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { throw new Error('Packed QKV is not implemented'); @@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti throw new Error('Packed KV is not implemented'); } + const q = context.inputs[0]; + const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; + const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; + const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; + const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; + const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined; + const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined; + const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; + + // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead. + + const splitAttributes: SplitAttributes = createAttributeWithCacheKey({ + axis: 2, + numOutputs: 3, + splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize], + }); + const [query, key, value] = + !k && !v + ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) + : [q, k!, v!]; + const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, - context.inputs[0], + query, undefined, 0, ); - const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; - const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; - const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1); - const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2); - applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes); + applyAttention( + context, + Q, + maybeTransposeToBNSH(context, key, params), + maybeTransposeToBNSH(context, value, params), + undefined, + undefined, + pastKey, + pastValue, + undefined, + params, + seqLens, + totalSequenceLengthInput, + ); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 1a312539056..db7a4b8e68b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -403,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio ); if (kvBNSH) { - return applyAttention( - context, - Q, - key, - value, - keyPaddingMask, - undefined, - pastKey, - pastValue, - attentionBias, - params, - attributes, - ); + return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); } if (!key || !value) { throw new Error('key and value must be provided'); @@ -442,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio 2 * params.hiddenSize, ); - applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes); + applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 1dc3a206cf9..8c39505734e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { }`; }; -const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { +export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType; diff --git a/js/web/test/data/ops/group-query-attention.jsonc b/js/web/test/data/ops/group-query-attention.jsonc index 2a4b2650784..036069f43eb 100644 --- a/js/web/test/data/ops/group-query-attention.jsonc +++ b/js/web/test/data/ops/group-query-attention.jsonc @@ -1,11 +1,163 @@ [ { - "name": "GroupQueryAttention Basic", + "name": "GroupQueryAttention 0", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 1", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7], + "dims": [1, 1, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // past value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length, unimplemented + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [48, 49, 50, 51, 52, 53, 54, 55], + "dims": [1, 1, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [40, 41, 42, 43, 44, 45, 46, 47, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 32, 33, 34, 35, 36, 37, 38, 39], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 2", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { @@ -13,43 +165,45 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 ], "dims": [1, 3, 16], "type": "float32" }, // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], "dims": [1, 3, 8], "type": "float32" }, // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], "dims": [1, 3, 8], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [3], "dims": [1], "type": "int32" } @@ -57,22 +211,22 @@ "outputs": [ { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, - 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21 + 72, 73, 74, 75, 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 ], "dims": [1, 3, 16], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21], - "dims": [1, 3, 2, 4], + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], "type": "float32" } ] @@ -80,86 +234,83 @@ ] }, { - "name": "GroupQueryAttention Scale", + "name": "GroupQueryAttention 3", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" }, - { "name": "scale", "data": 2.0, "type": "float" } + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } ], "cases": [ { "name": "T[0]", "inputs": [ { - "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 - ], - "dims": [1, 4, 8], + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], "type": "float32" }, + // key, BS* { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 3, 8], "type": "float32" }, + // value, BS* { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], "type": "float32" }, - // past key, BS* + // pask key, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // past value, BS* + // pask value, BNSH { - "data": null, + "data": [], + "dims": [1, 1, 0, 8], "type": "float32" }, - // seqlens_k, unimplemented + // seqlens_k { - "data": [1], + "data": [3], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [3], "dims": [1], "type": "int32" } ], "outputs": [ { - "data": [ - 1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, - 1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731 - ], - "dims": [1, 4, 8], + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], "type": "float32" }, { - // present key, BS* - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 3, 8], "type": "float32" }, { - // present value, BS* - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], "type": "float32" } ] } ] }, - { - "name": "GroupQueryAttention, different sequence length", + "name": "GroupQueryAttention 4", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ @@ -172,38 +323,378 @@ "inputs": [ { "data": [ - 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 ], - "dims": [1, 4, 8], + "dims": [1, 3, 32], + "type": "float32" + }, + // key, BS* + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // value, BS* + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, + 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 16], + "type": "float32" + }, + // past key, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // past value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, + 76, 77, 78, 79, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 3, 32], "type": "float32" }, { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + // present key, BNSH + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, 32, 33, 34, 35, 36, 37, 38, 39, 8, 9, 10, 11, 12, + 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47 + ], + "dims": [1, 2, 3, 8], "type": "float32" }, { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 4], + // present value, BNSH + "data": [ + 48, 49, 50, 51, 52, 53, 54, 55, 64, 65, 66, 67, 68, 69, 70, 71, 80, 81, 82, 83, 84, 85, 86, 87, 56, 57, + 58, 59, 60, 61, 62, 63, 72, 73, 74, 75, 76, 77, 78, 79, 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 2, 3, 8], "type": "float32" - }, - // past key, BS* + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 5", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ { - "data": null, + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], "type": "float32" }, - // past value, BS* + // key, BS* { - "data": null, + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 8], "type": "float32" }, - // seqlens_k, unimplemented + // value, BS* + { + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k { "data": [1], "dims": [1], "type": "int32" }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 6", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [3], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 1, 3, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 1, 3, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 7", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 3, 8], + "type": "float32" + }, + // key, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], + "dims": [1, 3, 8], + "type": "float32" + }, + // value, BS* + { + "data": [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], + "dims": [1, 3, 8], + "type": "float32" + }, + // past key, BS* + { + "data": [96, 97, 98, 99, 100, 101, 102, 103], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // past value, BS* + { + "data": [104, 105, 106, 107, 108, 109, 110, 111], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + // seqlens_k, unimplemented + { + "data": [3], + "dims": [1], + "type": "int32" + }, // total_sequence_length, unimplemented + { + "data": [4], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, 109, 110, 111, 104, 105, 106, 107, 108, + 109, 110, 111 + ], + "dims": [1, 3, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [ + 96, 97, 98, 99, 100, 101, 102, 103, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, + 65, 66, 67, 68, 69, 70, 71 + ], + "dims": [1, 1, 4, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [ + 104, 105, 106, 107, 108, 109, 110, 111, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, + 88, 89, 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 4, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": " GroupQueryAttention 8", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 4, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [1, 1, 32], + "type": "float32" + }, + // key, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + // value, BS* + { + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 1, 16], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length { "data": [1], "dims": [1], @@ -213,22 +704,22 @@ "outputs": [ { "data": [ - 1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823, - 1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, - 1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1, - 1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815 + 48, 49, 50, 51, 52, 53, 54, 55, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 56, 57, + 58, 59, 60, 61, 62, 63 ], - "dims": [1, 4, 8], + "dims": [1, 1, 32], "type": "float32" }, { - "data": [1, 9, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present key, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], "type": "float32" }, { - "data": [1, 1, 1, 1, 2, 2, 2, 2], - "dims": [1, 2, 2, 2], + // present value, BNSH + "data": [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], + "dims": [1, 2, 1, 8], "type": "float32" } ] @@ -236,12 +727,505 @@ ] }, { - "name": "GroupQueryAttention Basic, q k v same head number", + "name": "GroupQueryAttention 9", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 10", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 1, 16], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 16], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 16], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 1, 16], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 11", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [1, 2, 8], + "type": "float32" + }, + // key, BS* + { + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 2, 8], + "type": "float32" + }, + // value, BS* + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [2], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [2], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 2, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 2, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + "dims": [1, 1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 12", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [1, 1, 32], + "type": "float32" + }, + // key, BS* + { + "data": [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 + ], + "dims": [1, 1, 32], + "type": "float32" + }, + // value, BS* + { + "data": [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 32], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 32], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 32], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 32], + "type": "float32" + }, + { + // present key, BNSH + "data": [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 + ], + "dims": [1, 1, 1, 32], + "type": "float32" + }, + { + // present value, BNSH + "data": [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 1, 32], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention 13", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 1, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + // key, BS* + { + "data": [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + // value, BS* + { + "data": [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [4], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [4], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 4, 8], + "type": "float32" + }, + { + // present key, BNSH + "data": [ + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, + 58, 59, 60, 61, 62, 63 + ], + "dims": [1, 1, 4, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95 + ], + "dims": [1, 1, 4, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention PackedQKV 14", + "operator": "GroupQueryAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "num_heads", "data": 2, "type": "int" }, + { "name": "kv_num_heads", "data": 1, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ], + "dims": [1, 1, 32], + "type": "float32" + }, + // key, BS* + { + "data": null, + "type": "float32" + }, + // value, BS* + { + "data": null, + "type": "float32" + }, + // pask key, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 1, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // total_sequence_length + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [24, 25, 26, 27, 28, 29, 30, 31, 24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 16], + "type": "float32" + }, + { + // present key, BNSH + "data": [16, 17, 18, 19, 20, 21, 22, 23], + "dims": [1, 1, 1, 8], + "type": "float32" + }, + { + // present value, BNSH + "data": [24, 25, 26, 27, 28, 29, 30, 31], + "dims": [1, 1, 1, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GroupQueryAttention PackedQKV 15", "operator": "GroupQueryAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [ { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 4, "type": "int" } + { "name": "kv_num_heads", "data": 2, "type": "int" } ], "cases": [ { @@ -250,46 +1234,47 @@ { "data": [ 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, - 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4 + 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, + 22, 21, 1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, + 1, 3, 4, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, + 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, + 2, 131, 22, 21 ], - "dims": [1, 3, 16], + "dims": [1, 3, 64], "type": "float32" }, - { - "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 - ], - "dims": [1, 3, 16], - "type": "float32" - }, - { - "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 - ], - "dims": [1, 3, 16], - "type": "float32" - }, - // past key, BS* + // key { "data": null, "type": "float32" }, - // past value, BS* + // value { "data": null, "type": "float32" }, - // seqlens_k, unimplemented + // pask key, BNSH { - "data": [1], + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // pask value, BNSH + { + "data": [], + "dims": [1, 2, 0, 8], + "type": "float32" + }, + // seqlens_k + { + "data": [3], "dims": [1], "type": "int32" }, - // total_sequence_length, unimplemented + // total_sequence_length { - "data": [1], + "data": [3], "dims": [1], "type": "int32" } @@ -297,316 +1282,29 @@ "outputs": [ { "data": [ - 1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, - 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21 + 1, 9, 1, 1, 2, 2, 2, 2, 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 1, 12, 21, 131, 22, 21, 2, + 2, 8, 12, 233, 4, 5, 6, 7, 8, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4, + 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 5, 6, 7, 8, 1, 1, 3, 4, 5, 6, 7, 8, 1, 1, 3, 4 ], - "dims": [1, 3, 16], + "dims": [1, 3, 32], "type": "float32" }, { + // present key, BNSH "data": [ - 1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, - 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 2, 3, 4, 5, 6, 7, 131, 22, 21, 2, 2, 131, 22, 21, 5, 6, 7, 8, 1, 1, 3, 4, + 8, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 2, 2, 2 ], - "dims": [1, 3, 4, 4], + "dims": [1, 2, 3, 8], "type": "float32" }, { + // present value, BNSH "data": [ - 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1, - 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21 + 1, 9, 1, 1, 2, 2, 2, 2, 8, 12, 233, 4, 5, 6, 7, 8, 1, 1, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, + 5, 6, 7, 8, 1, 1, 3, 4, 131, 22, 21, 2, 2, 131, 22, 21 ], - "dims": [1, 3, 4, 4], - "type": "float32" - } - ] - } - ] - }, - { - "name": "GroupQueryAttention, no past kv, used as reference", - "operator": "GroupQueryAttention", - "opset": { "domain": "com.microsoft", "version": 1 }, - "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } - ], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 - ], - "dims": [1, 7, 16], - "type": "float32" - }, - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 8], - "type": "float32" - }, - { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 8], - "type": "float32" - }, - // past key, BS* - { - "data": null, - "type": "float32" - }, - // past value, BS* - { - "data": null, - "type": "float32" - }, - // seqlens_k, unimplemented - { - "data": [1], - "dims": [1], - "type": "int32" - }, - // total_sequence_length, unimplemented - { - "data": [1], - "dims": [1], - "type": "int32" - } - ], - "outputs": [ - { - "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 - ], - "dims": [1, 7, 16], - "type": "float32" - }, - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 2, 4], - "type": "float32" - }, - { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 2, 4], - "type": "float32" - } - ] - } - ] - }, - { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1", - "operator": "GroupQueryAttention", - "opset": { "domain": "com.microsoft", "version": 1 }, - "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } - ], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 - ], - "dims": [1, 7, 16], - "type": "float32" - }, - // new key, BS* - { - "data": [ - 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 6, 8], - "type": "float32" - }, - // new value, BS* - { - "data": [ - 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 6, 8], - "type": "float32" - }, - // past key, BS* - { - "data": [1, 2, 3, 4, 5, 6, 7, 8], - "dims": [1, 1, 2, 4], - "type": "float32" - }, - // past value, BS* - { - "data": [0, 1, 2, 3, 4, 5, 6, 7], - "dims": [1, 1, 2, 4], - "type": "float32" - }, - // seqlens_k, unimplemented - { - "data": [1], - "dims": [1], - "type": "int32" - }, - // total_sequence_length, unimplemented - { - "data": [1], - "dims": [1], - "type": "int32" - } - ], - "outputs": [ - { - "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 - ], - "dims": [1, 7, 16], - "type": "float32" - }, - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 2, 4], - "type": "float32" - }, - { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 2, 4], - "type": "float32" - } - ] - } - ] - }, - { - "name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2", - "operator": "GroupQueryAttention", - "opset": { "domain": "com.microsoft", "version": 1 }, - "attributes": [ - { "name": "num_heads", "data": 4, "type": "int" }, - { "name": "kv_num_heads", "data": 2, "type": "int" } - ], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, - 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, - 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, - 107, 108, 109, 110, 111, 112 - ], - "dims": [1, 7, 16], - "type": "float32" - }, - // new key, BS* - { - "data": [ - 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, - 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 5, 8], - "type": "float32" - }, - // new value, BS* - { - "data": [ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 5, 8], - "type": "float32" - }, - // past key, BS* - { - "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], - "dims": [1, 2, 2, 4], - "type": "float32" - }, - // past value, BS* - { - "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - "dims": [1, 2, 2, 4], - "type": "float32" - }, - // seqlens_k, unimplemented - { - "data": [1], - "dims": [1], - "type": "int32" - }, - // total_sequence_length, unimplemented - { - "data": [1], - "dims": [1], - "type": "int32" - } - ], - "outputs": [ - { - "data": [ - 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, - 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, - 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, - 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, - 52, 53, 54, 55, 52, 53, 54, 55 - ], - "dims": [1, 7, 16], - "type": "float32" - }, - { - "data": [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, - 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56 - ], - "dims": [1, 7, 2, 4], - "type": "float32" - }, - { - "data": [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55 - ], - "dims": [1, 7, 2, 4], + "dims": [1, 2, 3, 8], "type": "float32" } ] diff --git a/onnxruntime/contrib_ops/js/bert/group_query_attention.h b/onnxruntime/contrib_ops/js/bert/group_query_attention.h index 7553883a247..dff8663133c 100644 --- a/onnxruntime/contrib_ops/js/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/js/bert/group_query_attention.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once - +#include "contrib_ops/cpu/bert/gqa_attention_base.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { @@ -11,31 +11,29 @@ namespace js { using onnxruntime::js::JsKernel; -class GroupQueryAttention : public JsKernel { +class GroupQueryAttention : public JsKernel, GQAAttentionBase { public: explicit GroupQueryAttention(const OpKernelInfo& info) - : JsKernel(info) { - int64_t num_heads = 0; - int64_t kv_num_heads = 0; - ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); - ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); - num_heads_ = static_cast(num_heads); - kv_num_heads_ = static_cast(kv_num_heads); - scale_ = info.GetAttrOrDefault("scale", 0.0f); + : JsKernel(info), GQAAttentionBase(info, false) { JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({ "numHeads" : $1, "kvNumHeads" : $2, "scale" : $3, + "softcap" : $4, + "doRotary" : $5, + "rotaryInterleaved" : $6, + "smoothSoftmax" : $7, + "localWindowSize" : $8 }), static_cast(num_heads_), static_cast(kv_num_heads_), - static_cast(scale_)); + static_cast(scale_), + static_cast(softcap_), + static_cast(do_rotary_), + static_cast(rotary_interleaved_), + static_cast(use_smooth_softmax_), + static_cast(local_window_size_)); } - - protected: - int num_heads_; // number of attention heads - int kv_num_heads_; // number of k and v heads - float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size) }; } // namespace js