diff --git a/stl/inc/limits b/stl/inc/limits index b764f7dbe..9421c6d30 100644 --- a/stl/inc/limits +++ b/stl/inc/limits @@ -1064,22 +1064,10 @@ extern int __isa_available; } template -_NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept { +_NODISCARD int _Countr_zero_tzcnt(const _Ty _Val) noexcept { constexpr int _Digits = numeric_limits<_Ty>::digits; constexpr _Ty _Max = (numeric_limits<_Ty>::max) (); -#ifndef __AVX2__ - // Because the widening done below will always give a non-0 value, checking for tzcnt - // is not required for 8-bit and 16-bit since the only difference in behavior between - // bsf and tzcnt is when the value is 0. - if constexpr (_Digits > 16) { - const bool _Definitely_have_tzcnt = __isa_available >= __ISA_AVAILABLE_AVX2; - if (!_Definitely_have_tzcnt && _Val == 0) { - return _Digits; - } - } -#endif // __AVX2__ - if constexpr (_Digits <= 32) { // Intended widening to int. This operation means that a narrow 0 will widen // to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros @@ -1087,18 +1075,68 @@ _NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept { return static_cast(_TZCNT_U32(static_cast(~_Max | _Val))); } else { #ifdef _M_IX86 - const unsigned int _High = _Val >> 32; - const unsigned int _Low = static_cast(_Val); + const auto _Low = static_cast(_Val); if (_Low == 0) { - return 32 + _Checked_x86_x64_countr_zero(_High); + const unsigned int _High = _Val >> 32; + return static_cast(32 + _TZCNT_U32(_High)); } else { - return _Checked_x86_x64_countr_zero(_Low); + return static_cast(_TZCNT_U32(_Low)); } #else // ^^^ _M_IX86 / !_M_IX86 vvv return static_cast(_TZCNT_U64(_Val)); #endif // _M_IX86 } } + +template +_NODISCARD int _Countr_zero_bsf(const _Ty _Val) noexcept { + constexpr int _Digits = numeric_limits<_Ty>::digits; + constexpr _Ty _Max = (numeric_limits<_Ty>::max) (); + + unsigned long _Result; + if constexpr (_Digits <= 32) { + // Intended widening to int. This operation means that a narrow 0 will widen + // to 0xFFFF....FFFF0... instead of 0. We need this to avoid counting all the zeros + // of the wider type. + if (!_BitScanForward(&_Result, static_cast(~_Max | _Val))) { + return _Digits; + } + } else { +#ifdef _M_IX86 + const auto _Low = static_cast(_Val); + if (_BitScanForward(&_Result, _Low)) { + return static_cast(_Result); + } + + const unsigned int _High = _Val >> 32; + if (!_BitScanForward(&_Result, _High)) { + return _Digits; + } else { + return static_cast(_Result + 32); + } +#else // ^^^ _M_IX86 / !_M_IX86 vvv + if (!_BitScanForward64(&_Result, _Val)) { + return _Digits; + } +#endif // _M_IX86 + } + return static_cast(_Result); +} + +template +_NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept { +#ifdef __AVX2__ + return _Countr_zero_tzcnt(_Val); +#else // __AVX2__ + const bool _Definitely_have_tzcnt = __isa_available >= __ISA_AVAILABLE_AVX2; + if (_Definitely_have_tzcnt) { + return _Countr_zero_tzcnt(_Val); + } else { + return _Countr_zero_bsf(_Val); + } +#endif // __AVX2__ +} + #undef _TZCNT_U32 #undef _TZCNT_U64 #endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) diff --git a/tests/std/tests/GH_001103_countl_zero_correctness/test.cpp b/tests/std/tests/GH_001103_countl_zero_correctness/test.cpp index e7d2d9236..f9b389658 100644 --- a/tests/std/tests/GH_001103_countl_zero_correctness/test.cpp +++ b/tests/std/tests/GH_001103_countl_zero_correctness/test.cpp @@ -39,5 +39,30 @@ int main() { assert(_Countl_zero_bsr(static_cast(0x0000'0000'0000'0013)) == 59); assert(_Countl_zero_bsr(static_cast(0x8000'0000'0000'0003)) == 0); assert(_Countl_zero_bsr(static_cast(0xF000'0000'0000'0008)) == 0); + + assert(_Countr_zero_bsf(static_cast(0x00)) == 8); + assert(_Countr_zero_bsf(static_cast(0x13)) == 0); + assert(_Countr_zero_bsf(static_cast(0x80)) == 7); + assert(_Countr_zero_bsf(static_cast(0xF8)) == 3); + + assert(_Countr_zero_bsf(static_cast(0x0000)) == 16); + assert(_Countr_zero_bsf(static_cast(0x0013)) == 0); + assert(_Countr_zero_bsf(static_cast(0x8000)) == 15); + assert(_Countr_zero_bsf(static_cast(0xF008)) == 3); + + assert(_Countr_zero_bsf(static_cast(0x0000'0000)) == 32); + assert(_Countr_zero_bsf(static_cast(0x0000'0013)) == 0); + assert(_Countr_zero_bsf(static_cast(0x8000'0000)) == 31); + assert(_Countr_zero_bsf(static_cast(0xF000'0008)) == 3); + + assert(_Countr_zero_bsf(static_cast(0x0000'0000)) == 32); + assert(_Countr_zero_bsf(static_cast(0x0000'0013)) == 0); + assert(_Countr_zero_bsf(static_cast(0x8000'0000)) == 31); + assert(_Countr_zero_bsf(static_cast(0xF000'0008)) == 3); + + assert(_Countr_zero_bsf(static_cast(0x0000'0000'0000'0000)) == 64); + assert(_Countr_zero_bsf(static_cast(0x0000'0000'0000'0013)) == 0); + assert(_Countr_zero_bsf(static_cast(0x8000'0000'0000'0000)) == 63); + assert(_Countr_zero_bsf(static_cast(0xF000'0000'0000'0008)) == 3); #endif // ^^^ defined(_M_IX86) || defined(_M_X64) ^^^ }