Implement 2d tiled matmulnbits specialized for prefill (#23058)
### Description This change implements matmul4bits with tiling both for A and B. This is beneficial for prefill scenarios on Intel integrated GPUs, because each row of A has to run through the same set of shared rows of B. This change should improve core occupancy and model_benchmark does indicate improvements for prefill. The same shader is not used for generation because when A has just a single row, the other threads in the workgroup get unused and that hurts performance. ``` -- Baseline run on an Alderlake GPU -- C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.72338e+07 avg (tokens/s): 29.0707 << p50 (us): 1.72548e+07 stddev (us): 57012.8 n: 5 * 501 token(s) Token generation: avg (us): 79227.5 avg (tokens/s): 12.6219 p50 (us): 79284.4 stddev (us): 2109.72 n: 635 * 1 token(s) Token sampling: avg (us): 15.8198 avg (tokens/s): 63211.8 p50 (us): 14.3 stddev (us): 8.67178 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 27297.8 p50 (ms): 27269.8 stddev (ms): 89.4322 n: 5 Peak working set size (bytes): 5490987008 WebGPU device lost (2): Device was destroyed. ----------------------------------- With Prefill Optimization ---- C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.2135e+07 avg (tokens/s): 41.2856 << p50 (us): 1.21288e+07 stddev (us): 21282.1 n: 5 * 501 token(s) Token generation: avg (us): 78945.3 avg (tokens/s): 12.667 p50 (us): 78900.7 stddev (us): 2232.43 n: 635 * 1 token(s) Token sampling: avg (us): 20.5608 avg (tokens/s): 48636.3 p50 (us): 18.7 stddev (us): 19.0409 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 22163.8 p50 (ms): 22160.1 stddev (ms): 31.3122 n: 5 Peak working set size (bytes): 5478862848 WebGPU device lost (2): Device was destroyed. ```
This commit is contained in:
Родитель
d8de3c4096
Коммит
8800830a44
|
@ -39,6 +39,7 @@ std::string QuantizedDataType(int components) {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16;
|
||||
} // namespace
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
@ -321,6 +322,121 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddInput("scales", ShaderUsage::UseUniform);
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
|
||||
// This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion
|
||||
// M is the number of rows in A and M rows in the output.
|
||||
// N is the number of columns in B and N columns in the output.
|
||||
// K is the hidden/shared dimension number of columns in A and K rows in B.
|
||||
// Note in matmulnbits, B matrix is already transposed, however the following remains true
|
||||
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
|
||||
// K4/K8 are simply K divided by 4 or 8 respectively.
|
||||
shader.AdditionalImplementation() << R"INIT_SECTION(
|
||||
// Matrix dimensions and quantization parameters
|
||||
const TILE_SIZE : u32 = 16u;
|
||||
const VALUES_PER_VEC4 : u32 = 4u;
|
||||
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
|
||||
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
|
||||
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
|
||||
// This uses all 16 lanes on 12th gen intel chips.
|
||||
const BLOCKS_PER_CYCLE : u32 = 2u;
|
||||
const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE
|
||||
const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4;
|
||||
|
||||
//Shared memory
|
||||
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
|
||||
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
|
||||
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;
|
||||
|
||||
fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
|
||||
{
|
||||
if (a_global >= uniforms.M) {
|
||||
return;
|
||||
}
|
||||
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
|
||||
tile_A[slot][parallel_id] = local_A;
|
||||
}
|
||||
|
||||
fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t
|
||||
{
|
||||
// Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at
|
||||
// a time or 2 scales at every step.
|
||||
let scale_offset = vec_step_idx*2;
|
||||
let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset);
|
||||
return scales[idx+scale_idx];
|
||||
}
|
||||
|
||||
fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32)
|
||||
{
|
||||
if (b_global >= uniforms.N) {
|
||||
return;
|
||||
}
|
||||
let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE));
|
||||
let idx:u32 = parallel_id;
|
||||
if (idx % 2 == 0)
|
||||
{
|
||||
// Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time.
|
||||
// Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight
|
||||
// elements were holding one element each. For the case of each element holding 8 values, begin
|
||||
// would become vec_step_idx * 64/8 or vec_step_idx * 8.
|
||||
var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2);
|
||||
let b_value = input_b[b_global*uniforms.K8+weight_offset];
|
||||
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
|
||||
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
|
||||
tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale;
|
||||
tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale;
|
||||
tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale;
|
||||
tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale;
|
||||
tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale;
|
||||
tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale;
|
||||
tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale;
|
||||
tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale;
|
||||
}
|
||||
}
|
||||
|
||||
fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
|
||||
{
|
||||
var sum:output_value_t = 0;
|
||||
for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++)
|
||||
{
|
||||
sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
)INIT_SECTION";
|
||||
|
||||
shader.MainFunctionBody() << R"MAIN_FN(
|
||||
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
|
||||
// appears to give a performance win on Intel Gen12LP architecture.
|
||||
// This is likley because of locality of memory access, idy below in this approach
|
||||
// is the same as subgroup_id or lane id, while idx is the wave_id.
|
||||
// The work distribution therefore keeps memory accesses close together in
|
||||
// a single wave in this approach of indexing.
|
||||
let idx = u32(local_idx / TILE_SIZE);
|
||||
let idy = u32(local_idx % TILE_SIZE);
|
||||
let a_global_base = workgroup_id.x * TILE_SIZE;
|
||||
let b_global_base = workgroup_id.y * TILE_SIZE;
|
||||
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
|
||||
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
|
||||
{
|
||||
workgroupBarrier();
|
||||
loadA(idx, a_global_base+idx, vec_step, idy);
|
||||
loadB(idx, b_global_base+idx, vec_step, idy);
|
||||
workgroupBarrier();
|
||||
let result = computeDotProduct(idx, idy);
|
||||
tile_O[idx][idy]+=result;
|
||||
}
|
||||
workgroupBarrier();
|
||||
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
|
||||
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
|
||||
}
|
||||
)MAIN_FN";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
|
||||
const Tensor* a = context.Input(0);
|
||||
const Tensor* b = context.Input(1);
|
||||
|
@ -360,38 +476,65 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
|
|||
context.AdapterInfo().architecture == std::string_view{"gen-12lp"} &&
|
||||
block_size == 32;
|
||||
const bool has_zero_points = zero_points != nullptr;
|
||||
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
|
||||
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
|
||||
constexpr uint32_t output_number = 1;
|
||||
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};
|
||||
|
||||
if (use_block32) {
|
||||
components = 1;
|
||||
constexpr uint32_t workgroup_size = 128;
|
||||
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
|
||||
: 1;
|
||||
const uint32_t workgroup_x = workgroup_size / workgroup_y;
|
||||
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
|
||||
program.SetDispatchGroupSize(data_size / components / workgroup_y);
|
||||
if (use_block32 && batch_count == 1 &&
|
||||
components_a == 4 && components_b == 4 &&
|
||||
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
|
||||
MatMulNBitsProgramPrefill program;
|
||||
constexpr int32_t tile_size = 16;
|
||||
// subgroup_size here controls how many elements of the hidden dimension we load in a cycle.
|
||||
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
|
||||
// size just helps with optimal lane usage in the shader.
|
||||
constexpr int32_t subgroup_size = 16;
|
||||
program.SetWorkgroupSize(tile_size * subgroup_size);
|
||||
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
|
||||
(N + tile_size - 1) / tile_size,
|
||||
1);
|
||||
program
|
||||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
|
||||
{scales, ProgramTensorMetadataDependency::None}})
|
||||
.AddUniformVariables({{static_cast<uint32_t>(M)},
|
||||
{static_cast<uint32_t>(N)},
|
||||
{static_cast<uint32_t>(K)},
|
||||
{static_cast<uint32_t>(K / 4)},
|
||||
{static_cast<uint32_t>(K / 8)}})
|
||||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)});
|
||||
return context.RunProgram(program);
|
||||
} else {
|
||||
program.SetDispatchGroupSize(data_size / components / output_number);
|
||||
}
|
||||
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
|
||||
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
|
||||
constexpr uint32_t output_number = 1;
|
||||
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};
|
||||
|
||||
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
|
||||
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
|
||||
TensorShape reshaped_y_shape{batch_count, M, N / components};
|
||||
if (use_block32) {
|
||||
components = 1;
|
||||
constexpr uint32_t workgroup_size = 128;
|
||||
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
|
||||
: 1;
|
||||
const uint32_t workgroup_x = workgroup_size / workgroup_y;
|
||||
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
|
||||
program.SetDispatchGroupSize(data_size / components / workgroup_y);
|
||||
} else {
|
||||
program.SetDispatchGroupSize(data_size / components / output_number);
|
||||
}
|
||||
|
||||
program
|
||||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
|
||||
{scales, ProgramTensorMetadataDependency::None}})
|
||||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
|
||||
.AddUniformVariable({block_size})
|
||||
.CacheHint(std::to_string(output_number));
|
||||
if (has_zero_points) {
|
||||
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
|
||||
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
|
||||
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
|
||||
TensorShape reshaped_y_shape{batch_count, M, N / components};
|
||||
|
||||
program
|
||||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
|
||||
{scales, ProgramTensorMetadataDependency::None}})
|
||||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
|
||||
.AddUniformVariable({block_size})
|
||||
.CacheHint(std::to_string(output_number));
|
||||
if (has_zero_points) {
|
||||
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
|
||||
}
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
} // namespace webgpu
|
||||
|
|
|
@ -31,6 +31,20 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
|
|||
bool use_block32_;
|
||||
};
|
||||
|
||||
class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
|
||||
public:
|
||||
MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
|
||||
{"M", ProgramUniformVariableDataType::Uint32},
|
||||
{"N", ProgramUniformVariableDataType::Uint32},
|
||||
{"K", ProgramUniformVariableDataType::Uint32},
|
||||
{"K4", ProgramUniformVariableDataType::Uint32},
|
||||
{"K8", ProgramUniformVariableDataType::Uint32});
|
||||
};
|
||||
|
||||
class MatMulNBits final : public WebGpuKernel {
|
||||
public:
|
||||
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
|
||||
|
|
Загрузка…
Ссылка в новой задаче