// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "onnxruntime_c_api.h" #if ORT_API_VERSION >= 16 #include "onnxruntime_float16.h" #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 #include #endif namespace Ort { namespace Custom { // MFloat16 struct MFloat16 : onnxruntime_float16::Float16Impl { private: constexpr explicit MFloat16(uint16_t v) noexcept { val = v; } public: using Base = onnxruntime_float16::Float16Impl; MFloat16() = default; constexpr static MFloat16 FromBits(uint16_t v) noexcept { return MFloat16(v); } explicit MFloat16(float v) noexcept { val = Base::ToUint16Impl(v); } float ToFloat() const noexcept { return Base::ToFloatImpl(); } using Base::Abs; using Base::AreZero; using Base::IsFinite; using Base::IsInfinity; using Base::IsNaN; using Base::IsNaNOrZero; using Base::IsNegative; using Base::IsNegativeInfinity; using Base::IsNormal; using Base::IsPositiveInfinity; using Base::IsSubnormal; using Base::Negate; explicit operator float() const noexcept { return ToFloat(); } using Base::operator==; using Base::operator!=; using Base::operator<; }; #if defined(__CUDACC__) || defined(__HIPCC__) #define ORTC_HOST_DEVICE __host__ __device__ #else #define ORTC_HOST_DEVICE #endif // BFloat16 struct BFloat16 : onnxruntime_float16::BFloat16Impl { using Base = onnxruntime_float16::BFloat16Impl; #if defined(__HIP__) ORTC_HOST_DEVICE BFloat16() = default; #else BFloat16() = default; #endif struct FromBitsT {}; static constexpr ORTC_HOST_DEVICE FromBitsT FromBits() noexcept { return FromBitsT(); } constexpr ORTC_HOST_DEVICE BFloat16(unsigned short bits, FromBitsT) noexcept { val = bits; } static constexpr ORTC_HOST_DEVICE BFloat16 FromBits(uint16_t bits) noexcept { return BFloat16(bits, FromBits()); } inline ORTC_HOST_DEVICE BFloat16(float v) noexcept { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 val = __bfloat16_as_ushort(__float2bfloat16(v)); #elif defined(__HIP__) // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment. if (v != v) { // isnan val = UINT16_C(0x7FC0); } else { union { uint32_t U32; float F32; }; F32 = v; uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); val = static_cast((U32 + rounding_bias) >> 16); } #else // Use C isnan to work both in host and device if (std::isnan(v)) { val = kPositiveQNaNBits; } else { auto get_msb_half = [](float fl) { uint16_t result; if constexpr (onnxruntime_float16::detail::endian::native == onnxruntime_float16::detail::endian::little) { std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); } else { std::memcpy(&result, &fl, sizeof(uint16_t)); } return result; }; uint16_t upper_bits = get_msb_half(v); union { uint32_t U32; float F32; }; F32 = v; U32 += (upper_bits & 1) + kRoundToNearest; val = get_msb_half(F32); } #endif } inline ORTC_HOST_DEVICE float ToFloat() const noexcept { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 return __bfloat162float(*reinterpret_cast(&val)); #elif defined(__HIP__) // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment. float result = 0; uint32_t tmp = val; tmp <<= 16; float* tempRes = reinterpret_cast(&tmp); result = *tempRes; return result; #else if (IsNaNHostDevice()) { return std::numeric_limits::quiet_NaN(); } float result = 0; char* const first = reinterpret_cast(&result); if constexpr (onnxruntime_float16::detail::endian::native == onnxruntime_float16::detail::endian::little) { char* const second = first + sizeof(uint16_t); std::memcpy(second, &val, sizeof(uint16_t)); } else { std::memcpy(first, &val, sizeof(uint16_t)); } return result; #endif } static const BFloat16 NaN; static const BFloat16 NegativeNaN; static const BFloat16 Infinity; static const BFloat16 NegativeInfinity; static const BFloat16 Epsilon; static const BFloat16 MinValue; static const BFloat16 MaxValue; static const BFloat16 Zero; static const BFloat16 One; static const BFloat16 MinusOne; using Base::IsNegative; using Base::IsNaN; using Base::IsFinite; using Base::IsPositiveInfinity; using Base::IsNegativeInfinity; using Base::IsInfinity; using Base::IsNaNOrZero; using Base::IsNormal; using Base::IsSubnormal; using Base::Abs; using Base::Negate; ORTC_HOST_DEVICE operator float() const noexcept { return ToFloat(); } #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 ORTC_HOST_DEVICE BFloat16(const __nv_bfloat16& value) { val = *reinterpret_cast(&value); } explicit ORTC_HOST_DEVICE operator __nv_bfloat16() const { return *reinterpret_cast(&val); } #endif ORTC_HOST_DEVICE bool operator==(const BFloat16& rhs) const noexcept { if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) { // IEEE defines that NaN is not equal to anything, including itself. return false; } return val == rhs.val; } ORTC_HOST_DEVICE bool operator!=(const BFloat16& rhs) const noexcept { return !(*this == rhs); } ORTC_HOST_DEVICE bool operator<(const BFloat16& rhs) const noexcept { if (IsNaNHostDevice() || rhs.IsNaNHostDevice()) { // IEEE defines that NaN is unordered with respect to everything, including itself. return false; } const bool left_is_negative = IsNegativeHostDevice(); if (left_is_negative != rhs.IsNegativeHostDevice()) { // When the signs of left and right differ, we know that left is less than right if it is // the negative value. The exception to this is if both values are zero, in which case IEEE // says they should be equal, even if the signs differ. return left_is_negative && !AreZeroHostDevice(*this, rhs); } return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); } ORTC_HOST_DEVICE bool IsNegativeHostDevice() const noexcept { return (val & kSignMask) != 0; } ORTC_HOST_DEVICE bool IsNaNHostDevice() const noexcept { return static_cast(val & ~kSignMask) > kPositiveInfinityBits; } ORTC_HOST_DEVICE static bool AreZeroHostDevice(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { // IEEE defines that positive and negative zero are equal, this gives us a quick equality check // for two values by or'ing the private bits together and stripping the sign. They are both zero, // and therefore equivalent, if the resulting value is still zero. return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; } }; } // namespace Custom } // namespace Ort #endif