diff --git a/stl/inc/bit b/stl/inc/bit index d98eb56c5..658a7a150 100644 --- a/stl/inc/bit +++ b/stl/inc/bit @@ -96,23 +96,6 @@ _NODISCARD constexpr _Ty rotr(const _Ty _Val, const int _Rotation) noexcept { } } -// Implementation of popcount without using specialized CPU instructions. -// Used at compile time and when said instructions are not supported. -template -_NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept { - constexpr int _Digits = numeric_limits<_Ty>::digits; - // we static_cast these bit patterns in order to truncate them to the correct size - _Val = static_cast<_Ty>(_Val - ((_Val >> 1) & static_cast<_Ty>(0x5555'5555'5555'5555ull))); - _Val = static_cast<_Ty>((_Val & static_cast<_Ty>(0x3333'3333'3333'3333ull)) - + ((_Val >> 2) & static_cast<_Ty>(0x3333'3333'3333'3333ull))); - _Val = static_cast<_Ty>((_Val + (_Val >> 4)) & static_cast<_Ty>(0x0F0F'0F0F'0F0F'0F0Full)); - for (int _Shift_digits = 8; _Shift_digits < _Digits; _Shift_digits <<= 1) { - _Val = static_cast<_Ty>(_Val + static_cast<_Ty>(_Val >> _Shift_digits)); - } - // we want the bottom "slot" that's big enough to store _Digits - return static_cast(_Val & static_cast<_Ty>(_Digits + _Digits - 1)); -} - #if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) extern "C" { @@ -184,32 +167,8 @@ _NODISCARD int _Checked_x86_x64_countl_zero(const _Ty _Val) noexcept { } #endif // __AVX2__ } - -template -_NODISCARD int _Checked_x86_x64_popcount(const _Ty _Val) noexcept { - constexpr int _Digits = numeric_limits<_Ty>::digits; -#ifndef __AVX__ - const bool _Definitely_have_popcnt = __isa_available >= __ISA_AVAILABLE_SSE42; - if (!_Definitely_have_popcnt) { - return _Popcount_fallback(_Val); - } -#endif // !defined(__AVX__) - - if constexpr (_Digits <= 16) { - return static_cast(__popcnt16(_Val)); - } else if constexpr (_Digits == 32) { - return static_cast(__popcnt(_Val)); - } else { -#ifdef _M_IX86 - return static_cast(__popcnt(_Val >> 32) + __popcnt(static_cast(_Val))); -#else // ^^^ _M_IX86 / !_M_IX86 vvv - return static_cast(__popcnt64(_Val)); -#endif // _M_IX86 - } -} #endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) - #if defined(_M_ARM) || defined(_M_ARM64) #ifdef __clang__ // TRANSITION, GH-1586 _NODISCARD constexpr int _Clang_arm_arm64_countl_zero(const unsigned short _Val) { @@ -283,14 +242,9 @@ _NODISCARD constexpr int countr_one(const _Ty _Val) noexcept { return _Countr_zero(static_cast<_Ty>(~_Val)); } -template , int> _Enabled = 0> +template , int> = 0> _NODISCARD constexpr int popcount(const _Ty _Val) noexcept { -#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) - if (!_STD is_constant_evaluated()) { - return _Checked_x86_x64_popcount(_Val); - } -#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC)) - return _Popcount_fallback(_Val); + return _Popcount(_Val); } enum class endian { little = 0, big = 1, native = little }; diff --git a/stl/inc/limits b/stl/inc/limits index bfe6ee028..5a693b939 100644 --- a/stl/inc/limits +++ b/stl/inc/limits @@ -1040,6 +1040,23 @@ _NODISCARD constexpr int _Countr_zero_fallback(const _Ty _Val) noexcept { return _Digits - _Countl_zero_fallback(static_cast<_Ty>(static_cast<_Ty>(~_Val) & static_cast<_Ty>(_Val - 1))); } +// Implementation of popcount without using specialized CPU instructions. +// Used at compile time and when said instructions are not supported. +template +_NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept { + constexpr int _Digits = numeric_limits<_Ty>::digits; + // we static_cast these bit patterns in order to truncate them to the correct size + _Val = static_cast<_Ty>(_Val - ((_Val >> 1) & static_cast<_Ty>(0x5555'5555'5555'5555ull))); + _Val = static_cast<_Ty>((_Val & static_cast<_Ty>(0x3333'3333'3333'3333ull)) + + ((_Val >> 2) & static_cast<_Ty>(0x3333'3333'3333'3333ull))); + _Val = static_cast<_Ty>((_Val + (_Val >> 4)) & static_cast<_Ty>(0x0F0F'0F0F'0F0F'0F0Full)); + for (int _Shift_digits = 8; _Shift_digits < _Digits; _Shift_digits <<= 1) { + _Val = static_cast<_Ty>(_Val + static_cast<_Ty>(_Val >> _Shift_digits)); + } + // we want the bottom "slot" that's big enough to store _Digits + return static_cast(_Val & static_cast<_Ty>(_Digits + _Digits - 1)); +} + #if defined(_M_IX86) || defined(_M_X64) extern "C" { extern int __isa_available; @@ -1092,6 +1109,38 @@ _NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept { #undef _TZCNT_U64 #endif // defined(_M_IX86) || defined(_M_X64) +#if (defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))) && !defined(_M_CEE_PURE) && !defined(__CUDACC__) \ + && !defined(__INTEL_COMPILER) +#define _HAS_POPCNT_INTRINSICS 1 +#else // ^^^ intrinsics available ^^^ / vvv intrinsics unavailable vvv +#define _HAS_POPCNT_INTRINSICS 0 +#endif // ^^^ intrinsics unavailable ^^^ + +#if _HAS_POPCNT_INTRINSICS +template +_NODISCARD int _Checked_x86_x64_popcount(const _Ty _Val) noexcept { + constexpr int _Digits = numeric_limits<_Ty>::digits; +#ifndef __AVX__ + const bool _Definitely_have_popcnt = __isa_available >= __ISA_AVAILABLE_SSE42; + if (!_Definitely_have_popcnt) { + return _Popcount_fallback(_Val); + } +#endif // !defined(__AVX__) + + if constexpr (_Digits <= 16) { + return static_cast(__popcnt16(_Val)); + } else if constexpr (_Digits == 32) { + return static_cast(__popcnt(_Val)); + } else { +#ifdef _M_IX86 + return static_cast(__popcnt(_Val >> 32) + __popcnt(static_cast(_Val))); +#else // ^^^ _M_IX86 / !_M_IX86 vvv + return static_cast(__popcnt64(_Val)); +#endif // _M_IX86 + } +} +#endif // _HAS_POPCNT_INTRINSICS + template constexpr bool _Is_standard_unsigned_integer = _Is_any_of_v, unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>; @@ -1103,12 +1152,27 @@ _NODISCARD constexpr int _Countr_zero(const _Ty _Val) noexcept { if (!_STD is_constant_evaluated()) { return _Checked_x86_x64_countr_zero(_Val); } -#endif // defined(__cpp_lib_is_constant_evaluated) +#endif // __cpp_lib_is_constant_evaluated #endif // defined(_M_IX86) || defined(_M_X64) // C++17 constexpr gcd() calls this function, so it should be constexpr unless we detect runtime evaluation. return _Countr_zero_fallback(_Val); } +template , int> _Enabled = 0> +_NODISCARD _CONSTEXPR20 int _Popcount(const _Ty _Val) noexcept { +#if _HAS_POPCNT_INTRINSICS +#ifdef __cpp_lib_is_constant_evaluated + if (!_STD is_constant_evaluated()) +#endif // __cpp_lib_is_constant_evaluated + { + return _Checked_x86_x64_popcount(_Val); + } +#endif // _HAS_POPCNT_INTRINSICS + return _Popcount_fallback(_Val); +} + +#undef _HAS_POPCNT_INTRINSICS + _STD_END #pragma pop_macro("new") _STL_RESTORE_CLANG_WARNINGS diff --git a/stl/inc/vector b/stl/inc/vector index 52fef59fc..bd5b76820 100644 --- a/stl/inc/vector +++ b/stl/inc/vector @@ -3123,8 +3123,8 @@ _INLINE_VAR constexpr bool _Is_vb_iterator<_Vb_iterator<_Alloc>, _RequiresMutabl template _INLINE_VAR constexpr bool _Is_vb_iterator<_Vb_const_iterator<_Alloc>, false> = true; -template -_CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) { +template +_CONSTEXPR20 void _Fill_vbool(_VbIt _First, _VbIt _Last, const _Ty& _Val) { // Set [_First, _Last) to _Val if (_First == _Last) { return; @@ -3172,6 +3172,92 @@ _CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) { *_VbFirst = (*_VbFirst & _LastDestMask) | (_FillVal & _LastSourceMask); } } + +template +_NODISCARD _CONSTEXPR20 _VbIt _Find_vbool(_VbIt _First, const _VbIt _Last, const _Ty& _Val) { + // Find _Val in [_First, _Last) + if (_First == _Last) { + return _First; + } + + const _Vbase* _VbFirst = _First._Myptr; + const _Vbase* const _VbLast = _Last._Myptr; + + const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff; + + if (_VbFirst == _VbLast) { + // We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe + const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff); + const auto _SourceMask = _FirstSourceMask & _LastSourceMask; + const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask; + const auto _Count = _Countr_zero(_SelectVal); + return _Count == _VBITS ? _Last : _First + static_cast(_Count - _First._Myoff); + } + + const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask; + const auto _FirstCount = _Countr_zero(_FirstVal); + if (_FirstCount != _VBITS) { + return _First + static_cast(_FirstCount - _First._Myoff); + } + ++_VbFirst; + + _Iter_diff_t<_VbIt> _TotalCount = static_cast(_VBITS - _First._Myoff); + for (; _VbFirst != _VbLast; ++_VbFirst, _TotalCount += _VBITS) { + const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst; + const auto _Count = _Countr_zero(_SelectVal); + if (_Count != _VBITS) { + return _First + (_TotalCount + _Count); + } + } + + if (_Last._Myoff != 0) { + const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff); + const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask; + const auto _Count = _Countr_zero(_LastVal); + if (_Count != _VBITS) { + return _First + (_TotalCount + _Count); + } + } + + return _Last; +} + +template +_NODISCARD _CONSTEXPR20 _Iter_diff_t<_VbIt> _Count_vbool(_VbIt _First, const _VbIt _Last, const _Ty& _Val) { + if (_First == _Last) { + return 0; + } + + const _Vbase* _VbFirst = _First._Myptr; + const _Vbase* const _VbLast = _Last._Myptr; + + const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff; + + if (_VbFirst == _VbLast) { + // We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe + const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff); + const auto _SourceMask = _FirstSourceMask & _LastSourceMask; + const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask; + return _Popcount(_SelectVal); + } + + const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask; + _Iter_diff_t<_VbIt> _Count = _Popcount(_FirstVal); + ++_VbFirst; + + for (; _VbFirst != _VbLast; ++_VbFirst) { + const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst; + _Count += _Popcount(_SelectVal); + } + + if (_Last._Myoff != 0) { + const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff); + const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask; + _Count += _Popcount(_LastVal); + } + + return _Count; +} _STD_END #pragma pop_macro("new") diff --git a/stl/inc/xutility b/stl/inc/xutility index 68fd05a74..9091ec0bf 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -5345,8 +5345,12 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(const _InIt _First, const _InIt _L template _NODISCARD _CONSTEXPR20 _InIt find(_InIt _First, const _InIt _Last, const _Ty& _Val) { // find first matching _Val _Adl_verify_range(_First, _Last); - _Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val)); - return _First; + if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) { + return _Find_vbool(_First, _Last, _Val); + } else { + _Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val)); + return _First; + } } #if _HAS_CXX17 @@ -5432,17 +5436,21 @@ template _NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> count(const _InIt _First, const _InIt _Last, const _Ty& _Val) { // count elements that match _Val _Adl_verify_range(_First, _Last); - auto _UFirst = _Get_unwrapped(_First); - const auto _ULast = _Get_unwrapped(_Last); - _Iter_diff_t<_InIt> _Count = 0; + if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) { + return _Count_vbool(_First, _Last, _Val); + } else { + auto _UFirst = _Get_unwrapped(_First); + const auto _ULast = _Get_unwrapped(_Last); + _Iter_diff_t<_InIt> _Count = 0; - for (; _UFirst != _ULast; ++_UFirst) { - if (*_UFirst == _Val) { - ++_Count; + for (; _UFirst != _ULast; ++_UFirst) { + if (*_UFirst == _Val) { + ++_Count; + } } - } - return _Count; + return _Count; + } } #if _HAS_CXX17 diff --git a/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp b/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp index 3f26088c8..4b6f623d1 100644 --- a/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp +++ b/tests/std/tests/GH_000625_vector_bool_optimization/test.cpp @@ -3,6 +3,7 @@ #include #include +#include #include using namespace std; @@ -96,6 +97,124 @@ bool test_fill() { return true; } +void test_find_helper(const size_t length) { + // No offset + { + vector input_true(length + 3, false); + input_true.resize(length + 6, true); + input_true[length + 1].flip(); + const auto result_true = find(input_true.cbegin(), prev(input_true.cend(), 3), true); + assert(result_true == next(input_true.cbegin(), static_cast(length + 1))); + + vector input_false(length + 3, true); + input_false.resize(length + 6, false); + input_false[length + 1].flip(); + const auto result_false = find(input_false.cbegin(), prev(input_false.cend(), 3), false); + assert(result_false == next(input_false.cbegin(), static_cast(length + 1))); + } + + // With offset + { + vector input_true(length + 3, false); + input_true.resize(length + 6, true); + input_true[length + 1].flip(); + input_true[0].flip(); + const auto result_true = find(next(input_true.cbegin()), prev(input_true.cend(), 3), true); + assert(result_true == next(input_true.cbegin(), static_cast(length + 1))); + + vector input_false(length + 3, true); + input_false.resize(length + 6, false); + input_false[length + 1].flip(); + input_false[0].flip(); + const auto result_false = find(next(input_false.cbegin()), prev(input_false.cend(), 3), false); + assert(result_false == next(input_false.cbegin(), static_cast(length + 1))); + } +} + +bool test_find() { + // Empty range + test_find_helper(0); + + // One block, ends within block + test_find_helper(15); + + // One block, ends at block boundary + test_find_helper(blockSize); + + // Multiple blocks, within block + test_find_helper(3 * blockSize + 5); + + // Multiple blocks, ends at block boundary + test_find_helper(4 * blockSize); + return true; +} + +void test_count_helper(const ptrdiff_t length) { + // This test data is not random, but irregular enough to ensure confidence in the tests + // clang-format off + vector source = { true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false, + true, false, true, false, true, true, true, false + }; + const int counts_true[8] = { 0, 1, 1, 2, 2, 3, 4, 5 }; + const int counts_false[8] = { 0, 0, 1, 1, 2, 2, 2, 2 }; + // clang-format on + const auto expected = div(static_cast(length), 8); + // No offset + { + const auto result_true = static_cast(count(source.cbegin(), next(source.cbegin(), length), true)); + assert(result_true == expected.quot * 5 + counts_true[expected.rem]); + + const auto result_false = static_cast(count(source.cbegin(), next(source.cbegin(), length), false)); + assert(result_false == expected.quot * 3 + counts_false[expected.rem]); + } + + // With offset + { + const auto result_true = + static_cast(count(next(source.cbegin(), 2), next(source.cbegin(), length + 2), true)); + assert(result_true == expected.quot * 5 + counts_true[expected.rem]); + + const auto result_false = + static_cast(count(next(source.cbegin(), 2), next(source.cbegin(), length + 2), false)); + assert(result_false == expected.quot * 3 + counts_false[expected.rem]); + } +} + +bool test_count() { + // Empty range + test_count_helper(0); + + // One block, ends within block + test_count_helper(15); + + // One block, ends at block boundary + test_count_helper(blockSize); + + // Multiple blocks, within block + test_count_helper(3 * blockSize + 8); + + // Multiple blocks, ends at block boundary + test_count_helper(4 * blockSize); + return true; +} + int main() { test_fill(); + test_find(); + test_count(); }