[js/webgpu] Optimize transpose (#21964)
### Description <!-- Describe your changes. --> Fix bugs in previous implementation and add more situations to go the optimized path. Below situations will go to the optimized path. 1. 2d inputs or squeezed 2d inputs 2. channels last or channels first transpose. For example, channel last transpose: [1, 256, 512, 512] -> [1, 512, 512, 256] For this case, the transpose becomes [256, 512x512] -> [512x512, 256] ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> For SD Turbo demo, the total transpose time becomes 39.98ms from 122.09ms. And the correspnding percents becomes 3.89% from 11.05% in this demo. This PR will also help #21618, the total transpose time in that demo becomes 17.32 ms from 70.25 ms on my iGPUs.
This commit is contained in:
Родитель
190588bb64
Коммит
a80bfed5b4
|
@ -875,11 +875,12 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
|
||||
const globalIdxDefinition = is1DimensionDispatch
|
||||
? 'let global_idx = global_id.x; let local_idx = local_id.x;'
|
||||
: `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
|
||||
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
|
||||
workgroupSizeX * workgroupSizeY * workgroupSizeZ
|
||||
}u + local_idx;`;
|
||||
? `let global_idx = global_id.x;
|
||||
let local_idx = local_id.x;
|
||||
let workgroup_index = workgroup_id.x;`
|
||||
: `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
|
||||
workgroup_id.y * num_workgroups[0] + workgroup_id.x;
|
||||
let global_idx = workgroup_index * ${workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`;
|
||||
|
||||
return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
|
||||
fn main(${paramList}) {
|
||||
|
|
|
@ -36,33 +36,62 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
|
|||
return reverseFunc.join('\n');
|
||||
};
|
||||
|
||||
const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => {
|
||||
const newShape: number[] = [];
|
||||
const newPerm: number[] = [];
|
||||
for (let i = 0; i < shape.length; ++i) {
|
||||
if (shape[i] !== 1) {
|
||||
newShape.push(shape[i]);
|
||||
}
|
||||
if (shape[adjustedPerm[i]] !== 1) {
|
||||
newPerm.push(adjustedPerm[i]);
|
||||
}
|
||||
}
|
||||
return { newShape, newPerm };
|
||||
};
|
||||
|
||||
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
|
||||
const inputDataType = inputTensor.dataType;
|
||||
const inputRank = inputTensor.dims.length;
|
||||
const perm = getAdjustedPerm(inputRank, permAttr);
|
||||
const outputShape = getOutputShape(inputTensor.dims, perm);
|
||||
const output = outputVariable('output', inputDataType, outputShape.length);
|
||||
const input = inputVariable('a', inputDataType, inputRank);
|
||||
const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm);
|
||||
const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]);
|
||||
const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]);
|
||||
const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst;
|
||||
let newInputShape = useShared ? newShape : inputTensor.dims;
|
||||
let newOutputShape = outputShape;
|
||||
if (useShared) {
|
||||
newInputShape = channelsLast
|
||||
? [newShape[0], newShape[1] * newShape[2]]
|
||||
: channelsFirst
|
||||
? [newShape[0] * newShape[1], newShape[2]]
|
||||
: newShape;
|
||||
newOutputShape = [newInputShape[1], newInputShape[0]];
|
||||
}
|
||||
const input = inputVariable('a', inputDataType, newInputShape.length);
|
||||
const output = outputVariable('output', inputDataType, newOutputShape.length);
|
||||
const tileSize = 16;
|
||||
let getShaderSource;
|
||||
if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) {
|
||||
const wgslType = output.type.value;
|
||||
const workgroupSize: [number, number, number] = [16, 16, 1];
|
||||
if (useShared) {
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
|
||||
var<workgroup> tile : array<array<${wgslType}, ${workgroupSize[0] + 1}>, ${workgroupSize[0]}>;
|
||||
${shaderHelper.mainStart(workgroupSize)}
|
||||
var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x;
|
||||
var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y;
|
||||
let width = uniforms.output_shape[0];
|
||||
let height = uniforms.output_shape[1];
|
||||
if (x < width && y < height) {
|
||||
tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')};
|
||||
var<workgroup> tile : array<array<${output.type.value}, ${tileSize + 1}>, ${tileSize}>;
|
||||
${shaderHelper.mainStart([tileSize, tileSize, 1])}
|
||||
let stride = (uniforms.output_shape[1] - 1) / ${tileSize} + 1;
|
||||
let workgroup_id_x = workgroup_index % stride;
|
||||
let workgroup_id_y = workgroup_index / stride;
|
||||
let input_col = workgroup_id_y * ${tileSize}u + local_id.x;
|
||||
let input_row = workgroup_id_x * ${tileSize}u + local_id.y;
|
||||
if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {
|
||||
tile[local_id.y][local_id.x] = ${input.getByIndices(`${input.type.indices}(input_row, input_col)`)};
|
||||
}
|
||||
workgroupBarrier();
|
||||
x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x;
|
||||
y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y;
|
||||
if (x < height && y < width) {
|
||||
${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')}
|
||||
|
||||
let output_col = workgroup_id_x * ${tileSize}u + local_id.x;
|
||||
let output_row = workgroup_id_y * ${tileSize}u + local_id.y;
|
||||
if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {
|
||||
${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')}
|
||||
}
|
||||
}`;
|
||||
} else {
|
||||
|
@ -81,16 +110,18 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
}`;
|
||||
}
|
||||
return {
|
||||
name: 'Transpose',
|
||||
name: useShared ? 'TransposeShared' : 'Transpose',
|
||||
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
|
||||
getRunData: () => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
|
||||
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
||||
dispatchGroup: useShared
|
||||
? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }
|
||||
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
||||
programUniforms: [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
...createTensorShapeVariables(inputTensor.dims, outputShape),
|
||||
...createTensorShapeVariables(newInputShape, newOutputShape),
|
||||
],
|
||||
};
|
||||
},
|
||||
|
|
|
@ -167,6 +167,78 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose squeezed 2d - perms:[0, 2, 1, 3]",
|
||||
"operator": "Transpose",
|
||||
"attributes": [{ "name": "perm", "data": [0, 2, 1, 3], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1, 3 , 4, 1]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [1, 3, 4, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12],
|
||||
"dims": [1, 4, 3, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose 4D channelsFirst - perms:[0, 3, 1, 2]",
|
||||
"operator": "Transpose",
|
||||
"attributes": [{ "name": "perm", "data": [0, 3, 1, 2], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1, 2, 3, 4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
|
||||
"dims": [1, 2, 3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
|
||||
"dims": [1, 4, 2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose 4D channelsLast - perms:[0, 2, 3, 1]",
|
||||
"operator": "Transpose",
|
||||
"attributes": [{ "name": "perm", "data": [0, 2, 3, 1], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1, 2, 3, 4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
|
||||
"dims": [1, 2, 3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24],
|
||||
"dims": [1, 3, 4, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
|
||||
"operator": "Transpose",
|
||||
|
|
Загрузка…
Ссылка в новой задаче