diff --git a/benchmarks/src/find_first_of.cpp b/benchmarks/src/find_first_of.cpp index 4697ccd3c..05731a189 100644 --- a/benchmarks/src/find_first_of.cpp +++ b/benchmarks/src/find_first_of.cpp @@ -33,18 +33,14 @@ void bm(benchmark::State& state) { } } -#define ARGS \ - Args({2, 3}) \ - ->Args({7, 4}) \ - ->Args({9, 3}) \ - ->Args({22, 5}) \ - ->Args({58, 2}) \ - ->Args({102, 4}) \ - ->Args({325, 1}) \ - ->Args({1011, 11}) \ - ->Args({3056, 7}); +void common_args(auto bm) { + bm->Args({2, 3})->Args({7, 4})->Args({9, 3})->Args({22, 5})->Args({58, 2}); + bm->Args({102, 4})->Args({325, 1})->Args({1011, 11})->Args({3056, 7}); +} -BENCHMARK(bm)->ARGS; -BENCHMARK(bm)->ARGS; +BENCHMARK(bm)->Apply(common_args); +BENCHMARK(bm)->Apply(common_args); +BENCHMARK(bm)->Apply(common_args); +BENCHMARK(bm)->Apply(common_args); BENCHMARK_MAIN(); diff --git a/stl/inc/algorithm b/stl/inc/algorithm index 6ba8bb4df..8aab11135 100644 --- a/stl/inc/algorithm +++ b/stl/inc/algorithm @@ -62,6 +62,10 @@ const void* __stdcall __std_find_first_of_trivial_1( const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept; const void* __stdcall __std_find_first_of_trivial_2( const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept; +const void* __stdcall __std_find_first_of_trivial_4( + const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept; +const void* __stdcall __std_find_first_of_trivial_8( + const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept; __declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept; __declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept; @@ -202,6 +206,12 @@ _Ty1* _Find_first_of_vectorized( } else if constexpr (sizeof(_Ty1) == 2) { return const_cast<_Ty1*>( static_cast(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2))); + } else if constexpr (sizeof(_Ty1) == 4) { + return const_cast<_Ty1*>( + static_cast(::__std_find_first_of_trivial_4(_First1, _Last1, _First2, _Last2))); + } else if constexpr (sizeof(_Ty1) == 8) { + return const_cast<_Ty1*>( + static_cast(::__std_find_first_of_trivial_8(_First1, _Last1, _First2, _Last2))); } else { static_assert(false, "Unexpected size"); } @@ -230,9 +240,7 @@ _INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16; // Can we activate the vector algorithms for find_first_of? template -constexpr bool _Vector_alg_in_find_first_of_is_safe = - _Equal_memcmp_is_safe<_It1, _It2, _Pr> // can replace value comparison with bitwise comparison - && sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible size +constexpr bool _Vector_alg_in_find_first_of_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr>; // Can we activate the vector algorithms for replace? template diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index d61f74180..6383f415e 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -42,9 +42,10 @@ namespace { }; __m256i _Avx2_tail_mask_32(const size_t _Count_in_dwords) noexcept { - // _Count_in_dwords must be within [1, 7]. - static constexpr unsigned int _Tail_masks[14] = {~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0}; - return _mm256_loadu_si256(reinterpret_cast(_Tail_masks + (7 - _Count_in_dwords))); + // _Count_in_dwords must be within [0, 8]. + static constexpr unsigned int _Tail_masks[16] = { + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0}; + return _mm256_loadu_si256(reinterpret_cast(_Tail_masks + (8 - _Count_in_dwords))); } } // namespace #endif // !defined(_M_ARM64EC) @@ -2037,83 +2038,143 @@ namespace { return _Result; } - template - const void* __stdcall __std_find_first_of_trivial_impl( - const void* _First1, const void* const _Last1, const void* const _First2, const void* const _Last2) noexcept { + namespace __std_find_first_of { + + template + const void* __stdcall _Fallback(const void* _First1, const void* const _Last1, const void* const _First2, + const void* const _Last2) noexcept { + auto _Ptr_haystack = static_cast(_First1); + const auto _Ptr_haystack_end = static_cast(_Last1); + const auto _Ptr_needle = static_cast(_First2); + const auto _Ptr_needle_end = static_cast(_Last2); + + for (; _Ptr_haystack != _Ptr_haystack_end; ++_Ptr_haystack) { + for (auto _Ptr = _Ptr_needle; _Ptr != _Ptr_needle_end; ++_Ptr) { + if (*_Ptr_haystack == *_Ptr) { + return _Ptr_haystack; + } + } + } + + return _Ptr_haystack; + } + + template + const void* __stdcall _Impl_pcmpestri(const void* _First1, const void* const _Last1, const void* const _First2, + const void* const _Last2) noexcept { #ifndef _M_ARM64EC - if (_Use_sse42()) { - constexpr int _Op = - (sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT; - constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8; - const size_t _Needle_length = _Byte_length(_First2, _Last2); + if (_Use_sse42()) { + constexpr int _Op = (sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ANY + | _SIDD_LEAST_SIGNIFICANT; + constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8; + const size_t _Needle_length = _Byte_length(_First2, _Last2); - if (_Needle_length <= 16) { - // Special handling of small needle - // The generic branch could also handle it but with slightly worse performance + if (_Needle_length <= 16) { + // Special handling of small needle + // The generic branch could also handle it but with slightly worse performance - const int _Needle_length_el = static_cast(_Needle_length / sizeof(_Ty)); + const int _Needle_length_el = static_cast(_Needle_length / sizeof(_Ty)); - alignas(16) uint8_t _Tmp1[16]; - memcpy(_Tmp1, _First2, _Needle_length); - const __m128i _Needle = _mm_load_si128(reinterpret_cast(_Tmp1)); + alignas(16) uint8_t _Tmp2[16]; + memcpy(_Tmp2, _First2, _Needle_length); + const __m128i _Data2 = _mm_load_si128(reinterpret_cast(_Tmp2)); - const size_t _Haystack_length = _Byte_length(_First1, _Last1); - const void* _Stop_at = _First1; - _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); + const size_t _Haystack_length = _Byte_length(_First1, _Last1); + const void* _Stop_at = _First1; + _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); - while (_First1 != _Stop_at) { - const __m128i _Haystack_part = _mm_loadu_si128(static_cast(_First1)); - if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) { - const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op); + while (_First1 != _Stop_at) { + const __m128i _Data1 = _mm_loadu_si128(static_cast(_First1)); + if (_mm_cmpestrc(_Data2, _Needle_length_el, _Data1, _Part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Needle_length_el, _Data1, _Part_size_el, _Op); + _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + return _First1; + } + + _Advance_bytes(_First1, 16); + } + + const size_t _Last_part_size = _Haystack_length & 0xF; + const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); + + alignas(16) uint8_t _Tmp1[16]; + memcpy(_Tmp1, _First1, _Last_part_size); + const __m128i _Data1 = _mm_load_si128(reinterpret_cast(_Tmp1)); + + if (_mm_cmpestrc(_Data2, _Needle_length_el, _Data1, _Last_part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Needle_length_el, _Data1, _Last_part_size_el, _Op); _Advance_bytes(_First1, _Pos * sizeof(_Ty)); return _First1; } - _Advance_bytes(_First1, 16); - } - - const size_t _Last_part_size = _Haystack_length & 0xF; - const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); - - alignas(16) uint8_t _Tmp2[16]; - memcpy(_Tmp2, _First1, _Last_part_size); - const __m128i _Haystack_part = _mm_load_si128(reinterpret_cast(_Tmp2)); - - if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op)) { - const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op); - _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + _Advance_bytes(_First1, _Last_part_size); return _First1; - } + } else { + const void* _Last_needle = _First2; + _Advance_bytes(_Last_needle, _Needle_length & ~size_t{0xF}); - _Advance_bytes(_First1, _Last_part_size); - return _First1; - } else { - const void* _Last_needle = _First2; - _Advance_bytes(_Last_needle, _Needle_length & ~size_t{0xF}); + const int _Last_needle_length = static_cast(_Needle_length & 0xF); - const int _Last_needle_length = static_cast(_Needle_length & 0xF); + alignas(16) uint8_t _Tmp2[16]; + memcpy(_Tmp2, _Last_needle, _Last_needle_length); + const __m128i _Last_needle_val = _mm_load_si128(reinterpret_cast(_Tmp2)); + const int _Last_needle_length_el = _Last_needle_length / sizeof(_Ty); - alignas(16) uint8_t _Tmp1[16]; - memcpy(_Tmp1, _Last_needle, _Last_needle_length); - const __m128i _Last_needle_val = _mm_load_si128(reinterpret_cast(_Tmp1)); - const int _Last_needle_length_el = _Last_needle_length / sizeof(_Ty); + constexpr int _Not_found = 16; // arbitrary value greater than any found value - constexpr int _Not_found = 16; // arbitrary value greater than any found value + int _Found_pos = _Not_found; - int _Found_pos = _Not_found; + const size_t _Haystack_length = _Byte_length(_First1, _Last1); + const void* _Stop_at = _First1; + _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); - const size_t _Haystack_length = _Byte_length(_First1, _Last1); - const void* _Stop_at = _First1; - _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); + while (_First1 != _Stop_at) { + const __m128i _Data1 = _mm_loadu_si128(static_cast(_First1)); - while (_First1 != _Stop_at) { - const __m128i _Haystack_part = _mm_loadu_si128(static_cast(_First1)); + for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle; + _Advance_bytes(_Cur_needle, 16)) { + const __m128i _Data2 = _mm_loadu_si128(static_cast(_Cur_needle)); + if (_mm_cmpestrc(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op); + if (_Pos < _Found_pos) { + _Found_pos = _Pos; + } + } + } + + if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) { + const __m128i _Data2 = _Last_needle_val; + if (_mm_cmpestrc(_Data2, _Needle_length_el, _Data1, _Part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Needle_length_el, _Data1, _Part_size_el, _Op); + if (_Pos < _Found_pos) { + _Found_pos = _Pos; + } + } + } + + if (_Found_pos != _Not_found) { + _Advance_bytes(_First1, _Found_pos * sizeof(_Ty)); + return _First1; + } + + _Advance_bytes(_First1, 16); + } + + const size_t _Last_part_size = _Haystack_length & 0xF; + const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); + + alignas(16) uint8_t _Tmp1[16]; + memcpy(_Tmp1, _First1, _Last_part_size); + const __m128i _Data1 = _mm_load_si128(reinterpret_cast(_Tmp1)); + + _Found_pos = _Last_part_size_el; for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle; _Advance_bytes(_Cur_needle, 16)) { - const __m128i _Needle = _mm_loadu_si128(static_cast(_Cur_needle)); - if (_mm_cmpestrc(_Needle, _Part_size_el, _Haystack_part, _Part_size_el, _Op)) { - const int _Pos = _mm_cmpestri(_Needle, _Part_size_el, _Haystack_part, _Part_size_el, _Op); + const __m128i _Data2 = _mm_loadu_si128(static_cast(_Cur_needle)); + + if (_mm_cmpestrc(_Data2, _Part_size_el, _Data1, _Last_part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Part_size_el, _Data1, _Last_part_size_el, _Op); if (_Pos < _Found_pos) { _Found_pos = _Pos; } @@ -2121,77 +2182,231 @@ namespace { } if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) { - const __m128i _Needle = _Last_needle_val; - if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) { - const int _Pos = - _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op); + const __m128i _Data2 = _Last_needle_val; + if (_mm_cmpestrc(_Data2, _Needle_length_el, _Data1, _Last_part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Data2, _Needle_length_el, _Data1, _Last_part_size_el, _Op); if (_Pos < _Found_pos) { _Found_pos = _Pos; } } } - if (_Found_pos != _Not_found) { - _Advance_bytes(_First1, _Found_pos * sizeof(_Ty)); - return _First1; - } - - _Advance_bytes(_First1, 16); + _Advance_bytes(_First1, _Found_pos * sizeof(_Ty)); + return _First1; } - - const size_t _Last_part_size = _Haystack_length & 0xF; - const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); - - alignas(16) uint8_t _Tmp2[16]; - memcpy(_Tmp2, _First1, _Last_part_size); - const __m128i _Haystack_part = _mm_load_si128(reinterpret_cast(_Tmp2)); - - _Found_pos = _Last_part_size_el; - - for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle; _Advance_bytes(_Cur_needle, 16)) { - const __m128i _Needle = _mm_loadu_si128(static_cast(_Cur_needle)); - - if (_mm_cmpestrc(_Needle, _Part_size_el, _Haystack_part, _Last_part_size_el, _Op)) { - const int _Pos = _mm_cmpestri(_Needle, _Part_size_el, _Haystack_part, _Last_part_size_el, _Op); - if (_Pos < _Found_pos) { - _Found_pos = _Pos; - } - } - } - - if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) { - const __m128i _Needle = _Last_needle_val; - if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op)) { - const int _Pos = - _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op); - if (_Pos < _Found_pos) { - _Found_pos = _Pos; - } - } - } - - _Advance_bytes(_First1, _Found_pos * sizeof(_Ty)); - return _First1; } +#endif // !_M_ARM64EC + return _Fallback<_Ty>(_First1, _Last1, _First2, _Last2); } + + struct _Traits_4 : _Find_traits_4 { + using _Ty = uint32_t; +#ifndef _M_ARM64EC + template + static __m256i _Spread_avx(__m256i _Val, const size_t _Needle_length_el) noexcept { + if constexpr (_Amount == 1) { + return _mm256_broadcastd_epi32(_mm256_castsi256_si128(_Val)); + } else if constexpr (_Amount == 2) { + return _mm256_broadcastq_epi64(_mm256_castsi256_si128(_Val)); + } else if constexpr (_Amount == 4) { + if (_Needle_length_el < 4) { + _Val = _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(0, 2, 1, 0)); + } + + return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 1, 0)); + } else if constexpr (_Amount == 8) { + if (_Needle_length_el < 8) { + const __m256i _Mask = _Avx2_tail_mask_32(_Needle_length_el); + // zero unused elements in sequential permutation mask, so will be filled by 1st + const __m256i _Perm = _mm256_and_si256(_mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0), _Mask); + _Val = _mm256_permutevar8x32_epi32(_Val, _Perm); + } + + return _Val; + } else { + static_assert(false, "Unexpected amount"); + } + } + + template + static __m256i _Shuffle_avx(const __m256i _Val) noexcept { + if constexpr (_Amount == 1) { + return _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(2, 3, 0, 1)); + } else if constexpr (_Amount == 2) { + return _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(1, 0, 3, 2)); + } else if constexpr (_Amount == 4) { + return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 3, 2)); + } else { + static_assert(false, "Unexpected amount"); + } + } +#endif // !_M_ARM64EC + }; + + struct _Traits_8 : _Find_traits_8 { + using _Ty = uint64_t; +#ifndef _M_ARM64EC + template + static __m256i _Spread_avx(const __m256i _Val, const size_t _Needle_length_el) noexcept { + if constexpr (_Amount == 1) { + return _mm256_broadcastq_epi64(_mm256_castsi256_si128(_Val)); + } else if constexpr (_Amount == 2) { + return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 1, 0)); + } else if constexpr (_Amount == 4) { + if (_Needle_length_el < 4) { + return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(0, 2, 1, 0)); + } + + return _Val; + } else { + static_assert(false, "Unexpected amount"); + } + } + + template + static __m256i _Shuffle_avx(const __m256i _Val) noexcept { + if constexpr (_Amount == 1) { + return _mm256_shuffle_epi32(_Val, _MM_SHUFFLE(1, 0, 3, 2)); + } else if constexpr (_Amount == 2) { + return _mm256_permute4x64_epi64(_Val, _MM_SHUFFLE(1, 0, 3, 2)); + } else { + static_assert(false, "Unexpected amount"); + } + } +#endif // !_M_ARM64EC + }; + +#ifndef _M_ARM64EC + template + __m256i _Shuffle_step(const __m256i _Data1, const __m256i _Data2s0) noexcept { + __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2s0); + if constexpr (_Needle_length_el_magnitude >= 2) { + const __m256i _Data2s1 = _Traits::_Shuffle_avx<1>(_Data2s0); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s1)); + if constexpr (_Needle_length_el_magnitude >= 4) { + const __m256i _Data2s2 = _Traits::_Shuffle_avx<2>(_Data2s0); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s2)); + const __m256i _Data2s3 = _Traits::_Shuffle_avx<1>(_Data2s2); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s3)); + if constexpr (_Needle_length_el_magnitude >= 8) { + const __m256i _Data2s4 = _Traits::_Shuffle_avx<4>(_Data2s0); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s4)); + const __m256i _Data2s5 = _Traits::_Shuffle_avx<1>(_Data2s4); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s5)); + const __m256i _Data2s6 = _Traits::_Shuffle_avx<2>(_Data2s4); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s6)); + const __m256i _Data2s7 = _Traits::_Shuffle_avx<1>(_Data2s6); + _Eq = _mm256_or_si256(_Eq, _Traits::_Cmp_avx(_Data1, _Data2s7)); + } + } + } + return _Eq; + } + + template + const void* _Shuffle_impl(const void* _First1, const void* const _Last1, const void* const _First2, + const size_t _Needle_length_el) noexcept { + using _Ty = _Traits::_Ty; + const __m256i _Data2 = _mm256_maskload_epi32( + reinterpret_cast(_First2), _Avx2_tail_mask_32(_Needle_length_el * (sizeof(_Ty) / 4))); + const __m256i _Data2s0 = _Traits::_Spread_avx<_Needle_length_el_magnitude>(_Data2, _Needle_length_el); + + const size_t _Haystack_length = _Byte_length(_First1, _Last1); + + const void* _Stop1 = _First1; + _Advance_bytes(_Stop1, _Haystack_length & ~size_t{0x1F}); + + for (; _First1 != _Stop1; _Advance_bytes(_First1, 32)) { + const __m256i _Data1 = _mm256_loadu_si256(static_cast(_First1)); + const __m256i _Eq = _Shuffle_step<_Traits, _Needle_length_el_magnitude>(_Data1, _Data2s0); + const int _Bingo = _mm256_movemask_epi8(_Eq); + + if (_Bingo != 0) { + const unsigned long _Offset = _tzcnt_u32(_Bingo); + _Advance_bytes(_First1, _Offset); + return _First1; + } + } + + if (const size_t _Haystack_tail_length = _Haystack_length & 0x1C; _Haystack_tail_length != 0) { + const __m256i _Tail_mask = _Avx2_tail_mask_32(_Haystack_tail_length >> 2); + const __m256i _Data1 = _mm256_maskload_epi32(static_cast(_First1), _Tail_mask); + const __m256i _Eq = _Shuffle_step<_Traits, _Needle_length_el_magnitude>(_Data1, _Data2s0); + const int _Bingo = _mm256_movemask_epi8(_mm256_and_si256(_Eq, _Tail_mask)); + + if (_Bingo != 0) { + const unsigned long _Offset = _tzcnt_u32(_Bingo); + _Advance_bytes(_First1, _Offset); + return _First1; + } + + _Advance_bytes(_First1, _Haystack_tail_length); + } + + return _First1; + } + #endif // !_M_ARM64EC - auto _Ptr_haystack = static_cast(_First1); - const auto _Ptr_haystack_end = static_cast(_Last1); - const auto _Ptr_needle = static_cast(_First2); - const auto _Ptr_needle_end = static_cast(_Last2); + template + const void* __stdcall _Impl_4_8(const void* const _First1, const void* const _Last1, const void* const _First2, + const void* const _Last2) noexcept { + using _Ty = _Traits::_Ty; +#ifndef _M_ARM64EC + if (_Use_avx2()) { + _Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414 - for (; _Ptr_haystack != _Ptr_haystack_end; ++_Ptr_haystack) { - for (auto _Ptr = _Ptr_needle; _Ptr != _Ptr_needle_end; ++_Ptr) { - if (*_Ptr_haystack == *_Ptr) { - return _Ptr_haystack; + const size_t _Needle_length = _Byte_length(_First2, _Last2); + const int _Needle_length_el = static_cast(_Needle_length / sizeof(_Ty)); + + // Special handling of small needle + // The generic approach could also handle it but with worse performance + if (_Needle_length_el == 0) { + return _Last1; + } else if (_Needle_length_el == 1) { + _STL_UNREACHABLE; // This is expected to be forwarded to 'find' on an upper level + } else if (_Needle_length_el == 2) { + return _Shuffle_impl<_Traits, 2>(_First1, _Last1, _First2, _Needle_length_el); + } else if (_Needle_length_el <= 4) { + return _Shuffle_impl<_Traits, 4>(_First1, _Last1, _First2, _Needle_length_el); + } else if (_Needle_length_el <= 8) { + if constexpr (sizeof(_Ty) == 4) { + return _Shuffle_impl<_Traits, 8>(_First1, _Last1, _First2, _Needle_length_el); + } } + + // Generic approach + const size_t _Needle_length_tail = _Needle_length & 0x1C; + const __m256i _Tail_mask = _Avx2_tail_mask_32(_Needle_length_tail >> 2); + + const void* _Stop2 = _First2; + _Advance_bytes(_Stop2, _Needle_length & ~size_t{0x1F}); + + for (auto _Ptr1 = static_cast(_First1); _Ptr1 != _Last1; ++_Ptr1) { + const auto _Data1 = _Traits::_Set_avx(*_Ptr1); + for (auto _Ptr2 = _First2; _Ptr2 != _Stop2; _Advance_bytes(_Ptr2, 32)) { + const __m256i _Data2 = _mm256_loadu_si256(static_cast(_Ptr2)); + const __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2); + if (!_mm256_testz_si256(_Eq, _Eq)) { + return _Ptr1; + } + } + + if (_Needle_length_tail != 0) { + const __m256i _Data2 = _mm256_maskload_epi32(static_cast(_Stop2), _Tail_mask); + const __m256i _Eq = _Traits::_Cmp_avx(_Data1, _Data2); + if (!_mm256_testz_si256(_Eq, _Tail_mask)) { + return _Ptr1; + } + } + } + + return _Last1; } +#endif // !_M_ARM64EC + return _Fallback<_Ty>(_First1, _Last1, _First2, _Last2); } - - return _Ptr_haystack; - } - + } // namespace __std_find_first_of template __declspec(noalias) size_t __stdcall __std_mismatch_impl( @@ -2354,12 +2569,22 @@ __declspec(noalias) size_t const void* __stdcall __std_find_first_of_trivial_1( const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept { - return __std_find_first_of_trivial_impl(_First1, _Last1, _First2, _Last2); + return __std_find_first_of::_Impl_pcmpestri(_First1, _Last1, _First2, _Last2); } const void* __stdcall __std_find_first_of_trivial_2( const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept { - return __std_find_first_of_trivial_impl(_First1, _Last1, _First2, _Last2); + return __std_find_first_of::_Impl_pcmpestri(_First1, _Last1, _First2, _Last2); +} + +const void* __stdcall __std_find_first_of_trivial_4( + const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept { + return __std_find_first_of::_Impl_4_8<__std_find_first_of::_Traits_4>(_First1, _Last1, _First2, _Last2); +} + +const void* __stdcall __std_find_first_of_trivial_8( + const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept { + return __std_find_first_of::_Impl_4_8<__std_find_first_of::_Traits_8>(_First1, _Last1, _First2, _Last2); } __declspec(noalias) size_t