[JS/WebGPU] GroupQueryAttention rewrite (#20946)
### Description Implement JSEP GroupQueryAttention ### Motivation and Context Required to enable certain LLM models to run using WebGPU.
This commit is contained in:
Родитель
33e2f6ad8d
Коммит
fd8ee4894d
|
@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather';
|
|||
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
|
||||
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
|
||||
import { gemm, parseGemmAttributes } from './ops/gemm';
|
||||
import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
|
||||
import { groupQueryAttention } from './ops/group-query-attention';
|
||||
import { instanceNorm } from './ops/instance-norm';
|
||||
import { layerNorm } from './ops/layer-norm';
|
||||
import { matMul } from './ops/matmul';
|
||||
|
@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['Greater', [binaryOps.greater]],
|
||||
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
|
||||
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
|
||||
['GroupQueryAttention', [groupQueryAttention]],
|
||||
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
|
||||
['InstanceNormalization', [instanceNorm]],
|
||||
['LayerNormalization', [layerNorm]],
|
||||
|
|
|
@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU
|
|||
|
||||
import {
|
||||
getMaxComponents,
|
||||
IndicesHelper,
|
||||
inputVariable,
|
||||
outputVariable,
|
||||
ShaderHelper,
|
||||
|
@ -65,14 +66,17 @@ export interface AttentionParameters {
|
|||
broadcastResPosBias: boolean;
|
||||
passPastInKv: boolean;
|
||||
qkvFormat: AttentionQkvFormat;
|
||||
isPastkvBSNH?: boolean;
|
||||
softcap?: number;
|
||||
doRotary?: number;
|
||||
rotaryInterLeaved?: number;
|
||||
sommoothSoftmax?: number;
|
||||
localWindowsSize?: number;
|
||||
}
|
||||
|
||||
export interface AttentionAttrs {
|
||||
numHeads: number;
|
||||
kvNumHeads?: number;
|
||||
isUnidirectional?: number;
|
||||
maskFilterValue?: number;
|
||||
isUnidirectional: number;
|
||||
maskFilterValue: number;
|
||||
scale: number;
|
||||
doRotary: number;
|
||||
qkvHiddenSizes: number[];
|
||||
|
@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
|
|||
};
|
||||
};
|
||||
|
||||
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
|
||||
const components = getMaxComponents(d);
|
||||
const initVarStub = (
|
||||
seqLensInput: IndicesHelper | undefined,
|
||||
totalSequenceLengthInput: IndicesHelper | undefined,
|
||||
initPastSequenceLength: boolean,
|
||||
) => {
|
||||
// In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input
|
||||
if (totalSequenceLengthInput && seqLensInput) {
|
||||
return `
|
||||
let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')});
|
||||
let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);
|
||||
let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;
|
||||
let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;
|
||||
total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1;
|
||||
var past_sequence_length: u32 = 0;
|
||||
if (is_first_prompt == false) {
|
||||
past_sequence_length = total_sequence_length - sequence_length;
|
||||
}
|
||||
`;
|
||||
} else {
|
||||
return `
|
||||
${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''};
|
||||
let present_sequence_length = total_sequence_length;
|
||||
`;
|
||||
}
|
||||
};
|
||||
|
||||
const createInPlaceSoftmaxProgramInfo = (
|
||||
input: TensorView,
|
||||
batchSize: number,
|
||||
numHeads: number,
|
||||
pastSequenceLength: number,
|
||||
sequenceLength: number,
|
||||
totalSequenceLength: number,
|
||||
seqLens: TensorView | undefined,
|
||||
totalSequenceLengthInput: TensorView | undefined,
|
||||
) => {
|
||||
// Set components to 1 if seqLens is specified, i.e. GroupQueryAttention.
|
||||
const components = getMaxComponents(seqLens ? 1 : totalSequenceLength);
|
||||
let WG = 64;
|
||||
const dComp = d / components;
|
||||
if (dComp < WG) {
|
||||
const totalSequenceLengthComp = totalSequenceLength / components;
|
||||
if (totalSequenceLengthComp < WG) {
|
||||
WG = 32;
|
||||
}
|
||||
const elementsPerThread = Math.ceil(d / components / WG);
|
||||
const elementsPerThread = Math.ceil(totalSequenceLength / components / WG);
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.float, data: 1 / d },
|
||||
{ type: DataType.uint32, data: dComp },
|
||||
{ type: DataType.uint32, data: batchSize },
|
||||
{ type: DataType.uint32, data: numHeads },
|
||||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: sequenceLength },
|
||||
{ type: DataType.uint32, data: totalSequenceLengthComp },
|
||||
{ type: DataType.uint32, data: elementsPerThread },
|
||||
];
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
|
||||
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
|
||||
if (seqLens) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (totalSequenceLengthInput) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
|
||||
const inputHelpers = [inputHelper];
|
||||
const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
|
||||
if (seqLensInputHelper) {
|
||||
inputHelpers.push(seqLensInputHelper);
|
||||
}
|
||||
|
||||
const totalSequenceLengthInputHelper = totalSequenceLengthInput
|
||||
? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
|
||||
: undefined;
|
||||
if (totalSequenceLengthInputHelper) {
|
||||
inputHelpers.push(totalSequenceLengthInputHelper);
|
||||
}
|
||||
const elemValueType = tensorTypeToWsglValueType(input.dataType);
|
||||
const uniforms: UniformsArrayType = [
|
||||
{ name: 'd_inv', type: 'f32' },
|
||||
{ name: 'd_comp', type: 'u32' },
|
||||
{ name: 'batch_size', type: 'u32' },
|
||||
{ name: 'num_heads', type: 'u32' },
|
||||
{ name: 'past_sequence_length', type: 'u32' },
|
||||
{ name: 'sequence_length', type: 'u32' },
|
||||
{ name: 'total_sequence_length', type: 'u32' },
|
||||
{ name: 'elements_per_thread', type: 'u32' },
|
||||
];
|
||||
|
||||
return `
|
||||
var<workgroup> thread_max: array<f32, ${WG}>;
|
||||
var<workgroup> thread_sum: array<f32, ${WG}>;
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)}
|
||||
${shaderHelper.mainStart([WG, 1, 1])}
|
||||
let batchIdx = workgroup_id.z / uniforms.num_heads;
|
||||
let headIdx = workgroup_id.z % uniforms.num_heads;
|
||||
let sequence_length = uniforms.sequence_length;
|
||||
var total_sequence_length = uniforms.total_sequence_length;
|
||||
${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)}
|
||||
let local_offset = local_idx * uniforms.elements_per_thread;
|
||||
let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
|
||||
|
||||
let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset;
|
||||
let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'};
|
||||
var thread_max_vector = ${f32Type}(-3.402823e+38f);
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
|
||||
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
|
||||
}
|
||||
thread_max[local_idx] = ${(() => {
|
||||
|
@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
|
|||
}
|
||||
|
||||
var sum_vector = ${f32Type}(0);
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
|
||||
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
|
||||
}
|
||||
thread_sum[local_idx] = ${(() => {
|
||||
|
@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
|
|||
}
|
||||
|
||||
if (sum == 0) {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
|
||||
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length));
|
||||
}
|
||||
} else {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
|
||||
var f32input = ${f32Type}(x[offset + i]);
|
||||
x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
|
||||
}
|
||||
}
|
||||
${
|
||||
seqLens
|
||||
? `
|
||||
for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {
|
||||
x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0));
|
||||
}`
|
||||
: ''
|
||||
};
|
||||
}`;
|
||||
};
|
||||
|
||||
|
@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
|
|||
name: 'AttentionProbsSoftmax',
|
||||
shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies },
|
||||
getShaderSource,
|
||||
getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }),
|
||||
getRunData: () => ({
|
||||
outputs: [],
|
||||
dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads },
|
||||
programUniforms,
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = (
|
|||
pastKey: TensorView | undefined,
|
||||
attentionBias: TensorView | undefined,
|
||||
parameters: AttentionParameters,
|
||||
attributes: AttentionAttrs,
|
||||
pastSequenceLength: number,
|
||||
seqLens: TensorView | undefined,
|
||||
totalSequenceLengthInput: TensorView | undefined,
|
||||
) => {
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
|
||||
const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
|
||||
const presentKey = outputCount > 1 && pastKey;
|
||||
const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads;
|
||||
const presentKeyShape = presentKey
|
||||
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
|
||||
? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize]
|
||||
: undefined;
|
||||
|
||||
const nReps = parameters.nReps ? parameters.nReps : 1;
|
||||
// TODO: handle mask
|
||||
|
||||
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
|
||||
const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale;
|
||||
const components = getMaxComponents(parameters.headSize);
|
||||
const vectorizedHeadSize = parameters.headSize / components;
|
||||
const TILE_SIZE = 12;
|
||||
|
@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = (
|
|||
{ type: DataType.uint32, data: vectorizedHeadSize },
|
||||
{ type: DataType.uint32, data: totalSequenceLength },
|
||||
{ type: DataType.uint32, data: parameters.numHeads },
|
||||
{ type: DataType.uint32, data: parameters.headSize },
|
||||
{ type: DataType.float, data: alpha },
|
||||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: parameters.kvSequenceLength },
|
||||
{ type: DataType.uint32, data: nReps },
|
||||
];
|
||||
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
|
||||
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
|
||||
|
@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = (
|
|||
if (attentionBias) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (seqLens) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (totalSequenceLengthInput) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }];
|
||||
if (presentKey) {
|
||||
outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default });
|
||||
|
@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = (
|
|||
if (attentionBias) {
|
||||
inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims));
|
||||
}
|
||||
const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
|
||||
if (seqLensInputVariable) {
|
||||
inputVars.push(seqLensInputVariable);
|
||||
}
|
||||
const totalSequenceLengthInputVariable = totalSequenceLengthInput
|
||||
? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
|
||||
: undefined;
|
||||
if (totalSequenceLengthInputVariable) {
|
||||
inputVars.push(totalSequenceLengthInputVariable);
|
||||
}
|
||||
const output = outputVariable('output', q.dataType, probsShape);
|
||||
const outputVars = [output];
|
||||
if (presentKey) {
|
||||
|
@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = (
|
|||
{ name: 'K', type: 'u32' },
|
||||
{ name: 'N', type: 'u32' },
|
||||
{ name: 'num_heads', type: 'u32' },
|
||||
{ name: 'head_size', type: 'u32' },
|
||||
{ name: 'alpha', type: 'f32' as UniformDataElementType },
|
||||
{ name: 'past_sequence_length', type: 'u32' },
|
||||
{ name: 'kv_sequence_length', type: 'u32' },
|
||||
{ name: 'n_reps', type: 'u32' },
|
||||
];
|
||||
return `
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = (
|
|||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
|
||||
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
|
||||
// x holds the N and y holds the M
|
||||
let headIdx = workgroup_id.z;
|
||||
let headIdx = workgroup_id.z % uniforms.num_heads;
|
||||
let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'};
|
||||
let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'};
|
||||
let batchIdx = workgroup_id.z / uniforms.num_heads;
|
||||
let m = workgroup_id.y * TILE_SIZE;
|
||||
let n = workgroup_id.x * TILE_SIZE;
|
||||
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
|
||||
${(() => {
|
||||
if (feedPastKey && presentKey) {
|
||||
return `
|
||||
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
|
||||
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
|
||||
} else {
|
||||
return `
|
||||
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
|
||||
}
|
||||
})()}
|
||||
${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
|
||||
let sequence_length = uniforms.M;
|
||||
var total_sequence_length = uniforms.N;
|
||||
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
|
||||
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
|
||||
let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
|
||||
${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
|
||||
let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
|
||||
${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
|
||||
var value = ${f32Type}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
|
@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = (
|
|||
${(() => {
|
||||
if (feedPastKey && presentKey) {
|
||||
return `
|
||||
if (n + local_id.y < uniforms.past_sequence_length) {
|
||||
if (n + local_id.y < past_sequence_length) {
|
||||
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
} else {
|
||||
tileK[idx] =
|
||||
key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
|
||||
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
|
||||
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
|
||||
}`;
|
||||
} else {
|
||||
return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
|
||||
return `
|
||||
if (n + local_id.y < uniforms.kv_sequence_length) {
|
||||
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
}`;
|
||||
}
|
||||
})()}
|
||||
${
|
||||
presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : ''
|
||||
presentKey
|
||||
? `if (n + local_id.y < present_sequence_length) {
|
||||
present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];
|
||||
}`
|
||||
: ''
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
|
||||
value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
|
||||
value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let headOffset = headIdx * uniforms.M * uniforms.N;
|
||||
if (global_id.y < uniforms.M && global_id.x < uniforms.N) {
|
||||
if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {
|
||||
let headOffset = workgroup_id.z * uniforms.M * uniforms.N;
|
||||
let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
|
||||
var sum: f32 = ${(() => {
|
||||
switch (components) {
|
||||
|
@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = (
|
|||
pastValue: TensorView | undefined,
|
||||
params: AttentionParameters,
|
||||
pastSequenceLength: number,
|
||||
seqLens: TensorView | undefined = undefined,
|
||||
totalSequenceLengthInput: TensorView | undefined = undefined,
|
||||
) => {
|
||||
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
|
||||
const nReps = params.nReps ? params.nReps : 1;
|
||||
const repeatedVHiddenSize = params.vHiddenSize * nReps;
|
||||
const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
|
||||
const presentValue = outputCount > 1 && pastValue;
|
||||
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
|
||||
const presentValueShape = presentValue
|
||||
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
|
||||
? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize]
|
||||
: undefined;
|
||||
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
|
||||
const TILE_SIZE = 12;
|
||||
|
@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = (
|
|||
{ type: DataType.uint32, data: totalSequenceLength },
|
||||
{ type: DataType.uint32, data: params.vHeadSize },
|
||||
{ type: DataType.uint32, data: params.numHeads },
|
||||
{ type: DataType.uint32, data: params.headSize },
|
||||
{ type: DataType.uint32, data: repeatedVHiddenSize },
|
||||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: params.kvSequenceLength },
|
||||
{ type: DataType.uint32, data: nReps },
|
||||
];
|
||||
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
|
||||
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
|
||||
|
@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = (
|
|||
if (feedPastValue) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (seqLens) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (totalSequenceLengthInput) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
|
||||
if (presentValue) {
|
||||
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
|
||||
|
@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = (
|
|||
if (feedPastValue) {
|
||||
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
|
||||
}
|
||||
const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
|
||||
if (seqLens) {
|
||||
inputVars.push(seqLensInputVariable!);
|
||||
}
|
||||
const totalSequenceLengthInputVariable = totalSequenceLengthInput
|
||||
? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
|
||||
: undefined;
|
||||
if (totalSequenceLengthInput) {
|
||||
inputVars.push(totalSequenceLengthInputVariable!);
|
||||
}
|
||||
const output = outputVariable('output', probs.dataType, outputShape);
|
||||
const outputVars = [output];
|
||||
if (presentValue) {
|
||||
|
@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = (
|
|||
{ name: 'K', type: 'u32' },
|
||||
{ name: 'N', type: 'u32' },
|
||||
{ name: 'num_heads', type: 'u32' },
|
||||
{ name: 'head_size', type: 'u32' },
|
||||
{ name: 'v_hidden_size', type: 'u32' },
|
||||
{ name: 'past_sequence_length', type: 'u32' },
|
||||
{ name: 'kv_sequence_length', type: 'u32' },
|
||||
{ name: 'n_reps', type: 'u32' },
|
||||
];
|
||||
return `
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
|
||||
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
|
||||
let headIdx = workgroup_id.z;
|
||||
let headIdx = workgroup_id.z % uniforms.num_heads;
|
||||
let batchIdx = workgroup_id.z / uniforms.num_heads;
|
||||
let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'};
|
||||
let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'};
|
||||
let m = global_id.y;
|
||||
let n = global_id.x;
|
||||
|
||||
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
|
||||
${(() => {
|
||||
if (feedPastValue && presentValue) {
|
||||
return `
|
||||
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
|
||||
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
|
||||
`;
|
||||
} else {
|
||||
return `
|
||||
let offsetB = headIdx * uniforms.N * uniforms.K + n;
|
||||
`;
|
||||
}
|
||||
})()}
|
||||
${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
|
||||
let sequence_length = uniforms.M;
|
||||
var total_sequence_length = uniforms.K;
|
||||
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
|
||||
let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
|
||||
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
|
||||
${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
|
||||
let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
|
||||
${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
|
||||
var value = ${probsHelper.type.storage}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (m < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
|
@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = (
|
|||
${(() => {
|
||||
if (feedPastValue && presentValue) {
|
||||
return `
|
||||
if (w + local_id.y < uniforms.past_sequence_length) {
|
||||
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
|
||||
} else {
|
||||
tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
|
||||
if (w + local_id.y < past_sequence_length) {
|
||||
tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
|
||||
} else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
|
||||
tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];
|
||||
}
|
||||
`;
|
||||
} else {
|
||||
return `
|
||||
tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
|
||||
`;
|
||||
if (w + local_id.y < uniforms.kv_sequence_length) {
|
||||
tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N];
|
||||
}`;
|
||||
}
|
||||
})()}
|
||||
${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
|
||||
${
|
||||
presentValue
|
||||
? `
|
||||
if (w + local_id.y < present_sequence_length) {
|
||||
present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx];
|
||||
}`
|
||||
: ''
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
|
||||
for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// we need to transpose output from BNSH_v to BSND_v
|
||||
let batchIdx = workgroup_id.z / uniforms.num_heads;
|
||||
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
|
||||
if (m < uniforms.M && n < uniforms.N) {
|
||||
let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
|
||||
+ currentBatchHeadNumber * uniforms.N + n;
|
||||
+ headIdx * uniforms.N + n;
|
||||
output[outputIdx] = value;
|
||||
}
|
||||
}`;
|
||||
|
@ -671,23 +804,29 @@ export const applyAttention = (
|
|||
pastValue: TensorView | undefined,
|
||||
attentionBiasInput: TensorView | undefined,
|
||||
parameters: AttentionParameters,
|
||||
attributes: AttentionAttrs,
|
||||
seqLens: TensorView | undefined = undefined,
|
||||
totalSequenceLengthInput: TensorView | undefined = undefined,
|
||||
) => {
|
||||
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
|
||||
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
|
||||
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
|
||||
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
|
||||
const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0;
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const attentionBias =
|
||||
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
|
||||
|
||||
const inputsK = [q, k];
|
||||
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
|
||||
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
|
||||
inputsK.push(pastKey);
|
||||
}
|
||||
if (attentionBias) {
|
||||
inputsK.push(attentionBias);
|
||||
}
|
||||
|
||||
if (seqLens) {
|
||||
inputsK.push(seqLens);
|
||||
}
|
||||
if (totalSequenceLengthInput) {
|
||||
inputsK.push(totalSequenceLengthInput);
|
||||
}
|
||||
// Run AttentionProbs
|
||||
const probs = context.compute(
|
||||
createAttentionProbsProgramInfo(
|
||||
|
@ -697,31 +836,55 @@ export const applyAttention = (
|
|||
pastKey,
|
||||
attentionBias,
|
||||
parameters,
|
||||
attributes,
|
||||
pastSequenceLength,
|
||||
seqLens,
|
||||
totalSequenceLengthInput,
|
||||
),
|
||||
{ inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] },
|
||||
{ inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
|
||||
)[0];
|
||||
|
||||
// Run Softmax
|
||||
context.compute(
|
||||
createInPlaceSoftmaxProgramInfo(
|
||||
probs,
|
||||
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
|
||||
parameters.batchSize,
|
||||
parameters.numHeads,
|
||||
pastSequenceLength,
|
||||
parameters.sequenceLength,
|
||||
totalSequenceLength,
|
||||
seqLens,
|
||||
totalSequenceLengthInput,
|
||||
),
|
||||
{ inputs: [probs], outputs: [] },
|
||||
{ inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] },
|
||||
);
|
||||
|
||||
// Run AttrionScore
|
||||
// Run AttentionScore
|
||||
const inputsV = [probs, v];
|
||||
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
|
||||
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
|
||||
inputsV.push(pastValue);
|
||||
}
|
||||
context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
|
||||
inputs: inputsV,
|
||||
outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
|
||||
});
|
||||
if (seqLens) {
|
||||
inputsV.push(seqLens);
|
||||
}
|
||||
if (totalSequenceLengthInput) {
|
||||
inputsV.push(totalSequenceLengthInput);
|
||||
}
|
||||
context.compute(
|
||||
createVxAttentionScoreProgramInfo(
|
||||
outputCount,
|
||||
probs,
|
||||
v,
|
||||
pastValue,
|
||||
parameters,
|
||||
pastSequenceLength,
|
||||
seqLens,
|
||||
totalSequenceLengthInput,
|
||||
),
|
||||
{
|
||||
inputs: inputsV,
|
||||
outputs: outputCount > 1 ? [0, 2] : [0],
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
||||
|
@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs):
|
|||
undefined,
|
||||
context.inputs[5],
|
||||
params,
|
||||
attributes,
|
||||
);
|
||||
};
|
||||
|
|
|
@ -1,31 +1,49 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { DataType } from '../../../wasm-common';
|
||||
import { TensorView } from '../../tensor-view';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
||||
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
|
||||
import { ComputeContext } from '../types';
|
||||
|
||||
import {
|
||||
applyAttention,
|
||||
AttentionAttrs,
|
||||
AttentionMaskType,
|
||||
AttentionParameters,
|
||||
AttentionQkvFormat,
|
||||
} from './attention';
|
||||
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
|
||||
import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
|
||||
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
|
||||
import { createTileProgramInfo } from './tile';
|
||||
import { createSplitProgramInfo, SplitAttributes } from './split';
|
||||
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
|
||||
export interface GroupQueryAttentionAttributes {
|
||||
numHeads: number;
|
||||
kvNumHeads: number;
|
||||
scale: number;
|
||||
softcap: number;
|
||||
doRotary: number;
|
||||
rotaryInterleaved: number;
|
||||
smoothSoftmax: boolean;
|
||||
localWindowSize: number;
|
||||
}
|
||||
|
||||
export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
|
||||
export const validateInputs = (
|
||||
inputs: readonly TensorView[],
|
||||
attributes: GroupQueryAttentionAttributes,
|
||||
): AttentionParameters => {
|
||||
if (attributes.doRotary && inputs.length <= 7) {
|
||||
throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified');
|
||||
}
|
||||
const query = inputs[0];
|
||||
const key = inputs[1];
|
||||
const value = inputs[2];
|
||||
const pastKey = inputs[3];
|
||||
const pastValue = inputs[4];
|
||||
|
||||
if (attributes.localWindowSize !== -1) {
|
||||
throw new Error('Local attention is not supported');
|
||||
}
|
||||
if (attributes.softcap !== 0) {
|
||||
throw new Error('Softcap is not supported');
|
||||
}
|
||||
if (attributes.rotaryInterleaved !== 0) {
|
||||
throw new Error('Rotary interleaved is not supported');
|
||||
}
|
||||
if (attributes.smoothSoftmax) {
|
||||
throw new Error('Smooth softmax is not supported');
|
||||
}
|
||||
// Abbreviation and Meanings:
|
||||
// B: batch_size
|
||||
// S: sequence_length (input sequence length of query)
|
||||
|
@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
const dmmhaPacking = false;
|
||||
const batchSize = query.dims[0];
|
||||
const sequenceLength = query.dims[1];
|
||||
const hiddenSize =
|
||||
let hiddenSize =
|
||||
query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4];
|
||||
let kvSequenceLength = sequenceLength;
|
||||
|
||||
let pastSequenceLength = 0;
|
||||
let maxSequenceLength = 0;
|
||||
const headSize = Math.floor(hiddenSize / attributes.numHeads);
|
||||
const packedQKV = !key || key.dims.length === 0;
|
||||
const headSize = !packedQKV
|
||||
? Math.floor(hiddenSize / attributes.numHeads)
|
||||
: Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads));
|
||||
if (packedQKV) {
|
||||
hiddenSize = headSize * attributes.numHeads;
|
||||
}
|
||||
const hasPastKey = pastKey && pastKey.dims.length !== 0;
|
||||
const hasPastValue = pastValue && pastValue.dims.length !== 0;
|
||||
// TODO : this should be from attributes.
|
||||
const isPastkvBSNH = true;
|
||||
// Currenly the onnxruntime GQA specification only support key/value BNSH format.
|
||||
const isPastkvBSNH =
|
||||
hasPastKey &&
|
||||
pastKey.dims.length === 4 &&
|
||||
pastKey.dims[0] === batchSize &&
|
||||
pastKey.dims[1] !== attributes.kvNumHeads &&
|
||||
pastKey.dims[2] === attributes.kvNumHeads &&
|
||||
pastKey.dims[3] === headSize;
|
||||
|
||||
if (isPastkvBSNH) {
|
||||
throw new Error('BSNH pastKey/pastValue is not supported');
|
||||
}
|
||||
if (hasPastKey && hasPastValue) {
|
||||
if (pastKey.dims.length !== 4) {
|
||||
throw new Error('Input "past_key" is expected to have 4 dimensions');
|
||||
|
@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
if (pastValue.dims.length !== 4) {
|
||||
throw new Error('Input "past_value" is expected to have 4 dimensions');
|
||||
}
|
||||
if (isPastkvBSNH) {
|
||||
// For BSNH
|
||||
pastSequenceLength = pastKey.dims[1];
|
||||
maxSequenceLength = pastKey.dims[1];
|
||||
} else {
|
||||
// For BNSH
|
||||
pastSequenceLength = pastKey.dims[2];
|
||||
maxSequenceLength = pastKey.dims[2];
|
||||
}
|
||||
pastSequenceLength = pastKey.dims[2];
|
||||
} else if (hasPastKey || hasPastValue) {
|
||||
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
|
||||
}
|
||||
|
||||
let qkvFormat: AttentionQkvFormat;
|
||||
if (key) {
|
||||
let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH;
|
||||
if (key && key.dims.length > 0) {
|
||||
if (query.dims.length !== 3) {
|
||||
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
|
||||
}
|
||||
|
@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
if (query.dims[2] % key.dims[2] !== 0) {
|
||||
throw new Error('Dimension 2 of "query" should be a multiple of "key"');
|
||||
}
|
||||
qkvFormat = AttentionQkvFormat.qkvBSNH;
|
||||
kvSequenceLength = key.dims[1];
|
||||
} else if (key.dims.length === 5) {
|
||||
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
|
||||
|
@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
if (value) {
|
||||
throw new Error('Expect "value" be none when "key" has packed kv format.');
|
||||
}
|
||||
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
|
||||
kvSequenceLength = key.dims[1];
|
||||
} else {
|
||||
// key_dims.size() == 4 (cross-attention with past_key)
|
||||
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
|
||||
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
|
||||
}
|
||||
|
||||
qkvFormat = AttentionQkvFormat.unknown;
|
||||
kvSequenceLength = key.dims[2];
|
||||
}
|
||||
} else {
|
||||
|
@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
|
||||
const maskType: AttentionMaskType = AttentionMaskType.none;
|
||||
let passPastInKv = false;
|
||||
let vHiddenSize = hiddenSize;
|
||||
if (value) {
|
||||
let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize;
|
||||
if (value && value.dims.length > 0) {
|
||||
if (value.dims.length !== 3 && value.dims.length !== 4) {
|
||||
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
|
||||
}
|
||||
|
@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
passPastInKv = true;
|
||||
}
|
||||
}
|
||||
const totalSequenceLength = pastSequenceLength + kvSequenceLength;
|
||||
const seqlLens = inputs.length > 4 ? inputs[5] : undefined;
|
||||
if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) {
|
||||
throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size');
|
||||
}
|
||||
const totalSequenceLength = -1;
|
||||
const maxSequenceLength = -1;
|
||||
const broadcastResPosBias = false;
|
||||
|
||||
return {
|
||||
|
@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
|
|||
hiddenSize,
|
||||
vHiddenSize,
|
||||
headSize,
|
||||
vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!),
|
||||
vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads),
|
||||
numHeads: attributes.numHeads,
|
||||
kvNumHeads: attributes.kvNumHeads,
|
||||
nReps: attributes.numHeads / attributes.kvNumHeads!,
|
||||
nReps: attributes.numHeads / attributes.kvNumHeads,
|
||||
pastPresentShareBuffer: false,
|
||||
maskType,
|
||||
scale: attributes.scale,
|
||||
broadcastResPosBias,
|
||||
passPastInKv,
|
||||
qkvFormat,
|
||||
isPastkvBSNH,
|
||||
};
|
||||
};
|
||||
|
||||
const createConcatProgramInfo = (
|
||||
a: TensorView,
|
||||
b: TensorView | undefined,
|
||||
dataType: DataType,
|
||||
params: AttentionParameters,
|
||||
): ProgramInfo => {
|
||||
const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize];
|
||||
const component = 4;
|
||||
const outputSize = ShapeUtil.size(outputShape) / component;
|
||||
const presentSequenceLength = params.totalSequenceLength;
|
||||
const output = outputVariable('present_kv', dataType, outputShape.length, component);
|
||||
const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component);
|
||||
const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined;
|
||||
|
||||
const H = Math.ceil(params.headSize / component);
|
||||
const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 };
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank'];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
{ type: DataType.uint32, data: params.pastSequenceLength },
|
||||
{ type: DataType.uint32, data: params.kvSequenceLength },
|
||||
{ type: DataType.uint32, data: params.totalSequenceLength },
|
||||
];
|
||||
|
||||
const inputs = [inputA];
|
||||
if (inputB) {
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(a.dims),
|
||||
...createTensorShapeVariables(b!.dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
);
|
||||
inputs.push(inputB);
|
||||
} else {
|
||||
programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
const uniforms: UniformsArrayType = [
|
||||
{ name: 'output_size', type: 'u32' },
|
||||
{ name: 'past_seqlen', type: 'u32' },
|
||||
{ name: 'new_seqlen', type: 'u32' },
|
||||
{ name: 'present_seqlen', type: 'u32' },
|
||||
];
|
||||
|
||||
const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H;
|
||||
var past_head_stride = uniforms.past_seqlen * H;
|
||||
if (is_bsnh) {
|
||||
past_head_stride = H;
|
||||
}
|
||||
let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h;
|
||||
present_kv[out_offset] = past_kv[in_offset];`;
|
||||
const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H;
|
||||
let new_row_stride = num_heads * H;
|
||||
let new_head_stride = H;
|
||||
let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
|
||||
present_kv[out_offset] = new_kv[in_offset];`;
|
||||
const concatStr = b
|
||||
? `if (s < past_seqlen) {
|
||||
${pastStr}
|
||||
} else if (s < past_seqlen + uniforms.new_seqlen) {
|
||||
${newStr}
|
||||
}`
|
||||
: `if (s < past_seqlen + uniforms.new_seqlen) {
|
||||
${newStr}
|
||||
}`;
|
||||
|
||||
// TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit.
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)}
|
||||
${shaderHelper.mainStart([H, params.kvNumHeads!, 1])}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
var indices = ${output.offsetToIndices('global_idx')};
|
||||
let h = local_id.x;
|
||||
let n = local_id.y;
|
||||
let s = workgroup_id.x;
|
||||
let b = workgroup_id.y;
|
||||
let num_heads = ${params.kvNumHeads!}u;
|
||||
let H = ${H}u;
|
||||
|
||||
let present_seqlen = uniforms.present_seqlen;
|
||||
let present_batch_stride = present_seqlen * num_heads * H;
|
||||
var row_stride = H;
|
||||
let is_bsnh = ${params.isPastkvBSNH};
|
||||
|
||||
if (is_bsnh) {
|
||||
row_stride = num_heads * H;
|
||||
}
|
||||
var present_head_stride = present_seqlen * H;
|
||||
if (is_bsnh) {
|
||||
present_head_stride = H;
|
||||
}
|
||||
|
||||
let past_seqlen = uniforms.past_seqlen;
|
||||
|
||||
let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
|
||||
${concatStr}
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'ConcatPastNew',
|
||||
shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies },
|
||||
getRunData: () => ({
|
||||
outputs: [{ dims: outputShape, dataType }],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
|
||||
createAttributeWithCacheKey({ ...attributes });
|
||||
|
||||
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] });
|
||||
|
||||
const maybeExpandAndTransposeToBNSH = (
|
||||
context: ComputeContext,
|
||||
input: TensorView,
|
||||
pastKV: TensorView | undefined,
|
||||
params: AttentionParameters,
|
||||
outputIndex: number,
|
||||
) => {
|
||||
const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => {
|
||||
let reshapedInput = input;
|
||||
const numHeads = params.kvNumHeads!;
|
||||
const nReps = params.nReps!;
|
||||
if (input.dims.length === 3 && params.kvSequenceLength !== 0) {
|
||||
reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]);
|
||||
}
|
||||
|
||||
if (pastKV) {
|
||||
reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), {
|
||||
inputs: [reshapedInput, pastKV],
|
||||
outputs: [params.isPastkvBSNH ? outputIndex : -1],
|
||||
})[0];
|
||||
} else {
|
||||
reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), {
|
||||
inputs: [reshapedInput],
|
||||
outputs: [params.isPastkvBSNH ? outputIndex : -1],
|
||||
})[0];
|
||||
}
|
||||
if (nReps !== 1) {
|
||||
reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {
|
||||
reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
|
||||
inputs: [reshapedInput],
|
||||
outputs: [-1],
|
||||
})[0];
|
||||
reshapedInput = reshapedInput.reshape([
|
||||
params.batchSize,
|
||||
params.totalSequenceLength,
|
||||
numHeads * nReps,
|
||||
params.headSize,
|
||||
]);
|
||||
}
|
||||
|
||||
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
|
||||
inputs: [reshapedInput],
|
||||
outputs: [-1],
|
||||
})[0];
|
||||
return reshapedInput;
|
||||
};
|
||||
|
||||
export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
|
||||
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
|
||||
const params = validateInputs(context.inputs, attributes);
|
||||
if (context.inputs[0].dims.length === 5) {
|
||||
throw new Error('Packed QKV is not implemented');
|
||||
|
@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti
|
|||
throw new Error('Packed KV is not implemented');
|
||||
}
|
||||
|
||||
const q = context.inputs[0];
|
||||
const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
|
||||
const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
|
||||
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
|
||||
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
|
||||
const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
|
||||
const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
|
||||
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
|
||||
|
||||
// TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.
|
||||
|
||||
const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
|
||||
axis: 2,
|
||||
numOutputs: 3,
|
||||
splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
|
||||
});
|
||||
const [query, key, value] =
|
||||
!k && !v
|
||||
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
|
||||
: [q, k!, v!];
|
||||
|
||||
const Q = maybeTransposeToBNSHAndAddBias(
|
||||
context,
|
||||
params.batchSize,
|
||||
params.numHeads,
|
||||
params.sequenceLength,
|
||||
params.headSize,
|
||||
context.inputs[0],
|
||||
query,
|
||||
undefined,
|
||||
0,
|
||||
);
|
||||
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
|
||||
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
|
||||
const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1);
|
||||
const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2);
|
||||
applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);
|
||||
applyAttention(
|
||||
context,
|
||||
Q,
|
||||
maybeTransposeToBNSH(context, key, params),
|
||||
maybeTransposeToBNSH(context, value, params),
|
||||
undefined,
|
||||
undefined,
|
||||
pastKey,
|
||||
pastValue,
|
||||
undefined,
|
||||
params,
|
||||
seqLens,
|
||||
totalSequenceLengthInput,
|
||||
);
|
||||
};
|
||||
|
|
|
@ -403,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
|
|||
);
|
||||
|
||||
if (kvBNSH) {
|
||||
return applyAttention(
|
||||
context,
|
||||
Q,
|
||||
key,
|
||||
value,
|
||||
keyPaddingMask,
|
||||
undefined,
|
||||
pastKey,
|
||||
pastValue,
|
||||
attentionBias,
|
||||
params,
|
||||
attributes,
|
||||
);
|
||||
return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params);
|
||||
}
|
||||
if (!key || !value) {
|
||||
throw new Error('key and value must be provided');
|
||||
|
@ -442,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
|
|||
2 * params.hiddenSize,
|
||||
);
|
||||
|
||||
applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes);
|
||||
applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params);
|
||||
};
|
||||
|
|
|
@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
|
|||
}`;
|
||||
};
|
||||
|
||||
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
|
||||
export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const inputSize = ShapeUtil.size(inputShape);
|
||||
const dataType = inputs[0].dataType;
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/bert/gqa_attention_base.h"
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
@ -11,31 +11,29 @@ namespace js {
|
|||
|
||||
using onnxruntime::js::JsKernel;
|
||||
|
||||
class GroupQueryAttention : public JsKernel {
|
||||
class GroupQueryAttention : public JsKernel, GQAAttentionBase {
|
||||
public:
|
||||
explicit GroupQueryAttention(const OpKernelInfo& info)
|
||||
: JsKernel(info) {
|
||||
int64_t num_heads = 0;
|
||||
int64_t kv_num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
: JsKernel(info), GQAAttentionBase(info, false) {
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({
|
||||
"numHeads" : $1,
|
||||
"kvNumHeads" : $2,
|
||||
"scale" : $3,
|
||||
"softcap" : $4,
|
||||
"doRotary" : $5,
|
||||
"rotaryInterleaved" : $6,
|
||||
"smoothSoftmax" : $7,
|
||||
"localWindowSize" : $8
|
||||
}),
|
||||
static_cast<int32_t>(num_heads_),
|
||||
static_cast<int32_t>(kv_num_heads_),
|
||||
static_cast<float>(scale_));
|
||||
static_cast<float>(scale_),
|
||||
static_cast<float>(softcap_),
|
||||
static_cast<int32_t>(do_rotary_),
|
||||
static_cast<int32_t>(rotary_interleaved_),
|
||||
static_cast<int32_t>(use_smooth_softmax_),
|
||||
static_cast<int32_t>(local_window_size_));
|
||||
}
|
||||
|
||||
protected:
|
||||
int num_heads_; // number of attention heads
|
||||
int kv_num_heads_; // number of k and v heads
|
||||
float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size)
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
|
|
Загрузка…
Ссылка в новой задаче