[ROCm EP/ MIGraphx EP] matmul_nbits: Use GPU_WARP_SIZE_HOST for host side code (#22045)
### Description For ROCm device, the host side code needs to call GPU_WARP_SIZE_HOST to query warpSize of the underlying GPU device. ### Motivation and Context Fixes MatMulNBits tests on gfx1100/01 which has warpSize of 32. Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
This commit is contained in:
Родитель
4d82404544
Коммит
b800328628
|
@ -289,7 +289,7 @@ bool TryMatMul4Bits(
|
|||
return false;
|
||||
}
|
||||
dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m);
|
||||
dim3 threads(kWarpSize, kColsPerThreadBlock);
|
||||
dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock);
|
||||
int blocks_per_K = (k + block_size - 1) / block_size;
|
||||
int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock +
|
||||
(zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0);
|
||||
|
|
Загрузка…
Ссылка в новой задаче