Vectorize `std::search` of 1 and 2 bytes elements with `pcmpestri` (#4745)

Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
This commit is contained in:
Alex Guteniev 2024-09-09 12:29:36 -07:00 коммит произвёл GitHub
Родитель c7c5ca7d6a
Коммит e93126188e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
6 изменённых файлов: 303 добавлений и 14 удалений

Просмотреть файл

@ -2,12 +2,15 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <algorithm>
#include <array>
#include <benchmark/benchmark.h>
#include <cstdint>
#include <cstring>
#include <functional>
#include <string>
#include <string_view>
#include <vector>
using namespace std::string_view_literals;
const char src_haystack[] =
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nullam mollis imperdiet massa, at dapibus elit interdum "
@ -40,9 +43,14 @@ const char src_haystack[] =
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquet malesuada est at dignissim. Pellentesque finibus "
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";
const char src_needle[] = "aliquet";
constexpr std::array patterns = {
"aliquet"sv,
"aliquet malesuada"sv,
};
void c_strstr(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
const std::string haystack(std::begin(src_haystack), std::end(src_haystack));
const std::string needle(std::begin(src_needle), std::end(src_needle));
@ -56,6 +64,8 @@ void c_strstr(benchmark::State& state) {
template <class T>
void classic_search(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
@ -69,6 +79,8 @@ void classic_search(benchmark::State& state) {
template <class T>
void ranges_search(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
@ -82,6 +94,8 @@ void ranges_search(benchmark::State& state) {
template <class T>
void search_default_searcher(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];
const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));
@ -93,22 +107,26 @@ void search_default_searcher(benchmark::State& state) {
}
}
BENCHMARK(c_strstr);
void common_args(auto bm) {
bm->Range(0, patterns.size() - 1);
}
BENCHMARK(classic_search<std::uint8_t>);
BENCHMARK(classic_search<std::uint16_t>);
BENCHMARK(classic_search<std::uint32_t>);
BENCHMARK(classic_search<std::uint64_t>);
BENCHMARK(c_strstr)->Apply(common_args);
BENCHMARK(ranges_search<std::uint8_t>);
BENCHMARK(ranges_search<std::uint16_t>);
BENCHMARK(ranges_search<std::uint32_t>);
BENCHMARK(ranges_search<std::uint64_t>);
BENCHMARK(classic_search<std::uint8_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint32_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint64_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint8_t>);
BENCHMARK(search_default_searcher<std::uint16_t>);
BENCHMARK(search_default_searcher<std::uint32_t>);
BENCHMARK(search_default_searcher<std::uint64_t>);
BENCHMARK(ranges_search<std::uint8_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint32_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint64_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint32_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint64_t>)->Apply(common_args);
BENCHMARK_MAIN();

Просмотреть файл

@ -2107,6 +2107,26 @@ _NODISCARD _CONSTEXPR20 _FwdItHaystack search(_FwdItHaystack _First1, _FwdItHays
if constexpr (_Is_ranges_random_iter_v<_FwdItHaystack> && _Is_ranges_random_iter_v<_FwdItPat>) {
const _Iter_diff_t<_FwdItPat> _Count2 = _ULast2 - _UFirst2;
if (_ULast1 - _UFirst1 >= _Count2) {
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_UFirst1);
const auto _Ptr_res1 = _STD _Search_vectorized(
_Ptr1, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), static_cast<size_t>(_Count2));
if constexpr (is_pointer_v<decltype(_UFirst1)>) {
_UFirst1 = _Ptr_res1;
} else {
_UFirst1 += _Ptr_res1 - _Ptr1;
}
_STD _Seek_wrapped(_Last1, _UFirst1);
return _Last1;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
const auto _Last_possible = _ULast1 - static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2);
for (;; ++_UFirst1) {
if (_STD _Equal_rev_pred_unchecked(_UFirst1, _UFirst2, _ULast2, _STD _Pass_fn(_Pred))) {

Просмотреть файл

@ -2459,6 +2459,29 @@ _CONSTEXPR20 pair<_FwdItHaystack, _FwdItHaystack> _Search_pair_unchecked(
_Iter_diff_t<_FwdItHaystack> _Count1 = _Last1 - _First1;
_Iter_diff_t<_FwdItPat> _Count2 = _Last2 - _First2;
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_FwdItHaystack, _FwdItPat, _Pred_eq>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_First1);
const auto _Ptr_res1 = _STD _Search_vectorized(
_Ptr1, _STD _To_address(_Last1), _STD _To_address(_First2), static_cast<size_t>(_Count2));
if constexpr (is_pointer_v<_FwdItHaystack>) {
_First1 = _Ptr_res1;
} else {
_First1 += _Ptr_res1 - _Ptr1;
}
if (_First1 != _Last1) {
return {_First1, _First1 + static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2)};
} else {
return {_Last1, _Last1};
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (; _Count2 <= _Count1; ++_First1, (void) --_Count1) { // room for match, try it
_FwdItHaystack _Mid1 = _First1;
for (_FwdItPat _Mid2 = _First2;; ++_Mid1, (void) ++_Mid2) {

Просмотреть файл

@ -102,6 +102,11 @@ const void* __stdcall __std_find_first_of_trivial_4(
const void* __stdcall __std_find_first_of_trivial_8(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_search_1(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
const void* __stdcall __std_search_2(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
@ -231,6 +236,18 @@ _Ty1* _Find_first_of_vectorized(
}
}
template <class _Ty1, class _Ty2>
_Ty1* _Search_vectorized(_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_search_1(_First1, _Last1, _First2, _Count2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_search_2(_First1, _Last1, _First2, _Count2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}
template <class _Ty>
_Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
constexpr bool _Signed = is_signed_v<_Ty>;
@ -5411,6 +5428,11 @@ template <class _Iter1, class _Iter2, class _Pr>
constexpr bool _Equal_memcmp_is_safe =
_Equal_memcmp_is_safe_helper<remove_const_t<_Iter1>, remove_const_t<_Iter2>, remove_const_t<_Pr>>;
// Can we activate the vector algorithms for std::search?
template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_search_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr> // can search bitwise
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible element size
template <class _CtgIt1, class _CtgIt2>
_NODISCARD int _Memcmp_count(_CtgIt1 _First1, _CtgIt2 _First2, const size_t _Count) {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Iter_value_t<_CtgIt1>) == sizeof(_Iter_value_t<_CtgIt2>));
@ -6788,6 +6810,41 @@ namespace ranges {
_STL_INTERNAL_CHECK(_RANGES distance(_First1, _Last1) == _Count1);
_STL_INTERNAL_CHECK(_RANGES distance(_First2, _Last2) == _Count2);
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
&& is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated()) {
const auto _Ptr1 = _STD to_address(_First1);
const auto _Ptr2 = _STD to_address(_First2);
remove_const_t<decltype(_Ptr1)> _Ptr_last1;
if constexpr (is_same_v<_It1, _Se1>) {
_Ptr_last1 = _STD to_address(_Last1);
} else {
_Ptr_last1 = _Ptr1 + _Count1;
}
const auto _Ptr_res1 =
_STD _Search_vectorized(_Ptr1, _Ptr_last1, _Ptr2, static_cast<size_t>(_Count2));
if constexpr (is_pointer_v<_It1>) {
if (_Ptr_res1 != _Ptr_last1) {
return {_Ptr_res1, _Ptr_res1 + _Count2};
} else {
return {_Ptr_res1, _Ptr_res1};
}
} else {
_First1 += _Ptr_res1 - _Ptr1;
if (_First1 != _Last1) {
return {_First1, _First1 + static_cast<iter_difference_t<_It1>>(_Count2)};
} else {
return {_First1, _First1};
}
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (; _Count1 >= _Count2; ++_First1, (void) --_Count1) {
auto _Match_and_mid1 = _RANGES _Equal_rev_pred(_First1, _First2, _Last2, _Pred, _Proj1, _Proj2);
if (_Match_and_mid1.first) {

Просмотреть файл

@ -3261,6 +3261,154 @@ namespace {
return _Result;
}
template <class _Traits, class _Ty>
const void* __stdcall __std_search_impl(
const void* _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept {
if (_Count2 == 0) {
return _First1;
}
if (_Count2 == 1) {
return __std_find_trivial_impl<_Traits>(_First1, _Last1, *static_cast<const _Ty*>(_First2));
}
const size_t _Size_bytes_1 = _Byte_length(_First1, _Last1);
const size_t _Size_bytes_2 = _Count2 * sizeof(_Ty);
if (_Size_bytes_1 < _Size_bytes_2) {
return _Last1;
}
#ifndef _M_ARM64EC
if (_Use_sse42() && _Size_bytes_1 >= 16) {
constexpr int _Op = (sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ORDERED;
constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8;
if (_Size_bytes_2 <= 16) {
const int _Size_el_2 = static_cast<int>(_Size_bytes_2 / sizeof(_Ty));
const int _Max_full_match_pos = _Part_size_el - _Size_el_2;
alignas(16) uint8_t _Tmp2[16];
memcpy(_Tmp2, _First2, _Size_bytes_2);
const __m128i _Data2 = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp2));
const void* _Stop1 = _First1;
_Advance_bytes(_Stop1, _Size_bytes_1 - 16);
do {
const __m128i _Data1 = _mm_loadu_si128(static_cast<const __m128i*>(_First1));
if (!_mm_cmpestrc(_Data2, _Size_el_2, _Data1, _Part_size_el, _Op)) {
_Advance_bytes(_First1, 16); // No matches, next.
} else {
const int _Pos = _mm_cmpestri(_Data2, _Size_el_2, _Data1, _Part_size_el, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
if (_Pos <= _Max_full_match_pos) {
// Full match. Return this match.
return _First1;
}
// Partial match. Search again from the match start. Will return it if it is full.
}
} while (_First1 <= _Stop1);
const size_t _Size_bytes_1_tail = _Byte_length(_First1, _Last1);
if (_Size_bytes_1_tail != 0) {
const int _Size_el_1_tail = static_cast<int>(_Size_bytes_1_tail / sizeof(_Ty));
alignas(16) uint8_t _Tmp1[16];
memcpy(_Tmp1, _First1, _Size_bytes_1_tail);
const __m128i _Data1 = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp1));
if (_mm_cmpestrc(_Data2, _Size_el_2, _Data1, _Size_el_1_tail, _Op)) {
const int _Pos = _mm_cmpestri(_Data2, _Size_el_2, _Data1, _Size_el_1_tail, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
// Full match because size is less than 16. Return this match.
return _First1;
}
}
} else {
const __m128i _Data2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First2));
const size_t _Max_pos = _Size_bytes_1 - _Size_bytes_2;
const void* _Stop1 = _First1;
_Advance_bytes(_Stop1, _Max_pos);
const void* _Tail2 = _First2;
_Advance_bytes(_Tail2, 16);
do {
const __m128i _Data1 = _mm_loadu_si128(static_cast<const __m128i*>(_First1));
if (!_mm_cmpestrc(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op)) {
_Advance_bytes(_First1, 16); // No matches, next.
} else {
const int _Pos = _mm_cmpestri(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op);
// Matched first 16 or less
if (_Pos != 0) {
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
if (_First1 > _Stop1) {
break; // Oops, doesn't fit
}
// Match not from the first byte, check 16 symbols
const __m128i _Match1 = _mm_loadu_si128(static_cast<const __m128i*>(_First1));
const __m128i _Cmp = _mm_xor_si128(_Data2, _Match1);
if (!_mm_testz_si128(_Cmp, _Cmp)) {
// Start from the next element
_Advance_bytes(_First1, sizeof(_Ty));
continue;
}
}
// Matched first 16, check the rest
const void* _Tail1 = _First1;
_Advance_bytes(_Tail1, 16);
if (memcmp(_Tail1, _Tail2, _Size_bytes_2 - 16) == 0) {
return _First1;
}
// Start from the next element
_Advance_bytes(_First1, sizeof(_Ty));
}
} while (_First1 <= _Stop1);
}
return _Last1;
} else
#endif // !defined(_M_ARM64EC)
{
const size_t _Max_pos = _Size_bytes_1 - _Size_bytes_2 + sizeof(_Ty);
auto _Ptr1 = static_cast<const _Ty*>(_First1);
const auto _Ptr2 = static_cast<const _Ty*>(_First2);
const void* _Stop1 = _Ptr1;
_Advance_bytes(_Stop1, _Max_pos);
for (; _Ptr1 != _Stop1; ++_Ptr1) {
if (*_Ptr1 != *_Ptr2) {
continue;
}
bool _Equal = true;
for (size_t _Idx = 1; _Idx != _Count2; ++_Idx) {
if (_Ptr1[_Idx] != _Ptr2[_Idx]) {
_Equal = false;
break;
}
}
if (_Equal) {
return _Ptr1;
}
}
return _Last1;
}
}
} // unnamed namespace
extern "C" {
@ -3365,6 +3513,16 @@ const void* __stdcall __std_find_first_of_trivial_8(
return __std_find_first_of::_Impl_4_8<__std_find_first_of::_Traits_8>(_First1, _Last1, _First2, _Last2);
}
const void* __stdcall __std_search_1(
const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept {
return __std_search_impl<_Find_traits_1, uint8_t>(_First1, _Last1, _First2, _Count2);
}
const void* __stdcall __std_search_2(
const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept {
return __std_search_impl<_Find_traits_2, uint16_t>(_First1, _Last1, _First2, _Count2);
}
__declspec(noalias) size_t
__stdcall __std_mismatch_1(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return __std_mismatch_impl<_Find_traits_1, uint8_t>(_First1, _First2, _Count);

Просмотреть файл

@ -320,8 +320,10 @@ void test_search(mt19937_64& gen) {
uniform_int_distribution<TD> dis('0', '9');
vector<T> input_haystack;
vector<T> input_needle;
vector<T> temp;
input_haystack.reserve(haystackDataCount);
input_needle.reserve(needleDataCount);
temp.reserve(needleDataCount);
for (;;) {
input_needle.clear();
@ -330,6 +332,17 @@ void test_search(mt19937_64& gen) {
for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_search(input_haystack, input_needle);
// For large needles the chance of a match is low, so test a guaranteed match
if (input_haystack.size() > input_needle.size() * 2) {
uniform_int_distribution<size_t> pos_dis(0, input_haystack.size() - input_needle.size());
const size_t pos = pos_dis(gen);
const auto overwritten_first = input_haystack.begin() + static_cast<ptrdiff_t>(pos);
temp.assign(overwritten_first, overwritten_first + static_cast<ptrdiff_t>(input_needle.size()));
copy(input_needle.begin(), input_needle.end(), overwritten_first);
test_case_search(input_haystack, input_needle);
copy(temp.begin(), temp.end(), overwritten_first);
}
}
if (input_haystack.size() == haystackDataCount) {