### Description
<!-- Describe your changes. -->
Fix bugs in previous implementation and add more situations to go the
optimized path.

Below situations will go to the optimized path.
1. 2d inputs or squeezed 2d inputs
2. channels last or channels first transpose. For example, channel last
transpose: [1, 256, 512, 512] -> [1, 512, 512, 256]
For this case, the transpose becomes [256, 512x512] -> [512x512, 256]

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
For SD Turbo demo, the total transpose time becomes 39.98ms from
122.09ms. And the correspnding percents becomes 3.89% from 11.05% in
this demo.

This PR will also help #21618, the total transpose time in that demo
becomes 17.32 ms from 70.25 ms on my iGPUs.
This commit is contained in:
Jiajia Qin 2024-09-05 03:04:04 +08:00 коммит произвёл GitHub
Родитель 190588bb64
Коммит a80bfed5b4
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
3 изменённых файлов: 129 добавлений и 25 удалений

Просмотреть файл

@ -875,11 +875,12 @@ class ShaderHelperImpl implements ShaderHelper {
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
const globalIdxDefinition = is1DimensionDispatch
? 'let global_idx = global_id.x; let local_idx = local_id.x;'
: `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
workgroupSizeX * workgroupSizeY * workgroupSizeZ
}u + local_idx;`;
? `let global_idx = global_id.x;
let local_idx = local_id.x;
let workgroup_index = workgroup_id.x;`
: `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x;
let global_idx = workgroup_index * ${workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`;
return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
fn main(${paramList}) {

Просмотреть файл

@ -36,33 +36,62 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
return reverseFunc.join('\n');
};
const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => {
const newShape: number[] = [];
const newPerm: number[] = [];
for (let i = 0; i < shape.length; ++i) {
if (shape[i] !== 1) {
newShape.push(shape[i]);
}
if (shape[adjustedPerm[i]] !== 1) {
newPerm.push(adjustedPerm[i]);
}
}
return { newShape, newPerm };
};
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
const inputDataType = inputTensor.dataType;
const inputRank = inputTensor.dims.length;
const perm = getAdjustedPerm(inputRank, permAttr);
const outputShape = getOutputShape(inputTensor.dims, perm);
const output = outputVariable('output', inputDataType, outputShape.length);
const input = inputVariable('a', inputDataType, inputRank);
const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm);
const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]);
const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]);
const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst;
let newInputShape = useShared ? newShape : inputTensor.dims;
let newOutputShape = outputShape;
if (useShared) {
newInputShape = channelsLast
? [newShape[0], newShape[1] * newShape[2]]
: channelsFirst
? [newShape[0] * newShape[1], newShape[2]]
: newShape;
newOutputShape = [newInputShape[1], newInputShape[0]];
}
const input = inputVariable('a', inputDataType, newInputShape.length);
const output = outputVariable('output', inputDataType, newOutputShape.length);
const tileSize = 16;
let getShaderSource;
if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) {
const wgslType = output.type.value;
const workgroupSize: [number, number, number] = [16, 16, 1];
if (useShared) {
getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
var<workgroup> tile : array<array<${wgslType}, ${workgroupSize[0] + 1}>, ${workgroupSize[0]}>;
${shaderHelper.mainStart(workgroupSize)}
var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x;
var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y;
let width = uniforms.output_shape[0];
let height = uniforms.output_shape[1];
if (x < width && y < height) {
tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')};
var<workgroup> tile : array<array<${output.type.value}, ${tileSize + 1}>, ${tileSize}>;
${shaderHelper.mainStart([tileSize, tileSize, 1])}
let stride = (uniforms.output_shape[1] - 1) / ${tileSize} + 1;
let workgroup_id_x = workgroup_index % stride;
let workgroup_id_y = workgroup_index / stride;
let input_col = workgroup_id_y * ${tileSize}u + local_id.x;
let input_row = workgroup_id_x * ${tileSize}u + local_id.y;
if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {
tile[local_id.y][local_id.x] = ${input.getByIndices(`${input.type.indices}(input_row, input_col)`)};
}
workgroupBarrier();
x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x;
y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y;
if (x < height && y < width) {
${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')}
let output_col = workgroup_id_x * ${tileSize}u + local_id.x;
let output_row = workgroup_id_y * ${tileSize}u + local_id.y;
if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {
${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')}
}
}`;
} else {
@ -81,16 +110,18 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
}`;
}
return {
name: 'Transpose',
name: useShared ? 'TransposeShared' : 'Transpose',
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
dispatchGroup: useShared
? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
...createTensorShapeVariables(inputTensor.dims, outputShape),
...createTensorShapeVariables(newInputShape, newOutputShape),
],
};
},

Просмотреть файл

@ -167,6 +167,78 @@
}
]
},
{
"name": "Transpose squeezed 2d - perms:[0, 2, 1, 3]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [0, 2, 1, 3], "type": "ints" }],
"cases": [
{
"name": "T[1, 3 , 4, 1]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [1, 3, 4, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12],
"dims": [1, 4, 3, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose 4D channelsFirst - perms:[0, 3, 1, 2]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [0, 3, 1, 2], "type": "ints" }],
"cases": [
{
"name": "T[1, 2, 3, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [1, 2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
"dims": [1, 4, 2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose 4D channelsLast - perms:[0, 2, 3, 1]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [0, 2, 3, 1], "type": "ints" }],
"cases": [
{
"name": "T[1, 2, 3, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [1, 2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24],
"dims": [1, 3, 4, 2],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
"operator": "Transpose",