Remove explicit split operator in GQA when QKV packed.
This commit is contained in:
Родитель
7ad78733e6
Коммит
25199fc1e4
|
@ -71,6 +71,7 @@ export interface AttentionParameters {
|
|||
rotaryInterLeaved?: number;
|
||||
sommoothSoftmax?: number;
|
||||
localWindowsSize?: number;
|
||||
packedQKV?: boolean;
|
||||
}
|
||||
|
||||
export interface AttentionAttrs {
|
||||
|
@ -442,13 +443,14 @@ const createInPlaceSoftmaxProgramInfo = (
|
|||
const createAttentionProbsProgramInfo = (
|
||||
outputCount: number,
|
||||
q: TensorView,
|
||||
key: TensorView,
|
||||
key: TensorView | undefined,
|
||||
pastKey: TensorView | undefined,
|
||||
attentionBias: TensorView | undefined,
|
||||
parameters: AttentionParameters,
|
||||
pastSequenceLength: number,
|
||||
seqLens: TensorView | undefined,
|
||||
totalSequenceLengthInput: TensorView | undefined,
|
||||
packedQKV: boolean,
|
||||
) => {
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
|
||||
|
@ -474,7 +476,6 @@ const createAttentionProbsProgramInfo = (
|
|||
{ type: DataType.uint32, data: vectorizedHeadSize },
|
||||
{ type: DataType.uint32, data: totalSequenceLength },
|
||||
{ type: DataType.uint32, data: parameters.numHeads },
|
||||
{ type: DataType.uint32, data: parameters.headSize },
|
||||
{ type: DataType.float, data: alpha },
|
||||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: parameters.kvSequenceLength },
|
||||
|
@ -482,7 +483,10 @@ const createAttentionProbsProgramInfo = (
|
|||
];
|
||||
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
|
||||
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
|
||||
if (key) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (feedPastKey) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
|
@ -501,8 +505,11 @@ const createAttentionProbsProgramInfo = (
|
|||
}
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const qInput = inputVariable('q', q.dataType, q.dims, components);
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
const inputVars = [qInput, kInput];
|
||||
const inputVars = [qInput];
|
||||
if (key) {
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
inputVars.push(kInput);
|
||||
}
|
||||
if (feedPastKey) {
|
||||
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
|
||||
inputVars.push(pastKeyInput);
|
||||
|
@ -532,7 +539,6 @@ const createAttentionProbsProgramInfo = (
|
|||
{ name: 'K', type: 'u32' },
|
||||
{ name: 'N', type: 'u32' },
|
||||
{ name: 'num_heads', type: 'u32' },
|
||||
{ name: 'head_size', type: 'u32' },
|
||||
{ name: 'alpha', type: 'f32' as UniformDataElementType },
|
||||
{ name: 'past_sequence_length', type: 'u32' },
|
||||
{ name: 'kv_sequence_length', type: 'u32' },
|
||||
|
@ -555,10 +561,11 @@ const createAttentionProbsProgramInfo = (
|
|||
let sequence_length = uniforms.M;
|
||||
var total_sequence_length = uniforms.N;
|
||||
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
|
||||
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
|
||||
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
|
||||
let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
|
||||
let qOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + headIdx * uniforms.M * uniforms.K' : 'workgroup_id.z * uniforms.M * uniforms.K'} + m * uniforms.K;
|
||||
${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
|
||||
let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
|
||||
let kOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kvHeadIdx) * uniforms.kv_sequence_length * uniforms.K' : 'absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K'};
|
||||
${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
|
||||
var value = ${f32Type}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
|
@ -573,12 +580,12 @@ const createAttentionProbsProgramInfo = (
|
|||
if (n + local_id.y < past_sequence_length) {
|
||||
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
|
||||
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
|
||||
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
|
||||
}`;
|
||||
} else {
|
||||
return `
|
||||
if (n + local_id.y < uniforms.kv_sequence_length) {
|
||||
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
tileK[idx] = ${packedQKV ? 'q' : 'key'}[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
}`;
|
||||
}
|
||||
})()}
|
||||
|
@ -640,6 +647,7 @@ const createVxAttentionScoreProgramInfo = (
|
|||
pastSequenceLength: number,
|
||||
seqLens: TensorView | undefined = undefined,
|
||||
totalSequenceLengthInput: TensorView | undefined = undefined,
|
||||
packedQKV: boolean,
|
||||
) => {
|
||||
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
|
||||
const nReps = params.nReps ? params.nReps : 1;
|
||||
|
@ -662,7 +670,6 @@ const createVxAttentionScoreProgramInfo = (
|
|||
{ type: DataType.uint32, data: totalSequenceLength },
|
||||
{ type: DataType.uint32, data: params.vHeadSize },
|
||||
{ type: DataType.uint32, data: params.numHeads },
|
||||
{ type: DataType.uint32, data: params.headSize },
|
||||
{ type: DataType.uint32, data: repeatedVHiddenSize },
|
||||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: params.kvSequenceLength },
|
||||
|
@ -711,7 +718,6 @@ const createVxAttentionScoreProgramInfo = (
|
|||
{ name: 'K', type: 'u32' },
|
||||
{ name: 'N', type: 'u32' },
|
||||
{ name: 'num_heads', type: 'u32' },
|
||||
{ name: 'head_size', type: 'u32' },
|
||||
{ name: 'v_hidden_size', type: 'u32' },
|
||||
{ name: 'past_sequence_length', type: 'u32' },
|
||||
{ name: 'kv_sequence_length', type: 'u32' },
|
||||
|
@ -732,10 +738,11 @@ const createVxAttentionScoreProgramInfo = (
|
|||
let sequence_length = uniforms.M;
|
||||
var total_sequence_length = uniforms.K;
|
||||
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
|
||||
let packed_batch_stride = (uniforms.num_heads + 2 * kv_num_heads) * uniforms.M * uniforms.K;
|
||||
let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
|
||||
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
|
||||
${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
|
||||
let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
|
||||
let vOffset = ${packedQKV ? 'batchIdx * packed_batch_stride + (uniforms.num_heads + kv_num_heads + kvHeadIdx) * uniforms.N * uniforms.kv_sequence_length' : 'absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length'} + n;
|
||||
${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
|
||||
var value = ${probsHelper.type.storage}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
|
@ -796,8 +803,8 @@ const createVxAttentionScoreProgramInfo = (
|
|||
export const applyAttention = (
|
||||
context: ComputeContext,
|
||||
q: TensorView,
|
||||
k: TensorView,
|
||||
v: TensorView,
|
||||
k: TensorView | undefined,
|
||||
v: TensorView | undefined,
|
||||
_maskIndex: TensorView | undefined,
|
||||
_past: TensorView | undefined,
|
||||
pastKey: TensorView | undefined,
|
||||
|
@ -814,7 +821,10 @@ export const applyAttention = (
|
|||
const attentionBias =
|
||||
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
|
||||
|
||||
const inputsK = [q, k];
|
||||
const inputsK = [q];
|
||||
if (k) {
|
||||
inputsK.push(k);
|
||||
}
|
||||
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
|
||||
inputsK.push(pastKey);
|
||||
}
|
||||
|
@ -839,6 +849,7 @@ export const applyAttention = (
|
|||
pastSequenceLength,
|
||||
seqLens,
|
||||
totalSequenceLengthInput,
|
||||
parameters.packedQKV === true,
|
||||
),
|
||||
{ inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
|
||||
)[0];
|
||||
|
@ -859,7 +870,7 @@ export const applyAttention = (
|
|||
);
|
||||
|
||||
// Run AttentionScore
|
||||
const inputsV = [probs, v];
|
||||
const inputsV = [probs, parameters.packedQKV ? q : v!];
|
||||
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
|
||||
inputsV.push(pastValue);
|
||||
}
|
||||
|
@ -873,12 +884,13 @@ export const applyAttention = (
|
|||
createVxAttentionScoreProgramInfo(
|
||||
outputCount,
|
||||
probs,
|
||||
v,
|
||||
parameters.packedQKV ? q : v!,
|
||||
pastValue,
|
||||
parameters,
|
||||
pastSequenceLength,
|
||||
seqLens,
|
||||
totalSequenceLengthInput,
|
||||
parameters.packedQKV === true,
|
||||
),
|
||||
{
|
||||
inputs: inputsV,
|
||||
|
|
|
@ -7,7 +7,6 @@ import { ComputeContext } from '../types';
|
|||
|
||||
import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
|
||||
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
|
||||
import { createSplitProgramInfo, SplitAttributes } from './split';
|
||||
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
|
||||
export interface GroupQueryAttentionAttributes {
|
||||
numHeads: number;
|
||||
|
@ -216,6 +215,7 @@ export const validateInputs = (
|
|||
broadcastResPosBias,
|
||||
passPastInKv,
|
||||
qkvFormat,
|
||||
packedQKV,
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -237,39 +237,19 @@ const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params
|
|||
|
||||
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
|
||||
const params = validateInputs(context.inputs, attributes);
|
||||
if (context.inputs[0].dims.length === 5) {
|
||||
throw new Error('Packed QKV is not implemented');
|
||||
}
|
||||
|
||||
if (context.inputs[1]?.dims.length === 5) {
|
||||
throw new Error('Packed KV is not implemented');
|
||||
}
|
||||
|
||||
const q = context.inputs[0];
|
||||
const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
|
||||
const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
|
||||
const query = context.inputs[0];
|
||||
const key = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
|
||||
const value = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
|
||||
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
|
||||
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
|
||||
const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
|
||||
const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
|
||||
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
|
||||
|
||||
// TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.
|
||||
|
||||
const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
|
||||
axis: 2,
|
||||
numOutputs: 3,
|
||||
splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
|
||||
});
|
||||
const [query, key, value] =
|
||||
!k && !v
|
||||
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
|
||||
: [q, k!, v!];
|
||||
|
||||
const Q = maybeTransposeToBNSHAndAddBias(
|
||||
context,
|
||||
params.batchSize,
|
||||
params.numHeads,
|
||||
params.packedQKV ? params.numHeads + 2 * attributes.kvNumHeads : params.numHeads,
|
||||
params.sequenceLength,
|
||||
params.headSize,
|
||||
query,
|
||||
|
@ -279,8 +259,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
|
|||
applyAttention(
|
||||
context,
|
||||
Q,
|
||||
maybeTransposeToBNSH(context, key, params),
|
||||
maybeTransposeToBNSH(context, value, params),
|
||||
key ? maybeTransposeToBNSH(context, key, params) : undefined,
|
||||
value ? maybeTransposeToBNSH(context, value, params) : undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
pastKey,
|
||||
|
|
|
@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
|
|||
}`;
|
||||
};
|
||||
|
||||
export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
|
||||
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const inputSize = ShapeUtil.size(inputShape);
|
||||
const dataType = inputs[0].dataType;
|
||||
|
|
Загрузка…
Ссылка в новой задаче