[js/webgpu] Optimize grouped conv (#21892)

### Description
<!-- Describe your changes. -->
#21618

This PR optimizes grouped conv by 1) more sequential memory access in
gpu 2) reusing input's data to reduce global memory access times.

See `Conv|GroupedConv` op in
[Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base-960h) becomes
92 ms from 1058 ms on iGPUs with 32 EU.

For the whole model on my iGPUs with 32 EU,
wav2vec2 model becomes 982ms from 1942 ms.
squeezebert-uncased model becomes 71.86ms from 431.77ms.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Jiajia Qin 2024-09-05 08:16:35 +08:00 коммит произвёл GitHub
Родитель 30f07758a2
Коммит 3580e01348
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 156 добавлений и 78 удалений

Просмотреть файл

@ -15,7 +15,7 @@ import {
tensorTypeToWsglStorageType,
UniformsArrayType,
} from './common';
import { calculateOutputShape, ConvAttributes } from './conv';
import { ConvAttributes } from './conv';
import { appendActivationUniforms, appendActivationUniformsData, getActivationSnippet } from './fuse-utils';
/**
@ -25,24 +25,19 @@ import { appendActivationUniforms, appendActivationUniformsData, getActivationSn
export const createGroupedConvProgramInfo = (
inputs: readonly TensorView[],
attributes: ConvAttributes,
outputShape: readonly number[],
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[],
): ProgramInfo => {
const hasBias = inputs.length > 2;
const processBias = hasBias ? 'value += b[output_channel];' : '';
const xShape = inputs[0].dims;
const wShape = inputs[1].dims;
const outputChannelsPerGroup = wShape[0] / attributes.group;
const isChannelLast = attributes.format === 'NHWC';
const outputShape = calculateOutputShape(
xShape,
wShape,
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelLast,
);
const outputSize = ShapeUtil.size(outputShape);
const outputChannels = isChannelLast ? outputShape[3] : outputShape[1];
const outputChannelsPerGroup = outputChannels / attributes.group;
const components = isChannelLast && outputChannelsPerGroup >= 4 ? getMaxComponents(outputChannels) : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
@ -52,23 +47,23 @@ export const createGroupedConvProgramInfo = (
{ type: DataType.uint32, data: outputChannelsPerGroup },
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(...createTensorShapeVariables(xShape, wShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
inputDependencies.push('rank');
}
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(
...createTensorShapeVariables(xShape, [wShape[0], wShape[1], wShape[2], wShape[3] / components]),
);
const inputDependencies: ProgramInputTensorInfoDependency[] = hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'];
programUniforms.push(
...createTensorShapeVariables([outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]),
);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
const baseType = tensorTypeToWsglStorageType(output.type.tensor);
const applyActivation = getActivationSnippet(attributes, output.type.value, baseType);
const x = inputVariable('x', inputs[0].dataType, xShape.length);
const w = inputVariable('w', inputs[1].dataType, wShape.length);
const w = inputVariable('w', inputs[1].dataType, wShape.length, components);
const inputVars = [x, w];
if (hasBias) {
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims.length));
inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components));
}
const uniforms: UniformsArrayType = [
@ -79,6 +74,54 @@ export const createGroupedConvProgramInfo = (
{ name: 'output_channels_per_group', type: 'u32' },
];
appendActivationUniforms(attributes, uniforms);
const calculateResult = isChannelLast
? `
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[0]; wHeight++) {
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
if (xHeight < 0u || xHeight >= uniforms.x_shape[1]) {
continue;
}
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[1]; wWidth++) {
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
if (xWidth < 0u || xWidth >= uniforms.x_shape[2]) {
continue;
}
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[2]; wInChannel++) {
let input_channel = in_channel_offset + wInChannel;
let xVal = ${x.get('batch', 'xHeight', 'xWidth', 'input_channel')};
let wVal = ${w.get('wHeight', 'wWidth', 'wInChannel', 'output_channel')};
value += xVal * wVal;
}
}
}
`
: `
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
let input_channel = in_channel_offset + wInChannel;
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
if (xHeight < 0u || xHeight >= uniforms.x_shape[2]) {
continue;
}
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
if (xWidth < 0u || xWidth >= uniforms.x_shape[3]) {
continue;
}
let xVal = ${x.get('batch', 'input_channel', 'xHeight', 'xWidth')};
let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
value += xVal * wVal;
}
}
}
`;
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
@ -91,34 +134,11 @@ export const createGroupedConvProgramInfo = (
let xRCCorner: vec2<u32> = vec2<u32>(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${
isChannelLast ? 2 : 3
}]) * uniforms.strides - uniforms.pads;
let group_id: u32 = output_channel / uniforms.output_channels_per_group;
let group_id: u32 = output_channel * ${components} / uniforms.output_channels_per_group;
var in_channel_offset = group_id * uniforms.w_shape[${isChannelLast ? 2 : 1}];
var value: ${output.type.value} = ${output.type.value}(0);
for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {
let input_channel = group_id * uniforms.w_shape[1] + wInChannel;
for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {
let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];
if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) {
continue;
}
for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {
let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];
if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) {
continue;
}
let xVal = ${
isChannelLast
? x.get('batch', 'xHeight', 'xWidth', 'input_channel')
: x.get('batch', 'input_channel', 'xHeight', 'xWidth')
};
let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')};
value += xVal*wVal;
}
}
}
${calculateResult}
${processBias}
${applyActivation}
${output.setByOffset('global_idx', 'value')}
@ -126,7 +146,7 @@ export const createGroupedConvProgramInfo = (
};
return {
name: 'GroupedConv',
shaderCache: { hint: attributes.cacheKey, inputDependencies },
shaderCache: { hint: `${attributes.cacheKey}_${components}`, inputDependencies },
getRunData: () => ({
outputs: [
{

Просмотреть файл

@ -162,7 +162,33 @@ const conv2d = (
// const hasPreluActivationWeights = false; /* TODO: add support for prelu activation weights */
const isChannelsLast = attributes.format === 'NHWC';
const outputShape = calculateOutputShape(
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelsLast,
);
if (attributes.group !== 1) {
const convInputs = [inputs[0]];
if (isChannelsLast) {
const transposedWeight =
(context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
inputs: [1],
outputs: [attributes.wIsConst ? -2 : -1],
})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}
convInputs.push(transposedWeight);
} else {
convInputs.push(inputs[1]);
}
if (inputs.length === 3) {
convInputs.push(inputs[2]);
}
// NVIDIA GPU with ampere architecture fails with below 2 cases, but we couldn't repro them with any other
// GPUs. So just disable vectorize on NVIDIA ampere to ensure always correct outputs.
// [webgpu]Conv - conv - vectorize group - B
@ -176,33 +202,14 @@ const conv2d = (
attributes.dilations[0] === 1 &&
attributes.dilations[1] === 1
) {
const outputShape = calculateOutputShape(
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelsLast,
);
const transposedWeight =
(context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {
inputs: [1],
outputs: [attributes.wIsConst ? -2 : -1],
})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
}
const convInputs = [inputs[0], transposedWeight];
if (inputs.length === 3) {
convInputs.push(inputs[2]);
}
context.compute(
createGroupedConvVectorizeProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction),
{ inputs: convInputs },
);
} else {
context.compute(createGroupedConvProgramInfo(inputs, attributes, squeezeOutputShapeFunction));
context.compute(createGroupedConvProgramInfo(convInputs, attributes, outputShape, squeezeOutputShapeFunction), {
inputs: convInputs,
});
}
return;
}
@ -214,14 +221,6 @@ const conv2d = (
const weightHeight = inputs[1].dims[2];
const weightWidth = inputs[1].dims[3];
const outputShape = calculateOutputShape(
inputs[0].dims,
inputs[1].dims,
attributes.dilations,
attributes.pads,
attributes.strides,
isChannelsLast,
);
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
const outChannels = outputShape[isChannelsLast ? 3 : 1];

Просмотреть файл

@ -298,6 +298,65 @@
}
]
},
{
"name": "conv - group - NHWC",
"operator": "Conv",
"inputShapeDefinitions": "rankOnly",
"opset": { "domain": "", "version": 17 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "group", "data": 2, "type": "int" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0
],
"dims": [1, 8, 3, 3],
"type": "float32"
},
{
"data": [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0
],
"dims": [8, 4, 2, 2],
"type": "float32"
},
{
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [
1224.0999755859375, 1312.0999755859375, 936.0999755859375, 1024.0999755859375, 872.2000122070312,
976.2000122070312, 1016.2000122070312, 1120.199951171875, 1224.300048828125, 1312.300048828125,
936.2999877929688, 1024.300048828125, 872.4000244140625, 976.4000244140625, 1016.4000244140625,
1120.4000244140625, 859.5, 947.5, 594.5, 682.5, 1075.5999755859375, 1179.5999755859375, 866.5999755859375,
970.5999755859375, 859.7000122070312, 947.7000122070312, 594.7000122070312, 682.7000122070312,
1075.800048828125, 1179.800048828125, 866.7999877929688, 970.7999877929688
],
"dims": [1, 8, 2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "conv - vectorize group - A",
"operator": "Conv",