зеркало из https://github.com/microsoft/DeepSpeed.git
AMD Kernel Compatibility Fixes (#3180)
* Guard against APIs not available on AMD in reduction_utils, code cleanup * More API alignment simplification * Int conversion fix * Syntax --------- Co-authored-by: Logan Adams <loadams@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Ammar Ahmad Awan <ammar.awan@microsoft.com>
This commit is contained in:
Родитель
e194956571
Коммит
542dc0d5cb
|
@ -7,7 +7,6 @@
|
|||
|
||||
#include "ds_kernel_utils.h"
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef BF16_AVAILABLE
|
||||
|
|
|
@ -21,6 +21,7 @@ used throughout the codebase.
|
|||
constexpr int hw_warp_size = 64;
|
||||
#define HALF_PRECISION_AVAILABLE = 1
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#else // !__HIP_PLATFORM_HCC__
|
||||
|
||||
|
@ -37,6 +38,7 @@ constexpr int hw_warp_size = 32;
|
|||
#endif // __CUDA_ARCH__ >= 800
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#endif //__HIP_PLATFORM_HCC__
|
||||
|
||||
|
|
|
@ -145,6 +145,13 @@ of reduce should be straightforward (can just wrap the sum reduction) and
|
|||
would be a good extension of the header.
|
||||
*/
|
||||
|
||||
DS_D_INLINE int _warp_rank()
|
||||
{
|
||||
const int thread_rank =
|
||||
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;
|
||||
return thread_rank / hw_warp_size;
|
||||
}
|
||||
|
||||
/* Float element reduce implementations */
|
||||
template <>
|
||||
DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
|
||||
|
@ -273,22 +280,34 @@ DS_D_INLINE __half init<ROpType::Max>()
|
|||
template <>
|
||||
DS_D_INLINE __half2 init<ROpType::Add>()
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
return __half2{_Float16_2{0x0000, 0x0000}};
|
||||
#else
|
||||
constexpr __half2_raw zero = {0x0000, 0x0000};
|
||||
return __half2(zero);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE __half2 init<ROpType::Min>()
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
return __half2{_Float16_2{0x7C00, 0x7C00}};
|
||||
#else
|
||||
constexpr __half2_raw inf = {0x7C00, 0x7C00};
|
||||
return __half2(inf);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
DS_D_INLINE __half2 init<ROpType::Max>()
|
||||
{
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
return __half2{_Float16_2{0xFC00, 0xFC00}};
|
||||
#else
|
||||
constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
|
||||
return __half2(neg_inf);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ROpType Op, typename T>
|
||||
|
@ -379,23 +398,15 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
|
|||
Implementation for primary block reduction that serves both `block` and
|
||||
`partitioned_block`.
|
||||
|
||||
`local_warp_rank` refers to the warp's location within the partition, so
|
||||
for an unpartitioned threadblock this will be equivalent to
|
||||
`warp_arg.meta_group_rank()`.
|
||||
|
||||
Similarly, the warp offset is the `local_warp_rank` of the warp with the
|
||||
lowest rank in the partition. In the case of an 8 warp block with a
|
||||
4 warp reduction, this would map to [0, 0, 0, 0, 4, 4, 4, 4].
|
||||
|
||||
Partition size is the number of warps per partition (equal to the thread
|
||||
block in the default case). This enables us to only perform the warp reduction
|
||||
when able to.
|
||||
Total warps refers to the reduction width of the reduction, not
|
||||
the number of warps in the block (which may exceed that
|
||||
if the block is partitioned or if we do a conservative bound at
|
||||
compile time).
|
||||
*/
|
||||
template <int total_warps, ROpType... Ops>
|
||||
DS_D_INLINE void _block(cg::thread_block& tb,
|
||||
cg::thread_block_tile<hw_warp_size>& warp_arg,
|
||||
float* data,
|
||||
int warp_offset)
|
||||
float* data)
|
||||
{
|
||||
constexpr int elems = sizeof...(Ops);
|
||||
// Separated for now in case this no longer is true
|
||||
|
@ -403,24 +414,30 @@ DS_D_INLINE void _block(cg::thread_block& tb,
|
|||
// Unused when `partition_size == 1` or total_warps == 1
|
||||
__shared__ float reduce_buffer[max_warps * elems];
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
const int total_threads = blockDim.x * blockDim.y * blockDim.z;
|
||||
const int running_warps = total_threads / hw_warp_size;
|
||||
#else
|
||||
const int running_warps = warp_arg.meta_group_size();
|
||||
#endif
|
||||
|
||||
// Always perform warp-scope reduction
|
||||
_warp<Ops...>(warp_arg, data);
|
||||
|
||||
// If max_warps == 1 let's skip the runtime check
|
||||
if (warp_arg.meta_group_size() > 1 && total_warps != 1) {
|
||||
if (total_warps != 1) {
|
||||
if (warp_arg.thread_rank() == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems; i++) {
|
||||
mem_access::store_shared<bytes>(
|
||||
reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i);
|
||||
mem_access::store_shared<bytes>(reduce_buffer + elems * _warp_rank() + i, data + i);
|
||||
}
|
||||
}
|
||||
|
||||
// Synchronization inside block-uniform conditional is safe
|
||||
tb.sync();
|
||||
|
||||
if (warp_arg.meta_group_rank() == 0) {
|
||||
if (warp_arg.thread_rank() < warp_arg.meta_group_size()) {
|
||||
if (_warp_rank() == 0) {
|
||||
if (warp_arg.thread_rank() < running_warps) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems; i++) {
|
||||
mem_access::load_shared<bytes>(
|
||||
|
@ -444,8 +461,7 @@ DS_D_INLINE void _block(cg::thread_block& tb,
|
|||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems; i++) {
|
||||
mem_access::load_shared<bytes>(data + i,
|
||||
reduce_buffer + warp_arg.meta_group_rank() * elems + i);
|
||||
mem_access::load_shared<bytes>(data + i, reduce_buffer + _warp_rank() * elems + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -460,7 +476,7 @@ us to obfuscate the details of the partitioned implementation.
|
|||
template <ROpType Op, int warp_bound>
|
||||
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
|
||||
{
|
||||
_block<warp_bound, Op>(tb, warp, &val, 0);
|
||||
_block<warp_bound, Op>(tb, warp, &val);
|
||||
}
|
||||
|
||||
template <ROpType Op1, ROpType Op2, int warp_bound>
|
||||
|
@ -470,7 +486,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
|
|||
float& val2)
|
||||
{
|
||||
float data[2] = {val1, val2};
|
||||
_block<warp_bound, Op1, Op2>(tb, warp, data, 0);
|
||||
_block<warp_bound, Op1, Op2>(tb, warp, data);
|
||||
val1 = data[0];
|
||||
val2 = data[1];
|
||||
}
|
||||
|
@ -483,7 +499,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
|
|||
float& val3)
|
||||
{
|
||||
float data[3] = {val1, val2, val3};
|
||||
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data, 0);
|
||||
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data);
|
||||
val1 = data[0];
|
||||
val2 = data[1];
|
||||
val3 = data[2];
|
||||
|
@ -498,7 +514,7 @@ DS_D_INLINE void block(cg::thread_block& tb,
|
|||
float& val4)
|
||||
{
|
||||
float data[4] = {val1, val2, val3, val4};
|
||||
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data, 0);
|
||||
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data);
|
||||
val1 = data[0];
|
||||
val2 = data[1];
|
||||
val3 = data[2];
|
||||
|
@ -518,8 +534,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
_warp<Op, num_threads>(warp, &val);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
|
||||
_block<num_warps, Op>(tb, warp, &val, warp_offset);
|
||||
_block<num_warps, Op>(tb, warp, &val);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -535,8 +550,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
_warp<Op1, Op2, num_threads>(warp, data);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
|
||||
_block<num_warps, Op1, Op2>(tb, warp, data, warp_offset);
|
||||
_block<num_warps, Op1, Op2>(tb, warp, data);
|
||||
}
|
||||
|
||||
val1 = data[0];
|
||||
|
@ -556,8 +570,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
_warp<Op1, Op2, Op3, num_threads>(warp, data);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
|
||||
_block<num_warps, Op1, Op2, Op3>(tb, warp, data, warp_offset);
|
||||
_block<num_warps, Op1, Op2, Op3>(tb, warp, data);
|
||||
}
|
||||
|
||||
val1 = data[0];
|
||||
|
@ -579,8 +592,7 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
|
|||
_warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
|
||||
} else {
|
||||
constexpr int num_warps = num_threads / hw_warp_size;
|
||||
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
|
||||
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data, warp_offset);
|
||||
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data);
|
||||
}
|
||||
|
||||
val1 = data[0];
|
||||
|
|
Загрузка…
Ссылка в новой задаче