diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 6a78c8ae3b..e55e99b3a3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -71,6 +71,7 @@ export interface AttentionParameters { rotaryInterLeaved?: number; sommoothSoftmax?: number; localWindowsSize?: number; + packedQKV?: boolean; } export interface AttentionAttrs { @@ -442,13 +443,14 @@ const createInPlaceSoftmaxProgramInfo = ( const createAttentionProbsProgramInfo = ( outputCount: number, q: TensorView, - key: TensorView, + key: TensorView | undefined, pastKey: TensorView | undefined, attentionBias: TensorView | undefined, parameters: AttentionParameters, pastSequenceLength: number, seqLens: TensorView | undefined, totalSequenceLengthInput: TensorView | undefined, + packedQKV: boolean, ) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; @@ -474,7 +476,6 @@ const createAttentionProbsProgramInfo = ( { type: DataType.uint32, data: vectorizedHeadSize }, { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: parameters.numHeads }, - { type: DataType.uint32, data: parameters.headSize }, { type: DataType.float, data: alpha }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: parameters.kvSequenceLength }, @@ -482,7 +483,10 @@ const createAttentionProbsProgramInfo = ( ]; // Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0; - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + if (key) { + inputDependencies.push('type'); + } if (feedPastKey) { inputDependencies.push('type'); } @@ -501,8 +505,11 @@ const createAttentionProbsProgramInfo = ( } const getShaderSource = (shaderHelper: ShaderHelper) => { const qInput = inputVariable('q', q.dataType, q.dims, components); - const kInput = inputVariable('key', key.dataType, key.dims, components); - const inputVars = [qInput, kInput]; + const inputVars = [qInput]; + if (key) { + const kInput = inputVariable('key', key.dataType, key.dims, components); + inputVars.push(kInput); + } if (feedPastKey) { const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); inputVars.push(pastKeyInput); @@ -532,7 +539,6 @@ const createAttentionProbsProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, - { name: 'head_size', type: 'u32' }, { name: 'alpha', type: 'f32' as UniformDataElementType }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, @@ -555,10 +561,11 @@ const createAttentionProbsProgramInfo = ( let sequence_length = uniforms.M; var total_sequence_length = uniforms.N; ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K; let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; - let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; + let qOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + headIdx * uniforms.M * uniforms.K' : 'workgroup_id.z * uniforms.M * uniforms.K'} + m * uniforms.K; ${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''}; - let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K; + let kOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K' : 'absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K'}; ${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { @@ -573,12 +580,12 @@ const createAttentionProbsProgramInfo = ( if (n + local_id.y < past_sequence_length) { tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) { - tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; + tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x]; }`; } else { return ` if (n + local_id.y < uniforms.kv_sequence_length) { - tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; + tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; }`; } })()} @@ -640,6 +647,7 @@ const createVxAttentionScoreProgramInfo = ( pastSequenceLength: number, seqLens: TensorView | undefined = undefined, totalSequenceLengthInput: TensorView | undefined = undefined, + packedQKV: boolean, ) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; @@ -662,7 +670,6 @@ const createVxAttentionScoreProgramInfo = ( { type: DataType.uint32, data: totalSequenceLength }, { type: DataType.uint32, data: params.vHeadSize }, { type: DataType.uint32, data: params.numHeads }, - { type: DataType.uint32, data: params.headSize }, { type: DataType.uint32, data: repeatedVHiddenSize }, { type: DataType.uint32, data: pastSequenceLength }, { type: DataType.uint32, data: params.kvSequenceLength }, @@ -711,7 +718,6 @@ const createVxAttentionScoreProgramInfo = ( { name: 'K', type: 'u32' }, { name: 'N', type: 'u32' }, { name: 'num_heads', type: 'u32' }, - { name: 'head_size', type: 'u32' }, { name: 'v_hidden_size', type: 'u32' }, { name: 'past_sequence_length', type: 'u32' }, { name: 'kv_sequence_length', type: 'u32' }, @@ -732,10 +738,11 @@ const createVxAttentionScoreProgramInfo = ( let sequence_length = uniforms.M; var total_sequence_length = uniforms.K; ${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)} + let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K; let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K; let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch ${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''}; - let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n; + let vOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length' : 'absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length'} + n; ${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { @@ -796,8 +803,8 @@ const createVxAttentionScoreProgramInfo = ( export const applyAttention = ( context: ComputeContext, q: TensorView, - k: TensorView, - v: TensorView, + k: TensorView | undefined, + v: TensorView | undefined, _maskIndex: TensorView | undefined, _past: TensorView | undefined, pastKey: TensorView | undefined, @@ -814,7 +821,10 @@ export const applyAttention = ( const attentionBias = attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined; - const inputsK = [q, k]; + const inputsK = [q]; + if (k) { + inputsK.push(k); + } if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) { inputsK.push(pastKey); } @@ -839,6 +849,7 @@ export const applyAttention = ( pastSequenceLength, seqLens, totalSequenceLengthInput, + parameters.packedQKV === true, ), { inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] }, )[0]; @@ -859,7 +870,7 @@ export const applyAttention = ( ); // Run AttentionScore - const inputsV = [probs, v]; + const inputsV = [probs, parameters.packedQKV ? q : v!]; if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) { inputsV.push(pastValue); } @@ -873,12 +884,13 @@ export const applyAttention = ( createVxAttentionScoreProgramInfo( outputCount, probs, - v, + parameters.packedQKV ? q : v!, pastValue, parameters, pastSequenceLength, seqLens, totalSequenceLengthInput, + parameters.packedQKV === true, ), { inputs: inputsV, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index bbe25460d6..076e988218 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -7,7 +7,6 @@ import { ComputeContext } from '../types'; import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention'; import { maybeTransposeToBNSHAndAddBias } from './multihead-attention'; -import { createSplitProgramInfo, SplitAttributes } from './split'; import { createTransposeProgramInfo, TransposeAttributes } from './transpose'; export interface GroupQueryAttentionAttributes { numHeads: number; @@ -216,6 +215,7 @@ export const validateInputs = ( broadcastResPosBias, passPastInKv, qkvFormat, + packedQKV, }; }; @@ -237,39 +237,19 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { const params = validateInputs(context.inputs, attributes); - if (context.inputs[0].dims.length === 5) { - throw new Error('Packed QKV is not implemented'); - } - if (context.inputs[1]?.dims.length === 5) { - throw new Error('Packed KV is not implemented'); - } - - const q = context.inputs[0]; - const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; - const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; + const query = context.inputs[0]; + const key = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined; + const value = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined; const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined; const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined; const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined; const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined; - const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads; - - // TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead. - - const splitAttributes: SplitAttributes = createAttributeWithCacheKey({ - axis: 2, - numOutputs: 3, - splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize], - }); - const [query, key, value] = - !k && !v - ? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] }) - : [q, k!, v!]; const Q = maybeTransposeToBNSHAndAddBias( context, params.batchSize, - params.numHeads, + params.packedQKV ? params.numHeads + 2 * attributes.kvNumHeads : params.numHeads, params.sequenceLength, params.headSize, query, @@ -279,8 +259,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu applyAttention( context, Q, - maybeTransposeToBNSH(context, key, params), - maybeTransposeToBNSH(context, value, params), + key ? maybeTransposeToBNSH(context, key, params) : undefined, + value ? maybeTransposeToBNSH(context, value, params) : undefined, undefined, undefined, pastKey, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 8c39505734..1dc3a206cf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => { }`; }; -export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { +const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const inputSize = ShapeUtil.size(inputShape); const dataType = inputs[0].dataType;