Specialize find and count for vector<bool> (#1131)

Co-authored-by: Stephan T. Lavavej <stl@microsoft.com>
This commit is contained in:
Michael Schellenberger Costa 2021-06-12 04:30:53 +02:00 коммит произвёл GitHub
Родитель 1219eb3936
Коммит 264765b3c3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 292 добавлений и 61 удалений

Просмотреть файл

@ -96,23 +96,6 @@ _NODISCARD constexpr _Ty rotr(const _Ty _Val, const int _Rotation) noexcept {
}
}
// Implementation of popcount without using specialized CPU instructions.
// Used at compile time and when said instructions are not supported.
template <class _Ty>
_NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
// we static_cast these bit patterns in order to truncate them to the correct size
_Val = static_cast<_Ty>(_Val - ((_Val >> 1) & static_cast<_Ty>(0x5555'5555'5555'5555ull)));
_Val = static_cast<_Ty>((_Val & static_cast<_Ty>(0x3333'3333'3333'3333ull))
+ ((_Val >> 2) & static_cast<_Ty>(0x3333'3333'3333'3333ull)));
_Val = static_cast<_Ty>((_Val + (_Val >> 4)) & static_cast<_Ty>(0x0F0F'0F0F'0F0F'0F0Full));
for (int _Shift_digits = 8; _Shift_digits < _Digits; _Shift_digits <<= 1) {
_Val = static_cast<_Ty>(_Val + static_cast<_Ty>(_Val >> _Shift_digits));
}
// we want the bottom "slot" that's big enough to store _Digits
return static_cast<int>(_Val & static_cast<_Ty>(_Digits + _Digits - 1));
}
#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
extern "C" {
@ -184,32 +167,8 @@ _NODISCARD int _Checked_x86_x64_countl_zero(const _Ty _Val) noexcept {
}
#endif // __AVX2__
}
template <class _Ty>
_NODISCARD int _Checked_x86_x64_popcount(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
#ifndef __AVX__
const bool _Definitely_have_popcnt = __isa_available >= __ISA_AVAILABLE_SSE42;
if (!_Definitely_have_popcnt) {
return _Popcount_fallback(_Val);
}
#endif // !defined(__AVX__)
if constexpr (_Digits <= 16) {
return static_cast<int>(__popcnt16(_Val));
} else if constexpr (_Digits == 32) {
return static_cast<int>(__popcnt(_Val));
} else {
#ifdef _M_IX86
return static_cast<int>(__popcnt(_Val >> 32) + __popcnt(static_cast<unsigned int>(_Val)));
#else // ^^^ _M_IX86 / !_M_IX86 vvv
return static_cast<int>(__popcnt64(_Val));
#endif // _M_IX86
}
}
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
#if defined(_M_ARM) || defined(_M_ARM64)
#ifdef __clang__ // TRANSITION, GH-1586
_NODISCARD constexpr int _Clang_arm_arm64_countl_zero(const unsigned short _Val) {
@ -283,14 +242,9 @@ _NODISCARD constexpr int countr_one(const _Ty _Val) noexcept {
return _Countr_zero(static_cast<_Ty>(~_Val));
}
template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> _Enabled = 0>
template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> = 0>
_NODISCARD constexpr int popcount(const _Ty _Val) noexcept {
#if defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
if (!_STD is_constant_evaluated()) {
return _Checked_x86_x64_popcount(_Val);
}
#endif // defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))
return _Popcount_fallback(_Val);
return _Popcount(_Val);
}
enum class endian { little = 0, big = 1, native = little };

Просмотреть файл

@ -1040,6 +1040,23 @@ _NODISCARD constexpr int _Countr_zero_fallback(const _Ty _Val) noexcept {
return _Digits - _Countl_zero_fallback(static_cast<_Ty>(static_cast<_Ty>(~_Val) & static_cast<_Ty>(_Val - 1)));
}
// Implementation of popcount without using specialized CPU instructions.
// Used at compile time and when said instructions are not supported.
template <class _Ty>
_NODISCARD constexpr int _Popcount_fallback(_Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
// we static_cast these bit patterns in order to truncate them to the correct size
_Val = static_cast<_Ty>(_Val - ((_Val >> 1) & static_cast<_Ty>(0x5555'5555'5555'5555ull)));
_Val = static_cast<_Ty>((_Val & static_cast<_Ty>(0x3333'3333'3333'3333ull))
+ ((_Val >> 2) & static_cast<_Ty>(0x3333'3333'3333'3333ull)));
_Val = static_cast<_Ty>((_Val + (_Val >> 4)) & static_cast<_Ty>(0x0F0F'0F0F'0F0F'0F0Full));
for (int _Shift_digits = 8; _Shift_digits < _Digits; _Shift_digits <<= 1) {
_Val = static_cast<_Ty>(_Val + static_cast<_Ty>(_Val >> _Shift_digits));
}
// we want the bottom "slot" that's big enough to store _Digits
return static_cast<int>(_Val & static_cast<_Ty>(_Digits + _Digits - 1));
}
#if defined(_M_IX86) || defined(_M_X64)
extern "C" {
extern int __isa_available;
@ -1092,6 +1109,38 @@ _NODISCARD int _Checked_x86_x64_countr_zero(const _Ty _Val) noexcept {
#undef _TZCNT_U64
#endif // defined(_M_IX86) || defined(_M_X64)
#if (defined(_M_IX86) || (defined(_M_X64) && !defined(_M_ARM64EC))) && !defined(_M_CEE_PURE) && !defined(__CUDACC__) \
&& !defined(__INTEL_COMPILER)
#define _HAS_POPCNT_INTRINSICS 1
#else // ^^^ intrinsics available ^^^ / vvv intrinsics unavailable vvv
#define _HAS_POPCNT_INTRINSICS 0
#endif // ^^^ intrinsics unavailable ^^^
#if _HAS_POPCNT_INTRINSICS
template <class _Ty>
_NODISCARD int _Checked_x86_x64_popcount(const _Ty _Val) noexcept {
constexpr int _Digits = numeric_limits<_Ty>::digits;
#ifndef __AVX__
const bool _Definitely_have_popcnt = __isa_available >= __ISA_AVAILABLE_SSE42;
if (!_Definitely_have_popcnt) {
return _Popcount_fallback(_Val);
}
#endif // !defined(__AVX__)
if constexpr (_Digits <= 16) {
return static_cast<int>(__popcnt16(_Val));
} else if constexpr (_Digits == 32) {
return static_cast<int>(__popcnt(_Val));
} else {
#ifdef _M_IX86
return static_cast<int>(__popcnt(_Val >> 32) + __popcnt(static_cast<unsigned int>(_Val)));
#else // ^^^ _M_IX86 / !_M_IX86 vvv
return static_cast<int>(__popcnt64(_Val));
#endif // _M_IX86
}
}
#endif // _HAS_POPCNT_INTRINSICS
template <class _Ty>
constexpr bool _Is_standard_unsigned_integer =
_Is_any_of_v<remove_cv_t<_Ty>, unsigned char, unsigned short, unsigned int, unsigned long, unsigned long long>;
@ -1103,12 +1152,27 @@ _NODISCARD constexpr int _Countr_zero(const _Ty _Val) noexcept {
if (!_STD is_constant_evaluated()) {
return _Checked_x86_x64_countr_zero(_Val);
}
#endif // defined(__cpp_lib_is_constant_evaluated)
#endif // __cpp_lib_is_constant_evaluated
#endif // defined(_M_IX86) || defined(_M_X64)
// C++17 constexpr gcd() calls this function, so it should be constexpr unless we detect runtime evaluation.
return _Countr_zero_fallback(_Val);
}
template <class _Ty, enable_if_t<_Is_standard_unsigned_integer<_Ty>, int> _Enabled = 0>
_NODISCARD _CONSTEXPR20 int _Popcount(const _Ty _Val) noexcept {
#if _HAS_POPCNT_INTRINSICS
#ifdef __cpp_lib_is_constant_evaluated
if (!_STD is_constant_evaluated())
#endif // __cpp_lib_is_constant_evaluated
{
return _Checked_x86_x64_popcount(_Val);
}
#endif // _HAS_POPCNT_INTRINSICS
return _Popcount_fallback(_Val);
}
#undef _HAS_POPCNT_INTRINSICS
_STD_END
#pragma pop_macro("new")
_STL_RESTORE_CLANG_WARNINGS

Просмотреть файл

@ -3123,8 +3123,8 @@ _INLINE_VAR constexpr bool _Is_vb_iterator<_Vb_iterator<_Alloc>, _RequiresMutabl
template <class _Alloc>
_INLINE_VAR constexpr bool _Is_vb_iterator<_Vb_const_iterator<_Alloc>, false> = true;
template <class _FwdIt, class _Ty>
_CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
template <class _VbIt, class _Ty>
_CONSTEXPR20 void _Fill_vbool(_VbIt _First, _VbIt _Last, const _Ty& _Val) {
// Set [_First, _Last) to _Val
if (_First == _Last) {
return;
@ -3172,6 +3172,92 @@ _CONSTEXPR20 void _Fill_vbool(_FwdIt _First, _FwdIt _Last, const _Ty& _Val) {
*_VbFirst = (*_VbFirst & _LastDestMask) | (_FillVal & _LastSourceMask);
}
}
template <class _VbIt, class _Ty>
_NODISCARD _CONSTEXPR20 _VbIt _Find_vbool(_VbIt _First, const _VbIt _Last, const _Ty& _Val) {
// Find _Val in [_First, _Last)
if (_First == _Last) {
return _First;
}
const _Vbase* _VbFirst = _First._Myptr;
const _Vbase* const _VbLast = _Last._Myptr;
const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff;
if (_VbFirst == _VbLast) {
// We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _SourceMask = _FirstSourceMask & _LastSourceMask;
const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask;
const auto _Count = _Countr_zero(_SelectVal);
return _Count == _VBITS ? _Last : _First + static_cast<ptrdiff_t>(_Count - _First._Myoff);
}
const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask;
const auto _FirstCount = _Countr_zero(_FirstVal);
if (_FirstCount != _VBITS) {
return _First + static_cast<ptrdiff_t>(_FirstCount - _First._Myoff);
}
++_VbFirst;
_Iter_diff_t<_VbIt> _TotalCount = static_cast<ptrdiff_t>(_VBITS - _First._Myoff);
for (; _VbFirst != _VbLast; ++_VbFirst, _TotalCount += _VBITS) {
const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst;
const auto _Count = _Countr_zero(_SelectVal);
if (_Count != _VBITS) {
return _First + (_TotalCount + _Count);
}
}
if (_Last._Myoff != 0) {
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask;
const auto _Count = _Countr_zero(_LastVal);
if (_Count != _VBITS) {
return _First + (_TotalCount + _Count);
}
}
return _Last;
}
template <class _VbIt, class _Ty>
_NODISCARD _CONSTEXPR20 _Iter_diff_t<_VbIt> _Count_vbool(_VbIt _First, const _VbIt _Last, const _Ty& _Val) {
if (_First == _Last) {
return 0;
}
const _Vbase* _VbFirst = _First._Myptr;
const _Vbase* const _VbLast = _Last._Myptr;
const auto _FirstSourceMask = static_cast<_Vbase>(-1) << _First._Myoff;
if (_VbFirst == _VbLast) {
// We already excluded _First == _Last, so here _Last._Myoff > 0 and the shift is safe
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _SourceMask = _FirstSourceMask & _LastSourceMask;
const auto _SelectVal = (_Val ? *_VbFirst : ~*_VbFirst) & _SourceMask;
return _Popcount(_SelectVal);
}
const auto _FirstVal = (_Val ? *_VbFirst : ~*_VbFirst) & _FirstSourceMask;
_Iter_diff_t<_VbIt> _Count = _Popcount(_FirstVal);
++_VbFirst;
for (; _VbFirst != _VbLast; ++_VbFirst) {
const auto _SelectVal = _Val ? *_VbFirst : ~*_VbFirst;
_Count += _Popcount(_SelectVal);
}
if (_Last._Myoff != 0) {
const auto _LastSourceMask = static_cast<_Vbase>(-1) >> (_VBITS - _Last._Myoff);
const auto _LastVal = (_Val ? *_VbFirst : ~*_VbFirst) & _LastSourceMask;
_Count += _Popcount(_LastVal);
}
return _Count;
}
_STD_END
#pragma pop_macro("new")

Просмотреть файл

@ -5345,8 +5345,12 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(const _InIt _First, const _InIt _L
template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _InIt find(_InIt _First, const _InIt _Last, const _Ty& _Val) { // find first matching _Val
_Adl_verify_range(_First, _Last);
_Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val));
return _First;
if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) {
return _Find_vbool(_First, _Last, _Val);
} else {
_Seek_wrapped(_First, _Find_unchecked(_Get_unwrapped(_First), _Get_unwrapped(_Last), _Val));
return _First;
}
}
#if _HAS_CXX17
@ -5432,17 +5436,21 @@ template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> count(const _InIt _First, const _InIt _Last, const _Ty& _Val) {
// count elements that match _Val
_Adl_verify_range(_First, _Last);
auto _UFirst = _Get_unwrapped(_First);
const auto _ULast = _Get_unwrapped(_Last);
_Iter_diff_t<_InIt> _Count = 0;
if constexpr (_Is_vb_iterator<_InIt> && is_same_v<_Ty, bool>) {
return _Count_vbool(_First, _Last, _Val);
} else {
auto _UFirst = _Get_unwrapped(_First);
const auto _ULast = _Get_unwrapped(_Last);
_Iter_diff_t<_InIt> _Count = 0;
for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Val) {
++_Count;
for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Val) {
++_Count;
}
}
}
return _Count;
return _Count;
}
}
#if _HAS_CXX17

Просмотреть файл

@ -3,6 +3,7 @@
#include <algorithm>
#include <assert.h>
#include <cstdlib>
#include <vector>
using namespace std;
@ -96,6 +97,124 @@ bool test_fill() {
return true;
}
void test_find_helper(const size_t length) {
// No offset
{
vector<bool> input_true(length + 3, false);
input_true.resize(length + 6, true);
input_true[length + 1].flip();
const auto result_true = find(input_true.cbegin(), prev(input_true.cend(), 3), true);
assert(result_true == next(input_true.cbegin(), static_cast<ptrdiff_t>(length + 1)));
vector<bool> input_false(length + 3, true);
input_false.resize(length + 6, false);
input_false[length + 1].flip();
const auto result_false = find(input_false.cbegin(), prev(input_false.cend(), 3), false);
assert(result_false == next(input_false.cbegin(), static_cast<ptrdiff_t>(length + 1)));
}
// With offset
{
vector<bool> input_true(length + 3, false);
input_true.resize(length + 6, true);
input_true[length + 1].flip();
input_true[0].flip();
const auto result_true = find(next(input_true.cbegin()), prev(input_true.cend(), 3), true);
assert(result_true == next(input_true.cbegin(), static_cast<ptrdiff_t>(length + 1)));
vector<bool> input_false(length + 3, true);
input_false.resize(length + 6, false);
input_false[length + 1].flip();
input_false[0].flip();
const auto result_false = find(next(input_false.cbegin()), prev(input_false.cend(), 3), false);
assert(result_false == next(input_false.cbegin(), static_cast<ptrdiff_t>(length + 1)));
}
}
bool test_find() {
// Empty range
test_find_helper(0);
// One block, ends within block
test_find_helper(15);
// One block, ends at block boundary
test_find_helper(blockSize);
// Multiple blocks, within block
test_find_helper(3 * blockSize + 5);
// Multiple blocks, ends at block boundary
test_find_helper(4 * blockSize);
return true;
}
void test_count_helper(const ptrdiff_t length) {
// This test data is not random, but irregular enough to ensure confidence in the tests
// clang-format off
vector<bool> source = { true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false,
true, false, true, false, true, true, true, false
};
const int counts_true[8] = { 0, 1, 1, 2, 2, 3, 4, 5 };
const int counts_false[8] = { 0, 0, 1, 1, 2, 2, 2, 2 };
// clang-format on
const auto expected = div(static_cast<int>(length), 8);
// No offset
{
const auto result_true = static_cast<int>(count(source.cbegin(), next(source.cbegin(), length), true));
assert(result_true == expected.quot * 5 + counts_true[expected.rem]);
const auto result_false = static_cast<int>(count(source.cbegin(), next(source.cbegin(), length), false));
assert(result_false == expected.quot * 3 + counts_false[expected.rem]);
}
// With offset
{
const auto result_true =
static_cast<int>(count(next(source.cbegin(), 2), next(source.cbegin(), length + 2), true));
assert(result_true == expected.quot * 5 + counts_true[expected.rem]);
const auto result_false =
static_cast<int>(count(next(source.cbegin(), 2), next(source.cbegin(), length + 2), false));
assert(result_false == expected.quot * 3 + counts_false[expected.rem]);
}
}
bool test_count() {
// Empty range
test_count_helper(0);
// One block, ends within block
test_count_helper(15);
// One block, ends at block boundary
test_count_helper(blockSize);
// Multiple blocks, within block
test_count_helper(3 * blockSize + 8);
// Multiple blocks, ends at block boundary
test_count_helper(4 * blockSize);
return true;
}
int main() {
test_fill();
test_find();
test_count();
}