[js/webgpu] Optimize grouped conv (#21892)
### Description <!-- Describe your changes. --> #21618 This PR optimizes grouped conv by 1) more sequential memory access in gpu 2) reusing input's data to reduce global memory access times. See `Conv|GroupedConv` op in [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) becomes 92 ms from 1058 ms on iGPUs with 32 EU. For the whole model on my iGPUs with 32 EU, wav2vec2 model becomes 982ms from 1942 ms. squeezebert-uncased model becomes 71.86ms from 431.77ms. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Родитель
30f07758a2
Коммит
3580e01348
|
@ -15,7 +15,7 @@ import {
|
|||
tensorTypeToWsglStorageType,
|
||||
UniformsArrayType,
|
||||
} from './common';
|
||||
import { calculateOutputShape, ConvAttributes } from './conv';
|
||||
import { ConvAttributes } from './conv';
|
||||
import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from './fuse-utils';
|
||||
|
||||
/**
|
||||
|
@ -25,24 +25,19 @@ import { appendActivationUniforms, appendActivationUniformsData, getActivationSn
|
|||
export const createGroupedConvProgramInfo = (
|
||||
inputs: readonly TensorView[],
|
||||
attributes: ConvAttributes,
|
||||
outputShape: readonly number[],
|
||||
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
|
||||
): ProgramInfo => {
|
||||
const hasBias = inputs.length > 2;
|
||||
const processBias = hasBias ? 'value += b[output_channel];' : '';
|
||||
const xShape = inputs[0].dims;
|
||||
const wShape = inputs[1].dims;
|
||||
const outputChannelsPerGroup = wShape[0] / attributes.group;
|
||||
|
||||
const isChannelLast = attributes.format === 'NHWC';
|
||||
const outputShape = calculateOutputShape(
|
||||
xShape,
|
||||
wShape,
|
||||
attributes.dilations,
|
||||
attributes.pads,
|
||||
attributes.strides,
|
||||
isChannelLast,
|
||||
);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const outputChannels = isChannelLast ? outputShape[3] : outputShape[1];
|
||||
const outputChannelsPerGroup = outputChannels / attributes.group;
|
||||
const components = isChannelLast && outputChannelsPerGroup >= 4 ? getMaxComponents(outputChannels) : 1;
|
||||
const outputSize = ShapeUtil.size(outputShape) / components;
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
|
@ -52,23 +47,23 @@ export const createGroupedConvProgramInfo = (
|
|||
{ type: DataType.uint32, data: outputChannelsPerGroup },
|
||||
];
|
||||
appendActivationUniformsData(attributes, programUniforms);
|
||||
programUniforms.push(...createTensorShapeVariables(xShape, wShape));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
if (hasBias) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
|
||||
inputDependencies.push('rank');
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(xShape, [wShape[0], wShape[1], wShape[2], wShape[3] / components]),
|
||||
);
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'];
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables([outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]),
|
||||
);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
|
||||
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
|
||||
const applyActivation = getActivationSnippet(attributes, output.type.value, baseType);
|
||||
const x = inputVariable('x', inputs[0].dataType, xShape.length);
|
||||
const w = inputVariable('w', inputs[1].dataType, wShape.length);
|
||||
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
|
||||
const inputVars = [x, w];
|
||||
if (hasBias) {
|
||||
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims.length));
|
||||
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components));
|
||||
}
|
||||
|
||||
const uniforms: UniformsArrayType = [
|
||||
|
@ -79,6 +74,54 @@ export const createGroupedConvProgramInfo = (
|
|||
{ name: 'output_channels_per_group', type: 'u32' },
|
||||
];
|
||||
appendActivationUniforms(attributes, uniforms);
|
||||
|
||||
const calculateResult = isChannelLast
|
||||
? `
|
||||
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[0]; wHeight++) {
|
||||
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
|
||||
|
||||
if (xHeight < 0u || xHeight >= uniforms.x_shape[1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[1]; wWidth++) {
|
||||
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
|
||||
if (xWidth < 0u || xWidth >= uniforms.x_shape[2]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[2]; wInChannel++) {
|
||||
let input_channel = in_channel_offset + wInChannel;
|
||||
let xVal = ${x.get('batch', 'xHeight', 'xWidth', 'input_channel')};
|
||||
let wVal = ${w.get('wHeight', 'wWidth', 'wInChannel', 'output_channel')};
|
||||
value += xVal * wVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
: `
|
||||
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
|
||||
let input_channel = in_channel_offset + wInChannel;
|
||||
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
|
||||
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
|
||||
|
||||
if (xHeight < 0u || xHeight >= uniforms.x_shape[2]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
|
||||
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
|
||||
if (xWidth < 0u || xWidth >= uniforms.x_shape[3]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let xVal = ${x.get('batch', 'input_channel', 'xHeight', 'xWidth')};
|
||||
let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
|
||||
value += xVal * wVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
`;
|
||||
return `
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
|
||||
|
||||
|
@ -91,34 +134,11 @@ export const createGroupedConvProgramInfo = (
|
|||
let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${
|
||||
isChannelLast ? 2 : 3
|
||||
}]) * uniforms.strides - uniforms.pads;
|
||||
let group_id: u32 = output_channel / uniforms.output_channels_per_group;
|
||||
let group_id: u32 = output_channel * ${components} / uniforms.output_channels_per_group;
|
||||
var in_channel_offset = group_id * uniforms.w_shape[${isChannelLast ? 2 : 1}];
|
||||
|
||||
var value: ${output.type.value} = ${output.type.value}(0);
|
||||
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
|
||||
let input_channel = group_id * uniforms.w_shape[1] + wInChannel;
|
||||
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
|
||||
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
|
||||
|
||||
if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
|
||||
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
|
||||
if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let xVal = ${
|
||||
isChannelLast
|
||||
? x.get('batch', 'xHeight', 'xWidth', 'input_channel')
|
||||
: x.get('batch', 'input_channel', 'xHeight', 'xWidth')
|
||||
};
|
||||
let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
|
||||
value += xVal*wVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
${calculateResult}
|
||||
${processBias}
|
||||
${applyActivation}
|
||||
${output.setByOffset('global_idx', 'value')}
|
||||
|
@ -126,7 +146,7 @@ export const createGroupedConvProgramInfo = (
|
|||
};
|
||||
return {
|
||||
name: 'GroupedConv',
|
||||
shaderCache: { hint: attributes.cacheKey, inputDependencies },
|
||||
shaderCache: { hint: `${attributes.cacheKey}_${components}`, inputDependencies },
|
||||
getRunData: () => ({
|
||||
outputs: [
|
||||
{
|
||||
|
|
|
@ -162,7 +162,33 @@ const conv2d = (
|
|||
|
||||
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
|
||||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
const outputShape = calculateOutputShape(
|
||||
inputs[0].dims,
|
||||
inputs[1].dims,
|
||||
attributes.dilations,
|
||||
attributes.pads,
|
||||
attributes.strides,
|
||||
isChannelsLast,
|
||||
);
|
||||
if (attributes.group !== 1) {
|
||||
const convInputs = [inputs[0]];
|
||||
if (isChannelsLast) {
|
||||
const transposedWeight =
|
||||
(context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
|
||||
inputs: [1],
|
||||
outputs: [attributes.wIsConst ? -2 : -1],
|
||||
})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
}
|
||||
convInputs.push(transposedWeight);
|
||||
} else {
|
||||
convInputs.push(inputs[1]);
|
||||
}
|
||||
if (inputs.length === 3) {
|
||||
convInputs.push(inputs[2]);
|
||||
}
|
||||
// NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
|
||||
// GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
|
||||
// [webgpu]Conv - conv - vectorize group - B
|
||||
|
@ -176,33 +202,14 @@ const conv2d = (
|
|||
attributes.dilations[0] === 1 &&
|
||||
attributes.dilations[1] === 1
|
||||
) {
|
||||
const outputShape = calculateOutputShape(
|
||||
inputs[0].dims,
|
||||
inputs[1].dims,
|
||||
attributes.dilations,
|
||||
attributes.pads,
|
||||
attributes.strides,
|
||||
isChannelsLast,
|
||||
);
|
||||
const transposedWeight =
|
||||
(context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
|
||||
inputs: [1],
|
||||
outputs: [attributes.wIsConst ? -2 : -1],
|
||||
})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
}
|
||||
const convInputs = [inputs[0], transposedWeight];
|
||||
if (inputs.length === 3) {
|
||||
convInputs.push(inputs[2]);
|
||||
}
|
||||
context.compute(
|
||||
createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
|
||||
{ inputs: convInputs },
|
||||
);
|
||||
} else {
|
||||
context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
|
||||
context.compute(createGroupedConvProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction), {
|
||||
inputs: convInputs,
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -214,14 +221,6 @@ const conv2d = (
|
|||
const weightHeight = inputs[1].dims[2];
|
||||
const weightWidth = inputs[1].dims[3];
|
||||
|
||||
const outputShape = calculateOutputShape(
|
||||
inputs[0].dims,
|
||||
inputs[1].dims,
|
||||
attributes.dilations,
|
||||
attributes.pads,
|
||||
attributes.strides,
|
||||
isChannelsLast,
|
||||
);
|
||||
const outHeight = outputShape[isChannelsLast ? 1 : 2];
|
||||
const outWidth = outputShape[isChannelsLast ? 2 : 3];
|
||||
const outChannels = outputShape[isChannelsLast ? 3 : 1];
|
||||
|
|
|
@ -298,6 +298,65 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - group - NHWC",
|
||||
"operator": "Conv",
|
||||
"inputShapeDefinitions": "rankOnly",
|
||||
"opset": { "domain": "", "version": 17 },
|
||||
"attributes": [
|
||||
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
|
||||
{ "name": "group", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
|
||||
19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
|
||||
14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
|
||||
10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0
|
||||
],
|
||||
"dims": [1, 8, 3, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0,
|
||||
5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0
|
||||
],
|
||||
"dims": [8, 4, 2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
|
||||
"dims": [8],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1224.0999755859375, 1312.0999755859375, 936.0999755859375, 1024.0999755859375, 872.2000122070312,
|
||||
976.2000122070312, 1016.2000122070312, 1120.199951171875, 1224.300048828125, 1312.300048828125,
|
||||
936.2999877929688, 1024.300048828125, 872.4000244140625, 976.4000244140625, 1016.4000244140625,
|
||||
1120.4000244140625, 859.5, 947.5, 594.5, 682.5, 1075.5999755859375, 1179.5999755859375, 866.5999755859375,
|
||||
970.5999755859375, 859.7000122070312, 947.7000122070312, 594.7000122070312, 682.7000122070312,
|
||||
1075.800048828125, 1179.800048828125, 866.7999877929688, 970.7999877929688
|
||||
],
|
||||
"dims": [1, 8, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "conv - vectorize group - A",
|
||||
"operator": "Conv",
|
||||
|
|
Загрузка…
Ссылка в новой задаче