Remove explicit split operator in GQA when QKV packed.

This commit is contained in:
Satya Jandhyala 2024-10-28 12:35:20 -07:00
Родитель 7ad78733e6
Коммит 25199fc1e4
3 изменённых файлов: 38 добавлений и 46 удалений

Просмотреть файл

@ -71,6 +71,7 @@ export interface AttentionParameters {
rotaryInterLeaved?: number;
sommoothSoftmax?: number;
localWindowsSize?: number;
packedQKV?: boolean;
}
export interface AttentionAttrs {
@ -442,13 +443,14 @@ const createInPlaceSoftmaxProgramInfo = (
const createAttentionProbsProgramInfo = (
outputCount: number,
q: TensorView,
key: TensorView,
key: TensorView | undefined,
pastKey: TensorView | undefined,
attentionBias: TensorView | undefined,
parameters: AttentionParameters,
pastSequenceLength: number,
seqLens: TensorView | undefined,
totalSequenceLengthInput: TensorView | undefined,
packedQKV: boolean,
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
@ -474,7 +476,6 @@ const createAttentionProbsProgramInfo = (
{ type: DataType.uint32, data: vectorizedHeadSize },
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: parameters.numHeads },
{ type: DataType.uint32, data: parameters.headSize },
{ type: DataType.float, data: alpha },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: parameters.kvSequenceLength },
@ -482,7 +483,10 @@ const createAttentionProbsProgramInfo = (
];
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
if (key) {
inputDependencies.push('type');
}
if (feedPastKey) {
inputDependencies.push('type');
}
@ -501,8 +505,11 @@ const createAttentionProbsProgramInfo = (
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
const inputVars = [qInput];
if (key) {
const kInput = inputVariable('key', key.dataType, key.dims, components);
inputVars.push(kInput);
}
if (feedPastKey) {
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
@ -532,7 +539,6 @@ const createAttentionProbsProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'head_size', type: 'u32' },
{ name: 'alpha', type: 'f32' as UniformDataElementType },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
@ -555,10 +561,11 @@ const createAttentionProbsProgramInfo = (
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.N;
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
let qOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + headIdx * uniforms.M * uniforms.K' : 'workgroup_id.z * uniforms.M * uniforms.K'} + m * uniforms.K;
${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
let kOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K' : 'absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K'};
${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@ -573,12 +580,12 @@ const createAttentionProbsProgramInfo = (
if (n + local_id.y < past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
}`;
} else {
return `
if (n + local_id.y < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
}`;
}
})()}
@ -640,6 +647,7 @@ const createVxAttentionScoreProgramInfo = (
pastSequenceLength: number,
seqLens: TensorView | undefined = undefined,
totalSequenceLengthInput: TensorView | undefined = undefined,
packedQKV: boolean,
) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
@ -662,7 +670,6 @@ const createVxAttentionScoreProgramInfo = (
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: params.vHeadSize },
{ type: DataType.uint32, data: params.numHeads },
{ type: DataType.uint32, data: params.headSize },
{ type: DataType.uint32, data: repeatedVHiddenSize },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
@ -711,7 +718,6 @@ const createVxAttentionScoreProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'head_size', type: 'u32' },
{ name: 'v_hidden_size', type: 'u32' },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
@ -732,10 +738,11 @@ const createVxAttentionScoreProgramInfo = (
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.K;
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
let vOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length' : 'absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length'} + n;
${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
var value = ${probsHelper.type.storage}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@ -796,8 +803,8 @@ const createVxAttentionScoreProgramInfo = (
export const applyAttention = (
context: ComputeContext,
q: TensorView,
k: TensorView,
v: TensorView,
k: TensorView | undefined,
v: TensorView | undefined,
_maskIndex: TensorView | undefined,
_past: TensorView | undefined,
pastKey: TensorView | undefined,
@ -814,7 +821,10 @@ export const applyAttention = (
const attentionBias =
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
const inputsK = [q, k];
const inputsK = [q];
if (k) {
inputsK.push(k);
}
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
inputsK.push(pastKey);
}
@ -839,6 +849,7 @@ export const applyAttention = (
pastSequenceLength,
seqLens,
totalSequenceLengthInput,
parameters.packedQKV === true,
),
{ inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
)[0];
@ -859,7 +870,7 @@ export const applyAttention = (
);
// Run AttentionScore
const inputsV = [probs, v];
const inputsV = [probs, parameters.packedQKV ? q : v!];
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
@ -873,12 +884,13 @@ export const applyAttention = (
createVxAttentionScoreProgramInfo(
outputCount,
probs,
v,
parameters.packedQKV ? q : v!,
pastValue,
parameters,
pastSequenceLength,
seqLens,
totalSequenceLengthInput,
parameters.packedQKV === true,
),
{
inputs: inputsV,

Просмотреть файл

@ -7,7 +7,6 @@ import { ComputeContext } from '../types';
import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
import { createSplitProgramInfo, SplitAttributes } from './split';
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
export interface GroupQueryAttentionAttributes {
numHeads: number;
@ -216,6 +215,7 @@ export const validateInputs = (
broadcastResPosBias,
passPastInKv,
qkvFormat,
packedQKV,
};
};
@ -237,39 +237,19 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
}
if (context.inputs[1]?.dims.length === 5) {
throw new Error('Packed KV is not implemented');
}
const q = context.inputs[0];
const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
const query = context.inputs[0];
const key = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
const value = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
// TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.
const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
axis: 2,
numOutputs: 3,
splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
});
const [query, key, value] =
!k && !v
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
: [q, k!, v!];
const Q = maybeTransposeToBNSHAndAddBias(
context,
params.batchSize,
params.numHeads,
params.packedQKV ? params.numHeads + 2 * attributes.kvNumHeads : params.numHeads,
params.sequenceLength,
params.headSize,
query,
@ -279,8 +259,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
applyAttention(
context,
Q,
maybeTransposeToBNSH(context, key, params),
maybeTransposeToBNSH(context, value, params),
key ? maybeTransposeToBNSH(context, key, params) : undefined,
value ? maybeTransposeToBNSH(context, value, params) : undefined,
undefined,
undefined,
pastKey,

Просмотреть файл

@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
}`;
};
export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const dataType = inputs[0].dataType;