AMD Kernel Compatibility Fixes (#3180)

* Guard against APIs not available on AMD in reduction_utils, code cleanup

* More API alignment simplification

* Int conversion fix

* Syntax

---------

Co-authored-by: Logan Adams <loadams@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
This commit is contained in:
Connor Holmes 2023-09-08 09:54:57 -07:00 коммит произвёл GitHub
Родитель e194956571
Коммит 542dc0d5cb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 46 добавлений и 33 удалений

Просмотреть файл

@ -7,7 +7,6 @@
#include "ds_kernel_utils.h"
#include <cuda_fp16.h>
#include <stdint.h>
#ifdef BF16_AVAILABLE

Просмотреть файл

@ -21,6 +21,7 @@ used throughout the codebase.
constexpr int hw_warp_size = 64;
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_fp16.h>
#else // !__HIP_PLATFORM_HCC__
@ -37,6 +38,7 @@ constexpr int hw_warp_size = 32;
#endif // __CUDA_ARCH__ >= 800
#include <cooperative_groups.h>
#include <cuda_fp16.h>
#endif //__HIP_PLATFORM_HCC__

Просмотреть файл

@ -145,6 +145,13 @@ of reduce should be straightforward (can just wrap the sum reduction) and
would be a good extension of the header.
*/
DS_D_INLINE int _warp_rank()
{
const int thread_rank =
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;
return thread_rank / hw_warp_size;
}
/* Float element reduce implementations */
template <>
DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
@ -273,22 +280,34 @@ DS_D_INLINE __half init<ROpType::Max>()
template <>
DS_D_INLINE __half2 init<ROpType::Add>()
{
#ifdef __HIP_PLATFORM_HCC__
return __half2{_Float16_2{0x0000, 0x0000}};
#else
constexpr __half2_raw zero = {0x0000, 0x0000};
return __half2(zero);
#endif
}
template <>
DS_D_INLINE __half2 init<ROpType::Min>()
{
#ifdef __HIP_PLATFORM_HCC__
return __half2{_Float16_2{0x7C00, 0x7C00}};
#else
constexpr __half2_raw inf = {0x7C00, 0x7C00};
return __half2(inf);
#endif
}
template <>
DS_D_INLINE __half2 init<ROpType::Max>()
{
#ifdef __HIP_PLATFORM_HCC__
return __half2{_Float16_2{0xFC00, 0xFC00}};
#else
constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
return __half2(neg_inf);
#endif
}
template <ROpType Op, typename T>
@ -379,23 +398,15 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
Implementation for primary block reduction that serves both `block` and
`partitioned_block`.
`local_warp_rank` refers to the warp's location within the partition, so
for an unpartitioned threadblock this will be equivalent to
`warp_arg.meta_group_rank()`.
Similarly, the warp offset is the `local_warp_rank` of the warp with the
lowest rank in the partition. In the case of an 8 warp block with a
4 warp reduction, this would map to [0, 0, 0, 0, 4, 4, 4, 4].
Partition size is the number of warps per partition (equal to the thread
block in the default case). This enables us to only perform the warp reduction
when able to.
Total warps refers to the reduction width of the reduction, not
the number of warps in the block (which may exceed that
if the block is partitioned or if we do a conservative bound at
compile time).
*/
template <int total_warps, ROpType... Ops>
DS_D_INLINE void _block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp_arg,
float* data,
int warp_offset)
float* data)
{
constexpr int elems = sizeof...(Ops);
// Separated for now in case this no longer is true
@ -403,24 +414,30 @@ DS_D_INLINE void _block(cg::thread_block& tb,
// Unused when `partition_size == 1` or total_warps == 1
__shared__ float reduce_buffer[max_warps * elems];
#ifdef __HIP_PLATFORM_HCC__
const int total_threads = blockDim.x * blockDim.y * blockDim.z;
const int running_warps = total_threads / hw_warp_size;
#else
const int running_warps = warp_arg.meta_group_size();
#endif
// Always perform warp-scope reduction
_warp<Ops...>(warp_arg, data);
// If max_warps == 1 let's skip the runtime check
if (warp_arg.meta_group_size() > 1 && total_warps != 1) {
if (total_warps != 1) {
if (warp_arg.thread_rank() == 0) {
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::store_shared<bytes>(
reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i);
mem_access::store_shared<bytes>(reduce_buffer + elems * _warp_rank() + i, data + i);
}
}
// Synchronization inside block-uniform conditional is safe
tb.sync();
if (warp_arg.meta_group_rank() == 0) {
if (warp_arg.thread_rank() < warp_arg.meta_group_size()) {
if (_warp_rank() == 0) {
if (warp_arg.thread_rank() < running_warps) {
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::load_shared<bytes>(
@ -444,8 +461,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::load_shared<bytes>(data + i,
reduce_buffer + warp_arg.meta_group_rank() * elems + i);
mem_access::load_shared<bytes>(data + i, reduce_buffer + _warp_rank() * elems + i);
}
}
}
@ -460,7 +476,7 @@ us to obfuscate the details of the partitioned implementation.
template <ROpType Op, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
{
_block<warp_bound, Op>(tb, warp, &val, 0);
_block<warp_bound, Op>(tb, warp, &val);
}
template <ROpType Op1, ROpType Op2, int warp_bound>
@ -470,7 +486,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val2)
{
float data[2] = {val1, val2};
_block<warp_bound, Op1, Op2>(tb, warp, data, 0);
_block<warp_bound, Op1, Op2>(tb, warp, data);
val1 = data[0];
val2 = data[1];
}
@ -483,7 +499,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val3)
{
float data[3] = {val1, val2, val3};
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data, 0);
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
@ -498,7 +514,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
float& val4)
{
float data[4] = {val1, val2, val3, val4};
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data, 0);
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
val1 = data[0];
val2 = data[1];
val3 = data[2];
@ -518,8 +534,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op, num_threads>(warp, &val);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op>(tb, warp, &val, warp_offset);
_block<num_warps, Op>(tb, warp, &val);
}
}
@ -535,8 +550,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op1, Op2, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2>(tb, warp, data, warp_offset);
_block<num_warps, Op1, Op2>(tb, warp, data);
}
val1 = data[0];
@ -556,8 +570,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op1, Op2, Op3, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2, Op3>(tb, warp, data, warp_offset);
_block<num_warps, Op1, Op2, Op3>(tb, warp, data);
}
val1 = data[0];
@ -579,8 +592,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
_warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data, warp_offset);
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
}
val1 = data[0];