Vectorize `find_end`, make sure ASan passes (#5042)

This commit is contained in:
Alex Guteniev 2024-10-30 07:34:50 -07:00 коммит произвёл GitHub
Родитель 493c14608d
Коммит cb1e359f93
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 322 добавлений и 0 удалений

Просмотреть файл

@ -59,6 +59,11 @@ const void* __stdcall __std_find_last_trivial_2(const void* _First, const void*
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;
const void* __stdcall __std_find_end_1(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
const void* __stdcall __std_find_end_2(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
@ -189,6 +194,19 @@ _Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val
}
}
template <class _Ty1, class _Ty2>
_Ty1* _Find_end_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_1(_First1, _Last1, _First2, _Count2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_2(_First1, _Last1, _First2, _Count2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}
template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
@ -3194,6 +3212,26 @@ _NODISCARD _CONSTEXPR20 _FwdIt1 find_end(
if constexpr (_Is_ranges_random_iter_v<_FwdIt1> && _Is_ranges_random_iter_v<_FwdIt2>) {
const _Iter_diff_t<_FwdIt2> _Count2 = _ULast2 - _UFirst2;
if (_Count2 > 0 && _Count2 <= _ULast1 - _UFirst1) {
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_UFirst1);
const auto _Ptr_res1 = _STD _Find_end_vectorized(
_Ptr1, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), static_cast<size_t>(_Count2));
if constexpr (is_pointer_v<decltype(_UFirst1)>) {
_UFirst1 = _Ptr_res1;
} else {
_UFirst1 += _Ptr_res1 - _Ptr1;
}
_STD _Seek_wrapped(_First1, _UFirst1);
return _First1;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (auto _UCandidate = _ULast1 - static_cast<_Iter_diff_t<_FwdIt1>>(_Count2);; --_UCandidate) {
if (_STD _Equal_rev_pred_unchecked(_UCandidate, _UFirst2, _ULast2, _STD _Pass_fn(_Pred))) {
_STD _Seek_wrapped(_First1, _UCandidate);
@ -3297,6 +3335,34 @@ namespace ranges {
if (_Count2 > 0 && _Count2 <= _Count1) {
const auto _Count2_as1 = static_cast<iter_difference_t<_It1>>(_Count2);
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
&& is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated()) {
const auto _Ptr1 = _STD to_address(_First1);
const auto _Ptr2 = _STD to_address(_First2);
const auto _Ptr_last1 = _Ptr1 + _Count1;
const auto _Ptr_res1 =
_STD _Find_end_vectorized(_Ptr1, _Ptr_last1, _Ptr2, static_cast<size_t>(_Count2));
if constexpr (is_pointer_v<_It1>) {
if (_Ptr_res1 != _Ptr_last1) {
return {_Ptr_res1, _Ptr_res1 + _Count2};
} else {
return {_Ptr_res1, _Ptr_res1};
}
} else {
_First1 += _Ptr_res1 - _Ptr1;
if (_Ptr_res1 != _Ptr_last1) {
return {_First1, _First1 + _Count2_as1};
} else {
return {_First1, _First1};
}
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (auto _Candidate = _First1 + (_Count1 - _Count2_as1);; --_Candidate) {
auto _Match_and_mid1 =

Просмотреть файл

@ -3633,6 +3633,252 @@ namespace {
return _Last1;
}
}
template <class _Traits, class _Ty>
const void* __stdcall __std_find_end_impl(
const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept {
if (_Count2 == 0) {
return _Last1;
}
if (_Count2 == 1) {
return __std_find_last_trivial_impl<_Traits>(_First1, _Last1, *static_cast<const _Ty*>(_First2));
}
const size_t _Size_bytes_1 = _Byte_length(_First1, _Last1);
const size_t _Size_bytes_2 = _Count2 * sizeof(_Ty);
if (_Size_bytes_1 < _Size_bytes_2) {
return _Last1;
}
#ifndef _M_ARM64EC
if (_Use_sse42() && _Size_bytes_1 >= 16) {
constexpr int _Op = (sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ORDERED;
constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8;
static constexpr int8_t _Low_part_mask[] = {//
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, //
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
if (_Size_bytes_2 <= 16) {
const int _Size_el_2 = static_cast<int>(_Count2);
constexpr unsigned int _Whole_mask = (1 << _Part_size_el) - 1;
const unsigned int _Needle_fit_mask = (1 << (_Part_size_el - _Size_el_2 + 1)) - 1;
const unsigned int _Needle_unfit_mask = _Whole_mask ^ _Needle_fit_mask;
const void* _Stop1 = _First1;
_Advance_bytes(_Stop1, _Size_bytes_1 & 0xF);
alignas(16) uint8_t _Tmp2[16];
memcpy(_Tmp2, _First2, _Size_bytes_2);
const __m128i _Data2 = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp2));
const void* _Mid1 = _Last1;
_Rewind_bytes(_Mid1, 16);
const auto _Check_fit = [&_Mid1, _Needle_fit_mask](const unsigned int _Match) noexcept {
const unsigned int _Fit_match = _Match & _Needle_fit_mask;
if (_Fit_match != 0) {
unsigned long _Match_last_pos;
// CodeQL [SM02313] Result is always initialized: we just tested that _Fit_match is non-zero.
_BitScanReverse(&_Match_last_pos, _Fit_match);
_Advance_bytes(_Mid1, _Match_last_pos * sizeof(_Ty));
return true;
}
return false;
};
#pragma warning(push)
#pragma warning(disable : 4324) // structure was padded due to alignment specifier
const auto _Check_unfit = [=, &_Mid1](const unsigned int _Match) noexcept {
long _Unfit_match = _Match & _Needle_unfit_mask;
while (_Unfit_match != 0) {
const void* _Tmp1 = _Mid1;
unsigned long _Match_last_pos;
// CodeQL [SM02313] Result is always initialized: we just tested that _Unfit_match is non-zero.
_BitScanReverse(&_Match_last_pos, _Unfit_match);
_Advance_bytes(_Tmp1, _Match_last_pos * sizeof(_Ty));
const __m128i _Match_data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Tmp1));
const __m128i _Cmp_result = _mm_xor_si128(_Data2, _Match_data);
const __m128i _Data_mask =
_mm_loadu_si128(reinterpret_cast<const __m128i*>(_Low_part_mask + 16 - _Size_bytes_2));
if (_mm_testz_si128(_Cmp_result, _Data_mask)) {
_Mid1 = _Tmp1;
return true;
}
_bittestandreset(&_Unfit_match, _Match_last_pos);
}
return false;
};
#pragma warning(pop)
// TRANSITION, DevCom-10689455, the code below could test with _mm_cmpestrc,
// if it has been fused with _mm_cmpestrm.
// The very last part, for any match needle should fit, otherwise false match
const __m128i _Data1_last = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Mid1));
const auto _Match_last = _mm_cmpestrm(_Data2, _Size_el_2, _Data1_last, _Part_size_el, _Op);
const unsigned int _Match_last_val = _mm_cvtsi128_si32(_Match_last);
if (_Check_fit(_Match_last_val)) {
return _Mid1;
}
// The middle part, fit and unfit needle
while (_Mid1 != _Stop1) {
_Rewind_bytes(_Mid1, 16);
const __m128i _Data1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Mid1));
const auto _Match = _mm_cmpestrm(_Data2, _Size_el_2, _Data1, _Part_size_el, _Op);
const unsigned int _Match_val = _mm_cvtsi128_si32(_Match);
if (_Match_val != 0 && (_Check_unfit(_Match_val) || _Check_fit(_Match_val))) {
return _Mid1;
}
}
// The first part, fit and unfit needle, mask out already processed positions
if (const size_t _Tail_bytes_1 = _Size_bytes_1 & 0xF; _Tail_bytes_1 != 0) {
_Mid1 = _First1;
const __m128i _Data1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Mid1));
const auto _Match = _mm_cmpestrm(_Data2, _Size_el_2, _Data1, _Part_size_el, _Op);
const unsigned int _Match_val = _mm_cvtsi128_si32(_Match) & ((1 << _Tail_bytes_1) - 1);
if (_Match_val != 0 && (_Check_unfit(_Match_val) || _Check_fit(_Match_val))) {
return _Mid1;
}
}
return _Last1;
} else {
const __m128i _Data2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First2));
const void* _Tail2 = _First2;
_Advance_bytes(_Tail2, 16);
const void* _Mid1 = _Last1;
_Rewind_bytes(_Mid1, _Size_bytes_2);
const size_t _Size_diff_bytes = _Size_bytes_1 - _Size_bytes_2;
const void* _Stop1 = _First1;
_Advance_bytes(_Stop1, _Size_diff_bytes & 0xF);
#pragma warning(push)
#pragma warning(disable : 4324) // structure was padded due to alignment specifier
const auto _Check = [=, &_Mid1](long _Match) noexcept {
while (_Match != 0) {
const void* _Tmp1 = _Mid1;
unsigned long _Match_last_pos;
// CodeQL [SM02313] Result is always initialized: we just tested that _Match is non-zero.
_BitScanReverse(&_Match_last_pos, _Match);
bool _Match_1st_16 = true;
if (_Match_last_pos != 0) {
_Advance_bytes(_Tmp1, _Match_last_pos * sizeof(_Ty));
const __m128i _Match_data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Tmp1));
const __m128i _Cmp_result = _mm_xor_si128(_Data2, _Match_data);
if (!_mm_testz_si128(_Cmp_result, _Cmp_result)) {
_Match_1st_16 = false;
}
}
if (_Match_1st_16) {
const void* _Tail1 = _Tmp1;
_Advance_bytes(_Tail1, 16);
if (memcmp(_Tail1, _Tail2, _Size_bytes_2 - 16) == 0) {
_Mid1 = _Tmp1;
return true;
}
}
_bittestandreset(&_Match, _Match_last_pos);
}
return false;
};
#pragma warning(pop)
// The very last part, just compare, as true match must start with first symbol
const __m128i _Data1_last = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Mid1));
const __m128i _Match_last = _mm_xor_si128(_Data2, _Data1_last);
if (_mm_testz_si128(_Match_last, _Match_last)) {
// Matched 16 bytes, check the rest
const void* _Tail1 = _Mid1;
_Advance_bytes(_Tail1, 16);
if (memcmp(_Tail1, _Tail2, _Size_bytes_2 - 16) == 0) {
return _Mid1;
}
}
// TRANSITION, DevCom-10689455, the code below could test with _mm_cmpestrc,
// if it has been fused with _mm_cmpestrm.
// The main part, match all characters
while (_Mid1 != _Stop1) {
_Rewind_bytes(_Mid1, 16);
const __m128i _Data1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Mid1));
const auto _Match = _mm_cmpestrm(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op);
const unsigned int _Match_val = _mm_cvtsi128_si32(_Match);
if (_Match_val != 0 && _Check(_Match_val)) {
return _Mid1;
}
}
// The first part, mask out already processed positions
if (const size_t _Tail_bytes_1 = _Size_diff_bytes & 0xF; _Tail_bytes_1 != 0) {
_Mid1 = _First1;
const __m128i _Data1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_Mid1));
const auto _Match = _mm_cmpestrm(_Data2, _Part_size_el, _Data1, _Part_size_el, _Op);
const unsigned int _Match_val = _mm_cvtsi128_si32(_Match) & ((1 << _Tail_bytes_1) - 1);
if (_Match_val != 0 && _Check(_Match_val)) {
return _Mid1;
}
}
return _Last1;
}
} else
#endif // !defined(_M_ARM64EC)
{
auto _Ptr1 = static_cast<const _Ty*>(_Last1) - _Count2;
const auto _Ptr2 = static_cast<const _Ty*>(_First2);
for (;;) {
if (*_Ptr1 == *_Ptr2) {
bool _Equal = true;
for (size_t _Idx = 1; _Idx != _Count2; ++_Idx) {
if (_Ptr1[_Idx] != _Ptr2[_Idx]) {
_Equal = false;
break;
}
}
if (_Equal) {
return _Ptr1;
}
}
if (_Ptr1 == _First1) {
return _Last1;
}
--_Ptr1;
}
}
}
} // unnamed namespace
extern "C" {
@ -3757,6 +4003,16 @@ const void* __stdcall __std_search_2(
return __std_search_impl<_Find_traits_2, uint16_t>(_First1, _Last1, _First2, _Count2);
}
const void* __stdcall __std_find_end_1(
const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept {
return __std_find_end_impl<_Find_traits_1, uint8_t>(_First1, _Last1, _First2, _Count2);
}
const void* __stdcall __std_find_end_2(
const void* const _First1, const void* const _Last1, const void* const _First2, const size_t _Count2) noexcept {
return __std_find_end_impl<_Find_traits_2, uint16_t>(_First1, _Last1, _First2, _Count2);
}
__declspec(noalias) size_t __stdcall __std_mismatch_1(
const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return __std_mismatch_impl<_Find_traits_1, uint8_t>(_First1, _First2, _Count);