[js/webgpu] Add activation for conv3d naive (#21466)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Xu Xing 2024-07-29 23:47:41 +08:00 коммит произвёл GitHub
Родитель dbff0cd098
Коммит 5bc12bf209
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 149 добавлений и 28 удалений

Просмотреть файл

@ -26,6 +26,9 @@ import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
import {typeSnippet} from './activation_util';
const arrayProduct = (arr: number[]) => {
let product = 1;
@ -218,8 +221,8 @@ export const computeConv3DInfo =
export const createConv3DNaiveProgramInfo =
(inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[],
filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => {
const isChannelsLast = dataFormat === 'channelsLast';
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
const isChannelLast = dataFormat === 'channelsLast';
const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1];
// TODO: enable vec4.
const isVec4 = false;
const workGroupSize: [number, number, number] = [64, 1, 1];
@ -228,13 +231,14 @@ export const createConv3DNaiveProgramInfo =
LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`);
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims},
{type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides},
{type: DataType.uint32, data: attributes.dilations}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const hasBias = inputs.length === 3;
@ -251,6 +255,7 @@ export const createConv3DNaiveProgramInfo =
{name: 'strides', type: 'u32', length: attributes.strides.length},
{name: 'dilations', type: 'u32', length: attributes.dilations.length}
];
appendActivationUniforms(attributes, uniforms);
// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
@ -266,10 +271,12 @@ export const createConv3DNaiveProgramInfo =
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : array<u32, 5>) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[${isChannelsLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
isVec4 ? '/ 4' : ''}];
}`;
}
const resType = typeSnippet(innerElementSize, t);
const applyActivation = getActivationSnippet(attributes, resType, t);
return `
${declareFunctions}
@ -287,28 +294,28 @@ export const createConv3DNaiveProgramInfo =
let coords = ${output.offsetToIndices('global_idx')};
let batch = ${getElementAt('coords', 0, x.rank)};
let d2 = ${
isChannelsLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)};
isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)};
let xFRCCorner = vec3<u32>(${
isChannelsLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)},
${isChannelsLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)},
${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
${
isChannelsLast ? getElementAt('coords', 3, x.rank) :
getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads;
isChannelLast ? getElementAt('coords', 3, x.rank) :
getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads;
let xFCorner = xFRCCorner.x;
let xRCorner = xFRCCorner.y;
let xCCorner = xFRCCorner.z;
let xShapeY = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)};
let xShapeZ = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)};
let xShapeW = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)};
let xShapeU = ${
isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)};
isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)};
let inputDepthNearestVec4 = (xShapeU / 4) * 4;
let inputDepthVec4Remainder = xShapeU % 4;
var dotProd = 0.0;
var value = 0.0;
for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {
let xF = xFCorner + wF * uniforms.dilations[0];
if (xF < 0 || xF >= xShapeY) {
@ -329,13 +336,13 @@ export const createConv3DNaiveProgramInfo =
for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) {
${
isChannelsLast ? `let xValues = vec4<f32>(
isChannelLast ? `let xValues = vec4<f32>(
getX(batch, xF, xR, xC, d1),
getX(batch, xF, xR, xC, d1 + 1),
getX(batch, xF, xR, xC, d1 + 2),
getX(batch, xF, xR, xC, d1 + 3));
` :
`let xValues = vec4<f32>(
`let xValues = vec4<f32>(
getX(batch, d1, xF, xR, xC),
getX(batch, d1 + 1, xF, xR, xC),
getX(batch, d1 + 2, xF, xR, xC),
@ -346,36 +353,36 @@ export const createConv3DNaiveProgramInfo =
getW(d2, d1 + 1, wF, wR, wC),
getW(d2, d1 + 2, wF, wR, wC),
getW(d2, d1 + 3, wF, wR, wC));
dotProd += dot(xValues, wValues);
value += dot(xValues, wValues);
}
if (inputDepthVec4Remainder == 1) {
${
isChannelsLast ? `dotProd += getX(batch, xF, xR, xC, inputDepthNearestVec4)
isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);` :
`dotProd += getX(batch, inputDepthNearestVec4, xF, xR, xC)
`value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
* getW(d2, inputDepthNearestVec4, wF, wR, wC);`}
} else if (inputDepthVec4Remainder == 2) {
${
isChannelsLast ? `let xValues = vec2<f32>(
isChannelLast ? `let xValues = vec2<f32>(
getX(batch, xF, xR, xC, inputDepthNearestVec4),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1));
` :
`let xValues = vec2<f32>(
`let xValues = vec2<f32>(
getX(batch, inputDepthNearestVec4, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC));
`}
let wValues = vec2<f32>(
getW(d2, inputDepthNearestVec4, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC));
dotProd += dot(xValues, wValues);
value += dot(xValues, wValues);
} else if (inputDepthVec4Remainder == 3) {
${
isChannelsLast ? `let xValues = vec3<f32>(
isChannelLast ? `let xValues = vec3<f32>(
getX(batch, xF, xR, xC, inputDepthNearestVec4),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1),
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2));
` :
`let xValues = vec3<f32>(
`let xValues = vec3<f32>(
getX(batch, inputDepthNearestVec4, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC),
getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC));
@ -384,19 +391,20 @@ export const createConv3DNaiveProgramInfo =
getW(d2, inputDepthNearestVec4, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC),
getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC));
dotProd += dot(xValues, wValues);
value += dot(xValues, wValues);
}
}
}
}
${hasBias ? 'dotProd = dotProd + getBiasByOutputCoords(coords)' : ''};
result[global_idx] = f32(dotProd);
${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''};
${applyActivation}
result[global_idx] = f32(value);
}`;
};
return {
name: 'Conv3DNaive',
shaderCache:
{hint: `${attributes.cacheKey};${isChannelsLast};${innerElementSize};${hasBias}`, inputDependencies},
{hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},

Просмотреть файл

@ -0,0 +1,112 @@
[
{
"name": "fused conv3d with relu, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu",
"operator": "FusedConv",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "activation", "data": "Relu", "type": "string" },
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
{ "name": "auto_pad", "data": "VALID", "type": "string" },
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.25, 0.5, 0.75, 1],
"dims": [1, 1, 2, 1, 2],
"type": "float32"
},
{
"data": [-0.125, -0.25, -0.375, 0.5, 0.625, -0.75, -0.875, -1],
"dims": [2, 1, 2, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0.0625, 0],
"dims": [1, 2, 1, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "fused conv3d with clip",
"operator": "FusedConv",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "activation", "data": "Clip", "type": "string" },
{ "name": "activation_params", "data": [1.0, 3.0], "type": "floats" },
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
{ "name": "auto_pad", "data": "VALID", "type": "string" },
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.25, 0.5, 0.75, 1],
"dims": [1, 1, 2, 1, 2],
"type": "float32"
},
{
"data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1],
"dims": [2, 1, 2, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2.1875],
"dims": [1, 2, 1, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "fused conv3d with HardSigmoid, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu",
"operator": "FusedConv",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "activation", "data": "HardSigmoid", "type": "string" },
{ "name": "activation_params", "data": [0.1, 0.3], "type": "floats" },
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
{ "name": "auto_pad", "data": "VALID", "type": "string" },
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [0.25, 0.5, 0.75, 1],
"dims": [1, 1, 2, 1, 2],
"type": "float32"
},
{
"data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1],
"dims": [2, 1, 2, 1, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0.39375001192092896, 0.518750011920929],
"dims": [1, 2, 1, 1, 1],
"type": "float32"
}
]
}
]
}
]

Просмотреть файл

@ -1358,6 +1358,7 @@
"fast-gelu.jsonc",
"floor.jsonc",
"fused-conv.jsonc",
"fused-conv3dncdhw.jsonc",
"gather-elements.jsonc",
"gemm.jsonc",
"global-average-pool.jsonc",