Vectorize `basic_string::rfind` (the single character overload) (#5087)

Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
This commit is contained in:
Alex Guteniev 2024-11-19 01:44:15 -08:00 коммит произвёл GitHub
Родитель 64d143da95
Коммит 1711bc35aa
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
5 изменённых файлов: 107 добавлений и 45 удалений

Просмотреть файл

@ -7,25 +7,38 @@
#include <cstdint>
#include <cstdlib>
#include <ranges>
#include <string>
#include <type_traits>
#include <vector>
#include "skewed_allocator.hpp"
enum class Op {
FindSized,
FindUnsized,
Count,
StringFind,
StringRFind,
};
using namespace std;
template <class T, Op Operation>
template <class T, template <class> class Alloc, Op Operation>
void bm(benchmark::State& state) {
const auto size = static_cast<size_t>(state.range(0));
const auto pos = static_cast<size_t>(state.range(1));
vector<T> a(size, T{'0'});
using Container = conditional_t<Operation == Op::StringFind || Operation == Op::StringRFind,
basic_string<T, char_traits<T>, Alloc<T>>, vector<T, Alloc<T>>>;
Container a(size, T{'0'});
if (pos < size) {
a[pos] = T{'1'};
if constexpr (Operation == Op::StringRFind) {
a[size - pos - 1] = T{'1'};
} else {
a[pos] = T{'1'};
}
} else {
if constexpr (Operation == Op::FindUnsized) {
abort();
@ -39,6 +52,10 @@ void bm(benchmark::State& state) {
benchmark::DoNotOptimize(ranges::find(a.begin(), unreachable_sentinel, T{'1'}));
} else if constexpr (Operation == Op::Count) {
benchmark::DoNotOptimize(ranges::count(a.begin(), a.end(), T{'1'}));
} else if constexpr (Operation == Op::StringFind) {
benchmark::DoNotOptimize(a.find(T{'1'}));
} else if constexpr (Operation == Op::StringRFind) {
benchmark::DoNotOptimize(a.rfind(T{'1'}));
}
}
}
@ -50,17 +67,28 @@ void common_args(auto bm) {
}
BENCHMARK(bm<uint8_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, Op::FindUnsized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint8_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, not_highly_aligned_allocator, Op::FindUnsized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, highly_aligned_allocator, Op::FindUnsized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint8_t, highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<char, not_highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<char, highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<char, not_highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);
BENCHMARK(bm<char, highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);
BENCHMARK(bm<uint16_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint16_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint16_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint16_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<wchar_t, not_highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<wchar_t, not_highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);
BENCHMARK(bm<uint32_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint32_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint32_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint32_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<char32_t, not_highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<char32_t, not_highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);
BENCHMARK(bm<uint64_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint64_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint64_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint64_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK_MAIN();

Просмотреть файл

@ -724,7 +724,24 @@ constexpr size_t _Traits_rfind_ch(_In_reads_(_Hay_size) const _Traits_ptr_t<_Tra
return static_cast<size_t>(-1);
}
for (auto _Match_try = _Haystack + (_STD min)(_Start_at, _Hay_size - 1);; --_Match_try) {
const size_t _Actual_start_at = (_STD min)(_Start_at, _Hay_size - 1);
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Is_implementation_handled_char_traits<_Traits>) {
if (!_STD _Is_constant_evaluated()) {
const auto _End = _Haystack + _Actual_start_at + 1;
const auto _Ptr = _STD _Find_last_vectorized(_Haystack, _End, _Ch);
if (_Ptr != _End) {
return static_cast<size_t>(_Ptr - _Haystack);
} else {
return static_cast<size_t>(-1);
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (auto _Match_try = _Haystack + _Actual_start_at;; --_Match_try) {
if (_Traits::eq(*_Match_try, _Ch)) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
}

Просмотреть файл

@ -54,11 +54,6 @@ _Min_max_element_t __stdcall __std_minmax_element_8(const void* _First, const vo
_Min_max_element_t __stdcall __std_minmax_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;
const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;
__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
@ -162,33 +157,6 @@ auto _Minmax_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
}
}
template <class _Ty, class _TVal>
_Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noexcept {
if constexpr (is_pointer_v<_TVal> || is_null_pointer_v<_TVal>) {
#ifdef _WIN64
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, reinterpret_cast<uint64_t>(_Val))));
#else
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, reinterpret_cast<uint32_t>(_Val))));
#endif
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_1(_First, _Last, static_cast<uint8_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_2(_First, _Last, static_cast<uint16_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 4) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, static_cast<uint32_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 8) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, static_cast<uint64_t>(_Val))));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}
template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {

Просмотреть файл

@ -93,6 +93,11 @@ const void* __stdcall __std_find_trivial_2(const void* _First, const void* _Last
const void* __stdcall __std_find_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;
const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
@ -217,6 +222,33 @@ _Ty* _Find_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noe
}
}
template <class _Ty, class _TVal>
_Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noexcept {
if constexpr (is_pointer_v<_TVal> || is_null_pointer_v<_TVal>) {
#ifdef _WIN64
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, reinterpret_cast<uint64_t>(_Val))));
#else
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, reinterpret_cast<uint32_t>(_Val))));
#endif
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_1(_First, _Last, static_cast<uint8_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_2(_First, _Last, static_cast<uint16_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 4) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, static_cast<uint32_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 8) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, static_cast<uint64_t>(_Val))));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}
// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

Просмотреть файл

@ -1128,6 +1128,22 @@ void test_case_string_rfind_str(const basic_string<T>& input_haystack, const bas
assert(expected == actual);
}
template <class T>
void test_case_string_rfind_ch(const basic_string<T>& input_haystack, const T value) {
ptrdiff_t expected;
const auto expected_iter = last_known_good_find_last(input_haystack.begin(), input_haystack.end(), value);
if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}
const auto actual = static_cast<ptrdiff_t>(input_haystack.rfind(value));
assert(expected == actual);
}
template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
@ -1144,6 +1160,7 @@ void test_basic_string_dis(mt19937_64& gen, D& dis) {
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);
test_case_string_rfind_str(input_haystack, input_needle);
test_case_string_rfind_ch(input_haystack, static_cast<T>(dis(gen)));
for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));