diff --git a/benchmarks/src/search.cpp b/benchmarks/src/search.cpp index a7423e6c6..c6fad3d4f 100644 --- a/benchmarks/src/search.cpp +++ b/benchmarks/src/search.cpp @@ -2,12 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include +#include #include #include #include #include #include +#include #include +using namespace std::string_view_literals; const char src_haystack[] = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam mollis imperdiet massa, at dapibus elit interdum " @@ -40,9 +43,14 @@ const char src_haystack[] = "euismod eros, ut posuere ligula ullamcorper id. Nullam aliquet malesuada est at dignissim. Pellentesque finibus " "sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis."; -const char src_needle[] = "aliquet"; +constexpr std::array patterns = { + "aliquet"sv, + "aliquet malesuada"sv, +}; void c_strstr(benchmark::State& state) { + const auto& src_needle = patterns[static_cast(state.range())]; + const std::string haystack(std::begin(src_haystack), std::end(src_haystack)); const std::string needle(std::begin(src_needle), std::end(src_needle)); @@ -56,6 +64,8 @@ void c_strstr(benchmark::State& state) { template void classic_search(benchmark::State& state) { + const auto& src_needle = patterns[static_cast(state.range())]; + const std::vector haystack(std::begin(src_haystack), std::end(src_haystack)); const std::vector needle(std::begin(src_needle), std::end(src_needle)); @@ -69,6 +79,8 @@ void classic_search(benchmark::State& state) { template void ranges_search(benchmark::State& state) { + const auto& src_needle = patterns[static_cast(state.range())]; + const std::vector haystack(std::begin(src_haystack), std::end(src_haystack)); const std::vector needle(std::begin(src_needle), std::end(src_needle)); @@ -82,6 +94,8 @@ void ranges_search(benchmark::State& state) { template void search_default_searcher(benchmark::State& state) { + const auto& src_needle = patterns[static_cast(state.range())]; + const std::vector haystack(std::begin(src_haystack), std::end(src_haystack)); const std::vector needle(std::begin(src_needle), std::end(src_needle)); @@ -93,22 +107,26 @@ void search_default_searcher(benchmark::State& state) { } } -BENCHMARK(c_strstr); +void common_args(auto bm) { + bm->Range(0, patterns.size() - 1); +} -BENCHMARK(classic_search); -BENCHMARK(classic_search); -BENCHMARK(classic_search); -BENCHMARK(classic_search); +BENCHMARK(c_strstr)->Apply(common_args); -BENCHMARK(ranges_search); -BENCHMARK(ranges_search); -BENCHMARK(ranges_search); -BENCHMARK(ranges_search); +BENCHMARK(classic_search)->Apply(common_args); +BENCHMARK(classic_search)->Apply(common_args); +BENCHMARK(classic_search)->Apply(common_args); +BENCHMARK(classic_search)->Apply(common_args); -BENCHMARK(search_default_searcher); -BENCHMARK(search_default_searcher); -BENCHMARK(search_default_searcher); -BENCHMARK(search_default_searcher); +BENCHMARK(ranges_search)->Apply(common_args); +BENCHMARK(ranges_search)->Apply(common_args); +BENCHMARK(ranges_search)->Apply(common_args); +BENCHMARK(ranges_search)->Apply(common_args); + +BENCHMARK(search_default_searcher)->Apply(common_args); +BENCHMARK(search_default_searcher)->Apply(common_args); +BENCHMARK(search_default_searcher)->Apply(common_args); +BENCHMARK(search_default_searcher)->Apply(common_args); BENCHMARK_MAIN(); diff --git a/stl/inc/algorithm b/stl/inc/algorithm index 3f54cd75e..dcec475ff 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -2107,6 +2107,26 @@ _NODISCARD _CONSTEXPR20 _FwdItHaystack search(_FwdItHaystack _First1, _FwdItHays if constexpr (_Is_ranges_random_iter_v<_FwdItHaystack> && _Is_ranges_random_iter_v<_FwdItPat>) { const _Iter_diff_t<_FwdItPat> _Count2 = _ULast2 - _UFirst2; if (_ULast1 - _UFirst1 >= _Count2) { +#if _USE_STD_VECTOR_ALGORITHMS + if constexpr (_Vector_alg_in_search_is_safe) { + if (!_STD _Is_constant_evaluated()) { + const auto _Ptr1 = _STD _To_address(_UFirst1); + + const auto _Ptr_res1 = _STD _Search_vectorized( + _Ptr1, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), static_cast(_Count2)); + + if constexpr (is_pointer_v) { + _UFirst1 = _Ptr_res1; + } else { + _UFirst1 += _Ptr_res1 - _Ptr1; + } + + _STD _Seek_wrapped(_Last1, _UFirst1); + return _Last1; + } + } +#endif // _USE_STD_VECTOR_ALGORITHMS + const auto _Last_possible = _ULast1 - static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2); for (;; ++_UFirst1) { if (_STD _Equal_rev_pred_unchecked(_UFirst1, _UFirst2, _ULast2, _STD _Pass_fn(_Pred))) { diff --git a/stl/inc/functional b/stl/inc/functional index 5e81fc599..d06f4e129 100644 --- a/stl/inc/functional +++ b/stl/inc/functional @@ -2459,6 +2459,29 @@ _CONSTEXPR20 pair<_FwdItHaystack, _FwdItHaystack> _Search_pair_unchecked( _Iter_diff_t<_FwdItHaystack> _Count1 = _Last1 - _First1; _Iter_diff_t<_FwdItPat> _Count2 = _Last2 - _First2; +#if _USE_STD_VECTOR_ALGORITHMS + if constexpr (_Vector_alg_in_search_is_safe<_FwdItHaystack, _FwdItPat, _Pred_eq>) { + if (!_STD _Is_constant_evaluated()) { + const auto _Ptr1 = _STD _To_address(_First1); + + const auto _Ptr_res1 = _STD _Search_vectorized( + _Ptr1, _STD _To_address(_Last1), _STD _To_address(_First2), static_cast(_Count2)); + + if constexpr (is_pointer_v<_FwdItHaystack>) { + _First1 = _Ptr_res1; + } else { + _First1 += _Ptr_res1 - _Ptr1; + } + + if (_First1 != _Last1) { + return {_First1, _First1 + static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2)}; + } else { + return {_Last1, _Last1}; + } + } + } +#endif // _USE_STD_VECTOR_ALGORITHMS + for (; _Count2 <= _Count1; ++_First1, (void) --_Count1) { // room for match, try it _FwdItHaystack _Mid1 = _First1; for (_FwdItPat _Mid2 = _First2;; ++_Mid1, (void) ++_Mid2) { diff --git a/stl/inc/xutility b/stl/inc/xutility index 8661e79fc..c1e611c5d 100644 --- a/stl/inc/xutility +++ b/stl/inc/xutility @@ -102,6 +102,11 @@ const void* __stdcall __std_find_first_of_trivial_4( const void* __stdcall __std_find_first_of_trivial_8( const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept; +const void* __stdcall __std_search_1( + const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept; +const void* __stdcall __std_search_2( + const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept; + const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept; const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept; const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept; @@ -231,6 +236,18 @@ _Ty1* _Find_first_of_vectorized( } } +template +_Ty1* _Search_vectorized(_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept { + _STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2)); + if constexpr (sizeof(_Ty1) == 1) { + return const_cast<_Ty1*>(static_cast(::__std_search_1(_First1, _Last1, _First2, _Count2))); + } else if constexpr (sizeof(_Ty1) == 2) { + return const_cast<_Ty1*>(static_cast(::__std_search_2(_First1, _Last1, _First2, _Count2))); + } else { + _STL_INTERNAL_STATIC_ASSERT(false); // unexpected size + } +} + template _Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept { constexpr bool _Signed = is_signed_v<_Ty>; @@ -5411,6 +5428,11 @@ template constexpr bool _Equal_memcmp_is_safe = _Equal_memcmp_is_safe_helper, remove_const_t<_Iter2>, remove_const_t<_Pr>>; +// Can we activate the vector algorithms for std::search? +template +constexpr bool _Vector_alg_in_search_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr> // can search bitwise + && sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible element size + template _NODISCARD int _Memcmp_count(_CtgIt1 _First1, _CtgIt2 _First2, const size_t _Count) { _STL_INTERNAL_STATIC_ASSERT(sizeof(_Iter_value_t<_CtgIt1>) == sizeof(_Iter_value_t<_CtgIt2>)); @@ -6788,6 +6810,41 @@ namespace ranges { _STL_INTERNAL_CHECK(_RANGES distance(_First1, _Last1) == _Count1); _STL_INTERNAL_CHECK(_RANGES distance(_First2, _Last2) == _Count2); +#if _USE_STD_VECTOR_ALGORITHMS + if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity> + && is_same_v<_Pj2, identity>) { + if (!_STD is_constant_evaluated()) { + const auto _Ptr1 = _STD to_address(_First1); + const auto _Ptr2 = _STD to_address(_First2); + remove_const_t _Ptr_last1; + + if constexpr (is_same_v<_It1, _Se1>) { + _Ptr_last1 = _STD to_address(_Last1); + } else { + _Ptr_last1 = _Ptr1 + _Count1; + } + + const auto _Ptr_res1 = + _STD _Search_vectorized(_Ptr1, _Ptr_last1, _Ptr2, static_cast(_Count2)); + + if constexpr (is_pointer_v<_It1>) { + if (_Ptr_res1 != _Ptr_last1) { + return {_Ptr_res1, _Ptr_res1 + _Count2}; + } else { + return {_Ptr_res1, _Ptr_res1}; + } + } else { + _First1 += _Ptr_res1 - _Ptr1; + if (_First1 != _Last1) { + return {_First1, _First1 + static_cast>(_Count2)}; + } else { + return {_First1, _First1}; + } + } + } + } +#endif // _USE_STD_VECTOR_ALGORITHMS + for (; _Count1 >= _Count2; ++_First1, (void) --_Count1) { auto _Match_and_mid1 = _RANGES _Equal_rev_pred(_First1, _First2, _Last2, _Pred, _Proj1, _Proj2); if (_Match_and_mid1.first) { diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index c86b96871..b8a3649b6 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -3261,6 +3261,154 @@ namespace { return _Result; } + + template + const void* __stdcall __std_search_impl( + const void* _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept { + if (_Count2 == 0) { + return _First1; + } + + if (_Count2 == 1) { + return __std_find_trivial_impl<_Traits>(_First1, _Last1, *static_cast(_First2)); + } + + const size_t _Size_bytes_1 = _Byte_length(_First1, _Last1); + const size_t _Size_bytes_2 = _Count2 * sizeof(_Ty); + + if (_Size_bytes_1 < _Size_bytes_2) { + return _Last1; + } + +#ifndef _M_ARM64EC + if (_Use_sse42() && _Size_bytes_1 >= 16) { + constexpr int _Op = (sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ORDERED; + constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8; + + if (_Size_bytes_2 <= 16) { + const int _Size_el_2 = static_cast(_Size_bytes_2 / sizeof(_Ty)); + + const int _Max_full_match_pos = _Part_size_el - _Size_el_2; + + alignas(16) uint8_t _Tmp2[16]; + memcpy(_Tmp2, _First2, _Size_bytes_2); + const __m128i _Data2 = _mm_load_si128(reinterpret_cast(_Tmp2)); + + const void* _Stop1 = _First1; + _Advance_bytes(_Stop1, _Size_bytes_1 - 16); + + do { + const __m128i _Data1 = _mm_loadu_si128(static_cast(_First1)); + + if (!_mm_cmpestrc(_Data2, _Size_el_2, _Data1, _Part_size_el, _Op)) { + _Advance_bytes(_First1, 16); // No matches, next. + } else { + const int _Pos = _mm_cmpestri(_Data2, _Size_el_2, _Data1, _Part_size_el, _Op); + _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + if (_Pos <= _Max_full_match_pos) { + // Full match. Return this match. + return _First1; + } + // Partial match. Search again from the match start. Will return it if it is full. + } + } while (_First1 <= _Stop1); + + const size_t _Size_bytes_1_tail = _Byte_length(_First1, _Last1); + if (_Size_bytes_1_tail != 0) { + const int _Size_el_1_tail = static_cast(_Size_bytes_1_tail / sizeof(_Ty)); + + alignas(16) uint8_t _Tmp1[16]; + memcpy(_Tmp1, _First1, _Size_bytes_1_tail); + const __m128i _Data1 = _mm_load_si128(reinterpret_cast(_Tmp1)); + + if (_mm_cmpestrc(_Data2, _Size_el_2, _Data1, _Size_el_1_tail, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Size_el_2, _Data1, _Size_el_1_tail, _Op); + _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + // Full match because size is less than 16. Return this match. + return _First1; + } + } + } else { + const __m128i _Data2 = _mm_loadu_si128(reinterpret_cast(_First2)); + const size_t _Max_pos = _Size_bytes_1 - _Size_bytes_2; + + const void* _Stop1 = _First1; + _Advance_bytes(_Stop1, _Max_pos); + + const void* _Tail2 = _First2; + _Advance_bytes(_Tail2, 16); + + do { + const __m128i _Data1 = _mm_loadu_si128(static_cast(_First1)); + if (!_mm_cmpestrc(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op)) { + _Advance_bytes(_First1, 16); // No matches, next. + } else { + const int _Pos = _mm_cmpestri(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op); + // Matched first 16 or less + if (_Pos != 0) { + _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + + if (_First1 > _Stop1) { + break; // Oops, doesn't fit + } + + // Match not from the first byte, check 16 symbols + const __m128i _Match1 = _mm_loadu_si128(static_cast(_First1)); + const __m128i _Cmp = _mm_xor_si128(_Data2, _Match1); + if (!_mm_testz_si128(_Cmp, _Cmp)) { + // Start from the next element + _Advance_bytes(_First1, sizeof(_Ty)); + continue; + } + } + // Matched first 16, check the rest + + const void* _Tail1 = _First1; + _Advance_bytes(_Tail1, 16); + + if (memcmp(_Tail1, _Tail2, _Size_bytes_2 - 16) == 0) { + return _First1; + } + + // Start from the next element + _Advance_bytes(_First1, sizeof(_Ty)); + } + } while (_First1 <= _Stop1); + } + + return _Last1; + } else +#endif // !defined(_M_ARM64EC) + { + const size_t _Max_pos = _Size_bytes_1 - _Size_bytes_2 + sizeof(_Ty); + + auto _Ptr1 = static_cast(_First1); + const auto _Ptr2 = static_cast(_First2); + const void* _Stop1 = _Ptr1; + _Advance_bytes(_Stop1, _Max_pos); + + for (; _Ptr1 != _Stop1; ++_Ptr1) { + if (*_Ptr1 != *_Ptr2) { + continue; + } + + bool _Equal = true; + + for (size_t _Idx = 1; _Idx != _Count2; ++_Idx) { + if (_Ptr1[_Idx] != _Ptr2[_Idx]) { + _Equal = false; + break; + } + } + + if (_Equal) { + return _Ptr1; + } + } + + return _Last1; + } + } } // unnamed namespace extern "C" { @@ -3365,6 +3513,16 @@ const void* __stdcall __std_find_first_of_trivial_8( return __std_find_first_of::_Impl_4_8<__std_find_first_of::_Traits_8>(_First1, _Last1, _First2, _Last2); } +const void* __stdcall __std_search_1( + const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept { + return __std_search_impl<_Find_traits_1, uint8_t>(_First1, _Last1, _First2, _Count2); +} + +const void* __stdcall __std_search_2( + const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept { + return __std_search_impl<_Find_traits_2, uint16_t>(_First1, _Last1, _First2, _Count2); +} + __declspec(noalias) size_t __stdcall __std_mismatch_1(const void* const _First1, const void* const _First2, const size_t _Count) noexcept { return __std_mismatch_impl<_Find_traits_1, uint8_t>(_First1, _First2, _Count); diff --git a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp index a09ed25c3..28f5c92d3 100644 --- a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp +++ b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp @@ -320,8 +320,10 @@ void test_search(mt19937_64& gen) { uniform_int_distribution dis('0', '9'); vector input_haystack; vector input_needle; + vector temp; input_haystack.reserve(haystackDataCount); input_needle.reserve(needleDataCount); + temp.reserve(needleDataCount); for (;;) { input_needle.clear(); @@ -330,6 +332,17 @@ void test_search(mt19937_64& gen) { for (size_t attempts = 0; attempts < needleDataCount; ++attempts) { input_needle.push_back(static_cast(dis(gen))); test_case_search(input_haystack, input_needle); + + // For large needles the chance of a match is low, so test a guaranteed match + if (input_haystack.size() > input_needle.size() * 2) { + uniform_int_distribution pos_dis(0, input_haystack.size() - input_needle.size()); + const size_t pos = pos_dis(gen); + const auto overwritten_first = input_haystack.begin() + static_cast(pos); + temp.assign(overwritten_first, overwritten_first + static_cast(input_needle.size())); + copy(input_needle.begin(), input_needle.end(), overwritten_first); + test_case_search(input_haystack, input_needle); + copy(temp.begin(), temp.end(), overwritten_first); + } } if (input_haystack.size() == haystackDataCount) {