[JS/WebGPU] GroupQueryAttention rewrite (#20946)

### Description
Implement JSEP GroupQueryAttention



### Motivation and Context
Required to enable certain LLM models to run using WebGPU.
This commit is contained in:
Satya Kumar Jandhyala 2024-10-23 10:14:09 -07:00 коммит произвёл GitHub
Родитель 33e2f6ad8d
Коммит fd8ee4894d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
7 изменённых файлов: 1482 добавлений и 725 удалений

Просмотреть файл

@ -19,7 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather';
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
import { gemm, parseGemmAttributes } from './ops/gemm';
import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
import { groupQueryAttention } from './ops/group-query-attention';
import { instanceNorm } from './ops/instance-norm';
import { layerNorm } from './ops/layer-norm';
import { matMul } from './ops/matmul';
@ -104,7 +104,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
['GroupQueryAttention', [groupQueryAttention]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],

Просмотреть файл

@ -8,6 +8,7 @@ import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramU
import {
getMaxComponents,
IndicesHelper,
inputVariable,
outputVariable,
ShaderHelper,
@ -65,14 +66,17 @@ export interface AttentionParameters {
broadcastResPosBias: boolean;
passPastInKv: boolean;
qkvFormat: AttentionQkvFormat;
isPastkvBSNH?: boolean;
softcap?: number;
doRotary?: number;
rotaryInterLeaved?: number;
sommoothSoftmax?: number;
localWindowsSize?: number;
}
export interface AttentionAttrs {
numHeads: number;
kvNumHeads?: number;
isUnidirectional?: number;
maskFilterValue?: number;
isUnidirectional: number;
maskFilterValue: number;
scale: number;
doRotary: number;
qkvHiddenSizes: number[];
@ -258,41 +262,106 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
};
};
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
const components = getMaxComponents(d);
const initVarStub = (
seqLensInput: IndicesHelper | undefined,
totalSequenceLengthInput: IndicesHelper | undefined,
initPastSequenceLength: boolean,
) => {
// In the case of GQA, redefine total_sequence_length, present_sequence_length and past_sequence_length based on seqlen_k input
if (totalSequenceLengthInput && seqLensInput) {
return `
let total_sequence_length_input = u32(${totalSequenceLengthInput.getByOffset('0')});
let present_sequence_length = max(total_sequence_length_input, uniforms.past_sequence_length);
let is_subsequent_prompt: bool = sequence_length > 1 && sequence_length != total_sequence_length_input;
let is_first_prompt: bool = is_subsequent_prompt == false && sequence_length == total_sequence_length_input;
total_sequence_length = u32(${seqLensInput?.getByOffset('batchIdx')}) + 1;
var past_sequence_length: u32 = 0;
if (is_first_prompt == false) {
past_sequence_length = total_sequence_length - sequence_length;
}
`;
} else {
return `
${initPastSequenceLength ? 'let past_sequence_length = uniforms.past_sequence_length' : ''};
let present_sequence_length = total_sequence_length;
`;
}
};
const createInPlaceSoftmaxProgramInfo = (
input: TensorView,
batchSize: number,
numHeads: number,
pastSequenceLength: number,
sequenceLength: number,
totalSequenceLength: number,
seqLens: TensorView | undefined,
totalSequenceLengthInput: TensorView | undefined,
) => {
// Set components to 1 if seqLens is specified, i.e. GroupQueryAttention.
const components = getMaxComponents(seqLens ? 1 : totalSequenceLength);
let WG = 64;
const dComp = d / components;
if (dComp < WG) {
const totalSequenceLengthComp = totalSequenceLength / components;
if (totalSequenceLengthComp < WG) {
WG = 32;
}
const elementsPerThread = Math.ceil(d / components / WG);
const elementsPerThread = Math.ceil(totalSequenceLength / components / WG);
const programUniforms: ProgramUniform[] = [
{ type: DataType.float, data: 1 / d },
{ type: DataType.uint32, data: dComp },
{ type: DataType.uint32, data: batchSize },
{ type: DataType.uint32, data: numHeads },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: sequenceLength },
{ type: DataType.uint32, data: totalSequenceLengthComp },
{ type: DataType.uint32, data: elementsPerThread },
];
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
if (seqLens) {
inputDependencies.push('type');
}
if (totalSequenceLengthInput) {
inputDependencies.push('type');
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
const inputHelpers = [inputHelper];
const seqLensInputHelper = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
if (seqLensInputHelper) {
inputHelpers.push(seqLensInputHelper);
}
const totalSequenceLengthInputHelper = totalSequenceLengthInput
? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
: undefined;
if (totalSequenceLengthInputHelper) {
inputHelpers.push(totalSequenceLengthInputHelper);
}
const elemValueType = tensorTypeToWsglValueType(input.dataType);
const uniforms: UniformsArrayType = [
{ name: 'd_inv', type: 'f32' },
{ name: 'd_comp', type: 'u32' },
{ name: 'batch_size', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'sequence_length', type: 'u32' },
{ name: 'total_sequence_length', type: 'u32' },
{ name: 'elements_per_thread', type: 'u32' },
];
return `
var<workgroup> thread_max: array<f32, ${WG}>;
var<workgroup> thread_sum: array<f32, ${WG}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputHelpers)}
${shaderHelper.mainStart([WG, 1, 1])}
let batchIdx = workgroup_id.z / uniforms.num_heads;
let headIdx = workgroup_id.z % uniforms.num_heads;
let sequence_length = uniforms.sequence_length;
var total_sequence_length = uniforms.total_sequence_length;
${initVarStub(seqLensInputHelper, totalSequenceLengthInputHelper, false)}
let local_offset = local_idx * uniforms.elements_per_thread;
let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
let offset = (global_idx / ${WG}) * uniforms.total_sequence_length + local_offset;
let seq_causal_length = ${seqLens ? 'u32(past_sequence_length + workgroup_id.y + 1)' : 'total_sequence_length'};
var thread_max_vector = ${f32Type}(-3.402823e+38f);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
}
thread_max[local_idx] = ${(() => {
@ -315,7 +384,7 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
}
var sum_vector = ${f32Type}(0);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
}
thread_sum[local_idx] = ${(() => {
@ -338,15 +407,23 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
}
if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(1.0) / ${elemValueType}(seq_causal_length));
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
var f32input = ${f32Type}(x[offset + i]);
x[offset + i] = ${inputHelper.type.value}(exp(f32input - max_value) / sum);
}
}
${
seqLens
? `
for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length; total_seq_id++) {
x[offset + total_seq_id] = ${inputHelper.type.value}(${elemValueType}(0));
}`
: ''
};
}`;
};
@ -354,7 +431,11 @@ const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number
name: 'AttentionProbsSoftmax',
shaderCache: { hint: `${WG};${dataType};${components}`, inputDependencies },
getShaderSource,
getRunData: () => ({ outputs: [], dispatchGroup: { x: n }, programUniforms }),
getRunData: () => ({
outputs: [],
dispatchGroup: { x: Math.ceil(totalSequenceLength / WG), y: sequenceLength, z: batchSize * numHeads },
programUniforms,
}),
};
};
@ -365,19 +446,21 @@ const createAttentionProbsProgramInfo = (
pastKey: TensorView | undefined,
attentionBias: TensorView | undefined,
parameters: AttentionParameters,
attributes: AttentionAttrs,
pastSequenceLength: number,
seqLens: TensorView | undefined,
totalSequenceLengthInput: TensorView | undefined,
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
const presentKey = outputCount > 1 && pastKey;
const kvNumHeads = parameters.kvNumHeads ? parameters.kvNumHeads : parameters.numHeads;
const presentKeyShape = presentKey
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
? [parameters.batchSize, kvNumHeads, totalSequenceLength, parameters.headSize]
: undefined;
const nReps = parameters.nReps ? parameters.nReps : 1;
// TODO: handle mask
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
const alpha = parameters.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : parameters.scale;
const components = getMaxComponents(parameters.headSize);
const vectorizedHeadSize = parameters.headSize / components;
const TILE_SIZE = 12;
@ -391,9 +474,11 @@ const createAttentionProbsProgramInfo = (
{ type: DataType.uint32, data: vectorizedHeadSize },
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: parameters.numHeads },
{ type: DataType.uint32, data: parameters.headSize },
{ type: DataType.float, data: alpha },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: parameters.kvSequenceLength },
{ type: DataType.uint32, data: nReps },
];
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
@ -404,6 +489,12 @@ const createAttentionProbsProgramInfo = (
if (attentionBias) {
inputDependencies.push('type');
}
if (seqLens) {
inputDependencies.push('type');
}
if (totalSequenceLengthInput) {
inputDependencies.push('type');
}
const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }];
if (presentKey) {
outputs.push({ dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default });
@ -419,6 +510,16 @@ const createAttentionProbsProgramInfo = (
if (attentionBias) {
inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims));
}
const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
if (seqLensInputVariable) {
inputVars.push(seqLensInputVariable);
}
const totalSequenceLengthInputVariable = totalSequenceLengthInput
? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
: undefined;
if (totalSequenceLengthInputVariable) {
inputVars.push(totalSequenceLengthInputVariable);
}
const output = outputVariable('output', q.dataType, probsShape);
const outputVars = [output];
if (presentKey) {
@ -431,9 +532,11 @@ const createAttentionProbsProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'head_size', type: 'u32' },
{ name: 'alpha', type: 'f32' as UniformDataElementType },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
{ name: 'n_reps', type: 'u32' },
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
@ -443,21 +546,20 @@ const createAttentionProbsProgramInfo = (
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
// x holds the N and y holds the M
let headIdx = workgroup_id.z;
let headIdx = workgroup_id.z % uniforms.num_heads;
let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'};
let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'};
let batchIdx = workgroup_id.z / uniforms.num_heads;
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
${(() => {
if (feedPastKey && presentKey) {
return `
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
} else {
return `
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
}
})()}
${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.N;
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx;
let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
${feedPastKey && presentKey ? 'let pastKeyOffset = absKvHeadIdx * uniforms.past_sequence_length * uniforms.K;' : ''};
let kOffset = absKvHeadIdx * uniforms.kv_sequence_length * uniforms.K;
${presentKey ? 'let presentKeyOffset = absKvHeadIdx * uniforms.N * uniforms.K;' : ''}
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
@ -468,31 +570,37 @@ const createAttentionProbsProgramInfo = (
${(() => {
if (feedPastKey && presentKey) {
return `
if (n + local_id.y < uniforms.past_sequence_length) {
if (n + local_id.y < past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
} else {
tileK[idx] =
key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
}`;
} else {
return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
return `
if (n + local_id.y < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
}`;
}
})()}
${
presentKey ? 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : ''
presentKey
? `if (n + local_id.y < present_sequence_length) {
present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];
}`
: ''
}
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
value += ${f32Type}(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
}
workgroupBarrier();
}
let headOffset = headIdx * uniforms.M * uniforms.N;
if (global_id.y < uniforms.M && global_id.x < uniforms.N) {
if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {
let headOffset = workgroup_id.z * uniforms.M * uniforms.N;
let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;
var sum: f32 = ${(() => {
switch (components) {
@ -530,13 +638,16 @@ const createVxAttentionScoreProgramInfo = (
pastValue: TensorView | undefined,
params: AttentionParameters,
pastSequenceLength: number,
seqLens: TensorView | undefined = undefined,
totalSequenceLengthInput: TensorView | undefined = undefined,
) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
const presentValue = outputCount > 1 && pastValue;
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
const presentValueShape = presentValue
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
? [params.batchSize, kvNumHeads, totalSequenceLength, params.headSize]
: undefined;
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
const TILE_SIZE = 12;
@ -551,9 +662,11 @@ const createVxAttentionScoreProgramInfo = (
{ type: DataType.uint32, data: totalSequenceLength },
{ type: DataType.uint32, data: params.vHeadSize },
{ type: DataType.uint32, data: params.numHeads },
{ type: DataType.uint32, data: params.headSize },
{ type: DataType.uint32, data: repeatedVHiddenSize },
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
{ type: DataType.uint32, data: nReps },
];
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
@ -561,6 +674,12 @@ const createVxAttentionScoreProgramInfo = (
if (feedPastValue) {
inputDependencies.push('type');
}
if (seqLens) {
inputDependencies.push('type');
}
if (totalSequenceLengthInput) {
inputDependencies.push('type');
}
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
if (presentValue) {
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
@ -572,6 +691,16 @@ const createVxAttentionScoreProgramInfo = (
if (feedPastValue) {
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
}
const seqLensInputVariable = seqLens ? inputVariable('seq_lens', seqLens.dataType, seqLens.dims) : undefined;
if (seqLens) {
inputVars.push(seqLensInputVariable!);
}
const totalSequenceLengthInputVariable = totalSequenceLengthInput
? inputVariable('total_sequence_length_input', totalSequenceLengthInput.dataType, totalSequenceLengthInput.dims)
: undefined;
if (totalSequenceLengthInput) {
inputVars.push(totalSequenceLengthInputVariable!);
}
const output = outputVariable('output', probs.dataType, outputShape);
const outputVars = [output];
if (presentValue) {
@ -582,34 +711,32 @@ const createVxAttentionScoreProgramInfo = (
{ name: 'K', type: 'u32' },
{ name: 'N', type: 'u32' },
{ name: 'num_heads', type: 'u32' },
{ name: 'head_size', type: 'u32' },
{ name: 'v_hidden_size', type: 'u32' },
{ name: 'past_sequence_length', type: 'u32' },
{ name: 'kv_sequence_length', type: 'u32' },
{ name: 'n_reps', type: 'u32' },
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileV: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([TILE_SIZE, TILE_SIZE, 1])}
let headIdx = workgroup_id.z;
let headIdx = workgroup_id.z % uniforms.num_heads;
let batchIdx = workgroup_id.z / uniforms.num_heads;
let kvHeadIdx = ${nReps === 1 ? 'headIdx' : 'headIdx / uniforms.n_reps'};
let kv_num_heads = ${nReps === 1 ? 'uniforms.num_heads' : 'uniforms.num_heads / uniforms.n_reps'};
let m = global_id.y;
let n = global_id.x;
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
${(() => {
if (feedPastValue && presentValue) {
return `
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
`;
} else {
return `
let offsetB = headIdx * uniforms.N * uniforms.K + n;
`;
}
})()}
${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.K;
${initVarStub(seqLensInputVariable, totalSequenceLengthInputVariable, true)}
let offsetA = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;
let absKvHeadIdx = batchIdx * kv_num_heads + kvHeadIdx; // kvHeadIdx is relative to the batch
${feedPastValue && presentValue ? 'let pastValueOffset = absKvHeadIdx * uniforms.N * uniforms.past_sequence_length + n;' : ''};
let vOffset = absKvHeadIdx * uniforms.N * uniforms.kv_sequence_length + n;
${presentValue ? 'let presentValueOffset = absKvHeadIdx * uniforms.N * uniforms.K + n;' : ''}
var value = ${probsHelper.type.storage}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (m < uniforms.M && w + local_id.x < uniforms.K) {
@ -620,33 +747,39 @@ const createVxAttentionScoreProgramInfo = (
${(() => {
if (feedPastValue && presentValue) {
return `
if (w + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
} else {
tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
if (w + local_id.y < past_sequence_length) {
tileV[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
} else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
tileV[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];
}
`;
} else {
return `
tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
`;
if (w + local_id.y < uniforms.kv_sequence_length) {
tileV[idx] = v[vOffset + (w + local_id.y) * uniforms.N];
}`;
}
})()}
${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
${
presentValue
? `
if (w + local_id.y < present_sequence_length) {
present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileV[idx];
}`
: ''
}
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
for (var k: u32 = 0u; k < TILE_SIZE && w+k < total_sequence_length; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileV[TILE_SIZE * k + local_id.x];
}
workgroupBarrier();
}
// we need to transpose output from BNSH_v to BSND_v
let batchIdx = workgroup_id.z / uniforms.num_heads;
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
if (m < uniforms.M && n < uniforms.N) {
let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + m * uniforms.v_hidden_size
+ currentBatchHeadNumber * uniforms.N + n;
+ headIdx * uniforms.N + n;
output[outputIdx] = value;
}
}`;
@ -671,23 +804,29 @@ export const applyAttention = (
pastValue: TensorView | undefined,
attentionBiasInput: TensorView | undefined,
parameters: AttentionParameters,
attributes: AttentionAttrs,
seqLens: TensorView | undefined = undefined,
totalSequenceLengthInput: TensorView | undefined = undefined,
) => {
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const attentionBias =
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
const inputsK = [q, k];
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
inputsK.push(pastKey);
}
if (attentionBias) {
inputsK.push(attentionBias);
}
if (seqLens) {
inputsK.push(seqLens);
}
if (totalSequenceLengthInput) {
inputsK.push(totalSequenceLengthInput);
}
// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(
@ -697,31 +836,55 @@ export const applyAttention = (
pastKey,
attentionBias,
parameters,
attributes,
pastSequenceLength,
seqLens,
totalSequenceLengthInput,
),
{ inputs: inputsK, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [-1, 1] : [-1] },
{ inputs: inputsK, outputs: outputCount > 1 ? [-1, 1] : [-1] },
)[0];
// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
probs,
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
parameters.batchSize,
parameters.numHeads,
pastSequenceLength,
parameters.sequenceLength,
totalSequenceLength,
seqLens,
totalSequenceLengthInput,
),
{ inputs: [probs], outputs: [] },
{ inputs: seqLens && totalSequenceLengthInput ? [probs, seqLens, totalSequenceLengthInput] : [probs], outputs: [] },
);
// Run AttrionScore
// Run AttentionScore
const inputsV = [probs, v];
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
inputs: inputsV,
outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
});
if (seqLens) {
inputsV.push(seqLens);
}
if (totalSequenceLengthInput) {
inputsV.push(totalSequenceLengthInput);
}
context.compute(
createVxAttentionScoreProgramInfo(
outputCount,
probs,
v,
pastValue,
parameters,
pastSequenceLength,
seqLens,
totalSequenceLengthInput,
),
{
inputs: inputsV,
outputs: outputCount > 1 ? [0, 2] : [0],
},
);
};
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
@ -857,6 +1020,5 @@ export const attention = (context: ComputeContext, attributes: AttentionAttrs):
undefined,
context.inputs[5],
params,
attributes,
);
};

Просмотреть файл

@ -1,31 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
import { ComputeContext } from '../types';
import {
applyAttention,
AttentionAttrs,
AttentionMaskType,
AttentionParameters,
AttentionQkvFormat,
} from './attention';
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
import { applyAttention, AttentionMaskType, AttentionParameters, AttentionQkvFormat } from './attention';
import { maybeTransposeToBNSHAndAddBias } from './multihead-attention';
import { createTileProgramInfo } from './tile';
import { createSplitProgramInfo, SplitAttributes } from './split';
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
export interface GroupQueryAttentionAttributes {
numHeads: number;
kvNumHeads: number;
scale: number;
softcap: number;
doRotary: number;
rotaryInterleaved: number;
smoothSoftmax: boolean;
localWindowSize: number;
}
export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
export const validateInputs = (
inputs: readonly TensorView[],
attributes: GroupQueryAttentionAttributes,
): AttentionParameters => {
if (attributes.doRotary && inputs.length <= 7) {
throw new Error('cos_cache and sin_cache inputs are required if do_rotary is specified');
}
const query = inputs[0];
const key = inputs[1];
const value = inputs[2];
const pastKey = inputs[3];
const pastValue = inputs[4];
if (attributes.localWindowSize !== -1) {
throw new Error('Local attention is not supported');
}
if (attributes.softcap !== 0) {
throw new Error('Softcap is not supported');
}
if (attributes.rotaryInterleaved !== 0) {
throw new Error('Rotary interleaved is not supported');
}
if (attributes.smoothSoftmax) {
throw new Error('Smooth softmax is not supported');
}
// Abbreviation and Meanings:
// B: batch_size
// S: sequence_length (input sequence length of query)
@ -62,17 +80,32 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
const dmmhaPacking = false;
const batchSize = query.dims[0];
const sequenceLength = query.dims[1];
const hiddenSize =
let hiddenSize =
query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4];
let kvSequenceLength = sequenceLength;
let pastSequenceLength = 0;
let maxSequenceLength = 0;
const headSize = Math.floor(hiddenSize / attributes.numHeads);
const packedQKV = !key || key.dims.length === 0;
const headSize = !packedQKV
? Math.floor(hiddenSize / attributes.numHeads)
: Math.floor(hiddenSize / (attributes.numHeads + 2 * attributes.kvNumHeads));
if (packedQKV) {
hiddenSize = headSize * attributes.numHeads;
}
const hasPastKey = pastKey && pastKey.dims.length !== 0;
const hasPastValue = pastValue && pastValue.dims.length !== 0;
// TODO : this should be from attributes.
const isPastkvBSNH = true;
// Currenly the onnxruntime GQA specification only support key/value BNSH format.
const isPastkvBSNH =
hasPastKey &&
pastKey.dims.length === 4 &&
pastKey.dims[0] === batchSize &&
pastKey.dims[1] !== attributes.kvNumHeads &&
pastKey.dims[2] === attributes.kvNumHeads &&
pastKey.dims[3] === headSize;
if (isPastkvBSNH) {
throw new Error('BSNH pastKey/pastValue is not supported');
}
if (hasPastKey && hasPastValue) {
if (pastKey.dims.length !== 4) {
throw new Error('Input "past_key" is expected to have 4 dimensions');
@ -80,21 +113,13 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
if (pastValue.dims.length !== 4) {
throw new Error('Input "past_value" is expected to have 4 dimensions');
}
if (isPastkvBSNH) {
// For BSNH
pastSequenceLength = pastKey.dims[1];
maxSequenceLength = pastKey.dims[1];
} else {
// For BNSH
pastSequenceLength = pastKey.dims[2];
maxSequenceLength = pastKey.dims[2];
}
pastSequenceLength = pastKey.dims[2];
} else if (hasPastKey || hasPastValue) {
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
}
let qkvFormat: AttentionQkvFormat;
if (key) {
let qkvFormat: AttentionQkvFormat = AttentionQkvFormat.qkvBNSH;
if (key && key.dims.length > 0) {
if (query.dims.length !== 3) {
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
}
@ -109,7 +134,6 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
if (query.dims[2] % key.dims[2] !== 0) {
throw new Error('Dimension 2 of "query" should be a multiple of "key"');
}
qkvFormat = AttentionQkvFormat.qkvBSNH;
kvSequenceLength = key.dims[1];
} else if (key.dims.length === 5) {
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
@ -118,15 +142,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
if (value) {
throw new Error('Expect "value" be none when "key" has packed kv format.');
}
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
kvSequenceLength = key.dims[1];
} else {
// key_dims.size() == 4 (cross-attention with past_key)
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
}
qkvFormat = AttentionQkvFormat.unknown;
kvSequenceLength = key.dims[2];
}
} else {
@ -143,8 +164,8 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
const maskType: AttentionMaskType = AttentionMaskType.none;
let passPastInKv = false;
let vHiddenSize = hiddenSize;
if (value) {
let vHiddenSize = attributes.kvNumHeads ? headSize * attributes.kvNumHeads : hiddenSize;
if (value && value.dims.length > 0) {
if (value.dims.length !== 3 && value.dims.length !== 4) {
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
}
@ -166,7 +187,12 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
passPastInKv = true;
}
}
const totalSequenceLength = pastSequenceLength + kvSequenceLength;
const seqlLens = inputs.length > 4 ? inputs[5] : undefined;
if (seqlLens && seqlLens.dims.length !== 1 && seqlLens.dims[0] !== batchSize) {
throw new Error('Input "seqlens" is expected to have 1 dimension and the same dim 0 as batch_size');
}
const totalSequenceLength = -1;
const maxSequenceLength = -1;
const broadcastResPosBias = false;
return {
@ -180,181 +206,36 @@ export const validateInputs = (inputs: readonly TensorView[], attributes: Attent
hiddenSize,
vHiddenSize,
headSize,
vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!),
vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads),
numHeads: attributes.numHeads,
kvNumHeads: attributes.kvNumHeads,
nReps: attributes.numHeads / attributes.kvNumHeads!,
nReps: attributes.numHeads / attributes.kvNumHeads,
pastPresentShareBuffer: false,
maskType,
scale: attributes.scale,
broadcastResPosBias,
passPastInKv,
qkvFormat,
isPastkvBSNH,
};
};
const createConcatProgramInfo = (
a: TensorView,
b: TensorView | undefined,
dataType: DataType,
params: AttentionParameters,
): ProgramInfo => {
const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize];
const component = 4;
const outputSize = ShapeUtil.size(outputShape) / component;
const presentSequenceLength = params.totalSequenceLength;
const output = outputVariable('present_kv', dataType, outputShape.length, component);
const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component);
const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined;
const H = Math.ceil(params.headSize / component);
const dispatch = { x: presentSequenceLength, y: a.dims[0], z: 1 };
const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank'];
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: params.pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
{ type: DataType.uint32, data: params.totalSequenceLength },
];
const inputs = [inputA];
if (inputB) {
programUniforms.push(
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b!.dims),
...createTensorShapeVariables(outputShape),
);
inputs.push(inputB);
} else {
programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape));
}
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'past_seqlen', type: 'u32' },
{ name: 'new_seqlen', type: 'u32' },
{ name: 'present_seqlen', type: 'u32' },
];
const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H;
var past_head_stride = uniforms.past_seqlen * H;
if (is_bsnh) {
past_head_stride = H;
}
let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h;
present_kv[out_offset] = past_kv[in_offset];`;
const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H;
let new_row_stride = num_heads * H;
let new_head_stride = H;
let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
present_kv[out_offset] = new_kv[in_offset];`;
const concatStr = b
? `if (s < past_seqlen) {
${pastStr}
} else if (s < past_seqlen + uniforms.new_seqlen) {
${newStr}
}`
: `if (s < past_seqlen + uniforms.new_seqlen) {
${newStr}
}`;
// TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit.
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)}
${shaderHelper.mainStart([H, params.kvNumHeads!, 1])}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var indices = ${output.offsetToIndices('global_idx')};
let h = local_id.x;
let n = local_id.y;
let s = workgroup_id.x;
let b = workgroup_id.y;
let num_heads = ${params.kvNumHeads!}u;
let H = ${H}u;
let present_seqlen = uniforms.present_seqlen;
let present_batch_stride = present_seqlen * num_heads * H;
var row_stride = H;
let is_bsnh = ${params.isPastkvBSNH};
if (is_bsnh) {
row_stride = num_heads * H;
}
var present_head_stride = present_seqlen * H;
if (is_bsnh) {
present_head_stride = H;
}
let past_seqlen = uniforms.past_seqlen;
let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
${concatStr}
}`;
return {
name: 'ConcatPastNew',
shaderCache: { hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType }],
dispatchGroup: dispatch,
programUniforms,
}),
getShaderSource,
};
};
export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
createAttributeWithCacheKey({ ...attributes });
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({ perm: [0, 2, 1, 3] });
const maybeExpandAndTransposeToBNSH = (
context: ComputeContext,
input: TensorView,
pastKV: TensorView | undefined,
params: AttentionParameters,
outputIndex: number,
) => {
const maybeTransposeToBNSH = (context: ComputeContext, input: TensorView, params: AttentionParameters) => {
let reshapedInput = input;
const numHeads = params.kvNumHeads!;
const nReps = params.nReps!;
if (input.dims.length === 3 && params.kvSequenceLength !== 0) {
reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]);
}
if (pastKV) {
reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params), {
inputs: [reshapedInput, pastKV],
outputs: [params.isPastkvBSNH ? outputIndex : -1],
})[0];
} else {
reshapedInput = context.compute(createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params), {
inputs: [reshapedInput],
outputs: [params.isPastkvBSNH ? outputIndex : -1],
})[0];
}
if (nReps !== 1) {
reshapedInput = context.compute(createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {
reshapedInput = context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
})[0];
reshapedInput = reshapedInput.reshape([
params.batchSize,
params.totalSequenceLength,
numHeads * nReps,
params.headSize,
]);
}
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
})[0];
return reshapedInput;
};
export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
@ -364,19 +245,49 @@ export const groupQueryAttention = (context: ComputeContext, attributes: Attenti
throw new Error('Packed KV is not implemented');
}
const q = context.inputs[0];
const k = context.inputs[1] && context.inputs[1].dims.length > 0 ? context.inputs[1] : undefined;
const v = context.inputs[2] && context.inputs[2].dims.length > 0 ? context.inputs[2] : undefined;
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
const seqLens = context.inputs.length > 4 ? context.inputs[5] : undefined;
const totalSequenceLengthInput = context.inputs.length > 5 ? context.inputs[6] : undefined;
const kvNumHeads = params.kvNumHeads ? params.kvNumHeads : params.numHeads;
// TODO Remove explicit split operation and use indexing in Attention implementation to avoid overhead.
const splitAttributes: SplitAttributes = createAttributeWithCacheKey({
axis: 2,
numOutputs: 3,
splitSizes: [params.numHeads * params.headSize, kvNumHeads * params.headSize, kvNumHeads * params.headSize],
});
const [query, key, value] =
!k && !v
? context.compute(createSplitProgramInfo([q], splitAttributes), { inputs: [q], outputs: [-1, -1, -1] })
: [q, k!, v!];
const Q = maybeTransposeToBNSHAndAddBias(
context,
params.batchSize,
params.numHeads,
params.sequenceLength,
params.headSize,
context.inputs[0],
query,
undefined,
0,
);
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1);
const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2);
applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);
applyAttention(
context,
Q,
maybeTransposeToBNSH(context, key, params),
maybeTransposeToBNSH(context, value, params),
undefined,
undefined,
pastKey,
pastValue,
undefined,
params,
seqLens,
totalSequenceLengthInput,
);
};

Просмотреть файл

@ -403,19 +403,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
);
if (kvBNSH) {
return applyAttention(
context,
Q,
key,
value,
keyPaddingMask,
undefined,
pastKey,
pastValue,
attentionBias,
params,
attributes,
);
return applyAttention(context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params);
}
if (!key || !value) {
throw new Error('key and value must be provided');
@ -442,5 +430,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
2 * params.hiddenSize,
);
applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes);
applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params);
};

Просмотреть файл

@ -71,7 +71,7 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
}`;
};
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
export const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const dataType = inputs[0].dataType;

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/bert/gqa_attention_base.h"
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
@ -11,31 +11,29 @@ namespace js {
using onnxruntime::js::JsKernel;
class GroupQueryAttention : public JsKernel {
class GroupQueryAttention : public JsKernel, GQAAttentionBase {
public:
explicit GroupQueryAttention(const OpKernelInfo& info)
: JsKernel(info) {
int64_t num_heads = 0;
int64_t kv_num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
: JsKernel(info), GQAAttentionBase(info, false) {
JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({
"numHeads" : $1,
"kvNumHeads" : $2,
"scale" : $3,
"softcap" : $4,
"doRotary" : $5,
"rotaryInterleaved" : $6,
"smoothSoftmax" : $7,
"localWindowSize" : $8
}),
static_cast<int32_t>(num_heads_),
static_cast<int32_t>(kv_num_heads_),
static_cast<float>(scale_));
static_cast<float>(scale_),
static_cast<float>(softcap_),
static_cast<int32_t>(do_rotary_),
static_cast<int32_t>(rotary_interleaved_),
static_cast<int32_t>(use_smooth_softmax_),
static_cast<int32_t>(local_window_size_));
}
protected:
int num_heads_; // number of attention heads
int kv_num_heads_; // number of k and v heads
float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size)
};
} // namespace js