[js/webgpu] Add activation for conv3d naive (#21466)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Родитель
dbff0cd098
Коммит
5bc12bf209
|
@ -26,6 +26,9 @@ import {ShapeUtil} from '../../../util';
|
|||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
|
||||
import {ConvAttributes} from '../conv';
|
||||
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
|
||||
|
||||
import {typeSnippet} from './activation_util';
|
||||
|
||||
const arrayProduct = (arr: number[]) => {
|
||||
let product = 1;
|
||||
|
@ -218,8 +221,8 @@ export const computeConv3DInfo =
|
|||
export const createConv3DNaiveProgramInfo =
|
||||
(inputs: readonly TensorView[], attributes: ConvAttributes, outputShape: readonly number[],
|
||||
filterDims: readonly number[], pads: readonly number[], dataFormat: string): ProgramInfo => {
|
||||
const isChannelsLast = dataFormat === 'channelsLast';
|
||||
const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1];
|
||||
const isChannelLast = dataFormat === 'channelsLast';
|
||||
const inChannels = isChannelLast ? inputs[0].dims[3] : inputs[0].dims[1];
|
||||
// TODO: enable vec4.
|
||||
const isVec4 = false;
|
||||
const workGroupSize: [number, number, number] = [64, 1, 1];
|
||||
|
@ -228,13 +231,14 @@ export const createConv3DNaiveProgramInfo =
|
|||
|
||||
LOG_DEBUG('verbose', () => `[conv3d_naive_webgpu] dispatch = ${dispatch}`);
|
||||
|
||||
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
|
||||
const innerElementSize = isVec4 ? (isChannelLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: filterDims},
|
||||
{type: DataType.uint32, data: pads}, {type: DataType.uint32, data: attributes.strides},
|
||||
{type: DataType.uint32, data: attributes.dilations}
|
||||
];
|
||||
appendActivationUniformsData(attributes, programUniforms);
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
const hasBias = inputs.length === 3;
|
||||
|
@ -251,6 +255,7 @@ export const createConv3DNaiveProgramInfo =
|
|||
{name: 'strides', type: 'u32', length: attributes.strides.length},
|
||||
{name: 'dilations', type: 'u32', length: attributes.dilations.length}
|
||||
];
|
||||
appendActivationUniforms(attributes, uniforms);
|
||||
// TODO: support component 2, 3.
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const t = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
@ -266,10 +271,12 @@ export const createConv3DNaiveProgramInfo =
|
|||
inputVariables.push(bias);
|
||||
declareFunctions += `
|
||||
fn getBiasByOutputCoords(coords : array<u32, 5>) -> ${isVec4 ? `vec4<${t}>` : t} {
|
||||
return bias[${isChannelsLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
|
||||
return bias[${isChannelLast ? getElementAt('coords', 4, 5) : getElementAt('coords', 1, 5)}${
|
||||
isVec4 ? '/ 4' : ''}];
|
||||
}`;
|
||||
}
|
||||
const resType = typeSnippet(innerElementSize, t);
|
||||
const applyActivation = getActivationSnippet(attributes, resType, t);
|
||||
|
||||
return `
|
||||
${declareFunctions}
|
||||
|
@ -287,28 +294,28 @@ export const createConv3DNaiveProgramInfo =
|
|||
let coords = ${output.offsetToIndices('global_idx')};
|
||||
let batch = ${getElementAt('coords', 0, x.rank)};
|
||||
let d2 = ${
|
||||
isChannelsLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)};
|
||||
isChannelLast ? getElementAt('coords', x.rank - 1, x.rank) : getElementAt('coords', 1, x.rank)};
|
||||
let xFRCCorner = vec3<u32>(${
|
||||
isChannelsLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)},
|
||||
${isChannelsLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
|
||||
isChannelLast ? getElementAt('coords', 1, x.rank) : getElementAt('coords', 2, x.rank)},
|
||||
${isChannelLast ? getElementAt('coords', 2, x.rank) : getElementAt('coords', 3, x.rank)},
|
||||
${
|
||||
isChannelsLast ? getElementAt('coords', 3, x.rank) :
|
||||
getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads;
|
||||
isChannelLast ? getElementAt('coords', 3, x.rank) :
|
||||
getElementAt('coords', 4, x.rank)}) * uniforms.strides - uniforms.pads;
|
||||
let xFCorner = xFRCCorner.x;
|
||||
let xRCorner = xFRCCorner.y;
|
||||
let xCCorner = xFRCCorner.z;
|
||||
let xShapeY = ${
|
||||
isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)};
|
||||
isChannelLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)};
|
||||
let xShapeZ = ${
|
||||
isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)};
|
||||
isChannelLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)};
|
||||
let xShapeW = ${
|
||||
isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)};
|
||||
isChannelLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)};
|
||||
let xShapeU = ${
|
||||
isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)};
|
||||
isChannelLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)};
|
||||
let inputDepthNearestVec4 = (xShapeU / 4) * 4;
|
||||
let inputDepthVec4Remainder = xShapeU % 4;
|
||||
|
||||
var dotProd = 0.0;
|
||||
var value = 0.0;
|
||||
for (var wF = 0u; wF < uniforms.filter_dims[0]; wF++) {
|
||||
let xF = xFCorner + wF * uniforms.dilations[0];
|
||||
if (xF < 0 || xF >= xShapeY) {
|
||||
|
@ -329,13 +336,13 @@ export const createConv3DNaiveProgramInfo =
|
|||
|
||||
for (var d1 = 0u; d1 < inputDepthNearestVec4; d1 += 4) {
|
||||
${
|
||||
isChannelsLast ? `let xValues = vec4<f32>(
|
||||
isChannelLast ? `let xValues = vec4<f32>(
|
||||
getX(batch, xF, xR, xC, d1),
|
||||
getX(batch, xF, xR, xC, d1 + 1),
|
||||
getX(batch, xF, xR, xC, d1 + 2),
|
||||
getX(batch, xF, xR, xC, d1 + 3));
|
||||
` :
|
||||
`let xValues = vec4<f32>(
|
||||
`let xValues = vec4<f32>(
|
||||
getX(batch, d1, xF, xR, xC),
|
||||
getX(batch, d1 + 1, xF, xR, xC),
|
||||
getX(batch, d1 + 2, xF, xR, xC),
|
||||
|
@ -346,36 +353,36 @@ export const createConv3DNaiveProgramInfo =
|
|||
getW(d2, d1 + 1, wF, wR, wC),
|
||||
getW(d2, d1 + 2, wF, wR, wC),
|
||||
getW(d2, d1 + 3, wF, wR, wC));
|
||||
dotProd += dot(xValues, wValues);
|
||||
value += dot(xValues, wValues);
|
||||
}
|
||||
if (inputDepthVec4Remainder == 1) {
|
||||
${
|
||||
isChannelsLast ? `dotProd += getX(batch, xF, xR, xC, inputDepthNearestVec4)
|
||||
isChannelLast ? `value += getX(batch, xF, xR, xC, inputDepthNearestVec4)
|
||||
* getW(d2, inputDepthNearestVec4, wF, wR, wC);` :
|
||||
`dotProd += getX(batch, inputDepthNearestVec4, xF, xR, xC)
|
||||
`value += getX(batch, inputDepthNearestVec4, xF, xR, xC)
|
||||
* getW(d2, inputDepthNearestVec4, wF, wR, wC);`}
|
||||
} else if (inputDepthVec4Remainder == 2) {
|
||||
${
|
||||
isChannelsLast ? `let xValues = vec2<f32>(
|
||||
isChannelLast ? `let xValues = vec2<f32>(
|
||||
getX(batch, xF, xR, xC, inputDepthNearestVec4),
|
||||
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1));
|
||||
` :
|
||||
`let xValues = vec2<f32>(
|
||||
`let xValues = vec2<f32>(
|
||||
getX(batch, inputDepthNearestVec4, xF, xR, xC),
|
||||
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC));
|
||||
`}
|
||||
let wValues = vec2<f32>(
|
||||
getW(d2, inputDepthNearestVec4, wF, wR, wC),
|
||||
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC));
|
||||
dotProd += dot(xValues, wValues);
|
||||
value += dot(xValues, wValues);
|
||||
} else if (inputDepthVec4Remainder == 3) {
|
||||
${
|
||||
isChannelsLast ? `let xValues = vec3<f32>(
|
||||
isChannelLast ? `let xValues = vec3<f32>(
|
||||
getX(batch, xF, xR, xC, inputDepthNearestVec4),
|
||||
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 1),
|
||||
getX(batch, xF, xR, xC, inputDepthNearestVec4 + 2));
|
||||
` :
|
||||
`let xValues = vec3<f32>(
|
||||
`let xValues = vec3<f32>(
|
||||
getX(batch, inputDepthNearestVec4, xF, xR, xC),
|
||||
getX(batch, inputDepthNearestVec4 + 1, xF, xR, xC),
|
||||
getX(batch, inputDepthNearestVec4 + 2, xF, xR, xC));
|
||||
|
@ -384,19 +391,20 @@ export const createConv3DNaiveProgramInfo =
|
|||
getW(d2, inputDepthNearestVec4, wF, wR, wC),
|
||||
getW(d2, inputDepthNearestVec4 + 1, wF, wR, wC),
|
||||
getW(d2, inputDepthNearestVec4 + 2, wF, wR, wC));
|
||||
dotProd += dot(xValues, wValues);
|
||||
value += dot(xValues, wValues);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
${hasBias ? 'dotProd = dotProd + getBiasByOutputCoords(coords)' : ''};
|
||||
result[global_idx] = f32(dotProd);
|
||||
${hasBias ? 'value = value + getBiasByOutputCoords(coords)' : ''};
|
||||
${applyActivation}
|
||||
result[global_idx] = f32(value);
|
||||
}`;
|
||||
};
|
||||
return {
|
||||
name: 'Conv3DNaive',
|
||||
shaderCache:
|
||||
{hint: `${attributes.cacheKey};${isChannelsLast};${innerElementSize};${hasBias}`, inputDependencies},
|
||||
{hint: `${attributes.cacheKey};${isChannelLast};${innerElementSize};${hasBias}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
[
|
||||
{
|
||||
"name": "fused conv3d with relu, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu",
|
||||
"operator": "FusedConv",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "activation", "data": "Relu", "type": "string" },
|
||||
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
|
||||
{ "name": "auto_pad", "data": "VALID", "type": "string" },
|
||||
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
|
||||
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.25, 0.5, 0.75, 1],
|
||||
"dims": [1, 1, 2, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [-0.125, -0.25, -0.375, 0.5, 0.625, -0.75, -0.875, -1],
|
||||
"dims": [2, 1, 2, 1, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.0625, 0],
|
||||
"dims": [1, 2, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "fused conv3d with clip",
|
||||
"operator": "FusedConv",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "activation", "data": "Clip", "type": "string" },
|
||||
{ "name": "activation_params", "data": [1.0, 3.0], "type": "floats" },
|
||||
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
|
||||
{ "name": "auto_pad", "data": "VALID", "type": "string" },
|
||||
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
|
||||
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.25, 0.5, 0.75, 1],
|
||||
"dims": [1, 1, 2, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1],
|
||||
"dims": [2, 1, 2, 1, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 2.1875],
|
||||
"dims": [1, 2, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "fused conv3d with HardSigmoid, x=[1, 1, 2, 1, 2], f=[2, 1, 2, 1, 2], s=1, d=1, p=valid, relu",
|
||||
"operator": "FusedConv",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "activation", "data": "HardSigmoid", "type": "string" },
|
||||
{ "name": "activation_params", "data": [0.1, 0.3], "type": "floats" },
|
||||
{ "name": "kernel_shape", "data": [2, 1, 2], "type": "ints" },
|
||||
{ "name": "auto_pad", "data": "VALID", "type": "string" },
|
||||
{ "name": "strides", "data": [1, 1, 1], "type": "ints" },
|
||||
{ "name": "dilations", "data": [1, 1, 1], "type": "ints" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.25, 0.5, 0.75, 1],
|
||||
"dims": [1, 1, 2, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1],
|
||||
"dims": [2, 1, 2, 1, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.39375001192092896, 0.518750011920929],
|
||||
"dims": [1, 2, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
|
@ -1358,6 +1358,7 @@
|
|||
"fast-gelu.jsonc",
|
||||
"floor.jsonc",
|
||||
"fused-conv.jsonc",
|
||||
"fused-conv3dncdhw.jsonc",
|
||||
"gather-elements.jsonc",
|
||||
"gemm.jsonc",
|
||||
"global-average-pool.jsonc",
|
||||
|
|
Загрузка…
Ссылка в новой задаче