diff --git a/.gitignore b/.gitignore index 1b8e6af..3973a16 100644 --- a/.gitignore +++ b/.gitignore @@ -331,3 +331,6 @@ ASALocalRun/ # Visual Studio Code directory .vscode/ + +# CMake/Build output +build/ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..626814c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.11) + +# Set by build server to speed up build/reduce file/object size +option(FAST_BUILD "Sets options to speed up build/reduce obj/executable size" OFF) + +if (NOT DEFINED WIL_BUILD_VERSION) + set(WIL_BUILD_VERSION "0.0.0") +endif() + +add_subdirectory(packaging) +add_subdirectory(tests) diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index c2a1afb..eee568b 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -32,3 +32,29 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +Catch2 + +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/cmake/common_build_flags.cmake b/cmake/common_build_flags.cmake new file mode 100644 index 0000000..ce84c1a --- /dev/null +++ b/cmake/common_build_flags.cmake @@ -0,0 +1,62 @@ + +# E.g. replace_cxx_flag("/W[0-4]", "/W4") +macro(replace_cxx_flag pattern text) + foreach (flag + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + + string(REGEX REPLACE "${pattern}" "${text}" ${flag} "${${flag}}") + + endforeach() +endmacro() + +macro(append_cxx_flag text) + foreach (flag + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + + string(APPEND ${flag} " ${text}") + + endforeach() +endmacro() + +# Fixup default compiler settings + +# Be as strict as reasonably possible, since we want to support consumers using strict warning levels +replace_cxx_flag("/W[0-4]" "/W4") +append_cxx_flag("/WX") + +# We want to be as conformant as possible, so tell MSVC to not be permissive (note that this has no effect on clang-cl) +append_cxx_flag("/permissive-") + +# wistd::function has padding due to alignment. This is expected +append_cxx_flag("/wd4324") + +if (${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + # Ignore a few Clang warnings. We may want to revisit in the future to see if any of these can/should be removed + append_cxx_flag("-Wno-switch") + append_cxx_flag("-Wno-invalid-noreturn") + append_cxx_flag("-Wno-c++17-compat-mangling") + append_cxx_flag("-Wno-missing-field-initializers") + + # For tests, we want to be able to test self assignment, so disable this warning + append_cxx_flag("-Wno-self-assign-overloaded") + + # clang-cl does not understand the /permissive- flag (or at least it opts to ignore it). We can achieve similar + # results through the following flags. + # TODO: https://github.com/Microsoft/wil/issues/10 - not yet clean enough to have this on by default + # append_cxx_flag("-fno-delayed-template-parsing") + + # NOTE: Windows headers not clean enough for us to realistically attempt to start fixing these errors yet. That + # said, errors that originate from WIL headers may benefit + # append_cxx_flag("-fno-ms-compatibility") +else() + # Flags that are either ignored or unrecognized by clang-cl + # TODO: https://github.com/Microsoft/wil/issues/6 + # append_cxx_flag("/experimental:preprocessor") + + # CRT headers are not yet /experimental:preprocessor clean, so work around the known issues + # append_cxx_flag("/Wv:18") + + append_cxx_flag("/bigobj") +endif() diff --git a/include/wil/com.h b/include/wil/com.h index 6db2fd0..a19ff96 100644 --- a/include/wil/com.h +++ b/include/wil/com.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_COM_INCLUDED #define __WIL_COM_INCLUDED diff --git a/include/wil/common.h b/include/wil/common.h index 26fc7ce..156e4f3 100644 --- a/include/wil/common.h +++ b/include/wil/common.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_COMMON_INCLUDED #define __WIL_COMMON_INCLUDED @@ -165,25 +174,6 @@ Three exception modes are available: //! This macro is for use in other macros to paste two tokens together, such as a constant and the __LINE__ macro. #define WI_PASTE(a, b) __WI_PASTE_imp(a, b) -//! This macro is used to invoke another function-like macro with the specified argument list. This is primarily useful -//! when the invocation does not begin with a '(' character, e.g. when using the `WI_FLATTEN` macro. For example: -//! ~~~ -//! #define CALL_WITH_BAR_SUFFIX(prefix, ...) WI_PASTE(prefix, Bar) WI_FLATTEN((__VA_ARGS__)) -//! -//! // The macro invocation gets expanded to 'FooBar(x)' and does not get processed any further -//! #define FooBar(x) foo_bar(x) -//! CALL_WITH_BAR_SUFFIX(Foo, x); -//! ~~~ -//! Instead, you should do: -//! ~~~ -//! #define CALL_WITH_BAR_SUFFIX(prefix, ...) WI_MACRO_INVOKE(WI_PASTE(prefix, Bar), WI_FLATTEN((__VA_ARGS__))) -//! -//! // The macro invocation gets expanded to 'foo_bar(x)' -//! #define FooBar(x) foo_bar(x) -//! CALL_WITH_BAR_SUFFIX(Foo, x); -//! ~~~ -#define WI_MACRO_INVOKE(fn, argsList) fn argsList - /// @cond #define __WI_ARGS_COUNT1(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, \ A30, A31, A32, A33, A34, A35, A36, A37, A38, A39, A40, A41, A42, A43, A44, A45, A46, A47, A48, A49, A50, A51, A52, A53, A54, A55, A56, A57, A58, A59, \ @@ -200,108 +190,108 @@ Three exception modes are available: /// @cond #define __WI_FOR_imp0( fn) -#define __WI_FOR_imp1( fn, arg, ...) __WI_FOR_impN( 0, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp2( fn, arg, ...) __WI_FOR_impN( 1, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp3( fn, arg, ...) __WI_FOR_impN( 2, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp4( fn, arg, ...) __WI_FOR_impN( 3, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp5( fn, arg, ...) __WI_FOR_impN( 4, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp6( fn, arg, ...) __WI_FOR_impN( 5, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp7( fn, arg, ...) __WI_FOR_impN( 6, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp8( fn, arg, ...) __WI_FOR_impN( 7, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp9( fn, arg, ...) __WI_FOR_impN( 8, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp10(fn, arg, ...) __WI_FOR_impN( 9, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp11(fn, arg, ...) __WI_FOR_impN(10, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp12(fn, arg, ...) __WI_FOR_impN(11, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp13(fn, arg, ...) __WI_FOR_impN(12, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp14(fn, arg, ...) __WI_FOR_impN(13, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp15(fn, arg, ...) __WI_FOR_impN(14, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp16(fn, arg, ...) __WI_FOR_impN(15, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp17(fn, arg, ...) __WI_FOR_impN(16, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp18(fn, arg, ...) __WI_FOR_impN(17, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp19(fn, arg, ...) __WI_FOR_impN(18, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp20(fn, arg, ...) __WI_FOR_impN(19, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp21(fn, arg, ...) __WI_FOR_impN(20, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp22(fn, arg, ...) __WI_FOR_impN(21, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp23(fn, arg, ...) __WI_FOR_impN(22, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp24(fn, arg, ...) __WI_FOR_impN(23, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp25(fn, arg, ...) __WI_FOR_impN(24, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp26(fn, arg, ...) __WI_FOR_impN(25, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp27(fn, arg, ...) __WI_FOR_impN(26, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp28(fn, arg, ...) __WI_FOR_impN(27, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp29(fn, arg, ...) __WI_FOR_impN(28, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp30(fn, arg, ...) __WI_FOR_impN(29, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp31(fn, arg, ...) __WI_FOR_impN(30, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp32(fn, arg, ...) __WI_FOR_impN(31, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp33(fn, arg, ...) __WI_FOR_impN(32, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp34(fn, arg, ...) __WI_FOR_impN(33, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp35(fn, arg, ...) __WI_FOR_impN(34, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp36(fn, arg, ...) __WI_FOR_impN(35, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp37(fn, arg, ...) __WI_FOR_impN(36, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp38(fn, arg, ...) __WI_FOR_impN(37, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp39(fn, arg, ...) __WI_FOR_impN(38, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp40(fn, arg, ...) __WI_FOR_impN(39, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp41(fn, arg, ...) __WI_FOR_impN(40, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp42(fn, arg, ...) __WI_FOR_impN(41, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp43(fn, arg, ...) __WI_FOR_impN(42, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp44(fn, arg, ...) __WI_FOR_impN(43, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp45(fn, arg, ...) __WI_FOR_impN(44, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp46(fn, arg, ...) __WI_FOR_impN(45, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp47(fn, arg, ...) __WI_FOR_impN(46, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp48(fn, arg, ...) __WI_FOR_impN(47, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp49(fn, arg, ...) __WI_FOR_impN(48, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp50(fn, arg, ...) __WI_FOR_impN(49, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp51(fn, arg, ...) __WI_FOR_impN(50, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp52(fn, arg, ...) __WI_FOR_impN(51, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp53(fn, arg, ...) __WI_FOR_impN(52, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp54(fn, arg, ...) __WI_FOR_impN(53, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp55(fn, arg, ...) __WI_FOR_impN(54, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp56(fn, arg, ...) __WI_FOR_impN(55, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp57(fn, arg, ...) __WI_FOR_impN(56, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp58(fn, arg, ...) __WI_FOR_impN(57, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp59(fn, arg, ...) __WI_FOR_impN(58, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp60(fn, arg, ...) __WI_FOR_impN(59, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp61(fn, arg, ...) __WI_FOR_impN(60, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp62(fn, arg, ...) __WI_FOR_impN(61, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp63(fn, arg, ...) __WI_FOR_impN(62, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp64(fn, arg, ...) __WI_FOR_impN(63, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp65(fn, arg, ...) __WI_FOR_impN(64, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp66(fn, arg, ...) __WI_FOR_impN(65, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp67(fn, arg, ...) __WI_FOR_impN(66, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp68(fn, arg, ...) __WI_FOR_impN(67, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp69(fn, arg, ...) __WI_FOR_impN(68, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp70(fn, arg, ...) __WI_FOR_impN(69, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp71(fn, arg, ...) __WI_FOR_impN(70, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp72(fn, arg, ...) __WI_FOR_impN(71, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp73(fn, arg, ...) __WI_FOR_impN(72, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp74(fn, arg, ...) __WI_FOR_impN(73, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp75(fn, arg, ...) __WI_FOR_impN(74, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp76(fn, arg, ...) __WI_FOR_impN(75, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp77(fn, arg, ...) __WI_FOR_impN(76, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp78(fn, arg, ...) __WI_FOR_impN(77, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp79(fn, arg, ...) __WI_FOR_impN(78, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp80(fn, arg, ...) __WI_FOR_impN(79, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp81(fn, arg, ...) __WI_FOR_impN(80, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp82(fn, arg, ...) __WI_FOR_impN(81, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp83(fn, arg, ...) __WI_FOR_impN(82, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp84(fn, arg, ...) __WI_FOR_impN(83, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp85(fn, arg, ...) __WI_FOR_impN(84, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp86(fn, arg, ...) __WI_FOR_impN(85, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp87(fn, arg, ...) __WI_FOR_impN(86, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp88(fn, arg, ...) __WI_FOR_impN(87, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp89(fn, arg, ...) __WI_FOR_impN(88, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp90(fn, arg, ...) __WI_FOR_impN(89, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp91(fn, arg, ...) __WI_FOR_impN(90, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp92(fn, arg, ...) __WI_FOR_impN(91, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp93(fn, arg, ...) __WI_FOR_impN(92, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp94(fn, arg, ...) __WI_FOR_impN(93, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp95(fn, arg, ...) __WI_FOR_impN(94, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp96(fn, arg, ...) __WI_FOR_impN(95, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp97(fn, arg, ...) __WI_FOR_impN(96, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp98(fn, arg, ...) __WI_FOR_impN(97, fn, arg, fn, ##__VA_ARGS__) -#define __WI_FOR_imp99(fn, arg, ...) __WI_FOR_impN(98, fn, arg, fn, ##__VA_ARGS__) +#define __WI_FOR_imp1( fn, arg, ...) __WI_FOR_impN( 0, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp2( fn, arg, ...) __WI_FOR_impN( 1, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp3( fn, arg, ...) __WI_FOR_impN( 2, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp4( fn, arg, ...) __WI_FOR_impN( 3, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp5( fn, arg, ...) __WI_FOR_impN( 4, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp6( fn, arg, ...) __WI_FOR_impN( 5, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp7( fn, arg, ...) __WI_FOR_impN( 6, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp8( fn, arg, ...) __WI_FOR_impN( 7, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp9( fn, arg, ...) __WI_FOR_impN( 8, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp10(fn, arg, ...) __WI_FOR_impN( 9, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp11(fn, arg, ...) __WI_FOR_impN(10, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp12(fn, arg, ...) __WI_FOR_impN(11, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp13(fn, arg, ...) __WI_FOR_impN(12, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp14(fn, arg, ...) __WI_FOR_impN(13, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp15(fn, arg, ...) __WI_FOR_impN(14, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp16(fn, arg, ...) __WI_FOR_impN(15, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp17(fn, arg, ...) __WI_FOR_impN(16, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp18(fn, arg, ...) __WI_FOR_impN(17, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp19(fn, arg, ...) __WI_FOR_impN(18, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp20(fn, arg, ...) __WI_FOR_impN(19, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp21(fn, arg, ...) __WI_FOR_impN(20, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp22(fn, arg, ...) __WI_FOR_impN(21, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp23(fn, arg, ...) __WI_FOR_impN(22, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp24(fn, arg, ...) __WI_FOR_impN(23, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp25(fn, arg, ...) __WI_FOR_impN(24, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp26(fn, arg, ...) __WI_FOR_impN(25, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp27(fn, arg, ...) __WI_FOR_impN(26, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp28(fn, arg, ...) __WI_FOR_impN(27, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp29(fn, arg, ...) __WI_FOR_impN(28, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp30(fn, arg, ...) __WI_FOR_impN(29, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp31(fn, arg, ...) __WI_FOR_impN(30, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp32(fn, arg, ...) __WI_FOR_impN(31, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp33(fn, arg, ...) __WI_FOR_impN(32, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp34(fn, arg, ...) __WI_FOR_impN(33, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp35(fn, arg, ...) __WI_FOR_impN(34, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp36(fn, arg, ...) __WI_FOR_impN(35, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp37(fn, arg, ...) __WI_FOR_impN(36, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp38(fn, arg, ...) __WI_FOR_impN(37, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp39(fn, arg, ...) __WI_FOR_impN(38, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp40(fn, arg, ...) __WI_FOR_impN(39, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp41(fn, arg, ...) __WI_FOR_impN(40, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp42(fn, arg, ...) __WI_FOR_impN(41, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp43(fn, arg, ...) __WI_FOR_impN(42, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp44(fn, arg, ...) __WI_FOR_impN(43, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp45(fn, arg, ...) __WI_FOR_impN(44, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp46(fn, arg, ...) __WI_FOR_impN(45, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp47(fn, arg, ...) __WI_FOR_impN(46, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp48(fn, arg, ...) __WI_FOR_impN(47, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp49(fn, arg, ...) __WI_FOR_impN(48, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp50(fn, arg, ...) __WI_FOR_impN(49, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp51(fn, arg, ...) __WI_FOR_impN(50, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp52(fn, arg, ...) __WI_FOR_impN(51, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp53(fn, arg, ...) __WI_FOR_impN(52, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp54(fn, arg, ...) __WI_FOR_impN(53, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp55(fn, arg, ...) __WI_FOR_impN(54, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp56(fn, arg, ...) __WI_FOR_impN(55, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp57(fn, arg, ...) __WI_FOR_impN(56, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp58(fn, arg, ...) __WI_FOR_impN(57, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp59(fn, arg, ...) __WI_FOR_impN(58, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp60(fn, arg, ...) __WI_FOR_impN(59, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp61(fn, arg, ...) __WI_FOR_impN(60, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp62(fn, arg, ...) __WI_FOR_impN(61, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp63(fn, arg, ...) __WI_FOR_impN(62, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp64(fn, arg, ...) __WI_FOR_impN(63, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp65(fn, arg, ...) __WI_FOR_impN(64, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp66(fn, arg, ...) __WI_FOR_impN(65, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp67(fn, arg, ...) __WI_FOR_impN(66, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp68(fn, arg, ...) __WI_FOR_impN(67, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp69(fn, arg, ...) __WI_FOR_impN(68, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp70(fn, arg, ...) __WI_FOR_impN(69, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp71(fn, arg, ...) __WI_FOR_impN(70, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp72(fn, arg, ...) __WI_FOR_impN(71, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp73(fn, arg, ...) __WI_FOR_impN(72, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp74(fn, arg, ...) __WI_FOR_impN(73, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp75(fn, arg, ...) __WI_FOR_impN(74, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp76(fn, arg, ...) __WI_FOR_impN(75, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp77(fn, arg, ...) __WI_FOR_impN(76, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp78(fn, arg, ...) __WI_FOR_impN(77, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp79(fn, arg, ...) __WI_FOR_impN(78, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp80(fn, arg, ...) __WI_FOR_impN(79, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp81(fn, arg, ...) __WI_FOR_impN(80, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp82(fn, arg, ...) __WI_FOR_impN(81, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp83(fn, arg, ...) __WI_FOR_impN(82, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp84(fn, arg, ...) __WI_FOR_impN(83, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp85(fn, arg, ...) __WI_FOR_impN(84, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp86(fn, arg, ...) __WI_FOR_impN(85, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp87(fn, arg, ...) __WI_FOR_impN(86, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp88(fn, arg, ...) __WI_FOR_impN(87, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp89(fn, arg, ...) __WI_FOR_impN(88, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp90(fn, arg, ...) __WI_FOR_impN(89, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp91(fn, arg, ...) __WI_FOR_impN(90, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp92(fn, arg, ...) __WI_FOR_impN(91, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp93(fn, arg, ...) __WI_FOR_impN(92, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp94(fn, arg, ...) __WI_FOR_impN(93, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp95(fn, arg, ...) __WI_FOR_impN(94, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp96(fn, arg, ...) __WI_FOR_impN(95, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp97(fn, arg, ...) __WI_FOR_impN(96, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp98(fn, arg, ...) __WI_FOR_impN(97, fn, arg, fn, __VA_ARGS__) +#define __WI_FOR_imp99(fn, arg, ...) __WI_FOR_impN(98, fn, arg, fn, __VA_ARGS__) #define __WI_FOR_impN(n, fn, arg, ...) \ fn(arg) \ - WI_MACRO_INVOKE(WI_PASTE(__WI_FOR_imp, n), WI_FLATTEN((__VA_ARGS__))) + WI_PASTE(__WI_FOR_imp, n) WI_FLATTEN((__VA_ARGS__)) #define __WI_FOR_imp(n, fnAndArgs) WI_PASTE(__WI_FOR_imp, n) fnAndArgs /// @endcond @@ -372,14 +362,14 @@ Three exception modes are available: //! Set a single compile-time constant `flag` in the variable `var`. #define WI_SetFlag(var, flag) WI_SetAllFlags(var, WI_StaticAssertSingleBitSet(flag)) //! Conditionally sets a single compile-time constant `flag` in the variable `var` only if `condition` is true. -#define WI_SetFlagIf(var, flag, condition) do { if (wil::verify_bool(condition)) { WI_SetFlag(var, flag); } } while (0, 0) +#define WI_SetFlagIf(var, flag, condition) do { if (wil::verify_bool(condition)) { WI_SetFlag(var, flag); } } while ((void)0, 0) //! Clear zero or more bitflags specified by `flags` from the variable `var`. #define WI_ClearAllFlags(var, flags) ((var) &= ~(flags)) //! Clear a single compile-time constant `flag` from the variable `var`. #define WI_ClearFlag(var, flag) WI_ClearAllFlags(var, WI_StaticAssertSingleBitSet(flag)) //! Conditionally clear a single compile-time constant `flag` in the variable `var` only if `condition` is true. -#define WI_ClearFlagIf(var, flag, condition) do { if (wil::verify_bool(condition)) { WI_ClearFlag(var, flag); } } while (0, 0) +#define WI_ClearFlagIf(var, flag, condition) do { if (wil::verify_bool(condition)) { WI_ClearFlag(var, flag); } } while ((void)0, 0) //! Changes a single compile-time constant `flag` in the variable `var` to be set if `isFlagSet` is true or cleared if `isFlagSet` is false. #define WI_UpdateFlag(var, flag, isFlagSet) (wil::verify_bool(isFlagSet) ? WI_SetFlag(var, flag) : WI_ClearFlag(var, flag)) @@ -435,7 +425,6 @@ WI_HEADER_INITITALIZATION_FUNCTION(InitializeDesktopFamilyApis, [] { g_pfnGetModuleName = GetCurrentModuleName; g_pfnFailFastInLoaderCallout = FailFastInLoaderCallout; - g_pfnRtlNtStatusToDosErrorNoTeb = RtlNtStatusToDosErrorNoTeb; return 1; }); #endif @@ -457,7 +446,7 @@ doing it with global function pointers and header initialization allows a runtim #endif -/** All Windows Internal Library classes and functions are located within the "wil" namespace. +/** All Windows Implementation Library classes and functions are located within the "wil" namespace. The 'wil' namespace is an intentionally short name as the intent is for code to be able to reference the namespace directly (example: `wil::srwlock lock;`) without a using statement. Resist adding a using statement for wil to avoid introducing potential name collisions between wil and other namespaces. */ @@ -578,7 +567,7 @@ namespace wil } template - __forceinline bool verify_bool(T val) + __forceinline bool verify_bool(T /*val*/) { static_assert(!wistd::is_same::value, "Wrong Type: bool/BOOL/BOOLEAN/boolean expected"); } diff --git a/include/wil/cppwinrt.h b/include/wil/cppwinrt.h index 8a234b8..61c3a57 100644 --- a/include/wil/cppwinrt.h +++ b/include/wil/cppwinrt.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_CPPWINRT_INCLUDED #define __WIL_CPPWINRT_INCLUDED @@ -28,7 +37,7 @@ #define WINRT_EXTERNAL_CATCH_CLAUSE \ catch (const wil::ResultException& e) \ { \ - return hresult_error(e.GetErrorCode(), to_hstring(e.what())).to_abi(); \ + return winrt::hresult_error(e.GetErrorCode(), winrt::to_hstring(e.what())).to_abi(); \ } namespace wil::details diff --git a/include/wil/filesystem.h b/include/wil/filesystem.h index c194bb9..e223290 100644 --- a/include/wil/filesystem.h +++ b/include/wil/filesystem.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_FILESYSTEM_INCLUDED #define __WIL_FILESYSTEM_INCLUDED @@ -477,7 +486,7 @@ namespace wil struct folder_change_reader_state { folder_change_reader_state(bool isRecursive, FolderChangeEvents filter, wistd::function &&callback) - : m_isRecursive(isRecursive), m_filter(filter), m_callback(wistd::move(callback)) + : m_callback(wistd::move(callback)), m_isRecursive(isRecursive), m_filter(filter) { } @@ -525,7 +534,7 @@ namespace wil { // This operation does not have the usual semantic of returning // ERROR_IO_PENDING. - // NT_ASSERT(hr != HRESULT_FROM_WIN32(ERROR_IO_PENDING)); + // WI_ASSERT(hr != HRESULT_FROM_WIN32(ERROR_IO_PENDING)); // If the operation failed for whatever reason, ensure the TP // ref counts are accurate. @@ -585,7 +594,7 @@ namespace wil ULONG result, ULONG_PTR /* BytesTransferred */, TP_IO * /* Io */) { auto readerState = static_cast(context); - // NT_ASSERT(overlapped == &readerState->m_overlapped); + // WI_ASSERT(overlapped == &readerState->m_overlapped); bool requeue = true; if (result == ERROR_SUCCESS) diff --git a/include/wil/registry.h b/include/wil/registry.h index 909e68b..0ca7f92 100644 --- a/include/wil/registry.h +++ b/include/wil/registry.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_REGISTRY_INCLUDED #define __WIL_REGISTRY_INCLUDED @@ -54,7 +63,7 @@ namespace wil struct registry_watcher_state { registry_watcher_state(unique_hkey &&keyToWatch, bool isRecursive, wistd::function &&callback) - : m_keyToWatch(wistd::move(keyToWatch)), m_callback(wistd::move(callback)), m_isRecursive(isRecursive) + : m_callback(wistd::move(callback)), m_keyToWatch(wistd::move(keyToWatch)), m_isRecursive(isRecursive) { } wistd::function m_callback; diff --git a/include/wil/resource.h b/include/wil/resource.h index 17b00b7..ee860e3 100644 --- a/include/wil/resource.h +++ b/include/wil/resource.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #include "result_macros.h" #include "wistd_functional.h" #include "wistd_memory.h" @@ -461,7 +470,7 @@ namespace wil } lambda_call_log(lambda_call_log&& other) WI_NOEXCEPT : - m_info(other.m_info), m_address(other.m_address), m_lambda(wistd::move(other.m_lambda)), m_call(other.m_call) + m_address(other.m_address), m_info(other.m_info), m_lambda(wistd::move(other.m_lambda)), m_call(other.m_call) { other.m_call = false; } @@ -1783,7 +1792,7 @@ namespace std { size_t operator()(wil::unique_any_t const &val) const { - return (hash::pointer>()(val.get())); + return (hash::pointer>()(val.get())); } }; } @@ -1884,7 +1893,7 @@ namespace wil { void reset(wistd::nullptr_t) WI_NOEXCEPT { - static_assert(wistd::is_same::value, "reset(nullptr): valid only for handle types using nullptr as the invalid value"); + static_assert(wistd::is_same::value, "reset(nullptr): valid only for handle types using nullptr as the invalid value"); reset(); } @@ -1909,14 +1918,9 @@ namespace wil { return m_ptr.use_count(); } - bool unique() const WI_NOEXCEPT - { - return m_ptr.unique(); - } - private: template - friend class weak_any; + friend class ::wil::weak_any; std::shared_ptr m_ptr; }; @@ -1945,7 +1949,7 @@ namespace wil { shared_any_t(wistd::nullptr_t) WI_NOEXCEPT { - static_assert(wistd::is_same::value, "nullptr constructor: valid only for handle types using nullptr as the invalid value"); + static_assert(wistd::is_same::value, "nullptr constructor: valid only for handle types using nullptr as the invalid value"); } shared_any_t(shared_any_t &&other) WI_NOEXCEPT : @@ -1986,7 +1990,7 @@ namespace wil { shared_any_t& operator=(wistd::nullptr_t) WI_NOEXCEPT { - static_assert(wistd::is_same::value, "nullptr assignment: valid only for handle types using nullptr as the invalid value"); + static_assert(wistd::is_same::value, "nullptr assignment: valid only for handle types using nullptr as the invalid value"); storage_t::reset(); return (*this); } @@ -2005,14 +2009,14 @@ namespace wil { pointer_storage *operator&() { - static_assert(wistd::is_same::value, "operator & is not available for this handle"); + static_assert(wistd::is_same::value, "operator & is not available for this handle"); storage_t::reset(); return storage_t::addressof(); } pointer get() const WI_NOEXCEPT { - static_assert(!wistd::is_same::value, "get(): the raw handle value is not available for this resource class"); + static_assert(!wistd::is_same::value, "get(): the raw handle value is not available for this resource class"); return storage_t::get(); } @@ -2038,14 +2042,14 @@ namespace wil { template bool operator==(const shared_any_t& left, wistd::nullptr_t) WI_NOEXCEPT { - static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); + static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); return !left; } template bool operator==(wistd::nullptr_t, const shared_any_t& right) WI_NOEXCEPT { - static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); + static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); return !right; } @@ -2058,14 +2062,14 @@ namespace wil { template bool operator!=(const shared_any_t& left, wistd::nullptr_t) WI_NOEXCEPT { - static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); + static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); return !!left; } template bool operator!=(wistd::nullptr_t, const shared_any_t& right) WI_NOEXCEPT { - static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); + static_assert(wistd::is_same::policy::pointer_invalid, wistd::nullptr_t>::value, "the resource class does not use nullptr as an invalid value"); return !!right; } @@ -2108,7 +2112,7 @@ namespace wil { } weak_any(const shared_t &other) WI_NOEXCEPT : - m_weakPtr(other.m_ptr) + m_weakPtr(other.m_ptr) { } @@ -2175,7 +2179,7 @@ namespace std { size_t operator()(wil::shared_any_t const &val) const { - return (hash::pointer>()(val.get())); + return (hash::pointer>()(val.get())); } }; } @@ -2660,11 +2664,9 @@ namespace wil class slim_event_t { public: - slim_event_t() - { - } + slim_event_t() WI_NOEXCEPT = default; - slim_event_t(bool isSignaled) : + slim_event_t(bool isSignaled) WI_NOEXCEPT : m_isSignaled(isSignaled ? TRUE : FALSE) { } @@ -3465,7 +3467,6 @@ namespace wil unique_event_nothrow m_event; // The thread pool must be last to ensure that the other members are valid // when it is destructed as it will reference them. - // See http://osgvsowi/2224623 unique_threadpool_wait m_threadPoolWait; }; diff --git a/include/wil/result.h b/include/wil/result.h index 2bf0983..fbd6306 100644 --- a/include/wil/result.h +++ b/include/wil/result.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_RESULT_INCLUDED #define __WIL_RESULT_INCLUDED diff --git a/include/wil/result_macros.h b/include/wil/result_macros.h index 98ebc09..bf4021f 100644 --- a/include/wil/result_macros.h +++ b/include/wil/result_macros.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_RESULTMACROS_INCLUDED #define __WIL_RESULTMACROS_INCLUDED @@ -91,6 +100,15 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "1") #else WI_ODR_PRAGMA("WIL_FreeMemory", "0") #endif + +// It would appear as though the C++17 "noexcept is part of the type system" update in MSVC has "infected" the behavior +// when compiling with C++14 (the default...), however the updated behavior for decltype understanding noexcept is _not_ +// present... So, work around it +#if __WI_LIBCPP_STD_VER >= 17 +#define WI_PFN_NOEXCEPT WI_NOEXCEPT +#else +#define WI_PFN_NOEXCEPT +#endif /// @endcond #if defined(__cplusplus) && !defined(__WIL_MIN_KERNEL) && !defined(WIL_KERNEL_MODE) @@ -544,23 +562,23 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") #endif // end-of-repeated fail-fast handling macros // Helpers for return macros -#define __RETURN_HR_MSG(hr, str, fmt, ...) do { HRESULT __hr = (hr); if (FAILED(__hr)) { __R_FN(Return_HrMsg)(__R_INFO(str) __hr, fmt, __VA_ARGS__); } return __hr; } while (0, 0) -#define __RETURN_HR_MSG_FAIL(hr, str, fmt, ...) do { HRESULT __hr = (hr); __R_FN(Return_HrMsg)(__R_INFO(str) __hr, fmt, __VA_ARGS__); return __hr; } while (0, 0) -#define __RETURN_WIN32_MSG(err, str, fmt, ...) do { DWORD __err = (err); if (FAILED_WIN32(__err)) { return __R_FN(Return_Win32Msg)(__R_INFO(str) __err, fmt, __VA_ARGS__); } return S_OK; } while (0, 0) -#define __RETURN_WIN32_MSG_FAIL(err, str, fmt, ...) do { DWORD __err = (err); return __R_FN(Return_Win32Msg)(__R_INFO(str) __err, fmt, __VA_ARGS__); } while (0, 0) +#define __RETURN_HR_MSG(hr, str, fmt, ...) do { HRESULT __hr = (hr); if (FAILED(__hr)) { __R_FN(Return_HrMsg)(__R_INFO(str) __hr, fmt, __VA_ARGS__); } return __hr; } while ((void)0, 0) +#define __RETURN_HR_MSG_FAIL(hr, str, fmt, ...) do { HRESULT __hr = (hr); __R_FN(Return_HrMsg)(__R_INFO(str) __hr, fmt, __VA_ARGS__); return __hr; } while ((void)0, 0) +#define __RETURN_WIN32_MSG(err, str, fmt, ...) do { DWORD __err = (err); if (FAILED_WIN32(__err)) { return __R_FN(Return_Win32Msg)(__R_INFO(str) __err, fmt, __VA_ARGS__); } return S_OK; } while ((void)0, 0) +#define __RETURN_WIN32_MSG_FAIL(err, str, fmt, ...) do { DWORD __err = (err); return __R_FN(Return_Win32Msg)(__R_INFO(str) __err, fmt, __VA_ARGS__); } while ((void)0, 0) #define __RETURN_GLE_MSG_FAIL(str, fmt, ...) return __R_FN(Return_GetLastErrorMsg)(__R_INFO(str) fmt, __VA_ARGS__) -#define __RETURN_NTSTATUS_MSG(status, str, fmt, ...) do { NTSTATUS __status = (status); if(FAILED_NTSTATUS(__status)) { return __R_FN(Return_NtStatusMsg)(__R_INFO(str) __status, fmt, __VA_ARGS__); } return S_OK; } while (0, 0) -#define __RETURN_NTSTATUS_MSG_FAIL(status, str, fmt, ...) do { NTSTATUS __status = (status); return __R_FN(Return_NtStatusMsg)(__R_INFO(str) __status, fmt, __VA_ARGS__); } while (0, 0) -#define __RETURN_HR(hr, str) do { HRESULT __hr = (hr); if (FAILED(__hr)) { __R_FN(Return_Hr)(__R_INFO(str) __hr); } return __hr; } while (0, 0) -#define __RETURN_HR_NOFILE(hr, str) do { HRESULT __hr = (hr); if (FAILED(__hr)) { __R_FN(Return_Hr)(__R_INFO_NOFILE(str) __hr); } return __hr; } while (0, 0) -#define __RETURN_HR_FAIL(hr, str) do { HRESULT __hr = (hr); __R_FN(Return_Hr)(__R_INFO(str) __hr); return __hr; } while (0, 0) -#define __RETURN_HR_FAIL_NOFILE(hr, str) do { HRESULT __hr = (hr); __R_FN(Return_Hr)(__R_INFO_NOFILE(str) __hr); return __hr; } while (0, 0) -#define __RETURN_WIN32(err, str) do { DWORD __err = (err); if (FAILED_WIN32(__err)) { return __R_FN(Return_Win32)(__R_INFO(str) __err); } return S_OK; } while (0, 0) -#define __RETURN_WIN32_FAIL(err, str) do { DWORD __err = (err); return __R_FN(Return_Win32)(__R_INFO(str) __err); } while (0, 0) +#define __RETURN_NTSTATUS_MSG(status, str, fmt, ...) do { NTSTATUS __status = (status); if(FAILED_NTSTATUS(__status)) { return __R_FN(Return_NtStatusMsg)(__R_INFO(str) __status, fmt, __VA_ARGS__); } return S_OK; } while ((void)0, 0) +#define __RETURN_NTSTATUS_MSG_FAIL(status, str, fmt, ...) do { NTSTATUS __status = (status); return __R_FN(Return_NtStatusMsg)(__R_INFO(str) __status, fmt, __VA_ARGS__); } while ((void)0, 0) +#define __RETURN_HR(hr, str) do { HRESULT __hr = (hr); if (FAILED(__hr)) { __R_FN(Return_Hr)(__R_INFO(str) __hr); } return __hr; } while ((void)0, 0) +#define __RETURN_HR_NOFILE(hr, str) do { HRESULT __hr = (hr); if (FAILED(__hr)) { __R_FN(Return_Hr)(__R_INFO_NOFILE(str) __hr); } return __hr; } while ((void)0, 0) +#define __RETURN_HR_FAIL(hr, str) do { HRESULT __hr = (hr); __R_FN(Return_Hr)(__R_INFO(str) __hr); return __hr; } while ((void)0, 0) +#define __RETURN_HR_FAIL_NOFILE(hr, str) do { HRESULT __hr = (hr); __R_FN(Return_Hr)(__R_INFO_NOFILE(str) __hr); return __hr; } while ((void)0, 0) +#define __RETURN_WIN32(err, str) do { DWORD __err = (err); if (FAILED_WIN32(__err)) { return __R_FN(Return_Win32)(__R_INFO(str) __err); } return S_OK; } while ((void)0, 0) +#define __RETURN_WIN32_FAIL(err, str) do { DWORD __err = (err); return __R_FN(Return_Win32)(__R_INFO(str) __err); } while ((void)0, 0) #define __RETURN_GLE_FAIL(str) return __R_FN(Return_GetLastError)(__R_INFO_ONLY(str)) #define __RETURN_GLE_FAIL_NOFILE(str) return __R_FN(Return_GetLastError)(__R_INFO_NOFILE_ONLY(str)) -#define __RETURN_NTSTATUS(status, str) do { NTSTATUS __status = (status); if(FAILED_NTSTATUS(__status)) { return __R_FN(Return_NtStatus)(__R_INFO(str) __status); } return S_OK; } while (0, 0) -#define __RETURN_NTSTATUS_FAIL(status, str) do { NTSTATUS __status = (status); return __R_FN(Return_NtStatus)(__R_INFO(str) __status); } while (0, 0) +#define __RETURN_NTSTATUS(status, str) do { NTSTATUS __status = (status); if(FAILED_NTSTATUS(__status)) { return __R_FN(Return_NtStatus)(__R_INFO(str) __status); } return S_OK; } while ((void)0, 0) +#define __RETURN_NTSTATUS_FAIL(status, str) do { NTSTATUS __status = (status); return __R_FN(Return_NtStatus)(__R_INFO(str) __status); } while ((void)0, 0) /// @endcond //***************************************************************************** @@ -574,15 +592,15 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") #define RETURN_NTSTATUS(status) __RETURN_NTSTATUS(status, #status) // Conditionally returns failures (HRESULT) - always logs failures -#define RETURN_IF_FAILED(hr) do { HRESULT __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { __RETURN_HR_FAIL(__hrRet, #hr); }} while (0, 0) -#define RETURN_IF_WIN32_BOOL_FALSE(win32BOOL) do { BOOL __boolRet = wil::verify_BOOL(win32BOOL); if (!__boolRet) { __RETURN_GLE_FAIL(#win32BOOL); }} while (0, 0) -#define RETURN_IF_WIN32_ERROR(win32err) do { DWORD __errRet = (win32err); if (FAILED_WIN32(__errRet)) { __RETURN_WIN32_FAIL(__errRet, #win32err); }} while (0, 0) -#define RETURN_IF_NULL_ALLOC(ptr) do { if ((ptr) == nullptr) { __RETURN_HR_FAIL(E_OUTOFMEMORY, #ptr); }} while (0, 0) -#define RETURN_HR_IF(hr, condition) do { if (wil::verify_bool(condition)) { __RETURN_HR(wil::verify_hresult(hr), #condition); }} while (0, 0) -#define RETURN_HR_IF_NULL(hr, ptr) do { if ((ptr) == nullptr) { __RETURN_HR(wil::verify_hresult(hr), #ptr); }} while (0, 0) -#define RETURN_LAST_ERROR_IF(condition) do { if (wil::verify_bool(condition)) { __RETURN_GLE_FAIL(#condition); }} while (0, 0) -#define RETURN_LAST_ERROR_IF_NULL(ptr) do { if ((ptr) == nullptr) { __RETURN_GLE_FAIL(#ptr); }} while (0, 0) -#define RETURN_IF_NTSTATUS_FAILED(status) do { NTSTATUS __statusRet = (status); if (FAILED_NTSTATUS(__statusRet)) { __RETURN_NTSTATUS_FAIL(__statusRet, #status); }} while (0, 0) +#define RETURN_IF_FAILED(hr) do { HRESULT __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { __RETURN_HR_FAIL(__hrRet, #hr); }} while ((void)0, 0) +#define RETURN_IF_WIN32_BOOL_FALSE(win32BOOL) do { BOOL __boolRet = wil::verify_BOOL(win32BOOL); if (!__boolRet) { __RETURN_GLE_FAIL(#win32BOOL); }} while ((void)0, 0) +#define RETURN_IF_WIN32_ERROR(win32err) do { DWORD __errRet = (win32err); if (FAILED_WIN32(__errRet)) { __RETURN_WIN32_FAIL(__errRet, #win32err); }} while ((void)0, 0) +#define RETURN_IF_NULL_ALLOC(ptr) do { if ((ptr) == nullptr) { __RETURN_HR_FAIL(E_OUTOFMEMORY, #ptr); }} while ((void)0, 0) +#define RETURN_HR_IF(hr, condition) do { if (wil::verify_bool(condition)) { __RETURN_HR(wil::verify_hresult(hr), #condition); }} while ((void)0, 0) +#define RETURN_HR_IF_NULL(hr, ptr) do { if ((ptr) == nullptr) { __RETURN_HR(wil::verify_hresult(hr), #ptr); }} while ((void)0, 0) +#define RETURN_LAST_ERROR_IF(condition) do { if (wil::verify_bool(condition)) { __RETURN_GLE_FAIL(#condition); }} while ((void)0, 0) +#define RETURN_LAST_ERROR_IF_NULL(ptr) do { if ((ptr) == nullptr) { __RETURN_GLE_FAIL(#ptr); }} while ((void)0, 0) +#define RETURN_IF_NTSTATUS_FAILED(status) do { NTSTATUS __statusRet = (status); if (FAILED_NTSTATUS(__statusRet)) { __RETURN_NTSTATUS_FAIL(__statusRet, #status); }} while ((void)0, 0) // Always returns a known failure (HRESULT) - always logs a var-arg message on failure #define RETURN_HR_MSG(hr, fmt, ...) __RETURN_HR_MSG(wil::verify_hresult(hr), #hr, fmt, __VA_ARGS__) @@ -591,26 +609,26 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") #define RETURN_NTSTATUS_MSG(status, fmt, ...) __RETURN_NTSTATUS_MSG(status, #status, fmt, __VA_ARGS__) // Conditionally returns failures (HRESULT) - always logs a var-arg message on failure -#define RETURN_IF_FAILED_MSG(hr, fmt, ...) do { auto __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { __RETURN_HR_MSG_FAIL(__hrRet, #hr, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_IF_WIN32_BOOL_FALSE_MSG(win32BOOL, fmt, ...) do { if (!wil::verify_BOOL(win32BOOL)) { __RETURN_GLE_MSG_FAIL(#win32BOOL, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_IF_WIN32_ERROR_MSG(win32err, fmt, ...) do { auto __errRet = (win32err); if (FAILED_WIN32(__errRet)) { __RETURN_WIN32_MSG_FAIL(__errRet, #win32err, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_IF_NULL_ALLOC_MSG(ptr, fmt, ...) do { if ((ptr) == nullptr) { __RETURN_HR_MSG_FAIL(E_OUTOFMEMORY, #ptr, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_HR_IF_MSG(hr, condition, fmt, ...) do { if (wil::verify_bool(condition)) { __RETURN_HR_MSG(wil::verify_hresult(hr), #condition, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_HR_IF_NULL_MSG(hr, ptr, fmt, ...) do { if ((ptr) == nullptr) { __RETURN_HR_MSG(wil::verify_hresult(hr), #ptr, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_LAST_ERROR_IF_MSG(condition, fmt, ...) do { if (wil::verify_bool(condition)) { __RETURN_GLE_MSG_FAIL(#condition, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_LAST_ERROR_IF_NULL_MSG(ptr, fmt, ...) do { if ((ptr) == nullptr) { __RETURN_GLE_MSG_FAIL(#ptr, fmt, __VA_ARGS__); }} while (0, 0) -#define RETURN_IF_NTSTATUS_FAILED_MSG(status, fmt, ...) do { NTSTATUS __statusRet = (status); if (FAILED_NTSTATUS(__statusRet)) { __RETURN_NTSTATUS_MSG_FAIL(__statusRet, #status, fmt, __VA_ARGS__); }} while (0, 0) +#define RETURN_IF_FAILED_MSG(hr, fmt, ...) do { auto __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { __RETURN_HR_MSG_FAIL(__hrRet, #hr, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_IF_WIN32_BOOL_FALSE_MSG(win32BOOL, fmt, ...) do { if (!wil::verify_BOOL(win32BOOL)) { __RETURN_GLE_MSG_FAIL(#win32BOOL, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_IF_WIN32_ERROR_MSG(win32err, fmt, ...) do { auto __errRet = (win32err); if (FAILED_WIN32(__errRet)) { __RETURN_WIN32_MSG_FAIL(__errRet, #win32err, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_IF_NULL_ALLOC_MSG(ptr, fmt, ...) do { if ((ptr) == nullptr) { __RETURN_HR_MSG_FAIL(E_OUTOFMEMORY, #ptr, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_HR_IF_MSG(hr, condition, fmt, ...) do { if (wil::verify_bool(condition)) { __RETURN_HR_MSG(wil::verify_hresult(hr), #condition, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_HR_IF_NULL_MSG(hr, ptr, fmt, ...) do { if ((ptr) == nullptr) { __RETURN_HR_MSG(wil::verify_hresult(hr), #ptr, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_LAST_ERROR_IF_MSG(condition, fmt, ...) do { if (wil::verify_bool(condition)) { __RETURN_GLE_MSG_FAIL(#condition, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_LAST_ERROR_IF_NULL_MSG(ptr, fmt, ...) do { if ((ptr) == nullptr) { __RETURN_GLE_MSG_FAIL(#ptr, fmt, __VA_ARGS__); }} while ((void)0, 0) +#define RETURN_IF_NTSTATUS_FAILED_MSG(status, fmt, ...) do { NTSTATUS __statusRet = (status); if (FAILED_NTSTATUS(__statusRet)) { __RETURN_NTSTATUS_MSG_FAIL(__statusRet, #status, fmt, __VA_ARGS__); }} while ((void)0, 0) // Conditionally returns failures (HRESULT) - use for failures that are expected in common use - failures are not logged - macros are only for control flow pattern -#define RETURN_IF_FAILED_EXPECTED(hr) do { auto __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { return __hrRet; }} while (0, 0) -#define RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(win32BOOL) do { if (!wil::verify_BOOL(win32BOOL)) { return wil::details::GetLastErrorFailHr(); }} while (0, 0) -#define RETURN_IF_WIN32_ERROR_EXPECTED(win32err) do { auto __errRet = (win32err); if (FAILED_WIN32(__errRet)) { return HRESULT_FROM_WIN32(__errRet); }} while (0, 0) -#define RETURN_IF_NULL_ALLOC_EXPECTED(ptr) do { if ((ptr) == nullptr) { return E_OUTOFMEMORY; }} while (0, 0) -#define RETURN_HR_IF_EXPECTED(hr, condition) do { if (wil::verify_bool(condition)) { return wil::verify_hresult(hr); }} while (0, 0) -#define RETURN_HR_IF_NULL_EXPECTED(hr, ptr) do { if ((ptr) == nullptr) { return wil::verify_hresult(hr); }} while (0, 0) -#define RETURN_LAST_ERROR_IF_EXPECTED(condition) do { if (wil::verify_bool(condition)) { return wil::details::GetLastErrorFailHr(); }} while (0, 0) -#define RETURN_LAST_ERROR_IF_NULL_EXPECTED(ptr) do { if ((ptr) == nullptr) { return wil::details::GetLastErrorFailHr(); }} while (0, 0) -#define RETURN_IF_NTSTATUS_FAILED_EXPECTED(status) do { auto __statusRet = (status); if (FAILED_NTSTATUS(__statusRet)) { return wil::details::NtStatusToHr(__statusRet); }} while (0, 0) +#define RETURN_IF_FAILED_EXPECTED(hr) do { auto __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { return __hrRet; }} while ((void)0, 0) +#define RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(win32BOOL) do { if (!wil::verify_BOOL(win32BOOL)) { return wil::details::GetLastErrorFailHr(); }} while ((void)0, 0) +#define RETURN_IF_WIN32_ERROR_EXPECTED(win32err) do { auto __errRet = (win32err); if (FAILED_WIN32(__errRet)) { return HRESULT_FROM_WIN32(__errRet); }} while ((void)0, 0) +#define RETURN_IF_NULL_ALLOC_EXPECTED(ptr) do { if ((ptr) == nullptr) { return E_OUTOFMEMORY; }} while ((void)0, 0) +#define RETURN_HR_IF_EXPECTED(hr, condition) do { if (wil::verify_bool(condition)) { return wil::verify_hresult(hr); }} while ((void)0, 0) +#define RETURN_HR_IF_NULL_EXPECTED(hr, ptr) do { if ((ptr) == nullptr) { return wil::verify_hresult(hr); }} while ((void)0, 0) +#define RETURN_LAST_ERROR_IF_EXPECTED(condition) do { if (wil::verify_bool(condition)) { return wil::details::GetLastErrorFailHr(); }} while ((void)0, 0) +#define RETURN_LAST_ERROR_IF_NULL_EXPECTED(ptr) do { if ((ptr) == nullptr) { return wil::details::GetLastErrorFailHr(); }} while ((void)0, 0) +#define RETURN_IF_NTSTATUS_FAILED_EXPECTED(status) do { auto __statusRet = (status); if (FAILED_NTSTATUS(__statusRet)) { return wil::details::NtStatusToHr(__statusRet); }} while ((void)0, 0) #define __WI_OR_IS_EXPECTED_HRESULT(e) || (__hrRet == wil::verify_hresult(e)) #define RETURN_IF_FAILED_WITH_EXPECTED(hr, hrExpected, ...) \ @@ -626,7 +644,7 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") __RETURN_HR_FAIL(__hrRet, #hr); \ } \ } \ - while (0, 0) + while ((void)0, 0) //***************************************************************************** // Macros for logging failures (ignore or pass-through) @@ -744,7 +762,7 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") #define FAIL_FAST_IMMEDIATE_IF_NTSTATUS_FAILED(status) __RFF_FN(FailFastImmediate_IfNtStatusFailed)(status) // Specializations -#define FAIL_FAST_IMMEDIATE_IF_IN_LOADER_CALLOUT() do { if (wil::details::g_pfnFailFastInLoaderCallout != nullptr) { wil::details::g_pfnFailFastInLoaderCallout(); } } while (0, 0) +#define FAIL_FAST_IMMEDIATE_IF_IN_LOADER_CALLOUT() do { if (wil::details::g_pfnFailFastInLoaderCallout != nullptr) { wil::details::g_pfnFailFastInLoaderCallout(); } } while ((void)0, 0) //***************************************************************************** @@ -837,14 +855,14 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") #define WI_USAGE_ASSERT_STOP(condition) WI_ASSERT(condition) #endif #ifdef RESULT_DEBUG -#define WI_USAGE_ERROR(msg, ...) do { LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE), msg, __VA_ARGS__); WI_USAGE_ASSERT_STOP(false); } while (0, 0) -#define WI_USAGE_ERROR_FORWARD(msg, ...) do { ReportFailure_ReplaceMsg(__R_FN_CALL_FULL, FailureType::Log, HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE), msg, __VA_ARGS__); WI_USAGE_ASSERT_STOP(false); } while (0, 0) +#define WI_USAGE_ERROR(msg, ...) do { LOG_HR_MSG(HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE), msg, __VA_ARGS__); WI_USAGE_ASSERT_STOP(false); } while ((void)0, 0) +#define WI_USAGE_ERROR_FORWARD(msg, ...) do { ReportFailure_ReplaceMsg(__R_FN_CALL_FULL, FailureType::Log, HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE), msg, __VA_ARGS__); WI_USAGE_ASSERT_STOP(false); } while ((void)0, 0) #else -#define WI_USAGE_ERROR(msg, ...) do { LOG_HR(HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE)); WI_USAGE_ASSERT_STOP(false); } while (0, 0) -#define WI_USAGE_ERROR_FORWARD(msg, ...) do { ReportFailure_Hr(__R_FN_CALL_FULL, FailureType::Log, HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE)); WI_USAGE_ASSERT_STOP(false); } while (0, 0) +#define WI_USAGE_ERROR(msg, ...) do { LOG_HR(HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE)); WI_USAGE_ASSERT_STOP(false); } while ((void)0, 0) +#define WI_USAGE_ERROR_FORWARD(msg, ...) do { ReportFailure_Hr(__R_FN_CALL_FULL, FailureType::Log, HRESULT_FROM_WIN32(ERROR_ASSERTION_FAILURE)); WI_USAGE_ASSERT_STOP(false); } while ((void)0, 0) #endif -#define WI_USAGE_VERIFY(condition, msg, ...) do { auto __passed = wil::verify_bool(condition); if (!__passed) { WI_USAGE_ERROR(msg, __VA_ARGS__); }} while (0, 0) -#define WI_USAGE_VERIFY_FORWARD(condition, msg, ...) do { auto __passed = wil::verify_bool(condition); if (!__passed) { WI_USAGE_ERROR_FORWARD(msg, __VA_ARGS__); }} while (0, 0) +#define WI_USAGE_VERIFY(condition, msg, ...) do { auto __passed = wil::verify_bool(condition); if (!__passed) { WI_USAGE_ERROR(msg, __VA_ARGS__); }} while ((void)0, 0) +#define WI_USAGE_VERIFY_FORWARD(condition, msg, ...) do { auto __passed = wil::verify_bool(condition); if (!__passed) { WI_USAGE_ERROR_FORWARD(msg, __VA_ARGS__); }} while ((void)0, 0) #ifdef RESULT_DEBUG #define WI_USAGE_ASSERT(condition, msg, ...) WI_USAGE_VERIFY(condition, msg, __VA_ARGS__) #else @@ -866,12 +884,12 @@ WI_ODR_PRAGMA("WIL_FreeMemory", "0") #define __WIL_PRIVATE_FAIL_FAST_HR(hr) FAIL_FAST_HR(hr) #define __WIL_PRIVATE_LOG_HR(hr) LOG_HR(hr) #else -#define __WIL_PRIVATE_RETURN_IF_FAILED(hr) do { HRESULT __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { __RETURN_HR_FAIL_NOFILE(__hrRet, #hr); }} while (0, 0) -#define __WIL_PRIVATE_RETURN_HR_IF(hr, cond) do { if (wil::verify_bool(cond)) { __RETURN_HR_NOFILE(wil::verify_hresult(hr), #cond); }} while (0, 0) -#define __WIL_PRIVATE_RETURN_LAST_ERROR_IF(cond) do { if (wil::verify_bool(cond)) { __RETURN_GLE_FAIL_NOFILE(#cond); }} while (0, 0) -#define __WIL_PRIVATE_RETURN_IF_WIN32_BOOL_FALSE(win32BOOL) do { BOOL __boolRet = wil::verify_BOOL(win32BOOL); if (!__boolRet) { __RETURN_GLE_FAIL_NOFILE(#win32BOOL); }} while (0, 0) -#define __WIL_PRIVATE_RETURN_LAST_ERROR_IF_NULL(ptr) do { if ((ptr) == nullptr) { __RETURN_GLE_FAIL_NOFILE(#ptr); }} while (0, 0) -#define __WIL_PRIVATE_RETURN_IF_NULL_ALLOC(ptr) do { if ((ptr) == nullptr) { __RETURN_HR_FAIL_NOFILE(E_OUTOFMEMORY, #ptr); }} while (0, 0) +#define __WIL_PRIVATE_RETURN_IF_FAILED(hr) do { HRESULT __hrRet = wil::verify_hresult(hr); if (FAILED(__hrRet)) { __RETURN_HR_FAIL_NOFILE(__hrRet, #hr); }} while ((void)0, 0) +#define __WIL_PRIVATE_RETURN_HR_IF(hr, cond) do { if (wil::verify_bool(cond)) { __RETURN_HR_NOFILE(wil::verify_hresult(hr), #cond); }} while ((void)0, 0) +#define __WIL_PRIVATE_RETURN_LAST_ERROR_IF(cond) do { if (wil::verify_bool(cond)) { __RETURN_GLE_FAIL_NOFILE(#cond); }} while ((void)0, 0) +#define __WIL_PRIVATE_RETURN_IF_WIN32_BOOL_FALSE(win32BOOL) do { BOOL __boolRet = wil::verify_BOOL(win32BOOL); if (!__boolRet) { __RETURN_GLE_FAIL_NOFILE(#win32BOOL); }} while ((void)0, 0) +#define __WIL_PRIVATE_RETURN_LAST_ERROR_IF_NULL(ptr) do { if ((ptr) == nullptr) { __RETURN_GLE_FAIL_NOFILE(#ptr); }} while ((void)0, 0) +#define __WIL_PRIVATE_RETURN_IF_NULL_ALLOC(ptr) do { if ((ptr) == nullptr) { __RETURN_HR_FAIL_NOFILE(E_OUTOFMEMORY, #ptr); }} while ((void)0, 0) #define __WIL_PRIVATE_RETURN_LAST_ERROR() __RETURN_GLE_FAIL_NOFILE(nullptr) #define __WIL_PRIVATE_FAIL_FAST_HR_IF(hr, condition) __RFF_FN(FailFast_HrIf)(__RFF_INFO_NOFILE(#condition) wil::verify_hresult(hr), wil::verify_bool(condition)) #define __WIL_PRIVATE_FAIL_FAST_HR(hr) __RFF_FN(FailFast_Hr)(__RFF_INFO_NOFILE(#hr) wil::verify_hresult(hr)) @@ -971,7 +989,7 @@ namespace wil // [optionally] Plug in error logging // Note: This callback is deprecated. Please use SetResultTelemetryFallback for telemetry or // SetResultLoggingCallback for observation. - extern "C" __declspec(selectany) void(__stdcall *g_pfnResultLoggingCallback)(_Inout_ wil::FailureInfo *pFailure, _Inout_updates_opt_z_(cchDebugMessage) PWSTR pszDebugMessage, _Pre_satisfies_(cchDebugMessage > 0) size_t cchDebugMessage) WI_NOEXCEPT = nullptr; + extern "C" __declspec(selectany) void(__stdcall *g_pfnResultLoggingCallback)(_Inout_ wil::FailureInfo *pFailure, _Inout_updates_opt_z_(cchDebugMessage) PWSTR pszDebugMessage, _Pre_satisfies_(cchDebugMessage > 0) size_t cchDebugMessage) WI_PFN_NOEXCEPT = nullptr; // [optional] // This can be explicitly set to control whether or not error messages will be output to OutputDebugString. It can also @@ -981,13 +999,13 @@ namespace wil // [optionally] Allows application to specify a debugger to detect whether a debugger is present. // Useful for processes that can only be debugged under kernel debuggers where IsDebuggerPresent returns // false. - __declspec(selectany) bool(__stdcall *g_pfnIsDebuggerPresent)() WI_NOEXCEPT = nullptr; + __declspec(selectany) bool(__stdcall *g_pfnIsDebuggerPresent)() WI_PFN_NOEXCEPT = nullptr; // [optionally] Allows forcing WIL to believe a debugger is present. Useful for when a kernel debugger is attached and ::IsDebuggerPresent returns false __declspec(selectany) bool g_fIsDebuggerPresent = false; // [optionally] Plug in additional exception-type support (return S_OK when *unable* to remap the exception) - __declspec(selectany) HRESULT(__stdcall *g_pfnResultFromCaughtException)() WI_NOEXCEPT = nullptr; + __declspec(selectany) HRESULT(__stdcall *g_pfnResultFromCaughtException)() WI_PFN_NOEXCEPT = nullptr; // [optionally] Use to configure fast fail of unknown exceptions (turn them off). __declspec(selectany) bool g_fResultFailFastUnknownExceptions = true; @@ -1002,7 +1020,7 @@ namespace wil __declspec(selectany) bool g_fBreakOnFailure = false; // [optionally] customize failfast behavior - __declspec(selectany) bool(__stdcall *g_pfnWilFailFast)(const wil::FailureInfo& info) WI_NOEXCEPT = nullptr; + __declspec(selectany) bool(__stdcall *g_pfnWilFailFast)(const wil::FailureInfo& info) WI_PFN_NOEXCEPT = nullptr; /// @cond namespace details @@ -1139,31 +1157,34 @@ namespace wil }; // Fallback telemetry provider callback (set with wil::SetResultTelemetryFallback) - __declspec(selectany) void(__stdcall *g_pfnTelemetryCallback)(bool alreadyReported, wil::FailureInfo const &failure) WI_NOEXCEPT = nullptr; + __declspec(selectany) void(__stdcall *g_pfnTelemetryCallback)(bool alreadyReported, wil::FailureInfo const &failure) WI_PFN_NOEXCEPT = nullptr; // Result.h plug-in (WIL use only) - __declspec(selectany) void(__stdcall *g_pfnGetContextAndNotifyFailure)(_Inout_ FailureInfo *pFailure, _Out_writes_(callContextStringLength) _Post_z_ PSTR callContextString, _Pre_satisfies_(callContextStringLength > 0) size_t callContextStringLength) WI_NOEXCEPT = nullptr; + __declspec(selectany) void(__stdcall *g_pfnGetContextAndNotifyFailure)(_Inout_ FailureInfo *pFailure, _Out_writes_(callContextStringLength) _Post_z_ PSTR callContextString, _Pre_satisfies_(callContextStringLength > 0) size_t callContextStringLength) WI_PFN_NOEXCEPT = nullptr; // Observe all errors flowing through the system with this callback (set with wil::SetResultLoggingCallback); use with custom logging - __declspec(selectany) void(__stdcall *g_pfnLoggingCallback)(wil::FailureInfo const &failure) WI_NOEXCEPT = nullptr; + __declspec(selectany) void(__stdcall *g_pfnLoggingCallback)(wil::FailureInfo const &failure) WI_PFN_NOEXCEPT = nullptr; // Desktop/System Only: Module fetch function (automatically setup) - __declspec(selectany) PCSTR(__stdcall *g_pfnGetModuleName)() WI_NOEXCEPT = nullptr; + __declspec(selectany) PCSTR(__stdcall *g_pfnGetModuleName)() WI_PFN_NOEXCEPT = nullptr; // Desktop/System Only: Retrieve address offset and modulename - __declspec(selectany) bool(__stdcall *g_pfnGetModuleInformation)(void* address, _Out_opt_ unsigned int* addressOffset, _Out_writes_bytes_opt_(size) char* name, size_t size) WI_NOEXCEPT = nullptr; + __declspec(selectany) bool(__stdcall *g_pfnGetModuleInformation)(void* address, _Out_opt_ unsigned int* addressOffset, _Out_writes_bytes_opt_(size) char* name, size_t size) WI_PFN_NOEXCEPT = nullptr; - // Desktop/System Only: Private module load fail fast function (automatically setup) - __declspec(selectany) void(__stdcall *g_pfnFailFastInLoaderCallout)() WI_NOEXCEPT = nullptr; + // Called with the expectation that the program will terminate when called inside of a loader callout. + // Desktop/System Only: Automatically setup when building Windows (BUILD_WINDOWS defined) + __declspec(selectany) void(__stdcall *g_pfnFailFastInLoaderCallout)() WI_PFN_NOEXCEPT = nullptr; - // Desktop/System Only: Private module load convert NtStatus to HResult (automatically setup) - __declspec(selectany) ULONG(__stdcall *g_pfnRtlNtStatusToDosErrorNoTeb)(NTSTATUS) WI_NOEXCEPT = nullptr; + // Called to translate an NTSTATUS value to a Win32 error code + // Desktop/System Only: Automatically setup when building Windows (BUILD_WINDOWS defined) + __declspec(selectany) ULONG(__stdcall *g_pfnRtlNtStatusToDosErrorNoTeb)(NTSTATUS) WI_PFN_NOEXCEPT = nullptr; - // Desktop/System Only: Private module load to call debug break - __declspec(selectany) void(__stdcall *g_pfnDebugBreak)() WI_NOEXCEPT = nullptr; + // Desktop/System Only: Call to DebugBreak + __declspec(selectany) void(__stdcall *g_pfnDebugBreak)() WI_PFN_NOEXCEPT = nullptr; - // Private API to determine whether or not termination is happening - __declspec(selectany) BOOLEAN(__stdcall *g_pfnRtlDllShutdownInProgress)() WI_NOEXCEPT = nullptr; + // Called to determine whether or not termination is happening + // Desktop/System Only: Automatically setup when building Windows (BUILD_WINDOWS defined) + __declspec(selectany) BOOLEAN(__stdcall *g_pfnDllShutdownInProgress)() WI_PFN_NOEXCEPT = nullptr; __declspec(selectany) bool g_processShutdownInProgress = false; // On Desktop/System WINAPI family: dynalink RaiseFailFastException because we may encounter modules @@ -1174,15 +1195,15 @@ namespace wil __declspec(selectany) HRESULT(__stdcall *g_pfnRunFunctorWithExceptionFilter)(IFunctor& functor, IFunctorHost& host, void* returnAddress) = nullptr; __declspec(selectany) void(__stdcall *g_pfnRethrow)() = nullptr; __declspec(selectany) void(__stdcall *g_pfnThrowResultException)(const FailureInfo& failure) = nullptr; - extern "C" __declspec(selectany) HRESULT(__stdcall *g_pfnResultFromCaughtExceptionInternal)(_Out_writes_opt_(debugStringChars) PWSTR debugString, _When_(debugString != nullptr, _Pre_satisfies_(debugStringChars > 0)) size_t debugStringChars, _Out_ bool* isNormalized) WI_NOEXCEPT = nullptr; + extern "C" __declspec(selectany) HRESULT(__stdcall *g_pfnResultFromCaughtExceptionInternal)(_Out_writes_opt_(debugStringChars) PWSTR debugString, _When_(debugString != nullptr, _Pre_satisfies_(debugStringChars > 0)) size_t debugStringChars, _Out_ bool* isNormalized) WI_PFN_NOEXCEPT = nullptr; // C++/cx compiled additions extern "C" __declspec(selectany) void(__stdcall *g_pfnThrowPlatformException)(FailureInfo const &failure, PCWSTR debugString) = nullptr; - extern "C" __declspec(selectany) _Always_(_Post_satisfies_(return < 0)) HRESULT(__stdcall *g_pfnResultFromCaughtException_WinRt)(_Inout_updates_opt_(debugStringChars) PWSTR debugString, _When_(debugString != nullptr, _Pre_satisfies_(debugStringChars > 0)) size_t debugStringChars, _Out_ bool* isNormalized) WI_NOEXCEPT = nullptr; + extern "C" __declspec(selectany) _Always_(_Post_satisfies_(return < 0)) HRESULT(__stdcall *g_pfnResultFromCaughtException_WinRt)(_Inout_updates_opt_(debugStringChars) PWSTR debugString, _When_(debugString != nullptr, _Pre_satisfies_(debugStringChars > 0)) size_t debugStringChars, _Out_ bool* isNormalized) WI_PFN_NOEXCEPT = nullptr; __declspec(selectany) _Always_(_Post_satisfies_(return < 0)) HRESULT(__stdcall *g_pfnResultFromKnownExceptions_WinRt)(const DiagnosticsInfo& diagnostics, void* returnAddress, SupportedExceptions supported, IFunctor& functor) = nullptr; // Plugin to call RoOriginateError (WIL use only) - __declspec(selectany) void(__stdcall *g_pfnOriginateCallback)(wil::FailureInfo const& failure) WI_NOEXCEPT = nullptr; + __declspec(selectany) void(__stdcall *g_pfnOriginateCallback)(wil::FailureInfo const& failure) WI_PFN_NOEXCEPT = nullptr; enum class ReportFailureOptions { @@ -1202,7 +1223,7 @@ namespace wil TFunctor&& functor; functor_wrapper_void(TFunctor&& functor_) : functor(wistd::forward(functor_)) { } #pragma warning(push) - #pragma warning(disable:4702) /* https://microsoft.visualstudio.com/OS/_workitems?id=15917057&fullScreen=false&_a=edit */ + #pragma warning(disable:4702) // https://github.com/Microsoft/wil/issues/2 HRESULT Run() override { functor(); @@ -1229,7 +1250,7 @@ namespace wil TReturn& retVal; functor_wrapper_other(TFunctor& functor_, TReturn& retval_) : functor(wistd::forward(functor_)), retVal(retval_) { } #pragma warning(push) - #pragma warning(disable:4702) /* https://microsoft.visualstudio.com/OS/_workitems?id=15917057&fullScreen=false&_a=edit */ + #pragma warning(disable:4702) // https://github.com/Microsoft/wil/issues/2 HRESULT Run() override { retVal = functor(); @@ -1316,8 +1337,8 @@ namespace wil //***************************************************************************** /// @cond - #define __FAIL_FAST_ASSERT__(condition) do { if (!(condition)) { __RFF_FN(FailFast_Unexpected)(__RFF_INFO_ONLY(#condition)); } } while (0, 0) - #define __FAIL_FAST_IMMEDIATE_ASSERT__(condition) do { if (!(condition)) { wil::FailureInfo failure {}; wil::details::WilFailFast(failure); } } while (0, 0) + #define __FAIL_FAST_ASSERT__(condition) do { if (!(condition)) { __RFF_FN(FailFast_Unexpected)(__RFF_INFO_ONLY(#condition)); } } while ((void)0, 0) + #define __FAIL_FAST_IMMEDIATE_ASSERT__(condition) do { if (!(condition)) { wil::FailureInfo failure {}; wil::details::WilFailFast(failure); } } while ((void)0, 0) #define __FAIL_FAST_ASSERT_WIN32_BOOL_FALSE__(condition) __RFF_FN(FailFast_IfWin32BoolFalse)(__RFF_INFO(#condition) wil::verify_BOOL(condition)) // A simple ref-counted buffer class. The interface is very similar to shared_ptr<>, only it manages @@ -1719,7 +1740,7 @@ namespace wil // All successful status codes have only one hresult equivalent, S_OK return S_OK; } - if (status == STATUS_NO_MEMORY) + if (status == static_cast(STATUS_NO_MEMORY)) { // RtlNtStatusToDosErrorNoTeb maps STATUS_NO_MEMORY to the less popular of two Win32 no memory error codes resulting in an unexpected mapping return E_OUTOFMEMORY; @@ -1843,9 +1864,7 @@ namespace wil } #pragma warning(pop) - #pragma warning(push) - #pragma warning(disable : 4100) // Unused parameter (pszDest) - _Post_satisfies_(cchDest > 0 && cchDest <= cchMax) static STRSAFEAPI WilStringValidateDestA(_In_reads_opt_(cchDest) STRSAFE_PCNZCH pszDest, _In_ size_t cchDest, _In_ const size_t cchMax) + _Post_satisfies_(cchDest > 0 && cchDest <= cchMax) static STRSAFEAPI WilStringValidateDestA(_In_reads_opt_(cchDest) STRSAFE_PCNZCH /*pszDest*/, _In_ size_t cchDest, _In_ const size_t cchMax) { HRESULT hr = S_OK; if ((cchDest == 0) || (cchDest > cchMax)) @@ -1854,7 +1873,6 @@ namespace wil } return hr; } - #pragma warning(pop) static STRSAFEAPI WilStringVPrintfWorkerA(_Out_writes_(cchDest) _Always_(_Post_z_) STRSAFE_LPSTR pszDest, _In_ _In_range_(1, STRSAFE_MAX_CCH) size_t cchDest, _Always_(_Out_opt_ _Deref_out_range_(<=, cchDest - 1)) size_t* pcchNewDestLength, _In_ _Printf_format_string_ STRSAFE_LPCSTR pszFormat, _In_ va_list argList) { @@ -1909,7 +1927,7 @@ namespace wil return hr; } - STRSAFEAPI StringCchPrintfA( _Out_writes_(cchDest) _Always_(_Post_z_) STRSAFE_LPSTR pszDest, _In_ size_t cchDest, _In_ _Printf_format_string_ STRSAFE_LPCSTR pszFormat, ...) + __inline HRESULT StringCchPrintfA( _Out_writes_(cchDest) _Always_(_Post_z_) STRSAFE_LPSTR pszDest, _In_ size_t cchDest, _In_ _Printf_format_string_ STRSAFE_LPCSTR pszFormat, ...) { HRESULT hr; hr = wil::details::WilStringValidateDestA(pszDest, cchDest, STRSAFE_MAX_CCH); @@ -2030,7 +2048,7 @@ namespace wil //! Call this method to determine if process shutdown is in progress (allows avoiding work during dll unload). inline bool ProcessShutdownInProgress() { - return (details::g_processShutdownInProgress || (details::g_pfnRtlDllShutdownInProgress ? details::g_pfnRtlDllShutdownInProgress() : false)); + return (details::g_processShutdownInProgress || (details::g_pfnDllShutdownInProgress ? details::g_pfnDllShutdownInProgress() : false)); } /** Use this object to wrap an object that wants to prevent its destructor from being run when the process is shutting down, @@ -2067,7 +2085,7 @@ namespace wil } private: - unsigned char m_raw[sizeof(T)]; + alignas(T) unsigned char m_raw[sizeof(T)]; }; /** Use this object to wrap an object that wants to prevent its destructor from being run when the process is shutting down. @@ -2123,11 +2141,11 @@ namespace wil } private: - unsigned char m_raw[sizeof(T)]; + alignas(T) unsigned char m_raw[sizeof(T)]; }; /** Forward your DLLMain to this function so that WIL can have visibility into whether a DLL unload is because - of termination or normal unload. Note that when private API usage is enabled, WIL attempts to make this + of termination or normal unload. Note that when g_pfnDllShutdownInProgress is set, WIL attempts to make this determination on its own without this callback. Suppressing private APIs requires use of this. */ inline void DLLMain(HINSTANCE, DWORD reason, _In_opt_ LPVOID reserved) { @@ -3143,8 +3161,7 @@ namespace wil // Returns true if a debugger should be considered to be connected. // Modules can force this on through setting g_fIsDebuggerPresent explicitly (useful for live debugging), // they can provide a callback function by setting g_pfnIsDebuggerPresent (useful for kernel debbugging), - // and finally the user-mode check (IsDebuggerPrsent) is checked. IsDebuggerPresent is a fast call as it - // returns NtCurrentPeb()->BeingDebugged. + // and finally the user-mode check (IsDebuggerPrsent) is checked. IsDebuggerPresent is a fast call inline bool IsDebuggerPresent() { return g_fIsDebuggerPresent || ((g_pfnIsDebuggerPresent != nullptr) ? g_pfnIsDebuggerPresent() : (::IsDebuggerPresent() != FALSE)); @@ -3299,7 +3316,6 @@ namespace wil er.ExceptionInformation[0] = FAST_FAIL_FATAL_APP_EXIT; // see winnt.h, generated from minkernel\published\base\ntrtl_x.w if (failure.returnAddress == 0) // FailureInfo does not have _ReturnAddress, have RaiseFailFastException generate it { - // http://osgvsowi/17364039 - confirm with !analyze team that this is the best we can do in this case // passing ExceptionCode 0xC0000409 and one param with FAST_FAIL_APP_EXIT will use existing // !analyze functionality to crawl the stack looking for the HRESULT // don't pass a 0 HRESULT in param 1 because that will result in worse bucketing. diff --git a/include/wil/result_originate.h b/include/wil/result_originate.h index b40f688..bca7407 100644 --- a/include/wil/result_originate.h +++ b/include/wil/result_originate.h @@ -1,3 +1,14 @@ +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* + // Note: When origination is enabled by including this file, origination is done as part of the RETURN_* and THROW_* macros. Before originating // a new error we will observe whether there is already an error payload associated with the current thread. If there is, and the HRESULTs match, // then a new error will not be originated. Otherwise we will overwrite it with a new origination. The ABI boundary for WinRT APIs will check the diff --git a/include/wil/safecast.h b/include/wil/safecast.h index 88094e8..996d577 100644 --- a/include/wil/safecast.h +++ b/include/wil/safecast.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_SAFECAST_INCLUDED #define __WIL_SAFECAST_INCLUDED @@ -291,9 +300,9 @@ namespace wil typename OldT, wistd::enable_if_t, int> = 0 > - NewT safe_cast_nothrow(const OldT var) + NewT safe_cast_nothrow(const OldT /*var*/) { - static_assert(false, "This cast has the potential to fail, use the two parameter safe_cast_nothrow instead"); + static_assert(!wistd::is_same_v, "This cast has the potential to fail, use the two parameter safe_cast_nothrow instead"); } // This conversion is always safe, therefore a static_cast is fine. diff --git a/include/wil/stl.h b/include/wil/stl.h index b7c19d0..56e95f4 100644 --- a/include/wil/stl.h +++ b/include/wil/stl.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_STL_INCLUDED #define __WIL_STL_INCLUDED diff --git a/include/wil/token_helpers.h b/include/wil/token_helpers.h index 45c97f2..87ce582 100644 --- a/include/wil/token_helpers.h +++ b/include/wil/token_helpers.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_TOKEN_HELPERS_INCLUDED #define __WIL_TOKEN_HELPERS_INCLUDED diff --git a/include/wil/win32_helpers.h b/include/wil/win32_helpers.h index b573357..9eb8717 100644 --- a/include/wil/win32_helpers.h +++ b/include/wil/win32_helpers.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_WIN32_HELPERS_INCLUDED #define __WIL_WIN32_HELPERS_INCLUDED @@ -158,7 +167,7 @@ namespace wil // This function does not work beyond the default stack buffer size (255). // Needs to to retry in a loop similar to wil::GetModuleFileNameExW - // These updates and unit tests are tracked by //osgvsowi/14483919 + // These updates and unit tests are tracked by https://github.com/Microsoft/wil/issues/3 template HRESULT QueryFullProcessImageNameW(HANDLE processHandle, _In_ DWORD flags, string_type& result) WI_NOEXCEPT { @@ -401,7 +410,7 @@ namespace wil } #pragma warning(push) - #pragma warning(disable:4702) /* https://microsoft.visualstudio.com/OS/_workitems?id=15917057&fullScreen=false&_a=edit */ + #pragma warning(disable:4702) // https://github.com/Microsoft/wil/issues/2 void success() { m_flags = 0; diff --git a/include/wil/winrt.h b/include/wil/winrt.h index 68271c2..1766f6a 100644 --- a/include/wil/winrt.h +++ b/include/wil/winrt.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_WINRT_INCLUDED #define __WIL_WINRT_INCLUDED @@ -75,68 +84,6 @@ namespace wil template struct hstring_compare { - template - static auto compare(LhsT&& lhs, RhsT&& rhs) -> - decltype(get_buffer(lhs, wistd::declval()), get_buffer(rhs, wistd::declval()), int()) - { - UINT32 lhsLength; - UINT32 rhsLength; - auto lhsBuffer = get_buffer(wistd::forward(lhs), &lhsLength); - auto rhsBuffer = get_buffer(wistd::forward(rhs), &rhsLength); - - const auto result = ::CompareStringOrdinal( - lhsBuffer, - lhsLength, - rhsBuffer, - rhsLength, - IgnoreCase ? TRUE : FALSE); - NT_ASSERT(result != 0); - - return result; - } - - template - static auto equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> - decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) - { - return compare(wistd::forward(lhs), wistd::forward(rhs)) == CSTR_EQUAL; - } - - template - static auto not_equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> - decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) - { - return compare(wistd::forward(lhs), wistd::forward(rhs)) != CSTR_EQUAL; - } - - template - static auto less(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> - decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) - { - return compare(wistd::forward(lhs), wistd::forward(rhs)) == CSTR_LESS_THAN; - } - - template - static auto less_equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> - decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) - { - return compare(wistd::forward(lhs), wistd::forward(rhs)) != CSTR_GREATER_THAN; - } - - template - static auto greater(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> - decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) - { - return compare(wistd::forward(lhs), wistd::forward(rhs)) == CSTR_GREATER_THAN; - } - - template - static auto greater_equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> - decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) - { - return compare(wistd::forward(lhs), wistd::forward(rhs)) != CSTR_LESS_THAN; - } - // get_buffer returns the string buffer and length for the supported string types static const wchar_t* get_buffer(HSTRING hstr, UINT32* length) WI_NOEXCEPT { @@ -208,6 +155,68 @@ namespace wil return str.c_str(); } #endif + + template + static auto compare(LhsT&& lhs, RhsT&& rhs) -> + decltype(get_buffer(lhs, wistd::declval()), get_buffer(rhs, wistd::declval()), int()) + { + UINT32 lhsLength; + UINT32 rhsLength; + auto lhsBuffer = get_buffer(wistd::forward(lhs), &lhsLength); + auto rhsBuffer = get_buffer(wistd::forward(rhs), &rhsLength); + + const auto result = ::CompareStringOrdinal( + lhsBuffer, + lhsLength, + rhsBuffer, + rhsLength, + IgnoreCase ? TRUE : FALSE); + WI_ASSERT(result != 0); + + return result; + } + + template + static auto equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> + decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) + { + return compare(wistd::forward(lhs), wistd::forward(rhs)) == CSTR_EQUAL; + } + + template + static auto not_equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> + decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) + { + return compare(wistd::forward(lhs), wistd::forward(rhs)) != CSTR_EQUAL; + } + + template + static auto less(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> + decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) + { + return compare(wistd::forward(lhs), wistd::forward(rhs)) == CSTR_LESS_THAN; + } + + template + static auto less_equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> + decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) + { + return compare(wistd::forward(lhs), wistd::forward(rhs)) != CSTR_GREATER_THAN; + } + + template + static auto greater(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> + decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) + { + return compare(wistd::forward(lhs), wistd::forward(rhs)) == CSTR_GREATER_THAN; + } + + template + static auto greater_equals(LhsT&& lhs, RhsT&& rhs) WI_NOEXCEPT -> + decltype(compare(wistd::forward(lhs), wistd::forward(rhs)), bool()) + { + return compare(wistd::forward(lhs), wistd::forward(rhs)) != CSTR_LESS_THAN; + } }; } /// @endcond @@ -426,7 +435,7 @@ namespace wil template struct MapToSmartType { #pragma warning(push) - #pragma warning(disable:4702) /* https://microsoft.visualstudio.com/OS/_workitems?id=15917057&fullScreen=false&_a=edit */ + #pragma warning(disable:4702) // https://github.com/Microsoft/wil/issues/2 struct type // T holder { type() {}; @@ -439,6 +448,7 @@ namespace wil // In case of absense of T::operator=(T&&) a call to T::operator=(const T&) will happen T&& Get() { return wistd::move(m_value); } + HRESULT CopyTo(T* result) const { *result = m_value; return S_OK; } T* GetAddressOf() { return &m_value; } T* ReleaseAndGetAddressOf() { return &m_value; } T* operator&() { return &m_value; } @@ -727,8 +737,8 @@ namespace wil vector_range_nothrow& operator=(const vector_range_nothrow&) = delete; vector_range_nothrow(vector_range_nothrow&& other) : - m_v(other.m_v), m_result(other.m_result), m_resultStorage(other.m_resultStorage), - m_size(other.m_size), m_currentElement(wistd::move(other.m_currentElement)) + m_v(other.m_v), m_size(other.m_size), m_result(other.m_result), m_resultStorage(other.m_resultStorage), + m_currentElement(wistd::move(other.m_currentElement)) { } @@ -897,7 +907,7 @@ namespace wil } // for end() - iterable_iterator(int currentIndex) : m_i(-1) + iterable_iterator(int /*currentIndex*/) : m_i(-1) { } @@ -1244,6 +1254,7 @@ namespace details return wistd::forward(func)(wistd::forward(args)...); } +#ifdef WIL_ENABLE_EXCEPTIONS template ::value, int>::type = 0> HRESULT CallAndHandleErrorsWithReturnType(TFunc&& func, Args&&... args) @@ -1255,6 +1266,7 @@ namespace details CATCH_RETURN(); return S_OK; } +#endif template HRESULT CallAndHandleErrors(TFunc&& func, Args&&... args) @@ -2041,8 +2053,8 @@ namespace details #define WI_MakeUniqueWinRtEventToken(_event, _object, _handler) \ wil::details::make_unique_winrt_event_token( \ _object, \ - &wistd::remove_pointer::type::add_##_event, \ - &wistd::remove_pointer::type::remove_##_event, \ + &wistd::remove_pointer::type::add_##_event, \ + &wistd::remove_pointer::type::remove_##_event, \ _handler) #endif // WIL_ENABLE_EXCEPTIONS @@ -2050,16 +2062,16 @@ namespace details #define WI_MakeUniqueWinRtEventTokenNoThrow(_event, _object, _handler, _token_reference) \ wil::details::make_unique_winrt_event_token( \ _object, \ - &wistd::remove_pointer::type::add_##_event, \ - &wistd::remove_pointer::type::remove_##_event, \ + &wistd::remove_pointer::type::add_##_event, \ + &wistd::remove_pointer::type::remove_##_event, \ _handler, \ _token_reference) #define WI_MakeUniqueWinRtEventTokenFailFast(_event, _object, _handler) \ wil::details::make_unique_winrt_event_token( \ _object, \ - &wistd::remove_pointer::type::add_##_event, \ - &wistd::remove_pointer::type::remove_##_event, \ + &wistd::remove_pointer::type::add_##_event, \ + &wistd::remove_pointer::type::remove_##_event, \ _handler) #pragma endregion // EventRegistrationToken RAII wrapper diff --git a/include/wil/wistd_config.h b/include/wil/wistd_config.h index 0348a59..30b0ea9 100644 --- a/include/wil/wistd_config.h +++ b/include/wil/wistd_config.h @@ -1,3 +1,13 @@ +// -*- C++ -*- +//===--------------------------- __config ---------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source Licenses. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + // STL common functionality // // Some aspects of STL are core language concepts that should be used from all C++ code, regardless diff --git a/include/wil/wistd_functional.h b/include/wil/wistd_functional.h index f2b8210..e6bb087 100644 --- a/include/wil/wistd_functional.h +++ b/include/wil/wistd_functional.h @@ -37,6 +37,7 @@ // DO NOT add *any* additional includes to this file -- there should be no dependencies from its usage #include "wistd_memory.h" +#include // For __fastfail #if !defined(__WI_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER) #pragma GCC system_header @@ -128,7 +129,7 @@ namespace wistd // ("Windows Implementation" std) __WI_LIBCPP_NORETURN inline __WI_LIBCPP_INLINE_VISIBILITY void __throw_bad_function_call() { - __fastfail(FAST_FAIL_FATAL_APP_EXIT); + __fastfail(7); // FAST_FAIL_FATAL_APP_EXIT } template class __WI_LIBCPP_TEMPLATE_VIS function; // undefined @@ -191,7 +192,7 @@ namespace wistd // ("Windows Implementation" std) __WI_LIBCPP_INLINE_VISIBILITY __base() {} __WI_LIBCPP_INLINE_VISIBILITY virtual ~__base() {} virtual void __clone(__base*) const = 0; - virtual void __move(__base*) WI_NOEXCEPT = 0; + virtual void __move(__base*) = 0; virtual void destroy() WI_NOEXCEPT = 0; virtual _Rp operator()(_ArgTypes&& ...) = 0; }; @@ -213,7 +214,7 @@ namespace wistd // ("Windows Implementation" std) : __f_(__f) {} virtual void __clone(__base<_Rp(_ArgTypes...)>*) const; - virtual void __move(__base<_Rp(_ArgTypes...)>*) WI_NOEXCEPT; + virtual void __move(__base<_Rp(_ArgTypes...)>*); virtual void destroy() WI_NOEXCEPT; virtual _Rp operator()(_ArgTypes&& ... __arg); }; @@ -227,7 +228,7 @@ namespace wistd // ("Windows Implementation" std) template void - __func<_Fp, _Rp(_ArgTypes...)>::__move(__base<_Rp(_ArgTypes...)>* __p) WI_NOEXCEPT + __func<_Fp, _Rp(_ArgTypes...)>::__move(__base<_Rp(_ArgTypes...)>* __p) { ::new (__p) __func(wistd::move(__f_)); } @@ -301,12 +302,12 @@ namespace wistd // ("Windows Implementation" std) __WI_LIBCPP_INLINE_VISIBILITY function(nullptr_t) WI_NOEXCEPT : __f_(0) {} function(const function&); - function(function&&) WI_NOEXCEPT; + function(function&&); template> function(_Fp); function& operator=(const function&); - function& operator=(function&&) WI_NOEXCEPT; + function& operator=(function&&); function& operator=(nullptr_t) WI_NOEXCEPT; template> function& operator=(_Fp&&); @@ -314,7 +315,7 @@ namespace wistd // ("Windows Implementation" std) ~function(); // function modifiers: - void swap(function&) WI_NOEXCEPT; + void swap(function&); // function capacity: __WI_LIBCPP_INLINE_VISIBILITY @@ -346,7 +347,7 @@ namespace wistd // ("Windows Implementation" std) } template - function<_Rp(_ArgTypes...)>::function(function&& __f) WI_NOEXCEPT + function<_Rp(_ArgTypes...)>::function(function&& __f) { if (__f.__f_ == 0) __f_ = 0; @@ -388,7 +389,7 @@ namespace wistd // ("Windows Implementation" std) template function<_Rp(_ArgTypes...)>& - function<_Rp(_ArgTypes...)>::operator=(function&& __f) WI_NOEXCEPT + function<_Rp(_ArgTypes...)>::operator=(function&& __f) { *this = nullptr; if (__f.__f_) @@ -438,7 +439,7 @@ namespace wistd // ("Windows Implementation" std) template void - function<_Rp(_ArgTypes...)>::swap(function& __f) WI_NOEXCEPT + function<_Rp(_ArgTypes...)>::swap(function& __f) { if (wistd::addressof(__f) == this) return; @@ -506,13 +507,13 @@ namespace wistd // ("Windows Implementation" std) template inline __WI_LIBCPP_INLINE_VISIBILITY void - swap(function<_Rp(_ArgTypes...)>& __x, function<_Rp(_ArgTypes...)>& __y) WI_NOEXCEPT + swap(function<_Rp(_ArgTypes...)>& __x, function<_Rp(_ArgTypes...)>& __y) {return __x.swap(__y);} template inline __WI_LIBCPP_INLINE_VISIBILITY void - swap_wil(function<_Rp(_ArgTypes...)>& __x, function<_Rp(_ArgTypes...)>& __y) WI_NOEXCEPT + swap_wil(function<_Rp(_ArgTypes...)>& __x, function<_Rp(_ArgTypes...)>& __y) {return __x.swap(__y);} #else // __WI_LIBCPP_CXX03_LANG diff --git a/include/wil/wrl.h b/include/wil/wrl.h index 3ac54e5..e7f7c6c 100644 --- a/include/wil/wrl.h +++ b/include/wil/wrl.h @@ -1,4 +1,13 @@ - +//********************************************************* +// +// Copyright (c) Microsoft. All rights reserved. +// This code is licensed under the MIT License. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +// PARTICULAR PURPOSE AND NONINFRINGEMENT. +// +//********************************************************* #ifndef __WIL_WRL_INCLUDED #define __WIL_WRL_INCLUDED diff --git a/packaging/CMakeLists.txt b/packaging/CMakeLists.txt new file mode 100644 index 0000000..1f56884 --- /dev/null +++ b/packaging/CMakeLists.txt @@ -0,0 +1,2 @@ + +add_subdirectory(nuget) diff --git a/packaging/nuget/CMakeLists.txt b/packaging/nuget/CMakeLists.txt new file mode 100644 index 0000000..f3fd938 --- /dev/null +++ b/packaging/nuget/CMakeLists.txt @@ -0,0 +1,20 @@ + +file(TO_NATIVE_PATH "${CMAKE_BINARY_DIR}/build_tools/nuget.exe" nuget_exe) +file(TO_NATIVE_PATH "${CMAKE_CURRENT_BINARY_DIR}" nupkg_dir) +file(TO_NATIVE_PATH "${nupkg_dir}/Microsoft.Windows.ImplementationLibrary.${WIL_BUILD_VERSION}.nupkg" wil_nupkg) + +# The build servers don't have an up-to-date version of nuget, so pull it down ourselves... +file(DOWNLOAD https://dist.nuget.org/win-x86-commandline/latest/nuget.exe ${nuget_exe}) + +file(GLOB_RECURSE wil_headers ${CMAKE_SOURCE_DIR}/include/*.h) + +add_custom_command(OUTPUT ${wil_nupkg} + COMMAND ${nuget_exe} pack ${CMAKE_CURRENT_SOURCE_DIR}/Microsoft.Windows.ImplementationLibrary.nuspec -OutputDirectory ${nupkg_dir} -Version ${WIL_BUILD_VERSION} -NonInteractive + DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/Microsoft.Windows.ImplementationLibrary.nuspec + ${CMAKE_CURRENT_SOURCE_DIR}/Microsoft.Windows.ImplementationLibrary.targets + ${wil_headers} + ${CMAKE_SOURCE_DIR}/LICENSE + ${CMAKE_SOURCE_DIR}/ThirdPartyNotices.txt) + +add_custom_target(make_wil_nupkg DEPENDS ${wil_nupkg}) diff --git a/packaging/nuget/Microsoft.Windows.ImplementationLibrary.nuspec b/packaging/nuget/Microsoft.Windows.ImplementationLibrary.nuspec new file mode 100644 index 0000000..be709b9 --- /dev/null +++ b/packaging/nuget/Microsoft.Windows.ImplementationLibrary.nuspec @@ -0,0 +1,21 @@ + + + + Microsoft.Windows.ImplementationLibrary + 0.0.0 + Windows Implementation Library + Microsoft + false + The Windows Implementation Libraries (wil) were created to improve productivity and solve problems commonly seen by Windows developers. + windows utility wil native + © Microsoft Corporation. All rights reserved. + LICENSE + https://github.com/Microsoft/wil + + + + + + + + \ No newline at end of file diff --git a/packaging/nuget/Microsoft.Windows.ImplementationLibrary.targets b/packaging/nuget/Microsoft.Windows.ImplementationLibrary.targets new file mode 100644 index 0000000..29d756a --- /dev/null +++ b/packaging/nuget/Microsoft.Windows.ImplementationLibrary.targets @@ -0,0 +1,8 @@ + + + + + $(MSBuildThisFileDirectory)..\..\include\;%(AdditionalIncludeDirectories) + + + diff --git a/scripts/build_all.cmd b/scripts/build_all.cmd new file mode 100644 index 0000000..32d88f4 --- /dev/null +++ b/scripts/build_all.cmd @@ -0,0 +1,51 @@ +@echo off +setlocal EnableDelayedExpansion + +set BUILD_ROOT=%~dp0\..\build + +if "%Platform%"=="x64" ( + set BUILD_ARCH=64 +) else if "%Platform%"=="x86" ( + set BUILD_ARCH=32 +) else if [%Platform%]==[] ( + echo ERROR: The build_all.cmd script must be run from a Visual Studio command window + exit /B 1 +) else ( + echo ERROR: Unrecognized/unsupported platform %Platform% + exit /B 1 +) + +call :build clang debug +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :build clang release +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :build clang relwithdebinfo +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :build clang minsizerel +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +call :build msvc debug +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :build msvc release +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :build msvc relwithdebinfo +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :build msvc minsizerel +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +echo All build completed successfully! + +goto :eof + +:: build [compiler] [type] +:build +set BUILD_DIR=%BUILD_ROOT%\%1%BUILD_ARCH%%2 +if not exist %BUILD_DIR% ( + goto :eof +) + +pushd %BUILD_DIR% +echo Building from %CD% +ninja +popd +goto :eof diff --git a/scripts/init.cmd b/scripts/init.cmd new file mode 100644 index 0000000..3f83b74 --- /dev/null +++ b/scripts/init.cmd @@ -0,0 +1,154 @@ +@echo off +setlocal +setlocal EnableDelayedExpansion + +:: Globals +set BUILD_ROOT=%~dp0\..\build + +goto :init + +:usage + echo USAGE: + echo init.cmd [--help] [-c^|--compiler ^] [-g^|--generator ^] [--fast] + echo [-b^|--build-type ^] [-v^|--version X.Y.Z] + goto :eof + +:init + :: Initialize values as empty so that we can identify if we are using defaults later for error purposes + set COMPILER= + set GENERATOR= + set BUILD_TYPE= + set CMAKE_ARGS= + set BITNESS= + set VERSION= + set FAST_BUILD=0 + +:parse + if /I "%~1"=="" goto :execute + + if /I "%~1"=="--help" call :usage & goto :eof + + set COMPILER_SET=0 + if /I "%~1"=="-c" set COMPILER_SET=1 + if /I "%~1"=="--compiler" set COMPILER_SET=1 + if %COMPILER_SET%==1 ( + if "%COMPILER%" NEQ "" echo ERROR: Compiler already specified & exit /B 1 + + if /I "%~2"=="clang" set COMPILER=clang + if /I "%~2"=="msvc" set COMPILER=msvc + if "!COMPILER!"=="" echo ERROR: Unrecognized/missing compiler %~2 & exit /B 1 + + shift + shift + goto :parse + ) + + set GENERATOR_SET=0 + if /I "%~1"=="-g" set GENERATOR_SET=1 + if /I "%~1"=="--generator" set GENERATOR_SET=1 + if %GENERATOR_SET%==1 ( + if "%GENERATOR%" NEQ "" echo ERROR: Generator already specified & exit /B 1 + + if /I "%~2"=="ninja" set GENERATOR=ninja + if /I "%~2"=="msbuild" set GENERATOR=msbuild + if "!GENERATOR!"=="" echo ERROR: Unrecognized/missing generator %~2 & exit /B 1 + + shift + shift + goto :parse + ) + + set BUILD_TYPE_SET=0 + if /I "%~1"=="-b" set BUILD_TYPE_SET=1 + if /I "%~1"=="--build-type" set BUILD_TYPE_SET=1 + if %BUILD_TYPE_SET%==1 ( + if "%BUILD_TYPE%" NEQ "" echo ERROR: Build type already specified & exit /B 1 + + if /I "%~2"=="debug" set BUILD_TYPE=debug + if /I "%~2"=="release" set BUILD_TYPE=release + if /I "%~2"=="relwithdebinfo" set BUILD_TYPE=relwithdebinfo + if /I "%~2"=="minsizerel" set BUILD_TYPE=minsizerel + if "!BUILD_TYPE!"=="" echo ERROR: Unrecognized/missing build type %~2 & exit /B 1 + + shift + shift + goto :parse + ) + + set VERSION_SET=0 + if /I "%~1"=="-v" set VERSION_SET=1 + if /I "%~1"=="--version" set VERSION_SET=1 + if %VERSION_SET%==1 ( + if "%VERSION%" NEQ "" echo ERROR: Version alread specified & exit /B 1 + if /I "%~2"=="" echo ERROR: Version string missing & exit /B 1 + + set VERSION=%~2 + + shift + shift + goto :parse + ) + + if /I "%~1"=="--fast" ( + if %FAST_BUILD% NEQ 0 echo ERROR: Fast build already specified + set FAST_BUILD=1 + shift + goto :parse + ) + + echo ERROR: Unrecognized argument %~1 + exit /B 1 + +:execute + :: Check for conflicting arguments + if "%COMPILER%"=="clang" ( + if "%GENERATOR%"=="msbuild" echo ERROR: Cannot use Clang with MSBuild & exit /B 1 + ) + + :: Select defaults + if "%GENERATOR%"=="" set GENERATOR=ninja + if %GENERATOR%==msbuild set COMPILER=msvc + + if "%COMPILER%"=="" set COMPILER=clang + + if "%BUILD_TYPE%"=="" set BUILD_TYPE=debug + + :: Formulate CMake arguments + if %GENERATOR%==ninja set CMAKE_ARGS=%CMAKE_ARGS% -G Ninja + + if %COMPILER%==clang set CMAKE_ARGS=%CMAKE_ARGS% -DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl + if %COMPILER%==msvc set CMAKE_ARGS=%CMAKE_ARGS% -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=cl + + if %BUILD_TYPE%==debug set CMAKE_ARGS=%CMAKE_ARGS% -DCMAKE_BUILD_TYPE=Debug + if %BUILD_TYPE%==release set CMAKE_ARGS=%CMAKE_ARGS% -DCMAKE_BUILD_TYPE=Release + if %BUILD_TYPE%==relwithdebinfo set CMAKE_ARGS=%CMAKE_ARGS% -DCMAKE_BUILD_TYPE=RelWithDebInfo + if %BUILD_TYPE%==minsizerel set CMAKE_ARGS=%CMAKE_ARGS% -DCMAKE_BUILD_TYPE=MinSizeRel + + if "%VERSION%" NEQ "" set CMAKE_ARGS=%CMAKE_ARGS% -DWIL_BUILD_VERSION=%VERSION% + + if %FAST_BUILD%==1 set CMAKE_ARGS=%CMAKE_ARGS% -DFAST_BUILD=ON + + :: Figure out the platform + if "%Platform%"=="" echo ERROR: The init.cmd script must be run from a Visual Studio command window & exit /B 1 + if "%Platform%"=="x86" ( + set BITNESS=32 + if %COMPILER%==clang set CFLAGS=-m32 & set CXXFLAGS=-m32 + ) + if "%Platform%"=="x64" set BITNESS=64 + if "%BITNESS%"=="" echo ERROR: Unrecognized/unsupported platform %Platform% & exit /B 1 + + :: Set up the build directory + set BUILD_DIR=%BUILD_ROOT%\%COMPILER%%BITNESS%%BUILD_TYPE% + mkdir %BUILD_DIR% > NUL 2>&1 + + :: Run CMake + pushd %BUILD_DIR% + echo Using compiler....... %COMPILER% + echo Using architecture... %Platform% + echo Using build type..... %BUILD_TYPE% + echo Using build root..... %CD% + echo. + cmake %CMAKE_ARGS% ..\.. + popd + + goto :eof diff --git a/scripts/init_all.cmd b/scripts/init_all.cmd new file mode 100644 index 0000000..f068eda --- /dev/null +++ b/scripts/init_all.cmd @@ -0,0 +1,17 @@ +@echo off + +:: NOTE: Architecture is picked up from the command window, so we can't control that here :( + +:: TODO: https://github.com/Microsoft/wil/issues/7 - There's currently a bug where Clang and/or the linker chokes when +:: trying to compile the tests for 32-bit debug, so skip for now +if "%Platform%"=="x86" goto :skip_clang_x86_debug +call %~dp0\init.cmd -c clang -g ninja -b debug %* +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +:skip_clang_x86_debug +call %~dp0\init.cmd -c clang -g ninja -b relwithdebinfo %* +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +call %~dp0\init.cmd -c msvc -g ninja -b debug %* +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call %~dp0\init.cmd -c msvc -g ninja -b relwithdebinfo %* +if %ERRORLEVEL% NEQ 0 ( goto :eof ) diff --git a/scripts/runtests.cmd b/scripts/runtests.cmd new file mode 100644 index 0000000..e9282fb --- /dev/null +++ b/scripts/runtests.cmd @@ -0,0 +1,67 @@ +@echo off +setlocal EnableDelayedExpansion + +set BUILD_ROOT=%~dp0\..\build + +:: Unlike building, we don't need to limit ourselves to the Platform of the command window +call :execute_tests clang64debug +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests clang64release +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests clang64relwithdebinfo +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests clang64minsizerel +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +call :execute_tests clang32debug +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests clang32release +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests clang32relwithdebinfo +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests clang32minsizerel +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +call :execute_tests msvc64debug +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests msvc64release +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests msvc64relwithdebinfo +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests msvc64minsizerel +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +call :execute_tests msvc32debug +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests msvc32release +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests msvc32relwithdebinfo +if %ERRORLEVEL% NEQ 0 ( goto :eof ) +call :execute_tests msvc32minsizerel +if %ERRORLEVEL% NEQ 0 ( goto :eof ) + +goto :eof + +:execute_tests +set BUILD_DIR=%BUILD_ROOT%\%1 +if not exist %BUILD_DIR% ( goto :eof ) + +pushd %BUILD_DIR% +echo Running tests from %CD% +call :execute_test app witest.app.exe +if %ERRORLEVEL% NEQ 0 ( popd && goto :eof ) +call :execute_test cpplatest witest.cpplatest.exe +if %ERRORLEVEL% NEQ 0 ( popd && goto :eof ) +call :execute_test noexcept witest.noexcept.exe +if %ERRORLEVEL% NEQ 0 ( popd && goto :eof ) +call :execute_test normal witest.exe +if %ERRORLEVEL% NEQ 0 ( popd && goto :eof ) +popd + +goto :eof + +:execute_test +if not exist tests\%1\%2 ( goto :eof ) +echo Running %1 tests... +tests\%1\%2 +goto :eof diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000..a9c0fae --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,19 @@ + +include(${CMAKE_SOURCE_DIR}/cmake/common_build_flags.cmake) + +# All projects need to reference the WIL headers +include_directories(${CMAKE_SOURCE_DIR}/include) + +# TODO: Might be worth trying to conditionally do this on SDK version, assuming there's a semi-easy way to detect that +include_directories(BEFORE SYSTEM ./workarounds/wrl) + +# The build pipelines have limitations that local development environments do not, so turn a few knobs +if (${FAST_BUILD}) + replace_cxx_flag("/GR" "/GR-") # Disables RTTI + add_definitions(-DCATCH_CONFIG_FAST_COMPILE -DWIL_FAST_BUILD) +endif() + +add_subdirectory(app) +add_subdirectory(cpplatest) +add_subdirectory(noexcept) +add_subdirectory(normal) diff --git a/tests/ComTests.cpp b/tests/ComTests.cpp new file mode 100644 index 0000000..87d3552 --- /dev/null +++ b/tests/ComTests.cpp @@ -0,0 +1,2676 @@ + +#include // Bring in IObjectWithSite + +#include +#include + +#include "common.h" + +using namespace Microsoft::WRL; + +// avoid including #include , it fails to compile in noprivateapis +class DECLSPEC_UUID("00021401-0000-0000-C000-000000000046") ShellLink; + +// Uncomment this line to do a more exhaustive test of the concepts covered by this file. By +// default we don't fully compile every combination of tests as this test can substantially impact +// build times with template expansion. + +// #define WIL_EXHAUSTIVE_TEST + +// Helper objects / functions +class __declspec(uuid("a817e7a2-43fa-11d0-9e44-00aa00b6770a")) +IUnknownFake : public IUnknown +{ +public: + STDMETHOD_(ULONG, AddRef)() + { + AddRefCounter++; + return 0; + } + STDMETHOD_(ULONG, Release)() + { + ReleaseCounter++; + return 0; + } + STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) + { + if (riid == __uuidof(IUnknown)) + { + *ppvObject = this; + return S_OK; + } + *ppvObject = nullptr; + return E_NOINTERFACE; + } + bool ReturnTRUE() + { + return true; + } + static void Clear() + { + AddRefCounter = 0; + ReleaseCounter = 0; + } + static int GetAddRef() + { + int res = AddRefCounter; + AddRefCounter = 0; + return res; + } + static int GetRelease() + { + int res = ReleaseCounter; + ReleaseCounter = 0; + return res; + } +protected: + static int AddRefCounter; + static int ReleaseCounter; +}; + +int IUnknownFake::AddRefCounter = 0; +int IUnknownFake::ReleaseCounter = 0; + +class __declspec(uuid("a817e7a2-43fa-11d0-9e44-00aa00b6770b")) +IUnknownFake2 : public IUnknownFake {}; + +TEST_CASE("ComTests::Test_Constructors", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + + SECTION("Null/default construction") + { + wil::com_ptr_nothrow ptr; //default constructor + REQUIRE(ptr.get() == nullptr); + + wil::com_ptr_nothrow ptr2(nullptr); //default explicit null constructor + REQUIRE(ptr2.get() == nullptr); + + IUnknown* nullPtr = nullptr; + wil::com_ptr_nothrow ptr3(nullPtr); + REQUIRE(ptr3.get() == nullptr); + } + + SECTION("Valid pointer construction") + { + wil::com_ptr_nothrow ptr(&helper); // explicit + REQUIRE(IUnknownFake::GetAddRef() == 1); + REQUIRE(ptr.get() == &helper); + } + + SECTION("Copy construction") + { + wil::com_ptr_nothrow ptr(&helper); + wil::com_ptr_nothrow ptrCopy(ptr); // assign the same pointer + REQUIRE(IUnknownFake::GetAddRef() == 2); + REQUIRE(ptrCopy.get() == ptr.get()); + + IUnknownFake2 helper2; + wil::com_ptr_nothrow ptr2(&helper2); + wil::com_ptr_nothrow ptrCopy2(ptr2); + REQUIRE(IUnknownFake::GetAddRef() == 2); + REQUIRE(ptrCopy2.get() == &helper2); + } + + SECTION("Move construction") + { + IUnknownFake helper3; + wil::com_ptr_nothrow ptr(&helper3); + wil::com_ptr_nothrow ptrMove(reinterpret_cast&&>(ptr)); + REQUIRE(IUnknownFake::GetAddRef() == 1); + REQUIRE(ptrMove.get() == &helper3); + REQUIRE(ptr.get() == nullptr); + + IUnknownFake2 helper4; + wil::com_ptr_nothrow ptr2(&helper4); + wil::com_ptr_nothrow ptrMove2(reinterpret_cast&&>(ptr2)); + REQUIRE(IUnknownFake::GetAddRef() == 1); + REQUIRE(ptrMove2.get() == &helper4); + REQUIRE(ptr2.get() == nullptr); + } +} + +TEST_CASE("ComTests::Test_Assign", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + + SECTION("Null pointer assignment") + { + wil::com_ptr_nothrow ptr(&helper); + ptr = nullptr; + REQUIRE(ptr.get() == nullptr); + REQUIRE(IUnknownFake::GetRelease() == 1); + } + + IUnknownFake::Clear(); + IUnknownFake helper2; + + SECTION("Different pointer assignment") + { + wil::com_ptr_nothrow ptr(&helper); + wil::com_ptr_nothrow ptr2(&helper2); + + ptr = static_cast&>(ptr2); + + REQUIRE(ptr.get() == &helper2); + REQUIRE(ptr2.get() == &helper2); + REQUIRE(IUnknownFake::GetRelease() == 1); + REQUIRE(IUnknownFake::GetAddRef() == 3); + } + + SECTION("Self assignment") + { + wil::com_ptr_nothrow ptr(&helper); + + IUnknownFake::Clear(); + ptr = ptr; + + REQUIRE(ptr.get() == &helper); + // wil::com_ptr can do self-assignment without blowing up -- and chooses NOT to preserve the this comparison for performance + // as this should be a rare/never operation... + // REQUIRE(IUnknownFake::GetRelease() == 0); + // REQUIRE(IUnknownFake::GetAddRef() == 0); + } + + IUnknownFake2 helper3; + + SECTION("Assign pointer with different interface") + { + wil::com_ptr_nothrow ptr(&helper); + wil::com_ptr_nothrow ptr2(&helper3); + + IUnknownFake::Clear(); + ptr = static_cast&>(ptr2); + + REQUIRE(ptr.get() == &helper3); + REQUIRE(ptr2.get() == &helper3); + REQUIRE(IUnknownFake::GetRelease() == 1); + REQUIRE(IUnknownFake::GetAddRef() == 1); + } + + SECTION("Move assignment") + { + wil::com_ptr_nothrow ptr(&helper); + wil::com_ptr_nothrow ptr2(&helper2); + + IUnknownFake::Clear(); + ptr = static_cast&&>(ptr2); + + REQUIRE(ptr.get() == &helper2); + REQUIRE(ptr2.get() == nullptr); + REQUIRE(IUnknownFake::GetRelease() == 1); + REQUIRE(IUnknownFake::GetAddRef() == 0); + } + + SECTION("Move assign with different interface") + { + wil::com_ptr_nothrow ptr(&helper); + wil::com_ptr_nothrow ptr2(&helper3); + + IUnknownFake::Clear(); + ptr = static_cast&&>(ptr2); + + REQUIRE(ptr.get() == &helper3); + REQUIRE(ptr2.get() == nullptr); + REQUIRE(IUnknownFake::GetRelease() == 1); + REQUIRE(IUnknownFake::GetAddRef() == 0); + } +} + +TEST_CASE("ComTests::Test_Operators", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + IUnknownFake helper2; + IUnknownFake2 helper3; + + wil::com_ptr_nothrow ptrNULL; //NULL one + wil::com_ptr_nothrow ptrLT(&helper); + wil::com_ptr_nothrow ptrGT(&helper2); + wil::com_ptr_nothrow ptrDiff(&helper3); + + SECTION("equal operator") + { + REQUIRE_FALSE(ptrNULL == ptrLT); + REQUIRE(ptrNULL == ptrNULL); + REQUIRE(ptrLT == ptrLT); + REQUIRE_FALSE(ptrDiff == ptrLT); + REQUIRE_FALSE(ptrLT == ptrGT); + } + + SECTION("not equals operator") + { + REQUIRE(ptrNULL != ptrLT); + REQUIRE_FALSE(ptrNULL != ptrNULL); + REQUIRE_FALSE(ptrLT != ptrLT); + REQUIRE(ptrDiff != ptrLT); + REQUIRE(ptrLT != ptrGT); + } + + SECTION("less-than operator") + { + REQUIRE_FALSE(ptrNULL < ptrNULL); + REQUIRE(ptrNULL < ptrLT); + REQUIRE(ptrNULL < ptrLT); + + if (ptrLT.get() < ptrGT.get()) + { + REQUIRE(ptrLT < ptrGT); + } + else + { + REQUIRE(ptrGT < ptrLT); + } + } +} + +TEST_CASE("ComTests::Test_Conversion", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + + wil::com_ptr_nothrow nullPtr; + wil::com_ptr_nothrow ptr(&helper); + + REQUIRE_FALSE(nullPtr); + REQUIRE(ptr); +} + +TEST_CASE("ComTests::Test_Address", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + + IUnknownFake** pFakePtr; + SECTION("addressof") + { + wil::com_ptr_nothrow ptr(&helper); + + IUnknownFake::Clear(); + pFakePtr = ptr.addressof(); + REQUIRE(IUnknownFake::GetRelease() == 0); + REQUIRE(IUnknownFake::GetAddRef() == 0); + REQUIRE((*pFakePtr) == &helper); + } + + SECTION("Address operator") + { + wil::com_ptr_nothrow ptr(&helper); + IUnknownFake::Clear(); + + pFakePtr = &ptr; + REQUIRE(IUnknownFake::GetRelease() == 1); + REQUIRE(IUnknownFake::GetAddRef() == 0); + REQUIRE((*pFakePtr) == nullptr); + REQUIRE(ptr == nullptr); + } +} + +TEST_CASE("ComTests::Test_Helpers", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + IUnknownFake helper2; + IUnknownFake *ptrHelper; + wil::com_ptr_nothrow ptr(&helper); + + SECTION("detach") + { + IUnknownFake::Clear(); //clear addref counter + ptrHelper = ptr.detach(); + REQUIRE(ptr.get() == nullptr); + REQUIRE(ptrHelper == &helper); + REQUIRE(IUnknownFake::GetAddRef() == 0); + } + + SECTION("attach") + { + ptrHelper = &helper; + wil::com_ptr_nothrow ptr2(&helper2); //have some non null pointer + + IUnknownFake::Clear(); //clear addref counter + ptr2.attach(ptrHelper); + REQUIRE(ptr2.get() == ptrHelper); + REQUIRE(IUnknownFake::GetRelease() == 1); + REQUIRE(IUnknownFake::GetAddRef() == 0); + } + + SECTION("get") + { + wil::com_ptr_nothrow ptr2; + REQUIRE(ptr2.get() == nullptr); + + IUnknownFake helper3; + wil::com_ptr_nothrow ptr4(&helper3); + REQUIRE(ptr4.get() == &helper3); + } + + SECTION("l-value swap") + { + wil::com_ptr_nothrow ptr2(&helper); + wil::com_ptr_nothrow ptr3(&helper2); + + ptr2.swap(ptr3); + REQUIRE(ptr2.get() == &helper2); + REQUIRE(ptr3.get() == &helper); + } + + SECTION("r-value swap") + { + wil::com_ptr_nothrow ptr2(&helper); + wil::com_ptr_nothrow ptr3(&helper2); + + ptr2.swap(wistd::move(ptr3)); + REQUIRE(ptr2.get() == &helper2); + REQUIRE(ptr3.get() == &helper); + } +} + +TEST_CASE("ComTests::Test_As", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + + IUnknownFake helper; + wil::com_ptr_nothrow ptr(&helper); + + SECTION("query by IID") + { + wil::com_ptr_nothrow ptr2; + // REQUIRE(S_OK == ptr.AsIID(__uuidof(IUnknown), &ptr2)); + REQUIRE(S_OK == ptr.query_to(__uuidof(IUnknown), reinterpret_cast(&ptr2))); + REQUIRE(ptr2 != nullptr); + } + + SECTION("query by invalid IID") + { + wil::com_ptr_nothrow ptr2; + // REQUIRE(S_OK != ptr.AsIID(__uuidof(IDispatch), &ptr2)); + REQUIRE(S_OK != ptr.query_to(__uuidof(IDispatch), reinterpret_cast(&ptr2))); + REQUIRE(ptr2 == nullptr); + } + + SECTION("same interface query") + { + // wil::com_ptr optimizes same-type assignment to just call AddRef + IUnknownFake2 helper2; + wil::com_ptr_nothrow ptr2(&helper2); + wil::com_ptr_nothrow ptr3; + REQUIRE(S_OK == ptr2.query_to(&ptr3)); + REQUIRE(ptr3 != nullptr); + } + + SECTION("base interface query") + { + IUnknownFake2 helper2; + wil::com_ptr_nothrow ptr2(&helper2); + wil::com_ptr_nothrow ptr3; + REQUIRE(S_OK == ptr2.query_to(&ptr3)); + REQUIRE(ptr3 != nullptr); + } +} + +TEST_CASE("ComTests::Test_CopyTo", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + + IUnknownFake helper; + IUnknownFake2 helper2; + wil::com_ptr_nothrow ptr(&helper); + + SECTION("copy by IID") + { + wil::com_ptr_nothrow ptr2; + REQUIRE(S_OK == ptr.copy_to(__uuidof(IUnknown), reinterpret_cast(&ptr2))); + REQUIRE(ptr2 != nullptr); + } + + SECTION("copy by invalid IID") + { + wil::com_ptr_nothrow ptr2; + REQUIRE(S_OK != ptr.copy_to(__uuidof(IDispatch), reinterpret_cast(&ptr2))); + REQUIRE(ptr2 == nullptr); + } + + SECTION("same interface copy") + { + wil::com_ptr_nothrow ptr2(&helper2); + wil::com_ptr_nothrow ptr3; + REQUIRE(S_OK == ptr2.copy_to(&ptr3)); + REQUIRE(ptr3 != nullptr); + } + + SECTION("base interface copy") + { + wil::com_ptr_nothrow ptr2(&helper2); + wil::com_ptr_nothrow ptr3; + REQUIRE(S_OK == ptr2.copy_to(ptr3.addressof())); + REQUIRE(ptr3 != nullptr); + } +} + +// Helper used to verify correctness of IID_PPV_ARGS support +void IID_PPV_ARGS_Test_Helper(REFIID iid, void** pv) +{ + __analysis_assume(pv != nullptr); + REQUIRE(pv != nullptr); + REQUIRE(*pv == nullptr); + *pv = reinterpret_cast(0x01); // Set check value + + REQUIRE(iid == __uuidof(IUnknown)); +} + +TEST_CASE("ComTests::Test_IID_PPV_ARGS", "[com][com_ptr]") +{ + wil::com_ptr_nothrow unk; + IID_PPV_ARGS_Test_Helper(IID_PPV_ARGS(&unk)); + //Test if we got the correct check value back + REQUIRE(unk.get() == reinterpret_cast(0x01)); + // Make sure that we will not try to release some garbage + auto avoidWarning = unk.detach(); + (void)avoidWarning; +} + +// Helps with testing wil::com_ptr configuration when the operator -> is used +class ExtensionHelper +{ +public: + HRESULT Extend() const + { + return S_OK; + } + STDMETHOD_(ULONG, AddRef)() const + { + return 0; + } + STDMETHOD_(ULONG, Release)() const + { + return 0; + } +}; + +TEST_CASE("ComTests::Test_ConstPointer", "[com][com_ptr]") +{ + IUnknownFake::Clear(); + IUnknownFake helper; + + const wil::com_ptr_nothrow spUnk(&helper); + wil::com_ptr_nothrow spUnkHelper; + wil::com_ptr_nothrow spInspectable; + + REQUIRE(spUnk.get() != nullptr); + REQUIRE(spUnk); + spUnk.addressof(); + spUnk.copy_to(spUnkHelper.addressof()); + spUnk.copy_to(spInspectable.addressof()); + spUnk.copy_to(IID_PPV_ARGS(&spInspectable)); + + spUnk.query_to(&spUnkHelper); + spUnk.query_to(&spInspectable); + spUnk.query_to(__uuidof(IUnknown), reinterpret_cast(&spUnkHelper)); + + const ExtensionHelper extHelper; + wil::com_ptr_nothrow spExt(&extHelper); + REQUIRE(spExt->Extend() == S_OK); +} + +// Make sure that the pointer can be defined just with forward declaration of the class +TEST_CASE("ComTests::Test_ComPtrWithForwardDeclaration", "[com][com_ptr]") +{ + class MyClass; + + wil::com_ptr_nothrow spClass; + + class MyClass : public IUnknown + { + public: + STDMETHOD_(ULONG, AddRef)() + { + return 0; + } + STDMETHOD_(ULONG, Release)() + { + return 0; + } + }; +} + +//***************************************************************************** +// various com_ptr tests +//***************************************************************************** + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a00")) +ITest : public IUnknown +{ + STDMETHOD_(void, Test)() = 0; +}; + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a01")) +IDerivedTest : public ITest +{ + STDMETHOD_(void, TestDerived)() = 0; +}; + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a02")) +ITestInspectable : public IInspectable +{ + STDMETHOD_(void, TestInspctable)() = 0; +}; + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a03")) +IDerivedTestInspectable : public ITestInspectable +{ + STDMETHOD_(void, TestInspctableDerived)() = 0; +}; + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a04")) +INever : public IUnknown +{ + STDMETHOD_(void, Never)() = 0; +}; + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20a05")) +IAlways : public IUnknown +{ + STDMETHOD_(void, Always)() = 0; +}; + +class __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20b00")) // non-implemented to allow QI for the class to be attempted (and fail) +ComObject : witest::AllocatedObject, + public Microsoft::WRL::RuntimeClass, + ITest, IDerivedTest, IAlways>{ +public: + IFACEMETHODIMP_(void) Test() {} + IFACEMETHODIMP_(void) TestDerived() {} + IFACEMETHODIMP_(void) Always() {} +}; + +class __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20b01")) // non-implemented to allow QI for the class to be attempted (and fail) +WinRtObject : witest::AllocatedObject, + public Microsoft::WRL::RuntimeClass, + ITest, IDerivedTest, ITestInspectable, IDerivedTestInspectable, IAlways, Microsoft::WRL::FtmBase> +{ +public: + IFACEMETHODIMP_(void) Test() {} + IFACEMETHODIMP_(void) TestDerived() {} + IFACEMETHODIMP_(void) TestInspctable() {} + IFACEMETHODIMP_(void) TestInspctableDerived() {} + IFACEMETHODIMP_(void) Always() {} +}; + +class NoCom : witest::AllocatedObject +{ +public: + ULONG __stdcall AddRef() + { + return m_ref++; + } + + ULONG __stdcall Release() + { + auto retVal = (--m_ref); + if (retVal == 0) + { + delete this; + } + return retVal; + } + +private: + ULONG m_ref = 1; +}; + +template >> +T* cast_object(U*) +{ + FAIL_FAST(); +} + +template +T* cast_object(T* ptr) +{ + return ptr; +} + +template +static IFace* make_object() +{ + auto obj = Microsoft::WRL::Make(); + + IFace* result = nullptr; + if (FAILED(obj.Get()->QueryInterface(__uuidof(IFace), reinterpret_cast(&result)))) + { + // The QI only fails when we're asking for a CFoo from a CFoo (equivalent types)... in this + // case just return the original pointer -- the reinterpret_cast is needed as the code is shared + // and the other (nonuniform) cases also compile it (but do not execute it). + result = cast_object(obj.Detach()); + } + + return result; +} + +template <> +NoCom* make_object() +{ + return new NoCom(); +} + +template +void TestSmartPointer(const Ptr& ptr1, const Ptr& ptr2) +{ + SECTION("swap (method and global)") + { + auto p1 = ptr1; + auto p2 = ptr2; + p1.swap(p2); // l-value + REQUIRE(((p1 == ptr2) && (p2 == ptr1))); + p1.swap(wistd::move(p2)); // r-value + REQUIRE(((p1 == ptr1) && (p2 == ptr2))); + wil::swap(p1, p2); + REQUIRE(((p1 == ptr2) && (p2 == ptr1))); + } + + SECTION("WRL swap (method and global)") + { + auto p1 = ptr1; + Microsoft::WRL::ComPtr p2 = ptr2.get(); + p1.swap(p2); // l-value + REQUIRE(((p1 == ptr2) && (p2 == ptr1))); + p1.swap(wistd::move(p2)); // r-value + REQUIRE(((p1 == ptr1) && (p2 == ptr2))); + wil::swap(p1, p2); + REQUIRE(((p1 == ptr2) && (p2 == ptr1))); + wil::swap(p2, p1); + REQUIRE(((p1 == ptr1) && (p2 == ptr2))); + } + + SECTION("reset") + { + auto p = ptr1; + p.reset(); + REQUIRE_FALSE(p); + p = ptr1; + p.reset(nullptr); + REQUIRE_FALSE(p); + } + + SECTION("attach / detach") + { + auto p1 = ptr1; + auto p2 = ptr2; + p1.attach(p2.detach()); + REQUIRE(((p1.get() == ptr2.get()) && !p2)); + } + + SECTION("addressof") + { + auto p1 = ptr1; + auto p2 = ptr2; + p1.addressof(); // Doesn't reset + REQUIRE(p1.get() == ptr1.get()); + p1.reset(); + *(p1.addressof()) = p2.detach(); + REQUIRE(p1.get() == ptr2.get()); + } + + SECTION("operator&") + { + auto p1 = ptr1; + auto p2 = ptr2; + &p1; + REQUIRE_FALSE(p1); + *(&p1) = p2.detach(); + REQUIRE(p1.get() == ptr2.get()); + } + + SECTION("exercise const methods on the const param (ensure const)") + { + auto address = ptr1.addressof(); + REQUIRE(*address == ptr1.get()); + (void)static_cast(ptr1); + ptr1.get(); + auto deref = ptr1.operator->(); + (void)deref; + if (ptr1) + { + auto& ref = ptr1.operator*(); + (void)ref; + } + } +} + +template +static void TestPointerCombination(IFace* p1, IFace* p2) +{ +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointer(wil::com_ptr(p1), wil::com_ptr(p2)); +#endif + TestSmartPointer(wil::com_ptr_failfast(p1), wil::com_ptr_failfast(p2)); + TestSmartPointer(wil::com_ptr_nothrow(p1), wil::com_ptr_nothrow(p2)); +} + +template +static void TestPointer() +{ + auto p1 = make_object(); + auto p2 = make_object(); + IFace* nullPtr = nullptr; + TestPointerCombination(p1, p2); + TestPointerCombination(nullPtr, p2); + TestPointerCombination(p1, nullPtr); + TestPointerCombination(nullPtr, nullPtr); + TestPointerCombination(p1, p1); // same object + + p1->Release(); + p2->Release(); +} + +TEST_CASE("ComTests::Test_MemberFunctions", "[com][com_ptr]") +{ + // avoid overwhelming debug logging, perhaps the COM helpers are over reporting + auto restoreDebugString = wil::g_fResultOutputDebugString; + wil::g_fResultOutputDebugString = false; + + TestPointer(); + + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + TestPointer(); + + REQUIRE_FALSE(witest::g_objectCount.Leaked()); + wil::g_fResultOutputDebugString = restoreDebugString; +} + +template +static void TestSmartPointerConversion(const Ptr1& ptr1, const Ptr2& ptr2) +{ + const Microsoft::WRL::ComPtr wrl1 = ptr1.get(); + const Microsoft::WRL::ComPtr wrl2 = ptr2.get(); + + SECTION("global comparison operators") + { + auto p1 = ptr1.get(); + auto p2 = ptr2.get(); + + // com_ptr to com_ptr + REQUIRE((ptr1 == ptr2) == (p1 == p2)); + REQUIRE((ptr1 != ptr2) == (p1 != p2)); + REQUIRE((ptr1 < ptr2) == (p1 < p2)); + REQUIRE((ptr1 <= ptr2) == (p1 <= p2)); + REQUIRE((ptr1 > ptr2) == (p1 > p2)); + REQUIRE((ptr1 >= ptr2) == (p1 >= p2)); + + // com_ptr to ComPtr + REQUIRE((wrl1 == ptr2) == (p1 == p2)); + REQUIRE((wrl1 != ptr2) == (p1 != p2)); + REQUIRE((wrl1 < ptr2) == (p1 < p2)); + REQUIRE((wrl1 <= ptr2) == (p1 <= p2)); + REQUIRE((wrl1 > ptr2) == (p1 > p2)); + REQUIRE((wrl1 >= ptr2) == (p1 >= p2)); + + REQUIRE((ptr1 == wrl2) == (p1 == p2)); + REQUIRE((ptr1 != wrl2) == (p1 != p2)); + REQUIRE((ptr1 < wrl2) == (p1 < p2)); + REQUIRE((ptr1 <= wrl2) == (p1 <= p2)); + REQUIRE((ptr1 > wrl2) == (p1 > p2)); + REQUIRE((ptr1 >= wrl2) == (p1 >= p2)); + + // com_ptr to raw pointer + REQUIRE((ptr1 == p2) == (p1 == p2)); + REQUIRE((ptr1 != p2) == (p1 != p2)); + REQUIRE((ptr1 < p2) == (p1 < p2)); + REQUIRE((ptr1 <= p2) == (p1 <= p2)); + REQUIRE((ptr1 > p2) == (p1 > p2)); + REQUIRE((ptr1 >= p2) == (p1 >= p2)); + + REQUIRE((p1 == ptr2) == (p1 == p2)); + REQUIRE((p1 != ptr2) == (p1 != p2)); + REQUIRE((p1 < ptr2) == (p1 < p2)); + REQUIRE((p1 <= ptr2) == (p1 <= p2)); + REQUIRE((p1 > ptr2) == (p1 > p2)); + REQUIRE((p1 >= ptr2) == (p1 >= p2)); + } + + SECTION("construct from raw pointer") + { + Ptr1 p1(ptr2.get()); + Ptr1 p2 = ptr2.get(); + REQUIRE(((p1 == ptr2) && (p2 == ptr2))); + } + + SECTION("construct from com_ptr ref<>") + { + Ptr1 p1(ptr2); + Ptr1 p2 = (ptr2); + REQUIRE(((p1 == ptr2) && (p2 == ptr2))); + } + + SECTION("r-value construct from com_ptr ref<>") + { + auto move1 = ptr2; + auto move2 = ptr2; + Ptr1 p1(wistd::move(move1)); + Ptr1 p2 = wistd::move(move2); + REQUIRE(((p1 == ptr2) && (p2 == ptr2))); + } + + SECTION("assign from raw pointer") + { + Ptr1 p = ptr1; + p = (ptr2.get()); + REQUIRE(p == ptr2); + } + + SECTION("assign from com_ptr ref<>") + { + Ptr1 p = ptr1; + p = ptr2; + REQUIRE(p == ptr2); + } + + SECTION("r-value assign from com_ptr ref<>") + { + Ptr1 p = ptr1; + p = Ptr2(ptr2); + REQUIRE(p == ptr2); + } + + SECTION("construct from ComPtr ref<>") + { + Ptr1 p1(wrl2); + Ptr1 p2 = (wrl2); + REQUIRE(((p1 == wrl2) && (p2 == wrl2))); + } + + SECTION("r-value construct from ComPtr ref<>") + { + auto move1 = wrl2; + auto move2 = wrl2; + Ptr1 p1(wistd::move(move1)); + Ptr1 p2 = wistd::move(move2); + REQUIRE(((p1 == wrl2) && (p2 == wrl2))); + } + + SECTION("assign from ComPtr ref<>") + { + Ptr1 p = ptr1; + p = wrl2; + REQUIRE(p == wrl2); + } + + SECTION("r-value assign from ComPtr ref<>") + { + Ptr1 p = ptr1; + p = decltype(wrl2)(wrl2); + REQUIRE(p == wrl2); + } +} + +template +static void TestPointerConversionCombination(IFace1* p1, IFace2* p2) +{ +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerConversion(wil::com_ptr(p1), wil::com_ptr(p2)); + TestSmartPointerConversion(wil::com_ptr(p1), wil::com_ptr_failfast(p2)); + TestSmartPointerConversion(wil::com_ptr(p1), wil::com_ptr_nothrow(p2)); +#endif + +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerConversion(wil::com_ptr_failfast(p1), wil::com_ptr(p2)); +#endif + TestSmartPointerConversion(wil::com_ptr_failfast(p1), wil::com_ptr_failfast(p2)); + TestSmartPointerConversion(wil::com_ptr_failfast(p1), wil::com_ptr_nothrow(p2)); + +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerConversion(wil::com_ptr_nothrow(p1), wil::com_ptr(p2)); +#endif + TestSmartPointerConversion(wil::com_ptr_nothrow(p1), wil::com_ptr_failfast(p2)); + TestSmartPointerConversion(wil::com_ptr_nothrow(p1), wil::com_ptr_nothrow(p2)); +} + +template +static void TestPointerConversion() +{ + auto p1 = make_object(); + auto p2 = make_object(); + IFace1* nullPtr1 = nullptr; + IFace2* nullPtr2 = nullptr; + TestPointerConversionCombination(p1, p2); + TestPointerConversionCombination(nullPtr1, p2); + TestPointerConversionCombination(p1, nullPtr2); + TestPointerConversionCombination(nullPtr1, nullPtr2); + TestPointerConversionCombination(static_cast(p2), p2); // same object + + p1->Release(); + p2->Release(); +} + +TEST_CASE("ComTests::Test_PointerConversion", "[com][com_ptr]") +{ + // avoid overwhelming debug logging, perhaps the COM helpers are over reporting + auto restoreDebugString = wil::g_fResultOutputDebugString; + wil::g_fResultOutputDebugString = false; + + TestPointerConversion(); + + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + TestPointerConversion(); + + REQUIRE_FALSE(witest::g_objectCount.Leaked()); + wil::g_fResultOutputDebugString = restoreDebugString; +} + +template +void TestGlobalQueryIidPpv(wistd::true_type, const Ptr& source) // interface +{ + using DestPtr = wil::com_ptr_nothrow; + wil::com_ptr_nothrow never; + + SECTION("com_query_to(iid, ppv)") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + wil::com_query_to(source, IID_PPV_ARGS(&dest1)); + REQUIRE_ERROR(wil::com_query_to(source, IID_PPV_ARGS(&never))); + REQUIRE((dest1 && !never)); +#endif + + DestPtr dest2, dest3; + wil::com_query_to_failfast(source, IID_PPV_ARGS(&dest2)); + REQUIRE_ERROR(wil::com_query_to_failfast(source, IID_PPV_ARGS(&never))); + wil::com_query_to_nothrow(source, IID_PPV_ARGS(&dest3)); + REQUIRE_ERROR(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&never))); + REQUIRE((dest2 && dest3 && !never)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + REQUIRE_CRASH(wil::com_query_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE_CRASH(wil::com_query_to(source, IID_PPV_ARGS(&never))); +#endif + + DestPtr dest2, dest3; + REQUIRE_CRASH(wil::com_query_to_failfast(source, IID_PPV_ARGS(&dest2))); + REQUIRE_CRASH(wil::com_query_to_failfast(source, IID_PPV_ARGS(&never))); + REQUIRE_CRASH(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&dest3))); + REQUIRE_CRASH(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&never))); + } + } + + SECTION("try_com_query_to(iid, ppv)") + { + if (source) + { + DestPtr dest1; + REQUIRE(wil::try_com_query_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE_FALSE(wil::try_com_query_to(source, IID_PPV_ARGS(&never))); + REQUIRE((dest1 && !never)); + } + else + { + DestPtr dest1; + REQUIRE_CRASH(wil::try_com_query_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE_CRASH(wil::try_com_query_to(source, IID_PPV_ARGS(&never))); + } + } + + SECTION("com_copy_to(iid, ppv)") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + wil::com_copy_to(source, IID_PPV_ARGS(&dest1)); + REQUIRE_ERROR(wil::com_copy_to(source, IID_PPV_ARGS(&never))); + REQUIRE((dest1 && !never)); +#endif + + DestPtr dest2, dest3; + wil::com_copy_to_failfast(source, IID_PPV_ARGS(&dest2)); + REQUIRE_ERROR(wil::com_copy_to_failfast(source, IID_PPV_ARGS(&never))); + wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&dest3)); + REQUIRE_ERROR(wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&never))); + REQUIRE((dest2 && dest3 && !never)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + wil::com_copy_to(source, IID_PPV_ARGS(&dest1)); + wil::com_copy_to(source, IID_PPV_ARGS(&never)); +#endif + + DestPtr dest2, dest3; + wil::com_copy_to_failfast(source, IID_PPV_ARGS(&dest2)); + wil::com_copy_to_failfast(source, IID_PPV_ARGS(&never)); + wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&dest3)); + wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&never)); + } + } + + SECTION("try_com_copy_to(iid, ppv)") + { + if (source) + { + DestPtr dest1; + REQUIRE(wil::try_com_copy_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE_FALSE(wil::try_com_copy_to(source, IID_PPV_ARGS(&never))); + REQUIRE((dest1 && !never)); + } + else + { + DestPtr dest1; + REQUIRE_FALSE(wil::try_com_copy_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE_FALSE(wil::try_com_copy_to(source, IID_PPV_ARGS(&never))); + } + } +} + +template +void TestGlobalQueryIidPpv(wistd::false_type, const Ptr&) // class +{ + // we can't compile against iid, ppv with a class +} + +template +static void TestGlobalQuery(const Ptr& source) +{ + using DestPtr = wil::com_ptr_nothrow; + wil::com_ptr_nothrow never; + + SECTION("com_query") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(wil::com_query(source)); + REQUIRE_ERROR(wil::com_query(source)); +#endif + + REQUIRE(wil::com_query_failfast(source)); + REQUIRE_ERROR(wil::com_query_failfast(source)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_CRASH(wil::com_query(source)); + REQUIRE_CRASH(wil::com_query(source)); +#endif + + REQUIRE_CRASH(wil::com_query_failfast(source)); + REQUIRE_CRASH(wil::com_query_failfast(source)); + } + } + + SECTION("com_query_to(U**)") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + wil::com_query_to(source, &dest1); + REQUIRE_ERROR(wil::com_query_to(source, &never)); + REQUIRE((dest1 && !never)); +#endif + + DestPtr dest2, dest3; + wil::com_query_to_failfast(source, &dest2); + REQUIRE_ERROR(wil::com_query_to_failfast(source, &never)); + wil::com_query_to_nothrow(source, &dest3); + REQUIRE_ERROR(wil::com_query_to_nothrow(source, &never)); + REQUIRE((dest2 && dest3 && !never)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + REQUIRE_CRASH(wil::com_query_to(source, &dest1)); + REQUIRE_CRASH(wil::com_query_to(source, &never)); +#endif + + DestPtr dest2, dest3; + REQUIRE_CRASH(wil::com_query_to_failfast(source, &dest2)); + REQUIRE_CRASH(wil::com_query_to_failfast(source, &never)); + REQUIRE_CRASH(wil::com_query_to_nothrow(source, &dest3)); + REQUIRE_CRASH(wil::com_query_to_nothrow(source, &never)); + } + } + + SECTION("try_com_query") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(wil::try_com_query(source)); + REQUIRE_FALSE(wil::try_com_query(source)); +#endif + + REQUIRE(wil::try_com_query_failfast(source)); + REQUIRE_FALSE(wil::try_com_query_failfast(source)); + REQUIRE(wil::try_com_query_nothrow(source)); + REQUIRE_FALSE(wil::try_com_query_nothrow(source)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_CRASH(wil::try_com_query(source)); + REQUIRE_CRASH(wil::try_com_query(source)); +#endif + + REQUIRE_CRASH(wil::try_com_query_failfast(source)); + REQUIRE_CRASH(wil::try_com_query_failfast(source)); + REQUIRE_CRASH(wil::try_com_query_nothrow(source)); + REQUIRE_CRASH(wil::try_com_query_nothrow(source)); + } + } + + SECTION("try_com_query_to(U**)") + { + if (source) + { + DestPtr dest1; + REQUIRE(wil::try_com_query_to(source, &dest1)); + REQUIRE_FALSE(wil::try_com_query_to(source, &never)); + REQUIRE((dest1 && !never)); + } + else + { + DestPtr dest1; + REQUIRE_CRASH(wil::try_com_query_to(source, &dest1)); + REQUIRE_CRASH(wil::try_com_query_to(source, &never)); + } + } + + SECTION("com_copy") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(wil::com_copy(source)); + REQUIRE_ERROR(wil::com_copy(source)); +#endif + + REQUIRE(wil::com_copy_failfast(source)); + REQUIRE_ERROR(wil::com_copy_failfast(source)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_FALSE(wil::com_copy(source)); + REQUIRE_FALSE(wil::com_copy(source)); +#endif + + REQUIRE_FALSE(wil::com_copy_failfast(source)); + REQUIRE_FALSE(wil::com_copy_failfast(source)); + } + } + + SECTION("com_copy_to(U**)") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + wil::com_copy_to(source, &dest1); + REQUIRE_ERROR(wil::com_copy_to(source, &never)); + REQUIRE((dest1 && !never)); +#endif + + DestPtr dest2, dest3; + wil::com_copy_to_failfast(source, &dest2); + REQUIRE_ERROR(wil::com_copy_to_failfast(source, &never)); + wil::com_copy_to_nothrow(source, &dest3); + REQUIRE_ERROR(wil::com_copy_to_nothrow(source, &never)); + REQUIRE((dest2 && dest3 && !never)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + wil::com_copy_to(source, &dest1); + wil::com_copy_to(source, &never); +#endif + + DestPtr dest2, dest3; + wil::com_copy_to_failfast(source, &dest2); + wil::com_copy_to_failfast(source, &never); + wil::com_copy_to_nothrow(source, &dest3); + wil::com_copy_to_nothrow(source, &never); + } + } + + SECTION("try_com_copy") + { + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(wil::try_com_copy(source)); + REQUIRE_FALSE(wil::try_com_copy(source)); +#endif + REQUIRE(wil::try_com_copy_failfast(source)); + REQUIRE_FALSE(wil::try_com_copy_failfast(source)); + REQUIRE(wil::try_com_copy_nothrow(source)); + REQUIRE_FALSE(wil::try_com_copy_nothrow(source)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_FALSE(wil::try_com_copy(source)); + REQUIRE_FALSE(wil::try_com_copy(source)); +#endif + REQUIRE_FALSE(wil::try_com_copy_failfast(source)); + REQUIRE_FALSE(wil::try_com_copy_failfast(source)); + REQUIRE_FALSE(wil::try_com_copy_nothrow(source)); + REQUIRE_FALSE(wil::try_com_copy_nothrow(source)); + } + } + + SECTION("try_com_copy_to(U**)") + { + if (source) + { + DestPtr dest1; + REQUIRE(wil::try_com_copy_to(source, &dest1)); + REQUIRE_FALSE(wil::try_com_copy_to(source, &never)); + REQUIRE((dest1 && !never)); + } + else + { + DestPtr dest1; + REQUIRE_FALSE(wil::try_com_copy_to(source, &dest1)); + REQUIRE_FALSE(wil::try_com_copy_to(source, &never)); + } + } + + TestGlobalQueryIidPpv(typename wistd::is_abstract::type(), source); +} + +// Test fluent query functions for types that support them (exception and fail fast) +template +void TestSmartPointerQueryFluent(wistd::true_type, const Ptr& source) // void return (non-error based) +{ + using element_type = typename DestPtr::element_type; + + SECTION("query") + { + if (source) + { + REQUIRE(source.template query()); + REQUIRE_ERROR(source.template query()); + } + else + { + REQUIRE_CRASH(source.template query()); + REQUIRE_CRASH(source.template query()); + } + } + + SECTION("copy") + { + if (source) + { + REQUIRE(source.template copy()); + REQUIRE_ERROR(source.template copy()); + } + else + { + REQUIRE_FALSE(source.template copy()); + REQUIRE_FALSE(source.template copy()); + } + } +} + +// "Test" fluent query functions for error-based types (by doing nothing) +template +void TestSmartPointerQueryFluent(wistd::false_type, const Ptr& /*source*/) // error-code based return +{ + // error code based code cannot call the fluent error methods +} + +// Test iid, ppv queries for types that support them (interfaces yes, classes no) +template +void TestSmartPointerQueryIidPpv(wistd::true_type, const Ptr& source) // interface +{ + wil::com_ptr_nothrow never; + + SECTION("query_to(iid, ppv)") + { + if (source) + { + DestPtr dest; + source.query_to(IID_PPV_ARGS(&dest)); + REQUIRE_ERROR(source.query_to(IID_PPV_ARGS(&never))); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + REQUIRE_CRASH(source.query_to(IID_PPV_ARGS(&dest))); + REQUIRE_CRASH(source.query_to(IID_PPV_ARGS(&never))); + REQUIRE((!dest && !never)); + } + } + + SECTION("try_query_to(iid, ppv)") + { + if (source) + { + DestPtr dest; + REQUIRE(source.try_query_to(IID_PPV_ARGS(&dest))); + REQUIRE(!source.try_query_to(IID_PPV_ARGS(&never))); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + REQUIRE_CRASH(source.try_query_to(IID_PPV_ARGS(&dest))); + REQUIRE_CRASH(source.try_query_to(IID_PPV_ARGS(&never))); + REQUIRE((!dest && !never)); + } + } + + SECTION("copy_to(iid, ppv)") + { + if (source) + { + DestPtr dest; + source.copy_to(IID_PPV_ARGS(&dest)); + REQUIRE_ERROR(source.copy_to(IID_PPV_ARGS(&never))); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + source.copy_to(IID_PPV_ARGS(&dest)); + source.copy_to(IID_PPV_ARGS(&never)); + REQUIRE((!dest && !never)); + } + } + + SECTION("try_copy_to(iid, ppv)") + { + if (source) + { + DestPtr dest; + REQUIRE(source.try_copy_to(IID_PPV_ARGS(&dest))); + REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&never))); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&dest))); + REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&never))); + REQUIRE((!dest && !never)); + } + } +} + +// "Test" iid, ppv queries for types that support them for a class (unsupported same (interfaces yes, classes no) +template +void TestSmartPointerQueryIidPpv(wistd::false_type, const Ptr& /*source*/) // class +{ + // we can't compile against iid, ppv with a class +} + +// Test the various query and copy methods against the given source pointer (trying produce the given dest pointer) +template +void TestSmartPointerQuery(const Ptr& source) +{ + wil::com_ptr_nothrow never; + using element_type = typename DestPtr::element_type; + + SECTION("query_to(U**)") + { + if (source) + { + DestPtr dest; + source.query_to(&dest); + REQUIRE_ERROR(source.query_to(&never)); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + REQUIRE_CRASH(source.query_to(&dest)); + REQUIRE_CRASH(source.query_to(&never)); + REQUIRE((!dest && !never)); + } + } + + SECTION("try_query") + { + if (source) + { + REQUIRE(source.template try_query()); + REQUIRE_FALSE(source.template try_query()); + } + else + { + REQUIRE_CRASH(source.template try_query()); + REQUIRE_CRASH(source.template try_query()); + } + } + + SECTION("try_query_to(U**)") + { + if (source) + { + DestPtr dest; + REQUIRE(source.try_query_to(&dest)); + REQUIRE_FALSE(source.try_query_to(&never)); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + REQUIRE_CRASH(source.try_query_to(&dest)); + REQUIRE_CRASH(source.try_query_to(&never)); + REQUIRE((!dest && !never)); + } + } + + SECTION("copy_to(U**)") + { + if (source) + { + DestPtr dest; + source.copy_to(&dest); + REQUIRE_ERROR(source.copy_to(&never)); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + source.copy_to(&dest); + source.copy_to(&never); + REQUIRE((!dest && !never)); + } + } + + SECTION("try_copy") + { + if (source) + { + REQUIRE(source.template try_copy()); + REQUIRE_FALSE(source.template try_copy()); + } + else + { + REQUIRE_FALSE(source.template try_copy()); + REQUIRE_FALSE(source.template try_copy()); + } + } + + SECTION("try_copy_to(U**)") + { + if (source) + { + DestPtr dest; + REQUIRE(source.try_copy_to(&dest)); + REQUIRE_FALSE(source.try_copy_to(&never)); + REQUIRE((dest && !never)); + } + else + { + DestPtr dest; + REQUIRE_FALSE(source.try_copy_to(&dest)); + REQUIRE_FALSE(source.try_copy_to(&never)); + REQUIRE((!dest && !never)); + } + } + + TestSmartPointerQueryFluent(typename wistd::is_same::type(), source); + + TestSmartPointerQueryIidPpv(typename wistd::is_abstract::type(), source); +} + +template +static void TestQueryCombination(IFace* ptr) +{ + TestGlobalQuery(ptr); +#ifdef WIL_ENABLE_EXCEPTIONS + TestGlobalQuery(wil::com_ptr(ptr)); +#endif + TestGlobalQuery(wil::com_ptr_failfast(ptr)); + TestGlobalQuery(wil::com_ptr_nothrow(ptr)); + TestGlobalQuery(Microsoft::WRL::ComPtr(ptr)); + +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerQuery>(wil::com_ptr(ptr)); + TestSmartPointerQuery>(wil::com_ptr_failfast(ptr)); + TestSmartPointerQuery>(wil::com_ptr_nothrow(ptr)); + + TestSmartPointerQuery>(wil::com_ptr(ptr)); +#endif + TestSmartPointerQuery>(wil::com_ptr_failfast(ptr)); + TestSmartPointerQuery>(wil::com_ptr_nothrow(ptr)); + +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerQuery>(wil::com_ptr(ptr)); +#endif + TestSmartPointerQuery>(wil::com_ptr_failfast(ptr)); + TestSmartPointerQuery>(wil::com_ptr_nothrow(ptr)); +} + +template +static void TestQuery(IFace* ptr) +{ + IFace* nullPtr = nullptr; + TestQueryCombination(ptr); + TestQueryCombination(nullPtr); +} + +template +static void TestQuery() +{ + auto ptr = make_object(); + TestQuery(ptr); + ptr->Release(); +} + +TEST_CASE("ComTests::Test_Query", "[com][com_ptr]") +{ + // avoid overwhelming debug logging, perhaps the COM helpers are over reporting + auto restoreDebugString = wil::g_fResultOutputDebugString; + wil::g_fResultOutputDebugString = false; + + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + + // This adds a significant amount of time to the compilation duration, so most tests are disabled by default... +#ifdef WIL_EXHAUSTIVE_TEST + TestQuery(); // ComObject + TestQuery(); + TestQuery(); // IUnknown + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // ITest + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IDerivedTest + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IAlways + TestQuery(); + TestQuery(); + TestQuery(); + + TestQuery(); // WinRtObject + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IUnknown + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IInspectable + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // ITest + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IDerivedTest + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // ITestInspectable + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IDerivedTestInspectable + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); // IAlways + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); + TestQuery(); +#endif + + REQUIRE_FALSE(witest::g_objectCount.Leaked()); + wil::g_fResultOutputDebugString = restoreDebugString; +} + +#if (NTDDI_VERSION >= NTDDI_WINBLUE) +template +void TestAgile(const Ptr& source) +{ + bool source_valid = (source != nullptr); + + if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + auto agile1 = wil::com_agile_query(source); + REQUIRE(agile1); +#endif + + auto agile2 = wil::com_agile_query_failfast(source); + wil::com_agile_ref_nothrow agile3; + REQUIRE_SUCCEEDED(wil::com_agile_query_nothrow(source, &agile3)); + REQUIRE((agile2 && agile3)); + } + else + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_CRASH(wil::com_agile_query(source)); +#endif + + REQUIRE_CRASH(wil::com_agile_query_failfast(source)); + wil::com_agile_ref_nothrow agile3; + REQUIRE_CRASH(wil::com_agile_query_nothrow(source, &agile3)); + } + +#ifdef WIL_ENABLE_EXCEPTIONS + auto agile1 = wil::com_agile_copy(source); + REQUIRE(static_cast(agile1) == source_valid); +#endif + + auto agile2 = wil::com_agile_copy_failfast(source); + wil::com_agile_ref_nothrow agile3; + REQUIRE_SUCCEEDED(wil::com_agile_copy_nothrow(source, &agile3)); + REQUIRE(static_cast(agile2) == source_valid); + REQUIRE(static_cast(agile3) == source_valid); +} + +template +void TestAgileCombinations() +{ + auto ptr = make_object(); + + REQUIRE_SUCCEEDED(::CoInitializeEx(nullptr, COINIT_APARTMENTTHREADED)); + auto exit = wil::scope_exit([] { ::CoUninitialize(); }); + + TestAgile(ptr); + TestAgile(wil::com_ptr_nothrow(ptr)); + TestAgile(Microsoft::WRL::ComPtr(ptr)); + + auto agilePtr = wil::com_agile_query_failfast(ptr); + TestQuery(agilePtr.get()); + TestQuery(agilePtr.get()); + TestQuery(agilePtr.get()); +#ifdef WIL_EXHAUSTIVE_TEST + TestQuery(agilePtr.get()); + TestQuery(agilePtr.get()); + TestQuery(agilePtr.get()); + TestQuery(agilePtr.get()); +#endif + + ptr->Release(); +} + +TEST_CASE("ComTests::Test_Agile", "[com][com_agile_ref]") +{ + // TestAgileCombinations(); + TestAgileCombinations(); + TestAgileCombinations(); + TestAgileCombinations(); +#ifdef WIL_EXHAUSTIVE_TEST + TestAgileCombinations(); + TestAgileCombinations(); + TestAgileCombinations(); + TestAgileCombinations(); +#endif + + REQUIRE_FALSE(witest::g_objectCount.Leaked()); +} +#endif + +template +void TestWeak(const Ptr& source) +{ + bool supports_weak = (source && (wil::try_com_query_nothrow(source))); + + if (supports_weak && source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + auto weak1 = wil::com_weak_query(source); + REQUIRE(weak1); +#endif + + auto weak2 = wil::com_weak_query_failfast(source); + wil::com_weak_ref_nothrow weak3; + REQUIRE_SUCCEEDED(wil::com_weak_query_nothrow(source, &weak3)); + REQUIRE((weak2 && weak3)); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto weak1copy = wil::com_weak_copy(source); + REQUIRE(weak1copy); +#endif + auto weak2copy = wil::com_weak_copy_failfast(source); + wil::com_weak_ref_nothrow weak3copy; + REQUIRE_SUCCEEDED(wil::com_weak_copy_nothrow(source, &weak3copy)); + REQUIRE((weak2copy && weak3copy)); + } + else if (source) + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_ERROR(wil::com_weak_query(source)); +#endif + + REQUIRE_ERROR(wil::com_weak_query_failfast(source)); + wil::com_weak_ref_nothrow weak3err; + REQUIRE_ERROR(wil::com_weak_query_nothrow(source, &weak3err)); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_ERROR(wil::com_weak_copy(source)); +#endif + + REQUIRE_ERROR(wil::com_weak_copy_failfast(source)); + wil::com_weak_ref_nothrow weak3; + REQUIRE_ERROR(wil::com_weak_copy_nothrow(source, &weak3)); + } + else // !source + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_CRASH(wil::com_weak_query(source)); +#endif + + REQUIRE_CRASH(wil::com_weak_query_failfast(source)); + wil::com_weak_ref_nothrow weak3crash; + REQUIRE_CRASH(wil::com_weak_query_nothrow(source, &weak3crash)); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto weak1 = wil::com_weak_copy(source); + REQUIRE(!weak1); +#endif + + auto weak2 = wil::com_weak_copy_failfast(source); + wil::com_weak_ref_nothrow weak3; + REQUIRE_SUCCEEDED(wil::com_weak_copy_nothrow(source, &weak3)); + REQUIRE((!weak2 && !weak3)); + } +} + +template +void TestGlobalQueryWithFailedResolve(const Ptr& source) +{ + // No need to test the null source and wrong interface query + // since that's covered in the TestGlobalQuery. + using DestPtr = wil::com_ptr_nothrow; + + SECTION("com_query") + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_ERROR(wil::com_query(source)); +#endif + REQUIRE_ERROR(wil::com_query_failfast(source)); + } + + SECTION("com_query_to(U**)") + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + REQUIRE_ERROR(wil::com_query_to(source, &dest1)); + REQUIRE(!dest1); +#endif + + DestPtr dest2, dest3; + REQUIRE_ERROR(wil::com_query_to_failfast(source, &dest2)); + REQUIRE_ERROR(wil::com_query_to_nothrow(source, &dest3)); + REQUIRE((!dest2 && !dest3)); + } + + SECTION("try_com_query") + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(!wil::try_com_query(source)); +#endif + REQUIRE(!wil::try_com_query_failfast(source)); + REQUIRE(!wil::try_com_query_nothrow(source)); + } + + SECTION("try_com_query_to(U**)") + { + DestPtr dest1; + REQUIRE(!wil::try_com_query_to(source, &dest1)); + REQUIRE(!dest1); + } + + SECTION("com_copy") + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_ERROR(wil::com_copy(source)); +#endif + REQUIRE_ERROR(wil::com_copy_failfast(source)); + } + + SECTION("com_copy_to(U**)") + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + REQUIRE_ERROR(wil::com_copy_to(source, &dest1)); + REQUIRE(!dest1); +#endif + + DestPtr dest2, dest3; + REQUIRE_ERROR(wil::com_copy_to_failfast(source, &dest2)); + REQUIRE_ERROR(wil::com_copy_to_nothrow(source, &dest3)); + REQUIRE((!dest2 && !dest3)); + } + + + SECTION("try_com_copy") + { +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(!wil::try_com_copy(source)); +#endif + REQUIRE(!wil::try_com_copy_failfast(source)); + REQUIRE(!wil::try_com_copy_nothrow(source)); + } + + SECTION("try_com_copy_to(U**)") + { + DestPtr dest1; + REQUIRE(!wil::try_com_copy_to(source, &dest1)); + REQUIRE(!dest1); + } + + if (wistd::is_abstract::value) + { + SECTION("com_query_to(iid, ppv)") + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + REQUIRE_ERROR(wil::com_query_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE(!dest1); +#endif + + DestPtr dest2, dest3; + REQUIRE_ERROR(wil::com_query_to_failfast(source, IID_PPV_ARGS(&dest2))); + REQUIRE_ERROR(wil::com_query_to_nothrow(source, IID_PPV_ARGS(&dest3))); + REQUIRE((!dest2 && !dest3)); + } + + SECTION("try_com_query_to(iid, ppv)") + { + DestPtr dest1; + REQUIRE(!wil::try_com_query_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE(!dest1); + } + + SECTION("com_copy_to(iid, ppv)") + { +#ifdef WIL_ENABLE_EXCEPTIONS + DestPtr dest1; + REQUIRE_ERROR(wil::com_copy_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE(!dest1); +#endif + + DestPtr dest2, dest3; + REQUIRE_ERROR(wil::com_copy_to_failfast(source, IID_PPV_ARGS(&dest2))); + REQUIRE_ERROR(wil::com_copy_to_nothrow(source, IID_PPV_ARGS(&dest3))); + REQUIRE((!dest2 && !dest3)); + } + + SECTION("try_com_copy_to(iid, ppv)") + { + DestPtr dest1; + REQUIRE(!wil::try_com_copy_to(source, IID_PPV_ARGS(&dest1))); + REQUIRE(!dest1); + } + } +} + +template +void TestSmartPointerQueryFluentWithFailedResolve(wistd::false_type, const Ptr& /*source*/) +{ +} + + template +void TestSmartPointerQueryFluentWithFailedResolve(wistd::true_type, const Ptr& source) +{ + using element_type = typename TargetIFace::element_type; + + REQUIRE_ERROR(source.template query()); + REQUIRE_ERROR(source.template copy()); +} + +template +void TestSmartPointerQueryWithFailedResolve(const Ptr source) +{ + using element_type = typename TargetIFace::element_type; + + SECTION("query_to(U**)") + { + TargetIFace dest; + REQUIRE_ERROR(source.query_to(&dest)); + REQUIRE(!dest); + } + + SECTION("try_query") + { + REQUIRE(!source.template try_query()); + } + + SECTION("try_query_to(U**)") + { + TargetIFace dest; + REQUIRE(!source.try_query_to(&dest)); + REQUIRE(!dest); + } + + SECTION("copy_to(U**)") + { + TargetIFace dest; + REQUIRE_ERROR(source.copy_to(&dest)); + REQUIRE(!dest); + } + + SECTION("try_copy") + { + REQUIRE(!source.template try_copy()); + } + + SECTION("try_copy_to(U**)") + { + TargetIFace dest; + REQUIRE(!source.try_copy_to(&dest)); + REQUIRE(!dest); + } + + TestSmartPointerQueryFluentWithFailedResolve(typename wistd::is_same::type(), source); + + if (wistd::is_abstract::value) + { + SECTION("query_to(iid, ppv)") + { + TargetIFace dest; + REQUIRE_ERROR(source.query_to(IID_PPV_ARGS(&dest))); + REQUIRE(!dest); + } + + SECTION("try_query_to(iid, ppv)") + { + TargetIFace dest; + REQUIRE(!source.try_query_to(IID_PPV_ARGS(&dest))); + REQUIRE(!dest); + } + + SECTION("copy_to(iid, ppv)") + { + TargetIFace dest; + REQUIRE_ERROR(source.copy_to(IID_PPV_ARGS(&dest))); + REQUIRE(!dest); + } + + SECTION("try_copy_to(iid, ppv)") + { + TargetIFace dest; + REQUIRE(!source.try_copy_to(IID_PPV_ARGS(&dest))); + REQUIRE(!dest); + } + } +} + +template +void TestQueryWithFailedResolve(IFace* ptr) +{ + TestGlobalQueryWithFailedResolve(ptr); +#ifdef WIL_ENABLE_EXCEPTIONS + TestGlobalQueryWithFailedResolve(wil::com_ptr(ptr)); +#endif + TestGlobalQueryWithFailedResolve(wil::com_ptr_failfast(ptr)); + TestGlobalQueryWithFailedResolve(wil::com_ptr_nothrow(ptr)); + TestGlobalQueryWithFailedResolve(Microsoft::WRL::ComPtr(ptr)); + +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr(ptr)); + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr_failfast(ptr)); + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr_nothrow(ptr)); + + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr(ptr)); +#endif + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr_failfast(ptr)); + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr_nothrow(ptr)); + +#ifdef WIL_ENABLE_EXCEPTIONS + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr(ptr)); +#endif + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr_failfast(ptr)); + TestSmartPointerQueryWithFailedResolve>(wil::com_ptr_nothrow(ptr)); +} + +template +void TestWeakCombinations() +{ + auto ptr = make_object(); + + TestWeak(ptr); +#ifdef WIL_ENABLE_EXCEPTIONS + TestWeak(wil::com_ptr(ptr)); +#endif + TestWeak(Microsoft::WRL::ComPtr(ptr)); + + auto weakPtr = wil::com_weak_query_failfast(ptr); + TestQuery(weakPtr.get()); + TestQuery(weakPtr.get()); + TestQuery(weakPtr.get()); + +#ifdef WIL_EXHAUSTIVE_TEST + TestQuery(weakPtr.get()); + TestQuery(weakPtr.get()); + TestQuery(weakPtr.get()); + TestQuery(weakPtr.get()); +#endif + + // On the final release of the pointer, the weak reference will no longer resolve + ptr->Release(); + TestQueryWithFailedResolve(weakPtr.get()); + TestQueryWithFailedResolve(weakPtr.get()); + TestQueryWithFailedResolve(weakPtr.get()); +} + +TEST_CASE("ComTests::Test_Weak", "[com][com_weak_ref]") +{ + // TestWeakCombinations(); + TestWeakCombinations(); + TestWeakCombinations(); + TestWeakCombinations(); +#ifdef WIL_EXHAUSTIVE_TEST + TestWeakCombinations(); + TestWeakCombinations(); + TestWeakCombinations(); + TestWeakCombinations(); +#endif + + REQUIRE_FALSE(witest::g_objectCount.Leaked()); +} + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +TEST_CASE("ComTests::VerifyCoCreate", "[com][CoCreateInstance]") +{ + auto init = wil::CoInitializeEx_failfast(); + + // success cases +#ifdef WIL_ENABLE_EXCEPTIONS + auto link1 = wil::CoCreateInstance(); +#endif + auto link2 = wil::CoCreateInstanceFailFast(); + auto link3 = wil::CoCreateInstanceNoThrow(); + + // failure +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_THROWS((wil::CoCreateInstance())); +#endif + // skip this test, assume testing the exception based version is sufficient. + // auto link2 = wil::CoCreateInstanceFailFast(); + REQUIRE_FALSE(static_cast(wil::CoCreateInstanceNoThrow().get())); +} + +TEST_CASE("ComTests::VerifyCoGetClassObject", "[com][CoGetClassObject]") +{ + auto init = wil::CoInitializeEx_failfast(); + + // success cases +#ifdef WIL_ENABLE_EXCEPTIONS + auto linkFactory1 = wil::CoGetClassObject(); +#endif + auto linkFactory2 = wil::CoGetClassObjectFailFast(); + auto linkFactory3 = wil::CoGetClassObjectNoThrow(); + + // failure +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_THROWS((wil::CoGetClassObject())); +#endif + // skip this test, assume testing the exception based version is sufficient. + // auto linkFactory2 = wil::CoGetClassObjectFailFast(); + REQUIRE_FALSE(static_cast(wil::CoGetClassObjectNoThrow())); +} +#endif + +#ifdef __IObjectWithSite_INTERFACE_DEFINED__ +TEST_CASE("ComTests::VerifyComSetSiteNullIsMoveOnly", "[com][com_set_site]") +{ + wil::unique_set_site_null_call call1; + + // intentional compilation errors for copy construction/assignment + // wil::unique_set_site_null_call call2 = call1; + // call2 = call1; + + auto siteSetter = wil::com_set_site(nullptr, nullptr); + auto siteSetter2 = std::move(siteSetter); // Move construction + siteSetter2 = std::move(siteSetter); // Move assignment +} + +TEST_CASE("ComTests::VerifyComSetSite", "[com][com_set_site]") +{ + class ObjectWithSite WrlFinal : public RuntimeClass, IObjectWithSite> + { + public: + STDMETHODIMP SetSite(IUnknown* val) noexcept override + { + m_site = val; + return S_OK; + } + + STDMETHODIMP GetSite(REFIID riid, void** ppv) noexcept override + { + m_site.try_copy_to(riid, ppv); + return S_OK; + } + + private: + wil::com_ptr_nothrow m_site; + }; + + class ServiceObject WrlFinal : public RuntimeClass, IServiceProvider> + { + public: + ServiceObject(IServiceProvider* site = nullptr) + { + m_site = site; + } + + STDMETHODIMP QueryService(REFIID /*sid*/, REFIID /*riid*/, void** ppv) noexcept override + { + *ppv = nullptr; + return E_NOTIMPL; + } + private: + wil::com_ptr_nothrow m_site; + }; + + auto objWithSite = Make(); + auto serviceObj = Make(); + auto serviceObj2 = Make(serviceObj.Get()); + + { + auto cleanupSite = wil::com_set_site(objWithSite.Get(), serviceObj2.Get()); + + wil::com_ptr_nothrow site; + REQUIRE_SUCCEEDED(objWithSite->GetSite(IID_PPV_ARGS(&site))); + REQUIRE(static_cast(site)); + + auto siteCount = 0; + wil::for_each_site(objWithSite.Get(), [&](IUnknown* /*site*/) + { + siteCount++; + }); + REQUIRE(siteCount == 2); + } + + wil::com_ptr_nothrow site; + REQUIRE_SUCCEEDED(objWithSite->GetSite(IID_PPV_ARGS(&site))); + REQUIRE_FALSE(static_cast(site)); +} +#endif + +class FakeStream : public IStream +{ +public: + + STDMETHOD(QueryInterface)(REFIID riid, PVOID* ppv) override + { + if ((riid == __uuidof(IStream)) || + (riid == __uuidof(ISequentialStream)) || + (riid == __uuidof(IUnknown))) + { + *ppv = static_cast(this); + return S_OK; + } + + return E_NOTIMPL; + } + + STDMETHOD_(ULONG, AddRef)() override + { + return 2; + } + + STDMETHOD_(ULONG, Release)() override + { + return 1; + } + + unsigned long long Position = 0; + unsigned long long PositionMax = 0; + unsigned long MaxReadSize = 0; + unsigned long MaxWriteSize = 0; + unsigned long long TotalSize = 0; + + // ISequentialStream + STDMETHOD(Read)(_Out_writes_bytes_to_(cb, *pcbRead) void *pv, _In_ ULONG cb, _Out_opt_ ULONG *pcbRead) override + { + if (pcbRead) + { + *pcbRead = min(MaxReadSize, cb); + } + + ZeroMemory(pv, cb); + return (MaxReadSize <= cb) ? S_OK : S_FALSE; + } + + STDMETHOD(Write)(_In_reads_bytes_(cb) const void *, _In_ ULONG cb, _Out_opt_ ULONG *pcbWritten) override + { + if (pcbWritten) + { + *pcbWritten = min(MaxWriteSize, cb); + } + + return (MaxWriteSize <= cb) ? S_OK : S_FALSE; + } + + // IStream + STDMETHOD(Seek)(LARGE_INTEGER dlibMove, DWORD dwOrigin, _Out_opt_ ULARGE_INTEGER *plibNewPosition) + { + if (dwOrigin == STREAM_SEEK_CUR) + { + if ((dlibMove.QuadPart < 0) && (static_cast(-dlibMove.QuadPart) > Position)) + { + Position = 0; + } + else + { + Position += dlibMove.QuadPart; + } + } + else if (dwOrigin == STREAM_SEEK_SET) + { + Position = static_cast(dlibMove.QuadPart); + } + else if (dwOrigin == STREAM_SEEK_END) + { + if ((dlibMove.QuadPart < 0) && (static_cast(-dlibMove.QuadPart) > Position)) + { + Position = 0; + } + else + { + Position = PositionMax + dlibMove.QuadPart; + } + } + + Position = min(Position, PositionMax); + + if (plibNewPosition) + { + plibNewPosition->QuadPart = Position; + } + + return S_OK; + } + + STDMETHOD(Stat)(__RPC__out STATSTG *pstatstg, DWORD) override + { + *pstatstg = {}; + pstatstg->cbSize.QuadPart = TotalSize; + return S_OK; + } + + STDMETHOD(Revert)(void) override + { + return E_NOTIMPL; + } + + STDMETHOD(SetSize)(ULARGE_INTEGER) override + { + return E_NOTIMPL; + } + + STDMETHOD(Clone)(__RPC__deref_out_opt IStream **ppstm) override + { + *ppstm = this; + return S_OK; + } + + STDMETHOD(Commit)(DWORD) override + { + return E_NOTIMPL; + } + + STDMETHOD(CopyTo)(_In_ IStream *pstm, ULARGE_INTEGER cb, _Out_opt_ ULARGE_INTEGER *pcbRead, _Out_opt_ ULARGE_INTEGER *pcbWritten) override + { + unsigned long didWrite; + unsigned long didRead; + + FAIL_FAST_IF(cb.HighPart != 0); + RETURN_IF_FAILED(this->Read(nullptr, cb.LowPart, &didRead)); + RETURN_IF_FAILED(pstm->Write(nullptr, didRead, &didWrite)); + + pcbRead->QuadPart = didRead; + pcbWritten->QuadPart = didWrite; + + return S_OK; + } + + STDMETHOD(LockRegion)(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override + { + return E_NOTIMPL; + } + + STDMETHOD(UnlockRegion)(ULARGE_INTEGER, ULARGE_INTEGER, DWORD) override + { + return E_NOTIMPL; + } + + void SetPosition(unsigned long long position, unsigned long long positionMax) + { + Position = position; + PositionMax = positionMax; + } + + void SetPosition(unsigned long long position) + { + return SetPosition(position, position); + } +}; + +TEST_CASE("StreamTests::ReadPartial", "[com][IStream]") +{ + FakeStream stream; + stream.MaxReadSize = 16; + BYTE buffer[32]; + ULONG readSize; + + // Reading more than what's available is OK + REQUIRE_SUCCEEDED(wil::stream_read_partial_nothrow(&stream, buffer, 32, &readSize)); + REQUIRE(stream.MaxReadSize == readSize); + + // Reading less than what's available is OK + REQUIRE_SUCCEEDED(wil::stream_read_partial_nothrow(&stream, buffer, 5, &readSize)); + REQUIRE(5 == readSize); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(stream.MaxReadSize == wil::stream_read_partial(&stream, buffer, 32)); + REQUIRE(5ULL == wil::stream_read_partial(&stream, buffer, 5)); +#endif +} + +TEST_CASE("StreamTests::Read", "[com][IStream]") +{ + FakeStream stream; + stream.MaxReadSize = 10; + BYTE buffer[32]; + + // Reading less than available is OK + REQUIRE_SUCCEEDED(wil::stream_read_nothrow(&stream, buffer, 5)); + + // Reading more is not. + REQUIRE(stream.MaxReadSize < sizeof(buffer)); + REQUIRE_FAILED(wil::stream_read_nothrow(&stream, buffer, sizeof(buffer))); + + struct Header + { + ULONG Flags; + ULONG Other; + } header; + + // Reading a POD when there's not enough fails + stream.MaxReadSize = sizeof(header) - 1; + REQUIRE_FAILED(wil::stream_read_nothrow(&stream, &header)); + + // Reading a POD when there is is OK (and prove that the read happened) + header.Flags = 1; + header.Other = 2; + stream.MaxReadSize = sizeof(header); + REQUIRE_SUCCEEDED(wil::stream_read_nothrow(&stream, &header)); + REQUIRE(0UL == header.Flags); + REQUIRE(0UL == header.Other); + +#ifdef WIL_ENABLE_EXCEPTIONS + // Reading less than available is OK + REQUIRE_NOTHROW(wil::stream_read(&stream, buffer, 5)); + REQUIRE_THROWS(wil::stream_read(&stream, buffer, sizeof(buffer))); + + // Reading a POD when there's not enough fails + stream.MaxReadSize = sizeof(Header) - 1; + REQUIRE_THROWS(wil::stream_read
(&stream)); + + // Reading a POD when there is is OK (and prove that the read happened) + stream.MaxReadSize = sizeof(Header); + header = wil::stream_read
(&stream); + REQUIRE(0UL == header.Flags); + REQUIRE(0UL == header.Other); +#endif +} + +TEST_CASE("StreamTests::Write", "[com][IStream]") +{ + FakeStream stream; + BYTE buffer[16]; + + stream.MaxWriteSize = sizeof(buffer) + 1; + REQUIRE_SUCCEEDED(wil::stream_write_nothrow(&stream, buffer, sizeof(buffer))); + + stream.MaxWriteSize = sizeof(buffer) - 1; + REQUIRE_FAILED(wil::stream_write_nothrow(&stream, buffer, sizeof(buffer))); + + struct Header + { + ULONG Flags; + ULONG Other; + } header = { 1, 2 }; + + stream.MaxWriteSize = sizeof(header) + 1; + REQUIRE_SUCCEEDED(wil::stream_write_nothrow(&stream, header)); + + stream.MaxWriteSize = sizeof(header) - 1; + REQUIRE_FAILED(wil::stream_write_nothrow(&stream, header)); + +#ifdef WIL_ENABLE_EXCEPTIONS + stream.MaxWriteSize = sizeof(buffer) + 1; + REQUIRE_NOTHROW(wil::stream_write(&stream, buffer, sizeof(buffer))); + + stream.MaxWriteSize = sizeof(buffer) - 1; + REQUIRE_THROWS(wil::stream_write(&stream, buffer, sizeof(buffer))); + + header = { 1, 2 }; + stream.MaxWriteSize = sizeof(header) + 1; + REQUIRE_NOTHROW(wil::stream_write(&stream, header)); + + stream.MaxWriteSize = sizeof(header) - 1; + REQUIRE_THROWS(wil::stream_write(&stream, header)); +#endif +} + +TEST_CASE("StreamTests::Size", "[com][IStream]") +{ + FakeStream stream; + unsigned long long size; + + stream.TotalSize = 150; + REQUIRE_SUCCEEDED(wil::stream_size_nothrow(&stream, &size)); + REQUIRE(stream.TotalSize == size); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(stream.TotalSize == wil::stream_size(&stream)); +#endif +} + +TEST_CASE("StreamTests::SeekStart", "[com][IStream]") +{ + FakeStream stream; + unsigned long long landed; + + // Seek within the stream + stream.SetPosition(100, 1000); + REQUIRE_SUCCEEDED(wil::stream_set_position_nothrow(&stream, 10)); + REQUIRE(10ULL == stream.Position); + + // Seek and get the landing position + REQUIRE_SUCCEEDED(wil::stream_set_position_nothrow(&stream, 11, &landed)); + REQUIRE(11ULL == stream.Position); + REQUIRE(11ULL == landed); + + // Seek past the end + REQUIRE_SUCCEEDED(wil::stream_set_position_nothrow(&stream, 5000, &landed)); + REQUIRE(stream.PositionMax == landed); + + // Seek to the start + REQUIRE_SUCCEEDED(wil::stream_reset_nothrow(&stream)); + REQUIRE(0ULL == stream.Position); + +#ifdef WIL_ENABLE_EXCEPTIONS + // Seek within the stream + stream.SetPosition(100, 1000); + REQUIRE(10ULL == wil::stream_set_position(&stream, 10)); + + // Seek past the end + REQUIRE(stream.PositionMax == wil::stream_set_position(&stream, 5000)); + + // Seek to the start + REQUIRE_NOTHROW(wil::stream_reset(&stream)); + REQUIRE(0ULL == stream.Position); +#endif +} + +TEST_CASE("StreamTests::SeekCur", "[com][IStream]") +{ + FakeStream stream; + unsigned long long landed; + + stream.SetPosition(100, 5000); + REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, 10, &landed)); + REQUIRE(110ULL == landed); + + REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, -10, &landed)); + REQUIRE(100ULL == landed); + + REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, -1000, &landed)); + REQUIRE(0ULL == landed); + + REQUIRE_SUCCEEDED(wil::stream_seek_from_current_position_nothrow(&stream, 6000, &landed)); + REQUIRE(5000ULL == landed); + +#ifdef WIL_ENABLE_EXCEPTIONS + stream.SetPosition(100, 5000); + + REQUIRE(110ULL == wil::stream_seek_from_current_position(&stream, 10)); + + REQUIRE(100ULL == wil::stream_seek_from_current_position(&stream, -10)); + + REQUIRE(0ULL == wil::stream_seek_from_current_position(&stream, -1000)); + + REQUIRE(5000ULL == wil::stream_seek_from_current_position(&stream, 6000)); +#endif +} + +TEST_CASE("StreamTests::GetPosition", "[com][IStream]") +{ + FakeStream stream; + unsigned long long landed; + + stream.SetPosition(50); + REQUIRE_SUCCEEDED(wil::stream_get_position_nothrow(&stream, &landed)); + REQUIRE(stream.Position == landed); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(stream.Position == wil::stream_get_position(&stream)); +#endif +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("StreamTests::Saver", "[com][IStream]") +{ + FakeStream first; + FakeStream second; + + first.SetPosition(200); + { + auto saved = wil::stream_position_saver(&first); + first.SetPosition(250); + } + REQUIRE(200ULL == first.Position); + + first.SetPosition(200); + { + auto saved = wil::stream_position_saver(&first); + first.SetPosition(250); + saved.reset(); + REQUIRE(200ULL == first.Position); + } + + first.SetPosition(200); + { + auto saved = wil::stream_position_saver(&first); + first.SetPosition(250); + saved.dismiss(); + } + REQUIRE(250ULL == first.Position); + + first.SetPosition(200); + second.SetPosition(250); + { + auto saved = wil::stream_position_saver(&first); + first.SetPosition(210); + saved.reset(&second); + REQUIRE(200ULL == first.Position); + + second.SetPosition(300); + saved.reset(); + REQUIRE(250ULL == second.Position); + } +} +#endif diff --git a/tests/CommonTests.cpp b/tests/CommonTests.cpp new file mode 100644 index 0000000..d8eb4b4 --- /dev/null +++ b/tests/CommonTests.cpp @@ -0,0 +1,243 @@ + +#include +#include +#include + +#include "common.h" + +TEST_CASE("CommonTests::OutParamHelpers", "[common]") +{ + int i = 2; + int *pOutTest = &i; + int *pNullTest = nullptr; + + SECTION("Value type") + { + wil::assign_to_opt_param(pNullTest, 3); + wil::assign_to_opt_param(pOutTest, 3); + REQUIRE(*pOutTest == 3); + } + + SECTION("Pointer to value type") + { + int **ppOutTest = &pOutTest; + int **ppNullTest = nullptr; + wil::assign_null_to_opt_param(ppNullTest); + wil::assign_null_to_opt_param(ppOutTest); + REQUIRE(*ppOutTest == nullptr); + } + + SECTION("COM out pointer") + { + Microsoft::WRL::ComPtr spUnk; + IUnknown **ppunkNull = nullptr; + IUnknown *pUnk = reinterpret_cast(1); + IUnknown **ppUnkValid = &pUnk; + + wil::detach_to_opt_param(ppunkNull, spUnk); + wil::detach_to_opt_param(ppUnkValid, spUnk); + REQUIRE(*ppUnkValid == nullptr); + } +} + +TEST_CASE("CommonTests::TypeValidation", "[common]") +{ + std::unique_ptr boolCastClass; + std::vector noBoolCastClass; + HRESULT hr = S_OK; + BOOL bigBool = true; + bool smallBool = true; + DWORD dword = 1; + Microsoft::WRL::ComPtr comPtr; + (void)dword; + + // NOTE: The commented out verify* calls should give compilation errors + SECTION("verify_bool") + { + REQUIRE(wil::verify_bool(smallBool)); + REQUIRE(wil::verify_bool(bigBool)); + REQUIRE_FALSE(wil::verify_bool(boolCastClass)); + REQUIRE_FALSE(wil::verify_bool(comPtr)); + //wil::verify_bool(noBoolCastClass); + //wil::verify_bool(dword); + //wil::verify_bool(hr); + } + + SECTION("verify_hresult") + { + //wil::verify_hresult(smallBool); + //wil::verify_hresult(bigBool); + //wil::verify_hresult(boolCastClass); + //wil::verify_hresult(noBoolCastClass); + //wil::verify_hresult(dword); + //wil::verify_hresult(comPtr); + REQUIRE(wil::verify_hresult(hr) == S_OK); + } + + SECTION("verify_BOOL") + { + //wil::verify_BOOL(smallBool); + REQUIRE(wil::verify_BOOL(bigBool)); + //wil::verify_BOOL(boolCastClass); + //wil::verify_BOOL(noBoolCastClass); + //wil::verify_BOOL(dword); + //wil::verify_BOOL(comPtr); + //wil::verify_BOOL(hr); + } +} + +template +static void FlagsMacrosNonStatic(T none, T one, T two, T three, T four) +{ + T eval = one | four; + + REQUIRE(WI_AreAllFlagsSet(MDEC(eval), MDEC(one | four))); + REQUIRE_FALSE(WI_AreAllFlagsSet(eval, one | three)); + REQUIRE_FALSE(WI_AreAllFlagsSet(eval, three | two)); + REQUIRE(WI_AreAllFlagsSet(eval, none)); + + REQUIRE(WI_IsAnyFlagSet(MDEC(eval), MDEC(one))); + REQUIRE(WI_IsAnyFlagSet(eval, one | four | three)); + REQUIRE_FALSE(WI_IsAnyFlagSet(eval, two)); + + REQUIRE(WI_AreAllFlagsClear(MDEC(eval), MDEC(three))); + REQUIRE(WI_AreAllFlagsClear(eval, three | two)); + REQUIRE_FALSE(WI_AreAllFlagsClear(eval, one | four)); + REQUIRE_FALSE(WI_AreAllFlagsClear(eval, one | three)); + + REQUIRE(WI_IsAnyFlagClear(MDEC(eval), MDEC(three))); + REQUIRE(WI_IsAnyFlagClear(eval, three | two)); + REQUIRE(WI_IsAnyFlagClear(eval, four | three)); + REQUIRE_FALSE(WI_IsAnyFlagClear(eval, one)); + REQUIRE_FALSE(WI_IsAnyFlagClear(eval, one | four)); + + REQUIRE_FALSE(WI_IsSingleFlagSet(MDEC(eval))); + REQUIRE(WI_IsSingleFlagSet(eval & one)); + + REQUIRE(WI_IsSingleFlagSetInMask(MDEC(eval), MDEC(one))); + REQUIRE(WI_IsSingleFlagSetInMask(eval, one | three)); + REQUIRE_FALSE(WI_IsSingleFlagSetInMask(eval, three)); + REQUIRE_FALSE(WI_IsSingleFlagSetInMask(eval, one | four)); + + REQUIRE_FALSE(WI_IsClearOrSingleFlagSet(MDEC(eval))); + REQUIRE(WI_IsClearOrSingleFlagSet(eval & one)); + REQUIRE(WI_IsClearOrSingleFlagSet(none)); + + REQUIRE(WI_IsClearOrSingleFlagSetInMask(MDEC(eval), MDEC(one))); + REQUIRE(WI_IsClearOrSingleFlagSetInMask(eval, one | three)); + REQUIRE(WI_IsClearOrSingleFlagSetInMask(eval, three)); + REQUIRE_FALSE(WI_IsClearOrSingleFlagSetInMask(eval, one | four)); + + eval = none; + WI_SetAllFlags(MDEC(eval), MDEC(one)); + REQUIRE(eval == one); + WI_SetAllFlags(eval, one | two); + REQUIRE(eval == (one | two)); + + eval = one | two; + WI_ClearAllFlags(MDEC(eval), one); + REQUIRE(eval == two); + WI_ClearAllFlags(eval, two); + REQUIRE(eval == none); + + eval = one | two; + WI_UpdateFlagsInMask(MDEC(eval), MDEC(two | three), MDEC(three | four)); + REQUIRE(eval == (one | three)); + + eval = one; + WI_ToggleAllFlags(MDEC(eval), MDEC(one | two)); + REQUIRE(eval == two); +} + +enum class EClassTest +{ + None = 0x0, + One = 0x1, + Two = 0x2, + Three = 0x4, + Four = 0x8, +}; +DEFINE_ENUM_FLAG_OPERATORS(EClassTest); + +enum ERawTest +{ + ER_None = 0x0, + ER_One = 0x1, + ER_Two = 0x2, + ER_Three = 0x4, + ER_Four = 0x8, +}; +DEFINE_ENUM_FLAG_OPERATORS(ERawTest); + +TEST_CASE("CommonTests::FlagsMacros", "[common]") +{ + SECTION("Integral types") + { + FlagsMacrosNonStatic(static_cast(0), static_cast(0x1), static_cast(0x2), static_cast(0x4), static_cast(0x40)); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x80u); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x4000); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x8000u); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x80000000ul); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x80000000ul); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x8000000000000000ull); + FlagsMacrosNonStatic(0, 0x1, 0x2, 0x4, 0x8000000000000000ull); + } + + SECTION("Raw enum") + { + FlagsMacrosNonStatic(ER_None, ER_One, ER_Two, ER_Three, ER_Four); + } + + SECTION("Enum class") + { + FlagsMacrosNonStatic(EClassTest::None, EClassTest::One, EClassTest::Two, EClassTest::Three, EClassTest::Four); + + EClassTest eclass = EClassTest::One | EClassTest::Two; + REQUIRE(WI_IsFlagSet(MDEC(eclass), EClassTest::One)); + REQUIRE(WI_IsFlagSet(eclass, EClassTest::Two)); + REQUIRE_FALSE(WI_IsFlagSet(eclass, EClassTest::Three)); + + REQUIRE(WI_IsFlagClear(MDEC(eclass), EClassTest::Three)); + REQUIRE_FALSE(WI_IsFlagClear(eclass, EClassTest::One)); + + REQUIRE_FALSE(WI_IsSingleFlagSet(MDEC(eclass))); + REQUIRE(WI_IsSingleFlagSet(eclass & EClassTest::One)); + + eclass = EClassTest::None; + WI_SetFlag(MDEC(eclass), EClassTest::One); + REQUIRE(eclass == EClassTest::One); + + eclass = EClassTest::None; + WI_SetFlagIf(eclass, EClassTest::One, false); + REQUIRE(eclass == EClassTest::None); + WI_SetFlagIf(eclass, EClassTest::One, true); + REQUIRE(eclass == EClassTest::One); + + eclass = EClassTest::None; + WI_SetFlagIf(eclass, EClassTest::One, false); + REQUIRE(eclass == EClassTest::None); + WI_SetFlagIf(eclass, EClassTest::One, true); + REQUIRE(eclass == EClassTest::One); + + eclass = EClassTest::One | EClassTest::Two; + WI_ClearFlag(eclass, EClassTest::Two); + REQUIRE(eclass == EClassTest::One); + + eclass = EClassTest::One | EClassTest::Two; + WI_ClearFlagIf(eclass, EClassTest::One, false); + REQUIRE(eclass == (EClassTest::One | EClassTest::Two)); + WI_ClearFlagIf(eclass, EClassTest::One, true); + REQUIRE(eclass == EClassTest::Two); + + eclass = EClassTest::None; + WI_UpdateFlag(eclass, EClassTest::One, true); + REQUIRE(eclass == EClassTest::One); + WI_UpdateFlag(eclass, EClassTest::One, false); + REQUIRE(eclass == EClassTest::None); + + eclass = EClassTest::One; + WI_ToggleFlag(eclass, EClassTest::One); + WI_ToggleFlag(eclass, EClassTest::Two); + REQUIRE(eclass == EClassTest::Two); + } +} diff --git a/tests/CppWinRTTests.cpp b/tests/CppWinRTTests.cpp new file mode 100644 index 0000000..a2e69e0 --- /dev/null +++ b/tests/CppWinRTTests.cpp @@ -0,0 +1,74 @@ + +#include "catch.hpp" + +#include + +TEST_CASE("CppWinRTTests::WilToCppWinRTExceptionTranslationTest", "[cppwinrt]") +{ + auto test = [](HRESULT hr) + { + try + { + THROW_HR(hr); + } + catch (...) + { + REQUIRE(hr == winrt::to_hresult()); + } + }; + + test(E_UNEXPECTED); + test(E_ACCESSDENIED); + test(E_INVALIDARG); + test(E_HANDLE); + test(E_OUTOFMEMORY); +} + +TEST_CASE("CppWinRTTests::CppWinRTToWilExceptionTranslationTest", "[cppwinrt]") +{ + auto test = [](HRESULT hr) + { + try + { + winrt::check_hresult(hr); + } + catch (...) + { + REQUIRE(hr == wil::ResultFromCaughtException()); + } + }; + + test(E_UNEXPECTED); + test(E_ACCESSDENIED); + test(E_INVALIDARG); + test(E_HANDLE); + test(E_OUTOFMEMORY); +} + +TEST_CASE("CppWinRTTests::ResultFromExceptionDebugTest", "[cppwinrt]") +{ + auto test = [](HRESULT hr, wil::SupportedExceptions supportedExceptions) + { + auto result = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, supportedExceptions, [&]() + { + winrt::check_hresult(hr); + }); + REQUIRE(hr == result); + }; + + // Anything from SupportedExceptions::Known or SupportedExceptions::All should give back the same HRESULT + test(E_UNEXPECTED, wil::SupportedExceptions::Known); + test(E_ACCESSDENIED, wil::SupportedExceptions::Known); + test(E_INVALIDARG, wil::SupportedExceptions::All); + test(E_HANDLE, wil::SupportedExceptions::All); + + // OOM gets translated to bad_alloc, which should always give back E_OUTOFMEMORY + test(E_OUTOFMEMORY, wil::SupportedExceptions::All); + test(E_OUTOFMEMORY, wil::SupportedExceptions::Known); + test(E_OUTOFMEMORY, wil::SupportedExceptions::ThrownOrAlloc); + + // Uncomment any of the following to validate SEH failfast + //test(E_UNEXPECTED, wil::SupportedExceptions::None); + //test(E_ACCESSDENIED, wil::SupportedExceptions::Thrown); + //test(E_INVALIDARG, wil::SupportedExceptions::ThrownOrAlloc); +} diff --git a/tests/FakeWinRTTypes.h b/tests/FakeWinRTTypes.h new file mode 100644 index 0000000..70e607a --- /dev/null +++ b/tests/FakeWinRTTypes.h @@ -0,0 +1,400 @@ +#pragma once + +#include +#include +#include +#include + +template +struct WinRTStorage +{ + T value = {}; + + HRESULT CopyTo(T* result) + { + *result = value; + return S_OK; + } + + HRESULT Set(T val) + { + value = val; + return S_OK; + } + + static void Destroy(T&) + { + } + + void Reset() + { + } + + bool Equals(T val) + { + // NOTE: Padding can through this off, but this isn't intended to be a robust solution... + return memcmp(&value, &val, sizeof(T)) == 0; + } +}; + +template <> +struct WinRTStorage +{ + Microsoft::WRL::Wrappers::HString value; + + HRESULT CopyTo(HSTRING* result) + { + return value.CopyTo(result); + } + + HRESULT Set(HSTRING val) + { + return value.Set(val); + } + + static void Destroy(HSTRING& val) + { + ::WindowsDeleteString(val); + val = nullptr; + } + + void Reset() + { + value = {}; + } + + bool Equals(HSTRING val) + { + return value == val; + } +}; + +template +struct WinRTStorage +{ + Microsoft::WRL::ComPtr value; + + HRESULT CopyTo(T** result) + { + *result = Microsoft::WRL::ComPtr(value).Detach(); + return S_OK; + } + + HRESULT Set(T* val) + { + value = val; + return S_OK; + } + + static void Destroy(T*& val) + { + val->Release(); + val = nullptr; + } + + void Reset() + { + value.Reset(); + } + + bool Equals(T* val) + { + return value.Get() == val; + } +}; + +// Very minimal IAsyncOperation implementation that gives calling tests control over when it completes +template +struct FakeAsyncOperation : Microsoft::WRL::RuntimeClass< + ABI::Windows::Foundation::IAsyncInfo, + ABI::Windows::Foundation::IAsyncOperation> +{ + using Handler = ABI::Windows::Foundation::IAsyncOperationCompletedHandler; + + // IAsyncInfo + IFACEMETHODIMP get_Id(unsigned int*) override + { + return E_NOTIMPL; + } + + IFACEMETHODIMP get_Status(AsyncStatus* status) override + { + auto lock = m_lock.lock_shared(); + *status = m_status; + return S_OK; + } + + IFACEMETHODIMP get_ErrorCode(HRESULT* errorCode) override + { + auto lock = m_lock.lock_shared(); + *errorCode = m_result; + return S_OK; + } + + IFACEMETHODIMP Cancel() override + { + return E_NOTIMPL; + } + + IFACEMETHODIMP Close() override + { + return E_NOTIMPL; + } + + // IAsyncOperation + IFACEMETHODIMP put_Completed(Handler* handler) override + { + bool invoke = false; + { + auto lock = m_lock.lock_exclusive(); + if (m_handler) + { + return E_FAIL; + } + + m_handler = handler; + invoke = m_status != ABI::Windows::Foundation::AsyncStatus::Started; + } + + if (invoke) + { + handler->Invoke(this, m_status); + } + + return S_OK; + } + + IFACEMETHODIMP get_Completed(Handler** handler) override + { + auto lock = m_lock.lock_shared(); + *handler = Microsoft::WRL::ComPtr(m_handler).Detach(); + return S_OK; + } + + IFACEMETHODIMP GetResults(Abi* results) override + { + return m_storage.CopyTo(results); + } + + // Test functions + void Complete(HRESULT hr, Abi result) + { + using namespace ABI::Windows::Foundation; + Handler* handler = nullptr; + { + auto lock = m_lock.lock_exclusive(); + if (m_status == AsyncStatus::Started) + { + m_result = hr; + m_storage.Set(result); + m_status = SUCCEEDED(hr) ? AsyncStatus::Completed : AsyncStatus::Error; + handler = m_handler.Get(); + } + } + + if (handler) + { + handler->Invoke(this, m_status); + } + } + +private: + + wil::srwlock m_lock; + Microsoft::WRL::ComPtr m_handler; + ABI::Windows::Foundation::AsyncStatus m_status = ABI::Windows::Foundation::AsyncStatus::Started; + HRESULT m_result = S_OK; + WinRTStorage m_storage; +}; + +template +struct FakeVector : Microsoft::WRL::RuntimeClass< + ABI::Windows::Foundation::Collections::IVector, + ABI::Windows::Foundation::Collections::IVectorView> +{ + // IVector + IFACEMETHODIMP GetAt(unsigned index, Abi* item) override + { + if (index >= m_size) + { + return E_BOUNDS; + } + + return m_data[index].CopyTo(item); + } + + IFACEMETHODIMP get_Size(unsigned* size) override + { + *size = static_cast(m_size); + return S_OK; + } + + IFACEMETHODIMP GetView(ABI::Windows::Foundation::Collections::IVectorView** view) override + { + this->AddRef(); + *view = this; + return S_OK; + } + + IFACEMETHODIMP IndexOf(Abi value, unsigned* index, boolean* found) override + { + for (size_t i = 0; i < m_size; ++i) + { + if (m_data[i].Equals(value)) + { + *index = static_cast(i); + *found = true; + return S_OK; + } + } + + *index = 0; + *found = false; + return S_OK; + } + + IFACEMETHODIMP SetAt(unsigned index, Abi item) override + { + if (index >= m_size) + { + return E_BOUNDS; + } + + return m_data[index].Set(item); + } + + IFACEMETHODIMP InsertAt(unsigned index, Abi item) override + { + // Insert at the end and swap it into place + if (index > m_size) + { + return E_BOUNDS; + } + + auto hr = Append(item); + if (SUCCEEDED(hr)) + { + for (size_t i = m_size - 1; i > index; --i) + { + wistd::swap_wil(m_data[i], m_data[i - 1]); + } + } + + return hr; + } + + IFACEMETHODIMP RemoveAt(unsigned index) override + { + if (index >= m_size) + { + return E_BOUNDS; + } + + for (size_t i = index + 1; i < m_size; ++i) + { + wistd::swap_wil(m_data[i - 1], m_data[i]); + } + + m_data[--m_size].Reset(); + return S_OK; + } + + IFACEMETHODIMP Append(Abi item) override + { + if (m_size > MaxSize) + { + return E_OUTOFMEMORY; + } + + auto hr = m_data[m_size].Set(item); + if (SUCCEEDED(hr)) + { + ++m_size; + } + + return hr; + } + + IFACEMETHODIMP RemoveAtEnd() override + { + if (m_size == 0) + { + return E_BOUNDS; + } + + m_data[--m_size].Reset(); + return S_OK; + } + + IFACEMETHODIMP Clear() override + { + for (size_t i = 0; i < m_size; ++i) + { + m_data[i].Reset(); + } + + m_size = 0; + return S_OK; + } + + IFACEMETHODIMP GetMany(unsigned startIndex, unsigned capacity, Abi* value, unsigned* actual) override + { + *actual = 0; + if (startIndex >= m_size) + { + return S_OK; + } + + auto count = m_size - startIndex; + count = (count > capacity) ? capacity : count; + + HRESULT hr = S_OK; + unsigned i = 0; + for (; (i < count) && SUCCEEDED(hr); ++i) + { + hr = m_data[startIndex + i].CopyTo(value + i); + } + + if (SUCCEEDED(hr)) + { + *actual = static_cast(count); + } + else + { + while (i--) + { + WinRTStorage::Destroy(value[i]); + } + } + + return hr; + } + + IFACEMETHODIMP ReplaceAll(unsigned count, Abi* value) override + { + if (count > MaxSize) + { + return E_OUTOFMEMORY; + } + + Clear(); + + HRESULT hr = S_OK; + for (size_t i = 0; (i < count) && SUCCEEDED(hr); ++i) + { + hr = m_data[i].Set(value[i]); + } + + if (FAILED(hr)) + { + Clear(); + } + + return hr; + } + +private: + + size_t m_size = 0; + WinRTStorage m_data[MaxSize]; +}; diff --git a/tests/FileSystemTests.cpp b/tests/FileSystemTests.cpp new file mode 100644 index 0000000..6477b1f --- /dev/null +++ b/tests/FileSystemTests.cpp @@ -0,0 +1,448 @@ + +#include +#include // For wil::unique_hstring + +#include +#ifdef WIL_ENABLE_EXCEPTIONS +#include +#endif + +// TODO: str_raw_ptr is not two-phase name lookup clean (https://github.com/Microsoft/wil/issues/8) +namespace wil +{ + PCWSTR str_raw_ptr(HSTRING); +#ifdef WIL_ENABLE_EXCEPTIONS + PCWSTR str_raw_ptr(const std::wstring&); +#endif +} + +#include + +#ifdef WIL_ENABLE_EXCEPTIONS +#include // For std::wstring string_maker +#endif + +#include "common.h" + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + +bool DirectoryExists(_In_ PCWSTR path) +{ + DWORD dwAttrib = GetFileAttributesW(path); + + return (dwAttrib != INVALID_FILE_ATTRIBUTES && + (dwAttrib & FILE_ATTRIBUTE_DIRECTORY)); +} + +TEST_CASE("FileSystemTests::CreateDirectory", "[filesystem]") +{ + wchar_t basePath[MAX_PATH]; + REQUIRE(GetTempPathW(ARRAYSIZE(basePath), basePath)); + REQUIRE_SUCCEEDED(PathCchAppend(basePath, ARRAYSIZE(basePath), L"FileSystemTests")); + + REQUIRE_FALSE(DirectoryExists(basePath)); + REQUIRE(SUCCEEDED(wil::CreateDirectoryDeepNoThrow(basePath))); + REQUIRE(DirectoryExists(basePath)); + + auto scopeGuard = wil::scope_exit([&] + { + REQUIRE_SUCCEEDED(wil::RemoveDirectoryRecursiveNoThrow(basePath)); + }); + + PCWSTR relativeTestPath = L"folder1\\folder2\\folder3\\folder4\\folder5\\folder6\\folder7\\folder8"; + wchar_t absoluteTestPath[MAX_PATH]; + REQUIRE_SUCCEEDED(StringCchCopyW(absoluteTestPath, ARRAYSIZE(absoluteTestPath), basePath)); + REQUIRE_SUCCEEDED(PathCchAppend(absoluteTestPath, ARRAYSIZE(absoluteTestPath), relativeTestPath)); + REQUIRE_FALSE(DirectoryExists(absoluteTestPath)); + REQUIRE_SUCCEEDED(wil::CreateDirectoryDeepNoThrow(absoluteTestPath)); + + PCWSTR invalidCharsPath = L"Bad?Char|"; + wchar_t absoluteInvalidPath[MAX_PATH]; + REQUIRE_SUCCEEDED(StringCchCopyW(absoluteInvalidPath, ARRAYSIZE(absoluteInvalidPath), basePath)); + REQUIRE_SUCCEEDED(PathCchAppend(absoluteInvalidPath, ARRAYSIZE(absoluteInvalidPath), invalidCharsPath)); + REQUIRE_FALSE(DirectoryExists(absoluteInvalidPath)); + REQUIRE_FALSE(SUCCEEDED(wil::CreateDirectoryDeepNoThrow(absoluteInvalidPath))); + + PCWSTR testPath3 = L"folder1\\folder2\\folder3"; + wchar_t absoluteTestPath3[MAX_PATH]; + REQUIRE_SUCCEEDED(StringCchCopyW(absoluteTestPath3, ARRAYSIZE(absoluteTestPath3), basePath)); + REQUIRE_SUCCEEDED(PathCchAppend(absoluteTestPath3, ARRAYSIZE(absoluteTestPath3), testPath3)); + REQUIRE(DirectoryExists(absoluteTestPath3)); + + PCWSTR testPath4 = L"folder1\\folder2\\folder3\\folder4"; + wchar_t absoluteTestPath4[MAX_PATH]; + REQUIRE_SUCCEEDED(StringCchCopyW(absoluteTestPath4, ARRAYSIZE(absoluteTestPath4), basePath)); + REQUIRE_SUCCEEDED(PathCchAppend(absoluteTestPath4, ARRAYSIZE(absoluteTestPath4), testPath4)); + REQUIRE(DirectoryExists(absoluteTestPath4)); + + REQUIRE_SUCCEEDED(wil::RemoveDirectoryRecursiveNoThrow(absoluteTestPath3, wil::RemoveDirectoryOptions::KeepRootDirectory)); + REQUIRE(DirectoryExists(absoluteTestPath3)); + REQUIRE_FALSE(DirectoryExists(absoluteTestPath4)); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +// Learn about the Win32 API normalization here: https://blogs.msdn.microsoft.com/jeremykuhne/2016/04/21/path-normalization/ +// This test verifies the ability of RemoveDirectoryRecursive to be able to delete files +// that are in the non-normalized form. +TEST_CASE("FileSystemTests::VerifyRemoveDirectoryRecursiveCanDeleteFoldersWithNonNormalizedNames", "[filesystem]") +{ + // Extended length paths can access files with non-normalized names. + // This function creates a path with that ability. + auto CreatePathThatCanAccessNonNormalizedNames = [](PCWSTR root, PCWSTR name) + { + wil::unique_hlocal_string path; + THROW_IF_FAILED(PathAllocCombine(root, name, PATHCCH_DO_NOT_NORMALIZE_SEGMENTS | PATHCCH_ENSURE_IS_EXTENDED_LENGTH_PATH, &path)); + REQUIRE(wil::is_extended_length_path(path.get())); + return path; + }; + + // Regular paths are normalized in the Win32 APIs thus can't address files in the non-normalized form. + // This function creates a regular path form but preserves the non-normalized parts of the input (for testing) + auto CreateRegularPath = [](PCWSTR root, PCWSTR name) + { + wil::unique_hlocal_string path; + THROW_IF_FAILED(PathAllocCombine(root, name, PATHCCH_DO_NOT_NORMALIZE_SEGMENTS, &path)); + REQUIRE_FALSE(wil::is_extended_length_path(path.get())); + return path; + }; + + struct TestCases + { + PCWSTR CreateWithName; + PCWSTR DeleteWithName; + wil::unique_hlocal_string (*CreatePathFunction)(PCWSTR root, PCWSTR name); + HRESULT ExpectedResult; + }; + + PCWSTR NormalizedName = L"Foo"; + PCWSTR NonNormalizedName = L"Foo."; // The dot at the end is what makes this non-normalized. + const auto PathNotFoundError = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND); + + TestCases tests[] = + { + { NormalizedName, NormalizedName, CreateRegularPath, S_OK }, + { NonNormalizedName, NormalizedName, CreateRegularPath, PathNotFoundError }, + { NormalizedName, NonNormalizedName, CreateRegularPath, S_OK }, + { NonNormalizedName, NonNormalizedName, CreateRegularPath, PathNotFoundError }, + { NormalizedName, NormalizedName, CreatePathThatCanAccessNonNormalizedNames, S_OK }, + { NonNormalizedName, NormalizedName, CreatePathThatCanAccessNonNormalizedNames, PathNotFoundError }, + { NormalizedName, NonNormalizedName, CreatePathThatCanAccessNonNormalizedNames, PathNotFoundError }, + { NonNormalizedName, NonNormalizedName, CreatePathThatCanAccessNonNormalizedNames, S_OK }, + }; + + auto folderRoot = wil::ExpandEnvironmentStringsW(LR"(%TEMP%)"); + REQUIRE_FALSE(wil::is_extended_length_path(folderRoot.get())); + + auto EnsureFolderWithNonCanonicalNameAndContentsExists = [&](const TestCases& test) + { + const auto enableNonNormalized = PATHCCH_ENSURE_IS_EXTENDED_LENGTH_PATH | PATHCCH_DO_NOT_NORMALIZE_SEGMENTS; + + wil::unique_hlocal_string targetFolder; + // Create a folder for testing using the extended length form to enable + // access to non-normalized forms of the path + THROW_IF_FAILED(PathAllocCombine(folderRoot.get(), test.CreateWithName, enableNonNormalized, &targetFolder)); + + // This ensures the folder is there and won't fail if it already exists (common when testing). + wil::CreateDirectoryDeep(targetFolder.get()); + + // Create a file in that folder with a non-normalized name (with the dot at the end). + wil::unique_hlocal_string extendedFilePath; + THROW_IF_FAILED(PathAllocCombine(targetFolder.get(), L"NonNormalized.", enableNonNormalized, &extendedFilePath)); + wil::unique_hfile fileHandle(CreateFileW(extendedFilePath.get(), FILE_WRITE_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, + CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr)); + THROW_LAST_ERROR_IF(!fileHandle); + }; + + for (auto const& test : tests) + { + // remove remnants from previous test that will cause failures + wil::RemoveDirectoryRecursiveNoThrow(CreatePathThatCanAccessNonNormalizedNames(folderRoot.get(), NormalizedName).get()); + wil::RemoveDirectoryRecursiveNoThrow(CreatePathThatCanAccessNonNormalizedNames(folderRoot.get(), NonNormalizedName).get()); + + EnsureFolderWithNonCanonicalNameAndContentsExists(test); + auto deleteWithPath = test.CreatePathFunction(folderRoot.get(), test.DeleteWithName); + + const auto hr = wil::RemoveDirectoryRecursiveNoThrow(deleteWithPath.get()); + REQUIRE(test.ExpectedResult == hr); + } +} +#endif + +// real paths to test +const wchar_t c_variablePath[] = L"%systemdrive%\\Windows\\System32\\Windows.Storage.dll"; +const wchar_t c_expandedPath[] = L"c:\\Windows\\System32\\Windows.Storage.dll"; + +// // paths that should not exist on the system +const wchar_t c_missingVariable[] = L"%doesnotexist%\\doesnotexist.dll"; +const wchar_t c_missingPath[] = L"c:\\Windows\\System32\\doesnotexist.dll"; + +const int c_stackBufferLimitTest = 5; + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("FileSystemTests::VerifyGetCurrentDirectory", "[filesystem]") +{ + auto pwd = wil::GetCurrentDirectoryW(); + REQUIRE(*pwd.get() != L'\0'); +} + +TEST_CASE("FileSystemTests::VerifyGetFullPathName", "[filesystem]") +{ + PCWSTR fileName = L"ReadMe.txt"; + auto result = wil::GetFullPathNameW(fileName, nullptr); + + PCWSTR fileNameResult; + result = wil::GetFullPathNameW(fileName, &fileNameResult); + REQUIRE(wcscmp(fileName, fileNameResult) == 0); + auto result2 = wil::GetFullPathNameW(fileName, &fileNameResult); + REQUIRE(wcscmp(fileName, fileNameResult) == 0); + REQUIRE(wcscmp(result.get(), result2.get()) == 0); + + // The only negative test case I've found is a path > 32k. + std::wstring big(1024 * 32, L'a'); + wil::unique_hstring output; + auto hr = wil::GetFullPathNameW(big.c_str(), output, nullptr); + REQUIRE(hr == HRESULT_FROM_WIN32(ERROR_FILENAME_EXCED_RANGE)); +} + +TEST_CASE("FileSystemTests::VerifyGetFinalPathNameByHandle", "[filesystem]") +{ + wil::unique_hfile fileHandle(CreateFileW(c_expandedPath, FILE_READ_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, nullptr)); + THROW_LAST_ERROR_IF(!fileHandle); + + auto name = wil::GetFinalPathNameByHandleW(fileHandle.get()); + auto name2 = wil::GetFinalPathNameByHandleW(fileHandle.get()); + REQUIRE(wcscmp(name.get(), name2.get()) == 0); + + std::wstring path; + auto hr = wil::GetFinalPathNameByHandleW(nullptr, path); + REQUIRE(hr == E_HANDLE); // should be a usage error so be a fail fast. + // A more legitimate case is a non file handler like a drive volume. + + wil::unique_hfile volumeHandle(CreateFileW(LR"(\\?\C:)", FILE_READ_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, + FILE_FLAG_BACKUP_SEMANTICS, nullptr)); + THROW_LAST_ERROR_IF(!volumeHandle); + const auto hr2 = wil::GetFinalPathNameByHandleW(volumeHandle.get(), path); + REQUIRE(hr2 == HRESULT_FROM_WIN32(ERROR_INVALID_FUNCTION)); +} + +TEST_CASE("FileSystemTests::VerifyTrySearchPathW", "[filesystem]") +{ + auto pathToTest = wil::TrySearchPathW(nullptr, c_expandedPath, nullptr); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_expandedPath, -1, TRUE) == CSTR_EQUAL); + + pathToTest = wil::TrySearchPathW(nullptr, c_missingPath, nullptr); + REQUIRE(wil::string_get_not_null(pathToTest)[0] == L'\0'); +} +#endif + +// Simple test to expand an environmental string +TEST_CASE("FileSystemTests::VerifyExpandEnvironmentStringsW", "[filesystem]") +{ + wil::unique_cotaskmem_string pathToTest; + REQUIRE_SUCCEEDED(wil::ExpandEnvironmentStringsW(c_variablePath, pathToTest)); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_expandedPath, -1, TRUE) == CSTR_EQUAL); + + // This should effectively be a no-op + REQUIRE_SUCCEEDED(wil::ExpandEnvironmentStringsW(c_expandedPath, pathToTest)); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_expandedPath, -1, TRUE) == CSTR_EQUAL); + + // Environment variable does not exist, but the call should still succeed + REQUIRE_SUCCEEDED(wil::ExpandEnvironmentStringsW(c_missingVariable, pathToTest)); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_missingVariable, -1, TRUE) == CSTR_EQUAL); +} + +TEST_CASE("FileSystemTests::VerifySearchPathW", "[filesystem]") +{ + wil::unique_cotaskmem_string pathToTest; + REQUIRE_SUCCEEDED(wil::SearchPathW(nullptr, c_expandedPath, nullptr, pathToTest)); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_expandedPath, -1, TRUE) == CSTR_EQUAL); + + REQUIRE(HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) == wil::SearchPathW(nullptr, c_missingPath, nullptr, pathToTest)); +} + +TEST_CASE("FileSystemTests::VerifyExpandEnvAndSearchPath", "[filesystem]") +{ + wil::unique_cotaskmem_string pathToTest; + REQUIRE_SUCCEEDED(wil::ExpandEnvAndSearchPath(c_variablePath, pathToTest)); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_expandedPath, -1, TRUE) == CSTR_EQUAL); + + // This test will exercise the case where AdaptFixedSizeToAllocatedResult will need to + // reallocate the initial buffer to fit the final string. + // This test is sufficient to test both wil::ExpandEnvironmentStringsW and wil::SeachPathW + REQUIRE_SUCCEEDED((wil::ExpandEnvAndSearchPath(c_variablePath, pathToTest))); + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, c_expandedPath, -1, TRUE) == CSTR_EQUAL); + + pathToTest.reset(); + REQUIRE(HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) == wil::ExpandEnvAndSearchPath(c_missingVariable, pathToTest)); + REQUIRE(pathToTest.get() == nullptr); +} + +TEST_CASE("FileSystemTests::VerifyGetSystemDirectoryW", "[filesystem]") +{ + wil::unique_cotaskmem_string pathToTest; + REQUIRE_SUCCEEDED(wil::GetSystemDirectoryW(pathToTest)); + + // allocate based on the string that wil::GetSystemDirectoryW returned + size_t length = wcslen(pathToTest.get()) + 1; + auto trueSystemDir = wil::make_cotaskmem_string_nothrow(nullptr, length); + REQUIRE(GetSystemDirectoryW(trueSystemDir.get(), static_cast(length)) > 0); + + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, trueSystemDir.get(), -1, TRUE) == CSTR_EQUAL); + + // Force AdaptFixed* to realloc. Test stack boundary with small initial buffer limit, c_stackBufferLimitTest + REQUIRE_SUCCEEDED((wil::GetSystemDirectoryW(pathToTest))); + + // allocate based on the string that wil::GetSystemDirectoryW returned + length = wcslen(pathToTest.get()) + 1; + trueSystemDir = wil::make_cotaskmem_string_nothrow(nullptr, length); + REQUIRE(GetSystemDirectoryW(trueSystemDir.get(), static_cast(length)) > 0); + + REQUIRE(CompareStringOrdinal(pathToTest.get(), -1, trueSystemDir.get(), -1, TRUE) == CSTR_EQUAL); +} + +struct has_operator_pcwstr +{ + PCWSTR value; + operator PCWSTR() const + { + return value; + } +}; + +struct has_operator_pwstr +{ + PWSTR value; + operator PWSTR() const + { + return value; + } +}; + +#ifdef WIL_ENABLE_EXCEPTIONS +struct has_operator_wstr +{ + std::wstring value; + operator const std::wstring&() const + { + return value; + } +}; +#endif + +TEST_CASE("FileSystemTests::VerifyStrConcat", "[filesystem]") +{ + SECTION("Concat with multiple strings") + { + PCWSTR test1 = L"Test1"; +#ifdef WIL_ENABLE_EXCEPTIONS + std::wstring test2 = L"Test2"; +#else + PCWSTR test2 = L"Test2"; +#endif + WCHAR test3[6] = L"Test3"; + wil::unique_cotaskmem_string test4 = wil::make_unique_string_nothrow(L"test4"); + wil::unique_hstring test5 = wil::make_unique_string_nothrow(L"test5"); + + has_operator_pcwstr test6{ L"Test6" }; + WCHAR test7Buffer[] = L"Test7"; + has_operator_pwstr test7{ test7Buffer }; + +#ifdef WIL_ENABLE_EXCEPTIONS + has_operator_wstr test8{ L"Test8" }; +#else + PCWSTR test8 = L"Test8"; +#endif + PCWSTR expectedStr = L"Test1Test2Test3Test4Test5Test6Test7Test8"; + +#ifdef WIL_ENABLE_EXCEPTIONS + auto combinedString = wil::str_concat(test1, test2, test3, test4, test5, test6, test7, test8); + REQUIRE(CompareStringOrdinal(combinedString.get(), -1, expectedStr, -1, TRUE) == CSTR_EQUAL); +#endif + + wil::unique_cotaskmem_string combinedStringNT; + REQUIRE_SUCCEEDED(wil::str_concat_nothrow(combinedStringNT, test1, test2, test3, test4, test5, test6, test7, test8)); + REQUIRE(CompareStringOrdinal(combinedStringNT.get(), -1, expectedStr, -1, TRUE) == CSTR_EQUAL); + + auto combinedStringFF = wil::str_concat_failfast(test1, test2, test3, test4, test5, test6, test7, test8); + REQUIRE(CompareStringOrdinal(combinedStringFF.get(), -1, expectedStr, -1, TRUE) == CSTR_EQUAL); + } + + SECTION("Concat with single string") + { + PCWSTR test1 = L"Test1"; + +#ifdef WIL_ENABLE_EXCEPTIONS + auto combinedString = wil::str_concat(test1); + REQUIRE(CompareStringOrdinal(combinedString.get(), -1, test1, -1, TRUE) == CSTR_EQUAL); +#endif + + wil::unique_cotaskmem_string combinedStringNT; + REQUIRE_SUCCEEDED(wil::str_concat_nothrow(combinedStringNT, test1)); + REQUIRE(CompareStringOrdinal(combinedStringNT.get(), -1, test1, -1, TRUE) == CSTR_EQUAL); + + auto combinedStringFF = wil::str_concat_failfast(test1); + REQUIRE(CompareStringOrdinal(combinedStringFF.get(), -1, test1, -1, TRUE) == CSTR_EQUAL); + } + + SECTION("Concat with existing string") + { + std::wstring test2 = L"Test2"; + WCHAR test3[6] = L"Test3"; + PCWSTR expectedStr = L"Test1Test2Test3"; + + wil::unique_cotaskmem_string combinedStringNT = wil::make_unique_string_nothrow(L"Test1"); + REQUIRE_SUCCEEDED(wil::str_concat_nothrow(combinedStringNT, test2.c_str(), test3)); + REQUIRE(CompareStringOrdinal(combinedStringNT.get(), -1, expectedStr, -1, TRUE) == CSTR_EQUAL); + } +} + +TEST_CASE("FileSystemTests::VerifyStrPrintf", "[filesystem]") +{ +#ifdef WIL_ENABLE_EXCEPTIONS + auto formattedString = wil::str_printf(L"Test %s %c %d %4.2f", L"String", L'c', 42, 6.28); + REQUIRE(CompareStringOrdinal(formattedString.get(), -1, L"Test String c 42 6.28", -1, TRUE) == CSTR_EQUAL); +#endif + + wil::unique_cotaskmem_string formattedStringNT; + REQUIRE_SUCCEEDED(wil::str_printf_nothrow(formattedStringNT, L"Test %s %c %d %4.2f", L"String", L'c', 42, 6.28)); + REQUIRE(CompareStringOrdinal(formattedStringNT.get(), -1, L"Test String c 42 6.28", -1, TRUE) == CSTR_EQUAL); + + auto formattedStringFF = wil::str_printf_failfast(L"Test %s %c %d %4.2f", L"String", L'c', 42, 6.28); + REQUIRE(CompareStringOrdinal(formattedStringFF.get(), -1, L"Test String c 42 6.28", -1, TRUE) == CSTR_EQUAL); +} + +TEST_CASE("FileSystemTests::VerifyGetModuleFileNameW", "[filesystem]") +{ + wil::unique_cotaskmem_string path; + REQUIRE_SUCCEEDED(wil::GetModuleFileNameW(nullptr, path)); + auto len = wcslen(path.get()); + REQUIRE(((len >= 4) && (wcscmp(path.get() + len - 4, L".exe") == 0))); + + // Call again, but force multiple retries through a small initial buffer + wil::unique_cotaskmem_string path2; + REQUIRE_SUCCEEDED((wil::GetModuleFileNameW(nullptr, path2))); + REQUIRE(wcscmp(path.get(), path2.get()) == 0); + + REQUIRE_FAILED(wil::GetModuleFileNameW((HMODULE)INVALID_HANDLE_VALUE, path)); +} + +TEST_CASE("FileSystemTests::VerifyGetModuleFileNameExW", "[filesystem]") +{ + wil::unique_cotaskmem_string path; + REQUIRE_SUCCEEDED(wil::GetModuleFileNameExW(nullptr, nullptr, path)); + auto len = wcslen(path.get()); + REQUIRE(((len >= 4) && (wcscmp(path.get() + len - 4, L".exe") == 0))); + + // Call again, but force multiple retries through a small initial buffer + wil::unique_cotaskmem_string path2; + REQUIRE_SUCCEEDED((wil::GetModuleFileNameExW(nullptr, nullptr, path2))); + REQUIRE(wcscmp(path.get(), path2.get()) == 0); + + REQUIRE_FAILED(wil::GetModuleFileNameExW(nullptr, (HMODULE)INVALID_HANDLE_VALUE, path)); +} + +#endif // WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) diff --git a/tests/MallocSpy.h b/tests/MallocSpy.h new file mode 100644 index 0000000..b2d50aa --- /dev/null +++ b/tests/MallocSpy.h @@ -0,0 +1,145 @@ +#pragma once + +#include "catch.hpp" +#include +#include +#include + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM) + +// IMallocSpy requires you to implement all methods, but we often only want one or two... +struct MallocSpy : Microsoft::WRL::RuntimeClass, IMallocSpy> +{ + wistd::function PreAllocCallback; + virtual SIZE_T STDMETHODCALLTYPE PreAlloc(SIZE_T requestSize) override + { + if (PreAllocCallback) + { + return PreAllocCallback(requestSize); + } + + return requestSize; + } + + wistd::function PostAllocCallback; + virtual void* STDMETHODCALLTYPE PostAlloc(void* ptr) override + { + if (PostAllocCallback) + { + return PostAllocCallback(ptr); + } + + return ptr; + } + + wistd::function PreFreeCallback; + virtual void* STDMETHODCALLTYPE PreFree(void* ptr, BOOL wasSpyed) override + { + if (wasSpyed && PreFreeCallback) + { + return PreFreeCallback(ptr); + } + + return ptr; + } + + virtual void STDMETHODCALLTYPE PostFree(BOOL /*wasSpyed*/) override + { + } + + wistd::function PreReallocCallback; + virtual SIZE_T STDMETHODCALLTYPE PreRealloc(void* ptr, SIZE_T requestSize, void** newPtr, BOOL wasSpyed) override + { + *newPtr = ptr; + if (wasSpyed && PreReallocCallback) + { + return PreReallocCallback(ptr, requestSize, newPtr); + } + + return requestSize; + } + + wistd::function PostReallocCallback; + virtual void* STDMETHODCALLTYPE PostRealloc(void* ptr, BOOL wasSpyed) override + { + if (wasSpyed && PostReallocCallback) + { + return PostReallocCallback(ptr); + } + + return ptr; + } + + wistd::function PreGetSizeCallback; + virtual void* STDMETHODCALLTYPE PreGetSize(void* ptr, BOOL wasSpyed) override + { + if (wasSpyed && PreGetSizeCallback) + { + return PreGetSizeCallback(ptr); + } + + return ptr; + } + + wistd::function PostGetSizeCallback; + virtual SIZE_T STDMETHODCALLTYPE PostGetSize(SIZE_T size, BOOL wasSpyed) override + { + if (wasSpyed && PostGetSizeCallback) + { + return PostGetSizeCallback(size); + } + + return size; + } + + wistd::function PreDidAllocCallback; + virtual void* STDMETHODCALLTYPE PreDidAlloc(void* ptr, BOOL wasSpyed) override + { + if (wasSpyed && PreDidAllocCallback) + { + return PreDidAllocCallback(ptr); + } + + return ptr; + } + + virtual int STDMETHODCALLTYPE PostDidAlloc(void* /*ptr*/, BOOL /*wasSpyed*/, int result) override + { + return result; + } + + virtual void STDMETHODCALLTYPE PreHeapMinimize() override + { + } + + virtual void STDMETHODCALLTYPE PostHeapMinimize() override + { + } +}; + +Microsoft::WRL::ComPtr MakeSecureDeleterMallocSpy() +{ + using namespace Microsoft::WRL; + auto result = Make(); + REQUIRE(result); + + result->PreFreeCallback = [](void* ptr) + { + ComPtr malloc; + if (SUCCEEDED(::CoGetMalloc(1, &malloc))) + { + auto size = malloc->GetSize(ptr); + auto buffer = static_cast(ptr); + for (size_t i = 0; i < size; ++i) + { + REQUIRE(buffer[i] == 0); + } + } + + return ptr; + }; + + return result; +} + +#endif diff --git a/tests/ResourceTests.cpp b/tests/ResourceTests.cpp new file mode 100644 index 0000000..50edf6f --- /dev/null +++ b/tests/ResourceTests.cpp @@ -0,0 +1,598 @@ + +// Included first and then again later to ensure that we're able to "light up" new functionality based off new includes +#include + +// Headers to "light up" functionality in resource.h +#include +#include +#include + +#include +#include +#include +#include + +#include "common.h" + +TEST_CASE("ResourceTests::TestScopeExit", "[resource][scope_exit]") +{ + int count = 0; + auto validate = [&](int expected) { REQUIRE(count == expected); count = 0; }; + + { + auto foo = wil::scope_exit([&] { count++; }); + } + validate(1); + + { + auto foo = wil::scope_exit([&] { count++; }); + foo.release(); + foo.reset(); + } + validate(0); + + { + auto foo = wil::scope_exit([&] { count++; }); + foo.reset(); + foo.reset(); + validate(1); + } + validate(0); + +#ifdef WIL_ENABLE_EXCEPTIONS + { + auto foo = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { count++; THROW_HR(E_FAIL); }); + } + validate(1); + + { + auto foo = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { count++; THROW_HR(E_FAIL); }); + foo.release(); + foo.reset(); + } + validate(0); + + { + auto foo = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { count++; THROW_HR(E_FAIL); }); + foo.reset(); + foo.reset(); + validate(1); + } + validate(0); +#endif // WIL_ENABLE_EXCEPTIONS +} + +interface __declspec(uuid("ececcc6a-5193-4d14-b38e-ed1460c20b00")) +ITest : public IUnknown +{ + STDMETHOD_(void, Test)() = 0; +}; + +class PointerTestObject : witest::AllocatedObject, + public Microsoft::WRL::RuntimeClass, ITest> +{ +public: + STDMETHOD_(void, Test)() {}; +}; + +TEST_CASE("ResourceTests::TestOperationsOnGenericSmartPointerClasses", "[resource]") +{ +#ifdef WIL_ENABLE_EXCEPTIONS + { + // wil::unique_any_t example + wil::unique_event ptr2(wil::EventOptions::ManualReset); + // wil::com_ptr + wil::com_ptr ptr3 = Microsoft::WRL::Make(); + // wil::shared_any_t example + wil::shared_event ptr4(wil::EventOptions::ManualReset); + // wistd::unique_ptr example + auto ptr5 = wil::make_unique_failfast(); + + static_assert(wistd::is_same::pointer, HANDLE>::value, "type-mismatch"); + static_assert(wistd::is_same::pointer, PointerTestObject*>::value, "type-mismatch"); + + auto p2 = wil::detach_from_smart_pointer(ptr2); + auto p3 = wil::detach_from_smart_pointer(ptr3); + // auto p4 = wil::detach_from_smart_pointer(ptr4); // wil::shared_any_t and std::shared_ptr do not support release(). + HANDLE p4{}; + auto p5 = wil::detach_from_smart_pointer(ptr5); + + REQUIRE((!ptr2 && !ptr3)); + REQUIRE((p2 && p3)); + + wil::attach_to_smart_pointer(ptr2, p2); + wil::attach_to_smart_pointer(ptr3, p3); + wil::attach_to_smart_pointer(ptr4, p4); + wil::attach_to_smart_pointer(ptr5, p5); + + p2 = nullptr; + p3 = nullptr; + p4 = nullptr; + p5 = nullptr; + + wil::detach_to_opt_param(&p2, ptr2); + wil::detach_to_opt_param(&p3, ptr3); + + REQUIRE((!ptr2 && !ptr3)); + REQUIRE((p2 && p3)); + + wil::attach_to_smart_pointer(ptr2, p2); + wil::attach_to_smart_pointer(ptr3, p3); + p2 = nullptr; + p3 = nullptr; + + wil::detach_to_opt_param(&p2, ptr2); + wil::detach_to_opt_param(&p3, ptr3); + REQUIRE((!ptr2 && !ptr3)); + REQUIRE((p2 && p3)); + + [&](decltype(p2)* ptr) { *ptr = p2; } (wil::out_param(ptr2)); + [&](decltype(p3)* ptr) { *ptr = p3; } (wil::out_param(ptr3)); + [&](decltype(p4)* ptr) { *ptr = p4; } (wil::out_param(ptr4)); + [&](decltype(p5)* ptr) { *ptr = p5; } (wil::out_param(ptr5)); + + REQUIRE((ptr2 && ptr3)); + + // Validate R-Value compilation + wil::detach_to_opt_param(&p2, decltype(ptr2){}); + wil::detach_to_opt_param(&p3, decltype(ptr3){}); + } +#endif + + std::unique_ptr ptr1(new int(1)); + Microsoft::WRL::ComPtr ptr4 = Microsoft::WRL::Make(); + + static_assert(wistd::is_same::pointer, int*>::value, "type-mismatch"); + static_assert(wistd::is_same::pointer, PointerTestObject*>::value, "type-mismatch"); + + auto p1 = wil::detach_from_smart_pointer(ptr1); + auto p4 = wil::detach_from_smart_pointer(ptr4); + + REQUIRE((!ptr1 && !ptr4)); + REQUIRE((p1 && p4)); + + wil::attach_to_smart_pointer(ptr1, p1); + wil::attach_to_smart_pointer(ptr4, p4); + + REQUIRE((ptr1 && ptr4)); + + p1 = nullptr; + p4 = nullptr; + + int** pNull = nullptr; + wil::detach_to_opt_param(pNull, ptr1); + REQUIRE(ptr1); + + wil::detach_to_opt_param(&p1, ptr1); + wil::detach_to_opt_param(&p4, ptr4); + + REQUIRE((!ptr1 && !ptr4)); + REQUIRE((p1 && p4)); + + [&](decltype(p1)* ptr) { *ptr = p1; } (wil::out_param(ptr1)); + [&](decltype(p4)* ptr) { *ptr = p4; } (wil::out_param(ptr4)); + + REQUIRE((ptr1 && ptr4)); + + p1 = wil::detach_from_smart_pointer(ptr1); + [&](int** ptr) { *ptr = p1; } (wil::out_param_ptr(ptr1)); + REQUIRE(ptr1); +} + +// Compilation only test... +void StlAdlTest() +{ + // This test has exposed some Argument Dependent Lookup issues in wistd / stl. Primarily we're + // just looking for clean compilation. + + std::vector> v; + v.emplace_back(new int{ 1 }); + v.emplace_back(new int{ 2 }); + v.emplace_back(new int{ 3 }); + std::rotate(begin(v), begin(v) + 1, end(v)); + + REQUIRE(*v[0] == 1); + REQUIRE(*v[1] == 3); + REQUIRE(*v[2] == 2); + + decltype(v) v2; + v2 = std::move(v); + REQUIRE(*v2[0] == 1); + REQUIRE(*v2[1] == 3); + REQUIRE(*v2[2] == 2); + + decltype(v) v3; + std::swap(v2, v3); + REQUIRE(*v3[0] == 1); + REQUIRE(*v3[1] == 3); + REQUIRE(*v3[2] == 2); +} + +// Compilation only test... +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +void UniqueProcessInfo() +{ + wil::unique_process_information process; + CreateProcessW(nullptr, nullptr, nullptr, nullptr, FALSE, 0, nullptr, nullptr, nullptr, &process); + ResumeThread(process.hThread); + WaitForSingleObject(process.hProcess, INFINITE); + wil::unique_process_information other(wistd::move(process)); +} +#endif + +struct FakeComInterface +{ + void AddRef() + { + refs++; + } + void Release() + { + refs--; + } + + HRESULT __stdcall Close() + { + closes++; + return S_OK; + } + + size_t refs = 0; + size_t closes = 0; + + bool called() + { + auto old = closes; + closes = 0; + return (old > 0); + } + + bool has_ref() + { + return (refs > 0); + } +}; + +static void __stdcall CloseFakeComInterface(FakeComInterface* fake) +{ + fake->Close(); +} + +using unique_fakeclose_call = wil::unique_com_call; + +TEST_CASE("ResourceTests::VerifyUniqueComCall", "[resource][unique_com_call]") +{ + unique_fakeclose_call call1; + unique_fakeclose_call call2; + + // intentional compilation errors + // unique_fakeclose_call call3 = call1; + // call2 = call1; + + FakeComInterface fake1; + unique_fakeclose_call call4(&fake1); + REQUIRE(fake1.has_ref()); + + unique_fakeclose_call call5(wistd::move(call4)); + REQUIRE(!call4); + REQUIRE(call5); + REQUIRE(fake1.has_ref()); + + call4 = wistd::move(call5); + REQUIRE(call4); + REQUIRE(!call5); + REQUIRE(fake1.has_ref()); + REQUIRE(!fake1.called()); + + FakeComInterface fake2; + { + unique_fakeclose_call scoped(&fake2); + } + REQUIRE(!fake2.has_ref()); + REQUIRE(fake2.called()); + + call4.reset(&fake2); + REQUIRE(fake1.called()); + REQUIRE(!fake1.has_ref()); + call4.reset(); + REQUIRE(!fake2.has_ref()); + REQUIRE(fake2.called()); + + call1.reset(&fake1); + call2.swap(call1); + REQUIRE((call2 && !call1)); + + call2.release(); + REQUIRE(!fake1.called()); + REQUIRE(!fake1.has_ref()); + REQUIRE(!call2); + + REQUIRE(*call1.addressof() == nullptr); + + call1.reset(&fake1); + + fake2.closes = 0; + fake2.refs = 1; + *(&call1) = &fake2; + REQUIRE(!fake1.has_ref()); + REQUIRE(fake1.called()); + REQUIRE(fake2.has_ref()); + call1.reset(); + REQUIRE(!fake2.has_ref()); + REQUIRE(fake2.called()); +} + +static bool g_called = false; +static bool called() +{ + auto call = g_called; + g_called = false; + return (call); +} + +static void __stdcall FakeCall() +{ + g_called = true; +} + +using unique_fake_call = wil::unique_call; + +TEST_CASE("ResourceTests::VerifyUniqueCall", "[resource][unique_call]") +{ + unique_fake_call call1; + unique_fake_call call2; + + // intentional compilation errors + // unique_fake_call call3 = call1; + // call2 = call1; + + unique_fake_call call4; + REQUIRE(!called()); + + unique_fake_call call5(wistd::move(call4)); + REQUIRE(!call4); + REQUIRE(call5); + + call4 = wistd::move(call5); + REQUIRE(call4); + REQUIRE(!call5); + REQUIRE(!called()); + + { + unique_fake_call scoped; + } + REQUIRE(called()); + + call4.reset(); + REQUIRE(called()); + call4.reset(); + REQUIRE(!called()); + + call1.release(); + REQUIRE((!call1 && call2)); + call2.swap(call1); + REQUIRE((call1 && !call2)); + + call2.release(); + REQUIRE(!called()); + REQUIRE(!call2); + +#ifdef __WIL__ROAPI_H_APPEXCEPTIONAL + { + auto call = wil::RoInitialize(); + } +#endif +#ifdef __WIL__ROAPI_H_APP + { + wil::unique_rouninitialize_call uninit; + uninit.release(); + + auto call = wil::RoInitialize_failfast(); + } +#endif +#ifdef __WIL__COMBASEAPI_H_APPEXCEPTIONAL + { + auto call = wil::CoInitializeEx(); + } +#endif +#ifdef __WIL__COMBASEAPI_H_APP + { + wil::unique_couninitialize_call uninit; + uninit.release(); + + auto call = wil::CoInitializeEx_failfast(); + } +#endif +} + +void UniqueCallCompilationTest() +{ +#ifdef __WIL__COMBASEAPI_H_EXCEPTIONAL + { + auto call = wil::CoImpersonateClient(); + } +#endif +#ifdef __WIL__COMBASEAPI_H_ + { + wil::unique_coreverttoself_call uninit; + uninit.release(); + + auto call = wil::CoImpersonateClient_failfast(); + } +#endif +} + +template +static void TestStringMaker(VerifyContents&& verifyContents) +{ + PCWSTR values[] = + { + L"", + L"value", + // 300 chars + L"0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + L"0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + L"0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + }; + + for (const auto& value : values) + { + auto const valueLength = wcslen(value); + + // Direct construction case. + wil::details::string_maker maker; + THROW_IF_FAILED(maker.make(value, valueLength)); + auto result = maker.release(); + verifyContents(value, valueLength, result); + + // Two phase construction case. + THROW_IF_FAILED(maker.make(nullptr, valueLength)); + REQUIRE(maker.buffer() != nullptr); + // In the case of the wil::unique_hstring and the empty string the buffer is in a read only + // section and can't be written to, so StringCchCopy(maker.buffer(), valueLength + 1, value) will fault adding the nul terminator. + // Use memcpy_s specifying exact size that will be zero in this case instead. + memcpy_s(maker.buffer(), valueLength * sizeof(*value), value, valueLength * sizeof(*value)); + result = maker.release(); + verifyContents(value, valueLength, result); + + { + // no promote, ensure no leaks (not tested here, inspect in the debugger) + wil::details::string_maker maker2; + THROW_IF_FAILED(maker2.make(value, valueLength)); + } + } +} + +#ifdef WIL_ENABLE_EXCEPTIONS +template +static void VerifyMakeUniqueString(bool nullValueSupported = true) +{ + if (nullValueSupported) + { + auto value0 = wil::make_unique_string(nullptr, 5); + } + + struct + { + PCWSTR expectedValue; + PCWSTR testValue; + // this is an optional parameter + size_t testLength = static_cast(-1); + } + const testCaseEntries[] = + { + { L"value", L"value", 5 }, + { L"value", L"value" }, + { L"va", L"va\0ue", 5 }, + { L"v", L"value", 1 }, + { L"\0", L"", 5 }, + { L"\0", nullptr, 5 }, + }; + + using maker = wil::details::string_maker; + for (auto const &entry : testCaseEntries) + { + bool shouldSkipNullString = ((wcscmp(entry.expectedValue, L"\0") == 0) && !nullValueSupported); + if (!shouldSkipNullString) + { + auto desiredValue = wil::make_unique_string(entry.expectedValue); + auto stringValue = wil::make_unique_string(entry.testValue, entry.testLength); + auto stringValueNoThrow = wil::make_unique_string_nothrow(entry.testValue, entry.testLength); + auto stringValueFailFast = wil::make_unique_string_failfast(entry.testValue, entry.testLength); + REQUIRE(wcscmp(maker::get(desiredValue), maker::get(stringValue)) == 0); + REQUIRE(wcscmp(maker::get(desiredValue), maker::get(stringValueNoThrow)) == 0); + REQUIRE(wcscmp(maker::get(desiredValue), maker::get(stringValueFailFast)) == 0); + } + } +} + +TEST_CASE("UniqueStringAndStringMakerTests::VerifyStringMakerCoTaskMem", "[resource][string_maker]") +{ + VerifyMakeUniqueString(); + TestStringMaker( + [](PCWSTR value, size_t /*valueLength*/, const wil::unique_cotaskmem_string& result) + { + REQUIRE(wcscmp(value, result.get()) == 0); + }); +} + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +TEST_CASE("UniqueStringAndStringMakerTests::VerifyStringMakerLocalAlloc", "[resource][string_maker]") +{ + VerifyMakeUniqueString(); + TestStringMaker( + [](PCWSTR value, size_t /*valueLength*/, const wil::unique_hlocal_string& result) + { + REQUIRE(wcscmp(value, result.get()) == 0); + }); +} + +TEST_CASE("UniqueStringAndStringMakerTests::VerifyStringMakerGlobalAlloc", "[resource][string_maker]") +{ + VerifyMakeUniqueString(); + TestStringMaker( + [](PCWSTR value, size_t /*valueLength*/, const wil::unique_hglobal_string& result) + { + REQUIRE(wcscmp(value, result.get()) == 0); + }); +} + +TEST_CASE("UniqueStringAndStringMakerTests::VerifyStringMakerProcessHeap", "[resource][string_maker]") +{ + VerifyMakeUniqueString(); + TestStringMaker( + [](PCWSTR value, size_t /*valueLength*/, const wil::unique_process_heap_string& result) + { + REQUIRE(wcscmp(value, result.get()) == 0); + }); +} +#endif + +TEST_CASE("UniqueStringAndStringMakerTests::VerifyStringMakerHString", "[resource][string_maker]") +{ + wil::unique_hstring value; + value.reset(static_cast(nullptr)); + + VerifyMakeUniqueString(false); + + TestStringMaker( + [](PCWSTR value, size_t valueLength, const wil::unique_hstring& result) + { + UINT32 length; + REQUIRE(wcscmp(value, WindowsGetStringRawBuffer(result.get(), &length)) == 0); + REQUIRE(valueLength == length); + }); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("UniqueStringAndStringMakerTests::VerifyStringMakerStdWString", "[resource][string_maker]") +{ + std::string s; + wil::details::string_maker maker; + + TestStringMaker( + [](PCWSTR value, size_t valueLength, const std::wstring& result) + { + REQUIRE(wcscmp(value, result.c_str()) == 0); + REQUIRE(result == value); + REQUIRE(result.size() == valueLength); + }); +} +#endif + +TEST_CASE("UniqueStringAndStringMakerTests::VerifyLegacySTringMakers", "[resource][string_maker]") +{ +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + auto l = wil::make_hlocal_string(L"value"); + l = wil::make_hlocal_string_nothrow(L"value"); + l = wil::make_hlocal_string_failfast(L"value"); + + auto p = wil::make_process_heap_string(L"value"); + p = wil::make_process_heap_string_nothrow(L"value"); + p = wil::make_process_heap_string_failfast(L"value"); +#endif + auto c = wil::make_cotaskmem_string(L"value"); + c = wil::make_cotaskmem_string_nothrow(L"value"); + c = wil::make_cotaskmem_string_failfast(L"value"); +} +#endif diff --git a/tests/ResultTests.cpp b/tests/ResultTests.cpp new file mode 100644 index 0000000..f991589 --- /dev/null +++ b/tests/ResultTests.cpp @@ -0,0 +1,574 @@ + +#include +#include +#include +#include + +#include "common.h" + +static volatile long objectCount = 0; +struct SharedObject +{ + SharedObject() + { + ::InterlockedIncrement(&objectCount); + } + + ~SharedObject() + { + ::InterlockedDecrement(&objectCount); + } + + void ProcessShutdown() + { + } + + int value; +}; + +TEST_CASE("ResultTests::SemaphoreValue", "[result]") +{ + auto TestValue = [&](auto start, auto end) + { + wil::details_abi::SemaphoreValue semaphore; + for (auto index = start; index <= end; index++) + { + semaphore.Destroy(); + REQUIRE(SUCCEEDED(semaphore.CreateFromValue(L"test", index))); + + auto num1 = index; + auto num2 = index; + REQUIRE(SUCCEEDED(semaphore.TryGetValue(L"test", &num1))); + REQUIRE(SUCCEEDED(semaphore.TryGetValue(L"test", &num2))); + REQUIRE(num1 == index); + REQUIRE(num2 == index); + } + }; + + // Test 32-bit values (edge cases) + TestValue(0u, 10u); + TestValue(250u, 260u); + TestValue(0x7FFFFFF0u, 0x7FFFFFFFu); + + // Test 64-bit values (edge cases) + TestValue(0ull, 10ull); + TestValue(250ull, 260ull); + TestValue(0x000000007FFFFFF0ull, 0x000000008000000Full); + TestValue(0x00000000FFFFFFF0ull, 0x000000010000000Full); + TestValue(0x00000000FFFFFFF0ull, 0x000000010000000Full); + TestValue(0x3FFFFFFFFFFFFFF0ull, 0x3FFFFFFFFFFFFFFFull); + + // Test pointer values + wil::details_abi::SemaphoreValue semaphore; + void* address = &semaphore; + REQUIRE(SUCCEEDED(semaphore.CreateFromPointer(L"test", address))); + void* ptr; + REQUIRE(SUCCEEDED(semaphore.TryGetPointer(L"test", &ptr))); + REQUIRE(ptr == address); +} + +TEST_CASE("ResultTests::ProcessLocalStorage", "[result]") +{ + // Test process local storage memory and ref-counting + { + wil::details_abi::ProcessLocalStorage obj1("ver1"); + wil::details_abi::ProcessLocalStorage obj2("ver1"); + + auto& o1 = *obj1.GetShared(); + auto& o2 = *obj2.GetShared(); + + REQUIRE(o1.value == 0); + REQUIRE(o2.value == 0); + o1.value = 42; + REQUIRE(o2.value == 42); + REQUIRE(objectCount == 1); + + wil::details_abi::ProcessLocalStorage obj3("ver3"); + auto& o3 = *obj3.GetShared(); + + REQUIRE(o3.value == 0); + REQUIRE(objectCount == 2); + } + + REQUIRE(objectCount == 0); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +#pragma warning(push) +#pragma warning(disable: 4702) // Unreachable code +TEST_CASE("ResultTests::ExceptionHandling", "[result]") +{ + witest::TestFailureCache failures; + + SECTION("Test 'what()' implementation on ResultException") + { + auto swap = witest::AssignTemporaryValue(&wil::g_fResultThrowPlatformException, false); + try + { + THROW_HR(E_INVALIDARG); + FAIL("Expected an exception"); + } + catch (const std::exception& exception) + { + REQUIRE(failures.size() == 1); + REQUIRE(failures[0].hr == E_INVALIDARG); + auto what = exception.what(); + REQUIRE((what && *what)); + REQUIRE(strstr(what, "Exception") != nullptr); + } + } + failures.clear(); + + SECTION("Test messaging from an unhandled std exception") + { + // #pragma warning(suppress: 28931) // unused assignment -- it IS being used... seems like a tool issue. + auto hr = []() + { + try + { + throw std::runtime_error("runtime"); + } + catch (...) + { + RETURN_CAUGHT_EXCEPTION(); + } + }(); + REQUIRE(failures.size() == 1); + REQUIRE(failures[0].hr == HRESULT_FROM_WIN32(ERROR_UNHANDLED_EXCEPTION)); + REQUIRE(wcsstr(failures[0].pszMessage, L"runtime") != nullptr); // should get the exception what() string... + REQUIRE(hr == HRESULT_FROM_WIN32(ERROR_UNHANDLED_EXCEPTION)); + } + failures.clear(); + + SECTION("Test messaging from bad_alloc") + { + auto hr = []() -> HRESULT + { + try + { + throw std::bad_alloc(); + } + catch (...) + { + RETURN_CAUGHT_EXCEPTION(); + } + }(); + REQUIRE(failures.size() == 1); + REQUIRE(failures[0].hr == E_OUTOFMEMORY); + REQUIRE(wcsstr(failures[0].pszMessage, L"alloc") != nullptr); // should get the exception what() string... + REQUIRE(hr == E_OUTOFMEMORY); + } + failures.clear(); + + SECTION("Test messaging from a WIL exception") + { + auto hr = []() -> HRESULT + { + try + { + THROW_HR(E_INVALIDARG); + } + catch (...) + { + RETURN_CAUGHT_EXCEPTION(); + } + return S_OK; + }(); + REQUIRE(failures.size() == 2); + REQUIRE(failures[0].hr == E_INVALIDARG); + REQUIRE(failures[0].pszMessage == nullptr); + REQUIRE(failures[1].hr == E_INVALIDARG); + REQUIRE(wcsstr(failures[1].pszMessage, L"Exception") != nullptr); // should get the exception debug string... + REQUIRE(hr == E_INVALIDARG); + } + failures.clear(); + + SECTION("Fail fast an unknown exception") + { + REQUIRE(witest::DoesCodeCrash([]() + { + try + { + throw E_INVALIDARG; // bad throw... (long) + } + catch (...) + { + RETURN_CAUGHT_EXCEPTION(); + } + })); + } + failures.clear(); + + SECTION("Log test (returns hr)") + { + HRESULT hr = S_OK; + try + { + throw std::bad_alloc(); + } + catch (...) + { + hr = LOG_CAUGHT_EXCEPTION(); + auto hrDirect = wil::ResultFromCaughtException(); + REQUIRE(hr == hrDirect); + } + REQUIRE(failures.size() == 1); + REQUIRE(failures[0].hr == E_OUTOFMEMORY); + REQUIRE(wcsstr(failures[0].pszMessage, L"alloc") != nullptr); // should get the exception what() string... + REQUIRE(hr == E_OUTOFMEMORY); + } + failures.clear(); + + SECTION("Fail-fast test") + { + REQUIRE_CRASH([]() + { + try + { + throw std::bad_alloc(); + } + catch (...) + { + FAIL_FAST_CAUGHT_EXCEPTION(); + } + }()); + } + failures.clear(); + + SECTION("Exception test (different exception type thrown...)") + { + auto swap = witest::AssignTemporaryValue(&wil::g_fResultThrowPlatformException, false); + size_t line = 0; + try + { + try + { + throw std::bad_alloc(); + } + catch (...) + { + line = __LINE__; THROW_NORMALIZED_CAUGHT_EXCEPTION(); + } + } + catch (const wil::ResultException& exception) + { + REQUIRE(exception.GetFailureInfo().uLineNumber == line); // should have thrown new, so we should have the rethrow line number + REQUIRE(exception.GetErrorCode() == E_OUTOFMEMORY); + } + catch (...) + { + FAIL(); + } + REQUIRE(failures.size() == 1); + REQUIRE(failures[0].hr == E_OUTOFMEMORY); + REQUIRE(wcsstr(failures[0].pszMessage, L"alloc") != nullptr); // should get the exception what() string... + } + failures.clear(); + + SECTION("Exception test (rethrow same exception type...)") + { + auto swap = witest::AssignTemporaryValue(&wil::g_fResultThrowPlatformException, false); + size_t line = 0; + try + { + try + { + line = __LINE__; THROW_HR(E_OUTOFMEMORY); + } + catch (...) + { + THROW_NORMALIZED_CAUGHT_EXCEPTION(); + } + } + catch (const wil::ResultException& exception) + { + REQUIRE(exception.GetFailureInfo().uLineNumber == line); // should have re-thrown the original exception (with the original line number) + } + catch (...) + { + FAIL(); + } + } + failures.clear(); + + SECTION("Test catch message") + { + try + { + throw std::bad_alloc(); + } + catch (...) + { + LOG_CAUGHT_EXCEPTION_MSG("train: %d", 42); + } + REQUIRE(failures.size() == 1); + REQUIRE(failures[0].hr == E_OUTOFMEMORY); + REQUIRE(wcsstr(failures[0].pszMessage, L"alloc") != nullptr); // should get the exception what() string... + REQUIRE(wcsstr(failures[0].pszMessage, L"train") != nullptr); // should *also* get the message... + REQUIRE(wcsstr(failures[0].pszMessage, L"42") != nullptr); + } + failures.clear(); + + SECTION("Test messaging from a WIL exception") + { + auto hr = []() -> HRESULT + { + try + { + throw std::bad_alloc(); + } + catch (...) + { + RETURN_CAUGHT_EXCEPTION_EXPECTED(); + } + }(); + REQUIRE(failures.size() == 0); + REQUIRE(hr == E_OUTOFMEMORY); + } + failures.clear(); + + SECTION("Test ResultFromException...") + { + auto hrOk = wil::ResultFromException([&] + { + }); + REQUIRE(hrOk == S_OK); + + auto hr = wil::ResultFromException([&] + { + throw std::bad_alloc(); + }); + REQUIRE(failures.size() == 0); + REQUIRE(hr == E_OUTOFMEMORY); + } + failures.clear(); + + SECTION("Explicit failfast for unrecognized") + { + REQUIRE_CRASH(wil::ResultFromException([&] + { + throw E_FAIL; + })); + } + failures.clear(); + + SECTION("Manual debug-only validation of the SEH failfast") + { + auto hr1 = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, [&]() + { + // Uncomment to test SEH fail-fast + // throw E_FAIL; + }); + REQUIRE(hr1 == S_OK); + + auto hr2 = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, wil::SupportedExceptions::Thrown, [&] + { + // Uncomment to test SEH fail-fast + // throw std::range_error("range"); + }); + REQUIRE(hr2 == S_OK); + + wil::FailFastException(WI_DIAGNOSTICS_INFO, [&] + { + // Uncomment to test SEH fail-fast + // THROW_HR(E_FAIL); + }); + } + failures.clear(); + + SECTION("Standard") + { + auto line = __LINE__; auto hr = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, [&] + { + THROW_HR(E_INVALIDARG); + }); + REQUIRE(failures.size() == 2); + REQUIRE(static_cast(failures[1].uLineNumber) == line); + REQUIRE(hr == E_INVALIDARG); + } + failures.clear(); + + SECTION("bad_alloc") + { + auto hr = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, [&] + { + throw std::bad_alloc(); + }); + REQUIRE(failures.size() == 1); + REQUIRE(hr == E_OUTOFMEMORY); + } + failures.clear(); + + SECTION("std::exception") + { + auto hr = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, [&] + { + throw std::range_error("range"); + }); + REQUIRE(failures.size() == 1); + REQUIRE(wcsstr(failures[0].pszMessage, L"range") != nullptr); + REQUIRE(hr == HRESULT_FROM_WIN32(ERROR_UNHANDLED_EXCEPTION)); + } +} + +void ExceptionHandlingCompilationTest() +{ + []{ try { throw std::bad_alloc(); } CATCH_RETURN(); }(); + []{ try { throw std::bad_alloc(); } CATCH_RETURN_MSG("train: %d", 42); }(); + []{ try { throw std::bad_alloc(); } CATCH_RETURN_EXPECTED(); }(); + []{ try { throw std::bad_alloc(); } catch (...) { RETURN_CAUGHT_EXCEPTION(); } }(); + []{ try { throw std::bad_alloc(); } catch (...) { RETURN_CAUGHT_EXCEPTION_MSG("train: %d", 42); } }(); + []{ try { throw std::bad_alloc(); } catch (...) { RETURN_CAUGHT_EXCEPTION_EXPECTED(); } }(); + + try { throw std::bad_alloc(); } CATCH_LOG(); + try { throw std::bad_alloc(); } CATCH_LOG_MSG("train: %d", 42); + try { throw std::bad_alloc(); } catch (...) { LOG_CAUGHT_EXCEPTION(); } + try { throw std::bad_alloc(); } catch (...) { LOG_CAUGHT_EXCEPTION_MSG("train: %d", 42); } + + try { throw std::bad_alloc(); } CATCH_FAIL_FAST(); + try { throw std::bad_alloc(); } CATCH_FAIL_FAST_MSG("train: %d", 42); + try { throw std::bad_alloc(); } catch (...) { FAIL_FAST_CAUGHT_EXCEPTION(); } + try { throw std::bad_alloc(); } catch (...) { FAIL_FAST_CAUGHT_EXCEPTION_MSG("train: %d", 42); } + + try { try { throw std::bad_alloc(); } CATCH_THROW_NORMALIZED(); } catch (...) {} + try { try { throw std::bad_alloc(); } CATCH_THROW_NORMALIZED_MSG("train: %d", 42); } catch (...) {} + try { try { throw std::bad_alloc(); } catch (...) { THROW_NORMALIZED_CAUGHT_EXCEPTION(); } } catch (...) {} + try { try { throw std::bad_alloc(); } catch (...) { THROW_NORMALIZED_CAUGHT_EXCEPTION_MSG("train: %d", 42); } } catch (...) {} + + HRESULT hr = wil::ResultFromExceptionDebug(WI_DIAGNOSTICS_INFO, wil::SupportedExceptions::All, [&] + { + THROW_HR(E_FAIL); + }); + + hr = wil::ResultFromException(WI_DIAGNOSTICS_INFO, wil::SupportedExceptions::None, [&] + { + }); + + hr = wil::ResultFromException([&] + { + }); + + wil::FailFastException(WI_DIAGNOSTICS_INFO, [&] + { + }); +} +#pragma warning(pop) +#endif + +TEST_CASE("ResultTests::ErrorMacros", "[result]") +{ + REQUIRE_ERROR(FAIL_FAST()); + REQUIRE_ERROR(FAIL_FAST_IF(true)); + REQUIRE_ERROR(FAIL_FAST_IF_NULL(nullptr)); + + REQUIRE_NOERROR(FAIL_FAST_IF(false)); + REQUIRE_NOERROR(FAIL_FAST_IF_NULL(_ReturnAddress())); + + REQUIRE_ERROR(FAIL_FAST_MSG("%d", 42)); + REQUIRE_ERROR(FAIL_FAST_IF_MSG(true, "%d", 42)); + REQUIRE_ERROR(FAIL_FAST_IF_NULL_MSG(nullptr, "%d", 42)); + + REQUIRE_NOERROR(FAIL_FAST_IF_MSG(false, "%d", 42)); + REQUIRE_NOERROR(FAIL_FAST_IF_NULL_MSG(_ReturnAddress(), "%d", 42)); + + //wil::g_pfnResultLoggingCallback = ResultMacrosLoggingCallback; + SetLastError(ERROR_PRINTER_ALREADY_EXISTS); + REQUIRE_ERROR(__FAIL_FAST_ASSERT_WIN32_BOOL_FALSE__(FALSE)); + REQUIRE_NOERROR(__FAIL_FAST_ASSERT_WIN32_BOOL_FALSE__(TRUE)); +} + +// The originate helper isn't compatible with CX so don't test it in that mode. +#ifndef __cplusplus_winrt +TEST_CASE("ResultTests::NoOriginationByDefault", "[result]") +{ + ::wil::SetOriginateErrorCallback(nullptr); + wil::com_ptr_nothrow restrictedErrorInformation; + + // We can't guarantee test order, so clear the error payload prior to starting + SetRestrictedErrorInfo(nullptr); + + []() -> HRESULT + { + RETURN_HR(S_OK); + }(); + REQUIRE(S_FALSE == GetRestrictedErrorInfo(&restrictedErrorInformation)); + +#ifdef WIL_ENABLE_EXCEPTIONS + try + { + THROW_HR(E_FAIL); + } + catch (...) {} + REQUIRE(S_FALSE == GetRestrictedErrorInfo(&restrictedErrorInformation)); +#endif // WIL_ENABLE_EXCEPTIONS + + []() -> HRESULT + { + RETURN_HR(E_FAIL); + }(); + REQUIRE(S_FALSE == GetRestrictedErrorInfo(&restrictedErrorInformation)); + + []() -> HRESULT + { + RETURN_IF_FAILED_EXPECTED(E_ACCESSDENIED); + return S_OK; + }(); + REQUIRE(S_FALSE == GetRestrictedErrorInfo(&restrictedErrorInformation)); +} + +TEST_CASE("ResultTests::AutomaticOriginationOnFailure", "[result]") +{ + ::wil::SetOriginateErrorCallback(::wil::details::RaiseRoOriginateOnWilExceptions); + wil::com_ptr_nothrow restrictedErrorInformation; + + // Make sure we don't start with an error payload + SetRestrictedErrorInfo(nullptr); + + // Success codes shouldn't originate. + []() + { + RETURN_HR(S_OK); + }(); + REQUIRE(S_FALSE == GetRestrictedErrorInfo(&restrictedErrorInformation)); + + auto validateOriginatedError = [&](HRESULT hrExpected) + { + wil::unique_bstr descriptionUnused; + HRESULT existingHr = S_OK; + wil::unique_bstr restrictedDescriptionUnused; + wil::unique_bstr capabilitySidUnused; + REQUIRE_SUCCEEDED(restrictedErrorInformation->GetErrorDetails(&descriptionUnused, &existingHr, &restrictedDescriptionUnused, &capabilitySidUnused)); + REQUIRE(hrExpected == existingHr); + }; + +#ifdef WIL_ENABLE_EXCEPTIONS + // Throwing an error should originate. + constexpr HRESULT thrownErrorCode = TYPE_E_ELEMENTNOTFOUND; + try + { + THROW_HR(thrownErrorCode); + } + catch (...) {} + REQUIRE(S_OK == GetRestrictedErrorInfo(&restrictedErrorInformation)); + validateOriginatedError(thrownErrorCode); +#endif // WIL_ENABLE_EXCEPTIONS + + // Returning an error code should originate. + static constexpr HRESULT returnedErrorCode = REGDB_E_CLASSNOTREG; + []() + { + RETURN_HR(returnedErrorCode); + }(); + REQUIRE(S_OK == GetRestrictedErrorInfo(&restrictedErrorInformation)); + validateOriginatedError(returnedErrorCode); + + // _EXPECTED errors should NOT originate. + static constexpr HRESULT expectedErrorCode = E_ACCESSDENIED; + []() + { + RETURN_IF_FAILED_EXPECTED(expectedErrorCode); + return S_OK; + }(); + REQUIRE(S_FALSE == GetRestrictedErrorInfo(&restrictedErrorInformation)); +} +#endif // __cplusplus_winrt diff --git a/tests/SafeCastTests.cpp b/tests/SafeCastTests.cpp new file mode 100644 index 0000000..a33ad3a --- /dev/null +++ b/tests/SafeCastTests.cpp @@ -0,0 +1,571 @@ + +#include "common.h" + +#include + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("SafeCastTests::SafeCastThrowsTemplateCheck", "[safecast]") +{ + // In all cases, a value of '1' should be cast-able to any signed or unsigned integral type without error + SECTION("Unqualified char") + { + char orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + // wil::safe_cast (orig); // No available conversion in intsafe + wil::safe_cast (orig); + wil::safe_cast (orig); + // wil::safe_cast (orig); // No available conversion in intsafe + wil::safe_cast (orig); + wil::safe_cast (orig); + // wil::safe_cast (orig); // No available conversion in intsafe + wil::safe_cast (orig); + // wil::safe_cast (orig); // No available conversion in intsafe + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + // wil::safe_cast (orig); // No available conversion in intsafe + wil::safe_cast<__int3264> (orig); + // wil::safe_cast(orig); // No available conversion in intsafe + // wil::safe_cast (orig); // No available conversion in intsafe + } + + SECTION("Signed char") + { + signed char orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unsigned char") + { + unsigned char orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unqualified short") + { + short orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Signed short") + { + signed short orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unsigned short") + { + unsigned short orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unqualified int") + { + int orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Signed int") + { + signed int orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unsigned int") + { + unsigned int orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unqualified long") + { + long orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unsigned log") + { + unsigned long orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unqualified int64") + { + __int64 orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Signed int64") + { + signed __int64 orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unsigned int64") + { + unsigned __int64 orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unqualified int3264") + { + __int3264 orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("Unsigned int3264") + { + unsigned __int3264 orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } + + SECTION("wchar_t") + { + wchar_t orig = 1; + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int64> (orig); + wil::safe_cast (orig); + wil::safe_cast (orig); + wil::safe_cast<__int3264> (orig); + wil::safe_cast(orig); + wil::safe_cast (orig); + } +} +#endif + +TEST_CASE("SafeCastTests::SafeCastFailFastSyntaxCheck", "[safecast]") +{ + SECTION("safe_cast_failfast safe") + { + INT i = INT_MAX; + LONG l = wil::safe_cast_failfast(i); + REQUIRE(l == INT_MAX); + } + + SECTION("safe_cast_failfast unsafe") + { + INT i = 0; + SHORT s = wil::safe_cast_failfast(i); + REQUIRE(s == 0); + } + + SECTION("safe_cast_failfast unsafe to wchar_t") + { + INT i = 0; + wchar_t wc = wil::safe_cast_failfast(i); + REQUIRE(wc == 0); + } + + SECTION("safe_cast_failfast unsafe from wchar_t") + { + wchar_t wc = 0; + unsigned char uc = wil::safe_cast_failfast(wc); + REQUIRE(uc == 0); + } +} + +TEST_CASE("SafeCastTests::SafeCastNoThrowSyntaxCheck", "[safecast]") +{ + SECTION("safe_cast_nothrow safe") + { + INT i = INT_MAX; + LONG l = wil::safe_cast_nothrow(i); + REQUIRE(l == INT_MAX); + } + + // safe_cast_nothrow one parameter unsafe, throws compiler error as expected + // { + // __int64 i64 = 0; + // int i = wil::safe_cast_nothrow(i64); + // } + + SECTION("safe_cast_nothrow two parameter potentially unsafe due to usage of variable sized types") + { + SIZE_T st = 0; + UINT ui; + auto result = wil::safe_cast_nothrow(st, &ui); + REQUIRE_SUCCEEDED(result); + REQUIRE(ui == 0); + } + + // safe_cast_nothrow two parameter known safe, throws compiler error as expected + // { + // unsigned char uc = 0; + // unsigned short us; + // auto result = wil::safe_cast_nothrow(uc, &us); + // } + + SECTION("safe_cast_nothrow unsafe") + { + INT i = 0; + SHORT s; + auto result = wil::safe_cast_nothrow(i, &s); + REQUIRE_SUCCEEDED(result); + REQUIRE(s == 0); + } + + SECTION("safe_cast_nothrow unsafe to wchar_t") + { + INT i = 0; + wchar_t wc; + auto result = wil::safe_cast_nothrow(i, &wc); + REQUIRE_SUCCEEDED(result); + REQUIRE(wc == 0); + } + + SECTION("safe_cast_nothrow unsafe from wchar_t") + { + wchar_t wc = 0; + unsigned char uc; + auto result = wil::safe_cast_nothrow(wc, &uc); + REQUIRE_SUCCEEDED(result); + REQUIRE(uc == 0); + } +} + +TEST_CASE("SafeCastTests::SafeCastNoFailures", "[safecast]") +{ + SECTION("INT -> LONG") + { + INT i = INT_MAX; + LONG l = wil::safe_cast_nothrow(i); + REQUIRE(l == INT_MAX); + } + + SECTION("LONG -> INT") + { + LONG l = LONG_MAX; + INT i = wil::safe_cast_nothrow(l); + REQUIRE(i == LONG_MAX); + } + + SECTION("INT -> UINT") + { + INT i = INT_MAX; + UINT ui = wil::safe_cast_failfast(i); + REQUIRE(ui == INT_MAX); + } + + SECTION("SIZE_T -> SIZE_T") + { + SIZE_T st = SIZE_T_MAX; + SIZE_T st2 = wil::safe_cast_failfast(st); + REQUIRE(st2 == SIZE_T_MAX); + } + + SECTION("wchar_t -> uint") + { + wchar_t wc = 0; + UINT ui = wil::safe_cast_failfast(wc); + REQUIRE(ui == 0); + } + + SECTION("wchar_t -> unsigned char") + { + wchar_t wc = 0; + unsigned char uc = wil::safe_cast_failfast(wc); + REQUIRE(uc == 0); + auto result = wil::safe_cast_nothrow(wc, &uc); + REQUIRE_SUCCEEDED(result); + } + + SECTION("uint -> wchar_t") + { + UINT ui = 0; + wchar_t wc = wil::safe_cast_failfast(ui); + REQUIRE(wc == 0); + auto result = wil::safe_cast_nothrow(ui, &wc); + REQUIRE_SUCCEEDED(result); + } + +#ifndef _WIN64 + SECTION("SIZE_T -> UINT") + { + SIZE_T st = SIZE_T_MAX; + UINT ui = wil::safe_cast_nothrow(st); + REQUIRE(ui == SIZE_T_MAX); + } +#endif +} + +TEST_CASE("SafeCastTests::SafeCastNoThrowFail", "[safecast]") +{ + SECTION("size_t -> short") + { + size_t st = SIZE_T_MAX; + short s; + REQUIRE_FAILED(wil::safe_cast_nothrow(st, &s)); + } +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("SafeCastTests::SafeCastExpectFailFast", "[safecast]") +{ + // Template for safe_cast fail fast tests, fill out more instances needed + witest::TestFailureCache failures; + + failures.clear(); + { + size_t st = SIZE_T_MAX; + REQUIRE_CRASH(wil::safe_cast_failfast(st)); + REQUIRE(failures.size() == 1); + } + + failures.clear(); + { + size_t st = SIZE_T_MAX; + REQUIRE_THROWS(wil::safe_cast(st)); + REQUIRE(failures.size() == 1); + } +} +#endif diff --git a/tests/StlTests.cpp b/tests/StlTests.cpp new file mode 100644 index 0000000..e006202 --- /dev/null +++ b/tests/StlTests.cpp @@ -0,0 +1,47 @@ + +#include + +#include "common.h" + +#ifndef WIL_ENABLE_EXCEPTIONS +#error STL tests require exceptions +#endif + +struct dummy +{ + char value; +}; + +// Specialize std::allocator<> so that we don't actually allocate/deallocate memory +dummy g_memoryBuffer[256]; +namespace std +{ + template <> + struct allocator + { + using value_type = dummy; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + + dummy* allocate(std::size_t count) + { + REQUIRE(count <= std::size(g_memoryBuffer)); + return g_memoryBuffer; + } + + void deallocate(dummy* ptr, std::size_t count) + { + for (std::size_t i = 0; i < count; ++i) + { + REQUIRE(ptr[i].value == 0); + } + } + }; +} + +TEST_CASE("StlTests::TestSecureAllocator", "[stl][secure_allocator]") +{ + { + wil::secure_vector sensitiveBytes(32, dummy{ 'a' }); + } +} diff --git a/tests/TokenHelpersTests.cpp b/tests/TokenHelpersTests.cpp new file mode 100644 index 0000000..a92f7d1 --- /dev/null +++ b/tests/TokenHelpersTests.cpp @@ -0,0 +1,288 @@ + +#include + +#include "common.h" + +TEST_CASE("TokenHelpersTests::VerifyOpenCurrentAccessTokenNoThrow", "[token_helpers]") +{ + // Open current thread/process access token + wil::unique_handle token; + REQUIRE_SUCCEEDED(wil::open_current_access_token_nothrow(&token)); + REQUIRE(token != nullptr); + + REQUIRE_SUCCEEDED(wil::open_current_access_token_nothrow(&token, TOKEN_READ)); + REQUIRE(token != nullptr); + + REQUIRE_SUCCEEDED(wil::open_current_access_token_nothrow(&token, TOKEN_READ, wil::OpenThreadTokenAs::Current)); + REQUIRE(token != nullptr); + + REQUIRE_SUCCEEDED(wil::open_current_access_token_nothrow(&token, TOKEN_READ, wil::OpenThreadTokenAs::Self)); + REQUIRE(token != nullptr); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("TokenHelpersTests::VerifyOpenCurrentAccessToken", "[token_helpers]") +{ + // Open current thread/process access token + wil::unique_handle token(wil::open_current_access_token()); + REQUIRE(token != nullptr); + + token = wil::open_current_access_token(TOKEN_READ); + REQUIRE(token != nullptr); + + token = wil::open_current_access_token(TOKEN_READ, wil::OpenThreadTokenAs::Current); + REQUIRE(token != nullptr); + + token = wil::open_current_access_token(TOKEN_READ, wil::OpenThreadTokenAs::Self); + REQUIRE(token != nullptr); +} +#endif + +TEST_CASE("TokenHelpersTests::VerifyGetTokenInformationNoThrow", "[token_helpers]") +{ + SECTION("Passing a null token") + { + wistd::unique_ptr tokenInfo; + REQUIRE_SUCCEEDED(wil::get_token_information_nothrow(tokenInfo, nullptr)); + REQUIRE(tokenInfo != nullptr); + } + + SECTION("Passing a non null token, since it a fake token there is no tokenInfo and hence should fail, code path is correct") + { + HANDLE faketoken = GetStdHandle(STD_INPUT_HANDLE); + wistd::unique_ptr tokenInfo; + REQUIRE_FAILED(wil::get_token_information_nothrow(tokenInfo, faketoken)); + } +} + +// Pseudo tokens can be passed to token APIs and avoid the handle allocations +// making use more efficient. +TEST_CASE("TokenHelpersTests::DemonstrateUseWithPseudoTokens", "[token_helpers]") +{ + wistd::unique_ptr tokenInfo; + REQUIRE_SUCCEEDED(wil::get_token_information_nothrow(tokenInfo, GetCurrentProcessToken())); + REQUIRE(tokenInfo != nullptr); + + REQUIRE_SUCCEEDED(wil::get_token_information_nothrow(tokenInfo, GetCurrentThreadEffectiveToken())); + REQUIRE(tokenInfo != nullptr); + + // No thread token by default, this should fail + REQUIRE_FAILED(wil::get_token_information_nothrow(tokenInfo, GetCurrentThreadToken())); + REQUIRE(tokenInfo == nullptr); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("TokenHelpersTests::VerifyGetTokenInformation", "[token_helpers]") +{ + // Passing a null token + wistd::unique_ptr tokenInfo(wil::get_token_information(nullptr)); + REQUIRE(tokenInfo != nullptr); +} +#endif + +// This fails with 'ERROR_NO_SUCH_LOGON_SESSION' on the CI machines, so disable +#ifndef WIL_FAST_BUILD +TEST_CASE("TokenHelpersTests::VerifyLinkedToken", "[token_helpers]") +{ + wil::unique_token_linked_token theToken; + REQUIRE_SUCCEEDED(wil::get_token_information_nothrow(theToken, nullptr)); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_NOTHROW(wil::get_linked_token_information()); +#endif +} +#endif + +bool IsImpersonating() +{ + wil::unique_handle token; + if (!::OpenThreadToken(GetCurrentThread(), TOKEN_QUERY, TRUE, &token)) + { + WI_ASSERT(::GetLastError() == ERROR_NO_TOKEN); + return false; + } + + return true; +} + +wil::unique_handle GetTokenToImpersonate() +{ + wil::unique_handle processToken; + FAIL_FAST_IF_WIN32_BOOL_FALSE(::OpenProcessToken(GetCurrentProcess(), TOKEN_ALL_ACCESS, &processToken)); + + wil::unique_handle impersonateToken; + FAIL_FAST_IF_WIN32_BOOL_FALSE(::DuplicateToken(processToken.get(), SecurityImpersonation, &impersonateToken)); + + return impersonateToken; +} + +TEST_CASE("TokenHelpersTests::VerifyResetThreadTokenNoThrow", "[token_helpers]") +{ + auto impersonateToken = GetTokenToImpersonate(); + + // Set the thread into a known state - no token. + wil::unique_token_reverter clearThreadToken; + REQUIRE_SUCCEEDED(wil::run_as_self_nothrow(clearThreadToken)); + REQUIRE_FALSE(IsImpersonating()); + + // Set a token on the thread - the process token, guaranteed to be friendly + wil::unique_token_reverter setThreadToken1; + REQUIRE_SUCCEEDED(wil::impersonate_token_nothrow(impersonateToken.get(), setThreadToken1)); + REQUIRE(IsImpersonating()); + + SECTION("Clear the token again, should be not impersonating, explicit reset") + { + wil::unique_token_reverter clearThreadAgain; + REQUIRE_SUCCEEDED(wil::run_as_self_nothrow(clearThreadAgain)); + REQUIRE_FALSE(IsImpersonating()); + clearThreadAgain.reset(); + REQUIRE(IsImpersonating()); + } + + SECTION("Clear the token again, should be not impersonating, dtor reset") + { + wil::unique_token_reverter clearThreadAgain; + REQUIRE_SUCCEEDED(wil::run_as_self_nothrow(clearThreadAgain)); + REQUIRE_FALSE(IsImpersonating()); + } + REQUIRE(IsImpersonating()); + + // Clear what we were impersonating + setThreadToken1.reset(); + REQUIRE_FALSE(IsImpersonating()); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("TokenHelpersTests::VerifyResetThreadToken", "[token_helpers]") +{ + auto impersonateToken = GetTokenToImpersonate(); + + // Set the thread into a known state - no token. + auto clearThreadToken = wil::run_as_self(); + REQUIRE_FALSE(IsImpersonating()); + + // Set a token on the thread - the process token, guaranteed to be friendly + auto setThreadToken1 = wil::impersonate_token(impersonateToken.get()); + REQUIRE(IsImpersonating()); + + SECTION("Clear the token again, should be not impersonating, explicit reset") + { + auto clearThreadAgain = wil::run_as_self(); + REQUIRE_FALSE(IsImpersonating()); + clearThreadAgain.reset(); + REQUIRE(IsImpersonating()); + } + + SECTION("Clear the token again, should be not impersonating, dtor reset") + { + auto clearThreadAgain = wil::run_as_self(); + REQUIRE_FALSE(IsImpersonating()); + } + REQUIRE(IsImpersonating()); + + // Clear what we were impersonating + setThreadToken1.reset(); + REQUIRE_FALSE(IsImpersonating()); +} +#endif // WIL_ENABLE_EXCEPTIONS + +template ::FixedSize>* = nullptr> +void TestGetTokenInfoForCurrentThread() +{ + wistd::unique_ptr tokenInfo; + const auto hr = wil::get_token_information_nothrow(tokenInfo, nullptr); + REQUIRE(S_OK == hr); +} + +template ::FixedSize>* = nullptr> +void TestGetTokenInfoForCurrentThread() +{ + T tokenInfo{}; + const auto hr = wil::get_token_information_nothrow(&tokenInfo, nullptr); + REQUIRE(S_OK == hr); +} + +TEST_CASE("TokenHelpersTests::VerifyGetTokenInformation2", "[token_helpers]") +{ + // Variable sized cases + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + + // Fixed size and reports size using ERROR_INSUFFICIENT_BUFFER (perf opportunity, ignore second allocation) + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); + TestGetTokenInfoForCurrentThread(); +} + +TEST_CASE("TokenHelpersTests::VerifyGetTokenInformationBadLength", "[token_helpers]") +{ + // Fixed size and reports size using ERROR_BAD_LENGTH (bug) + TestGetTokenInfoForCurrentThread(); +} + +TEST_CASE("TokenHelpersTests::VerifyGetTokenInformationSecurityImpersonationLevelErrorCases", "[token_helpers]") +{ + SECURITY_IMPERSONATION_LEVEL tokenInfo{}; + + // SECURITY_IMPERSONATION_LEVEL does not support the effective token when it is implicit. + // Demonstrate the error return in that case. + REQUIRE(E_INVALIDARG == wil::get_token_information_nothrow(&tokenInfo, GetCurrentThreadEffectiveToken())); + + // Using an explicit token is supported but returns ERROR_NO_TOKEN when there is no + // impersonation token be sure to use RETURN_IF_FAILED_EXPECTED() and don't use + // the exception forms if this case is not expected. + REQUIRE(HRESULT_FROM_WIN32(ERROR_NO_TOKEN) == wil::get_token_information_nothrow(&tokenInfo, GetCurrentThreadToken())); + + // Setup the impersonation token that SECURITY_IMPERSONATION_LEVEL requires. + FAIL_FAST_IF_WIN32_BOOL_FALSE(ImpersonateSelf(SecurityIdentification)); + TestGetTokenInfoForCurrentThread(); + + REQUIRE(S_OK == wil::get_token_information_nothrow(&tokenInfo, GetCurrentThreadToken())); + + RevertToSelf(); +} + +#ifdef WIL_ENABLE_EXCEPTIONS + +TEST_CASE("TokenHelpersTests::VerifyGetTokenInfo", "[token_helpers]") +{ + REQUIRE_NOTHROW(wil::get_token_information()); + REQUIRE_NOTHROW(wil::get_token_information()); + REQUIRE_NOTHROW(wil::get_token_information()); + REQUIRE_NOTHROW(wil::get_token_information()); + REQUIRE_NOTHROW(wil::get_token_information()); + REQUIRE_NOTHROW(wil::get_token_information()); + REQUIRE_NOTHROW(wil::get_token_information()); + + // check a non-pointer size value to make sure the whole struct is returned. + DWORD resultSize{}; + TOKEN_SOURCE ts{}; + auto tokenSource = wil::get_token_information(); + GetTokenInformation(GetCurrentThreadEffectiveToken(), TokenSource, &ts, sizeof(ts), &resultSize); + REQUIRE(memcmp(&ts, &tokenSource, sizeof(ts)) == 0); +} + +TEST_CASE("TokenHelpersTests::VerifyGetTokenInfoFailFast", "[token_helpers]") +{ + // fixed size + REQUIRE_NOTHROW(wil::get_token_information_failfast()); + // variable size + REQUIRE_NOTHROW(wil::get_token_information_failfast()); +} + +TEST_CASE("TokenHelpersTests::Verify_impersonate_token", "[token_helpers]") +{ + auto impersonationToken = wil::impersonate_token(); + REQUIRE_NOTHROW(wil::get_token_information()); +} +#endif // WIL_ENABLE_EXCEPTIONS diff --git a/tests/UniqueWinRTEventTokenTests.cpp b/tests/UniqueWinRTEventTokenTests.cpp new file mode 100644 index 0000000..936e747 --- /dev/null +++ b/tests/UniqueWinRTEventTokenTests.cpp @@ -0,0 +1,266 @@ + +#include +#include + +#include "common.h" + +using namespace ABI::Windows::Foundation; +using namespace Microsoft::WRL; + +namespace wiltest +{ + class AbiTestEventSender WrlFinal : public RuntimeClass< + RuntimeClassFlags, + IClosable, + IMemoryBufferReference, + FtmBase> + { + public: + + // IMemoryBufferReference + IFACEMETHODIMP get_Capacity(_Out_ UINT32* value) + { + *value = 0; + return S_OK; + } + + IFACEMETHODIMP add_Closed( + _In_ ITypedEventHandler* handler, + _Out_ ::EventRegistrationToken* token) + { + return m_closedEvent.Add(handler, token); + } + + IFACEMETHODIMP remove_Closed(::EventRegistrationToken token) + { + return m_closedEvent.Remove(token); + } + + // IClosable + IFACEMETHODIMP Close() + { + RETURN_IF_FAILED(m_closedEvent.InvokeAll(this, nullptr)); + return S_OK; + } + + private: + Microsoft::WRL::EventSource> m_closedEvent; + }; +} + +TEST_CASE("UniqueWinRTEventTokenTests::AbiUniqueWinrtEventTokenEventSubscribe", "[winrt][unique_winrt_event_token]") +{ + ComPtr testEventSender; + REQUIRE_SUCCEEDED(MakeAndInitialize(&testEventSender)); + ComPtr closable; + testEventSender.As(&closable); + + int timesInvoked = 0; + auto handler = Callback, + ITypedEventHandler, FtmBase>> + ([×Invoked](IInspectable*, IInspectable*) + { + timesInvoked++; + return S_OK; + }); + REQUIRE(timesInvoked == 0); + + { + wil::unique_winrt_event_token token; + REQUIRE_SUCCEEDED(WI_MakeUniqueWinRtEventTokenNoThrow(Closed, testEventSender.Get(), handler.Get(), &token)); + REQUIRE(static_cast(token)); + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 1); + } + + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 1); +} + +TEST_CASE("UniqueWinRTEventTokenTests::AbiUniqueWinrtEventTokenEarlyReset", "[winrt][unique_winrt_event_token]") +{ + ComPtr testEventSender; + REQUIRE_SUCCEEDED(MakeAndInitialize(&testEventSender)); + ComPtr closable; + testEventSender.As(&closable); + + int timesInvoked = 0; + auto handler = Callback, + ITypedEventHandler, FtmBase>> + ([×Invoked](IInspectable*, IInspectable*) + { + timesInvoked++; + return S_OK; + }); + REQUIRE(timesInvoked == 0); + + wil::unique_winrt_event_token token; + REQUIRE_SUCCEEDED(WI_MakeUniqueWinRtEventTokenNoThrow(Closed, testEventSender.Get(), handler.Get(), &token)); + REQUIRE(static_cast(token)); + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 1); + + token.reset(); + + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 1); +} + +TEST_CASE("UniqueWinRTEventTokenTests::AbiUniqueWinrtEventTokenMoveTokenToDifferentScope", "[winrt][unique_winrt_event_token]") +{ + ComPtr testEventSender; + REQUIRE_SUCCEEDED(MakeAndInitialize(&testEventSender)); + ComPtr closable; + testEventSender.As(&closable); + + int timesInvoked = 0; + auto handler = Callback, + ITypedEventHandler, FtmBase>> + ([×Invoked](IInspectable*, IInspectable*) + { + timesInvoked++; + return S_OK; + }); + REQUIRE(timesInvoked == 0); + + wil::unique_winrt_event_token outerToken; + REQUIRE_FALSE(static_cast(outerToken)); + { + wil::unique_winrt_event_token token; + REQUIRE_SUCCEEDED(WI_MakeUniqueWinRtEventTokenNoThrow(Closed, testEventSender.Get(), handler.Get(), &token)); + REQUIRE(static_cast(token)); + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 1); + + outerToken = std::move(token); + REQUIRE_FALSE(static_cast(token)); + REQUIRE(static_cast(outerToken)); + } + + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 2); +} + +TEST_CASE("UniqueWinRTEventTokenTests::AbiUniqueWinrtEventTokenMoveConstructor", "[winrt][unique_winrt_event_token]") +{ + ComPtr testEventSender; + REQUIRE_SUCCEEDED(MakeAndInitialize(&testEventSender)); + ComPtr closable; + testEventSender.As(&closable); + + int timesInvoked = 0; + auto handler = Callback, + ITypedEventHandler, FtmBase>> + ([×Invoked](IInspectable*, IInspectable*) + { + timesInvoked++; + return S_OK; + }); + REQUIRE(timesInvoked == 0); + + wil::unique_winrt_event_token firstToken; + REQUIRE_SUCCEEDED(WI_MakeUniqueWinRtEventTokenNoThrow(Closed, testEventSender.Get(), handler.Get(), &firstToken)); + REQUIRE(static_cast(firstToken)); + closable->Close(); + REQUIRE(timesInvoked == 1); + + wil::unique_winrt_event_token secondToken(std::move(firstToken)); + REQUIRE_FALSE(static_cast(firstToken)); + REQUIRE(static_cast(secondToken)); + + closable->Close(); + REQUIRE(timesInvoked == 2); + + firstToken.reset(); + closable->Close(); + REQUIRE(timesInvoked == 3); + + secondToken.reset(); + closable->Close(); + REQUIRE(timesInvoked == 3); +} + +TEST_CASE("UniqueWinRTEventTokenTests::AbiUniqueWinrtEventTokenReleaseAndReattachToNewWrapper", "[winrt][unique_winrt_event_token]") +{ + ComPtr testEventSender; + REQUIRE_SUCCEEDED(MakeAndInitialize(&testEventSender)); + ComPtr closable; + testEventSender.As(&closable); + + int timesInvoked = 0; + auto handler = Callback, + ITypedEventHandler, FtmBase>> + ([×Invoked](IInspectable*, IInspectable*) + { + timesInvoked++; + return S_OK; + }); + REQUIRE(timesInvoked == 0); + + wil::unique_winrt_event_token firstToken; + REQUIRE_SUCCEEDED(WI_MakeUniqueWinRtEventTokenNoThrow(Closed, testEventSender.Get(), handler.Get(), &firstToken)); + REQUIRE(static_cast(firstToken)); + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 1); + + ::EventRegistrationToken rawToken = firstToken.release(); + REQUIRE_FALSE(static_cast(firstToken)); + REQUIRE(rawToken.value != 0); + + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 2); + + wil::unique_winrt_event_token secondToken( + rawToken, testEventSender.Get(), &IMemoryBufferReference::remove_Closed); + + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 3); + + secondToken.reset(); + REQUIRE_SUCCEEDED(closable->Close()); + REQUIRE(timesInvoked == 3); +} + +TEST_CASE("UniqueWinRTEventTokenTests::AbiUniqueWinrtEventTokenPolicyVariants", "[winrt][unique_winrt_event_token]") +{ + ComPtr testEventSender; + REQUIRE_SUCCEEDED(MakeAndInitialize(&testEventSender)); + ComPtr closable; + testEventSender.As(&closable); + + int timesInvoked = 0; + auto handler = Callback, + ITypedEventHandler, FtmBase>> + ([×Invoked](IInspectable*, IInspectable*) + { + timesInvoked++; + return S_OK; + }); + REQUIRE(timesInvoked == 0); + + { +#ifdef WIL_ENABLE_EXCEPTIONS + auto exceptionPolicyToken = WI_MakeUniqueWinRtEventToken(Closed, testEventSender.Get(), handler.Get()); + REQUIRE(static_cast(exceptionPolicyToken)); +#endif + + auto failFastPolicyToken = WI_MakeUniqueWinRtEventTokenFailFast(Closed, testEventSender.Get(), handler.Get()); + REQUIRE(static_cast(failFastPolicyToken)); + + REQUIRE_SUCCEEDED(closable->Close()); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(timesInvoked == 2); +#else + REQUIRE(timesInvoked == 1); +#endif + } + + REQUIRE_SUCCEEDED(closable->Close()); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE(timesInvoked == 2); +#else + REQUIRE(timesInvoked == 1); +#endif +} diff --git a/tests/WatcherTests.cpp b/tests/WatcherTests.cpp new file mode 100644 index 0000000..5c02120 --- /dev/null +++ b/tests/WatcherTests.cpp @@ -0,0 +1,516 @@ + +#include // For shared_event_watcher +#include // filesystem.h includes PathCch.h which includes winnt.h, which will complain about missing target architecture +#include +#include +#include + +#include "common.h" + +TEST_CASE("EventWatcherTests::Construction", "[resource][event_watcher]") +{ + SECTION("Create unique_event_watcher_nothrow without event") + { + auto watcher = wil::make_event_watcher_nothrow([]{}); + REQUIRE(watcher != nullptr); + } + + SECTION("Create unique_event_watcher_nothrow with unique_event_nothrow") + { + wil::unique_event_nothrow eventToPass; + FAIL_FAST_IF_FAILED(eventToPass.create(wil::EventOptions::None)); + auto watcher = wil::make_event_watcher_nothrow(wistd::move(eventToPass), []{}); + REQUIRE(watcher != nullptr); + REQUIRE(eventToPass.get() == nullptr); // move construction must take it + } + + SECTION("Create unique_event_watcher_nothrow with handle") + { + wil::unique_event_nothrow eventToDupe; + FAIL_FAST_IF_FAILED(eventToDupe.create(wil::EventOptions::None)); + auto watcher = wil::make_event_watcher_nothrow(eventToDupe.get(), []{}); + REQUIRE(watcher != nullptr); + REQUIRE(eventToDupe.get() != nullptr); // handle duped in this case + } + +#ifdef WIL_ENABLE_EXCEPTIONS + SECTION("Create unique_event_watcher_nothrow with unique_event") + { + wil::unique_event eventToPass(wil::EventOptions::None); + auto watcher = wil::make_event_watcher_nothrow(wistd::move(eventToPass), []{}); + REQUIRE(watcher != nullptr); + REQUIRE(eventToPass.get() == nullptr); // move construction must take it + } + + SECTION("Create unique_event_watcher without event") + { + auto watcher = wil::make_event_watcher([]{}); + } + + SECTION("Create unique_event_watcher with unique_event_nothrow") + { + wil::unique_event_nothrow eventToPass; + THROW_IF_FAILED(eventToPass.create(wil::EventOptions::None)); + auto watcher = wil::make_event_watcher(wistd::move(eventToPass), []{}); + REQUIRE(eventToPass.get() == nullptr); // move construction must take it + } + + SECTION("Create unique_event_watcher with unique_event") + { + wil::unique_event eventToPass(wil::EventOptions::None); + auto watcher = wil::make_event_watcher(wistd::move(eventToPass), []{}); + REQUIRE(eventToPass.get() == nullptr); // move construction must take it + } + + SECTION("Create unique_event_watcher with handle") + { + wil::unique_event eventToDupe(wil::EventOptions::None); + auto watcher = wil::make_event_watcher(eventToDupe.get(), []{}); + REQUIRE(eventToDupe.get() != nullptr); // handle duped in this case + } + + SECTION("Create unique_event_watcher shared watcher") + { + wil::shared_event_watcher sharedWatcher = wil::make_event_watcher([]{}); + } +#endif +} + +static auto make_event(wil::EventOptions options = wil::EventOptions::None) +{ + wil::unique_event_nothrow result; + FAIL_FAST_IF_FAILED(result.create(options)); + return result; +} + +TEST_CASE("EventWatcherTests::VerifyDelivery", "[resource][event_watcher]") +{ + auto notificationReceived = make_event(); + + int volatile countObserved = 0; + auto watcher = wil::make_event_watcher_nothrow([&] + { + countObserved++; + notificationReceived.SetEvent(); + }); + REQUIRE(watcher != nullptr); + + watcher.SetEvent(); + REQUIRE(notificationReceived.wait(5000)); // 5 second max wait + + watcher.SetEvent(); + REQUIRE(notificationReceived.wait(5000)); // 5 second max wait + REQUIRE(countObserved == 2); +} + +TEST_CASE("EventWatcherTests::VerifyLastChangeObserved", "[resource][event_watcher]") +{ + wil::EventOptions const eventOptions[] = + { + wil::EventOptions::None, + wil::EventOptions::ManualReset, + wil::EventOptions::Signaled, + wil::EventOptions::ManualReset | wil::EventOptions::Signaled, + }; + + for (auto const &eventOption : eventOptions) + { + auto allChangesMade = make_event(wil::EventOptions::ManualReset); // ManualReset to avoid hang in case where 2 callbacks are generated (a test failure). + auto processedChange = make_event(); + + DWORD volatile stateToObserve = 0; + DWORD volatile lastObservedState = 0; + int volatile countObserved = 0; + auto watcher = wil::make_event_watcher_nothrow(make_event(eventOption), [&] + { + allChangesMade.wait(); + countObserved++; + lastObservedState = stateToObserve; + processedChange.SetEvent(); + }); + REQUIRE(watcher != nullptr); + + stateToObserve = 1; + watcher.SetEvent(); + stateToObserve = 2; + watcher.SetEvent(); + + allChangesMade.SetEvent(); + REQUIRE(processedChange.wait(5000)); + + REQUIRE((countObserved == 1 || countObserved == 2)); // ensure the race worked how we wanted it to + REQUIRE(lastObservedState == stateToObserve); + } +} + +#define ROOT_KEY_PAIR HKEY_CURRENT_USER, L"Software\\Microsoft\\RegistryWatcherTest" + +TEST_CASE("RegistryWatcherTests::Construction", "[registry][registry_watcher]") +{ + SECTION("Create unique_registry_watcher_nothrow with string") + { + auto watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind){}); + REQUIRE(watcher); + } + + SECTION("Create unique_registry_watcher_nothrow with unique_hkey") + { + wil::unique_hkey keyToMove; + REQUIRE_SUCCEEDED(HRESULT_FROM_WIN32(RegCreateKeyExW(ROOT_KEY_PAIR, 0, nullptr, 0, KEY_NOTIFY, nullptr, &keyToMove, nullptr))); + + auto watcher = wil::make_registry_watcher_nothrow(wistd::move(keyToMove), true, [&](wil::RegistryChangeKind){}); + REQUIRE(watcher); + REQUIRE(keyToMove.get() == nullptr); // ownership is transferred + } + + SECTION("Create unique_registry_watcher_nothrow with handle") + { + // construct with just an open registry key + wil::unique_hkey rootKey; + REQUIRE_SUCCEEDED(HRESULT_FROM_WIN32(RegCreateKeyExW(ROOT_KEY_PAIR, 0, nullptr, 0, KEY_NOTIFY, nullptr, &rootKey, nullptr))); + + auto watcher = wil::make_registry_watcher_nothrow(rootKey.get(), L"", true, [&](wil::RegistryChangeKind){}); + REQUIRE(watcher); + } + +#ifdef WIL_ENABLE_EXCEPTIONS + SECTION("Create unique_registry_watcher with string") + { + REQUIRE_NOTHROW(wil::make_registry_watcher(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind){})); + } + + SECTION("Create unique_registry_watcher with unique_hkey") + { + wil::unique_hkey keyToMove; + THROW_IF_FAILED(HRESULT_FROM_WIN32(RegCreateKeyExW(ROOT_KEY_PAIR, 0, nullptr, 0, KEY_NOTIFY, nullptr, &keyToMove, nullptr))); + + REQUIRE_NOTHROW(wil::make_registry_watcher(wistd::move(keyToMove), true, [&](wil::RegistryChangeKind){})); + REQUIRE(keyToMove.get() == nullptr); // ownership is transferred + } +#endif +} + +void SetRegistryValue( + _In_ HKEY hKey, + _In_opt_ LPCWSTR lpSubKey, + _In_opt_ LPCWSTR lpValueName, + _In_ DWORD dwType, + _In_reads_bytes_opt_(cbData) LPCVOID lpData, + _In_ DWORD cbData) +{ + wil::unique_hkey key; + REQUIRE(RegOpenKeyExW(hKey, lpSubKey, 0, KEY_WRITE, &key) == ERROR_SUCCESS); + REQUIRE(RegSetValueExW(key.get(), lpValueName, 0, dwType, static_cast(lpData), cbData) == ERROR_SUCCESS); +} + +TEST_CASE("RegistryWatcherTests::VerifyDelivery", "[registry][registry_watcher]") +{ + RegDeleteTreeW(ROOT_KEY_PAIR); // So that we get the 'Modify' event + auto notificationReceived = make_event(); + + int volatile countObserved = 0; + auto volatile observedChangeType = wil::RegistryChangeKind::Delete; + auto watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind changeType) + { + countObserved++; + observedChangeType = changeType; + notificationReceived.SetEvent(); + }); + REQUIRE(watcher); + + DWORD value = 1; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + REQUIRE(notificationReceived.wait(5000)); + REQUIRE(observedChangeType == wil::RegistryChangeKind::Modify); + + value++; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + REQUIRE(notificationReceived.wait(5000)); + REQUIRE(countObserved == 2); + REQUIRE(observedChangeType == wil::RegistryChangeKind::Modify); +} + +TEST_CASE("RegistryWatcherTests::VerifyLastChangeObserved", "[registry][registry_watcher]") +{ + RegDeleteTreeW(ROOT_KEY_PAIR); + auto allChangesMade = make_event(wil::EventOptions::ManualReset); // ManualReset for the case where both registry operations result in a callback. + auto processedChange = make_event(); + + DWORD volatile stateToObserve = 0; + DWORD volatile lastObservedState = 0; + DWORD volatile lastObservedValue = 0; + int volatile countObserved = 0; + auto watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&, called = false](wil::RegistryChangeKind) mutable + { + // This callback may be called more than once (since we modify the key twice), but we're holding references to + // local variables. Therefore, bail out if this is not the first time we're called + if (called) + { + return; + } + called = true; + + allChangesMade.wait(); + countObserved++; + lastObservedState = stateToObserve; + DWORD value, cbValue = sizeof(value); + RegGetValueW(ROOT_KEY_PAIR, L"value", RRF_RT_REG_DWORD, nullptr, &value, &cbValue); + lastObservedValue = value; + processedChange.SetEvent(); + }); + REQUIRE(watcher); + + DWORD value; + // make 2 changes and verify that only the last gets observed + stateToObserve = 1; + value = 0; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + + stateToObserve = 2; + value = 1; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + + allChangesMade.SetEvent(); + REQUIRE(processedChange.wait(5000)); + + REQUIRE(countObserved >= 1); // Sometimes 2 events are observed, see if this can be eliminated. + REQUIRE(lastObservedState == stateToObserve); + REQUIRE(lastObservedValue == 1); +} + +TEST_CASE("RegistryWatcherTests::VerifyDeleteBehavior", "[registry][registry_watcher]") +{ + auto notificationReceived = make_event(); + + int volatile countObserved = 0; + auto volatile observedChangeType = wil::RegistryChangeKind::Modify; + auto watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind changeType) + { + countObserved++; + observedChangeType = changeType; + notificationReceived.SetEvent(); + }); + REQUIRE(watcher); + + RegDeleteTreeW(ROOT_KEY_PAIR); // delete the key to signal the watcher with the special error case + REQUIRE(notificationReceived.wait(5000)); + REQUIRE(countObserved == 1); + REQUIRE(observedChangeType == wil::RegistryChangeKind::Delete); +} + +TEST_CASE("RegistryWatcherTests::VerifyResetInCallback", "[registry][registry_watcher]") +{ + auto notificationReceived = make_event(); + + wil::unique_registry_watcher_nothrow watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, TRUE, [&](wil::RegistryChangeKind) + { + watcher.reset(); + DWORD value = 2; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + notificationReceived.SetEvent(); + }); + REQUIRE(watcher); + + DWORD value = 1; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + REQUIRE(notificationReceived.wait(5000)); +} + +// Stress test, disabled by default +TEST_CASE("RegistryWatcherTests::VerifyResetInCallbackStress", "[!hide][registry][registry_watcher][stress]") +{ + for (DWORD value = 0; value < 10000; ++value) + { + wil::srwlock lock; + auto notificationReceived = make_event(); + + wil::unique_registry_watcher_nothrow watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, TRUE, [&](wil::RegistryChangeKind) + { + { + auto al = lock.lock_exclusive(); + watcher.reset(); // get m_refCount to 1 to ensure the Release happens on the background thread + } + ++value; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + notificationReceived.SetEvent(); + }); + REQUIRE(watcher); + + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + notificationReceived.wait(); + + { + auto al = lock.lock_exclusive(); + watcher.reset(); + } + } +} + +TEST_CASE("RegistryWatcherTests::VerifyResetAfterDelete", "[registry][registry_watcher]") +{ + auto notificationReceived = make_event(); + + int volatile countObserved = 0; + auto volatile observedChangeType = wil::RegistryChangeKind::Modify; + wil::unique_registry_watcher_nothrow watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind changeType) + { + countObserved++; + observedChangeType = changeType; + notificationReceived.SetEvent(); + watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind changeType) + { + countObserved++; + observedChangeType = changeType; + notificationReceived.SetEvent(); + }); + REQUIRE(watcher); + }); + REQUIRE(watcher); + + RegDeleteTreeW(ROOT_KEY_PAIR); // delete the key to signal the watcher with the special error case + notificationReceived.wait(); + REQUIRE(countObserved == 1); + REQUIRE(observedChangeType == wil::RegistryChangeKind::Delete); + + // wait for the reset to finish. The constructor creates the registry key + notificationReceived.wait(300); + DWORD value = 1; + SetRegistryValue(ROOT_KEY_PAIR, L"value", REG_DWORD, &value, sizeof(value)); + + notificationReceived.wait(); + REQUIRE(countObserved == 2); + REQUIRE(observedChangeType == wil::RegistryChangeKind::Modify); +} + +TEST_CASE("RegistryWatcherTests::VerifyCallbackFinishesBeforeFreed", "[registry][registry_watcher]") +{ + auto notificationReceived = make_event(); + auto deleteNotification = make_event(); + + int volatile deleteObserved = 0; + auto watcher = wil::make_registry_watcher_nothrow(ROOT_KEY_PAIR, true, [&](wil::RegistryChangeKind) + { + notificationReceived.SetEvent(); + // ensure that the callback is still being executed while the watcher is reset(). + deleteNotification.wait(200); + deleteObserved++; + notificationReceived.SetEvent(); + }); + + RegDeleteTreeW(ROOT_KEY_PAIR); // delete the key to signal the watcher with the special error case + REQUIRE(notificationReceived.wait(5000)); + + watcher.reset(); + deleteNotification.SetEvent(); + REQUIRE(notificationReceived.wait(5000)); + REQUIRE(deleteObserved == 1); +} + +TEST_CASE("FileSystemWatcherTests::Construction", "[resource][folder_watcher]") +{ + SECTION("Create unique_folder_watcher_nothrow with valid path") + { + auto watcher = wil::make_folder_watcher_nothrow(L"C:\\Windows\\System32", true, wil::FolderChangeEvents::All, []{}); + REQUIRE(watcher); + } + + SECTION("Create unique_folder_watcher_nothrow with invalid path") + { + auto watcher = wil::make_folder_watcher_nothrow(L"X:\\invalid path", true, wil::FolderChangeEvents::All, []{}); + REQUIRE(!watcher); + } + +#ifdef WIL_ENABLE_EXCEPTIONS + SECTION("Create unique_folder_watcher with valid path") + { + REQUIRE_NOTHROW(wil::make_folder_watcher(L"C:\\Windows\\System32", true, wil::FolderChangeEvents::All, []{})); + } + + SECTION("Create unique_folder_watcher with invalid path") + { + REQUIRE_THROWS(wil::make_folder_watcher(L"X:\\invalid path", true, wil::FolderChangeEvents::All, []{})); + } +#endif +} + +TEST_CASE("FileSystemWatcherTests::VerifyDelivery", "[resource][folder_watcher]") +{ + witest::TestFolder folder; + REQUIRE(folder); + + auto notificationEvent = make_event(); + int observedCount = 0; + auto watcher = wil::make_folder_watcher_nothrow(folder.Path(), true, wil::FolderChangeEvents::All, [&] + { + ++observedCount; + notificationEvent.SetEvent(); + }); + REQUIRE(watcher); + + witest::TestFile file(folder.Path(), L"file.txt"); + REQUIRE(file); + REQUIRE(notificationEvent.wait(5000)); + REQUIRE(observedCount == 1); + + witest::TestFile file2(folder.Path(), L"file2.txt"); + REQUIRE(file2); + REQUIRE(notificationEvent.wait(5000)); + REQUIRE(observedCount == 2); +} + +TEST_CASE("FolderChangeReaderTests::Construction", "[resource][folder_change_reader]") +{ + SECTION("Create folder_change_reader_nothrow with valid path") + { + auto reader = wil::make_folder_change_reader_nothrow(L"C:\\Windows\\System32", true, wil::FolderChangeEvents::All, [](auto, auto) {}); + REQUIRE(reader); + } + + SECTION("Create folder_change_reader_nothrow with invalid path") + { + auto reader = wil::make_folder_change_reader_nothrow(L"X:\\invalid path", true, wil::FolderChangeEvents::All, [](auto, auto) {}); + REQUIRE(!reader); + } + +#ifdef WIL_ENABLE_EXCEPTIONS + SECTION("Create folder_change_reader with valid path") + { + REQUIRE_NOTHROW(wil::make_folder_change_reader(L"C:\\Windows\\System32", true, wil::FolderChangeEvents::All, [](auto, auto) {})); + } + + SECTION("Create folder_change_reader with invalid path") + { + REQUIRE_THROWS(wil::make_folder_change_reader(L"X:\\invalid path", true, wil::FolderChangeEvents::All, [](auto, auto) {})); + } +#endif +} + +TEST_CASE("FolderChangeReaderTests::VerifyDelivery", "[resource][folder_change_reader]") +{ + witest::TestFolder folder; + REQUIRE(folder); + + auto notificationEvent = make_event(); + wil::FolderChangeEvent observedEvent; + wchar_t observedFileName[MAX_PATH] = L""; + auto reader = wil::make_folder_change_reader_nothrow(folder.Path(), true, wil::FolderChangeEvents::All, + [&](wil::FolderChangeEvent event, PCWSTR fileName) + { + observedEvent = event; + StringCchCopyW(observedFileName, ARRAYSIZE(observedFileName), fileName); + notificationEvent.SetEvent(); + }); + REQUIRE(reader); + + witest::TestFile testFile(folder.Path(), L"file.txt"); + REQUIRE(testFile); + REQUIRE(notificationEvent.wait(5000)); + REQUIRE(observedEvent == wil::FolderChangeEvent::Added); + REQUIRE(wcscmp(observedFileName, L"file.txt") == 0); + + witest::TestFile testFile2(folder.Path(), L"file2.txt"); + REQUIRE(testFile2); + REQUIRE(notificationEvent.wait(5000)); + REQUIRE(observedEvent == wil::FolderChangeEvent::Added); + REQUIRE(wcscmp(observedFileName, L"file2.txt") == 0); +} diff --git a/tests/WinRTTests.cpp b/tests/WinRTTests.cpp new file mode 100644 index 0000000..80de3b6 --- /dev/null +++ b/tests/WinRTTests.cpp @@ -0,0 +1,842 @@ + +#include +#include + +#ifdef WIL_ENABLE_EXCEPTIONS +#include +#include +#endif + +#include "common.h" +#include "FakeWinRTTypes.h" +#include "test_objects.h" + +using namespace ABI::Windows::Foundation; +using namespace ABI::Windows::Foundation::Collections; +using namespace ABI::Windows::Storage; +using namespace ABI::Windows::System; +using namespace Microsoft::WRL; +using namespace Microsoft::WRL::Wrappers; + +TEST_CASE("WinRTTests::VerifyTraitsTypes", "[winrt]") +{ + static_assert(wistd::is_same_v::type>, ""); + static_assert(wistd::is_same_v::type>, ""); + + static_assert(wistd::is_same_v, ""); + static_assert(wistd::is_same_v*, decltype(wil::details::GetReturnParamPointerType(&ILauncherStatics::LaunchUriAsync))>, ""); + + static_assert(wistd::is_same_v(nullptr)))>, ""); + static_assert(wistd::is_same_v*>(nullptr)))>, ""); + static_assert(wistd::is_same_v*>(nullptr)))>, ""); +} + +template +void DoHStringComparisonTest(LhsT&& lhs, RhsT&& rhs, int relation) +{ + using compare = wil::details::hstring_compare; + + // == and != + REQUIRE(compare::equals(lhs, rhs) == (relation == 0)); + REQUIRE(compare::not_equals(lhs, rhs) == (relation != 0)); + + REQUIRE(compare::equals(rhs, lhs) == (relation == 0)); + REQUIRE(compare::not_equals(rhs, lhs) == (relation != 0)); + + // < and >= + REQUIRE(compare::less(lhs, rhs) == (relation < 0)); + REQUIRE(compare::greater_equals(lhs, rhs) == (relation >= 0)); + + REQUIRE(compare::less(rhs, lhs) == (relation > 0)); + REQUIRE(compare::greater_equals(rhs, lhs) == (relation <= 0)); + + // > and <= + REQUIRE(compare::greater(lhs, rhs) == (relation > 0)); + REQUIRE(compare::less_equals(lhs, rhs) == (relation <= 0)); + + REQUIRE(compare::greater(rhs, lhs) == (relation < 0)); + REQUIRE(compare::less_equals(rhs, lhs) == (relation >= 0)); + + // We wish to test with both const and non-const values. We can do this for free here so long as the type is + // not an array since changing the const-ness of an array may change the expected results +#pragma warning(suppress: 4127) + if (!wistd::is_array>::value && + !wistd::is_const>::value) + { + const wistd::remove_reference_t& constLhs = lhs; + DoHStringComparisonTest(constLhs, rhs, relation); + } + +#pragma warning(suppress: 4127) + if (!wistd::is_array>::value && + !wistd::is_const>::value) + { + const wistd::remove_reference_t& constRhs = rhs; + DoHStringComparisonTest(lhs, constRhs, relation); + } +} + +// The two string arguments are expected to compare equal to one another using the specified IgnoreCase argument and +// contain at least one embedded null character +template +void DoHStringSameValueComparisonTest(const wchar_t (&lhs)[Size], const wchar_t (&rhs)[Size]) +{ + wchar_t lhsNonConstArray[Size + 5]; + wchar_t rhsNonConstArray[Size + 5]; + wcsncpy_s(lhsNonConstArray, lhs, Size); + wcsncpy_s(rhsNonConstArray, rhs, Size); + + // For non-const arrays, we should never deduce length, so even though we append different values to each string, we + // do so after the last null character, so they should never be read + wcsncpy_s(lhsNonConstArray + Size + 1, 4, L"foo", 3); + wcsncpy_s(rhsNonConstArray + Size + 1, 4, L"bar", 3); + + const wchar_t* lhsCstr = lhs; + const wchar_t* rhsCstr = rhs; + + HStringReference lhsRef(lhs); + HStringReference rhsRef(rhs); + HString lhsStr; + HString rhsStr; + REQUIRE_SUCCEEDED(lhsStr.Set(lhs)); + REQUIRE_SUCCEEDED(rhsStr.Set(rhs)); + auto lhsHstr = lhsStr.Get(); + auto rhsHstr = rhsStr.Get(); + + wil::unique_hstring lhsUniqueStr; + wil::unique_hstring rhsUniqueStr; + REQUIRE_SUCCEEDED(lhsStr.CopyTo(&lhsUniqueStr)); + REQUIRE_SUCCEEDED(rhsStr.CopyTo(&rhsUniqueStr)); + + // Const array - embedded nulls are included only if InhibitArrayReferences is false + DoHStringComparisonTest(lhs, rhs, 0); + DoHStringComparisonTest(lhs, rhsNonConstArray, InhibitArrayReferences ? 0 : 1); + DoHStringComparisonTest(lhs, rhsCstr, InhibitArrayReferences ? 0 : 1); + DoHStringComparisonTest(lhs, rhsRef, InhibitArrayReferences ? -1 : 0); + DoHStringComparisonTest(lhs, rhsStr, InhibitArrayReferences ? -1 : 0); + DoHStringComparisonTest(lhs, rhsHstr, InhibitArrayReferences ? -1 : 0); + DoHStringComparisonTest(lhs, rhsUniqueStr, InhibitArrayReferences ? -1 : 0); + + // Non-const array - *never* deduce length + DoHStringComparisonTest(lhsNonConstArray, rhsNonConstArray, 0); + DoHStringComparisonTest(lhsNonConstArray, rhsCstr, 0); + DoHStringComparisonTest(lhsNonConstArray, rhsRef, -1); + DoHStringComparisonTest(lhsNonConstArray, rhsStr, -1); + DoHStringComparisonTest(lhsNonConstArray, rhsHstr, -1); + DoHStringComparisonTest(lhsNonConstArray, rhsUniqueStr, -1); + + // C string - impossible to deduce length + DoHStringComparisonTest(lhsCstr, rhsCstr, 0); + DoHStringComparisonTest(lhsCstr, rhsRef, -1); + DoHStringComparisonTest(lhsCstr, rhsStr, -1); + DoHStringComparisonTest(lhsCstr, rhsHstr, -1); + DoHStringComparisonTest(lhsCstr, rhsUniqueStr, -1); + + // HStringReference + DoHStringComparisonTest(lhsRef, rhsRef, 0); + DoHStringComparisonTest(lhsRef, rhsStr, 0); + DoHStringComparisonTest(lhsRef, rhsHstr, 0); + DoHStringComparisonTest(lhsRef, rhsUniqueStr, 0); + + // HString + DoHStringComparisonTest(lhsStr, rhsStr, 0); + DoHStringComparisonTest(lhsStr, rhsHstr, 0); + DoHStringComparisonTest(lhsStr, rhsUniqueStr, 0); + + // Raw HSTRING + DoHStringComparisonTest(lhsHstr, rhsHstr, 0); + DoHStringComparisonTest(lhsHstr, rhsUniqueStr, 0); + + // wil::unique_hstring + DoHStringComparisonTest(lhsUniqueStr, rhsUniqueStr, 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + std::wstring lhsWstr(lhs, 7); + std::wstring rhsWstr(rhs, 7); + DoHStringComparisonTest(lhsWstr, rhsWstr, 0); + DoHStringComparisonTest(lhsWstr, rhs, InhibitArrayReferences ? 1 : 0); + DoHStringComparisonTest(lhsWstr, rhsNonConstArray, 1); + DoHStringComparisonTest(lhsWstr, rhsCstr, 1); + DoHStringComparisonTest(lhsWstr, rhsRef, 0); + DoHStringComparisonTest(lhsWstr, rhsStr, 0); + DoHStringComparisonTest(lhsWstr, rhsHstr, 0); + DoHStringComparisonTest(lhsWstr, rhsUniqueStr, 0); +#endif +} + +// It's expected that the first argument (lhs) compares greater than the second argument (rhs) +template +void DoHStringDifferentValueComparisonTest(const wchar_t (&lhs)[LhsSize], const wchar_t (&rhs)[RhsSize]) +{ + wchar_t lhsNonConstArray[LhsSize]; + wchar_t rhsNonConstArray[RhsSize]; + wcsncpy_s(lhsNonConstArray, lhs, LhsSize); + wcsncpy_s(rhsNonConstArray, rhs, RhsSize); + + const wchar_t* lhsCstr = lhs; + const wchar_t* rhsCstr = rhs; + + HStringReference lhsRef(lhs); + HStringReference rhsRef(rhs); + HString lhsStr; + HString rhsStr; + REQUIRE_SUCCEEDED(lhsStr.Set(lhs)); + REQUIRE_SUCCEEDED(rhsStr.Set(rhs)); + auto lhsHstr = lhsStr.Get(); + auto rhsHstr = rhsStr.Get(); + + wil::unique_hstring lhsUniqueStr; + wil::unique_hstring rhsUniqueStr; + REQUIRE_SUCCEEDED(lhsStr.CopyTo(&lhsUniqueStr)); + REQUIRE_SUCCEEDED(rhsStr.CopyTo(&rhsUniqueStr)); + + // Const array + DoHStringComparisonTest(lhs, rhs, 1); + DoHStringComparisonTest(lhs, rhsNonConstArray, 1); + DoHStringComparisonTest(lhs, rhsCstr, 1); + DoHStringComparisonTest(lhs, rhsRef, 1); + DoHStringComparisonTest(lhs, rhsStr, 1); + DoHStringComparisonTest(lhs, rhsHstr, 1); + DoHStringComparisonTest(lhs, rhsUniqueStr, 1); + + // Non-const array + DoHStringComparisonTest(lhsNonConstArray, rhsNonConstArray, 1); + DoHStringComparisonTest(lhsNonConstArray, rhsCstr, 1); + DoHStringComparisonTest(lhsNonConstArray, rhsRef, 1); + DoHStringComparisonTest(lhsNonConstArray, rhsStr, 1); + DoHStringComparisonTest(lhsNonConstArray, rhsHstr, 1); + DoHStringComparisonTest(lhsNonConstArray, rhsUniqueStr, 1); + + // C string + DoHStringComparisonTest(lhsCstr, rhsCstr, 1); + DoHStringComparisonTest(lhsCstr, rhsRef, 1); + DoHStringComparisonTest(lhsCstr, rhsStr, 1); + DoHStringComparisonTest(lhsCstr, rhsHstr, 1); + DoHStringComparisonTest(lhsCstr, rhsUniqueStr, 1); + + // HStringReference + DoHStringComparisonTest(lhsRef, rhsRef, 1); + DoHStringComparisonTest(lhsRef, rhsStr, 1); + DoHStringComparisonTest(lhsRef, rhsHstr, 1); + DoHStringComparisonTest(lhsRef, rhsUniqueStr, 1); + + // HString + DoHStringComparisonTest(lhsStr, rhsStr, 1); + DoHStringComparisonTest(lhsStr, rhsHstr, 1); + DoHStringComparisonTest(lhsStr, rhsUniqueStr, 1); + + // Raw HSTRING + DoHStringComparisonTest(lhsHstr, rhsHstr, 1); + DoHStringComparisonTest(lhsHstr, rhsUniqueStr, 1); + + // wil::unique_hstring + DoHStringComparisonTest(lhsUniqueStr, rhsUniqueStr, 1); + +#ifdef WIL_ENABLE_EXCEPTIONS + std::wstring lhsWstr(lhs, 7); + std::wstring rhsWstr(rhs, 7); + DoHStringComparisonTest(lhsWstr, rhsWstr, 1); + DoHStringComparisonTest(lhsWstr, rhs, 1); + DoHStringComparisonTest(lhsWstr, rhsNonConstArray, 1); + DoHStringComparisonTest(lhsWstr, rhsCstr, 1); + DoHStringComparisonTest(lhsWstr, rhsRef, 1); + DoHStringComparisonTest(lhsWstr, rhsStr, 1); + DoHStringComparisonTest(lhsWstr, rhsHstr, 1); + DoHStringComparisonTest(lhsWstr, rhsUniqueStr, 1); +#endif +} + +TEST_CASE("WinRTTests::HStringComparison", "[winrt][hstring_compare]") +{ + SECTION("Don't inhibit arrays") + { + DoHStringSameValueComparisonTest(L"foo\0bar", L"foo\0bar"); + DoHStringDifferentValueComparisonTest(L"foo", L"bar"); + } + + SECTION("Inhibit arrays") + { + DoHStringSameValueComparisonTest(L"foo\0bar", L"foo\0bar"); + DoHStringDifferentValueComparisonTest(L"foo", L"bar"); + } + + SECTION("Ignore case") + { + DoHStringSameValueComparisonTest(L"foo\0bar", L"FoO\0bAR"); + DoHStringDifferentValueComparisonTest(L"Foo", L"baR"); + } + + SECTION("Empty string") + { + const wchar_t constArray[] = L""; + wchar_t nonConstArray[] = L""; + const wchar_t* cstr = constArray; + const wchar_t* nullCstr = nullptr; + + // str may end up referencing a null HSTRING. That's fine; we'll just test null HSTRING twice + HString str; + REQUIRE_SUCCEEDED(str.Set(constArray)); + HSTRING nullHstr = nullptr; + + // Const array - impossible to use null value + DoHStringComparisonTest(constArray, constArray, 0); + DoHStringComparisonTest(constArray, nonConstArray, 0); + DoHStringComparisonTest(constArray, cstr, 0); + DoHStringComparisonTest(constArray, nullCstr, 0); + DoHStringComparisonTest(constArray, str.Get(), 0); + DoHStringComparisonTest(constArray, nullHstr, 0); + + // Non-const array - impossible to use null value + DoHStringComparisonTest(nonConstArray, nonConstArray, 0); + DoHStringComparisonTest(nonConstArray, cstr, 0); + DoHStringComparisonTest(nonConstArray, nullCstr, 0); + DoHStringComparisonTest(nonConstArray, str.Get(), 0); + DoHStringComparisonTest(nonConstArray, nullHstr, 0); + + // Non-null c-string + DoHStringComparisonTest(cstr, cstr, 0); + DoHStringComparisonTest(cstr, nullCstr, 0); + DoHStringComparisonTest(cstr, str.Get(), 0); + DoHStringComparisonTest(cstr, nullHstr, 0); + + // Null c-string + DoHStringComparisonTest(nullCstr, nullCstr, 0); + DoHStringComparisonTest(nullCstr, str.Get(), 0); + DoHStringComparisonTest(nullCstr, nullHstr, 0); + + // (Possibly) non-null HSTRING + DoHStringComparisonTest(str.Get(), str.Get(), 0); + DoHStringComparisonTest(str.Get(), nullHstr, 0); + + // Null HSTRING + DoHStringComparisonTest(nullHstr, nullHstr, 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + std::wstring wstr; + DoHStringComparisonTest(wstr, wstr, 0); + DoHStringComparisonTest(wstr, constArray, 0); + DoHStringComparisonTest(wstr, nonConstArray, 0); + DoHStringComparisonTest(wstr, cstr, 0); + DoHStringComparisonTest(wstr, nullCstr, 0); + DoHStringComparisonTest(wstr, str.Get(), 0); + DoHStringComparisonTest(wstr, nullHstr, 0); +#endif + } +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("WinRTTests::HStringMapTest", "[winrt][hstring_compare]") +{ + int nextValue = 0; + std::map wstringMap; + wstringMap.emplace(L"foo", nextValue++); + wstringMap.emplace(L"bar", nextValue++); + wstringMap.emplace(std::wstring(L"foo\0bar", 7), nextValue++); + wstringMap.emplace(L"adding", nextValue++); + wstringMap.emplace(L"quite", nextValue++); + wstringMap.emplace(L"a", nextValue++); + wstringMap.emplace(L"few", nextValue++); + wstringMap.emplace(L"more", nextValue++); + wstringMap.emplace(L"values", nextValue++); + wstringMap.emplace(L"for", nextValue++); + wstringMap.emplace(L"testing", nextValue++); + wstringMap.emplace(L"", nextValue++); + + std::map hstringMap; + for (auto& pair : wstringMap) + { + HString str; + THROW_IF_FAILED(str.Set(pair.first.c_str(), static_cast(pair.first.length()))); + hstringMap.emplace(std::move(str), pair.second); + } + + // Order should be the same as the map of wstring + auto itr = hstringMap.begin(); + for (auto& pair : wstringMap) + { + REQUIRE(itr != hstringMap.end()); + REQUIRE(itr->first == HStringReference(pair.first.c_str(), static_cast(pair.first.length()))); + + // Should also be able to find the value + REQUIRE(hstringMap.find(pair.first) != hstringMap.end()); + + ++itr; + } + REQUIRE(itr == hstringMap.end()); + + const wchar_t constArray[] = L"foo\0bar"; + wchar_t nonConstArray[] = L"foo\0bar"; + const wchar_t* cstr = constArray; + + HString key; + wil::unique_hstring uniqueHstr; + THROW_IF_FAILED(key.Set(constArray)); + THROW_IF_FAILED(key.CopyTo(&uniqueHstr)); + + HStringReference ref(constArray); + std::wstring wstr(constArray, 7); + + auto verifyFunc = [&](int expectedValue, auto&& keyValue) + { + auto itr = hstringMap.find(std::forward(keyValue)); + REQUIRE(itr != hstringMap.end()); + REQUIRE(expectedValue == itr->second); + }; + + // The following should find "foo\0bar" + auto expectedValue = wstringMap[wstr]; + verifyFunc(expectedValue, uniqueHstr); + verifyFunc(expectedValue, key); + verifyFunc(expectedValue, key.Get()); + verifyFunc(expectedValue, ref); + verifyFunc(expectedValue, wstr); + + // Arrays/strings should not deduce length and should therefore find "foo" + expectedValue = wstringMap[L"foo"]; + verifyFunc(expectedValue, constArray); + verifyFunc(expectedValue, nonConstArray); + verifyFunc(expectedValue, cstr); + + // Should not ignore case + REQUIRE(hstringMap.find(L"FOO") == hstringMap.end()); + + // Should also be able to find empty values + const wchar_t constEmptyArray[] = L""; + wchar_t nonConstEmptyArray[] = L""; + const wchar_t* emptyCstr = constEmptyArray; + const wchar_t* nullCstr = nullptr; + + HString emptyStr; + HSTRING nullHstr = nullptr; + + std::wstring emptyWstr; + + expectedValue = wstringMap[L""]; + verifyFunc(expectedValue, constEmptyArray); + verifyFunc(expectedValue, nonConstEmptyArray); + verifyFunc(expectedValue, emptyCstr); + verifyFunc(expectedValue, nullCstr); + verifyFunc(expectedValue, emptyStr); + verifyFunc(expectedValue, nullHstr); + verifyFunc(expectedValue, emptyWstr); +} + +TEST_CASE("WinRTTests::HStringCaseInsensitiveMapTest", "[winrt][hstring_compare]") +{ + std::map hstringMap; + + auto emplaceFunc = [&](auto&& key, int value) + { + HString str; + THROW_IF_FAILED(str.Set(std::forward(key))); + hstringMap.emplace(std::move(str), value); + }; + + int nextValue = 0; + int fooValue = nextValue++; + emplaceFunc(L"foo", fooValue); + emplaceFunc(L"bar", nextValue++); + int foobarValue = nextValue++; + emplaceFunc(L"foo\0bar", foobarValue); + emplaceFunc(L"foobar", nextValue++); + emplaceFunc(L"adding", nextValue++); + emplaceFunc(L"some", nextValue++); + emplaceFunc(L"more", nextValue++); + emplaceFunc(L"values", nextValue++); + emplaceFunc(L"for", nextValue++); + emplaceFunc(L"testing", nextValue++); + WI_ASSERT(static_cast(nextValue) == hstringMap.size()); + + const wchar_t constArray[] = L"FoO\0BAr"; + wchar_t nonConstArray[] = L"fOo\0baR"; + const wchar_t* cstr = constArray; + + HString key; + wil::unique_hstring uniqueHstr; + THROW_IF_FAILED(key.Set(constArray)); + THROW_IF_FAILED(key.CopyTo(&uniqueHstr)); + + HStringReference ref(constArray); + std::wstring wstr(constArray, 7); + + auto verifyFunc = [&](int expectedValue, auto&& key) + { + auto itr = hstringMap.find(std::forward(key)); + REQUIRE(itr != std::end(hstringMap)); + REQUIRE(expectedValue == itr->second); + }; + + // The following should find "foo\0bar" + verifyFunc(foobarValue, uniqueHstr); + verifyFunc(foobarValue, key); + verifyFunc(foobarValue, key.Get()); + verifyFunc(foobarValue, ref); + verifyFunc(foobarValue, wstr); + + // Arrays/strings should not deduce length and should therefore find "foo" + verifyFunc(fooValue, constArray); + verifyFunc(fooValue, nonConstArray); + verifyFunc(fooValue, cstr); +} +#endif + +// This is not a test method, nor should it be called. This is a compilation-only test. +#ifdef WIL_ENABLE_EXCEPTIONS +void RunWhenCompleteCompilationTest() +{ + { + ComPtr> stringOp; + wil::run_when_complete(stringOp.Get(), [](HRESULT /* result */, HSTRING /* value */) {}); + auto result = wil::wait_for_completion(stringOp.Get()); + } + + { + ComPtr> stringOpWithProgress; + wil::run_when_complete(stringOpWithProgress.Get(), [](HRESULT /* result */, HSTRING /* value */) {}); + auto result = wil::wait_for_completion(stringOpWithProgress.Get()); + } +} +#endif + +TEST_CASE("WinRTTests::RunWhenCompleteMoveOnlyTest", "[winrt][run_when_complete]") +{ + auto op = Make>(); + REQUIRE(op); + + bool gotEvent = false; + auto hr = wil::run_when_complete_nothrow(op.Get(), [&gotEvent, enforce = cannot_copy{}](HRESULT hr, int result) + { + (void)enforce; + REQUIRE_SUCCEEDED(hr); + REQUIRE(result == 42); + gotEvent = true; + return S_OK; + }); + REQUIRE_SUCCEEDED(hr); + + op->Complete(S_OK, 42); + REQUIRE(gotEvent); +} + +TEST_CASE("WinRTTests::WaitForCompletionTimeout", "[winrt][wait_for_completion]") +{ + auto op = Make>(); + REQUIRE(op); + + // The wait_for_completion* functions don't properly deduce the "decayed" async type, so force it here + auto asyncOp = static_cast*>(op.Get()); + + bool timedOut = false; + REQUIRE_SUCCEEDED(wil::wait_for_completion_or_timeout_nothrow(asyncOp, 1, &timedOut)); + REQUIRE(timedOut); +} + +// This is not a test method, nor should it be called. This is a compilation-only test. +#pragma warning(push) +#pragma warning(disable: 4702) // Unreachable code +void WaitForCompletionCompilationTest() +{ + // Ensure the wait_for_completion variants compile + FAIL_FAST_HR_MSG(E_UNEXPECTED, "This is a compilation test, and should not be called"); + + // template + // inline HRESULT wait_for_completion_nothrow(_In_ TAsync* operation, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS, DWORD timeout = INFINITE); + IAsyncAction* action = nullptr; + wil::wait_for_completion_nothrow(action); + wil::wait_for_completion_nothrow(action, COWAIT_DEFAULT); + + // template + // HRESULT wait_for_completion_nothrow(_In_ ABI::Windows::Foundation::IAsyncOperation* operation, + // _Out_ typename wil::details::MapAsyncOpResultType::type* result, + // COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS, DWORD timeout = INFINITE); + IAsyncOperation* operation = nullptr; + wil::wait_for_completion_nothrow(operation); + wil::wait_for_completion_nothrow(operation, COWAIT_DEFAULT); + + // template + // HRESULT wait_for_completion_nothrow(_In_ ABI::Windows::Foundation::IAsyncOperationWithProgress* operation, + // _Out_ typename wil::details::MapAsyncOpProgressResultType::type* result, + // COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS, DWORD timeout = INFINITE); + + ComPtr> operationWithResult; + boolean result = false; + wil::wait_for_completion_nothrow(operationWithResult.Get(), &result); + wil::wait_for_completion_nothrow(operationWithResult.Get(), &result, COWAIT_DEFAULT); + + DWORD timeoutValue = 1000; // arbitrary + bool timedOut = false; + + // template + // inline HRESULT wait_for_completion_or_timeout_nothrow(_In_ TAsync* operation, + // DWORD timeoutValue, _Out_ bool* timedOut, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS); + wil::wait_for_completion_or_timeout_nothrow(action, timeoutValue, &timedOut); + wil::wait_for_completion_or_timeout_nothrow(action, timeoutValue, &timedOut, COWAIT_DEFAULT); + + // template + // HRESULT wait_for_completion_or_timeout_nothrow(_In_ ABI::Windows::Foundation::IAsyncOperation* operation, + // _Out_ typename wil::details::MapAsyncOpResultType::type* result, + // DWORD timeoutValue, _Out_ bool* timedOut, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS); + wil::wait_for_completion_or_timeout_nothrow(operation, timeoutValue, &timedOut); + wil::wait_for_completion_or_timeout_nothrow(operation, timeoutValue, &timedOut, COWAIT_DEFAULT); + + // template + // HRESULT wait_for_completion_or_timeout_nothrow(_In_ ABI::Windows::Foundation::IAsyncOperationWithProgress* operation, + // _Out_ typename wil::details::MapAsyncOpProgressResultType::type* result, + // DWORD timeoutValue, _Out_ bool* timedOut, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS); + wil::wait_for_completion_or_timeout_nothrow(operationWithResult.Get(), &result, timeoutValue, &timedOut); + wil::wait_for_completion_or_timeout_nothrow(operationWithResult.Get(), &result, timeoutValue, &timedOut, COWAIT_DEFAULT); + +#ifdef WIL_ENABLE_EXCEPTIONS + // template + // inline void wait_for_completion(_In_ TAsync* operation, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS, DWORD timeout = INFINITE); + wil::wait_for_completion(action); + wil::wait_for_completion(action, COWAIT_DEFAULT); + + // template ::type>::type> + // TReturn + // wait_for_completion(_In_ ABI::Windows::Foundation::IAsyncOperation* operation, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS, DWORD timeout = INFINITE); + wil::wait_for_completion(operation); + wil::wait_for_completion(operation, COWAIT_DEFAULT); + + // template ::type>::type> + // TReturn + // wait_for_completion(_In_ ABI::Windows::Foundation::IAsyncOperationWithProgress* operation, COWAIT_FLAGS flags = COWAIT_DISPATCH_CALLS, DWORD timeout = INFINITE); + result = wil::wait_for_completion(operationWithResult.Get()); + result = wil::wait_for_completion(operationWithResult.Get(), COWAIT_DEFAULT); +#endif +} +#pragma warning(pop) + +TEST_CASE("WinRTTests::TimeTTests", "[winrt][time_t]") +{ + // Verifying that converting DateTime variable set as the date that means 0 as time_t works + DateTime time1 = { wil::SecondsToStartOf1970 * wil::HundredNanoSecondsInSecond }; + __time64_t time_t1 = wil::DateTime_to_time_t(time1); + REQUIRE(time_t1 == 0); + + // Verifying that converting back to DateTime would return the same value + DateTime time2 = wil::time_t_to_DateTime(time_t1); + REQUIRE(time1.UniversalTime == time2.UniversalTime); + + // Verifying that converting to time_t for non-zero value also works + time2.UniversalTime += wil::HundredNanoSecondsInSecond * 123; + __time64_t time_t2 = wil::DateTime_to_time_t(time2); + REQUIRE(time_t2 - time_t1 == 123); + + // Verifying that converting back to DateTime for non-zero value also works + time1 = wil::time_t_to_DateTime(time_t2); + REQUIRE(time1.UniversalTime == time2.UniversalTime); +} + +ComPtr> MakeSampleInspectableVector() +{ + auto result = Make>(); + REQUIRE(result); + + ComPtr propStatics; + REQUIRE_SUCCEEDED(GetActivationFactory(HStringReference(RuntimeClass_Windows_Foundation_PropertyValue).Get(), &propStatics)); + + for (UINT32 i = 0; i < 5; ++i) + { + ComPtr myProp; + REQUIRE_SUCCEEDED(propStatics->CreateUInt32(i, &myProp)); + REQUIRE_SUCCEEDED(result->Append(myProp.Get())); + } + + return result; +} + +ComPtr> MakeSampleStringVector() +{ + auto result = Make>(); + REQUIRE(result); + + const HStringReference items[] = { HStringReference(L"one"), HStringReference(L"two"), HStringReference(L"three") }; + for (const auto& i : items) + { + REQUIRE_SUCCEEDED(result->Append(i.Get())); + } + + return result; +} + +ComPtr> MakeSamplePointVector() +{ + auto result = Make>(); + REQUIRE(result); + + for (int i = 0; i < 5; ++i) + { + auto value = static_cast(i); + REQUIRE_SUCCEEDED(result->Append(Point{ value, value })); + } + + return result; +} + +TEST_CASE("WinRTTests::VectorRangeTest", "[winrt][vector_range]") +{ + auto uninit = wil::RoInitialize_failfast(); + + auto inspectables = MakeSampleInspectableVector(); + unsigned count = 0; + REQUIRE_SUCCEEDED(inspectables->get_Size(&count)); + + unsigned idx = 0; + HRESULT success = S_OK; + for (const auto& i : wil::get_range_nothrow(inspectables.Get(), &success)) + { + // Duplications are not a typo - they verify the thing is callable twice + + UINT32 value; + ComPtr> intRef; + REQUIRE_SUCCEEDED(i.CopyTo(IID_PPV_ARGS(&intRef))); + REQUIRE_SUCCEEDED(intRef->get_Value(&value)); + REQUIRE(idx == value); + REQUIRE_SUCCEEDED(i.CopyTo(IID_PPV_ARGS(&intRef))); + REQUIRE_SUCCEEDED(intRef->get_Value(&value)); + REQUIRE(idx == value); + + ++idx; + + HString rtc; + REQUIRE_SUCCEEDED(i->GetRuntimeClassName(rtc.GetAddressOf())); + REQUIRE_SUCCEEDED(i->GetRuntimeClassName(rtc.GetAddressOf())); + } + REQUIRE_SUCCEEDED(success); + REQUIRE(count == idx); + + auto strings = MakeSampleStringVector(); + for (const auto& i : wil::get_range_nothrow(strings.Get(), &success)) + { + REQUIRE(i.Get()); + REQUIRE(i.Get()); + } + REQUIRE_SUCCEEDED(success); + + int index = 0; + auto points = MakeSamplePointVector(); + for (auto value : wil::get_range_nothrow(points.Get(), &success)) + { + REQUIRE(index++ == value.Get().X); + } + REQUIRE_SUCCEEDED(success); + + // operator-> should not clear out the pointer + auto inspRange = wil::get_range_nothrow(inspectables.Get()); + for (auto itr = inspRange.begin(); itr != inspRange.end(); ++itr) + { + REQUIRE(itr->Get()); + } + + auto strRange = wil::get_range_nothrow(strings.Get()); + for (auto itr = strRange.begin(); itr != strRange.end(); ++itr) + { + REQUIRE(itr->Get()); + } + + index = 0; + auto pointRange = wil::get_range_nothrow(points.Get()); + for (auto itr = pointRange.begin(); itr != pointRange.end(); ++itr) + { + REQUIRE(index++ == itr->Get().X); + } + +#if (defined WIL_ENABLE_EXCEPTIONS) + idx = 0; + for (const auto& i : wil::get_range(inspectables.Get())) + { + // Duplications are not a typo - they verify the thing is callable twice + + UINT32 value; + ComPtr> intRef; + REQUIRE_SUCCEEDED(i.CopyTo(IID_PPV_ARGS(&intRef))); + REQUIRE_SUCCEEDED(intRef->get_Value(&value)); + REQUIRE(idx == value); + REQUIRE_SUCCEEDED(i.CopyTo(IID_PPV_ARGS(&intRef))); + REQUIRE_SUCCEEDED(intRef->get_Value(&value)); + REQUIRE(idx == value); + + ++idx; + + HString rtc; + REQUIRE_SUCCEEDED(i->GetRuntimeClassName(rtc.GetAddressOf())); + REQUIRE_SUCCEEDED(i->GetRuntimeClassName(rtc.GetAddressOf())); + } + REQUIRE(count == idx); + + for (const auto& i : wil::get_range(strings.Get())) + { + REQUIRE(i.Get()); + REQUIRE(i.Get()); + } + + index = 0; + for (auto value : wil::get_range(points.Get())) + { + REQUIRE(index++ == value.Get().X); + } + + // operator-> should not clear out the pointer + for (auto itr = inspRange.begin(); itr != inspRange.end(); ++itr) + { + REQUIRE(itr->Get()); + } + + for (auto itr = strRange.begin(); itr != strRange.end(); ++itr) + { + REQUIRE(itr->Get()); + } + + index = 0; + for (auto itr = pointRange.begin(); itr != pointRange.end(); ++itr) + { + REQUIRE(index++ == itr->Get().X); + } +#endif +} + +unsigned long GetComObjectRefCount(IUnknown* unk) { unk->AddRef(); return unk->Release(); } + +TEST_CASE("WinRTTests::VectorRangeLeakTest", "[winrt][vector_range]") +{ + auto uninit = wil::RoInitialize_failfast(); + + auto inspectables = MakeSampleInspectableVector(); + ComPtr verifyNotLeaked; + HRESULT hr = S_OK; + for (const auto& ptr : wil::get_range_nothrow(inspectables.Get(), &hr)) + { + if (!verifyNotLeaked) + { + verifyNotLeaked = ptr; + } + } + inspectables = nullptr; // clear all refs to verifyNotLeaked + REQUIRE_SUCCEEDED(hr); + REQUIRE(GetComObjectRefCount(verifyNotLeaked.Get()) == 1); + + inspectables = MakeSampleInspectableVector(); + for (const auto& ptr : wil::get_range_failfast(inspectables.Get())) + { + if (!verifyNotLeaked) + { + verifyNotLeaked = ptr; + } + } + inspectables = nullptr; // clear all refs to verifyNotLeaked + REQUIRE(GetComObjectRefCount(verifyNotLeaked.Get()) == 1); + +#if (defined WIL_ENABLE_EXCEPTIONS) + inspectables = MakeSampleInspectableVector(); + for (const auto& ptr : wil::get_range(inspectables.Get())) + { + if (!verifyNotLeaked) + { + verifyNotLeaked = ptr; + } + } + inspectables = nullptr; // clear all refs to verifyNotLeaked + REQUIRE(GetComObjectRefCount(verifyNotLeaked.Get()) == 1); +#endif +} diff --git a/tests/WistdTests.cpp b/tests/WistdTests.cpp new file mode 100644 index 0000000..b4a6974 --- /dev/null +++ b/tests/WistdTests.cpp @@ -0,0 +1,209 @@ + +#include + +#include "common.h" +#include "test_objects.h" + +// Test methods/objects +int GetValue() +{ + return 42; +} + +int GetOtherValue() +{ + return 8; +} + +int Negate(int value) +{ + return -value; +} + +int Add(int lhs, int rhs) +{ + return lhs + rhs; +} + +TEST_CASE("WistdFunctionTests::CallOperatorTest", "[wistd]") +{ + wistd::function getValue = GetValue; + REQUIRE(GetValue() == getValue()); + + wistd::function negate = Negate; + REQUIRE(Negate(42) == negate(42)); + + wistd::function add = Add; + REQUIRE(Add(42, 8) == add(42, 8)); +} + +TEST_CASE("WistdFunctionTests::AssignmentOperatorTest", "[wistd]") +{ + wistd::function fn = GetValue; + REQUIRE(GetValue() == fn()); + + fn = GetOtherValue; + REQUIRE(GetOtherValue() == fn()); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("WistdFunctionTests::StdFunctionConstructionTest", "[wistd]") +{ + // We should be able to capture a std::function in a wistd::function + wistd::function fn; + + { + value_holder holder{ 42 }; + std::function stdFn = [holder]() + { + return holder.value; + }; + + fn = stdFn; + } + + REQUIRE(42 == fn()); +} +#endif + +TEST_CASE("WistdFunctionTests::CopyConstructionTest", "[wistd]") +{ + object_counter_state state; + { + wistd::function copyFrom = [counter = object_counter{ state }]() + { + return counter.state->copy_count; + }; + REQUIRE(0 == copyFrom()); + + auto copyTo = copyFrom; + REQUIRE(1 == copyTo()); + } + + REQUIRE(0 == state.instance_count()); +} + +TEST_CASE("WistdFunctionTests::CopyAssignmentTest", "[wistd]") +{ + object_counter_state state; + { + wistd::function copyTo; + { + wistd::function copyFrom = [counter = object_counter{ state }]() + { + return counter.state->copy_count; + }; + REQUIRE(0 == copyFrom()); + + copyTo = copyFrom; + } + + REQUIRE(1 == copyTo()); + } + + REQUIRE(0 == state.instance_count()); +} + +TEST_CASE("WistdFunctionTests::MoveConstructionTest", "[wistd]") +{ + object_counter_state state; + { + wistd::function moveFrom = [counter = object_counter{ state }]() + { + return counter.state->copy_count; + }; + REQUIRE(0 == moveFrom()); + + auto moveTo = std::move(moveFrom); + REQUIRE(0 == moveTo()); + + // Because we move the underlying function object, we _must_ invalidate the moved from function + REQUIRE_FALSE(moveFrom != nullptr); + } + + REQUIRE(0 == state.instance_count()); +} + +TEST_CASE("WistdFunctionTests::MoveAssignmentTest", "[wistd]") +{ + object_counter_state state; + { + wistd::function moveTo; + { + wistd::function moveFrom = [counter = object_counter{ state }]() + { + return counter.state->copy_count; + }; + REQUIRE(0 == moveFrom()); + + moveTo = std::move(moveFrom); + } + + REQUIRE(0 == moveTo()); + } + + REQUIRE(0 == state.instance_count()); +} + +TEST_CASE("WistdFunctionTests::SwapTest", "[wistd]") +{ + object_counter_state state; + { + wistd::function first; + wistd::function second; + + first.swap(second); + REQUIRE_FALSE(first != nullptr); + REQUIRE_FALSE(second != nullptr); + + first = [counter = object_counter{ state }]() + { + return counter.state->copy_count; + }; + + first.swap(second); + REQUIRE_FALSE(first != nullptr); + REQUIRE(second != nullptr); + REQUIRE(0 == second()); + + first.swap(second); + REQUIRE(first != nullptr); + REQUIRE_FALSE(second != nullptr); + REQUIRE(0 == first()); + + second = [counter = object_counter{ state }]() + { + return counter.state->copy_count; + }; + + first.swap(second); + REQUIRE(first != nullptr); + REQUIRE(second != nullptr); + REQUIRE(0 == first()); + } + + REQUIRE(0 == state.instance_count()); +} + +// MSVC's optimizer has had issues with wistd::function in the past when forwarding wistd::function objects to a +// function that accepts the arguments by value. This test exercises the workaround that we have in place. Note +// that this of course requires building with optimizations enabled +void ForwardingTest(wistd::function getValue, wistd::function negate, wistd::function add) +{ + // Previously, this would cause a runtime crash + REQUIRE(Add(GetValue(), Negate(8)) == add(getValue(), negate(8))); +} + +template +void CallForwardingTest(Args&&... args) +{ + ForwardingTest(wistd::forward(args)...); +} + +TEST_CASE("WistdFunctionTests::OptimizationRegressionTest", "[wistd]") +{ + CallForwardingTest( + wistd::function(GetValue), + wistd::function(Negate), + wistd::function(Add)); +} diff --git a/tests/app/CMakeLists.txt b/tests/app/CMakeLists.txt new file mode 100644 index 0000000..7f75306 --- /dev/null +++ b/tests/app/CMakeLists.txt @@ -0,0 +1,20 @@ + +project(witest.app) +add_executable(witest.app) + +add_definitions(-DWINAPI_FAMILY=WINAPI_FAMILY_PC_APP) + +target_sources(witest.app PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../main.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../CommonTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ComTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../FileSystemTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResourceTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResultTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../SafeCastTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../StlTests.cpp +# ${CMAKE_CURRENT_SOURCE_DIR}/../UniqueWinRTEventTokenTests.cpp +# ${CMAKE_CURRENT_SOURCE_DIR}/../WinRTTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WistdTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../wiTest.cpp + ) diff --git a/tests/catch.hpp b/tests/catch.hpp new file mode 100644 index 0000000..1850fff --- /dev/null +++ b/tests/catch.hpp @@ -0,0 +1,14934 @@ +/* + * Catch v2.7.0 + * Generated: 2019-03-07 21:34:30.252164 + * ---------------------------------------------------------- + * This file has been merged from multiple headers. Please don't edit it directly + * Copyright (c) 2019 Two Blue Cubes Ltd. All rights reserved. + * + * Distributed under the Boost Software License, Version 1.0. (See accompanying + * file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + */ +#ifndef TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED +#define TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED +// start catch.hpp + + +#define CATCH_VERSION_MAJOR 2 +#define CATCH_VERSION_MINOR 7 +#define CATCH_VERSION_PATCH 0 + +#ifdef __clang__ +# pragma clang system_header +#elif defined __GNUC__ +# pragma GCC system_header +#endif + +// start catch_suppress_warnings.h + +#ifdef __clang__ +# ifdef __ICC // icpc defines the __clang__ macro +# pragma warning(push) +# pragma warning(disable: 161 1682) +# else // __ICC +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpadded" +# pragma clang diagnostic ignored "-Wswitch-enum" +# pragma clang diagnostic ignored "-Wcovered-switch-default" +# endif +#elif defined __GNUC__ + // Because REQUIREs trigger GCC's -Wparentheses, and because still + // supported version of g++ have only buggy support for _Pragmas, + // Wparentheses have to be suppressed globally. +# pragma GCC diagnostic ignored "-Wparentheses" // See #674 for details + +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wunused-variable" +# pragma GCC diagnostic ignored "-Wpadded" +#endif +// end catch_suppress_warnings.h +#if defined(CATCH_CONFIG_MAIN) || defined(CATCH_CONFIG_RUNNER) +# define CATCH_IMPL +# define CATCH_CONFIG_ALL_PARTS +#endif + +// In the impl file, we want to have access to all parts of the headers +// Can also be used to sanely support PCHs +#if defined(CATCH_CONFIG_ALL_PARTS) +# define CATCH_CONFIG_EXTERNAL_INTERFACES +# if defined(CATCH_CONFIG_DISABLE_MATCHERS) +# undef CATCH_CONFIG_DISABLE_MATCHERS +# endif +# if !defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER) +# define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER +# endif +#endif + +#if !defined(CATCH_CONFIG_IMPL_ONLY) +// start catch_platform.h + +#ifdef __APPLE__ +# include +# if TARGET_OS_OSX == 1 +# define CATCH_PLATFORM_MAC +# elif TARGET_OS_IPHONE == 1 +# define CATCH_PLATFORM_IPHONE +# endif + +#elif defined(linux) || defined(__linux) || defined(__linux__) +# define CATCH_PLATFORM_LINUX + +#elif defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) || defined(__MINGW32__) +# define CATCH_PLATFORM_WINDOWS +#endif + +// end catch_platform.h + +#ifdef CATCH_IMPL +# ifndef CLARA_CONFIG_MAIN +# define CLARA_CONFIG_MAIN_NOT_DEFINED +# define CLARA_CONFIG_MAIN +# endif +#endif + +// start catch_user_interfaces.h + +namespace Catch { + unsigned int rngSeed(); +} + +// end catch_user_interfaces.h +// start catch_tag_alias_autoregistrar.h + +// start catch_common.h + +// start catch_compiler_capabilities.h + +// Detect a number of compiler features - by compiler +// The following features are defined: +// +// CATCH_CONFIG_COUNTER : is the __COUNTER__ macro supported? +// CATCH_CONFIG_WINDOWS_SEH : is Windows SEH supported? +// CATCH_CONFIG_POSIX_SIGNALS : are POSIX signals supported? +// CATCH_CONFIG_DISABLE_EXCEPTIONS : Are exceptions enabled? +// **************** +// Note to maintainers: if new toggles are added please document them +// in configuration.md, too +// **************** + +// In general each macro has a _NO_ form +// (e.g. CATCH_CONFIG_NO_POSIX_SIGNALS) which disables the feature. +// Many features, at point of detection, define an _INTERNAL_ macro, so they +// can be combined, en-mass, with the _NO_ forms later. + +#ifdef __cplusplus + +# if (__cplusplus >= 201402L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +# define CATCH_CPP14_OR_GREATER +# endif + +# if (__cplusplus >= 201703L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +# define CATCH_CPP17_OR_GREATER +# endif + +#endif + +#if defined(CATCH_CPP17_OR_GREATER) +# define CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS +#endif + +#ifdef __clang__ + +# define CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wexit-time-destructors\"" ) \ + _Pragma( "clang diagnostic ignored \"-Wglobal-constructors\"") +# define CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wparentheses\"" ) +# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +# define CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \ + _Pragma( "clang diagnostic push" ) \ + _Pragma( "clang diagnostic ignored \"-Wunused-variable\"" ) +# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS \ + _Pragma( "clang diagnostic pop" ) + +#endif // __clang__ + +//////////////////////////////////////////////////////////////////////////////// +// Assume that non-Windows platforms support posix signals by default +#if !defined(CATCH_PLATFORM_WINDOWS) + #define CATCH_INTERNAL_CONFIG_POSIX_SIGNALS +#endif + +//////////////////////////////////////////////////////////////////////////////// +// We know some environments not to support full POSIX signals +#if defined(__CYGWIN__) || defined(__QNX__) || defined(__EMSCRIPTEN__) || defined(__DJGPP__) + #define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS +#endif + +#ifdef __OS400__ +# define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS +# define CATCH_CONFIG_COLOUR_NONE +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Android somehow still does not support std::to_string +#if defined(__ANDROID__) +# define CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Not all Windows environments support SEH properly +#if defined(__MINGW32__) +# define CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH +#endif + +//////////////////////////////////////////////////////////////////////////////// +// PS4 +#if defined(__ORBIS__) +# define CATCH_INTERNAL_CONFIG_NO_NEW_CAPTURE +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Cygwin +#ifdef __CYGWIN__ + +// Required for some versions of Cygwin to declare gettimeofday +// see: http://stackoverflow.com/questions/36901803/gettimeofday-not-declared-in-this-scope-cygwin +# define _BSD_SOURCE +// some versions of cygwin (most) do not support std::to_string. Use the libstd check. +// https://gcc.gnu.org/onlinedocs/gcc-4.8.2/libstdc++/api/a01053_source.html line 2812-2813 +# if !((__cplusplus >= 201103L) && defined(_GLIBCXX_USE_C99) \ + && !defined(_GLIBCXX_HAVE_BROKEN_VSWPRINTF)) + +# define CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING + +# endif +#endif // __CYGWIN__ + +//////////////////////////////////////////////////////////////////////////////// +// Visual C++ +#ifdef _MSC_VER + +# if _MSC_VER >= 1900 // Visual Studio 2015 or newer +# define CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS +# endif + +// Universal Windows platform does not support SEH +// Or console colours (or console at all...) +# if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +# define CATCH_CONFIG_COLOUR_NONE +# else +# define CATCH_INTERNAL_CONFIG_WINDOWS_SEH +# endif + +// MSVC traditional preprocessor needs some workaround for __VA_ARGS__ +// _MSVC_TRADITIONAL == 0 means new conformant preprocessor +// _MSVC_TRADITIONAL == 1 means old traditional non-conformant preprocessor +# if !defined(_MSVC_TRADITIONAL) || (defined(_MSVC_TRADITIONAL) && _MSVC_TRADITIONAL) +# define CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +# endif + +#endif // _MSC_VER + +//////////////////////////////////////////////////////////////////////////////// +// Check if we are compiled with -fno-exceptions or equivalent +#if defined(__EXCEPTIONS) || defined(__cpp_exceptions) || defined(_CPPUNWIND) +# define CATCH_INTERNAL_CONFIG_EXCEPTIONS_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////// +// DJGPP +#ifdef __DJGPP__ +# define CATCH_INTERNAL_CONFIG_NO_WCHAR +#endif // __DJGPP__ + +//////////////////////////////////////////////////////////////////////////////// +// Embarcadero C++Build +#if defined(__BORLANDC__) + #define CATCH_INTERNAL_CONFIG_POLYFILL_ISNAN +#endif + +//////////////////////////////////////////////////////////////////////////////// + +// Use of __COUNTER__ is suppressed during code analysis in +// CLion/AppCode 2017.2.x and former, because __COUNTER__ is not properly +// handled by it. +// Otherwise all supported compilers support COUNTER macro, +// but user still might want to turn it off +#if ( !defined(__JETBRAINS_IDE__) || __JETBRAINS_IDE__ >= 20170300L ) + #define CATCH_INTERNAL_CONFIG_COUNTER +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Check if string_view is available and usable +// The check is split apart to work around v140 (VS2015) preprocessor issue... +#if defined(__has_include) +#if __has_include() && defined(CATCH_CPP17_OR_GREATER) +# define CATCH_INTERNAL_CONFIG_CPP17_STRING_VIEW +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// +// Check if optional is available and usable +#if defined(__has_include) +# if __has_include() && defined(CATCH_CPP17_OR_GREATER) +# define CATCH_INTERNAL_CONFIG_CPP17_OPTIONAL +# endif // __has_include() && defined(CATCH_CPP17_OR_GREATER) +#endif // __has_include + +//////////////////////////////////////////////////////////////////////////////// +// Check if variant is available and usable +#if defined(__has_include) +# if __has_include() && defined(CATCH_CPP17_OR_GREATER) +# if defined(__clang__) && (__clang_major__ < 8) + // work around clang bug with libstdc++ https://bugs.llvm.org/show_bug.cgi?id=31852 + // fix should be in clang 8, workaround in libstdc++ 8.2 +# include +# if defined(__GLIBCXX__) && defined(_GLIBCXX_RELEASE) && (_GLIBCXX_RELEASE < 9) +# define CATCH_CONFIG_NO_CPP17_VARIANT +# else +# define CATCH_INTERNAL_CONFIG_CPP17_VARIANT +# endif // defined(__GLIBCXX__) && defined(_GLIBCXX_RELEASE) && (_GLIBCXX_RELEASE < 9) +# else +# define CATCH_INTERNAL_CONFIG_CPP17_VARIANT +# endif // defined(__clang__) && (__clang_major__ < 8) +# endif // __has_include() && defined(CATCH_CPP17_OR_GREATER) +#endif // __has_include + +#if defined(CATCH_INTERNAL_CONFIG_COUNTER) && !defined(CATCH_CONFIG_NO_COUNTER) && !defined(CATCH_CONFIG_COUNTER) +# define CATCH_CONFIG_COUNTER +#endif +#if defined(CATCH_INTERNAL_CONFIG_WINDOWS_SEH) && !defined(CATCH_CONFIG_NO_WINDOWS_SEH) && !defined(CATCH_CONFIG_WINDOWS_SEH) && !defined(CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH) +# define CATCH_CONFIG_WINDOWS_SEH +#endif +// This is set by default, because we assume that unix compilers are posix-signal-compatible by default. +#if defined(CATCH_INTERNAL_CONFIG_POSIX_SIGNALS) && !defined(CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_POSIX_SIGNALS) +# define CATCH_CONFIG_POSIX_SIGNALS +#endif +// This is set by default, because we assume that compilers with no wchar_t support are just rare exceptions. +#if !defined(CATCH_INTERNAL_CONFIG_NO_WCHAR) && !defined(CATCH_CONFIG_NO_WCHAR) && !defined(CATCH_CONFIG_WCHAR) +# define CATCH_CONFIG_WCHAR +#endif + +#if !defined(CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING) && !defined(CATCH_CONFIG_NO_CPP11_TO_STRING) && !defined(CATCH_CONFIG_CPP11_TO_STRING) +# define CATCH_CONFIG_CPP11_TO_STRING +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_OPTIONAL) && !defined(CATCH_CONFIG_NO_CPP17_OPTIONAL) && !defined(CATCH_CONFIG_CPP17_OPTIONAL) +# define CATCH_CONFIG_CPP17_OPTIONAL +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) && !defined(CATCH_CONFIG_NO_CPP17_UNCAUGHT_EXCEPTIONS) && !defined(CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) +# define CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_STRING_VIEW) && !defined(CATCH_CONFIG_NO_CPP17_STRING_VIEW) && !defined(CATCH_CONFIG_CPP17_STRING_VIEW) +# define CATCH_CONFIG_CPP17_STRING_VIEW +#endif + +#if defined(CATCH_INTERNAL_CONFIG_CPP17_VARIANT) && !defined(CATCH_CONFIG_NO_CPP17_VARIANT) && !defined(CATCH_CONFIG_CPP17_VARIANT) +# define CATCH_CONFIG_CPP17_VARIANT +#endif + +#if defined(CATCH_CONFIG_EXPERIMENTAL_REDIRECT) +# define CATCH_INTERNAL_CONFIG_NEW_CAPTURE +#endif + +#if defined(CATCH_INTERNAL_CONFIG_NEW_CAPTURE) && !defined(CATCH_INTERNAL_CONFIG_NO_NEW_CAPTURE) && !defined(CATCH_CONFIG_NO_NEW_CAPTURE) && !defined(CATCH_CONFIG_NEW_CAPTURE) +# define CATCH_CONFIG_NEW_CAPTURE +#endif + +#if !defined(CATCH_INTERNAL_CONFIG_EXCEPTIONS_ENABLED) && !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) +# define CATCH_CONFIG_DISABLE_EXCEPTIONS +#endif + +#if defined(CATCH_INTERNAL_CONFIG_POLYFILL_ISNAN) && !defined(CATCH_CONFIG_NO_POLYFILL_ISNAN) && !defined(CATCH_CONFIG_POLYFILL_ISNAN) +# define CATCH_CONFIG_POLYFILL_ISNAN +#endif + +#if !defined(CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS +#endif +#if !defined(CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS +#endif +#if !defined(CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS) +# define CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS +# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS +#endif + +#if defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) +#define CATCH_TRY if ((true)) +#define CATCH_CATCH_ALL if ((false)) +#define CATCH_CATCH_ANON(type) if ((false)) +#else +#define CATCH_TRY try +#define CATCH_CATCH_ALL catch (...) +#define CATCH_CATCH_ANON(type) catch (type) +#endif + +#if defined(CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR) && !defined(CATCH_CONFIG_NO_TRADITIONAL_MSVC_PREPROCESSOR) && !defined(CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR) +#define CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#endif + +// end catch_compiler_capabilities.h +#define INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) name##line +#define INTERNAL_CATCH_UNIQUE_NAME_LINE( name, line ) INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) +#ifdef CATCH_CONFIG_COUNTER +# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __COUNTER__ ) +#else +# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __LINE__ ) +#endif + +#include +#include +#include + +// We need a dummy global operator<< so we can bring it into Catch namespace later +struct Catch_global_namespace_dummy {}; +std::ostream& operator<<(std::ostream&, Catch_global_namespace_dummy); + +namespace Catch { + + struct CaseSensitive { enum Choice { + Yes, + No + }; }; + + class NonCopyable { + NonCopyable( NonCopyable const& ) = delete; + NonCopyable( NonCopyable && ) = delete; + NonCopyable& operator = ( NonCopyable const& ) = delete; + NonCopyable& operator = ( NonCopyable && ) = delete; + + protected: + NonCopyable(); + virtual ~NonCopyable(); + }; + + struct SourceLineInfo { + + SourceLineInfo() = delete; + SourceLineInfo( char const* _file, std::size_t _line ) noexcept + : file( _file ), + line( _line ) + {} + + SourceLineInfo( SourceLineInfo const& other ) = default; + SourceLineInfo& operator = ( SourceLineInfo const& ) = default; + SourceLineInfo( SourceLineInfo&& ) noexcept = default; + SourceLineInfo& operator = ( SourceLineInfo&& ) noexcept = default; + + bool empty() const noexcept; + bool operator == ( SourceLineInfo const& other ) const noexcept; + bool operator < ( SourceLineInfo const& other ) const noexcept; + + char const* file; + std::size_t line; + }; + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ); + + // Bring in operator<< from global namespace into Catch namespace + // This is necessary because the overload of operator<< above makes + // lookup stop at namespace Catch + using ::operator<<; + + // Use this in variadic streaming macros to allow + // >> +StreamEndStop + // as well as + // >> stuff +StreamEndStop + struct StreamEndStop { + std::string operator+() const; + }; + template + T const& operator + ( T const& value, StreamEndStop ) { + return value; + } +} + +#define CATCH_INTERNAL_LINEINFO \ + ::Catch::SourceLineInfo( __FILE__, static_cast( __LINE__ ) ) + +// end catch_common.h +namespace Catch { + + struct RegistrarForTagAliases { + RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo ); + }; + +} // end namespace Catch + +#define CATCH_REGISTER_TAG_ALIAS( alias, spec ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::RegistrarForTagAliases INTERNAL_CATCH_UNIQUE_NAME( AutoRegisterTagAlias )( alias, spec, CATCH_INTERNAL_LINEINFO ); } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + +// end catch_tag_alias_autoregistrar.h +// start catch_test_registry.h + +// start catch_interfaces_testcase.h + +#include + +namespace Catch { + + class TestSpec; + + struct ITestInvoker { + virtual void invoke () const = 0; + virtual ~ITestInvoker(); + }; + + class TestCase; + struct IConfig; + + struct ITestCaseRegistry { + virtual ~ITestCaseRegistry(); + virtual std::vector const& getAllTests() const = 0; + virtual std::vector const& getAllTestsSorted( IConfig const& config ) const = 0; + }; + + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ); + std::vector filterTests( std::vector const& testCases, TestSpec const& testSpec, IConfig const& config ); + std::vector const& getAllTestCasesSorted( IConfig const& config ); + +} + +// end catch_interfaces_testcase.h +// start catch_stringref.h + +#include +#include +#include + +namespace Catch { + + /// A non-owning string class (similar to the forthcoming std::string_view) + /// Note that, because a StringRef may be a substring of another string, + /// it may not be null terminated. c_str() must return a null terminated + /// string, however, and so the StringRef will internally take ownership + /// (taking a copy), if necessary. In theory this ownership is not externally + /// visible - but it does mean (substring) StringRefs should not be shared between + /// threads. + class StringRef { + public: + using size_type = std::size_t; + + private: + friend struct StringRefTestAccess; + + char const* m_start; + size_type m_size; + + char* m_data = nullptr; + + void takeOwnership(); + + static constexpr char const* const s_empty = ""; + + public: // construction/ assignment + StringRef() noexcept + : StringRef( s_empty, 0 ) + {} + + StringRef( StringRef const& other ) noexcept + : m_start( other.m_start ), + m_size( other.m_size ) + {} + + StringRef( StringRef&& other ) noexcept + : m_start( other.m_start ), + m_size( other.m_size ), + m_data( other.m_data ) + { + other.m_data = nullptr; + } + + StringRef( char const* rawChars ) noexcept; + + StringRef( char const* rawChars, size_type size ) noexcept + : m_start( rawChars ), + m_size( size ) + {} + + StringRef( std::string const& stdString ) noexcept + : m_start( stdString.c_str() ), + m_size( stdString.size() ) + {} + + ~StringRef() noexcept { + delete[] m_data; + } + + auto operator = ( StringRef const &other ) noexcept -> StringRef& { + delete[] m_data; + m_data = nullptr; + m_start = other.m_start; + m_size = other.m_size; + return *this; + } + + operator std::string() const; + + void swap( StringRef& other ) noexcept; + + public: // operators + auto operator == ( StringRef const& other ) const noexcept -> bool; + auto operator != ( StringRef const& other ) const noexcept -> bool; + + auto operator[] ( size_type index ) const noexcept -> char; + + public: // named queries + auto empty() const noexcept -> bool { + return m_size == 0; + } + auto size() const noexcept -> size_type { + return m_size; + } + + auto numberOfCharacters() const noexcept -> size_type; + auto c_str() const -> char const*; + + public: // substrings and searches + auto substr( size_type start, size_type size ) const noexcept -> StringRef; + + // Returns the current start pointer. + // Note that the pointer can change when if the StringRef is a substring + auto currentData() const noexcept -> char const*; + + private: // ownership queries - may not be consistent between calls + auto isOwned() const noexcept -> bool; + auto isSubstring() const noexcept -> bool; + }; + + auto operator + ( StringRef const& lhs, StringRef const& rhs ) -> std::string; + auto operator + ( StringRef const& lhs, char const* rhs ) -> std::string; + auto operator + ( char const* lhs, StringRef const& rhs ) -> std::string; + + auto operator += ( std::string& lhs, StringRef const& sr ) -> std::string&; + auto operator << ( std::ostream& os, StringRef const& sr ) -> std::ostream&; + + inline auto operator "" _sr( char const* rawChars, std::size_t size ) noexcept -> StringRef { + return StringRef( rawChars, size ); + } + +} // namespace Catch + +inline auto operator "" _catch_sr( char const* rawChars, std::size_t size ) noexcept -> Catch::StringRef { + return Catch::StringRef( rawChars, size ); +} + +// end catch_stringref.h +// start catch_type_traits.hpp + + +#include + +namespace Catch{ + +#ifdef CATCH_CPP17_OR_GREATER + template + inline constexpr auto is_unique = std::true_type{}; + + template + inline constexpr auto is_unique = std::bool_constant< + (!std::is_same_v && ...) && is_unique + >{}; +#else + +template +struct is_unique : std::true_type{}; + +template +struct is_unique : std::integral_constant +::value + && is_unique::value + && is_unique::value +>{}; + +#endif +} + +// end catch_type_traits.hpp +// start catch_preprocessor.hpp + + +#define CATCH_RECURSION_LEVEL0(...) __VA_ARGS__ +#define CATCH_RECURSION_LEVEL1(...) CATCH_RECURSION_LEVEL0(CATCH_RECURSION_LEVEL0(CATCH_RECURSION_LEVEL0(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL2(...) CATCH_RECURSION_LEVEL1(CATCH_RECURSION_LEVEL1(CATCH_RECURSION_LEVEL1(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL3(...) CATCH_RECURSION_LEVEL2(CATCH_RECURSION_LEVEL2(CATCH_RECURSION_LEVEL2(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL4(...) CATCH_RECURSION_LEVEL3(CATCH_RECURSION_LEVEL3(CATCH_RECURSION_LEVEL3(__VA_ARGS__))) +#define CATCH_RECURSION_LEVEL5(...) CATCH_RECURSION_LEVEL4(CATCH_RECURSION_LEVEL4(CATCH_RECURSION_LEVEL4(__VA_ARGS__))) + +#ifdef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define INTERNAL_CATCH_EXPAND_VARGS(...) __VA_ARGS__ +// MSVC needs more evaluations +#define CATCH_RECURSION_LEVEL6(...) CATCH_RECURSION_LEVEL5(CATCH_RECURSION_LEVEL5(CATCH_RECURSION_LEVEL5(__VA_ARGS__))) +#define CATCH_RECURSE(...) CATCH_RECURSION_LEVEL6(CATCH_RECURSION_LEVEL6(__VA_ARGS__)) +#else +#define CATCH_RECURSE(...) CATCH_RECURSION_LEVEL5(__VA_ARGS__) +#endif + +#define CATCH_REC_END(...) +#define CATCH_REC_OUT + +#define CATCH_EMPTY() +#define CATCH_DEFER(id) id CATCH_EMPTY() + +#define CATCH_REC_GET_END2() 0, CATCH_REC_END +#define CATCH_REC_GET_END1(...) CATCH_REC_GET_END2 +#define CATCH_REC_GET_END(...) CATCH_REC_GET_END1 +#define CATCH_REC_NEXT0(test, next, ...) next CATCH_REC_OUT +#define CATCH_REC_NEXT1(test, next) CATCH_DEFER ( CATCH_REC_NEXT0 ) ( test, next, 0) +#define CATCH_REC_NEXT(test, next) CATCH_REC_NEXT1(CATCH_REC_GET_END test, next) + +#define CATCH_REC_LIST0(f, x, peek, ...) , f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1) ) ( f, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST1(f, x, peek, ...) , f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST0) ) ( f, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST2(f, x, peek, ...) f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1) ) ( f, peek, __VA_ARGS__ ) + +#define CATCH_REC_LIST0_UD(f, userdata, x, peek, ...) , f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1_UD) ) ( f, userdata, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST1_UD(f, userdata, x, peek, ...) , f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST0_UD) ) ( f, userdata, peek, __VA_ARGS__ ) +#define CATCH_REC_LIST2_UD(f, userdata, x, peek, ...) f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1_UD) ) ( f, userdata, peek, __VA_ARGS__ ) + +// Applies the function macro `f` to each of the remaining parameters, inserts commas between the results, +// and passes userdata as the first parameter to each invocation, +// e.g. CATCH_REC_LIST_UD(f, x, a, b, c) evaluates to f(x, a), f(x, b), f(x, c) +#define CATCH_REC_LIST_UD(f, userdata, ...) CATCH_RECURSE(CATCH_REC_LIST2_UD(f, userdata, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) + +#define CATCH_REC_LIST(f, ...) CATCH_RECURSE(CATCH_REC_LIST2(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0)) + +#define INTERNAL_CATCH_EXPAND1(param) INTERNAL_CATCH_EXPAND2(param) +#define INTERNAL_CATCH_EXPAND2(...) INTERNAL_CATCH_NO## __VA_ARGS__ +#define INTERNAL_CATCH_DEF(...) INTERNAL_CATCH_DEF __VA_ARGS__ +#define INTERNAL_CATCH_NOINTERNAL_CATCH_DEF +#define INTERNAL_CATCH_STRINGIZE(...) INTERNAL_CATCH_STRINGIZE2(__VA_ARGS__) +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define INTERNAL_CATCH_STRINGIZE2(...) #__VA_ARGS__ +#define INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS(param) INTERNAL_CATCH_STRINGIZE(INTERNAL_CATCH_REMOVE_PARENS(param)) +#else +// MSVC is adding extra space and needs another indirection to expand INTERNAL_CATCH_NOINTERNAL_CATCH_DEF +#define INTERNAL_CATCH_STRINGIZE2(...) INTERNAL_CATCH_STRINGIZE3(__VA_ARGS__) +#define INTERNAL_CATCH_STRINGIZE3(...) #__VA_ARGS__ +#define INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS(param) (INTERNAL_CATCH_STRINGIZE(INTERNAL_CATCH_REMOVE_PARENS(param)) + 1) +#endif + +#define INTERNAL_CATCH_REMOVE_PARENS(...) INTERNAL_CATCH_EXPAND1(INTERNAL_CATCH_DEF __VA_ARGS__) + +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME2(Name, ...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME3(Name, __VA_ARGS__) +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME3(Name,...) Name " - " #__VA_ARGS__ +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME(Name,...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME2(Name, INTERNAL_CATCH_REMOVE_PARENS(__VA_ARGS__)) +#else +// MSVC is adding extra space and needs more calls to properly remove () +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME3(Name,...) Name " -" #__VA_ARGS__ +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME1(Name, ...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME2(Name, __VA_ARGS__) +#define INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME(Name, ...) INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME1(Name, INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_REMOVE_PARENS(__VA_ARGS__))) +#endif + +#define INTERNAL_CATCH_MAKE_TYPE_LIST(types) Catch::TypeList + +#define INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(types)\ + CATCH_REC_LIST(INTERNAL_CATCH_MAKE_TYPE_LIST,INTERNAL_CATCH_REMOVE_PARENS(types)) + +// end catch_preprocessor.hpp +// start catch_meta.hpp + + +#include + +namespace Catch { +template< typename... > +struct TypeList {}; + +template< typename... > +struct append; + +template< template class L1 + , typename...E1 + , template class L2 + , typename...E2 +> +struct append< L1, L2 > { + using type = L1; +}; + +template< template class L1 + , typename...E1 + , template class L2 + , typename...E2 + , typename...Rest +> +struct append< L1, L2, Rest...> { + using type = typename append< L1, Rest... >::type; +}; + +template< template class + , typename... +> +struct rewrap; + +template< template class Container + , template class List + , typename...elems +> +struct rewrap> { + using type = TypeList< Container< elems... > >; +}; + +template< template class Container + , template class List + , class...Elems + , typename...Elements> + struct rewrap, Elements...> { + using type = typename append>, typename rewrap::type>::type; +}; + +template< template class...Containers > +struct combine { + template< typename...Types > + struct with_types { + template< template class Final > + struct into { + using type = typename append, typename rewrap::type...>::type; + }; + }; +}; + +template +struct always_false : std::false_type {}; + +} // namespace Catch + +// end catch_meta.hpp +namespace Catch { + +template +class TestInvokerAsMethod : public ITestInvoker { + void (C::*m_testAsMethod)(); +public: + TestInvokerAsMethod( void (C::*testAsMethod)() ) noexcept : m_testAsMethod( testAsMethod ) {} + + void invoke() const override { + C obj; + (obj.*m_testAsMethod)(); + } +}; + +auto makeTestInvoker( void(*testAsFunction)() ) noexcept -> ITestInvoker*; + +template +auto makeTestInvoker( void (C::*testAsMethod)() ) noexcept -> ITestInvoker* { + return new(std::nothrow) TestInvokerAsMethod( testAsMethod ); +} + +struct NameAndTags { + NameAndTags( StringRef const& name_ = StringRef(), StringRef const& tags_ = StringRef() ) noexcept; + StringRef name; + StringRef tags; +}; + +struct AutoReg : NonCopyable { + AutoReg( ITestInvoker* invoker, SourceLineInfo const& lineInfo, StringRef const& classOrMethod, NameAndTags const& nameAndTags ) noexcept; + ~AutoReg(); +}; + +} // end namespace Catch + +#if defined(CATCH_CONFIG_DISABLE) + #define INTERNAL_CATCH_TESTCASE_NO_REGISTRATION( TestName, ... ) \ + static void TestName() + #define INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION( TestName, ClassName, ... ) \ + namespace{ \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName) { \ + void test(); \ + }; \ + } \ + void TestName::test() + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION( TestName, ... ) \ + template \ + static void TestName() + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION( TestName, ClassName, ... ) \ + namespace{ \ + template \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName ) { \ + void test(); \ + }; \ + } \ + template \ + void TestName::test() +#endif + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TESTCASE2( TestName, ... ) \ + static void TestName(); \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( &TestName ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ __VA_ARGS__ } ); } /* NOLINT */ \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + static void TestName() + #define INTERNAL_CATCH_TESTCASE( ... ) \ + INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), __VA_ARGS__ ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, ... ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( &QualifiedMethod ), CATCH_INTERNAL_LINEINFO, "&" #QualifiedMethod, Catch::NameAndTags{ __VA_ARGS__ } ); } /* NOLINT */ \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEST_CASE_METHOD2( TestName, ClassName, ... )\ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName) { \ + void test(); \ + }; \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( Catch::makeTestInvoker( &TestName::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ __VA_ARGS__ } ); /* NOLINT */ \ + } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + void TestName::test() + #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, ... ) \ + INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, __VA_ARGS__ ) + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_REGISTER_TESTCASE( Function, ... ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( Function ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ __VA_ARGS__ } ); /* NOLINT */ \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + + /////////////////////////////////////////////////////////////////////////////// + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_2(TestName, TestFunc, Name, Tags, ... )\ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + template \ + static void TestFunc();\ + namespace {\ + template \ + struct TestName{\ + template \ + TestName(Ts...names){\ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(CATCH_REC_LIST(INTERNAL_CATCH_REMOVE_PARENS, __VA_ARGS__)) \ + using expander = int[];\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestFunc ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ names, Tags } ), 0)... };/* NOLINT */ \ + }\ + };\ + INTERNAL_CATCH_TEMPLATE_REGISTRY_INITIATE(TestName, Name, __VA_ARGS__) \ + }\ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + template \ + static void TestFunc() + +#if defined(CATCH_CPP17_OR_GREATER) +#define CATCH_INTERNAL_CHECK_UNIQUE_TYPES(...) static_assert(Catch::is_unique<__VA_ARGS__>,"Duplicate type detected in declaration of template test case"); +#else +#define CATCH_INTERNAL_CHECK_UNIQUE_TYPES(...) static_assert(Catch::is_unique<__VA_ARGS__>::value,"Duplicate type detected in declaration of template test case"); +#endif + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \ + INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, __VA_ARGS__ ) +#else + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, __VA_ARGS__ ) ) +#endif + + #define INTERNAL_CATCH_TEMPLATE_REGISTRY_INITIATE(TestName, Name, ...)\ + static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\ + TestName(CATCH_REC_LIST_UD(INTERNAL_CATCH_TEMPLATE_UNIQUE_NAME,Name, __VA_ARGS__));\ + return 0;\ + }(); + + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(TestName, TestFuncName, Name, Tags, TmplTypes, TypesList) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + template static void TestFuncName(); \ + namespace { \ + template \ + struct TestName { \ + TestName() { \ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(Types...) \ + int index = 0; \ + using expander = int[]; \ + constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TmplTypes))};\ + constexpr char const* types_list[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TypesList))};\ + constexpr auto num_types = sizeof(types_list) / sizeof(types_list[0]);\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestFuncName ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index / num_types]) + "<" + std::string(types_list[index % num_types]) + ">", Tags } ), index++, 0)... };/* NOLINT */\ + } \ + }; \ + static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){ \ + using TestInit = Catch::combine \ + ::with_types::into::type; \ + TestInit(); \ + return 0; \ + }(); \ + } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + template \ + static void TestFuncName() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\ + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ),Name,Tags,__VA_ARGS__) +#else + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, __VA_ARGS__ ) ) +#endif + + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( TestNameClass, TestName, ClassName, Name, Tags, ... ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ \ + template \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName ) { \ + void test();\ + };\ + template \ + struct TestNameClass{\ + template \ + TestNameClass(Ts...names){\ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(CATCH_REC_LIST(INTERNAL_CATCH_REMOVE_PARENS, __VA_ARGS__)) \ + using expander = int[];\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestName::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ names, Tags } ), 0)... };/* NOLINT */ \ + }\ + };\ + INTERNAL_CATCH_TEMPLATE_REGISTRY_INITIATE(TestNameClass, Name, __VA_ARGS__)\ + }\ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS\ + template \ + void TestName::test() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \ + INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, __VA_ARGS__ ) +#else + #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, __VA_ARGS__ ) ) +#endif + + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2(TestNameClass, TestName, ClassName, Name, Tags, TmplTypes, TypesList)\ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + template \ + struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName ) { \ + void test();\ + };\ + namespace {\ + template\ + struct TestNameClass{\ + TestNameClass(){\ + CATCH_INTERNAL_CHECK_UNIQUE_TYPES(Types...)\ + int index = 0;\ + using expander = int[];\ + constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TmplTypes))};\ + constexpr char const* types_list[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TypesList))};\ + constexpr auto num_types = sizeof(types_list) / sizeof(types_list[0]);\ + (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestName::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index / num_types]) + "<" + std::string(types_list[index % num_types]) + ">", Tags } ), index++, 0)... };/* NOLINT */ \ + }\ + };\ + static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\ + using TestInit = Catch::combine\ + ::with_types::into::type;\ + TestInit();\ + return 0;\ + }(); \ + }\ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + template \ + void TestName::test() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\ + INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, __VA_ARGS__ ) +#else + #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\ + INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, __VA_ARGS__ ) ) +#endif + +// end catch_test_registry.h +// start catch_capture.hpp + +// start catch_assertionhandler.h + +// start catch_assertioninfo.h + +// start catch_result_type.h + +namespace Catch { + + // ResultWas::OfType enum + struct ResultWas { enum OfType { + Unknown = -1, + Ok = 0, + Info = 1, + Warning = 2, + + FailureBit = 0x10, + + ExpressionFailed = FailureBit | 1, + ExplicitFailure = FailureBit | 2, + + Exception = 0x100 | FailureBit, + + ThrewException = Exception | 1, + DidntThrowException = Exception | 2, + + FatalErrorCondition = 0x200 | FailureBit + + }; }; + + bool isOk( ResultWas::OfType resultType ); + bool isJustInfo( int flags ); + + // ResultDisposition::Flags enum + struct ResultDisposition { enum Flags { + Normal = 0x01, + + ContinueOnFailure = 0x02, // Failures fail test, but execution continues + FalseTest = 0x04, // Prefix expression with ! + SuppressFail = 0x08 // Failures are reported but do not fail the test + }; }; + + ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ); + + bool shouldContinueOnFailure( int flags ); + inline bool isFalseTest( int flags ) { return ( flags & ResultDisposition::FalseTest ) != 0; } + bool shouldSuppressFailure( int flags ); + +} // end namespace Catch + +// end catch_result_type.h +namespace Catch { + + struct AssertionInfo + { + StringRef macroName; + SourceLineInfo lineInfo; + StringRef capturedExpression; + ResultDisposition::Flags resultDisposition; + + // We want to delete this constructor but a compiler bug in 4.8 means + // the struct is then treated as non-aggregate + //AssertionInfo() = delete; + }; + +} // end namespace Catch + +// end catch_assertioninfo.h +// start catch_decomposer.h + +// start catch_tostring.h + +#include +#include +#include +#include +// start catch_stream.h + +#include +#include +#include + +namespace Catch { + + std::ostream& cout(); + std::ostream& cerr(); + std::ostream& clog(); + + class StringRef; + + struct IStream { + virtual ~IStream(); + virtual std::ostream& stream() const = 0; + }; + + auto makeStream( StringRef const &filename ) -> IStream const*; + + class ReusableStringStream { + std::size_t m_index; + std::ostream* m_oss; + public: + ReusableStringStream(); + ~ReusableStringStream(); + + auto str() const -> std::string; + + template + auto operator << ( T const& value ) -> ReusableStringStream& { + *m_oss << value; + return *this; + } + auto get() -> std::ostream& { return *m_oss; } + }; +} + +// end catch_stream.h + +#ifdef CATCH_CONFIG_CPP17_STRING_VIEW +#include +#endif + +#ifdef __OBJC__ +// start catch_objc_arc.hpp + +#import + +#ifdef __has_feature +#define CATCH_ARC_ENABLED __has_feature(objc_arc) +#else +#define CATCH_ARC_ENABLED 0 +#endif + +void arcSafeRelease( NSObject* obj ); +id performOptionalSelector( id obj, SEL sel ); + +#if !CATCH_ARC_ENABLED +inline void arcSafeRelease( NSObject* obj ) { + [obj release]; +} +inline id performOptionalSelector( id obj, SEL sel ) { + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; + return nil; +} +#define CATCH_UNSAFE_UNRETAINED +#define CATCH_ARC_STRONG +#else +inline void arcSafeRelease( NSObject* ){} +inline id performOptionalSelector( id obj, SEL sel ) { +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Warc-performSelector-leaks" +#endif + if( [obj respondsToSelector: sel] ) + return [obj performSelector: sel]; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + return nil; +} +#define CATCH_UNSAFE_UNRETAINED __unsafe_unretained +#define CATCH_ARC_STRONG __strong +#endif + +// end catch_objc_arc.hpp +#endif + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable:4180) // We attempt to stream a function (address) by const&, which MSVC complains about but is harmless +#endif + +namespace Catch { + namespace Detail { + + extern const std::string unprintableString; + + std::string rawMemoryToString( const void *object, std::size_t size ); + + template + std::string rawMemoryToString( const T& object ) { + return rawMemoryToString( &object, sizeof(object) ); + } + + template + class IsStreamInsertable { + template + static auto test(int) + -> decltype(std::declval() << std::declval(), std::true_type()); + + template + static auto test(...)->std::false_type; + + public: + static const bool value = decltype(test(0))::value; + }; + + template + std::string convertUnknownEnumToString( E e ); + + template + typename std::enable_if< + !std::is_enum::value && !std::is_base_of::value, + std::string>::type convertUnstreamable( T const& ) { + return Detail::unprintableString; + } + template + typename std::enable_if< + !std::is_enum::value && std::is_base_of::value, + std::string>::type convertUnstreamable(T const& ex) { + return ex.what(); + } + + template + typename std::enable_if< + std::is_enum::value + , std::string>::type convertUnstreamable( T const& value ) { + return convertUnknownEnumToString( value ); + } + +#if defined(_MANAGED) + //! Convert a CLR string to a utf8 std::string + template + std::string clrReferenceToString( T^ ref ) { + if (ref == nullptr) + return std::string("null"); + auto bytes = System::Text::Encoding::UTF8->GetBytes(ref->ToString()); + cli::pin_ptr p = &bytes[0]; + return std::string(reinterpret_cast(p), bytes->Length); + } +#endif + + } // namespace Detail + + // If we decide for C++14, change these to enable_if_ts + template + struct StringMaker { + template + static + typename std::enable_if<::Catch::Detail::IsStreamInsertable::value, std::string>::type + convert(const Fake& value) { + ReusableStringStream rss; + // NB: call using the function-like syntax to avoid ambiguity with + // user-defined templated operator<< under clang. + rss.operator<<(value); + return rss.str(); + } + + template + static + typename std::enable_if::value, std::string>::type + convert( const Fake& value ) { +#if !defined(CATCH_CONFIG_FALLBACK_STRINGIFIER) + return Detail::convertUnstreamable(value); +#else + return CATCH_CONFIG_FALLBACK_STRINGIFIER(value); +#endif + } + }; + + namespace Detail { + + // This function dispatches all stringification requests inside of Catch. + // Should be preferably called fully qualified, like ::Catch::Detail::stringify + template + std::string stringify(const T& e) { + return ::Catch::StringMaker::type>::type>::convert(e); + } + + template + std::string convertUnknownEnumToString( E e ) { + return ::Catch::Detail::stringify(static_cast::type>(e)); + } + +#if defined(_MANAGED) + template + std::string stringify( T^ e ) { + return ::Catch::StringMaker::convert(e); + } +#endif + + } // namespace Detail + + // Some predefined specializations + + template<> + struct StringMaker { + static std::string convert(const std::string& str); + }; + +#ifdef CATCH_CONFIG_CPP17_STRING_VIEW + template<> + struct StringMaker { + static std::string convert(std::string_view str); + }; +#endif + + template<> + struct StringMaker { + static std::string convert(char const * str); + }; + template<> + struct StringMaker { + static std::string convert(char * str); + }; + +#ifdef CATCH_CONFIG_WCHAR + template<> + struct StringMaker { + static std::string convert(const std::wstring& wstr); + }; + +# ifdef CATCH_CONFIG_CPP17_STRING_VIEW + template<> + struct StringMaker { + static std::string convert(std::wstring_view str); + }; +# endif + + template<> + struct StringMaker { + static std::string convert(wchar_t const * str); + }; + template<> + struct StringMaker { + static std::string convert(wchar_t * str); + }; +#endif + + // TBD: Should we use `strnlen` to ensure that we don't go out of the buffer, + // while keeping string semantics? + template + struct StringMaker { + static std::string convert(char const* str) { + return ::Catch::Detail::stringify(std::string{ str }); + } + }; + template + struct StringMaker { + static std::string convert(signed char const* str) { + return ::Catch::Detail::stringify(std::string{ reinterpret_cast(str) }); + } + }; + template + struct StringMaker { + static std::string convert(unsigned char const* str) { + return ::Catch::Detail::stringify(std::string{ reinterpret_cast(str) }); + } + }; + + template<> + struct StringMaker { + static std::string convert(int value); + }; + template<> + struct StringMaker { + static std::string convert(long value); + }; + template<> + struct StringMaker { + static std::string convert(long long value); + }; + template<> + struct StringMaker { + static std::string convert(unsigned int value); + }; + template<> + struct StringMaker { + static std::string convert(unsigned long value); + }; + template<> + struct StringMaker { + static std::string convert(unsigned long long value); + }; + + template<> + struct StringMaker { + static std::string convert(bool b); + }; + + template<> + struct StringMaker { + static std::string convert(char c); + }; + template<> + struct StringMaker { + static std::string convert(signed char c); + }; + template<> + struct StringMaker { + static std::string convert(unsigned char c); + }; + + template<> + struct StringMaker { + static std::string convert(std::nullptr_t); + }; + + template<> + struct StringMaker { + static std::string convert(float value); + }; + template<> + struct StringMaker { + static std::string convert(double value); + }; + + template + struct StringMaker { + template + static std::string convert(U* p) { + if (p) { + return ::Catch::Detail::rawMemoryToString(p); + } else { + return "nullptr"; + } + } + }; + + template + struct StringMaker { + static std::string convert(R C::* p) { + if (p) { + return ::Catch::Detail::rawMemoryToString(p); + } else { + return "nullptr"; + } + } + }; + +#if defined(_MANAGED) + template + struct StringMaker { + static std::string convert( T^ ref ) { + return ::Catch::Detail::clrReferenceToString(ref); + } + }; +#endif + + namespace Detail { + template + std::string rangeToString(InputIterator first, InputIterator last) { + ReusableStringStream rss; + rss << "{ "; + if (first != last) { + rss << ::Catch::Detail::stringify(*first); + for (++first; first != last; ++first) + rss << ", " << ::Catch::Detail::stringify(*first); + } + rss << " }"; + return rss.str(); + } + } + +#ifdef __OBJC__ + template<> + struct StringMaker { + static std::string convert(NSString * nsstring) { + if (!nsstring) + return "nil"; + return std::string("@") + [nsstring UTF8String]; + } + }; + template<> + struct StringMaker { + static std::string convert(NSObject* nsObject) { + return ::Catch::Detail::stringify([nsObject description]); + } + + }; + namespace Detail { + inline std::string stringify( NSString* nsstring ) { + return StringMaker::convert( nsstring ); + } + + } // namespace Detail +#endif // __OBJC__ + +} // namespace Catch + +////////////////////////////////////////////////////// +// Separate std-lib types stringification, so it can be selectively enabled +// This means that we do not bring in + +#if defined(CATCH_CONFIG_ENABLE_ALL_STRINGMAKERS) +# define CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER +# define CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER +# define CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER +# define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER +# define CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER +#endif + +// Separate std::pair specialization +#if defined(CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER) +#include +namespace Catch { + template + struct StringMaker > { + static std::string convert(const std::pair& pair) { + ReusableStringStream rss; + rss << "{ " + << ::Catch::Detail::stringify(pair.first) + << ", " + << ::Catch::Detail::stringify(pair.second) + << " }"; + return rss.str(); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER + +#if defined(CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER) && defined(CATCH_CONFIG_CPP17_OPTIONAL) +#include +namespace Catch { + template + struct StringMaker > { + static std::string convert(const std::optional& optional) { + ReusableStringStream rss; + if (optional.has_value()) { + rss << ::Catch::Detail::stringify(*optional); + } else { + rss << "{ }"; + } + return rss.str(); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER + +// Separate std::tuple specialization +#if defined(CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER) +#include +namespace Catch { + namespace Detail { + template< + typename Tuple, + std::size_t N = 0, + bool = (N < std::tuple_size::value) + > + struct TupleElementPrinter { + static void print(const Tuple& tuple, std::ostream& os) { + os << (N ? ", " : " ") + << ::Catch::Detail::stringify(std::get(tuple)); + TupleElementPrinter::print(tuple, os); + } + }; + + template< + typename Tuple, + std::size_t N + > + struct TupleElementPrinter { + static void print(const Tuple&, std::ostream&) {} + }; + + } + + template + struct StringMaker> { + static std::string convert(const std::tuple& tuple) { + ReusableStringStream rss; + rss << '{'; + Detail::TupleElementPrinter>::print(tuple, rss.get()); + rss << " }"; + return rss.str(); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER + +#if defined(CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER) && defined(CATCH_CONFIG_CPP17_VARIANT) +#include +namespace Catch { + template<> + struct StringMaker { + static std::string convert(const std::monostate&) { + return "{ }"; + } + }; + + template + struct StringMaker> { + static std::string convert(const std::variant& variant) { + if (variant.valueless_by_exception()) { + return "{valueless variant}"; + } else { + return std::visit( + [](const auto& value) { + return ::Catch::Detail::stringify(value); + }, + variant + ); + } + } + }; +} +#endif // CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER + +namespace Catch { + struct not_this_one {}; // Tag type for detecting which begin/ end are being selected + + // Import begin/ end from std here so they are considered alongside the fallback (...) overloads in this namespace + using std::begin; + using std::end; + + not_this_one begin( ... ); + not_this_one end( ... ); + + template + struct is_range { + static const bool value = + !std::is_same())), not_this_one>::value && + !std::is_same())), not_this_one>::value; + }; + +#if defined(_MANAGED) // Managed types are never ranges + template + struct is_range { + static const bool value = false; + }; +#endif + + template + std::string rangeToString( Range const& range ) { + return ::Catch::Detail::rangeToString( begin( range ), end( range ) ); + } + + // Handle vector specially + template + std::string rangeToString( std::vector const& v ) { + ReusableStringStream rss; + rss << "{ "; + bool first = true; + for( bool b : v ) { + if( first ) + first = false; + else + rss << ", "; + rss << ::Catch::Detail::stringify( b ); + } + rss << " }"; + return rss.str(); + } + + template + struct StringMaker::value && !::Catch::Detail::IsStreamInsertable::value>::type> { + static std::string convert( R const& range ) { + return rangeToString( range ); + } + }; + + template + struct StringMaker { + static std::string convert(T const(&arr)[SZ]) { + return rangeToString(arr); + } + }; + +} // namespace Catch + +// Separate std::chrono::duration specialization +#if defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER) +#include +#include +#include + +namespace Catch { + +template +struct ratio_string { + static std::string symbol(); +}; + +template +std::string ratio_string::symbol() { + Catch::ReusableStringStream rss; + rss << '[' << Ratio::num << '/' + << Ratio::den << ']'; + return rss.str(); +} +template <> +struct ratio_string { + static std::string symbol(); +}; +template <> +struct ratio_string { + static std::string symbol(); +}; +template <> +struct ratio_string { + static std::string symbol(); +}; +template <> +struct ratio_string { + static std::string symbol(); +}; +template <> +struct ratio_string { + static std::string symbol(); +}; +template <> +struct ratio_string { + static std::string symbol(); +}; + + //////////// + // std::chrono::duration specializations + template + struct StringMaker> { + static std::string convert(std::chrono::duration const& duration) { + ReusableStringStream rss; + rss << duration.count() << ' ' << ratio_string::symbol() << 's'; + return rss.str(); + } + }; + template + struct StringMaker>> { + static std::string convert(std::chrono::duration> const& duration) { + ReusableStringStream rss; + rss << duration.count() << " s"; + return rss.str(); + } + }; + template + struct StringMaker>> { + static std::string convert(std::chrono::duration> const& duration) { + ReusableStringStream rss; + rss << duration.count() << " m"; + return rss.str(); + } + }; + template + struct StringMaker>> { + static std::string convert(std::chrono::duration> const& duration) { + ReusableStringStream rss; + rss << duration.count() << " h"; + return rss.str(); + } + }; + + //////////// + // std::chrono::time_point specialization + // Generic time_point cannot be specialized, only std::chrono::time_point + template + struct StringMaker> { + static std::string convert(std::chrono::time_point const& time_point) { + return ::Catch::Detail::stringify(time_point.time_since_epoch()) + " since epoch"; + } + }; + // std::chrono::time_point specialization + template + struct StringMaker> { + static std::string convert(std::chrono::time_point const& time_point) { + auto converted = std::chrono::system_clock::to_time_t(time_point); + +#ifdef _MSC_VER + std::tm timeInfo = {}; + gmtime_s(&timeInfo, &converted); +#else + std::tm* timeInfo = std::gmtime(&converted); +#endif + + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + char timeStamp[timeStampSize]; + const char * const fmt = "%Y-%m-%dT%H:%M:%SZ"; + +#ifdef _MSC_VER + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); +#else + std::strftime(timeStamp, timeStampSize, fmt, timeInfo); +#endif + return std::string(timeStamp); + } + }; +} +#endif // CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// end catch_tostring.h +#include + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable:4389) // '==' : signed/unsigned mismatch +#pragma warning(disable:4018) // more "signed/unsigned mismatch" +#pragma warning(disable:4312) // Converting int to T* using reinterpret_cast (issue on x64 platform) +#pragma warning(disable:4180) // qualifier applied to function type has no meaning +#pragma warning(disable:4800) // Forcing result to true or false +#endif + +namespace Catch { + + struct ITransientExpression { + auto isBinaryExpression() const -> bool { return m_isBinaryExpression; } + auto getResult() const -> bool { return m_result; } + virtual void streamReconstructedExpression( std::ostream &os ) const = 0; + + ITransientExpression( bool isBinaryExpression, bool result ) + : m_isBinaryExpression( isBinaryExpression ), + m_result( result ) + {} + + // We don't actually need a virtual destructor, but many static analysers + // complain if it's not here :-( + virtual ~ITransientExpression(); + + bool m_isBinaryExpression; + bool m_result; + + }; + + void formatReconstructedExpression( std::ostream &os, std::string const& lhs, StringRef op, std::string const& rhs ); + + template + class BinaryExpr : public ITransientExpression { + LhsT m_lhs; + StringRef m_op; + RhsT m_rhs; + + void streamReconstructedExpression( std::ostream &os ) const override { + formatReconstructedExpression + ( os, Catch::Detail::stringify( m_lhs ), m_op, Catch::Detail::stringify( m_rhs ) ); + } + + public: + BinaryExpr( bool comparisonResult, LhsT lhs, StringRef op, RhsT rhs ) + : ITransientExpression{ true, comparisonResult }, + m_lhs( lhs ), + m_op( op ), + m_rhs( rhs ) + {} + + template + auto operator && ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator || ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator == ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator != ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator > ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator < ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator >= ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator <= ( T ) const -> BinaryExpr const { + static_assert(always_false::value, + "chained comparisons are not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + }; + + template + class UnaryExpr : public ITransientExpression { + LhsT m_lhs; + + void streamReconstructedExpression( std::ostream &os ) const override { + os << Catch::Detail::stringify( m_lhs ); + } + + public: + explicit UnaryExpr( LhsT lhs ) + : ITransientExpression{ false, static_cast(lhs) }, + m_lhs( lhs ) + {} + }; + + // Specialised comparison functions to handle equality comparisons between ints and pointers (NULL deduces as an int) + template + auto compareEqual( LhsT const& lhs, RhsT const& rhs ) -> bool { return static_cast(lhs == rhs); } + template + auto compareEqual( T* const& lhs, int rhs ) -> bool { return lhs == reinterpret_cast( rhs ); } + template + auto compareEqual( T* const& lhs, long rhs ) -> bool { return lhs == reinterpret_cast( rhs ); } + template + auto compareEqual( int lhs, T* const& rhs ) -> bool { return reinterpret_cast( lhs ) == rhs; } + template + auto compareEqual( long lhs, T* const& rhs ) -> bool { return reinterpret_cast( lhs ) == rhs; } + + template + auto compareNotEqual( LhsT const& lhs, RhsT&& rhs ) -> bool { return static_cast(lhs != rhs); } + template + auto compareNotEqual( T* const& lhs, int rhs ) -> bool { return lhs != reinterpret_cast( rhs ); } + template + auto compareNotEqual( T* const& lhs, long rhs ) -> bool { return lhs != reinterpret_cast( rhs ); } + template + auto compareNotEqual( int lhs, T* const& rhs ) -> bool { return reinterpret_cast( lhs ) != rhs; } + template + auto compareNotEqual( long lhs, T* const& rhs ) -> bool { return reinterpret_cast( lhs ) != rhs; } + + template + class ExprLhs { + LhsT m_lhs; + public: + explicit ExprLhs( LhsT lhs ) : m_lhs( lhs ) {} + + template + auto operator == ( RhsT const& rhs ) -> BinaryExpr const { + return { compareEqual( m_lhs, rhs ), m_lhs, "==", rhs }; + } + auto operator == ( bool rhs ) -> BinaryExpr const { + return { m_lhs == rhs, m_lhs, "==", rhs }; + } + + template + auto operator != ( RhsT const& rhs ) -> BinaryExpr const { + return { compareNotEqual( m_lhs, rhs ), m_lhs, "!=", rhs }; + } + auto operator != ( bool rhs ) -> BinaryExpr const { + return { m_lhs != rhs, m_lhs, "!=", rhs }; + } + + template + auto operator > ( RhsT const& rhs ) -> BinaryExpr const { + return { static_cast(m_lhs > rhs), m_lhs, ">", rhs }; + } + template + auto operator < ( RhsT const& rhs ) -> BinaryExpr const { + return { static_cast(m_lhs < rhs), m_lhs, "<", rhs }; + } + template + auto operator >= ( RhsT const& rhs ) -> BinaryExpr const { + return { static_cast(m_lhs >= rhs), m_lhs, ">=", rhs }; + } + template + auto operator <= ( RhsT const& rhs ) -> BinaryExpr const { + return { static_cast(m_lhs <= rhs), m_lhs, "<=", rhs }; + } + + template + auto operator && ( RhsT const& ) -> BinaryExpr const { + static_assert(always_false::value, + "operator&& is not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + template + auto operator || ( RhsT const& ) -> BinaryExpr const { + static_assert(always_false::value, + "operator|| is not supported inside assertions, " + "wrap the expression inside parentheses, or decompose it"); + } + + auto makeUnaryExpr() const -> UnaryExpr { + return UnaryExpr{ m_lhs }; + } + }; + + void handleExpression( ITransientExpression const& expr ); + + template + void handleExpression( ExprLhs const& expr ) { + handleExpression( expr.makeUnaryExpr() ); + } + + struct Decomposer { + template + auto operator <= ( T const& lhs ) -> ExprLhs { + return ExprLhs{ lhs }; + } + + auto operator <=( bool value ) -> ExprLhs { + return ExprLhs{ value }; + } + }; + +} // end namespace Catch + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +// end catch_decomposer.h +// start catch_interfaces_capture.h + +#include + +namespace Catch { + + class AssertionResult; + struct AssertionInfo; + struct SectionInfo; + struct SectionEndInfo; + struct MessageInfo; + struct MessageBuilder; + struct Counts; + struct BenchmarkInfo; + struct BenchmarkStats; + struct AssertionReaction; + struct SourceLineInfo; + + struct ITransientExpression; + struct IGeneratorTracker; + + struct IResultCapture { + + virtual ~IResultCapture(); + + virtual bool sectionStarted( SectionInfo const& sectionInfo, + Counts& assertions ) = 0; + virtual void sectionEnded( SectionEndInfo const& endInfo ) = 0; + virtual void sectionEndedEarly( SectionEndInfo const& endInfo ) = 0; + + virtual auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& = 0; + + virtual void benchmarkStarting( BenchmarkInfo const& info ) = 0; + virtual void benchmarkEnded( BenchmarkStats const& stats ) = 0; + + virtual void pushScopedMessage( MessageInfo const& message ) = 0; + virtual void popScopedMessage( MessageInfo const& message ) = 0; + + virtual void emplaceUnscopedMessage( MessageBuilder const& builder ) = 0; + + virtual void handleFatalErrorCondition( StringRef message ) = 0; + + virtual void handleExpr + ( AssertionInfo const& info, + ITransientExpression const& expr, + AssertionReaction& reaction ) = 0; + virtual void handleMessage + ( AssertionInfo const& info, + ResultWas::OfType resultType, + StringRef const& message, + AssertionReaction& reaction ) = 0; + virtual void handleUnexpectedExceptionNotThrown + ( AssertionInfo const& info, + AssertionReaction& reaction ) = 0; + virtual void handleUnexpectedInflightException + ( AssertionInfo const& info, + std::string const& message, + AssertionReaction& reaction ) = 0; + virtual void handleIncomplete + ( AssertionInfo const& info ) = 0; + virtual void handleNonExpr + ( AssertionInfo const &info, + ResultWas::OfType resultType, + AssertionReaction &reaction ) = 0; + + virtual bool lastAssertionPassed() = 0; + virtual void assertionPassed() = 0; + + // Deprecated, do not use: + virtual std::string getCurrentTestName() const = 0; + virtual const AssertionResult* getLastResult() const = 0; + virtual void exceptionEarlyReported() = 0; + }; + + IResultCapture& getResultCapture(); +} + +// end catch_interfaces_capture.h +namespace Catch { + + struct TestFailureException{}; + struct AssertionResultData; + struct IResultCapture; + class RunContext; + + class LazyExpression { + friend class AssertionHandler; + friend struct AssertionStats; + friend class RunContext; + + ITransientExpression const* m_transientExpression = nullptr; + bool m_isNegated; + public: + LazyExpression( bool isNegated ); + LazyExpression( LazyExpression const& other ); + LazyExpression& operator = ( LazyExpression const& ) = delete; + + explicit operator bool() const; + + friend auto operator << ( std::ostream& os, LazyExpression const& lazyExpr ) -> std::ostream&; + }; + + struct AssertionReaction { + bool shouldDebugBreak = false; + bool shouldThrow = false; + }; + + class AssertionHandler { + AssertionInfo m_assertionInfo; + AssertionReaction m_reaction; + bool m_completed = false; + IResultCapture& m_resultCapture; + + public: + AssertionHandler + ( StringRef const& macroName, + SourceLineInfo const& lineInfo, + StringRef capturedExpression, + ResultDisposition::Flags resultDisposition ); + ~AssertionHandler() { + if ( !m_completed ) { + m_resultCapture.handleIncomplete( m_assertionInfo ); + } + } + + template + void handleExpr( ExprLhs const& expr ) { + handleExpr( expr.makeUnaryExpr() ); + } + void handleExpr( ITransientExpression const& expr ); + + void handleMessage(ResultWas::OfType resultType, StringRef const& message); + + void handleExceptionThrownAsExpected(); + void handleUnexpectedExceptionNotThrown(); + void handleExceptionNotThrownAsExpected(); + void handleThrowingCallSkipped(); + void handleUnexpectedInflightException(); + + void complete(); + void setCompleted(); + + // query + auto allowThrows() const -> bool; + }; + + void handleExceptionMatchExpr( AssertionHandler& handler, std::string const& str, StringRef const& matcherString ); + +} // namespace Catch + +// end catch_assertionhandler.h +// start catch_message.h + +#include +#include + +namespace Catch { + + struct MessageInfo { + MessageInfo( StringRef const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ); + + StringRef macroName; + std::string message; + SourceLineInfo lineInfo; + ResultWas::OfType type; + unsigned int sequence; + + bool operator == ( MessageInfo const& other ) const; + bool operator < ( MessageInfo const& other ) const; + private: + static unsigned int globalCount; + }; + + struct MessageStream { + + template + MessageStream& operator << ( T const& value ) { + m_stream << value; + return *this; + } + + ReusableStringStream m_stream; + }; + + struct MessageBuilder : MessageStream { + MessageBuilder( StringRef const& macroName, + SourceLineInfo const& lineInfo, + ResultWas::OfType type ); + + template + MessageBuilder& operator << ( T const& value ) { + m_stream << value; + return *this; + } + + MessageInfo m_info; + }; + + class ScopedMessage { + public: + explicit ScopedMessage( MessageBuilder const& builder ); + ScopedMessage( ScopedMessage& duplicate ) = delete; + ScopedMessage( ScopedMessage&& old ); + ~ScopedMessage(); + + MessageInfo m_info; + bool m_moved; + }; + + class Capturer { + std::vector m_messages; + IResultCapture& m_resultCapture = getResultCapture(); + size_t m_captured = 0; + public: + Capturer( StringRef macroName, SourceLineInfo const& lineInfo, ResultWas::OfType resultType, StringRef names ); + ~Capturer(); + + void captureValue( size_t index, std::string const& value ); + + template + void captureValues( size_t index, T const& value ) { + captureValue( index, Catch::Detail::stringify( value ) ); + } + + template + void captureValues( size_t index, T const& value, Ts const&... values ) { + captureValue( index, Catch::Detail::stringify(value) ); + captureValues( index+1, values... ); + } + }; + +} // end namespace Catch + +// end catch_message.h +#if !defined(CATCH_CONFIG_DISABLE) + +#if !defined(CATCH_CONFIG_DISABLE_STRINGIFICATION) + #define CATCH_INTERNAL_STRINGIFY(...) #__VA_ARGS__ +#else + #define CATCH_INTERNAL_STRINGIFY(...) "Disabled by CATCH_CONFIG_DISABLE_STRINGIFICATION" +#endif + +#if defined(CATCH_CONFIG_FAST_COMPILE) || defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + +/////////////////////////////////////////////////////////////////////////////// +// Another way to speed-up compilation is to omit local try-catch for REQUIRE* +// macros. +#define INTERNAL_CATCH_TRY +#define INTERNAL_CATCH_CATCH( capturer ) + +#else // CATCH_CONFIG_FAST_COMPILE + +#define INTERNAL_CATCH_TRY try +#define INTERNAL_CATCH_CATCH( handler ) catch(...) { handler.handleUnexpectedInflightException(); } + +#endif + +#define INTERNAL_CATCH_REACT( handler ) handler.complete(); + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TEST( macroName, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition ); \ + INTERNAL_CATCH_TRY { \ + CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \ + catchAssertionHandler.handleExpr( Catch::Decomposer() <= __VA_ARGS__ ); \ + CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \ + } INTERNAL_CATCH_CATCH( catchAssertionHandler ) \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( (void)0, (false) && static_cast( !!(__VA_ARGS__) ) ) // the expression here is never evaluated at runtime but it forces the compiler to give it a look + // The double negation silences MSVC's C4800 warning, the static_cast forces short-circuit evaluation if the type has overloaded &&. + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_IF( macroName, resultDisposition, ... ) \ + INTERNAL_CATCH_TEST( macroName, resultDisposition, __VA_ARGS__ ); \ + if( Catch::getResultCapture().lastAssertionPassed() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_ELSE( macroName, resultDisposition, ... ) \ + INTERNAL_CATCH_TEST( macroName, resultDisposition, __VA_ARGS__ ); \ + if( !Catch::getResultCapture().lastAssertionPassed() ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_NO_THROW( macroName, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition ); \ + try { \ + static_cast(__VA_ARGS__); \ + catchAssertionHandler.handleExceptionNotThrownAsExpected(); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleUnexpectedInflightException(); \ + } \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS( macroName, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast(__VA_ARGS__); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleExceptionThrownAsExpected(); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS_AS( macroName, exceptionType, resultDisposition, expr ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(expr) ", " CATCH_INTERNAL_STRINGIFY(exceptionType), resultDisposition ); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast(expr); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( exceptionType const& ) { \ + catchAssertionHandler.handleExceptionThrownAsExpected(); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleUnexpectedInflightException(); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::StringRef(), resultDisposition ); \ + catchAssertionHandler.handleMessage( messageType, ( Catch::MessageStream() << __VA_ARGS__ + ::Catch::StreamEndStop() ).m_stream.str() ); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_CAPTURE( varName, macroName, ... ) \ + auto varName = Catch::Capturer( macroName, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info, #__VA_ARGS__ ); \ + varName.captureValues( 0, __VA_ARGS__ ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_INFO( macroName, log ) \ + Catch::ScopedMessage INTERNAL_CATCH_UNIQUE_NAME( scopedMessage )( Catch::MessageBuilder( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log ); + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_UNSCOPED_INFO( macroName, log ) \ + Catch::getResultCapture().emplaceUnscopedMessage( Catch::MessageBuilder( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log ) + +/////////////////////////////////////////////////////////////////////////////// +// Although this is matcher-based, it can be used with just a string +#define INTERNAL_CATCH_THROWS_STR_MATCHES( macroName, resultDisposition, matcher, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast(__VA_ARGS__); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( ... ) { \ + Catch::handleExceptionMatchExpr( catchAssertionHandler, matcher, #matcher##_catch_sr ); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +#endif // CATCH_CONFIG_DISABLE + +// end catch_capture.hpp +// start catch_section.h + +// start catch_section_info.h + +// start catch_totals.h + +#include + +namespace Catch { + + struct Counts { + Counts operator - ( Counts const& other ) const; + Counts& operator += ( Counts const& other ); + + std::size_t total() const; + bool allPassed() const; + bool allOk() const; + + std::size_t passed = 0; + std::size_t failed = 0; + std::size_t failedButOk = 0; + }; + + struct Totals { + + Totals operator - ( Totals const& other ) const; + Totals& operator += ( Totals const& other ); + + Totals delta( Totals const& prevTotals ) const; + + int error = 0; + Counts assertions; + Counts testCases; + }; +} + +// end catch_totals.h +#include + +namespace Catch { + + struct SectionInfo { + SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name ); + + // Deprecated + SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name, + std::string const& ) : SectionInfo( _lineInfo, _name ) {} + + std::string name; + std::string description; // !Deprecated: this will always be empty + SourceLineInfo lineInfo; + }; + + struct SectionEndInfo { + SectionInfo sectionInfo; + Counts prevAssertions; + double durationInSeconds; + }; + +} // end namespace Catch + +// end catch_section_info.h +// start catch_timer.h + +#include + +namespace Catch { + + auto getCurrentNanosecondsSinceEpoch() -> uint64_t; + auto getEstimatedClockResolution() -> uint64_t; + + class Timer { + uint64_t m_nanoseconds = 0; + public: + void start(); + auto getElapsedNanoseconds() const -> uint64_t; + auto getElapsedMicroseconds() const -> uint64_t; + auto getElapsedMilliseconds() const -> unsigned int; + auto getElapsedSeconds() const -> double; + }; + +} // namespace Catch + +// end catch_timer.h +#include + +namespace Catch { + + class Section : NonCopyable { + public: + Section( SectionInfo const& info ); + ~Section(); + + // This indicates whether the section should be executed or not + explicit operator bool() const; + + private: + SectionInfo m_info; + + std::string m_name; + Counts m_assertions; + bool m_sectionIncluded; + Timer m_timer; + }; + +} // end namespace Catch + +#define INTERNAL_CATCH_SECTION( ... ) \ + CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \ + if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, __VA_ARGS__ ) ) \ + CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS + +#define INTERNAL_CATCH_DYNAMIC_SECTION( ... ) \ + CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \ + if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, (Catch::ReusableStringStream() << __VA_ARGS__).str() ) ) \ + CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS + +// end catch_section.h +// start catch_benchmark.h + +#include +#include + +namespace Catch { + + class BenchmarkLooper { + + std::string m_name; + std::size_t m_count = 0; + std::size_t m_iterationsToRun = 1; + uint64_t m_resolution; + Timer m_timer; + + static auto getResolution() -> uint64_t; + public: + // Keep most of this inline as it's on the code path that is being timed + BenchmarkLooper( StringRef name ) + : m_name( name ), + m_resolution( getResolution() ) + { + reportStart(); + m_timer.start(); + } + + explicit operator bool() { + if( m_count < m_iterationsToRun ) + return true; + return needsMoreIterations(); + } + + void increment() { + ++m_count; + } + + void reportStart(); + auto needsMoreIterations() -> bool; + }; + +} // end namespace Catch + +#define BENCHMARK( name ) \ + for( Catch::BenchmarkLooper looper( name ); looper; looper.increment() ) + +// end catch_benchmark.h +// start catch_interfaces_exception.h + +// start catch_interfaces_registry_hub.h + +#include +#include + +namespace Catch { + + class TestCase; + struct ITestCaseRegistry; + struct IExceptionTranslatorRegistry; + struct IExceptionTranslator; + struct IReporterRegistry; + struct IReporterFactory; + struct ITagAliasRegistry; + class StartupExceptionRegistry; + + using IReporterFactoryPtr = std::shared_ptr; + + struct IRegistryHub { + virtual ~IRegistryHub(); + + virtual IReporterRegistry const& getReporterRegistry() const = 0; + virtual ITestCaseRegistry const& getTestCaseRegistry() const = 0; + virtual ITagAliasRegistry const& getTagAliasRegistry() const = 0; + + virtual IExceptionTranslatorRegistry const& getExceptionTranslatorRegistry() const = 0; + + virtual StartupExceptionRegistry const& getStartupExceptionRegistry() const = 0; + }; + + struct IMutableRegistryHub { + virtual ~IMutableRegistryHub(); + virtual void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) = 0; + virtual void registerListener( IReporterFactoryPtr const& factory ) = 0; + virtual void registerTest( TestCase const& testInfo ) = 0; + virtual void registerTranslator( const IExceptionTranslator* translator ) = 0; + virtual void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) = 0; + virtual void registerStartupException() noexcept = 0; + }; + + IRegistryHub const& getRegistryHub(); + IMutableRegistryHub& getMutableRegistryHub(); + void cleanUp(); + std::string translateActiveException(); + +} + +// end catch_interfaces_registry_hub.h +#if defined(CATCH_CONFIG_DISABLE) + #define INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( translatorName, signature) \ + static std::string translatorName( signature ) +#endif + +#include +#include +#include + +namespace Catch { + using exceptionTranslateFunction = std::string(*)(); + + struct IExceptionTranslator; + using ExceptionTranslators = std::vector>; + + struct IExceptionTranslator { + virtual ~IExceptionTranslator(); + virtual std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const = 0; + }; + + struct IExceptionTranslatorRegistry { + virtual ~IExceptionTranslatorRegistry(); + + virtual std::string translateActiveException() const = 0; + }; + + class ExceptionTranslatorRegistrar { + template + class ExceptionTranslator : public IExceptionTranslator { + public: + + ExceptionTranslator( std::string(*translateFunction)( T& ) ) + : m_translateFunction( translateFunction ) + {} + + std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const override { + try { + if( it == itEnd ) + std::rethrow_exception(std::current_exception()); + else + return (*it)->translate( it+1, itEnd ); + } + catch( T& ex ) { + return m_translateFunction( ex ); + } + } + + protected: + std::string(*m_translateFunction)( T& ); + }; + + public: + template + ExceptionTranslatorRegistrar( std::string(*translateFunction)( T& ) ) { + getMutableRegistryHub().registerTranslator + ( new ExceptionTranslator( translateFunction ) ); + } + }; +} + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION2( translatorName, signature ) \ + static std::string translatorName( signature ); \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::ExceptionTranslatorRegistrar INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionRegistrar )( &translatorName ); } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \ + static std::string translatorName( signature ) + +#define INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION2( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature ) + +// end catch_interfaces_exception.h +// start catch_approx.h + +#include + +namespace Catch { +namespace Detail { + + class Approx { + private: + bool equalityComparisonImpl(double other) const; + // Validates the new margin (margin >= 0) + // out-of-line to avoid including stdexcept in the header + void setMargin(double margin); + // Validates the new epsilon (0 < epsilon < 1) + // out-of-line to avoid including stdexcept in the header + void setEpsilon(double epsilon); + + public: + explicit Approx ( double value ); + + static Approx custom(); + + Approx operator-() const; + + template ::value>::type> + Approx operator()( T const& value ) { + Approx approx( static_cast(value) ); + approx.m_epsilon = m_epsilon; + approx.m_margin = m_margin; + approx.m_scale = m_scale; + return approx; + } + + template ::value>::type> + explicit Approx( T const& value ): Approx(static_cast(value)) + {} + + template ::value>::type> + friend bool operator == ( const T& lhs, Approx const& rhs ) { + auto lhs_v = static_cast(lhs); + return rhs.equalityComparisonImpl(lhs_v); + } + + template ::value>::type> + friend bool operator == ( Approx const& lhs, const T& rhs ) { + return operator==( rhs, lhs ); + } + + template ::value>::type> + friend bool operator != ( T const& lhs, Approx const& rhs ) { + return !operator==( lhs, rhs ); + } + + template ::value>::type> + friend bool operator != ( Approx const& lhs, T const& rhs ) { + return !operator==( rhs, lhs ); + } + + template ::value>::type> + friend bool operator <= ( T const& lhs, Approx const& rhs ) { + return static_cast(lhs) < rhs.m_value || lhs == rhs; + } + + template ::value>::type> + friend bool operator <= ( Approx const& lhs, T const& rhs ) { + return lhs.m_value < static_cast(rhs) || lhs == rhs; + } + + template ::value>::type> + friend bool operator >= ( T const& lhs, Approx const& rhs ) { + return static_cast(lhs) > rhs.m_value || lhs == rhs; + } + + template ::value>::type> + friend bool operator >= ( Approx const& lhs, T const& rhs ) { + return lhs.m_value > static_cast(rhs) || lhs == rhs; + } + + template ::value>::type> + Approx& epsilon( T const& newEpsilon ) { + double epsilonAsDouble = static_cast(newEpsilon); + setEpsilon(epsilonAsDouble); + return *this; + } + + template ::value>::type> + Approx& margin( T const& newMargin ) { + double marginAsDouble = static_cast(newMargin); + setMargin(marginAsDouble); + return *this; + } + + template ::value>::type> + Approx& scale( T const& newScale ) { + m_scale = static_cast(newScale); + return *this; + } + + std::string toString() const; + + private: + double m_epsilon; + double m_margin; + double m_scale; + double m_value; + }; +} // end namespace Detail + +namespace literals { + Detail::Approx operator "" _a(long double val); + Detail::Approx operator "" _a(unsigned long long val); +} // end namespace literals + +template<> +struct StringMaker { + static std::string convert(Catch::Detail::Approx const& value); +}; + +} // end namespace Catch + +// end catch_approx.h +// start catch_string_manip.h + +#include +#include + +namespace Catch { + + bool startsWith( std::string const& s, std::string const& prefix ); + bool startsWith( std::string const& s, char prefix ); + bool endsWith( std::string const& s, std::string const& suffix ); + bool endsWith( std::string const& s, char suffix ); + bool contains( std::string const& s, std::string const& infix ); + void toLowerInPlace( std::string& s ); + std::string toLower( std::string const& s ); + std::string trim( std::string const& str ); + bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ); + + struct pluralise { + pluralise( std::size_t count, std::string const& label ); + + friend std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ); + + std::size_t m_count; + std::string m_label; + }; +} + +// end catch_string_manip.h +#ifndef CATCH_CONFIG_DISABLE_MATCHERS +// start catch_capture_matchers.h + +// start catch_matchers.h + +#include +#include + +namespace Catch { +namespace Matchers { + namespace Impl { + + template struct MatchAllOf; + template struct MatchAnyOf; + template struct MatchNotOf; + + class MatcherUntypedBase { + public: + MatcherUntypedBase() = default; + MatcherUntypedBase ( MatcherUntypedBase const& ) = default; + MatcherUntypedBase& operator = ( MatcherUntypedBase const& ) = delete; + std::string toString() const; + + protected: + virtual ~MatcherUntypedBase(); + virtual std::string describe() const = 0; + mutable std::string m_cachedToString; + }; + +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wnon-virtual-dtor" +#endif + + template + struct MatcherMethod { + virtual bool match( ObjectT const& arg ) const = 0; + }; + +#ifdef __clang__ +# pragma clang diagnostic pop +#endif + + template + struct MatcherBase : MatcherUntypedBase, MatcherMethod { + + MatchAllOf operator && ( MatcherBase const& other ) const; + MatchAnyOf operator || ( MatcherBase const& other ) const; + MatchNotOf operator ! () const; + }; + + template + struct MatchAllOf : MatcherBase { + bool match( ArgT const& arg ) const override { + for( auto matcher : m_matchers ) { + if (!matcher->match(arg)) + return false; + } + return true; + } + std::string describe() const override { + std::string description; + description.reserve( 4 + m_matchers.size()*32 ); + description += "( "; + bool first = true; + for( auto matcher : m_matchers ) { + if( first ) + first = false; + else + description += " and "; + description += matcher->toString(); + } + description += " )"; + return description; + } + + MatchAllOf& operator && ( MatcherBase const& other ) { + m_matchers.push_back( &other ); + return *this; + } + + std::vector const*> m_matchers; + }; + template + struct MatchAnyOf : MatcherBase { + + bool match( ArgT const& arg ) const override { + for( auto matcher : m_matchers ) { + if (matcher->match(arg)) + return true; + } + return false; + } + std::string describe() const override { + std::string description; + description.reserve( 4 + m_matchers.size()*32 ); + description += "( "; + bool first = true; + for( auto matcher : m_matchers ) { + if( first ) + first = false; + else + description += " or "; + description += matcher->toString(); + } + description += " )"; + return description; + } + + MatchAnyOf& operator || ( MatcherBase const& other ) { + m_matchers.push_back( &other ); + return *this; + } + + std::vector const*> m_matchers; + }; + + template + struct MatchNotOf : MatcherBase { + + MatchNotOf( MatcherBase const& underlyingMatcher ) : m_underlyingMatcher( underlyingMatcher ) {} + + bool match( ArgT const& arg ) const override { + return !m_underlyingMatcher.match( arg ); + } + + std::string describe() const override { + return "not " + m_underlyingMatcher.toString(); + } + MatcherBase const& m_underlyingMatcher; + }; + + template + MatchAllOf MatcherBase::operator && ( MatcherBase const& other ) const { + return MatchAllOf() && *this && other; + } + template + MatchAnyOf MatcherBase::operator || ( MatcherBase const& other ) const { + return MatchAnyOf() || *this || other; + } + template + MatchNotOf MatcherBase::operator ! () const { + return MatchNotOf( *this ); + } + + } // namespace Impl + +} // namespace Matchers + +using namespace Matchers; +using Matchers::Impl::MatcherBase; + +} // namespace Catch + +// end catch_matchers.h +// start catch_matchers_floating.h + +#include +#include + +namespace Catch { +namespace Matchers { + + namespace Floating { + + enum class FloatingPointKind : uint8_t; + + struct WithinAbsMatcher : MatcherBase { + WithinAbsMatcher(double target, double margin); + bool match(double const& matchee) const override; + std::string describe() const override; + private: + double m_target; + double m_margin; + }; + + struct WithinUlpsMatcher : MatcherBase { + WithinUlpsMatcher(double target, int ulps, FloatingPointKind baseType); + bool match(double const& matchee) const override; + std::string describe() const override; + private: + double m_target; + int m_ulps; + FloatingPointKind m_type; + }; + + } // namespace Floating + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + Floating::WithinUlpsMatcher WithinULP(double target, int maxUlpDiff); + Floating::WithinUlpsMatcher WithinULP(float target, int maxUlpDiff); + Floating::WithinAbsMatcher WithinAbs(double target, double margin); + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_floating.h +// start catch_matchers_generic.hpp + +#include +#include + +namespace Catch { +namespace Matchers { +namespace Generic { + +namespace Detail { + std::string finalizeDescription(const std::string& desc); +} + +template +class PredicateMatcher : public MatcherBase { + std::function m_predicate; + std::string m_description; +public: + + PredicateMatcher(std::function const& elem, std::string const& descr) + :m_predicate(std::move(elem)), + m_description(Detail::finalizeDescription(descr)) + {} + + bool match( T const& item ) const override { + return m_predicate(item); + } + + std::string describe() const override { + return m_description; + } +}; + +} // namespace Generic + + // The following functions create the actual matcher objects. + // The user has to explicitly specify type to the function, because + // infering std::function is hard (but possible) and + // requires a lot of TMP. + template + Generic::PredicateMatcher Predicate(std::function const& predicate, std::string const& description = "") { + return Generic::PredicateMatcher(predicate, description); + } + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_generic.hpp +// start catch_matchers_string.h + +#include + +namespace Catch { +namespace Matchers { + + namespace StdString { + + struct CasedString + { + CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity ); + std::string adjustString( std::string const& str ) const; + std::string caseSensitivitySuffix() const; + + CaseSensitive::Choice m_caseSensitivity; + std::string m_str; + }; + + struct StringMatcherBase : MatcherBase { + StringMatcherBase( std::string const& operation, CasedString const& comparator ); + std::string describe() const override; + + CasedString m_comparator; + std::string m_operation; + }; + + struct EqualsMatcher : StringMatcherBase { + EqualsMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + struct ContainsMatcher : StringMatcherBase { + ContainsMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + struct StartsWithMatcher : StringMatcherBase { + StartsWithMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + struct EndsWithMatcher : StringMatcherBase { + EndsWithMatcher( CasedString const& comparator ); + bool match( std::string const& source ) const override; + }; + + struct RegexMatcher : MatcherBase { + RegexMatcher( std::string regex, CaseSensitive::Choice caseSensitivity ); + bool match( std::string const& matchee ) const override; + std::string describe() const override; + + private: + std::string m_regex; + CaseSensitive::Choice m_caseSensitivity; + }; + + } // namespace StdString + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + + StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + StdString::RegexMatcher Matches( std::string const& regex, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes ); + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_string.h +// start catch_matchers_vector.h + +#include + +namespace Catch { +namespace Matchers { + + namespace Vector { + namespace Detail { + template + size_t count(InputIterator first, InputIterator last, T const& item) { + size_t cnt = 0; + for (; first != last; ++first) { + if (*first == item) { + ++cnt; + } + } + return cnt; + } + template + bool contains(InputIterator first, InputIterator last, T const& item) { + for (; first != last; ++first) { + if (*first == item) { + return true; + } + } + return false; + } + } + + template + struct ContainsElementMatcher : MatcherBase> { + + ContainsElementMatcher(T const &comparator) : m_comparator( comparator) {} + + bool match(std::vector const &v) const override { + for (auto const& el : v) { + if (el == m_comparator) { + return true; + } + } + return false; + } + + std::string describe() const override { + return "Contains: " + ::Catch::Detail::stringify( m_comparator ); + } + + T const& m_comparator; + }; + + template + struct ContainsMatcher : MatcherBase> { + + ContainsMatcher(std::vector const &comparator) : m_comparator( comparator ) {} + + bool match(std::vector const &v) const override { + // !TBD: see note in EqualsMatcher + if (m_comparator.size() > v.size()) + return false; + for (auto const& comparator : m_comparator) { + auto present = false; + for (const auto& el : v) { + if (el == comparator) { + present = true; + break; + } + } + if (!present) { + return false; + } + } + return true; + } + std::string describe() const override { + return "Contains: " + ::Catch::Detail::stringify( m_comparator ); + } + + std::vector const& m_comparator; + }; + + template + struct EqualsMatcher : MatcherBase> { + + EqualsMatcher(std::vector const &comparator) : m_comparator( comparator ) {} + + bool match(std::vector const &v) const override { + // !TBD: This currently works if all elements can be compared using != + // - a more general approach would be via a compare template that defaults + // to using !=. but could be specialised for, e.g. std::vector etc + // - then just call that directly + if (m_comparator.size() != v.size()) + return false; + for (std::size_t i = 0; i < v.size(); ++i) + if (m_comparator[i] != v[i]) + return false; + return true; + } + std::string describe() const override { + return "Equals: " + ::Catch::Detail::stringify( m_comparator ); + } + std::vector const& m_comparator; + }; + + template + struct UnorderedEqualsMatcher : MatcherBase> { + UnorderedEqualsMatcher(std::vector const& target) : m_target(target) {} + bool match(std::vector const& vec) const override { + // Note: This is a reimplementation of std::is_permutation, + // because I don't want to include inside the common path + if (m_target.size() != vec.size()) { + return false; + } + auto lfirst = m_target.begin(), llast = m_target.end(); + auto rfirst = vec.begin(), rlast = vec.end(); + // Cut common prefix to optimize checking of permuted parts + while (lfirst != llast && *lfirst == *rfirst) { + ++lfirst; ++rfirst; + } + if (lfirst == llast) { + return true; + } + + for (auto mid = lfirst; mid != llast; ++mid) { + // Skip already counted items + if (Detail::contains(lfirst, mid, *mid)) { + continue; + } + size_t num_vec = Detail::count(rfirst, rlast, *mid); + if (num_vec == 0 || Detail::count(lfirst, llast, *mid) != num_vec) { + return false; + } + } + + return true; + } + + std::string describe() const override { + return "UnorderedEquals: " + ::Catch::Detail::stringify(m_target); + } + private: + std::vector const& m_target; + }; + + } // namespace Vector + + // The following functions create the actual matcher objects. + // This allows the types to be inferred + + template + Vector::ContainsMatcher Contains( std::vector const& comparator ) { + return Vector::ContainsMatcher( comparator ); + } + + template + Vector::ContainsElementMatcher VectorContains( T const& comparator ) { + return Vector::ContainsElementMatcher( comparator ); + } + + template + Vector::EqualsMatcher Equals( std::vector const& comparator ) { + return Vector::EqualsMatcher( comparator ); + } + + template + Vector::UnorderedEqualsMatcher UnorderedEquals(std::vector const& target) { + return Vector::UnorderedEqualsMatcher(target); + } + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_vector.h +namespace Catch { + + template + class MatchExpr : public ITransientExpression { + ArgT const& m_arg; + MatcherT m_matcher; + StringRef m_matcherString; + public: + MatchExpr( ArgT const& arg, MatcherT const& matcher, StringRef const& matcherString ) + : ITransientExpression{ true, matcher.match( arg ) }, + m_arg( arg ), + m_matcher( matcher ), + m_matcherString( matcherString ) + {} + + void streamReconstructedExpression( std::ostream &os ) const override { + auto matcherAsString = m_matcher.toString(); + os << Catch::Detail::stringify( m_arg ) << ' '; + if( matcherAsString == Detail::unprintableString ) + os << m_matcherString; + else + os << matcherAsString; + } + }; + + using StringMatcher = Matchers::Impl::MatcherBase; + + void handleExceptionMatchExpr( AssertionHandler& handler, StringMatcher const& matcher, StringRef const& matcherString ); + + template + auto makeMatchExpr( ArgT const& arg, MatcherT const& matcher, StringRef const& matcherString ) -> MatchExpr { + return MatchExpr( arg, matcher, matcherString ); + } + +} // namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CHECK_THAT( macroName, matcher, resultDisposition, arg ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(arg) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \ + INTERNAL_CATCH_TRY { \ + catchAssertionHandler.handleExpr( Catch::makeMatchExpr( arg, matcher, #matcher##_catch_sr ) ); \ + } INTERNAL_CATCH_CATCH( catchAssertionHandler ) \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +/////////////////////////////////////////////////////////////////////////////// +#define INTERNAL_CATCH_THROWS_MATCHES( macroName, exceptionType, resultDisposition, matcher, ... ) \ + do { \ + Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__) ", " CATCH_INTERNAL_STRINGIFY(exceptionType) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \ + if( catchAssertionHandler.allowThrows() ) \ + try { \ + static_cast(__VA_ARGS__ ); \ + catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \ + } \ + catch( exceptionType const& ex ) { \ + catchAssertionHandler.handleExpr( Catch::makeMatchExpr( ex, matcher, #matcher##_catch_sr ) ); \ + } \ + catch( ... ) { \ + catchAssertionHandler.handleUnexpectedInflightException(); \ + } \ + else \ + catchAssertionHandler.handleThrowingCallSkipped(); \ + INTERNAL_CATCH_REACT( catchAssertionHandler ) \ + } while( false ) + +// end catch_capture_matchers.h +#endif +// start catch_generators.hpp + +// start catch_interfaces_generatortracker.h + + +#include + +namespace Catch { + + namespace Generators { + class GeneratorUntypedBase { + public: + GeneratorUntypedBase() = default; + virtual ~GeneratorUntypedBase(); + // Attempts to move the generator to the next element + // + // Returns true iff the move succeeded (and a valid element + // can be retrieved). + virtual bool next() = 0; + }; + using GeneratorBasePtr = std::unique_ptr; + + } // namespace Generators + + struct IGeneratorTracker { + virtual ~IGeneratorTracker(); + virtual auto hasGenerator() const -> bool = 0; + virtual auto getGenerator() const -> Generators::GeneratorBasePtr const& = 0; + virtual void setGenerator( Generators::GeneratorBasePtr&& generator ) = 0; + }; + +} // namespace Catch + +// end catch_interfaces_generatortracker.h +// start catch_enforce.h + +#include + +namespace Catch { +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + template + [[noreturn]] + void throw_exception(Ex const& e) { + throw e; + } +#else // ^^ Exceptions are enabled // Exceptions are disabled vv + [[noreturn]] + void throw_exception(std::exception const& e); +#endif +} // namespace Catch; + +#define CATCH_PREPARE_EXCEPTION( type, msg ) \ + type( ( Catch::ReusableStringStream() << msg ).str() ) +#define CATCH_INTERNAL_ERROR( msg ) \ + Catch::throw_exception(CATCH_PREPARE_EXCEPTION( std::logic_error, CATCH_INTERNAL_LINEINFO << ": Internal Catch error: " << msg)) +#define CATCH_ERROR( msg ) \ + Catch::throw_exception(CATCH_PREPARE_EXCEPTION( std::domain_error, msg )) +#define CATCH_RUNTIME_ERROR( msg ) \ + Catch::throw_exception(CATCH_PREPARE_EXCEPTION( std::runtime_error, msg )) +#define CATCH_ENFORCE( condition, msg ) \ + do{ if( !(condition) ) CATCH_ERROR( msg ); } while(false) + +// end catch_enforce.h +#include +#include +#include + +#include +#include + +namespace Catch { + +class GeneratorException : public std::exception { + const char* const m_msg = ""; + +public: + GeneratorException(const char* msg): + m_msg(msg) + {} + + const char* what() const noexcept override final; +}; + +namespace Generators { + + // !TBD move this into its own location? + namespace pf{ + template + std::unique_ptr make_unique( Args&&... args ) { + return std::unique_ptr(new T(std::forward(args)...)); + } + } + + template + struct IGenerator : GeneratorUntypedBase { + virtual ~IGenerator() = default; + + // Returns the current element of the generator + // + // \Precondition The generator is either freshly constructed, + // or the last call to `next()` returned true + virtual T const& get() const = 0; + using type = T; + }; + + template + class SingleValueGenerator final : public IGenerator { + T m_value; + public: + SingleValueGenerator(T const& value) : m_value( value ) {} + SingleValueGenerator(T&& value) : m_value(std::move(value)) {} + + T const& get() const override { + return m_value; + } + bool next() override { + return false; + } + }; + + template + class FixedValuesGenerator final : public IGenerator { + std::vector m_values; + size_t m_idx = 0; + public: + FixedValuesGenerator( std::initializer_list values ) : m_values( values ) {} + + T const& get() const override { + return m_values[m_idx]; + } + bool next() override { + ++m_idx; + return m_idx < m_values.size(); + } + }; + + template + class GeneratorWrapper final { + std::unique_ptr> m_generator; + public: + GeneratorWrapper(std::unique_ptr> generator): + m_generator(std::move(generator)) + {} + T const& get() const { + return m_generator->get(); + } + bool next() { + return m_generator->next(); + } + }; + + template + GeneratorWrapper value(T&& value) { + return GeneratorWrapper(pf::make_unique>(std::forward(value))); + } + template + GeneratorWrapper values(std::initializer_list values) { + return GeneratorWrapper(pf::make_unique>(values)); + } + + template + class Generators : public IGenerator { + std::vector> m_generators; + size_t m_current = 0; + + void populate(GeneratorWrapper&& generator) { + m_generators.emplace_back(std::move(generator)); + } + void populate(T&& val) { + m_generators.emplace_back(value(std::move(val))); + } + template + void populate(U&& val) { + populate(T(std::move(val))); + } + template + void populate(U&& valueOrGenerator, Gs... moreGenerators) { + populate(std::forward(valueOrGenerator)); + populate(std::forward(moreGenerators)...); + } + + public: + template + Generators(Gs... moreGenerators) { + m_generators.reserve(sizeof...(Gs)); + populate(std::forward(moreGenerators)...); + } + + T const& get() const override { + return m_generators[m_current].get(); + } + + bool next() override { + if (m_current >= m_generators.size()) { + return false; + } + const bool current_status = m_generators[m_current].next(); + if (!current_status) { + ++m_current; + } + return m_current < m_generators.size(); + } + }; + + template + GeneratorWrapper> table( std::initializer_list::type...>> tuples ) { + return values>( tuples ); + } + + // Tag type to signal that a generator sequence should convert arguments to a specific type + template + struct as {}; + + template + auto makeGenerators( GeneratorWrapper&& generator, Gs... moreGenerators ) -> Generators { + return Generators(std::move(generator), std::forward(moreGenerators)...); + } + template + auto makeGenerators( GeneratorWrapper&& generator ) -> Generators { + return Generators(std::move(generator)); + } + template + auto makeGenerators( T&& val, Gs... moreGenerators ) -> Generators { + return makeGenerators( value( std::forward( val ) ), std::forward( moreGenerators )... ); + } + template + auto makeGenerators( as, U&& val, Gs... moreGenerators ) -> Generators { + return makeGenerators( value( T( std::forward( val ) ) ), std::forward( moreGenerators )... ); + } + + auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker&; + + template + // Note: The type after -> is weird, because VS2015 cannot parse + // the expression used in the typedef inside, when it is in + // return type. Yeah. + auto generate( SourceLineInfo const& lineInfo, L const& generatorExpression ) -> decltype(std::declval().get()) { + using UnderlyingType = typename decltype(generatorExpression())::type; + + IGeneratorTracker& tracker = acquireGeneratorTracker( lineInfo ); + if (!tracker.hasGenerator()) { + tracker.setGenerator(pf::make_unique>(generatorExpression())); + } + + auto const& generator = static_cast const&>( *tracker.getGenerator() ); + return generator.get(); + } + +} // namespace Generators +} // namespace Catch + +#define GENERATE( ... ) \ + Catch::Generators::generate( CATCH_INTERNAL_LINEINFO, []{ using namespace Catch::Generators; return makeGenerators( __VA_ARGS__ ); } ) + +// end catch_generators.hpp +// start catch_generators_generic.hpp + +namespace Catch { +namespace Generators { + + template + class TakeGenerator : public IGenerator { + GeneratorWrapper m_generator; + size_t m_returned = 0; + size_t m_target; + public: + TakeGenerator(size_t target, GeneratorWrapper&& generator): + m_generator(std::move(generator)), + m_target(target) + { + assert(target != 0 && "Empty generators are not allowed"); + } + T const& get() const override { + return m_generator.get(); + } + bool next() override { + ++m_returned; + if (m_returned >= m_target) { + return false; + } + + const auto success = m_generator.next(); + // If the underlying generator does not contain enough values + // then we cut short as well + if (!success) { + m_returned = m_target; + } + return success; + } + }; + + template + GeneratorWrapper take(size_t target, GeneratorWrapper&& generator) { + return GeneratorWrapper(pf::make_unique>(target, std::move(generator))); + } + + template + class FilterGenerator : public IGenerator { + GeneratorWrapper m_generator; + Predicate m_predicate; + public: + template + FilterGenerator(P&& pred, GeneratorWrapper&& generator): + m_generator(std::move(generator)), + m_predicate(std::forward

(pred)) + { + if (!m_predicate(m_generator.get())) { + // It might happen that there are no values that pass the + // filter. In that case we throw an exception. + auto has_initial_value = next(); + if (!has_initial_value) { + Catch::throw_exception(GeneratorException("No valid value found in filtered generator")); + } + } + } + + T const& get() const override { + return m_generator.get(); + } + + bool next() override { + bool success = m_generator.next(); + if (!success) { + return false; + } + while (!m_predicate(m_generator.get()) && (success = m_generator.next()) == true); + return success; + } + }; + + template + GeneratorWrapper filter(Predicate&& pred, GeneratorWrapper&& generator) { + return GeneratorWrapper(std::unique_ptr>(pf::make_unique>(std::forward(pred), std::move(generator)))); + } + + template + class RepeatGenerator : public IGenerator { + GeneratorWrapper m_generator; + mutable std::vector m_returned; + size_t m_target_repeats; + size_t m_current_repeat = 0; + size_t m_repeat_index = 0; + public: + RepeatGenerator(size_t repeats, GeneratorWrapper&& generator): + m_generator(std::move(generator)), + m_target_repeats(repeats) + { + assert(m_target_repeats > 0 && "Repeat generator must repeat at least once"); + } + + T const& get() const override { + if (m_current_repeat == 0) { + m_returned.push_back(m_generator.get()); + return m_returned.back(); + } + return m_returned[m_repeat_index]; + } + + bool next() override { + // There are 2 basic cases: + // 1) We are still reading the generator + // 2) We are reading our own cache + + // In the first case, we need to poke the underlying generator. + // If it happily moves, we are left in that state, otherwise it is time to start reading from our cache + if (m_current_repeat == 0) { + const auto success = m_generator.next(); + if (!success) { + ++m_current_repeat; + } + return m_current_repeat < m_target_repeats; + } + + // In the second case, we need to move indices forward and check that we haven't run up against the end + ++m_repeat_index; + if (m_repeat_index == m_returned.size()) { + m_repeat_index = 0; + ++m_current_repeat; + } + return m_current_repeat < m_target_repeats; + } + }; + + template + GeneratorWrapper repeat(size_t repeats, GeneratorWrapper&& generator) { + return GeneratorWrapper(pf::make_unique>(repeats, std::move(generator))); + } + + template + class MapGenerator : public IGenerator { + // TBD: provide static assert for mapping function, for friendly error message + GeneratorWrapper m_generator; + Func m_function; + // To avoid returning dangling reference, we have to save the values + T m_cache; + public: + template + MapGenerator(F2&& function, GeneratorWrapper&& generator) : + m_generator(std::move(generator)), + m_function(std::forward(function)), + m_cache(m_function(m_generator.get())) + {} + + T const& get() const override { + return m_cache; + } + bool next() override { + const auto success = m_generator.next(); + if (success) { + m_cache = m_function(m_generator.get()); + } + return success; + } + }; + + template + GeneratorWrapper map(Func&& function, GeneratorWrapper&& generator) { + return GeneratorWrapper( + pf::make_unique>(std::forward(function), std::move(generator)) + ); + } + template + GeneratorWrapper map(Func&& function, GeneratorWrapper&& generator) { + return GeneratorWrapper( + pf::make_unique>(std::forward(function), std::move(generator)) + ); + } + + template + class ChunkGenerator final : public IGenerator> { + std::vector m_chunk; + size_t m_chunk_size; + GeneratorWrapper m_generator; + bool m_used_up = false; + public: + ChunkGenerator(size_t size, GeneratorWrapper generator) : + m_chunk_size(size), m_generator(std::move(generator)) + { + m_chunk.reserve(m_chunk_size); + m_chunk.push_back(m_generator.get()); + for (size_t i = 1; i < m_chunk_size; ++i) { + if (!m_generator.next()) { + Catch::throw_exception(GeneratorException("Not enough values to initialize the first chunk")); + } + m_chunk.push_back(m_generator.get()); + } + } + std::vector const& get() const override { + return m_chunk; + } + bool next() override { + m_chunk.clear(); + for (size_t idx = 0; idx < m_chunk_size; ++idx) { + if (!m_generator.next()) { + return false; + } + m_chunk.push_back(m_generator.get()); + } + return true; + } + }; + + template + GeneratorWrapper> chunk(size_t size, GeneratorWrapper&& generator) { + return GeneratorWrapper>( + pf::make_unique>(size, std::move(generator)) + ); + } + +} // namespace Generators +} // namespace Catch + +// end catch_generators_generic.hpp +// start catch_generators_specific.hpp + +// start catch_context.h + +#include + +namespace Catch { + + struct IResultCapture; + struct IRunner; + struct IConfig; + struct IMutableContext; + + using IConfigPtr = std::shared_ptr; + + struct IContext + { + virtual ~IContext(); + + virtual IResultCapture* getResultCapture() = 0; + virtual IRunner* getRunner() = 0; + virtual IConfigPtr const& getConfig() const = 0; + }; + + struct IMutableContext : IContext + { + virtual ~IMutableContext(); + virtual void setResultCapture( IResultCapture* resultCapture ) = 0; + virtual void setRunner( IRunner* runner ) = 0; + virtual void setConfig( IConfigPtr const& config ) = 0; + + private: + static IMutableContext *currentContext; + friend IMutableContext& getCurrentMutableContext(); + friend void cleanUpContext(); + static void createContext(); + }; + + inline IMutableContext& getCurrentMutableContext() + { + if( !IMutableContext::currentContext ) + IMutableContext::createContext(); + return *IMutableContext::currentContext; + } + + inline IContext& getCurrentContext() + { + return getCurrentMutableContext(); + } + + void cleanUpContext(); +} + +// end catch_context.h +// start catch_interfaces_config.h + +#include +#include +#include +#include + +namespace Catch { + + enum class Verbosity { + Quiet = 0, + Normal, + High + }; + + struct WarnAbout { enum What { + Nothing = 0x00, + NoAssertions = 0x01, + NoTests = 0x02 + }; }; + + struct ShowDurations { enum OrNot { + DefaultForReporter, + Always, + Never + }; }; + struct RunTests { enum InWhatOrder { + InDeclarationOrder, + InLexicographicalOrder, + InRandomOrder + }; }; + struct UseColour { enum YesOrNo { + Auto, + Yes, + No + }; }; + struct WaitForKeypress { enum When { + Never, + BeforeStart = 1, + BeforeExit = 2, + BeforeStartAndExit = BeforeStart | BeforeExit + }; }; + + class TestSpec; + + struct IConfig : NonCopyable { + + virtual ~IConfig(); + + virtual bool allowThrows() const = 0; + virtual std::ostream& stream() const = 0; + virtual std::string name() const = 0; + virtual bool includeSuccessfulResults() const = 0; + virtual bool shouldDebugBreak() const = 0; + virtual bool warnAboutMissingAssertions() const = 0; + virtual bool warnAboutNoTests() const = 0; + virtual int abortAfter() const = 0; + virtual bool showInvisibles() const = 0; + virtual ShowDurations::OrNot showDurations() const = 0; + virtual TestSpec const& testSpec() const = 0; + virtual bool hasTestFilters() const = 0; + virtual RunTests::InWhatOrder runOrder() const = 0; + virtual unsigned int rngSeed() const = 0; + virtual int benchmarkResolutionMultiple() const = 0; + virtual UseColour::YesOrNo useColour() const = 0; + virtual std::vector const& getSectionsToRun() const = 0; + virtual Verbosity verbosity() const = 0; + }; + + using IConfigPtr = std::shared_ptr; +} + +// end catch_interfaces_config.h +#include + +namespace Catch { +namespace Generators { + +template +class RandomFloatingGenerator final : public IGenerator { + // FIXME: What is the right seed? + std::minstd_rand m_rand; + std::uniform_real_distribution m_dist; + Float m_current_number; +public: + + RandomFloatingGenerator(Float a, Float b): + m_rand(getCurrentContext().getConfig()->rngSeed()), + m_dist(a, b) { + static_cast(next()); + } + + Float const& get() const override { + return m_current_number; + } + bool next() override { + m_current_number = m_dist(m_rand); + return true; + } +}; + +template +class RandomIntegerGenerator final : public IGenerator { + std::minstd_rand m_rand; + std::uniform_int_distribution m_dist; + Integer m_current_number; +public: + + RandomIntegerGenerator(Integer a, Integer b): + m_rand(getCurrentContext().getConfig()->rngSeed()), + m_dist(a, b) { + static_cast(next()); + } + + Integer const& get() const override { + return m_current_number; + } + bool next() override { + m_current_number = m_dist(m_rand); + return true; + } +}; + +// TODO: Ideally this would be also constrained against the various char types, +// but I don't expect users to run into that in practice. +template +typename std::enable_if::value && !std::is_same::value, +GeneratorWrapper>::type +random(T a, T b) { + return GeneratorWrapper( + pf::make_unique>(a, b) + ); +} + +template +typename std::enable_if::value, +GeneratorWrapper>::type +random(T a, T b) { + return GeneratorWrapper( + pf::make_unique>(a, b) + ); +} + +template +class RangeGenerator final : public IGenerator { + T m_current; + T m_end; + T m_step; + bool m_positive; + +public: + RangeGenerator(T const& start, T const& end, T const& step): + m_current(start), + m_end(end), + m_step(step), + m_positive(m_step > T(0)) + { + assert(m_current != m_end && "Range start and end cannot be equal"); + assert(m_step != T(0) && "Step size cannot be zero"); + assert(((m_positive && m_current <= m_end) || (!m_positive && m_current >= m_end)) && "Step moves away from end"); + } + + RangeGenerator(T const& start, T const& end): + RangeGenerator(start, end, (start < end) ? T(1) : T(-1)) + {} + + T const& get() const override { + return m_current; + } + + bool next() override { + m_current += m_step; + return (m_positive) ? (m_current < m_end) : (m_current > m_end); + } +}; + +template +GeneratorWrapper range(T const& start, T const& end, T const& step) { + static_assert(std::is_integral::value && !std::is_same::value, "Type must be an integer"); + return GeneratorWrapper(pf::make_unique>(start, end, step)); +} + +template +GeneratorWrapper range(T const& start, T const& end) { + static_assert(std::is_integral::value && !std::is_same::value, "Type must be an integer"); + return GeneratorWrapper(pf::make_unique>(start, end)); +} + +} // namespace Generators +} // namespace Catch + +// end catch_generators_specific.hpp + +// These files are included here so the single_include script doesn't put them +// in the conditionally compiled sections +// start catch_test_case_info.h + +#include +#include +#include + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +namespace Catch { + + struct ITestInvoker; + + struct TestCaseInfo { + enum SpecialProperties{ + None = 0, + IsHidden = 1 << 1, + ShouldFail = 1 << 2, + MayFail = 1 << 3, + Throws = 1 << 4, + NonPortable = 1 << 5, + Benchmark = 1 << 6 + }; + + TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::vector const& _tags, + SourceLineInfo const& _lineInfo ); + + friend void setTags( TestCaseInfo& testCaseInfo, std::vector tags ); + + bool isHidden() const; + bool throws() const; + bool okToFail() const; + bool expectedToFail() const; + + std::string tagsAsString() const; + + std::string name; + std::string className; + std::string description; + std::vector tags; + std::vector lcaseTags; + SourceLineInfo lineInfo; + SpecialProperties properties; + }; + + class TestCase : public TestCaseInfo { + public: + + TestCase( ITestInvoker* testCase, TestCaseInfo&& info ); + + TestCase withName( std::string const& _newName ) const; + + void invoke() const; + + TestCaseInfo const& getTestCaseInfo() const; + + bool operator == ( TestCase const& other ) const; + bool operator < ( TestCase const& other ) const; + + private: + std::shared_ptr test; + }; + + TestCase makeTestCase( ITestInvoker* testCase, + std::string const& className, + NameAndTags const& nameAndTags, + SourceLineInfo const& lineInfo ); +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_test_case_info.h +// start catch_interfaces_runner.h + +namespace Catch { + + struct IRunner { + virtual ~IRunner(); + virtual bool aborting() const = 0; + }; +} + +// end catch_interfaces_runner.h + +#ifdef __OBJC__ +// start catch_objc.hpp + +#import + +#include + +// NB. Any general catch headers included here must be included +// in catch.hpp first to make sure they are included by the single +// header for non obj-usage + +/////////////////////////////////////////////////////////////////////////////// +// This protocol is really only here for (self) documenting purposes, since +// all its methods are optional. +@protocol OcFixture + +@optional + +-(void) setUp; +-(void) tearDown; + +@end + +namespace Catch { + + class OcMethod : public ITestInvoker { + + public: + OcMethod( Class cls, SEL sel ) : m_cls( cls ), m_sel( sel ) {} + + virtual void invoke() const { + id obj = [[m_cls alloc] init]; + + performOptionalSelector( obj, @selector(setUp) ); + performOptionalSelector( obj, m_sel ); + performOptionalSelector( obj, @selector(tearDown) ); + + arcSafeRelease( obj ); + } + private: + virtual ~OcMethod() {} + + Class m_cls; + SEL m_sel; + }; + + namespace Detail{ + + inline std::string getAnnotation( Class cls, + std::string const& annotationName, + std::string const& testCaseName ) { + NSString* selStr = [[NSString alloc] initWithFormat:@"Catch_%s_%s", annotationName.c_str(), testCaseName.c_str()]; + SEL sel = NSSelectorFromString( selStr ); + arcSafeRelease( selStr ); + id value = performOptionalSelector( cls, sel ); + if( value ) + return [(NSString*)value UTF8String]; + return ""; + } + } + + inline std::size_t registerTestMethods() { + std::size_t noTestMethods = 0; + int noClasses = objc_getClassList( nullptr, 0 ); + + Class* classes = (CATCH_UNSAFE_UNRETAINED Class *)malloc( sizeof(Class) * noClasses); + objc_getClassList( classes, noClasses ); + + for( int c = 0; c < noClasses; c++ ) { + Class cls = classes[c]; + { + u_int count; + Method* methods = class_copyMethodList( cls, &count ); + for( u_int m = 0; m < count ; m++ ) { + SEL selector = method_getName(methods[m]); + std::string methodName = sel_getName(selector); + if( startsWith( methodName, "Catch_TestCase_" ) ) { + std::string testCaseName = methodName.substr( 15 ); + std::string name = Detail::getAnnotation( cls, "Name", testCaseName ); + std::string desc = Detail::getAnnotation( cls, "Description", testCaseName ); + const char* className = class_getName( cls ); + + getMutableRegistryHub().registerTest( makeTestCase( new OcMethod( cls, selector ), className, NameAndTags( name.c_str(), desc.c_str() ), SourceLineInfo("",0) ) ); + noTestMethods++; + } + } + free(methods); + } + } + return noTestMethods; + } + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) + + namespace Matchers { + namespace Impl { + namespace NSStringMatchers { + + struct StringHolder : MatcherBase{ + StringHolder( NSString* substr ) : m_substr( [substr copy] ){} + StringHolder( StringHolder const& other ) : m_substr( [other.m_substr copy] ){} + StringHolder() { + arcSafeRelease( m_substr ); + } + + bool match( NSString* arg ) const override { + return false; + } + + NSString* CATCH_ARC_STRONG m_substr; + }; + + struct Equals : StringHolder { + Equals( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const override { + return (str != nil || m_substr == nil ) && + [str isEqualToString:m_substr]; + } + + std::string describe() const override { + return "equals string: " + Catch::Detail::stringify( m_substr ); + } + }; + + struct Contains : StringHolder { + Contains( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location != NSNotFound; + } + + std::string describe() const override { + return "contains string: " + Catch::Detail::stringify( m_substr ); + } + }; + + struct StartsWith : StringHolder { + StartsWith( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const override { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == 0; + } + + std::string describe() const override { + return "starts with: " + Catch::Detail::stringify( m_substr ); + } + }; + struct EndsWith : StringHolder { + EndsWith( NSString* substr ) : StringHolder( substr ){} + + bool match( NSString* str ) const override { + return (str != nil || m_substr == nil ) && + [str rangeOfString:m_substr].location == [str length] - [m_substr length]; + } + + std::string describe() const override { + return "ends with: " + Catch::Detail::stringify( m_substr ); + } + }; + + } // namespace NSStringMatchers + } // namespace Impl + + inline Impl::NSStringMatchers::Equals + Equals( NSString* substr ){ return Impl::NSStringMatchers::Equals( substr ); } + + inline Impl::NSStringMatchers::Contains + Contains( NSString* substr ){ return Impl::NSStringMatchers::Contains( substr ); } + + inline Impl::NSStringMatchers::StartsWith + StartsWith( NSString* substr ){ return Impl::NSStringMatchers::StartsWith( substr ); } + + inline Impl::NSStringMatchers::EndsWith + EndsWith( NSString* substr ){ return Impl::NSStringMatchers::EndsWith( substr ); } + + } // namespace Matchers + + using namespace Matchers; + +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +} // namespace Catch + +/////////////////////////////////////////////////////////////////////////////// +#define OC_MAKE_UNIQUE_NAME( root, uniqueSuffix ) root##uniqueSuffix +#define OC_TEST_CASE2( name, desc, uniqueSuffix ) \ ++(NSString*) OC_MAKE_UNIQUE_NAME( Catch_Name_test_, uniqueSuffix ) \ +{ \ +return @ name; \ +} \ ++(NSString*) OC_MAKE_UNIQUE_NAME( Catch_Description_test_, uniqueSuffix ) \ +{ \ +return @ desc; \ +} \ +-(void) OC_MAKE_UNIQUE_NAME( Catch_TestCase_test_, uniqueSuffix ) + +#define OC_TEST_CASE( name, desc ) OC_TEST_CASE2( name, desc, __LINE__ ) + +// end catch_objc.hpp +#endif + +#ifdef CATCH_CONFIG_EXTERNAL_INTERFACES +// start catch_external_interfaces.h + +// start catch_reporter_bases.hpp + +// start catch_interfaces_reporter.h + +// start catch_config.hpp + +// start catch_test_spec_parser.h + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// start catch_test_spec.h + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpadded" +#endif + +// start catch_wildcard_pattern.h + +namespace Catch +{ + class WildcardPattern { + enum WildcardPosition { + NoWildcard = 0, + WildcardAtStart = 1, + WildcardAtEnd = 2, + WildcardAtBothEnds = WildcardAtStart | WildcardAtEnd + }; + + public: + + WildcardPattern( std::string const& pattern, CaseSensitive::Choice caseSensitivity ); + virtual ~WildcardPattern() = default; + virtual bool matches( std::string const& str ) const; + + private: + std::string adjustCase( std::string const& str ) const; + CaseSensitive::Choice m_caseSensitivity; + WildcardPosition m_wildcard = NoWildcard; + std::string m_pattern; + }; +} + +// end catch_wildcard_pattern.h +#include +#include +#include + +namespace Catch { + + class TestSpec { + struct Pattern { + virtual ~Pattern(); + virtual bool matches( TestCaseInfo const& testCase ) const = 0; + }; + using PatternPtr = std::shared_ptr; + + class NamePattern : public Pattern { + public: + NamePattern( std::string const& name ); + virtual ~NamePattern(); + virtual bool matches( TestCaseInfo const& testCase ) const override; + private: + WildcardPattern m_wildcardPattern; + }; + + class TagPattern : public Pattern { + public: + TagPattern( std::string const& tag ); + virtual ~TagPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const override; + private: + std::string m_tag; + }; + + class ExcludedPattern : public Pattern { + public: + ExcludedPattern( PatternPtr const& underlyingPattern ); + virtual ~ExcludedPattern(); + virtual bool matches( TestCaseInfo const& testCase ) const override; + private: + PatternPtr m_underlyingPattern; + }; + + struct Filter { + std::vector m_patterns; + + bool matches( TestCaseInfo const& testCase ) const; + }; + + public: + bool hasFilters() const; + bool matches( TestCaseInfo const& testCase ) const; + + private: + std::vector m_filters; + + friend class TestSpecParser; + }; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_test_spec.h +// start catch_interfaces_tag_alias_registry.h + +#include + +namespace Catch { + + struct TagAlias; + + struct ITagAliasRegistry { + virtual ~ITagAliasRegistry(); + // Nullptr if not present + virtual TagAlias const* find( std::string const& alias ) const = 0; + virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const = 0; + + static ITagAliasRegistry const& get(); + }; + +} // end namespace Catch + +// end catch_interfaces_tag_alias_registry.h +namespace Catch { + + class TestSpecParser { + enum Mode{ None, Name, QuotedName, Tag, EscapedName }; + Mode m_mode = None; + bool m_exclusion = false; + std::size_t m_start = std::string::npos, m_pos = 0; + std::string m_arg; + std::vector m_escapeChars; + TestSpec::Filter m_currentFilter; + TestSpec m_testSpec; + ITagAliasRegistry const* m_tagAliases = nullptr; + + public: + TestSpecParser( ITagAliasRegistry const& tagAliases ); + + TestSpecParser& parse( std::string const& arg ); + TestSpec testSpec(); + + private: + void visitChar( char c ); + void startNewMode( Mode mode, std::size_t start ); + void escape(); + std::string subString() const; + + template + void addPattern() { + std::string token = subString(); + for( std::size_t i = 0; i < m_escapeChars.size(); ++i ) + token = token.substr( 0, m_escapeChars[i]-m_start-i ) + token.substr( m_escapeChars[i]-m_start-i+1 ); + m_escapeChars.clear(); + if( startsWith( token, "exclude:" ) ) { + m_exclusion = true; + token = token.substr( 8 ); + } + if( !token.empty() ) { + TestSpec::PatternPtr pattern = std::make_shared( token ); + if( m_exclusion ) + pattern = std::make_shared( pattern ); + m_currentFilter.m_patterns.push_back( pattern ); + } + m_exclusion = false; + m_mode = None; + } + + void addFilter(); + }; + TestSpec parseTestSpec( std::string const& arg ); + +} // namespace Catch + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_test_spec_parser.h +// Libstdc++ doesn't like incomplete classes for unique_ptr + +#include +#include +#include + +#ifndef CATCH_CONFIG_CONSOLE_WIDTH +#define CATCH_CONFIG_CONSOLE_WIDTH 80 +#endif + +namespace Catch { + + struct IStream; + + struct ConfigData { + bool listTests = false; + bool listTags = false; + bool listReporters = false; + bool listTestNamesOnly = false; + + bool showSuccessfulTests = false; + bool shouldDebugBreak = false; + bool noThrow = false; + bool showHelp = false; + bool showInvisibles = false; + bool filenamesAsTags = false; + bool libIdentify = false; + + int abortAfter = -1; + unsigned int rngSeed = 0; + int benchmarkResolutionMultiple = 100; + + Verbosity verbosity = Verbosity::Normal; + WarnAbout::What warnings = WarnAbout::Nothing; + ShowDurations::OrNot showDurations = ShowDurations::DefaultForReporter; + RunTests::InWhatOrder runOrder = RunTests::InDeclarationOrder; + UseColour::YesOrNo useColour = UseColour::Auto; + WaitForKeypress::When waitForKeypress = WaitForKeypress::Never; + + std::string outputFilename; + std::string name; + std::string processName; +#ifndef CATCH_CONFIG_DEFAULT_REPORTER +#define CATCH_CONFIG_DEFAULT_REPORTER "console" +#endif + std::string reporterName = CATCH_CONFIG_DEFAULT_REPORTER; +#undef CATCH_CONFIG_DEFAULT_REPORTER + + std::vector testsOrTags; + std::vector sectionsToRun; + }; + + class Config : public IConfig { + public: + + Config() = default; + Config( ConfigData const& data ); + virtual ~Config() = default; + + std::string const& getFilename() const; + + bool listTests() const; + bool listTestNamesOnly() const; + bool listTags() const; + bool listReporters() const; + + std::string getProcessName() const; + std::string const& getReporterName() const; + + std::vector const& getTestsOrTags() const; + std::vector const& getSectionsToRun() const override; + + virtual TestSpec const& testSpec() const override; + bool hasTestFilters() const override; + + bool showHelp() const; + + // IConfig interface + bool allowThrows() const override; + std::ostream& stream() const override; + std::string name() const override; + bool includeSuccessfulResults() const override; + bool warnAboutMissingAssertions() const override; + bool warnAboutNoTests() const override; + ShowDurations::OrNot showDurations() const override; + RunTests::InWhatOrder runOrder() const override; + unsigned int rngSeed() const override; + int benchmarkResolutionMultiple() const override; + UseColour::YesOrNo useColour() const override; + bool shouldDebugBreak() const override; + int abortAfter() const override; + bool showInvisibles() const override; + Verbosity verbosity() const override; + + private: + + IStream const* openStream(); + ConfigData m_data; + + std::unique_ptr m_stream; + TestSpec m_testSpec; + bool m_hasTestFilters = false; + }; + +} // end namespace Catch + +// end catch_config.hpp +// start catch_assertionresult.h + +#include + +namespace Catch { + + struct AssertionResultData + { + AssertionResultData() = delete; + + AssertionResultData( ResultWas::OfType _resultType, LazyExpression const& _lazyExpression ); + + std::string message; + mutable std::string reconstructedExpression; + LazyExpression lazyExpression; + ResultWas::OfType resultType; + + std::string reconstructExpression() const; + }; + + class AssertionResult { + public: + AssertionResult() = delete; + AssertionResult( AssertionInfo const& info, AssertionResultData const& data ); + + bool isOk() const; + bool succeeded() const; + ResultWas::OfType getResultType() const; + bool hasExpression() const; + bool hasMessage() const; + std::string getExpression() const; + std::string getExpressionInMacro() const; + bool hasExpandedExpression() const; + std::string getExpandedExpression() const; + std::string getMessage() const; + SourceLineInfo getSourceInfo() const; + StringRef getTestMacroName() const; + + //protected: + AssertionInfo m_info; + AssertionResultData m_resultData; + }; + +} // end namespace Catch + +// end catch_assertionresult.h +// start catch_option.hpp + +namespace Catch { + + // An optional type + template + class Option { + public: + Option() : nullableValue( nullptr ) {} + Option( T const& _value ) + : nullableValue( new( storage ) T( _value ) ) + {} + Option( Option const& _other ) + : nullableValue( _other ? new( storage ) T( *_other ) : nullptr ) + {} + + ~Option() { + reset(); + } + + Option& operator= ( Option const& _other ) { + if( &_other != this ) { + reset(); + if( _other ) + nullableValue = new( storage ) T( *_other ); + } + return *this; + } + Option& operator = ( T const& _value ) { + reset(); + nullableValue = new( storage ) T( _value ); + return *this; + } + + void reset() { + if( nullableValue ) + nullableValue->~T(); + nullableValue = nullptr; + } + + T& operator*() { return *nullableValue; } + T const& operator*() const { return *nullableValue; } + T* operator->() { return nullableValue; } + const T* operator->() const { return nullableValue; } + + T valueOr( T const& defaultValue ) const { + return nullableValue ? *nullableValue : defaultValue; + } + + bool some() const { return nullableValue != nullptr; } + bool none() const { return nullableValue == nullptr; } + + bool operator !() const { return nullableValue == nullptr; } + explicit operator bool() const { + return some(); + } + + private: + T *nullableValue; + alignas(alignof(T)) char storage[sizeof(T)]; + }; + +} // end namespace Catch + +// end catch_option.hpp +#include +#include +#include +#include +#include + +namespace Catch { + + struct ReporterConfig { + explicit ReporterConfig( IConfigPtr const& _fullConfig ); + + ReporterConfig( IConfigPtr const& _fullConfig, std::ostream& _stream ); + + std::ostream& stream() const; + IConfigPtr fullConfig() const; + + private: + std::ostream* m_stream; + IConfigPtr m_fullConfig; + }; + + struct ReporterPreferences { + bool shouldRedirectStdOut = false; + bool shouldReportAllAssertions = false; + }; + + template + struct LazyStat : Option { + LazyStat& operator=( T const& _value ) { + Option::operator=( _value ); + used = false; + return *this; + } + void reset() { + Option::reset(); + used = false; + } + bool used = false; + }; + + struct TestRunInfo { + TestRunInfo( std::string const& _name ); + std::string name; + }; + struct GroupInfo { + GroupInfo( std::string const& _name, + std::size_t _groupIndex, + std::size_t _groupsCount ); + + std::string name; + std::size_t groupIndex; + std::size_t groupsCounts; + }; + + struct AssertionStats { + AssertionStats( AssertionResult const& _assertionResult, + std::vector const& _infoMessages, + Totals const& _totals ); + + AssertionStats( AssertionStats const& ) = default; + AssertionStats( AssertionStats && ) = default; + AssertionStats& operator = ( AssertionStats const& ) = delete; + AssertionStats& operator = ( AssertionStats && ) = delete; + virtual ~AssertionStats(); + + AssertionResult assertionResult; + std::vector infoMessages; + Totals totals; + }; + + struct SectionStats { + SectionStats( SectionInfo const& _sectionInfo, + Counts const& _assertions, + double _durationInSeconds, + bool _missingAssertions ); + SectionStats( SectionStats const& ) = default; + SectionStats( SectionStats && ) = default; + SectionStats& operator = ( SectionStats const& ) = default; + SectionStats& operator = ( SectionStats && ) = default; + virtual ~SectionStats(); + + SectionInfo sectionInfo; + Counts assertions; + double durationInSeconds; + bool missingAssertions; + }; + + struct TestCaseStats { + TestCaseStats( TestCaseInfo const& _testInfo, + Totals const& _totals, + std::string const& _stdOut, + std::string const& _stdErr, + bool _aborting ); + + TestCaseStats( TestCaseStats const& ) = default; + TestCaseStats( TestCaseStats && ) = default; + TestCaseStats& operator = ( TestCaseStats const& ) = default; + TestCaseStats& operator = ( TestCaseStats && ) = default; + virtual ~TestCaseStats(); + + TestCaseInfo testInfo; + Totals totals; + std::string stdOut; + std::string stdErr; + bool aborting; + }; + + struct TestGroupStats { + TestGroupStats( GroupInfo const& _groupInfo, + Totals const& _totals, + bool _aborting ); + TestGroupStats( GroupInfo const& _groupInfo ); + + TestGroupStats( TestGroupStats const& ) = default; + TestGroupStats( TestGroupStats && ) = default; + TestGroupStats& operator = ( TestGroupStats const& ) = default; + TestGroupStats& operator = ( TestGroupStats && ) = default; + virtual ~TestGroupStats(); + + GroupInfo groupInfo; + Totals totals; + bool aborting; + }; + + struct TestRunStats { + TestRunStats( TestRunInfo const& _runInfo, + Totals const& _totals, + bool _aborting ); + + TestRunStats( TestRunStats const& ) = default; + TestRunStats( TestRunStats && ) = default; + TestRunStats& operator = ( TestRunStats const& ) = default; + TestRunStats& operator = ( TestRunStats && ) = default; + virtual ~TestRunStats(); + + TestRunInfo runInfo; + Totals totals; + bool aborting; + }; + + struct BenchmarkInfo { + std::string name; + }; + struct BenchmarkStats { + BenchmarkInfo info; + std::size_t iterations; + uint64_t elapsedTimeInNanoseconds; + }; + + struct IStreamingReporter { + virtual ~IStreamingReporter() = default; + + // Implementing class must also provide the following static methods: + // static std::string getDescription(); + // static std::set getSupportedVerbosities() + + virtual ReporterPreferences getPreferences() const = 0; + + virtual void noMatchingTestCases( std::string const& spec ) = 0; + + virtual void testRunStarting( TestRunInfo const& testRunInfo ) = 0; + virtual void testGroupStarting( GroupInfo const& groupInfo ) = 0; + + virtual void testCaseStarting( TestCaseInfo const& testInfo ) = 0; + virtual void sectionStarting( SectionInfo const& sectionInfo ) = 0; + + // *** experimental *** + virtual void benchmarkStarting( BenchmarkInfo const& ) {} + + virtual void assertionStarting( AssertionInfo const& assertionInfo ) = 0; + + // The return value indicates if the messages buffer should be cleared: + virtual bool assertionEnded( AssertionStats const& assertionStats ) = 0; + + // *** experimental *** + virtual void benchmarkEnded( BenchmarkStats const& ) {} + + virtual void sectionEnded( SectionStats const& sectionStats ) = 0; + virtual void testCaseEnded( TestCaseStats const& testCaseStats ) = 0; + virtual void testGroupEnded( TestGroupStats const& testGroupStats ) = 0; + virtual void testRunEnded( TestRunStats const& testRunStats ) = 0; + + virtual void skipTest( TestCaseInfo const& testInfo ) = 0; + + // Default empty implementation provided + virtual void fatalErrorEncountered( StringRef name ); + + virtual bool isMulti() const; + }; + using IStreamingReporterPtr = std::unique_ptr; + + struct IReporterFactory { + virtual ~IReporterFactory(); + virtual IStreamingReporterPtr create( ReporterConfig const& config ) const = 0; + virtual std::string getDescription() const = 0; + }; + using IReporterFactoryPtr = std::shared_ptr; + + struct IReporterRegistry { + using FactoryMap = std::map; + using Listeners = std::vector; + + virtual ~IReporterRegistry(); + virtual IStreamingReporterPtr create( std::string const& name, IConfigPtr const& config ) const = 0; + virtual FactoryMap const& getFactories() const = 0; + virtual Listeners const& getListeners() const = 0; + }; + +} // end namespace Catch + +// end catch_interfaces_reporter.h +#include +#include +#include +#include +#include +#include +#include + +namespace Catch { + void prepareExpandedExpression(AssertionResult& result); + + // Returns double formatted as %.3f (format expected on output) + std::string getFormattedDuration( double duration ); + + template + struct StreamingReporterBase : IStreamingReporter { + + StreamingReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = false; + if( !DerivedT::getSupportedVerbosities().count( m_config->verbosity() ) ) + CATCH_ERROR( "Verbosity level not supported by this reporter" ); + } + + ReporterPreferences getPreferences() const override { + return m_reporterPrefs; + } + + static std::set getSupportedVerbosities() { + return { Verbosity::Normal }; + } + + ~StreamingReporterBase() override = default; + + void noMatchingTestCases(std::string const&) override {} + + void testRunStarting(TestRunInfo const& _testRunInfo) override { + currentTestRunInfo = _testRunInfo; + } + void testGroupStarting(GroupInfo const& _groupInfo) override { + currentGroupInfo = _groupInfo; + } + + void testCaseStarting(TestCaseInfo const& _testInfo) override { + currentTestCaseInfo = _testInfo; + } + void sectionStarting(SectionInfo const& _sectionInfo) override { + m_sectionStack.push_back(_sectionInfo); + } + + void sectionEnded(SectionStats const& /* _sectionStats */) override { + m_sectionStack.pop_back(); + } + void testCaseEnded(TestCaseStats const& /* _testCaseStats */) override { + currentTestCaseInfo.reset(); + } + void testGroupEnded(TestGroupStats const& /* _testGroupStats */) override { + currentGroupInfo.reset(); + } + void testRunEnded(TestRunStats const& /* _testRunStats */) override { + currentTestCaseInfo.reset(); + currentGroupInfo.reset(); + currentTestRunInfo.reset(); + } + + void skipTest(TestCaseInfo const&) override { + // Don't do anything with this by default. + // It can optionally be overridden in the derived class. + } + + IConfigPtr m_config; + std::ostream& stream; + + LazyStat currentTestRunInfo; + LazyStat currentGroupInfo; + LazyStat currentTestCaseInfo; + + std::vector m_sectionStack; + ReporterPreferences m_reporterPrefs; + }; + + template + struct CumulativeReporterBase : IStreamingReporter { + template + struct Node { + explicit Node( T const& _value ) : value( _value ) {} + virtual ~Node() {} + + using ChildNodes = std::vector>; + T value; + ChildNodes children; + }; + struct SectionNode { + explicit SectionNode(SectionStats const& _stats) : stats(_stats) {} + virtual ~SectionNode() = default; + + bool operator == (SectionNode const& other) const { + return stats.sectionInfo.lineInfo == other.stats.sectionInfo.lineInfo; + } + bool operator == (std::shared_ptr const& other) const { + return operator==(*other); + } + + SectionStats stats; + using ChildSections = std::vector>; + using Assertions = std::vector; + ChildSections childSections; + Assertions assertions; + std::string stdOut; + std::string stdErr; + }; + + struct BySectionInfo { + BySectionInfo( SectionInfo const& other ) : m_other( other ) {} + BySectionInfo( BySectionInfo const& other ) : m_other( other.m_other ) {} + bool operator() (std::shared_ptr const& node) const { + return ((node->stats.sectionInfo.name == m_other.name) && + (node->stats.sectionInfo.lineInfo == m_other.lineInfo)); + } + void operator=(BySectionInfo const&) = delete; + + private: + SectionInfo const& m_other; + }; + + using TestCaseNode = Node; + using TestGroupNode = Node; + using TestRunNode = Node; + + CumulativeReporterBase( ReporterConfig const& _config ) + : m_config( _config.fullConfig() ), + stream( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = false; + if( !DerivedT::getSupportedVerbosities().count( m_config->verbosity() ) ) + CATCH_ERROR( "Verbosity level not supported by this reporter" ); + } + ~CumulativeReporterBase() override = default; + + ReporterPreferences getPreferences() const override { + return m_reporterPrefs; + } + + static std::set getSupportedVerbosities() { + return { Verbosity::Normal }; + } + + void testRunStarting( TestRunInfo const& ) override {} + void testGroupStarting( GroupInfo const& ) override {} + + void testCaseStarting( TestCaseInfo const& ) override {} + + void sectionStarting( SectionInfo const& sectionInfo ) override { + SectionStats incompleteStats( sectionInfo, Counts(), 0, false ); + std::shared_ptr node; + if( m_sectionStack.empty() ) { + if( !m_rootSection ) + m_rootSection = std::make_shared( incompleteStats ); + node = m_rootSection; + } + else { + SectionNode& parentNode = *m_sectionStack.back(); + auto it = + std::find_if( parentNode.childSections.begin(), + parentNode.childSections.end(), + BySectionInfo( sectionInfo ) ); + if( it == parentNode.childSections.end() ) { + node = std::make_shared( incompleteStats ); + parentNode.childSections.push_back( node ); + } + else + node = *it; + } + m_sectionStack.push_back( node ); + m_deepestSection = std::move(node); + } + + void assertionStarting(AssertionInfo const&) override {} + + bool assertionEnded(AssertionStats const& assertionStats) override { + assert(!m_sectionStack.empty()); + // AssertionResult holds a pointer to a temporary DecomposedExpression, + // which getExpandedExpression() calls to build the expression string. + // Our section stack copy of the assertionResult will likely outlive the + // temporary, so it must be expanded or discarded now to avoid calling + // a destroyed object later. + prepareExpandedExpression(const_cast( assertionStats.assertionResult ) ); + SectionNode& sectionNode = *m_sectionStack.back(); + sectionNode.assertions.push_back(assertionStats); + return true; + } + void sectionEnded(SectionStats const& sectionStats) override { + assert(!m_sectionStack.empty()); + SectionNode& node = *m_sectionStack.back(); + node.stats = sectionStats; + m_sectionStack.pop_back(); + } + void testCaseEnded(TestCaseStats const& testCaseStats) override { + auto node = std::make_shared(testCaseStats); + assert(m_sectionStack.size() == 0); + node->children.push_back(m_rootSection); + m_testCases.push_back(node); + m_rootSection.reset(); + + assert(m_deepestSection); + m_deepestSection->stdOut = testCaseStats.stdOut; + m_deepestSection->stdErr = testCaseStats.stdErr; + } + void testGroupEnded(TestGroupStats const& testGroupStats) override { + auto node = std::make_shared(testGroupStats); + node->children.swap(m_testCases); + m_testGroups.push_back(node); + } + void testRunEnded(TestRunStats const& testRunStats) override { + auto node = std::make_shared(testRunStats); + node->children.swap(m_testGroups); + m_testRuns.push_back(node); + testRunEndedCumulative(); + } + virtual void testRunEndedCumulative() = 0; + + void skipTest(TestCaseInfo const&) override {} + + IConfigPtr m_config; + std::ostream& stream; + std::vector m_assertions; + std::vector>> m_sections; + std::vector> m_testCases; + std::vector> m_testGroups; + + std::vector> m_testRuns; + + std::shared_ptr m_rootSection; + std::shared_ptr m_deepestSection; + std::vector> m_sectionStack; + ReporterPreferences m_reporterPrefs; + }; + + template + char const* getLineOfChars() { + static char line[CATCH_CONFIG_CONSOLE_WIDTH] = {0}; + if( !*line ) { + std::memset( line, C, CATCH_CONFIG_CONSOLE_WIDTH-1 ); + line[CATCH_CONFIG_CONSOLE_WIDTH-1] = 0; + } + return line; + } + + struct TestEventListenerBase : StreamingReporterBase { + TestEventListenerBase( ReporterConfig const& _config ); + + static std::set getSupportedVerbosities(); + + void assertionStarting(AssertionInfo const&) override; + bool assertionEnded(AssertionStats const&) override; + }; + +} // end namespace Catch + +// end catch_reporter_bases.hpp +// start catch_console_colour.h + +namespace Catch { + + struct Colour { + enum Code { + None = 0, + + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White, + BrightYellow = Bright | Yellow, + + // By intention + FileName = LightGrey, + Warning = BrightYellow, + ResultError = BrightRed, + ResultSuccess = BrightGreen, + ResultExpectedFailure = Warning, + + Error = BrightRed, + Success = Green, + + OriginalExpression = Cyan, + ReconstructedExpression = BrightYellow, + + SecondaryText = LightGrey, + Headers = White + }; + + // Use constructed object for RAII guard + Colour( Code _colourCode ); + Colour( Colour&& other ) noexcept; + Colour& operator=( Colour&& other ) noexcept; + ~Colour(); + + // Use static method for one-shot changes + static void use( Code _colourCode ); + + private: + bool m_moved = false; + }; + + std::ostream& operator << ( std::ostream& os, Colour const& ); + +} // end namespace Catch + +// end catch_console_colour.h +// start catch_reporter_registrars.hpp + + +namespace Catch { + + template + class ReporterRegistrar { + + class ReporterFactory : public IReporterFactory { + + virtual IStreamingReporterPtr create( ReporterConfig const& config ) const override { + return std::unique_ptr( new T( config ) ); + } + + virtual std::string getDescription() const override { + return T::getDescription(); + } + }; + + public: + + explicit ReporterRegistrar( std::string const& name ) { + getMutableRegistryHub().registerReporter( name, std::make_shared() ); + } + }; + + template + class ListenerRegistrar { + + class ListenerFactory : public IReporterFactory { + + virtual IStreamingReporterPtr create( ReporterConfig const& config ) const override { + return std::unique_ptr( new T( config ) ); + } + virtual std::string getDescription() const override { + return std::string(); + } + }; + + public: + + ListenerRegistrar() { + getMutableRegistryHub().registerListener( std::make_shared() ); + } + }; +} + +#if !defined(CATCH_CONFIG_DISABLE) + +#define CATCH_REGISTER_REPORTER( name, reporterType ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::ReporterRegistrar catch_internal_RegistrarFor##reporterType( name ); } \ + CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS + +#define CATCH_REGISTER_LISTENER( listenerType ) \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \ + namespace{ Catch::ListenerRegistrar catch_internal_RegistrarFor##listenerType; } \ + CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS +#else // CATCH_CONFIG_DISABLE + +#define CATCH_REGISTER_REPORTER(name, reporterType) +#define CATCH_REGISTER_LISTENER(listenerType) + +#endif // CATCH_CONFIG_DISABLE + +// end catch_reporter_registrars.hpp +// Allow users to base their work off existing reporters +// start catch_reporter_compact.h + +namespace Catch { + + struct CompactReporter : StreamingReporterBase { + + using StreamingReporterBase::StreamingReporterBase; + + ~CompactReporter() override; + + static std::string getDescription(); + + ReporterPreferences getPreferences() const override; + + void noMatchingTestCases(std::string const& spec) override; + + void assertionStarting(AssertionInfo const&) override; + + bool assertionEnded(AssertionStats const& _assertionStats) override; + + void sectionEnded(SectionStats const& _sectionStats) override; + + void testRunEnded(TestRunStats const& _testRunStats) override; + + }; + +} // end namespace Catch + +// end catch_reporter_compact.h +// start catch_reporter_console.h + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch + // Note that 4062 (not all labels are handled + // and default is missing) is enabled +#endif + +namespace Catch { + // Fwd decls + struct SummaryColumn; + class TablePrinter; + + struct ConsoleReporter : StreamingReporterBase { + std::unique_ptr m_tablePrinter; + + ConsoleReporter(ReporterConfig const& config); + ~ConsoleReporter() override; + static std::string getDescription(); + + void noMatchingTestCases(std::string const& spec) override; + + void assertionStarting(AssertionInfo const&) override; + + bool assertionEnded(AssertionStats const& _assertionStats) override; + + void sectionStarting(SectionInfo const& _sectionInfo) override; + void sectionEnded(SectionStats const& _sectionStats) override; + + void benchmarkStarting(BenchmarkInfo const& info) override; + void benchmarkEnded(BenchmarkStats const& stats) override; + + void testCaseEnded(TestCaseStats const& _testCaseStats) override; + void testGroupEnded(TestGroupStats const& _testGroupStats) override; + void testRunEnded(TestRunStats const& _testRunStats) override; + + private: + + void lazyPrint(); + + void lazyPrintWithoutClosingBenchmarkTable(); + void lazyPrintRunInfo(); + void lazyPrintGroupInfo(); + void printTestCaseAndSectionHeader(); + + void printClosedHeader(std::string const& _name); + void printOpenHeader(std::string const& _name); + + // if string has a : in first line will set indent to follow it on + // subsequent lines + void printHeaderString(std::string const& _string, std::size_t indent = 0); + + void printTotals(Totals const& totals); + void printSummaryRow(std::string const& label, std::vector const& cols, std::size_t row); + + void printTotalsDivider(Totals const& totals); + void printSummaryDivider(); + + private: + bool m_headerPrinted = false; + }; + +} // end namespace Catch + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +// end catch_reporter_console.h +// start catch_reporter_junit.h + +// start catch_xmlwriter.h + +#include + +namespace Catch { + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); + + void encodeTo( std::ostream& os ) const; + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ); + + ScopedElement( ScopedElement&& other ) noexcept; + ScopedElement& operator=( ScopedElement&& other ) noexcept; + + ~ScopedElement(); + + ScopedElement& writeText( std::string const& text, bool indent = true ); + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer = nullptr; + }; + + XmlWriter( std::ostream& os = Catch::cout() ); + ~XmlWriter(); + + XmlWriter( XmlWriter const& ) = delete; + XmlWriter& operator=( XmlWriter const& ) = delete; + + XmlWriter& startElement( std::string const& name ); + + ScopedElement scopedElement( std::string const& name ); + + XmlWriter& endElement(); + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); + + XmlWriter& writeAttribute( std::string const& name, bool attribute ); + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + ReusableStringStream rss; + rss << attribute; + return writeAttribute( name, rss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ); + + XmlWriter& writeComment( std::string const& text ); + + void writeStylesheetRef( std::string const& url ); + + XmlWriter& writeBlankLine(); + + void ensureTagClosed(); + + private: + + void writeDeclaration(); + + void newlineIfNecessary(); + + bool m_tagIsOpen = false; + bool m_needsNewline = false; + std::vector m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +} + +// end catch_xmlwriter.h +namespace Catch { + + class JunitReporter : public CumulativeReporterBase { + public: + JunitReporter(ReporterConfig const& _config); + + ~JunitReporter() override; + + static std::string getDescription(); + + void noMatchingTestCases(std::string const& /*spec*/) override; + + void testRunStarting(TestRunInfo const& runInfo) override; + + void testGroupStarting(GroupInfo const& groupInfo) override; + + void testCaseStarting(TestCaseInfo const& testCaseInfo) override; + bool assertionEnded(AssertionStats const& assertionStats) override; + + void testCaseEnded(TestCaseStats const& testCaseStats) override; + + void testGroupEnded(TestGroupStats const& testGroupStats) override; + + void testRunEndedCumulative() override; + + void writeGroup(TestGroupNode const& groupNode, double suiteTime); + + void writeTestCase(TestCaseNode const& testCaseNode); + + void writeSection(std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode); + + void writeAssertions(SectionNode const& sectionNode); + void writeAssertion(AssertionStats const& stats); + + XmlWriter xml; + Timer suiteTimer; + std::string stdOutForSuite; + std::string stdErrForSuite; + unsigned int unexpectedExceptions = 0; + bool m_okToFail = false; + }; + +} // end namespace Catch + +// end catch_reporter_junit.h +// start catch_reporter_xml.h + +namespace Catch { + class XmlReporter : public StreamingReporterBase { + public: + XmlReporter(ReporterConfig const& _config); + + ~XmlReporter() override; + + static std::string getDescription(); + + virtual std::string getStylesheetRef() const; + + void writeSourceInfo(SourceLineInfo const& sourceInfo); + + public: // StreamingReporterBase + + void noMatchingTestCases(std::string const& s) override; + + void testRunStarting(TestRunInfo const& testInfo) override; + + void testGroupStarting(GroupInfo const& groupInfo) override; + + void testCaseStarting(TestCaseInfo const& testInfo) override; + + void sectionStarting(SectionInfo const& sectionInfo) override; + + void assertionStarting(AssertionInfo const&) override; + + bool assertionEnded(AssertionStats const& assertionStats) override; + + void sectionEnded(SectionStats const& sectionStats) override; + + void testCaseEnded(TestCaseStats const& testCaseStats) override; + + void testGroupEnded(TestGroupStats const& testGroupStats) override; + + void testRunEnded(TestRunStats const& testRunStats) override; + + private: + Timer m_testCaseTimer; + XmlWriter m_xml; + int m_sectionDepth = 0; + }; + +} // end namespace Catch + +// end catch_reporter_xml.h + +// end catch_external_interfaces.h +#endif + +#endif // ! CATCH_CONFIG_IMPL_ONLY + +#ifdef CATCH_IMPL +// start catch_impl.hpp + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wweak-vtables" +#endif + +// Keep these here for external reporters +// start catch_test_case_tracker.h + +#include +#include +#include + +namespace Catch { +namespace TestCaseTracking { + + struct NameAndLocation { + std::string name; + SourceLineInfo location; + + NameAndLocation( std::string const& _name, SourceLineInfo const& _location ); + }; + + struct ITracker; + + using ITrackerPtr = std::shared_ptr; + + struct ITracker { + virtual ~ITracker(); + + // static queries + virtual NameAndLocation const& nameAndLocation() const = 0; + + // dynamic queries + virtual bool isComplete() const = 0; // Successfully completed or failed + virtual bool isSuccessfullyCompleted() const = 0; + virtual bool isOpen() const = 0; // Started but not complete + virtual bool hasChildren() const = 0; + + virtual ITracker& parent() = 0; + + // actions + virtual void close() = 0; // Successfully complete + virtual void fail() = 0; + virtual void markAsNeedingAnotherRun() = 0; + + virtual void addChild( ITrackerPtr const& child ) = 0; + virtual ITrackerPtr findChild( NameAndLocation const& nameAndLocation ) = 0; + virtual void openChild() = 0; + + // Debug/ checking + virtual bool isSectionTracker() const = 0; + virtual bool isGeneratorTracker() const = 0; + }; + + class TrackerContext { + + enum RunState { + NotStarted, + Executing, + CompletedCycle + }; + + ITrackerPtr m_rootTracker; + ITracker* m_currentTracker = nullptr; + RunState m_runState = NotStarted; + + public: + + static TrackerContext& instance(); + + ITracker& startRun(); + void endRun(); + + void startCycle(); + void completeCycle(); + + bool completedCycle() const; + ITracker& currentTracker(); + void setCurrentTracker( ITracker* tracker ); + }; + + class TrackerBase : public ITracker { + protected: + enum CycleState { + NotStarted, + Executing, + ExecutingChildren, + NeedsAnotherRun, + CompletedSuccessfully, + Failed + }; + + using Children = std::vector; + NameAndLocation m_nameAndLocation; + TrackerContext& m_ctx; + ITracker* m_parent; + Children m_children; + CycleState m_runState = NotStarted; + + public: + TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ); + + NameAndLocation const& nameAndLocation() const override; + bool isComplete() const override; + bool isSuccessfullyCompleted() const override; + bool isOpen() const override; + bool hasChildren() const override; + + void addChild( ITrackerPtr const& child ) override; + + ITrackerPtr findChild( NameAndLocation const& nameAndLocation ) override; + ITracker& parent() override; + + void openChild() override; + + bool isSectionTracker() const override; + bool isGeneratorTracker() const override; + + void open(); + + void close() override; + void fail() override; + void markAsNeedingAnotherRun() override; + + private: + void moveToParent(); + void moveToThis(); + }; + + class SectionTracker : public TrackerBase { + std::vector m_filters; + public: + SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ); + + bool isSectionTracker() const override; + + bool isComplete() const override; + + static SectionTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ); + + void tryOpen(); + + void addInitialFilters( std::vector const& filters ); + void addNextFilters( std::vector const& filters ); + }; + +} // namespace TestCaseTracking + +using TestCaseTracking::ITracker; +using TestCaseTracking::TrackerContext; +using TestCaseTracking::SectionTracker; + +} // namespace Catch + +// end catch_test_case_tracker.h + +// start catch_leak_detector.h + +namespace Catch { + + struct LeakDetector { + LeakDetector(); + ~LeakDetector(); + }; + +} +// end catch_leak_detector.h +// Cpp files will be included in the single-header file here +// start catch_approx.cpp + +#include +#include + +namespace { + +// Performs equivalent check of std::fabs(lhs - rhs) <= margin +// But without the subtraction to allow for INFINITY in comparison +bool marginComparison(double lhs, double rhs, double margin) { + return (lhs + margin >= rhs) && (rhs + margin >= lhs); +} + +} + +namespace Catch { +namespace Detail { + + Approx::Approx ( double value ) + : m_epsilon( std::numeric_limits::epsilon()*100 ), + m_margin( 0.0 ), + m_scale( 0.0 ), + m_value( value ) + {} + + Approx Approx::custom() { + return Approx( 0 ); + } + + Approx Approx::operator-() const { + auto temp(*this); + temp.m_value = -temp.m_value; + return temp; + } + + std::string Approx::toString() const { + ReusableStringStream rss; + rss << "Approx( " << ::Catch::Detail::stringify( m_value ) << " )"; + return rss.str(); + } + + bool Approx::equalityComparisonImpl(const double other) const { + // First try with fixed margin, then compute margin based on epsilon, scale and Approx's value + // Thanks to Richard Harris for his help refining the scaled margin value + return marginComparison(m_value, other, m_margin) || marginComparison(m_value, other, m_epsilon * (m_scale + std::fabs(m_value))); + } + + void Approx::setMargin(double margin) { + CATCH_ENFORCE(margin >= 0, + "Invalid Approx::margin: " << margin << '.' + << " Approx::Margin has to be non-negative."); + m_margin = margin; + } + + void Approx::setEpsilon(double epsilon) { + CATCH_ENFORCE(epsilon >= 0 && epsilon <= 1.0, + "Invalid Approx::epsilon: " << epsilon << '.' + << " Approx::epsilon has to be in [0, 1]"); + m_epsilon = epsilon; + } + +} // end namespace Detail + +namespace literals { + Detail::Approx operator "" _a(long double val) { + return Detail::Approx(val); + } + Detail::Approx operator "" _a(unsigned long long val) { + return Detail::Approx(val); + } +} // end namespace literals + +std::string StringMaker::convert(Catch::Detail::Approx const& value) { + return value.toString(); +} + +} // end namespace Catch +// end catch_approx.cpp +// start catch_assertionhandler.cpp + +// start catch_debugger.h + +namespace Catch { + bool isDebuggerActive(); +} + +#ifdef CATCH_PLATFORM_MAC + + #define CATCH_TRAP() __asm__("int $3\n" : : ) /* NOLINT */ + +#elif defined(CATCH_PLATFORM_LINUX) + // If we can use inline assembler, do it because this allows us to break + // directly at the location of the failing check instead of breaking inside + // raise() called from it, i.e. one stack frame below. + #if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) + #define CATCH_TRAP() asm volatile ("int $3") /* NOLINT */ + #else // Fall back to the generic way. + #include + + #define CATCH_TRAP() raise(SIGTRAP) + #endif +#elif defined(_MSC_VER) + #define CATCH_TRAP() __debugbreak() +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) void __stdcall DebugBreak(); + #define CATCH_TRAP() DebugBreak() +#endif + +#ifdef CATCH_TRAP + #define CATCH_BREAK_INTO_DEBUGGER() []{ if( Catch::isDebuggerActive() ) { CATCH_TRAP(); } }() +#else + #define CATCH_BREAK_INTO_DEBUGGER() []{}() +#endif + +// end catch_debugger.h +// start catch_run_context.h + +// start catch_fatal_condition.h + +// start catch_windows_h_proxy.h + + +#if defined(CATCH_PLATFORM_WINDOWS) + +#if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX) +# define CATCH_DEFINED_NOMINMAX +# define NOMINMAX +#endif +#if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN) +# define CATCH_DEFINED_WIN32_LEAN_AND_MEAN +# define WIN32_LEAN_AND_MEAN +#endif + +#ifdef __AFXDLL +#include +#else +#include +#endif + +#ifdef CATCH_DEFINED_NOMINMAX +# undef NOMINMAX +#endif +#ifdef CATCH_DEFINED_WIN32_LEAN_AND_MEAN +# undef WIN32_LEAN_AND_MEAN +#endif + +#endif // defined(CATCH_PLATFORM_WINDOWS) + +// end catch_windows_h_proxy.h +#if defined( CATCH_CONFIG_WINDOWS_SEH ) + +namespace Catch { + + struct FatalConditionHandler { + + static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo); + FatalConditionHandler(); + static void reset(); + ~FatalConditionHandler(); + + private: + static bool isSet; + static ULONG guaranteeSize; + static PVOID exceptionHandlerHandle; + }; + +} // namespace Catch + +#elif defined ( CATCH_CONFIG_POSIX_SIGNALS ) + +#include + +namespace Catch { + + struct FatalConditionHandler { + + static bool isSet; + static struct sigaction oldSigActions[]; + static stack_t oldSigStack; + static char altStackMem[]; + + static void handleSignal( int sig ); + + FatalConditionHandler(); + ~FatalConditionHandler(); + static void reset(); + }; + +} // namespace Catch + +#else + +namespace Catch { + struct FatalConditionHandler { + void reset(); + }; +} + +#endif + +// end catch_fatal_condition.h +#include + +namespace Catch { + + struct IMutableContext; + + /////////////////////////////////////////////////////////////////////////// + + class RunContext : public IResultCapture, public IRunner { + + public: + RunContext( RunContext const& ) = delete; + RunContext& operator =( RunContext const& ) = delete; + + explicit RunContext( IConfigPtr const& _config, IStreamingReporterPtr&& reporter ); + + ~RunContext() override; + + void testGroupStarting( std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount ); + void testGroupEnded( std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount ); + + Totals runTest(TestCase const& testCase); + + IConfigPtr config() const; + IStreamingReporter& reporter() const; + + public: // IResultCapture + + // Assertion handlers + void handleExpr + ( AssertionInfo const& info, + ITransientExpression const& expr, + AssertionReaction& reaction ) override; + void handleMessage + ( AssertionInfo const& info, + ResultWas::OfType resultType, + StringRef const& message, + AssertionReaction& reaction ) override; + void handleUnexpectedExceptionNotThrown + ( AssertionInfo const& info, + AssertionReaction& reaction ) override; + void handleUnexpectedInflightException + ( AssertionInfo const& info, + std::string const& message, + AssertionReaction& reaction ) override; + void handleIncomplete + ( AssertionInfo const& info ) override; + void handleNonExpr + ( AssertionInfo const &info, + ResultWas::OfType resultType, + AssertionReaction &reaction ) override; + + bool sectionStarted( SectionInfo const& sectionInfo, Counts& assertions ) override; + + void sectionEnded( SectionEndInfo const& endInfo ) override; + void sectionEndedEarly( SectionEndInfo const& endInfo ) override; + + auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& override; + + void benchmarkStarting( BenchmarkInfo const& info ) override; + void benchmarkEnded( BenchmarkStats const& stats ) override; + + void pushScopedMessage( MessageInfo const& message ) override; + void popScopedMessage( MessageInfo const& message ) override; + + void emplaceUnscopedMessage( MessageBuilder const& builder ) override; + + std::string getCurrentTestName() const override; + + const AssertionResult* getLastResult() const override; + + void exceptionEarlyReported() override; + + void handleFatalErrorCondition( StringRef message ) override; + + bool lastAssertionPassed() override; + + void assertionPassed() override; + + public: + // !TBD We need to do this another way! + bool aborting() const final; + + private: + + void runCurrentTest( std::string& redirectedCout, std::string& redirectedCerr ); + void invokeActiveTestCase(); + + void resetAssertionInfo(); + bool testForMissingAssertions( Counts& assertions ); + + void assertionEnded( AssertionResult const& result ); + void reportExpr + ( AssertionInfo const &info, + ResultWas::OfType resultType, + ITransientExpression const *expr, + bool negated ); + + void populateReaction( AssertionReaction& reaction ); + + private: + + void handleUnfinishedSections(); + + TestRunInfo m_runInfo; + IMutableContext& m_context; + TestCase const* m_activeTestCase = nullptr; + ITracker* m_testCaseTracker = nullptr; + Option m_lastResult; + + IConfigPtr m_config; + Totals m_totals; + IStreamingReporterPtr m_reporter; + std::vector m_messages; + std::vector m_messageScopes; /* Keeps owners of so-called unscoped messages. */ + AssertionInfo m_lastAssertionInfo; + std::vector m_unfinishedSections; + std::vector m_activeSections; + TrackerContext m_trackerContext; + bool m_lastAssertionPassed = false; + bool m_shouldReportUnexpected = true; + bool m_includeSuccessfulResults; + }; + +} // end namespace Catch + +// end catch_run_context.h +namespace Catch { + + namespace { + auto operator <<( std::ostream& os, ITransientExpression const& expr ) -> std::ostream& { + expr.streamReconstructedExpression( os ); + return os; + } + } + + LazyExpression::LazyExpression( bool isNegated ) + : m_isNegated( isNegated ) + {} + + LazyExpression::LazyExpression( LazyExpression const& other ) : m_isNegated( other.m_isNegated ) {} + + LazyExpression::operator bool() const { + return m_transientExpression != nullptr; + } + + auto operator << ( std::ostream& os, LazyExpression const& lazyExpr ) -> std::ostream& { + if( lazyExpr.m_isNegated ) + os << "!"; + + if( lazyExpr ) { + if( lazyExpr.m_isNegated && lazyExpr.m_transientExpression->isBinaryExpression() ) + os << "(" << *lazyExpr.m_transientExpression << ")"; + else + os << *lazyExpr.m_transientExpression; + } + else { + os << "{** error - unchecked empty expression requested **}"; + } + return os; + } + + AssertionHandler::AssertionHandler + ( StringRef const& macroName, + SourceLineInfo const& lineInfo, + StringRef capturedExpression, + ResultDisposition::Flags resultDisposition ) + : m_assertionInfo{ macroName, lineInfo, capturedExpression, resultDisposition }, + m_resultCapture( getResultCapture() ) + {} + + void AssertionHandler::handleExpr( ITransientExpression const& expr ) { + m_resultCapture.handleExpr( m_assertionInfo, expr, m_reaction ); + } + void AssertionHandler::handleMessage(ResultWas::OfType resultType, StringRef const& message) { + m_resultCapture.handleMessage( m_assertionInfo, resultType, message, m_reaction ); + } + + auto AssertionHandler::allowThrows() const -> bool { + return getCurrentContext().getConfig()->allowThrows(); + } + + void AssertionHandler::complete() { + setCompleted(); + if( m_reaction.shouldDebugBreak ) { + + // If you find your debugger stopping you here then go one level up on the + // call-stack for the code that caused it (typically a failed assertion) + + // (To go back to the test and change execution, jump over the throw, next) + CATCH_BREAK_INTO_DEBUGGER(); + } + if (m_reaction.shouldThrow) { +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + throw Catch::TestFailureException(); +#else + CATCH_ERROR( "Test failure requires aborting test!" ); +#endif + } + } + void AssertionHandler::setCompleted() { + m_completed = true; + } + + void AssertionHandler::handleUnexpectedInflightException() { + m_resultCapture.handleUnexpectedInflightException( m_assertionInfo, Catch::translateActiveException(), m_reaction ); + } + + void AssertionHandler::handleExceptionThrownAsExpected() { + m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction); + } + void AssertionHandler::handleExceptionNotThrownAsExpected() { + m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction); + } + + void AssertionHandler::handleUnexpectedExceptionNotThrown() { + m_resultCapture.handleUnexpectedExceptionNotThrown( m_assertionInfo, m_reaction ); + } + + void AssertionHandler::handleThrowingCallSkipped() { + m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction); + } + + // This is the overload that takes a string and infers the Equals matcher from it + // The more general overload, that takes any string matcher, is in catch_capture_matchers.cpp + void handleExceptionMatchExpr( AssertionHandler& handler, std::string const& str, StringRef const& matcherString ) { + handleExceptionMatchExpr( handler, Matchers::Equals( str ), matcherString ); + } + +} // namespace Catch +// end catch_assertionhandler.cpp +// start catch_assertionresult.cpp + +namespace Catch { + AssertionResultData::AssertionResultData(ResultWas::OfType _resultType, LazyExpression const & _lazyExpression): + lazyExpression(_lazyExpression), + resultType(_resultType) {} + + std::string AssertionResultData::reconstructExpression() const { + + if( reconstructedExpression.empty() ) { + if( lazyExpression ) { + ReusableStringStream rss; + rss << lazyExpression; + reconstructedExpression = rss.str(); + } + } + return reconstructedExpression; + } + + AssertionResult::AssertionResult( AssertionInfo const& info, AssertionResultData const& data ) + : m_info( info ), + m_resultData( data ) + {} + + // Result was a success + bool AssertionResult::succeeded() const { + return Catch::isOk( m_resultData.resultType ); + } + + // Result was a success, or failure is suppressed + bool AssertionResult::isOk() const { + return Catch::isOk( m_resultData.resultType ) || shouldSuppressFailure( m_info.resultDisposition ); + } + + ResultWas::OfType AssertionResult::getResultType() const { + return m_resultData.resultType; + } + + bool AssertionResult::hasExpression() const { + return m_info.capturedExpression[0] != 0; + } + + bool AssertionResult::hasMessage() const { + return !m_resultData.message.empty(); + } + + std::string AssertionResult::getExpression() const { + if( isFalseTest( m_info.resultDisposition ) ) + return "!(" + m_info.capturedExpression + ")"; + else + return m_info.capturedExpression; + } + + std::string AssertionResult::getExpressionInMacro() const { + std::string expr; + if( m_info.macroName[0] == 0 ) + expr = m_info.capturedExpression; + else { + expr.reserve( m_info.macroName.size() + m_info.capturedExpression.size() + 4 ); + expr += m_info.macroName; + expr += "( "; + expr += m_info.capturedExpression; + expr += " )"; + } + return expr; + } + + bool AssertionResult::hasExpandedExpression() const { + return hasExpression() && getExpandedExpression() != getExpression(); + } + + std::string AssertionResult::getExpandedExpression() const { + std::string expr = m_resultData.reconstructExpression(); + return expr.empty() + ? getExpression() + : expr; + } + + std::string AssertionResult::getMessage() const { + return m_resultData.message; + } + SourceLineInfo AssertionResult::getSourceInfo() const { + return m_info.lineInfo; + } + + StringRef AssertionResult::getTestMacroName() const { + return m_info.macroName; + } + +} // end namespace Catch +// end catch_assertionresult.cpp +// start catch_benchmark.cpp + +namespace Catch { + + auto BenchmarkLooper::getResolution() -> uint64_t { + return getEstimatedClockResolution() * getCurrentContext().getConfig()->benchmarkResolutionMultiple(); + } + + void BenchmarkLooper::reportStart() { + getResultCapture().benchmarkStarting( { m_name } ); + } + auto BenchmarkLooper::needsMoreIterations() -> bool { + auto elapsed = m_timer.getElapsedNanoseconds(); + + // Exponentially increasing iterations until we're confident in our timer resolution + if( elapsed < m_resolution ) { + m_iterationsToRun *= 10; + return true; + } + + getResultCapture().benchmarkEnded( { { m_name }, m_count, elapsed } ); + return false; + } + +} // end namespace Catch +// end catch_benchmark.cpp +// start catch_capture_matchers.cpp + +namespace Catch { + + using StringMatcher = Matchers::Impl::MatcherBase; + + // This is the general overload that takes a any string matcher + // There is another overload, in catch_assertionhandler.h/.cpp, that only takes a string and infers + // the Equals matcher (so the header does not mention matchers) + void handleExceptionMatchExpr( AssertionHandler& handler, StringMatcher const& matcher, StringRef const& matcherString ) { + std::string exceptionMessage = Catch::translateActiveException(); + MatchExpr expr( exceptionMessage, matcher, matcherString ); + handler.handleExpr( expr ); + } + +} // namespace Catch +// end catch_capture_matchers.cpp +// start catch_commandline.cpp + +// start catch_commandline.h + +// start catch_clara.h + +// Use Catch's value for console width (store Clara's off to the side, if present) +#ifdef CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#undef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#endif +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH-1 + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wweak-vtables" +#pragma clang diagnostic ignored "-Wexit-time-destructors" +#pragma clang diagnostic ignored "-Wshadow" +#endif + +// start clara.hpp +// Copyright 2017 Two Blue Cubes Ltd. All rights reserved. +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// See https://github.com/philsquared/Clara for more details + +// Clara v1.1.5 + + +#ifndef CATCH_CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_CONFIG_CONSOLE_WIDTH 80 +#endif + +#ifndef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_CLARA_CONFIG_CONSOLE_WIDTH +#endif + +#ifndef CLARA_CONFIG_OPTIONAL_TYPE +#ifdef __has_include +#if __has_include() && __cplusplus >= 201703L +#include +#define CLARA_CONFIG_OPTIONAL_TYPE std::optional +#endif +#endif +#endif + +// ----------- #included from clara_textflow.hpp ----------- + +// TextFlowCpp +// +// A single-header library for wrapping and laying out basic text, by Phil Nash +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// This project is hosted at https://github.com/philsquared/textflowcpp + + +#include +#include +#include +#include + +#ifndef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH 80 +#endif + +namespace Catch { +namespace clara { +namespace TextFlow { + +inline auto isWhitespace(char c) -> bool { + static std::string chars = " \t\n\r"; + return chars.find(c) != std::string::npos; +} +inline auto isBreakableBefore(char c) -> bool { + static std::string chars = "[({<|"; + return chars.find(c) != std::string::npos; +} +inline auto isBreakableAfter(char c) -> bool { + static std::string chars = "])}>.,:;*+-=&/\\"; + return chars.find(c) != std::string::npos; +} + +class Columns; + +class Column { + std::vector m_strings; + size_t m_width = CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH; + size_t m_indent = 0; + size_t m_initialIndent = std::string::npos; + +public: + class iterator { + friend Column; + + Column const& m_column; + size_t m_stringIndex = 0; + size_t m_pos = 0; + + size_t m_len = 0; + size_t m_end = 0; + bool m_suffix = false; + + iterator(Column const& column, size_t stringIndex) + : m_column(column), + m_stringIndex(stringIndex) {} + + auto line() const -> std::string const& { return m_column.m_strings[m_stringIndex]; } + + auto isBoundary(size_t at) const -> bool { + assert(at > 0); + assert(at <= line().size()); + + return at == line().size() || + (isWhitespace(line()[at]) && !isWhitespace(line()[at - 1])) || + isBreakableBefore(line()[at]) || + isBreakableAfter(line()[at - 1]); + } + + void calcLength() { + assert(m_stringIndex < m_column.m_strings.size()); + + m_suffix = false; + auto width = m_column.m_width - indent(); + m_end = m_pos; + while (m_end < line().size() && line()[m_end] != '\n') + ++m_end; + + if (m_end < m_pos + width) { + m_len = m_end - m_pos; + } else { + size_t len = width; + while (len > 0 && !isBoundary(m_pos + len)) + --len; + while (len > 0 && isWhitespace(line()[m_pos + len - 1])) + --len; + + if (len > 0) { + m_len = len; + } else { + m_suffix = true; + m_len = width - 1; + } + } + } + + auto indent() const -> size_t { + auto initial = m_pos == 0 && m_stringIndex == 0 ? m_column.m_initialIndent : std::string::npos; + return initial == std::string::npos ? m_column.m_indent : initial; + } + + auto addIndentAndSuffix(std::string const &plain) const -> std::string { + return std::string(indent(), ' ') + (m_suffix ? plain + "-" : plain); + } + + public: + using difference_type = std::ptrdiff_t; + using value_type = std::string; + using pointer = value_type * ; + using reference = value_type & ; + using iterator_category = std::forward_iterator_tag; + + explicit iterator(Column const& column) : m_column(column) { + assert(m_column.m_width > m_column.m_indent); + assert(m_column.m_initialIndent == std::string::npos || m_column.m_width > m_column.m_initialIndent); + calcLength(); + if (m_len == 0) + m_stringIndex++; // Empty string + } + + auto operator *() const -> std::string { + assert(m_stringIndex < m_column.m_strings.size()); + assert(m_pos <= m_end); + return addIndentAndSuffix(line().substr(m_pos, m_len)); + } + + auto operator ++() -> iterator& { + m_pos += m_len; + if (m_pos < line().size() && line()[m_pos] == '\n') + m_pos += 1; + else + while (m_pos < line().size() && isWhitespace(line()[m_pos])) + ++m_pos; + + if (m_pos == line().size()) { + m_pos = 0; + ++m_stringIndex; + } + if (m_stringIndex < m_column.m_strings.size()) + calcLength(); + return *this; + } + auto operator ++(int) -> iterator { + iterator prev(*this); + operator++(); + return prev; + } + + auto operator ==(iterator const& other) const -> bool { + return + m_pos == other.m_pos && + m_stringIndex == other.m_stringIndex && + &m_column == &other.m_column; + } + auto operator !=(iterator const& other) const -> bool { + return !operator==(other); + } + }; + using const_iterator = iterator; + + explicit Column(std::string const& text) { m_strings.push_back(text); } + + auto width(size_t newWidth) -> Column& { + assert(newWidth > 0); + m_width = newWidth; + return *this; + } + auto indent(size_t newIndent) -> Column& { + m_indent = newIndent; + return *this; + } + auto initialIndent(size_t newIndent) -> Column& { + m_initialIndent = newIndent; + return *this; + } + + auto width() const -> size_t { return m_width; } + auto begin() const -> iterator { return iterator(*this); } + auto end() const -> iterator { return { *this, m_strings.size() }; } + + inline friend std::ostream& operator << (std::ostream& os, Column const& col) { + bool first = true; + for (auto line : col) { + if (first) + first = false; + else + os << "\n"; + os << line; + } + return os; + } + + auto operator + (Column const& other)->Columns; + + auto toString() const -> std::string { + std::ostringstream oss; + oss << *this; + return oss.str(); + } +}; + +class Spacer : public Column { + +public: + explicit Spacer(size_t spaceWidth) : Column("") { + width(spaceWidth); + } +}; + +class Columns { + std::vector m_columns; + +public: + + class iterator { + friend Columns; + struct EndTag {}; + + std::vector const& m_columns; + std::vector m_iterators; + size_t m_activeIterators; + + iterator(Columns const& columns, EndTag) + : m_columns(columns.m_columns), + m_activeIterators(0) { + m_iterators.reserve(m_columns.size()); + + for (auto const& col : m_columns) + m_iterators.push_back(col.end()); + } + + public: + using difference_type = std::ptrdiff_t; + using value_type = std::string; + using pointer = value_type * ; + using reference = value_type & ; + using iterator_category = std::forward_iterator_tag; + + explicit iterator(Columns const& columns) + : m_columns(columns.m_columns), + m_activeIterators(m_columns.size()) { + m_iterators.reserve(m_columns.size()); + + for (auto const& col : m_columns) + m_iterators.push_back(col.begin()); + } + + auto operator ==(iterator const& other) const -> bool { + return m_iterators == other.m_iterators; + } + auto operator !=(iterator const& other) const -> bool { + return m_iterators != other.m_iterators; + } + auto operator *() const -> std::string { + std::string row, padding; + + for (size_t i = 0; i < m_columns.size(); ++i) { + auto width = m_columns[i].width(); + if (m_iterators[i] != m_columns[i].end()) { + std::string col = *m_iterators[i]; + row += padding + col; + if (col.size() < width) + padding = std::string(width - col.size(), ' '); + else + padding = ""; + } else { + padding += std::string(width, ' '); + } + } + return row; + } + auto operator ++() -> iterator& { + for (size_t i = 0; i < m_columns.size(); ++i) { + if (m_iterators[i] != m_columns[i].end()) + ++m_iterators[i]; + } + return *this; + } + auto operator ++(int) -> iterator { + iterator prev(*this); + operator++(); + return prev; + } + }; + using const_iterator = iterator; + + auto begin() const -> iterator { return iterator(*this); } + auto end() const -> iterator { return { *this, iterator::EndTag() }; } + + auto operator += (Column const& col) -> Columns& { + m_columns.push_back(col); + return *this; + } + auto operator + (Column const& col) -> Columns { + Columns combined = *this; + combined += col; + return combined; + } + + inline friend std::ostream& operator << (std::ostream& os, Columns const& cols) { + + bool first = true; + for (auto line : cols) { + if (first) + first = false; + else + os << "\n"; + os << line; + } + return os; + } + + auto toString() const -> std::string { + std::ostringstream oss; + oss << *this; + return oss.str(); + } +}; + +inline auto Column::operator + (Column const& other) -> Columns { + Columns cols; + cols += *this; + cols += other; + return cols; +} +} + +} +} + +// ----------- end of #include from clara_textflow.hpp ----------- +// ........... back in clara.hpp + +#include +#include +#include +#include +#include + +#if !defined(CATCH_PLATFORM_WINDOWS) && ( defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) ) +#define CATCH_PLATFORM_WINDOWS +#endif + +namespace Catch { namespace clara { +namespace detail { + + // Traits for extracting arg and return type of lambdas (for single argument lambdas) + template + struct UnaryLambdaTraits : UnaryLambdaTraits {}; + + template + struct UnaryLambdaTraits { + static const bool isValid = false; + }; + + template + struct UnaryLambdaTraits { + static const bool isValid = true; + using ArgType = typename std::remove_const::type>::type; + using ReturnType = ReturnT; + }; + + class TokenStream; + + // Transport for raw args (copied from main args, or supplied via init list for testing) + class Args { + friend TokenStream; + std::string m_exeName; + std::vector m_args; + + public: + Args( int argc, char const* const* argv ) + : m_exeName(argv[0]), + m_args(argv + 1, argv + argc) {} + + Args( std::initializer_list args ) + : m_exeName( *args.begin() ), + m_args( args.begin()+1, args.end() ) + {} + + auto exeName() const -> std::string { + return m_exeName; + } + }; + + // Wraps a token coming from a token stream. These may not directly correspond to strings as a single string + // may encode an option + its argument if the : or = form is used + enum class TokenType { + Option, Argument + }; + struct Token { + TokenType type; + std::string token; + }; + + inline auto isOptPrefix( char c ) -> bool { + return c == '-' +#ifdef CATCH_PLATFORM_WINDOWS + || c == '/' +#endif + ; + } + + // Abstracts iterators into args as a stream of tokens, with option arguments uniformly handled + class TokenStream { + using Iterator = std::vector::const_iterator; + Iterator it; + Iterator itEnd; + std::vector m_tokenBuffer; + + void loadBuffer() { + m_tokenBuffer.resize( 0 ); + + // Skip any empty strings + while( it != itEnd && it->empty() ) + ++it; + + if( it != itEnd ) { + auto const &next = *it; + if( isOptPrefix( next[0] ) ) { + auto delimiterPos = next.find_first_of( " :=" ); + if( delimiterPos != std::string::npos ) { + m_tokenBuffer.push_back( { TokenType::Option, next.substr( 0, delimiterPos ) } ); + m_tokenBuffer.push_back( { TokenType::Argument, next.substr( delimiterPos + 1 ) } ); + } else { + if( next[1] != '-' && next.size() > 2 ) { + std::string opt = "- "; + for( size_t i = 1; i < next.size(); ++i ) { + opt[1] = next[i]; + m_tokenBuffer.push_back( { TokenType::Option, opt } ); + } + } else { + m_tokenBuffer.push_back( { TokenType::Option, next } ); + } + } + } else { + m_tokenBuffer.push_back( { TokenType::Argument, next } ); + } + } + } + + public: + explicit TokenStream( Args const &args ) : TokenStream( args.m_args.begin(), args.m_args.end() ) {} + + TokenStream( Iterator it, Iterator itEnd ) : it( it ), itEnd( itEnd ) { + loadBuffer(); + } + + explicit operator bool() const { + return !m_tokenBuffer.empty() || it != itEnd; + } + + auto count() const -> size_t { return m_tokenBuffer.size() + (itEnd - it); } + + auto operator*() const -> Token { + assert( !m_tokenBuffer.empty() ); + return m_tokenBuffer.front(); + } + + auto operator->() const -> Token const * { + assert( !m_tokenBuffer.empty() ); + return &m_tokenBuffer.front(); + } + + auto operator++() -> TokenStream & { + if( m_tokenBuffer.size() >= 2 ) { + m_tokenBuffer.erase( m_tokenBuffer.begin() ); + } else { + if( it != itEnd ) + ++it; + loadBuffer(); + } + return *this; + } + }; + + class ResultBase { + public: + enum Type { + Ok, LogicError, RuntimeError + }; + + protected: + ResultBase( Type type ) : m_type( type ) {} + virtual ~ResultBase() = default; + + virtual void enforceOk() const = 0; + + Type m_type; + }; + + template + class ResultValueBase : public ResultBase { + public: + auto value() const -> T const & { + enforceOk(); + return m_value; + } + + protected: + ResultValueBase( Type type ) : ResultBase( type ) {} + + ResultValueBase( ResultValueBase const &other ) : ResultBase( other ) { + if( m_type == ResultBase::Ok ) + new( &m_value ) T( other.m_value ); + } + + ResultValueBase( Type, T const &value ) : ResultBase( Ok ) { + new( &m_value ) T( value ); + } + + auto operator=( ResultValueBase const &other ) -> ResultValueBase & { + if( m_type == ResultBase::Ok ) + m_value.~T(); + ResultBase::operator=(other); + if( m_type == ResultBase::Ok ) + new( &m_value ) T( other.m_value ); + return *this; + } + + ~ResultValueBase() override { + if( m_type == Ok ) + m_value.~T(); + } + + union { + T m_value; + }; + }; + + template<> + class ResultValueBase : public ResultBase { + protected: + using ResultBase::ResultBase; + }; + + template + class BasicResult : public ResultValueBase { + public: + template + explicit BasicResult( BasicResult const &other ) + : ResultValueBase( other.type() ), + m_errorMessage( other.errorMessage() ) + { + assert( type() != ResultBase::Ok ); + } + + template + static auto ok( U const &value ) -> BasicResult { return { ResultBase::Ok, value }; } + static auto ok() -> BasicResult { return { ResultBase::Ok }; } + static auto logicError( std::string const &message ) -> BasicResult { return { ResultBase::LogicError, message }; } + static auto runtimeError( std::string const &message ) -> BasicResult { return { ResultBase::RuntimeError, message }; } + + explicit operator bool() const { return m_type == ResultBase::Ok; } + auto type() const -> ResultBase::Type { return m_type; } + auto errorMessage() const -> std::string { return m_errorMessage; } + + protected: + void enforceOk() const override { + + // Errors shouldn't reach this point, but if they do + // the actual error message will be in m_errorMessage + assert( m_type != ResultBase::LogicError ); + assert( m_type != ResultBase::RuntimeError ); + if( m_type != ResultBase::Ok ) + std::abort(); + } + + std::string m_errorMessage; // Only populated if resultType is an error + + BasicResult( ResultBase::Type type, std::string const &message ) + : ResultValueBase(type), + m_errorMessage(message) + { + assert( m_type != ResultBase::Ok ); + } + + using ResultValueBase::ResultValueBase; + using ResultBase::m_type; + }; + + enum class ParseResultType { + Matched, NoMatch, ShortCircuitAll, ShortCircuitSame + }; + + class ParseState { + public: + + ParseState( ParseResultType type, TokenStream const &remainingTokens ) + : m_type(type), + m_remainingTokens( remainingTokens ) + {} + + auto type() const -> ParseResultType { return m_type; } + auto remainingTokens() const -> TokenStream { return m_remainingTokens; } + + private: + ParseResultType m_type; + TokenStream m_remainingTokens; + }; + + using Result = BasicResult; + using ParserResult = BasicResult; + using InternalParseResult = BasicResult; + + struct HelpColumns { + std::string left; + std::string right; + }; + + template + inline auto convertInto( std::string const &source, T& target ) -> ParserResult { + std::stringstream ss; + ss << source; + ss >> target; + if( ss.fail() ) + return ParserResult::runtimeError( "Unable to convert '" + source + "' to destination type" ); + else + return ParserResult::ok( ParseResultType::Matched ); + } + inline auto convertInto( std::string const &source, std::string& target ) -> ParserResult { + target = source; + return ParserResult::ok( ParseResultType::Matched ); + } + inline auto convertInto( std::string const &source, bool &target ) -> ParserResult { + std::string srcLC = source; + std::transform( srcLC.begin(), srcLC.end(), srcLC.begin(), []( char c ) { return static_cast( std::tolower(c) ); } ); + if (srcLC == "y" || srcLC == "1" || srcLC == "true" || srcLC == "yes" || srcLC == "on") + target = true; + else if (srcLC == "n" || srcLC == "0" || srcLC == "false" || srcLC == "no" || srcLC == "off") + target = false; + else + return ParserResult::runtimeError( "Expected a boolean value but did not recognise: '" + source + "'" ); + return ParserResult::ok( ParseResultType::Matched ); + } +#ifdef CLARA_CONFIG_OPTIONAL_TYPE + template + inline auto convertInto( std::string const &source, CLARA_CONFIG_OPTIONAL_TYPE& target ) -> ParserResult { + T temp; + auto result = convertInto( source, temp ); + if( result ) + target = std::move(temp); + return result; + } +#endif // CLARA_CONFIG_OPTIONAL_TYPE + + struct NonCopyable { + NonCopyable() = default; + NonCopyable( NonCopyable const & ) = delete; + NonCopyable( NonCopyable && ) = delete; + NonCopyable &operator=( NonCopyable const & ) = delete; + NonCopyable &operator=( NonCopyable && ) = delete; + }; + + struct BoundRef : NonCopyable { + virtual ~BoundRef() = default; + virtual auto isContainer() const -> bool { return false; } + virtual auto isFlag() const -> bool { return false; } + }; + struct BoundValueRefBase : BoundRef { + virtual auto setValue( std::string const &arg ) -> ParserResult = 0; + }; + struct BoundFlagRefBase : BoundRef { + virtual auto setFlag( bool flag ) -> ParserResult = 0; + virtual auto isFlag() const -> bool { return true; } + }; + + template + struct BoundValueRef : BoundValueRefBase { + T &m_ref; + + explicit BoundValueRef( T &ref ) : m_ref( ref ) {} + + auto setValue( std::string const &arg ) -> ParserResult override { + return convertInto( arg, m_ref ); + } + }; + + template + struct BoundValueRef> : BoundValueRefBase { + std::vector &m_ref; + + explicit BoundValueRef( std::vector &ref ) : m_ref( ref ) {} + + auto isContainer() const -> bool override { return true; } + + auto setValue( std::string const &arg ) -> ParserResult override { + T temp; + auto result = convertInto( arg, temp ); + if( result ) + m_ref.push_back( temp ); + return result; + } + }; + + struct BoundFlagRef : BoundFlagRefBase { + bool &m_ref; + + explicit BoundFlagRef( bool &ref ) : m_ref( ref ) {} + + auto setFlag( bool flag ) -> ParserResult override { + m_ref = flag; + return ParserResult::ok( ParseResultType::Matched ); + } + }; + + template + struct LambdaInvoker { + static_assert( std::is_same::value, "Lambda must return void or clara::ParserResult" ); + + template + static auto invoke( L const &lambda, ArgType const &arg ) -> ParserResult { + return lambda( arg ); + } + }; + + template<> + struct LambdaInvoker { + template + static auto invoke( L const &lambda, ArgType const &arg ) -> ParserResult { + lambda( arg ); + return ParserResult::ok( ParseResultType::Matched ); + } + }; + + template + inline auto invokeLambda( L const &lambda, std::string const &arg ) -> ParserResult { + ArgType temp{}; + auto result = convertInto( arg, temp ); + return !result + ? result + : LambdaInvoker::ReturnType>::invoke( lambda, temp ); + } + + template + struct BoundLambda : BoundValueRefBase { + L m_lambda; + + static_assert( UnaryLambdaTraits::isValid, "Supplied lambda must take exactly one argument" ); + explicit BoundLambda( L const &lambda ) : m_lambda( lambda ) {} + + auto setValue( std::string const &arg ) -> ParserResult override { + return invokeLambda::ArgType>( m_lambda, arg ); + } + }; + + template + struct BoundFlagLambda : BoundFlagRefBase { + L m_lambda; + + static_assert( UnaryLambdaTraits::isValid, "Supplied lambda must take exactly one argument" ); + static_assert( std::is_same::ArgType, bool>::value, "flags must be boolean" ); + + explicit BoundFlagLambda( L const &lambda ) : m_lambda( lambda ) {} + + auto setFlag( bool flag ) -> ParserResult override { + return LambdaInvoker::ReturnType>::invoke( m_lambda, flag ); + } + }; + + enum class Optionality { Optional, Required }; + + struct Parser; + + class ParserBase { + public: + virtual ~ParserBase() = default; + virtual auto validate() const -> Result { return Result::ok(); } + virtual auto parse( std::string const& exeName, TokenStream const &tokens) const -> InternalParseResult = 0; + virtual auto cardinality() const -> size_t { return 1; } + + auto parse( Args const &args ) const -> InternalParseResult { + return parse( args.exeName(), TokenStream( args ) ); + } + }; + + template + class ComposableParserImpl : public ParserBase { + public: + template + auto operator|( T const &other ) const -> Parser; + + template + auto operator+( T const &other ) const -> Parser; + }; + + // Common code and state for Args and Opts + template + class ParserRefImpl : public ComposableParserImpl { + protected: + Optionality m_optionality = Optionality::Optional; + std::shared_ptr m_ref; + std::string m_hint; + std::string m_description; + + explicit ParserRefImpl( std::shared_ptr const &ref ) : m_ref( ref ) {} + + public: + template + ParserRefImpl( T &ref, std::string const &hint ) + : m_ref( std::make_shared>( ref ) ), + m_hint( hint ) + {} + + template + ParserRefImpl( LambdaT const &ref, std::string const &hint ) + : m_ref( std::make_shared>( ref ) ), + m_hint(hint) + {} + + auto operator()( std::string const &description ) -> DerivedT & { + m_description = description; + return static_cast( *this ); + } + + auto optional() -> DerivedT & { + m_optionality = Optionality::Optional; + return static_cast( *this ); + }; + + auto required() -> DerivedT & { + m_optionality = Optionality::Required; + return static_cast( *this ); + }; + + auto isOptional() const -> bool { + return m_optionality == Optionality::Optional; + } + + auto cardinality() const -> size_t override { + if( m_ref->isContainer() ) + return 0; + else + return 1; + } + + auto hint() const -> std::string { return m_hint; } + }; + + class ExeName : public ComposableParserImpl { + std::shared_ptr m_name; + std::shared_ptr m_ref; + + template + static auto makeRef(LambdaT const &lambda) -> std::shared_ptr { + return std::make_shared>( lambda) ; + } + + public: + ExeName() : m_name( std::make_shared( "" ) ) {} + + explicit ExeName( std::string &ref ) : ExeName() { + m_ref = std::make_shared>( ref ); + } + + template + explicit ExeName( LambdaT const& lambda ) : ExeName() { + m_ref = std::make_shared>( lambda ); + } + + // The exe name is not parsed out of the normal tokens, but is handled specially + auto parse( std::string const&, TokenStream const &tokens ) const -> InternalParseResult override { + return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, tokens ) ); + } + + auto name() const -> std::string { return *m_name; } + auto set( std::string const& newName ) -> ParserResult { + + auto lastSlash = newName.find_last_of( "\\/" ); + auto filename = ( lastSlash == std::string::npos ) + ? newName + : newName.substr( lastSlash+1 ); + + *m_name = filename; + if( m_ref ) + return m_ref->setValue( filename ); + else + return ParserResult::ok( ParseResultType::Matched ); + } + }; + + class Arg : public ParserRefImpl { + public: + using ParserRefImpl::ParserRefImpl; + + auto parse( std::string const &, TokenStream const &tokens ) const -> InternalParseResult override { + auto validationResult = validate(); + if( !validationResult ) + return InternalParseResult( validationResult ); + + auto remainingTokens = tokens; + auto const &token = *remainingTokens; + if( token.type != TokenType::Argument ) + return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, remainingTokens ) ); + + assert( !m_ref->isFlag() ); + auto valueRef = static_cast( m_ref.get() ); + + auto result = valueRef->setValue( remainingTokens->token ); + if( !result ) + return InternalParseResult( result ); + else + return InternalParseResult::ok( ParseState( ParseResultType::Matched, ++remainingTokens ) ); + } + }; + + inline auto normaliseOpt( std::string const &optName ) -> std::string { +#ifdef CATCH_PLATFORM_WINDOWS + if( optName[0] == '/' ) + return "-" + optName.substr( 1 ); + else +#endif + return optName; + } + + class Opt : public ParserRefImpl { + protected: + std::vector m_optNames; + + public: + template + explicit Opt( LambdaT const &ref ) : ParserRefImpl( std::make_shared>( ref ) ) {} + + explicit Opt( bool &ref ) : ParserRefImpl( std::make_shared( ref ) ) {} + + template + Opt( LambdaT const &ref, std::string const &hint ) : ParserRefImpl( ref, hint ) {} + + template + Opt( T &ref, std::string const &hint ) : ParserRefImpl( ref, hint ) {} + + auto operator[]( std::string const &optName ) -> Opt & { + m_optNames.push_back( optName ); + return *this; + } + + auto getHelpColumns() const -> std::vector { + std::ostringstream oss; + bool first = true; + for( auto const &opt : m_optNames ) { + if (first) + first = false; + else + oss << ", "; + oss << opt; + } + if( !m_hint.empty() ) + oss << " <" << m_hint << ">"; + return { { oss.str(), m_description } }; + } + + auto isMatch( std::string const &optToken ) const -> bool { + auto normalisedToken = normaliseOpt( optToken ); + for( auto const &name : m_optNames ) { + if( normaliseOpt( name ) == normalisedToken ) + return true; + } + return false; + } + + using ParserBase::parse; + + auto parse( std::string const&, TokenStream const &tokens ) const -> InternalParseResult override { + auto validationResult = validate(); + if( !validationResult ) + return InternalParseResult( validationResult ); + + auto remainingTokens = tokens; + if( remainingTokens && remainingTokens->type == TokenType::Option ) { + auto const &token = *remainingTokens; + if( isMatch(token.token ) ) { + if( m_ref->isFlag() ) { + auto flagRef = static_cast( m_ref.get() ); + auto result = flagRef->setFlag( true ); + if( !result ) + return InternalParseResult( result ); + if( result.value() == ParseResultType::ShortCircuitAll ) + return InternalParseResult::ok( ParseState( result.value(), remainingTokens ) ); + } else { + auto valueRef = static_cast( m_ref.get() ); + ++remainingTokens; + if( !remainingTokens ) + return InternalParseResult::runtimeError( "Expected argument following " + token.token ); + auto const &argToken = *remainingTokens; + if( argToken.type != TokenType::Argument ) + return InternalParseResult::runtimeError( "Expected argument following " + token.token ); + auto result = valueRef->setValue( argToken.token ); + if( !result ) + return InternalParseResult( result ); + if( result.value() == ParseResultType::ShortCircuitAll ) + return InternalParseResult::ok( ParseState( result.value(), remainingTokens ) ); + } + return InternalParseResult::ok( ParseState( ParseResultType::Matched, ++remainingTokens ) ); + } + } + return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, remainingTokens ) ); + } + + auto validate() const -> Result override { + if( m_optNames.empty() ) + return Result::logicError( "No options supplied to Opt" ); + for( auto const &name : m_optNames ) { + if( name.empty() ) + return Result::logicError( "Option name cannot be empty" ); +#ifdef CATCH_PLATFORM_WINDOWS + if( name[0] != '-' && name[0] != '/' ) + return Result::logicError( "Option name must begin with '-' or '/'" ); +#else + if( name[0] != '-' ) + return Result::logicError( "Option name must begin with '-'" ); +#endif + } + return ParserRefImpl::validate(); + } + }; + + struct Help : Opt { + Help( bool &showHelpFlag ) + : Opt([&]( bool flag ) { + showHelpFlag = flag; + return ParserResult::ok( ParseResultType::ShortCircuitAll ); + }) + { + static_cast( *this ) + ("display usage information") + ["-?"]["-h"]["--help"] + .optional(); + } + }; + + struct Parser : ParserBase { + + mutable ExeName m_exeName; + std::vector m_options; + std::vector m_args; + + auto operator|=( ExeName const &exeName ) -> Parser & { + m_exeName = exeName; + return *this; + } + + auto operator|=( Arg const &arg ) -> Parser & { + m_args.push_back(arg); + return *this; + } + + auto operator|=( Opt const &opt ) -> Parser & { + m_options.push_back(opt); + return *this; + } + + auto operator|=( Parser const &other ) -> Parser & { + m_options.insert(m_options.end(), other.m_options.begin(), other.m_options.end()); + m_args.insert(m_args.end(), other.m_args.begin(), other.m_args.end()); + return *this; + } + + template + auto operator|( T const &other ) const -> Parser { + return Parser( *this ) |= other; + } + + // Forward deprecated interface with '+' instead of '|' + template + auto operator+=( T const &other ) -> Parser & { return operator|=( other ); } + template + auto operator+( T const &other ) const -> Parser { return operator|( other ); } + + auto getHelpColumns() const -> std::vector { + std::vector cols; + for (auto const &o : m_options) { + auto childCols = o.getHelpColumns(); + cols.insert( cols.end(), childCols.begin(), childCols.end() ); + } + return cols; + } + + void writeToStream( std::ostream &os ) const { + if (!m_exeName.name().empty()) { + os << "usage:\n" << " " << m_exeName.name() << " "; + bool required = true, first = true; + for( auto const &arg : m_args ) { + if (first) + first = false; + else + os << " "; + if( arg.isOptional() && required ) { + os << "["; + required = false; + } + os << "<" << arg.hint() << ">"; + if( arg.cardinality() == 0 ) + os << " ... "; + } + if( !required ) + os << "]"; + if( !m_options.empty() ) + os << " options"; + os << "\n\nwhere options are:" << std::endl; + } + + auto rows = getHelpColumns(); + size_t consoleWidth = CATCH_CLARA_CONFIG_CONSOLE_WIDTH; + size_t optWidth = 0; + for( auto const &cols : rows ) + optWidth = (std::max)(optWidth, cols.left.size() + 2); + + optWidth = (std::min)(optWidth, consoleWidth/2); + + for( auto const &cols : rows ) { + auto row = + TextFlow::Column( cols.left ).width( optWidth ).indent( 2 ) + + TextFlow::Spacer(4) + + TextFlow::Column( cols.right ).width( consoleWidth - 7 - optWidth ); + os << row << std::endl; + } + } + + friend auto operator<<( std::ostream &os, Parser const &parser ) -> std::ostream& { + parser.writeToStream( os ); + return os; + } + + auto validate() const -> Result override { + for( auto const &opt : m_options ) { + auto result = opt.validate(); + if( !result ) + return result; + } + for( auto const &arg : m_args ) { + auto result = arg.validate(); + if( !result ) + return result; + } + return Result::ok(); + } + + using ParserBase::parse; + + auto parse( std::string const& exeName, TokenStream const &tokens ) const -> InternalParseResult override { + + struct ParserInfo { + ParserBase const* parser = nullptr; + size_t count = 0; + }; + const size_t totalParsers = m_options.size() + m_args.size(); + assert( totalParsers < 512 ); + // ParserInfo parseInfos[totalParsers]; // <-- this is what we really want to do + ParserInfo parseInfos[512]; + + { + size_t i = 0; + for (auto const &opt : m_options) parseInfos[i++].parser = &opt; + for (auto const &arg : m_args) parseInfos[i++].parser = &arg; + } + + m_exeName.set( exeName ); + + auto result = InternalParseResult::ok( ParseState( ParseResultType::NoMatch, tokens ) ); + while( result.value().remainingTokens() ) { + bool tokenParsed = false; + + for( size_t i = 0; i < totalParsers; ++i ) { + auto& parseInfo = parseInfos[i]; + if( parseInfo.parser->cardinality() == 0 || parseInfo.count < parseInfo.parser->cardinality() ) { + result = parseInfo.parser->parse(exeName, result.value().remainingTokens()); + if (!result) + return result; + if (result.value().type() != ParseResultType::NoMatch) { + tokenParsed = true; + ++parseInfo.count; + break; + } + } + } + + if( result.value().type() == ParseResultType::ShortCircuitAll ) + return result; + if( !tokenParsed ) + return InternalParseResult::runtimeError( "Unrecognised token: " + result.value().remainingTokens()->token ); + } + // !TBD Check missing required options + return result; + } + }; + + template + template + auto ComposableParserImpl::operator|( T const &other ) const -> Parser { + return Parser() | static_cast( *this ) | other; + } +} // namespace detail + +// A Combined parser +using detail::Parser; + +// A parser for options +using detail::Opt; + +// A parser for arguments +using detail::Arg; + +// Wrapper for argc, argv from main() +using detail::Args; + +// Specifies the name of the executable +using detail::ExeName; + +// Convenience wrapper for option parser that specifies the help option +using detail::Help; + +// enum of result types from a parse +using detail::ParseResultType; + +// Result type for parser operation +using detail::ParserResult; + +}} // namespace Catch::clara + +// end clara.hpp +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// Restore Clara's value for console width, if present +#ifdef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#undef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH +#endif + +// end catch_clara.h +namespace Catch { + + clara::Parser makeCommandLineParser( ConfigData& config ); + +} // end namespace Catch + +// end catch_commandline.h +#include +#include + +namespace Catch { + + clara::Parser makeCommandLineParser( ConfigData& config ) { + + using namespace clara; + + auto const setWarning = [&]( std::string const& warning ) { + auto warningSet = [&]() { + if( warning == "NoAssertions" ) + return WarnAbout::NoAssertions; + + if ( warning == "NoTests" ) + return WarnAbout::NoTests; + + return WarnAbout::Nothing; + }(); + + if (warningSet == WarnAbout::Nothing) + return ParserResult::runtimeError( "Unrecognised warning: '" + warning + "'" ); + config.warnings = static_cast( config.warnings | warningSet ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const loadTestNamesFromFile = [&]( std::string const& filename ) { + std::ifstream f( filename.c_str() ); + if( !f.is_open() ) + return ParserResult::runtimeError( "Unable to load input file: '" + filename + "'" ); + + std::string line; + while( std::getline( f, line ) ) { + line = trim(line); + if( !line.empty() && !startsWith( line, '#' ) ) { + if( !startsWith( line, '"' ) ) + line = '"' + line + '"'; + config.testsOrTags.push_back( line + ',' ); + } + } + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setTestOrder = [&]( std::string const& order ) { + if( startsWith( "declared", order ) ) + config.runOrder = RunTests::InDeclarationOrder; + else if( startsWith( "lexical", order ) ) + config.runOrder = RunTests::InLexicographicalOrder; + else if( startsWith( "random", order ) ) + config.runOrder = RunTests::InRandomOrder; + else + return clara::ParserResult::runtimeError( "Unrecognised ordering: '" + order + "'" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setRngSeed = [&]( std::string const& seed ) { + if( seed != "time" ) + return clara::detail::convertInto( seed, config.rngSeed ); + config.rngSeed = static_cast( std::time(nullptr) ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setColourUsage = [&]( std::string const& useColour ) { + auto mode = toLower( useColour ); + + if( mode == "yes" ) + config.useColour = UseColour::Yes; + else if( mode == "no" ) + config.useColour = UseColour::No; + else if( mode == "auto" ) + config.useColour = UseColour::Auto; + else + return ParserResult::runtimeError( "colour mode must be one of: auto, yes or no. '" + useColour + "' not recognised" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setWaitForKeypress = [&]( std::string const& keypress ) { + auto keypressLc = toLower( keypress ); + if( keypressLc == "start" ) + config.waitForKeypress = WaitForKeypress::BeforeStart; + else if( keypressLc == "exit" ) + config.waitForKeypress = WaitForKeypress::BeforeExit; + else if( keypressLc == "both" ) + config.waitForKeypress = WaitForKeypress::BeforeStartAndExit; + else + return ParserResult::runtimeError( "keypress argument must be one of: start, exit or both. '" + keypress + "' not recognised" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setVerbosity = [&]( std::string const& verbosity ) { + auto lcVerbosity = toLower( verbosity ); + if( lcVerbosity == "quiet" ) + config.verbosity = Verbosity::Quiet; + else if( lcVerbosity == "normal" ) + config.verbosity = Verbosity::Normal; + else if( lcVerbosity == "high" ) + config.verbosity = Verbosity::High; + else + return ParserResult::runtimeError( "Unrecognised verbosity, '" + verbosity + "'" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + auto const setReporter = [&]( std::string const& reporter ) { + IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories(); + + auto lcReporter = toLower( reporter ); + auto result = factories.find( lcReporter ); + + if( factories.end() != result ) + config.reporterName = lcReporter; + else + return ParserResult::runtimeError( "Unrecognized reporter, '" + reporter + "'. Check available with --list-reporters" ); + return ParserResult::ok( ParseResultType::Matched ); + }; + + auto cli + = ExeName( config.processName ) + | Help( config.showHelp ) + | Opt( config.listTests ) + ["-l"]["--list-tests"] + ( "list all/matching test cases" ) + | Opt( config.listTags ) + ["-t"]["--list-tags"] + ( "list all/matching tags" ) + | Opt( config.showSuccessfulTests ) + ["-s"]["--success"] + ( "include successful tests in output" ) + | Opt( config.shouldDebugBreak ) + ["-b"]["--break"] + ( "break into debugger on failure" ) + | Opt( config.noThrow ) + ["-e"]["--nothrow"] + ( "skip exception tests" ) + | Opt( config.showInvisibles ) + ["-i"]["--invisibles"] + ( "show invisibles (tabs, newlines)" ) + | Opt( config.outputFilename, "filename" ) + ["-o"]["--out"] + ( "output filename" ) + | Opt( setReporter, "name" ) + ["-r"]["--reporter"] + ( "reporter to use (defaults to console)" ) + | Opt( config.name, "name" ) + ["-n"]["--name"] + ( "suite name" ) + | Opt( [&]( bool ){ config.abortAfter = 1; } ) + ["-a"]["--abort"] + ( "abort at first failure" ) + | Opt( [&]( int x ){ config.abortAfter = x; }, "no. failures" ) + ["-x"]["--abortx"] + ( "abort after x failures" ) + | Opt( setWarning, "warning name" ) + ["-w"]["--warn"] + ( "enable warnings" ) + | Opt( [&]( bool flag ) { config.showDurations = flag ? ShowDurations::Always : ShowDurations::Never; }, "yes|no" ) + ["-d"]["--durations"] + ( "show test durations" ) + | Opt( loadTestNamesFromFile, "filename" ) + ["-f"]["--input-file"] + ( "load test names to run from a file" ) + | Opt( config.filenamesAsTags ) + ["-#"]["--filenames-as-tags"] + ( "adds a tag for the filename" ) + | Opt( config.sectionsToRun, "section name" ) + ["-c"]["--section"] + ( "specify section to run" ) + | Opt( setVerbosity, "quiet|normal|high" ) + ["-v"]["--verbosity"] + ( "set output verbosity" ) + | Opt( config.listTestNamesOnly ) + ["--list-test-names-only"] + ( "list all/matching test cases names only" ) + | Opt( config.listReporters ) + ["--list-reporters"] + ( "list all reporters" ) + | Opt( setTestOrder, "decl|lex|rand" ) + ["--order"] + ( "test case order (defaults to decl)" ) + | Opt( setRngSeed, "'time'|number" ) + ["--rng-seed"] + ( "set a specific seed for random numbers" ) + | Opt( setColourUsage, "yes|no" ) + ["--use-colour"] + ( "should output be colourised" ) + | Opt( config.libIdentify ) + ["--libidentify"] + ( "report name and version according to libidentify standard" ) + | Opt( setWaitForKeypress, "start|exit|both" ) + ["--wait-for-keypress"] + ( "waits for a keypress before exiting" ) + | Opt( config.benchmarkResolutionMultiple, "multiplier" ) + ["--benchmark-resolution-multiple"] + ( "multiple of clock resolution to run benchmarks" ) + + | Arg( config.testsOrTags, "test name|pattern|tags" ) + ( "which test or tests to use" ); + + return cli; + } + +} // end namespace Catch +// end catch_commandline.cpp +// start catch_common.cpp + +#include +#include + +namespace Catch { + + bool SourceLineInfo::empty() const noexcept { + return file[0] == '\0'; + } + bool SourceLineInfo::operator == ( SourceLineInfo const& other ) const noexcept { + return line == other.line && (file == other.file || std::strcmp(file, other.file) == 0); + } + bool SourceLineInfo::operator < ( SourceLineInfo const& other ) const noexcept { + // We can assume that the same file will usually have the same pointer. + // Thus, if the pointers are the same, there is no point in calling the strcmp + return line < other.line || ( line == other.line && file != other.file && (std::strcmp(file, other.file) < 0)); + } + + std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ) { +#ifndef __GNUG__ + os << info.file << '(' << info.line << ')'; +#else + os << info.file << ':' << info.line; +#endif + return os; + } + + std::string StreamEndStop::operator+() const { + return std::string(); + } + + NonCopyable::NonCopyable() = default; + NonCopyable::~NonCopyable() = default; + +} +// end catch_common.cpp +// start catch_config.cpp + +namespace Catch { + + Config::Config( ConfigData const& data ) + : m_data( data ), + m_stream( openStream() ) + { + TestSpecParser parser(ITagAliasRegistry::get()); + if (data.testsOrTags.empty()) { + parser.parse("~[.]"); // All not hidden tests + } + else { + m_hasTestFilters = true; + for( auto const& testOrTags : data.testsOrTags ) + parser.parse( testOrTags ); + } + m_testSpec = parser.testSpec(); + } + + std::string const& Config::getFilename() const { + return m_data.outputFilename ; + } + + bool Config::listTests() const { return m_data.listTests; } + bool Config::listTestNamesOnly() const { return m_data.listTestNamesOnly; } + bool Config::listTags() const { return m_data.listTags; } + bool Config::listReporters() const { return m_data.listReporters; } + + std::string Config::getProcessName() const { return m_data.processName; } + std::string const& Config::getReporterName() const { return m_data.reporterName; } + + std::vector const& Config::getTestsOrTags() const { return m_data.testsOrTags; } + std::vector const& Config::getSectionsToRun() const { return m_data.sectionsToRun; } + + TestSpec const& Config::testSpec() const { return m_testSpec; } + bool Config::hasTestFilters() const { return m_hasTestFilters; } + + bool Config::showHelp() const { return m_data.showHelp; } + + // IConfig interface + bool Config::allowThrows() const { return !m_data.noThrow; } + std::ostream& Config::stream() const { return m_stream->stream(); } + std::string Config::name() const { return m_data.name.empty() ? m_data.processName : m_data.name; } + bool Config::includeSuccessfulResults() const { return m_data.showSuccessfulTests; } + bool Config::warnAboutMissingAssertions() const { return !!(m_data.warnings & WarnAbout::NoAssertions); } + bool Config::warnAboutNoTests() const { return !!(m_data.warnings & WarnAbout::NoTests); } + ShowDurations::OrNot Config::showDurations() const { return m_data.showDurations; } + RunTests::InWhatOrder Config::runOrder() const { return m_data.runOrder; } + unsigned int Config::rngSeed() const { return m_data.rngSeed; } + int Config::benchmarkResolutionMultiple() const { return m_data.benchmarkResolutionMultiple; } + UseColour::YesOrNo Config::useColour() const { return m_data.useColour; } + bool Config::shouldDebugBreak() const { return m_data.shouldDebugBreak; } + int Config::abortAfter() const { return m_data.abortAfter; } + bool Config::showInvisibles() const { return m_data.showInvisibles; } + Verbosity Config::verbosity() const { return m_data.verbosity; } + + IStream const* Config::openStream() { + return Catch::makeStream(m_data.outputFilename); + } + +} // end namespace Catch +// end catch_config.cpp +// start catch_console_colour.cpp + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +#endif + +// start catch_errno_guard.h + +namespace Catch { + + class ErrnoGuard { + public: + ErrnoGuard(); + ~ErrnoGuard(); + private: + int m_oldErrno; + }; + +} + +// end catch_errno_guard.h +#include + +namespace Catch { + namespace { + + struct IColourImpl { + virtual ~IColourImpl() = default; + virtual void use( Colour::Code _colourCode ) = 0; + }; + + struct NoColourImpl : IColourImpl { + void use( Colour::Code ) {} + + static IColourImpl* instance() { + static NoColourImpl s_instance; + return &s_instance; + } + }; + + } // anon namespace +} // namespace Catch + +#if !defined( CATCH_CONFIG_COLOUR_NONE ) && !defined( CATCH_CONFIG_COLOUR_WINDOWS ) && !defined( CATCH_CONFIG_COLOUR_ANSI ) +# ifdef CATCH_PLATFORM_WINDOWS +# define CATCH_CONFIG_COLOUR_WINDOWS +# else +# define CATCH_CONFIG_COLOUR_ANSI +# endif +#endif + +#if defined ( CATCH_CONFIG_COLOUR_WINDOWS ) ///////////////////////////////////////// + +namespace Catch { +namespace { + + class Win32ColourImpl : public IColourImpl { + public: + Win32ColourImpl() : stdoutHandle( GetStdHandle(STD_OUTPUT_HANDLE) ) + { + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo( stdoutHandle, &csbiInfo ); + originalForegroundAttributes = csbiInfo.wAttributes & ~( BACKGROUND_GREEN | BACKGROUND_RED | BACKGROUND_BLUE | BACKGROUND_INTENSITY ); + originalBackgroundAttributes = csbiInfo.wAttributes & ~( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE | FOREGROUND_INTENSITY ); + } + + virtual void use( Colour::Code _colourCode ) override { + switch( _colourCode ) { + case Colour::None: return setTextAttribute( originalForegroundAttributes ); + case Colour::White: return setTextAttribute( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + case Colour::Red: return setTextAttribute( FOREGROUND_RED ); + case Colour::Green: return setTextAttribute( FOREGROUND_GREEN ); + case Colour::Blue: return setTextAttribute( FOREGROUND_BLUE ); + case Colour::Cyan: return setTextAttribute( FOREGROUND_BLUE | FOREGROUND_GREEN ); + case Colour::Yellow: return setTextAttribute( FOREGROUND_RED | FOREGROUND_GREEN ); + case Colour::Grey: return setTextAttribute( 0 ); + + case Colour::LightGrey: return setTextAttribute( FOREGROUND_INTENSITY ); + case Colour::BrightRed: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED ); + case Colour::BrightGreen: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN ); + case Colour::BrightWhite: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE ); + case Colour::BrightYellow: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED | FOREGROUND_GREEN ); + + case Colour::Bright: CATCH_INTERNAL_ERROR( "not a colour" ); + + default: + CATCH_ERROR( "Unknown colour requested" ); + } + } + + private: + void setTextAttribute( WORD _textAttribute ) { + SetConsoleTextAttribute( stdoutHandle, _textAttribute | originalBackgroundAttributes ); + } + HANDLE stdoutHandle; + WORD originalForegroundAttributes; + WORD originalBackgroundAttributes; + }; + + IColourImpl* platformColourInstance() { + static Win32ColourImpl s_instance; + + IConfigPtr config = getCurrentContext().getConfig(); + UseColour::YesOrNo colourMode = config + ? config->useColour() + : UseColour::Auto; + if( colourMode == UseColour::Auto ) + colourMode = UseColour::Yes; + return colourMode == UseColour::Yes + ? &s_instance + : NoColourImpl::instance(); + } + +} // end anon namespace +} // end namespace Catch + +#elif defined( CATCH_CONFIG_COLOUR_ANSI ) ////////////////////////////////////// + +#include + +namespace Catch { +namespace { + + // use POSIX/ ANSI console terminal codes + // Thanks to Adam Strzelecki for original contribution + // (http://github.com/nanoant) + // https://github.com/philsquared/Catch/pull/131 + class PosixColourImpl : public IColourImpl { + public: + virtual void use( Colour::Code _colourCode ) override { + switch( _colourCode ) { + case Colour::None: + case Colour::White: return setColour( "[0m" ); + case Colour::Red: return setColour( "[0;31m" ); + case Colour::Green: return setColour( "[0;32m" ); + case Colour::Blue: return setColour( "[0;34m" ); + case Colour::Cyan: return setColour( "[0;36m" ); + case Colour::Yellow: return setColour( "[0;33m" ); + case Colour::Grey: return setColour( "[1;30m" ); + + case Colour::LightGrey: return setColour( "[0;37m" ); + case Colour::BrightRed: return setColour( "[1;31m" ); + case Colour::BrightGreen: return setColour( "[1;32m" ); + case Colour::BrightWhite: return setColour( "[1;37m" ); + case Colour::BrightYellow: return setColour( "[1;33m" ); + + case Colour::Bright: CATCH_INTERNAL_ERROR( "not a colour" ); + default: CATCH_INTERNAL_ERROR( "Unknown colour requested" ); + } + } + static IColourImpl* instance() { + static PosixColourImpl s_instance; + return &s_instance; + } + + private: + void setColour( const char* _escapeCode ) { + getCurrentContext().getConfig()->stream() + << '\033' << _escapeCode; + } + }; + + bool useColourOnPlatform() { + return +#ifdef CATCH_PLATFORM_MAC + !isDebuggerActive() && +#endif +#if !(defined(__DJGPP__) && defined(__STRICT_ANSI__)) + isatty(STDOUT_FILENO) +#else + false +#endif + ; + } + IColourImpl* platformColourInstance() { + ErrnoGuard guard; + IConfigPtr config = getCurrentContext().getConfig(); + UseColour::YesOrNo colourMode = config + ? config->useColour() + : UseColour::Auto; + if( colourMode == UseColour::Auto ) + colourMode = useColourOnPlatform() + ? UseColour::Yes + : UseColour::No; + return colourMode == UseColour::Yes + ? PosixColourImpl::instance() + : NoColourImpl::instance(); + } + +} // end anon namespace +} // end namespace Catch + +#else // not Windows or ANSI /////////////////////////////////////////////// + +namespace Catch { + + static IColourImpl* platformColourInstance() { return NoColourImpl::instance(); } + +} // end namespace Catch + +#endif // Windows/ ANSI/ None + +namespace Catch { + + Colour::Colour( Code _colourCode ) { use( _colourCode ); } + Colour::Colour( Colour&& rhs ) noexcept { + m_moved = rhs.m_moved; + rhs.m_moved = true; + } + Colour& Colour::operator=( Colour&& rhs ) noexcept { + m_moved = rhs.m_moved; + rhs.m_moved = true; + return *this; + } + + Colour::~Colour(){ if( !m_moved ) use( None ); } + + void Colour::use( Code _colourCode ) { + static IColourImpl* impl = platformColourInstance(); + impl->use( _colourCode ); + } + + std::ostream& operator << ( std::ostream& os, Colour const& ) { + return os; + } + +} // end namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + +// end catch_console_colour.cpp +// start catch_context.cpp + +namespace Catch { + + class Context : public IMutableContext, NonCopyable { + + public: // IContext + virtual IResultCapture* getResultCapture() override { + return m_resultCapture; + } + virtual IRunner* getRunner() override { + return m_runner; + } + + virtual IConfigPtr const& getConfig() const override { + return m_config; + } + + virtual ~Context() override; + + public: // IMutableContext + virtual void setResultCapture( IResultCapture* resultCapture ) override { + m_resultCapture = resultCapture; + } + virtual void setRunner( IRunner* runner ) override { + m_runner = runner; + } + virtual void setConfig( IConfigPtr const& config ) override { + m_config = config; + } + + friend IMutableContext& getCurrentMutableContext(); + + private: + IConfigPtr m_config; + IRunner* m_runner = nullptr; + IResultCapture* m_resultCapture = nullptr; + }; + + IMutableContext *IMutableContext::currentContext = nullptr; + + void IMutableContext::createContext() + { + currentContext = new Context(); + } + + void cleanUpContext() { + delete IMutableContext::currentContext; + IMutableContext::currentContext = nullptr; + } + IContext::~IContext() = default; + IMutableContext::~IMutableContext() = default; + Context::~Context() = default; +} +// end catch_context.cpp +// start catch_debug_console.cpp + +// start catch_debug_console.h + +#include + +namespace Catch { + void writeToDebugConsole( std::string const& text ); +} + +// end catch_debug_console.h +#ifdef CATCH_PLATFORM_WINDOWS + + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + ::OutputDebugStringA( text.c_str() ); + } + } + +#else + + namespace Catch { + void writeToDebugConsole( std::string const& text ) { + // !TBD: Need a version for Mac/ XCode and other IDEs + Catch::cout() << text; + } + } + +#endif // Platform +// end catch_debug_console.cpp +// start catch_debugger.cpp + +#ifdef CATCH_PLATFORM_MAC + +# include +# include +# include +# include +# include +# include +# include + +namespace Catch { + + // The following function is taken directly from the following technical note: + // http://developer.apple.com/library/mac/#qa/qa2004/qa1361.html + + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive(){ + + int mib[4]; + struct kinfo_proc info; + std::size_t size; + + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + + info.kp_proc.p_flag = 0; + + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + + // Call sysctl. + + size = sizeof(info); + if( sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, nullptr, 0) != 0 ) { + Catch::cerr() << "\n** Call to sysctl failed - unable to determine if debugger is active **\n" << std::endl; + return false; + } + + // We're being debugged if the P_TRACED flag is set. + + return ( (info.kp_proc.p_flag & P_TRACED) != 0 ); + } + } // namespace Catch + +#elif defined(CATCH_PLATFORM_LINUX) + #include + #include + + namespace Catch{ + // The standard POSIX way of detecting a debugger is to attempt to + // ptrace() the process, but this needs to be done from a child and not + // this process itself to still allow attaching to this process later + // if wanted, so is rather heavy. Under Linux we have the PID of the + // "debugger" (which doesn't need to be gdb, of course, it could also + // be strace, for example) in /proc/$PID/status, so just get it from + // there instead. + bool isDebuggerActive(){ + // Libstdc++ has a bug, where std::ifstream sets errno to 0 + // This way our users can properly assert over errno values + ErrnoGuard guard; + std::ifstream in("/proc/self/status"); + for( std::string line; std::getline(in, line); ) { + static const int PREFIX_LEN = 11; + if( line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0 ) { + // We're traced if the PID is not 0 and no other PID starts + // with 0 digit, so it's enough to check for just a single + // character. + return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; + } + } + + return false; + } + } // namespace Catch +#elif defined(_MSC_VER) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#elif defined(__MINGW32__) + extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent(); + namespace Catch { + bool isDebuggerActive() { + return IsDebuggerPresent() != 0; + } + } +#else + namespace Catch { + bool isDebuggerActive() { return false; } + } +#endif // Platform +// end catch_debugger.cpp +// start catch_decomposer.cpp + +namespace Catch { + + ITransientExpression::~ITransientExpression() = default; + + void formatReconstructedExpression( std::ostream &os, std::string const& lhs, StringRef op, std::string const& rhs ) { + if( lhs.size() + rhs.size() < 40 && + lhs.find('\n') == std::string::npos && + rhs.find('\n') == std::string::npos ) + os << lhs << " " << op << " " << rhs; + else + os << lhs << "\n" << op << "\n" << rhs; + } +} +// end catch_decomposer.cpp +// start catch_enforce.cpp + +namespace Catch { +#if defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) && !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS_CUSTOM_HANDLER) + [[noreturn]] + void throw_exception(std::exception const& e) { + Catch::cerr() << "Catch will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; + std::terminate(); + } +#endif +} // namespace Catch; +// end catch_enforce.cpp +// start catch_errno_guard.cpp + +#include + +namespace Catch { + ErrnoGuard::ErrnoGuard():m_oldErrno(errno){} + ErrnoGuard::~ErrnoGuard() { errno = m_oldErrno; } +} +// end catch_errno_guard.cpp +// start catch_exception_translator_registry.cpp + +// start catch_exception_translator_registry.h + +#include +#include +#include + +namespace Catch { + + class ExceptionTranslatorRegistry : public IExceptionTranslatorRegistry { + public: + ~ExceptionTranslatorRegistry(); + virtual void registerTranslator( const IExceptionTranslator* translator ); + virtual std::string translateActiveException() const override; + std::string tryTranslators() const; + + private: + std::vector> m_translators; + }; +} + +// end catch_exception_translator_registry.h +#ifdef __OBJC__ +#import "Foundation/Foundation.h" +#endif + +namespace Catch { + + ExceptionTranslatorRegistry::~ExceptionTranslatorRegistry() { + } + + void ExceptionTranslatorRegistry::registerTranslator( const IExceptionTranslator* translator ) { + m_translators.push_back( std::unique_ptr( translator ) ); + } + +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + std::string ExceptionTranslatorRegistry::translateActiveException() const { + try { +#ifdef __OBJC__ + // In Objective-C try objective-c exceptions first + @try { + return tryTranslators(); + } + @catch (NSException *exception) { + return Catch::Detail::stringify( [exception description] ); + } +#else + // Compiling a mixed mode project with MSVC means that CLR + // exceptions will be caught in (...) as well. However, these + // do not fill-in std::current_exception and thus lead to crash + // when attempting rethrow. + // /EHa switch also causes structured exceptions to be caught + // here, but they fill-in current_exception properly, so + // at worst the output should be a little weird, instead of + // causing a crash. + if (std::current_exception() == nullptr) { + return "Non C++ exception. Possibly a CLR exception."; + } + return tryTranslators(); +#endif + } + catch( TestFailureException& ) { + std::rethrow_exception(std::current_exception()); + } + catch( std::exception& ex ) { + return ex.what(); + } + catch( std::string& msg ) { + return msg; + } + catch( const char* msg ) { + return msg; + } + catch(...) { + return "Unknown exception"; + } + } + + std::string ExceptionTranslatorRegistry::tryTranslators() const { + if (m_translators.empty()) { + std::rethrow_exception(std::current_exception()); + } else { + return m_translators[0]->translate(m_translators.begin() + 1, m_translators.end()); + } + } + +#else // ^^ Exceptions are enabled // Exceptions are disabled vv + std::string ExceptionTranslatorRegistry::translateActiveException() const { + CATCH_INTERNAL_ERROR("Attempted to translate active exception under CATCH_CONFIG_DISABLE_EXCEPTIONS!"); + } + + std::string ExceptionTranslatorRegistry::tryTranslators() const { + CATCH_INTERNAL_ERROR("Attempted to use exception translators under CATCH_CONFIG_DISABLE_EXCEPTIONS!"); + } +#endif + +} +// end catch_exception_translator_registry.cpp +// start catch_fatal_condition.cpp + +#if defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#endif + +#if defined( CATCH_CONFIG_WINDOWS_SEH ) || defined( CATCH_CONFIG_POSIX_SIGNALS ) + +namespace { + // Report the error condition + void reportFatal( char const * const message ) { + Catch::getCurrentContext().getResultCapture()->handleFatalErrorCondition( message ); + } +} + +#endif // signals/SEH handling + +#if defined( CATCH_CONFIG_WINDOWS_SEH ) + +namespace Catch { + struct SignalDefs { DWORD id; const char* name; }; + + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + static SignalDefs signalDefs[] = { + { EXCEPTION_ILLEGAL_INSTRUCTION, "SIGILL - Illegal instruction signal" }, + { EXCEPTION_STACK_OVERFLOW, "SIGSEGV - Stack overflow" }, + { EXCEPTION_ACCESS_VIOLATION, "SIGSEGV - Segmentation violation signal" }, + { EXCEPTION_INT_DIVIDE_BY_ZERO, "Divide by zero error" }, + }; + + LONG CALLBACK FatalConditionHandler::handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) { + for (auto const& def : signalDefs) { + if (ExceptionInfo->ExceptionRecord->ExceptionCode == def.id) { + reportFatal(def.name); + } + } + // If its not an exception we care about, pass it along. + // This stops us from eating debugger breaks etc. + return EXCEPTION_CONTINUE_SEARCH; + } + + FatalConditionHandler::FatalConditionHandler() { + isSet = true; + // 32k seems enough for Catch to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + exceptionHandlerHandle = nullptr; + // Register as first handler in current chain + exceptionHandlerHandle = AddVectoredExceptionHandler(1, handleVectoredException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + } + + void FatalConditionHandler::reset() { + if (isSet) { + RemoveVectoredExceptionHandler(exceptionHandlerHandle); + SetThreadStackGuarantee(&guaranteeSize); + exceptionHandlerHandle = nullptr; + isSet = false; + } + } + + FatalConditionHandler::~FatalConditionHandler() { + reset(); + } + +bool FatalConditionHandler::isSet = false; +ULONG FatalConditionHandler::guaranteeSize = 0; +PVOID FatalConditionHandler::exceptionHandlerHandle = nullptr; + +} // namespace Catch + +#elif defined( CATCH_CONFIG_POSIX_SIGNALS ) + +namespace Catch { + + struct SignalDefs { + int id; + const char* name; + }; + + // 32kb for the alternate stack seems to be sufficient. However, this value + // is experimentally determined, so that's not guaranteed. + constexpr static std::size_t sigStackSize = 32768 >= MINSIGSTKSZ ? 32768 : MINSIGSTKSZ; + + static SignalDefs signalDefs[] = { + { SIGINT, "SIGINT - Terminal interrupt signal" }, + { SIGILL, "SIGILL - Illegal instruction signal" }, + { SIGFPE, "SIGFPE - Floating point error signal" }, + { SIGSEGV, "SIGSEGV - Segmentation violation signal" }, + { SIGTERM, "SIGTERM - Termination request signal" }, + { SIGABRT, "SIGABRT - Abort (abnormal termination) signal" } + }; + + void FatalConditionHandler::handleSignal( int sig ) { + char const * name = ""; + for (auto const& def : signalDefs) { + if (sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise( sig ); + } + + FatalConditionHandler::FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = sigStackSize; + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = { }; + + sa.sa_handler = handleSignal; + sa.sa_flags = SA_ONSTACK; + for (std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + FatalConditionHandler::~FatalConditionHandler() { + reset(); + } + + void FatalConditionHandler::reset() { + if( isSet ) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for( std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i ) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + isSet = false; + } + } + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[sizeof(signalDefs)/sizeof(SignalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + char FatalConditionHandler::altStackMem[sigStackSize] = {}; + +} // namespace Catch + +#else + +namespace Catch { + void FatalConditionHandler::reset() {} +} + +#endif // signals/SEH handling + +#if defined(__GNUC__) +# pragma GCC diagnostic pop +#endif +// end catch_fatal_condition.cpp +// start catch_generators.cpp + +// start catch_random_number_generator.h + +#include +#include + +namespace Catch { + + struct IConfig; + + std::mt19937& rng(); + void seedRng( IConfig const& config ); + unsigned int rngSeed(); + +} + +// end catch_random_number_generator.h +#include +#include + +namespace Catch { + +IGeneratorTracker::~IGeneratorTracker() {} + +const char* GeneratorException::what() const noexcept { + return m_msg; +} + +namespace Generators { + + GeneratorUntypedBase::~GeneratorUntypedBase() {} + + auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& { + return getResultCapture().acquireGeneratorTracker( lineInfo ); + } + +} // namespace Generators +} // namespace Catch +// end catch_generators.cpp +// start catch_interfaces_capture.cpp + +namespace Catch { + IResultCapture::~IResultCapture() = default; +} +// end catch_interfaces_capture.cpp +// start catch_interfaces_config.cpp + +namespace Catch { + IConfig::~IConfig() = default; +} +// end catch_interfaces_config.cpp +// start catch_interfaces_exception.cpp + +namespace Catch { + IExceptionTranslator::~IExceptionTranslator() = default; + IExceptionTranslatorRegistry::~IExceptionTranslatorRegistry() = default; +} +// end catch_interfaces_exception.cpp +// start catch_interfaces_registry_hub.cpp + +namespace Catch { + IRegistryHub::~IRegistryHub() = default; + IMutableRegistryHub::~IMutableRegistryHub() = default; +} +// end catch_interfaces_registry_hub.cpp +// start catch_interfaces_reporter.cpp + +// start catch_reporter_listening.h + +namespace Catch { + + class ListeningReporter : public IStreamingReporter { + using Reporters = std::vector; + Reporters m_listeners; + IStreamingReporterPtr m_reporter = nullptr; + ReporterPreferences m_preferences; + + public: + ListeningReporter(); + + void addListener( IStreamingReporterPtr&& listener ); + void addReporter( IStreamingReporterPtr&& reporter ); + + public: // IStreamingReporter + + ReporterPreferences getPreferences() const override; + + void noMatchingTestCases( std::string const& spec ) override; + + static std::set getSupportedVerbosities(); + + void benchmarkStarting( BenchmarkInfo const& benchmarkInfo ) override; + void benchmarkEnded( BenchmarkStats const& benchmarkStats ) override; + + void testRunStarting( TestRunInfo const& testRunInfo ) override; + void testGroupStarting( GroupInfo const& groupInfo ) override; + void testCaseStarting( TestCaseInfo const& testInfo ) override; + void sectionStarting( SectionInfo const& sectionInfo ) override; + void assertionStarting( AssertionInfo const& assertionInfo ) override; + + // The return value indicates if the messages buffer should be cleared: + bool assertionEnded( AssertionStats const& assertionStats ) override; + void sectionEnded( SectionStats const& sectionStats ) override; + void testCaseEnded( TestCaseStats const& testCaseStats ) override; + void testGroupEnded( TestGroupStats const& testGroupStats ) override; + void testRunEnded( TestRunStats const& testRunStats ) override; + + void skipTest( TestCaseInfo const& testInfo ) override; + bool isMulti() const override; + + }; + +} // end namespace Catch + +// end catch_reporter_listening.h +namespace Catch { + + ReporterConfig::ReporterConfig( IConfigPtr const& _fullConfig ) + : m_stream( &_fullConfig->stream() ), m_fullConfig( _fullConfig ) {} + + ReporterConfig::ReporterConfig( IConfigPtr const& _fullConfig, std::ostream& _stream ) + : m_stream( &_stream ), m_fullConfig( _fullConfig ) {} + + std::ostream& ReporterConfig::stream() const { return *m_stream; } + IConfigPtr ReporterConfig::fullConfig() const { return m_fullConfig; } + + TestRunInfo::TestRunInfo( std::string const& _name ) : name( _name ) {} + + GroupInfo::GroupInfo( std::string const& _name, + std::size_t _groupIndex, + std::size_t _groupsCount ) + : name( _name ), + groupIndex( _groupIndex ), + groupsCounts( _groupsCount ) + {} + + AssertionStats::AssertionStats( AssertionResult const& _assertionResult, + std::vector const& _infoMessages, + Totals const& _totals ) + : assertionResult( _assertionResult ), + infoMessages( _infoMessages ), + totals( _totals ) + { + assertionResult.m_resultData.lazyExpression.m_transientExpression = _assertionResult.m_resultData.lazyExpression.m_transientExpression; + + if( assertionResult.hasMessage() ) { + // Copy message into messages list. + // !TBD This should have been done earlier, somewhere + MessageBuilder builder( assertionResult.getTestMacroName(), assertionResult.getSourceInfo(), assertionResult.getResultType() ); + builder << assertionResult.getMessage(); + builder.m_info.message = builder.m_stream.str(); + + infoMessages.push_back( builder.m_info ); + } + } + + AssertionStats::~AssertionStats() = default; + + SectionStats::SectionStats( SectionInfo const& _sectionInfo, + Counts const& _assertions, + double _durationInSeconds, + bool _missingAssertions ) + : sectionInfo( _sectionInfo ), + assertions( _assertions ), + durationInSeconds( _durationInSeconds ), + missingAssertions( _missingAssertions ) + {} + + SectionStats::~SectionStats() = default; + + TestCaseStats::TestCaseStats( TestCaseInfo const& _testInfo, + Totals const& _totals, + std::string const& _stdOut, + std::string const& _stdErr, + bool _aborting ) + : testInfo( _testInfo ), + totals( _totals ), + stdOut( _stdOut ), + stdErr( _stdErr ), + aborting( _aborting ) + {} + + TestCaseStats::~TestCaseStats() = default; + + TestGroupStats::TestGroupStats( GroupInfo const& _groupInfo, + Totals const& _totals, + bool _aborting ) + : groupInfo( _groupInfo ), + totals( _totals ), + aborting( _aborting ) + {} + + TestGroupStats::TestGroupStats( GroupInfo const& _groupInfo ) + : groupInfo( _groupInfo ), + aborting( false ) + {} + + TestGroupStats::~TestGroupStats() = default; + + TestRunStats::TestRunStats( TestRunInfo const& _runInfo, + Totals const& _totals, + bool _aborting ) + : runInfo( _runInfo ), + totals( _totals ), + aborting( _aborting ) + {} + + TestRunStats::~TestRunStats() = default; + + void IStreamingReporter::fatalErrorEncountered( StringRef ) {} + bool IStreamingReporter::isMulti() const { return false; } + + IReporterFactory::~IReporterFactory() = default; + IReporterRegistry::~IReporterRegistry() = default; + +} // end namespace Catch +// end catch_interfaces_reporter.cpp +// start catch_interfaces_runner.cpp + +namespace Catch { + IRunner::~IRunner() = default; +} +// end catch_interfaces_runner.cpp +// start catch_interfaces_testcase.cpp + +namespace Catch { + ITestInvoker::~ITestInvoker() = default; + ITestCaseRegistry::~ITestCaseRegistry() = default; +} +// end catch_interfaces_testcase.cpp +// start catch_leak_detector.cpp + +#ifdef CATCH_CONFIG_WINDOWS_CRTDBG +#include + +namespace Catch { + + LeakDetector::LeakDetector() { + int flag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); + flag |= _CRTDBG_LEAK_CHECK_DF; + flag |= _CRTDBG_ALLOC_MEM_DF; + _CrtSetDbgFlag(flag); + _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR); + // Change this to leaking allocation's number to break there + _CrtSetBreakAlloc(-1); + } +} + +#else + + Catch::LeakDetector::LeakDetector() {} + +#endif + +Catch::LeakDetector::~LeakDetector() { + Catch::cleanUp(); +} +// end catch_leak_detector.cpp +// start catch_list.cpp + +// start catch_list.h + +#include + +namespace Catch { + + std::size_t listTests( Config const& config ); + + std::size_t listTestsNamesOnly( Config const& config ); + + struct TagInfo { + void add( std::string const& spelling ); + std::string all() const; + + std::set spellings; + std::size_t count = 0; + }; + + std::size_t listTags( Config const& config ); + + std::size_t listReporters(); + + Option list( std::shared_ptr const& config ); + +} // end namespace Catch + +// end catch_list.h +// start catch_text.h + +namespace Catch { + using namespace clara::TextFlow; +} + +// end catch_text.h +#include +#include +#include + +namespace Catch { + + std::size_t listTests( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( config.hasTestFilters() ) + Catch::cout() << "Matching test cases:\n"; + else { + Catch::cout() << "All available test cases:\n"; + } + + auto matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( auto const& testCaseInfo : matchedTestCases ) { + Colour::Code colour = testCaseInfo.isHidden() + ? Colour::SecondaryText + : Colour::None; + Colour colourGuard( colour ); + + Catch::cout() << Column( testCaseInfo.name ).initialIndent( 2 ).indent( 4 ) << "\n"; + if( config.verbosity() >= Verbosity::High ) { + Catch::cout() << Column( Catch::Detail::stringify( testCaseInfo.lineInfo ) ).indent(4) << std::endl; + std::string description = testCaseInfo.description; + if( description.empty() ) + description = "(NO DESCRIPTION)"; + Catch::cout() << Column( description ).indent(4) << std::endl; + } + if( !testCaseInfo.tags.empty() ) + Catch::cout() << Column( testCaseInfo.tagsAsString() ).indent( 6 ) << "\n"; + } + + if( !config.hasTestFilters() ) + Catch::cout() << pluralise( matchedTestCases.size(), "test case" ) << '\n' << std::endl; + else + Catch::cout() << pluralise( matchedTestCases.size(), "matching test case" ) << '\n' << std::endl; + return matchedTestCases.size(); + } + + std::size_t listTestsNamesOnly( Config const& config ) { + TestSpec testSpec = config.testSpec(); + std::size_t matchedTests = 0; + std::vector matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( auto const& testCaseInfo : matchedTestCases ) { + matchedTests++; + if( startsWith( testCaseInfo.name, '#' ) ) + Catch::cout() << '"' << testCaseInfo.name << '"'; + else + Catch::cout() << testCaseInfo.name; + if ( config.verbosity() >= Verbosity::High ) + Catch::cout() << "\t@" << testCaseInfo.lineInfo; + Catch::cout() << std::endl; + } + return matchedTests; + } + + void TagInfo::add( std::string const& spelling ) { + ++count; + spellings.insert( spelling ); + } + + std::string TagInfo::all() const { + std::string out; + for( auto const& spelling : spellings ) + out += "[" + spelling + "]"; + return out; + } + + std::size_t listTags( Config const& config ) { + TestSpec testSpec = config.testSpec(); + if( config.hasTestFilters() ) + Catch::cout() << "Tags for matching test cases:\n"; + else { + Catch::cout() << "All available tags:\n"; + } + + std::map tagCounts; + + std::vector matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config ); + for( auto const& testCase : matchedTestCases ) { + for( auto const& tagName : testCase.getTestCaseInfo().tags ) { + std::string lcaseTagName = toLower( tagName ); + auto countIt = tagCounts.find( lcaseTagName ); + if( countIt == tagCounts.end() ) + countIt = tagCounts.insert( std::make_pair( lcaseTagName, TagInfo() ) ).first; + countIt->second.add( tagName ); + } + } + + for( auto const& tagCount : tagCounts ) { + ReusableStringStream rss; + rss << " " << std::setw(2) << tagCount.second.count << " "; + auto str = rss.str(); + auto wrapper = Column( tagCount.second.all() ) + .initialIndent( 0 ) + .indent( str.size() ) + .width( CATCH_CONFIG_CONSOLE_WIDTH-10 ); + Catch::cout() << str << wrapper << '\n'; + } + Catch::cout() << pluralise( tagCounts.size(), "tag" ) << '\n' << std::endl; + return tagCounts.size(); + } + + std::size_t listReporters() { + Catch::cout() << "Available reporters:\n"; + IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories(); + std::size_t maxNameLen = 0; + for( auto const& factoryKvp : factories ) + maxNameLen = (std::max)( maxNameLen, factoryKvp.first.size() ); + + for( auto const& factoryKvp : factories ) { + Catch::cout() + << Column( factoryKvp.first + ":" ) + .indent(2) + .width( 5+maxNameLen ) + + Column( factoryKvp.second->getDescription() ) + .initialIndent(0) + .indent(2) + .width( CATCH_CONFIG_CONSOLE_WIDTH - maxNameLen-8 ) + << "\n"; + } + Catch::cout() << std::endl; + return factories.size(); + } + + Option list( std::shared_ptr const& config ) { + Option listedCount; + getCurrentMutableContext().setConfig( config ); + if( config->listTests() ) + listedCount = listedCount.valueOr(0) + listTests( *config ); + if( config->listTestNamesOnly() ) + listedCount = listedCount.valueOr(0) + listTestsNamesOnly( *config ); + if( config->listTags() ) + listedCount = listedCount.valueOr(0) + listTags( *config ); + if( config->listReporters() ) + listedCount = listedCount.valueOr(0) + listReporters(); + return listedCount; + } + +} // end namespace Catch +// end catch_list.cpp +// start catch_matchers.cpp + +namespace Catch { +namespace Matchers { + namespace Impl { + + std::string MatcherUntypedBase::toString() const { + if( m_cachedToString.empty() ) + m_cachedToString = describe(); + return m_cachedToString; + } + + MatcherUntypedBase::~MatcherUntypedBase() = default; + + } // namespace Impl +} // namespace Matchers + +using namespace Matchers; +using Matchers::Impl::MatcherBase; + +} // namespace Catch +// end catch_matchers.cpp +// start catch_matchers_floating.cpp + +// start catch_polyfills.hpp + +namespace Catch { + bool isnan(float f); + bool isnan(double d); +} + +// end catch_polyfills.hpp +// start catch_to_string.hpp + +#include + +namespace Catch { + template + std::string to_string(T const& t) { +#if defined(CATCH_CONFIG_CPP11_TO_STRING) + return std::to_string(t); +#else + ReusableStringStream rss; + rss << t; + return rss.str(); +#endif + } +} // end namespace Catch + +// end catch_to_string.hpp +#include +#include +#include + +namespace Catch { +namespace Matchers { +namespace Floating { +enum class FloatingPointKind : uint8_t { + Float, + Double +}; +} +} +} + +namespace { + +template +struct Converter; + +template <> +struct Converter { + static_assert(sizeof(float) == sizeof(int32_t), "Important ULP matcher assumption violated"); + Converter(float f) { + std::memcpy(&i, &f, sizeof(f)); + } + int32_t i; +}; + +template <> +struct Converter { + static_assert(sizeof(double) == sizeof(int64_t), "Important ULP matcher assumption violated"); + Converter(double d) { + std::memcpy(&i, &d, sizeof(d)); + } + int64_t i; +}; + +template +auto convert(T t) -> Converter { + return Converter(t); +} + +template +bool almostEqualUlps(FP lhs, FP rhs, int maxUlpDiff) { + // Comparison with NaN should always be false. + // This way we can rule it out before getting into the ugly details + if (Catch::isnan(lhs) || Catch::isnan(rhs)) { + return false; + } + + auto lc = convert(lhs); + auto rc = convert(rhs); + + if ((lc.i < 0) != (rc.i < 0)) { + // Potentially we can have +0 and -0 + return lhs == rhs; + } + + auto ulpDiff = std::abs(lc.i - rc.i); + return ulpDiff <= maxUlpDiff; +} + +} + +namespace Catch { +namespace Matchers { +namespace Floating { + WithinAbsMatcher::WithinAbsMatcher(double target, double margin) + :m_target{ target }, m_margin{ margin } { + CATCH_ENFORCE(margin >= 0, "Invalid margin: " << margin << '.' + << " Margin has to be non-negative."); + } + + // Performs equivalent check of std::fabs(lhs - rhs) <= margin + // But without the subtraction to allow for INFINITY in comparison + bool WithinAbsMatcher::match(double const& matchee) const { + return (matchee + m_margin >= m_target) && (m_target + m_margin >= matchee); + } + + std::string WithinAbsMatcher::describe() const { + return "is within " + ::Catch::Detail::stringify(m_margin) + " of " + ::Catch::Detail::stringify(m_target); + } + + WithinUlpsMatcher::WithinUlpsMatcher(double target, int ulps, FloatingPointKind baseType) + :m_target{ target }, m_ulps{ ulps }, m_type{ baseType } { + CATCH_ENFORCE(ulps >= 0, "Invalid ULP setting: " << ulps << '.' + << " ULPs have to be non-negative."); + } + +#if defined(__clang__) +#pragma clang diagnostic push +// Clang <3.5 reports on the default branch in the switch below +#pragma clang diagnostic ignored "-Wunreachable-code" +#endif + + bool WithinUlpsMatcher::match(double const& matchee) const { + switch (m_type) { + case FloatingPointKind::Float: + return almostEqualUlps(static_cast(matchee), static_cast(m_target), m_ulps); + case FloatingPointKind::Double: + return almostEqualUlps(matchee, m_target, m_ulps); + default: + CATCH_INTERNAL_ERROR( "Unknown FloatingPointKind value" ); + } + } + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + + std::string WithinUlpsMatcher::describe() const { + return "is within " + Catch::to_string(m_ulps) + " ULPs of " + ::Catch::Detail::stringify(m_target) + ((m_type == FloatingPointKind::Float)? "f" : ""); + } + +}// namespace Floating + +Floating::WithinUlpsMatcher WithinULP(double target, int maxUlpDiff) { + return Floating::WithinUlpsMatcher(target, maxUlpDiff, Floating::FloatingPointKind::Double); +} + +Floating::WithinUlpsMatcher WithinULP(float target, int maxUlpDiff) { + return Floating::WithinUlpsMatcher(target, maxUlpDiff, Floating::FloatingPointKind::Float); +} + +Floating::WithinAbsMatcher WithinAbs(double target, double margin) { + return Floating::WithinAbsMatcher(target, margin); +} + +} // namespace Matchers +} // namespace Catch + +// end catch_matchers_floating.cpp +// start catch_matchers_generic.cpp + +std::string Catch::Matchers::Generic::Detail::finalizeDescription(const std::string& desc) { + if (desc.empty()) { + return "matches undescribed predicate"; + } else { + return "matches predicate: \"" + desc + '"'; + } +} +// end catch_matchers_generic.cpp +// start catch_matchers_string.cpp + +#include + +namespace Catch { +namespace Matchers { + + namespace StdString { + + CasedString::CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity ) + : m_caseSensitivity( caseSensitivity ), + m_str( adjustString( str ) ) + {} + std::string CasedString::adjustString( std::string const& str ) const { + return m_caseSensitivity == CaseSensitive::No + ? toLower( str ) + : str; + } + std::string CasedString::caseSensitivitySuffix() const { + return m_caseSensitivity == CaseSensitive::No + ? " (case insensitive)" + : std::string(); + } + + StringMatcherBase::StringMatcherBase( std::string const& operation, CasedString const& comparator ) + : m_comparator( comparator ), + m_operation( operation ) { + } + + std::string StringMatcherBase::describe() const { + std::string description; + description.reserve(5 + m_operation.size() + m_comparator.m_str.size() + + m_comparator.caseSensitivitySuffix().size()); + description += m_operation; + description += ": \""; + description += m_comparator.m_str; + description += "\""; + description += m_comparator.caseSensitivitySuffix(); + return description; + } + + EqualsMatcher::EqualsMatcher( CasedString const& comparator ) : StringMatcherBase( "equals", comparator ) {} + + bool EqualsMatcher::match( std::string const& source ) const { + return m_comparator.adjustString( source ) == m_comparator.m_str; + } + + ContainsMatcher::ContainsMatcher( CasedString const& comparator ) : StringMatcherBase( "contains", comparator ) {} + + bool ContainsMatcher::match( std::string const& source ) const { + return contains( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + StartsWithMatcher::StartsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "starts with", comparator ) {} + + bool StartsWithMatcher::match( std::string const& source ) const { + return startsWith( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + EndsWithMatcher::EndsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "ends with", comparator ) {} + + bool EndsWithMatcher::match( std::string const& source ) const { + return endsWith( m_comparator.adjustString( source ), m_comparator.m_str ); + } + + RegexMatcher::RegexMatcher(std::string regex, CaseSensitive::Choice caseSensitivity): m_regex(std::move(regex)), m_caseSensitivity(caseSensitivity) {} + + bool RegexMatcher::match(std::string const& matchee) const { + auto flags = std::regex::ECMAScript; // ECMAScript is the default syntax option anyway + if (m_caseSensitivity == CaseSensitive::Choice::No) { + flags |= std::regex::icase; + } + auto reg = std::regex(m_regex, flags); + return std::regex_match(matchee, reg); + } + + std::string RegexMatcher::describe() const { + return "matches " + ::Catch::Detail::stringify(m_regex) + ((m_caseSensitivity == CaseSensitive::Choice::Yes)? " case sensitively" : " case insensitively"); + } + + } // namespace StdString + + StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::EqualsMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::ContainsMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::EndsWithMatcher( StdString::CasedString( str, caseSensitivity) ); + } + StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) { + return StdString::StartsWithMatcher( StdString::CasedString( str, caseSensitivity) ); + } + + StdString::RegexMatcher Matches(std::string const& regex, CaseSensitive::Choice caseSensitivity) { + return StdString::RegexMatcher(regex, caseSensitivity); + } + +} // namespace Matchers +} // namespace Catch +// end catch_matchers_string.cpp +// start catch_message.cpp + +// start catch_uncaught_exceptions.h + +namespace Catch { + bool uncaught_exceptions(); +} // end namespace Catch + +// end catch_uncaught_exceptions.h +#include +#include + +namespace Catch { + + MessageInfo::MessageInfo( StringRef const& _macroName, + SourceLineInfo const& _lineInfo, + ResultWas::OfType _type ) + : macroName( _macroName ), + lineInfo( _lineInfo ), + type( _type ), + sequence( ++globalCount ) + {} + + bool MessageInfo::operator==( MessageInfo const& other ) const { + return sequence == other.sequence; + } + + bool MessageInfo::operator<( MessageInfo const& other ) const { + return sequence < other.sequence; + } + + // This may need protecting if threading support is added + unsigned int MessageInfo::globalCount = 0; + + //////////////////////////////////////////////////////////////////////////// + + Catch::MessageBuilder::MessageBuilder( StringRef const& macroName, + SourceLineInfo const& lineInfo, + ResultWas::OfType type ) + :m_info(macroName, lineInfo, type) {} + + //////////////////////////////////////////////////////////////////////////// + + ScopedMessage::ScopedMessage( MessageBuilder const& builder ) + : m_info( builder.m_info ), m_moved() + { + m_info.message = builder.m_stream.str(); + getResultCapture().pushScopedMessage( m_info ); + } + + ScopedMessage::ScopedMessage( ScopedMessage&& old ) + : m_info( old.m_info ), m_moved() + { + old.m_moved = true; + } + + ScopedMessage::~ScopedMessage() { + if ( !uncaught_exceptions() && !m_moved ){ + getResultCapture().popScopedMessage(m_info); + } + } + + Capturer::Capturer( StringRef macroName, SourceLineInfo const& lineInfo, ResultWas::OfType resultType, StringRef names ) { + auto trimmed = [&] (size_t start, size_t end) { + while (names[start] == ',' || isspace(names[start])) { + ++start; + } + while (names[end] == ',' || isspace(names[end])) { + --end; + } + return names.substr(start, end - start + 1); + }; + + size_t start = 0; + std::stack openings; + for (size_t pos = 0; pos < names.size(); ++pos) { + char c = names[pos]; + switch (c) { + case '[': + case '{': + case '(': + // It is basically impossible to disambiguate between + // comparison and start of template args in this context +// case '<': + openings.push(c); + break; + case ']': + case '}': + case ')': +// case '>': + openings.pop(); + break; + case ',': + if (start != pos && openings.size() == 0) { + m_messages.emplace_back(macroName, lineInfo, resultType); + m_messages.back().message = trimmed(start, pos); + m_messages.back().message += " := "; + start = pos; + } + } + } + assert(openings.size() == 0 && "Mismatched openings"); + m_messages.emplace_back(macroName, lineInfo, resultType); + m_messages.back().message = trimmed(start, names.size() - 1); + m_messages.back().message += " := "; + } + Capturer::~Capturer() { + if ( !uncaught_exceptions() ){ + assert( m_captured == m_messages.size() ); + for( size_t i = 0; i < m_captured; ++i ) + m_resultCapture.popScopedMessage( m_messages[i] ); + } + } + + void Capturer::captureValue( size_t index, std::string const& value ) { + assert( index < m_messages.size() ); + m_messages[index].message += value; + m_resultCapture.pushScopedMessage( m_messages[index] ); + m_captured++; + } + +} // end namespace Catch +// end catch_message.cpp +// start catch_output_redirect.cpp + +// start catch_output_redirect.h +#ifndef TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H +#define TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H + +#include +#include +#include + +namespace Catch { + + class RedirectedStream { + std::ostream& m_originalStream; + std::ostream& m_redirectionStream; + std::streambuf* m_prevBuf; + + public: + RedirectedStream( std::ostream& originalStream, std::ostream& redirectionStream ); + ~RedirectedStream(); + }; + + class RedirectedStdOut { + ReusableStringStream m_rss; + RedirectedStream m_cout; + public: + RedirectedStdOut(); + auto str() const -> std::string; + }; + + // StdErr has two constituent streams in C++, std::cerr and std::clog + // This means that we need to redirect 2 streams into 1 to keep proper + // order of writes + class RedirectedStdErr { + ReusableStringStream m_rss; + RedirectedStream m_cerr; + RedirectedStream m_clog; + public: + RedirectedStdErr(); + auto str() const -> std::string; + }; + + class RedirectedStreams { + public: + RedirectedStreams(RedirectedStreams const&) = delete; + RedirectedStreams& operator=(RedirectedStreams const&) = delete; + RedirectedStreams(RedirectedStreams&&) = delete; + RedirectedStreams& operator=(RedirectedStreams&&) = delete; + + RedirectedStreams(std::string& redirectedCout, std::string& redirectedCerr); + ~RedirectedStreams(); + private: + std::string& m_redirectedCout; + std::string& m_redirectedCerr; + RedirectedStdOut m_redirectedStdOut; + RedirectedStdErr m_redirectedStdErr; + }; + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + + // Windows's implementation of std::tmpfile is terrible (it tries + // to create a file inside system folder, thus requiring elevated + // privileges for the binary), so we have to use tmpnam(_s) and + // create the file ourselves there. + class TempFile { + public: + TempFile(TempFile const&) = delete; + TempFile& operator=(TempFile const&) = delete; + TempFile(TempFile&&) = delete; + TempFile& operator=(TempFile&&) = delete; + + TempFile(); + ~TempFile(); + + std::FILE* getFile(); + std::string getContents(); + + private: + std::FILE* m_file = nullptr; + #if defined(_MSC_VER) + char m_buffer[L_tmpnam] = { 0 }; + #endif + }; + + class OutputRedirect { + public: + OutputRedirect(OutputRedirect const&) = delete; + OutputRedirect& operator=(OutputRedirect const&) = delete; + OutputRedirect(OutputRedirect&&) = delete; + OutputRedirect& operator=(OutputRedirect&&) = delete; + + OutputRedirect(std::string& stdout_dest, std::string& stderr_dest); + ~OutputRedirect(); + + private: + int m_originalStdout = -1; + int m_originalStderr = -1; + TempFile m_stdoutFile; + TempFile m_stderrFile; + std::string& m_stdoutDest; + std::string& m_stderrDest; + }; + +#endif + +} // end namespace Catch + +#endif // TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H +// end catch_output_redirect.h +#include +#include +#include +#include +#include + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + #if defined(_MSC_VER) + #include //_dup and _dup2 + #define dup _dup + #define dup2 _dup2 + #define fileno _fileno + #else + #include // dup and dup2 + #endif +#endif + +namespace Catch { + + RedirectedStream::RedirectedStream( std::ostream& originalStream, std::ostream& redirectionStream ) + : m_originalStream( originalStream ), + m_redirectionStream( redirectionStream ), + m_prevBuf( m_originalStream.rdbuf() ) + { + m_originalStream.rdbuf( m_redirectionStream.rdbuf() ); + } + + RedirectedStream::~RedirectedStream() { + m_originalStream.rdbuf( m_prevBuf ); + } + + RedirectedStdOut::RedirectedStdOut() : m_cout( Catch::cout(), m_rss.get() ) {} + auto RedirectedStdOut::str() const -> std::string { return m_rss.str(); } + + RedirectedStdErr::RedirectedStdErr() + : m_cerr( Catch::cerr(), m_rss.get() ), + m_clog( Catch::clog(), m_rss.get() ) + {} + auto RedirectedStdErr::str() const -> std::string { return m_rss.str(); } + + RedirectedStreams::RedirectedStreams(std::string& redirectedCout, std::string& redirectedCerr) + : m_redirectedCout(redirectedCout), + m_redirectedCerr(redirectedCerr) + {} + + RedirectedStreams::~RedirectedStreams() { + m_redirectedCout += m_redirectedStdOut.str(); + m_redirectedCerr += m_redirectedStdErr.str(); + } + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + +#if defined(_MSC_VER) + TempFile::TempFile() { + if (tmpnam_s(m_buffer)) { + CATCH_RUNTIME_ERROR("Could not get a temp filename"); + } + if (fopen_s(&m_file, m_buffer, "w")) { + char buffer[100]; + if (strerror_s(buffer, errno)) { + CATCH_RUNTIME_ERROR("Could not translate errno to a string"); + } + CATCH_RUNTIME_ERROR("Coul dnot open the temp file: '" << m_buffer << "' because: " << buffer); + } + } +#else + TempFile::TempFile() { + m_file = std::tmpfile(); + if (!m_file) { + CATCH_RUNTIME_ERROR("Could not create a temp file."); + } + } + +#endif + + TempFile::~TempFile() { + // TBD: What to do about errors here? + std::fclose(m_file); + // We manually create the file on Windows only, on Linux + // it will be autodeleted +#if defined(_MSC_VER) + std::remove(m_buffer); +#endif + } + + FILE* TempFile::getFile() { + return m_file; + } + + std::string TempFile::getContents() { + std::stringstream sstr; + char buffer[100] = {}; + std::rewind(m_file); + while (std::fgets(buffer, sizeof(buffer), m_file)) { + sstr << buffer; + } + return sstr.str(); + } + + OutputRedirect::OutputRedirect(std::string& stdout_dest, std::string& stderr_dest) : + m_originalStdout(dup(1)), + m_originalStderr(dup(2)), + m_stdoutDest(stdout_dest), + m_stderrDest(stderr_dest) { + dup2(fileno(m_stdoutFile.getFile()), 1); + dup2(fileno(m_stderrFile.getFile()), 2); + } + + OutputRedirect::~OutputRedirect() { + Catch::cout() << std::flush; + fflush(stdout); + // Since we support overriding these streams, we flush cerr + // even though std::cerr is unbuffered + Catch::cerr() << std::flush; + Catch::clog() << std::flush; + fflush(stderr); + + dup2(m_originalStdout, 1); + dup2(m_originalStderr, 2); + + m_stdoutDest += m_stdoutFile.getContents(); + m_stderrDest += m_stderrFile.getContents(); + } + +#endif // CATCH_CONFIG_NEW_CAPTURE + +} // namespace Catch + +#if defined(CATCH_CONFIG_NEW_CAPTURE) + #if defined(_MSC_VER) + #undef dup + #undef dup2 + #undef fileno + #endif +#endif +// end catch_output_redirect.cpp +// start catch_polyfills.cpp + +#include + +namespace Catch { + +#if !defined(CATCH_CONFIG_POLYFILL_ISNAN) + bool isnan(float f) { + return std::isnan(f); + } + bool isnan(double d) { + return std::isnan(d); + } +#else + // For now we only use this for embarcadero + bool isnan(float f) { + return std::_isnan(f); + } + bool isnan(double d) { + return std::_isnan(d); + } +#endif + +} // end namespace Catch +// end catch_polyfills.cpp +// start catch_random_number_generator.cpp + +namespace Catch { + + std::mt19937& rng() { + static std::mt19937 s_rng; + return s_rng; + } + + void seedRng( IConfig const& config ) { + if( config.rngSeed() != 0 ) { + std::srand( config.rngSeed() ); + rng().seed( config.rngSeed() ); + } + } + + unsigned int rngSeed() { + return getCurrentContext().getConfig()->rngSeed(); + } +} +// end catch_random_number_generator.cpp +// start catch_registry_hub.cpp + +// start catch_test_case_registry_impl.h + +#include +#include +#include +#include + +namespace Catch { + + class TestCase; + struct IConfig; + + std::vector sortTests( IConfig const& config, std::vector const& unsortedTestCases ); + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ); + + void enforceNoDuplicateTestCases( std::vector const& functions ); + + std::vector filterTests( std::vector const& testCases, TestSpec const& testSpec, IConfig const& config ); + std::vector const& getAllTestCasesSorted( IConfig const& config ); + + class TestRegistry : public ITestCaseRegistry { + public: + virtual ~TestRegistry() = default; + + virtual void registerTest( TestCase const& testCase ); + + std::vector const& getAllTests() const override; + std::vector const& getAllTestsSorted( IConfig const& config ) const override; + + private: + std::vector m_functions; + mutable RunTests::InWhatOrder m_currentSortOrder = RunTests::InDeclarationOrder; + mutable std::vector m_sortedFunctions; + std::size_t m_unnamedCount = 0; + std::ios_base::Init m_ostreamInit; // Forces cout/ cerr to be initialised + }; + + /////////////////////////////////////////////////////////////////////////// + + class TestInvokerAsFunction : public ITestInvoker { + void(*m_testAsFunction)(); + public: + TestInvokerAsFunction( void(*testAsFunction)() ) noexcept; + + void invoke() const override; + }; + + std::string extractClassName( StringRef const& classOrQualifiedMethodName ); + + /////////////////////////////////////////////////////////////////////////// + +} // end namespace Catch + +// end catch_test_case_registry_impl.h +// start catch_reporter_registry.h + +#include + +namespace Catch { + + class ReporterRegistry : public IReporterRegistry { + + public: + + ~ReporterRegistry() override; + + IStreamingReporterPtr create( std::string const& name, IConfigPtr const& config ) const override; + + void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ); + void registerListener( IReporterFactoryPtr const& factory ); + + FactoryMap const& getFactories() const override; + Listeners const& getListeners() const override; + + private: + FactoryMap m_factories; + Listeners m_listeners; + }; +} + +// end catch_reporter_registry.h +// start catch_tag_alias_registry.h + +// start catch_tag_alias.h + +#include + +namespace Catch { + + struct TagAlias { + TagAlias(std::string const& _tag, SourceLineInfo _lineInfo); + + std::string tag; + SourceLineInfo lineInfo; + }; + +} // end namespace Catch + +// end catch_tag_alias.h +#include + +namespace Catch { + + class TagAliasRegistry : public ITagAliasRegistry { + public: + ~TagAliasRegistry() override; + TagAlias const* find( std::string const& alias ) const override; + std::string expandAliases( std::string const& unexpandedTestSpec ) const override; + void add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ); + + private: + std::map m_registry; + }; + +} // end namespace Catch + +// end catch_tag_alias_registry.h +// start catch_startup_exception_registry.h + +#include +#include + +namespace Catch { + + class StartupExceptionRegistry { + public: + void add(std::exception_ptr const& exception) noexcept; + std::vector const& getExceptions() const noexcept; + private: + std::vector m_exceptions; + }; + +} // end namespace Catch + +// end catch_startup_exception_registry.h +// start catch_singletons.hpp + +namespace Catch { + + struct ISingleton { + virtual ~ISingleton(); + }; + + void addSingleton( ISingleton* singleton ); + void cleanupSingletons(); + + template + class Singleton : SingletonImplT, public ISingleton { + + static auto getInternal() -> Singleton* { + static Singleton* s_instance = nullptr; + if( !s_instance ) { + s_instance = new Singleton; + addSingleton( s_instance ); + } + return s_instance; + } + + public: + static auto get() -> InterfaceT const& { + return *getInternal(); + } + static auto getMutable() -> MutableInterfaceT& { + return *getInternal(); + } + }; + +} // namespace Catch + +// end catch_singletons.hpp +namespace Catch { + + namespace { + + class RegistryHub : public IRegistryHub, public IMutableRegistryHub, + private NonCopyable { + + public: // IRegistryHub + RegistryHub() = default; + IReporterRegistry const& getReporterRegistry() const override { + return m_reporterRegistry; + } + ITestCaseRegistry const& getTestCaseRegistry() const override { + return m_testCaseRegistry; + } + IExceptionTranslatorRegistry const& getExceptionTranslatorRegistry() const override { + return m_exceptionTranslatorRegistry; + } + ITagAliasRegistry const& getTagAliasRegistry() const override { + return m_tagAliasRegistry; + } + StartupExceptionRegistry const& getStartupExceptionRegistry() const override { + return m_exceptionRegistry; + } + + public: // IMutableRegistryHub + void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) override { + m_reporterRegistry.registerReporter( name, factory ); + } + void registerListener( IReporterFactoryPtr const& factory ) override { + m_reporterRegistry.registerListener( factory ); + } + void registerTest( TestCase const& testInfo ) override { + m_testCaseRegistry.registerTest( testInfo ); + } + void registerTranslator( const IExceptionTranslator* translator ) override { + m_exceptionTranslatorRegistry.registerTranslator( translator ); + } + void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) override { + m_tagAliasRegistry.add( alias, tag, lineInfo ); + } + void registerStartupException() noexcept override { + m_exceptionRegistry.add(std::current_exception()); + } + + private: + TestRegistry m_testCaseRegistry; + ReporterRegistry m_reporterRegistry; + ExceptionTranslatorRegistry m_exceptionTranslatorRegistry; + TagAliasRegistry m_tagAliasRegistry; + StartupExceptionRegistry m_exceptionRegistry; + }; + } + + using RegistryHubSingleton = Singleton; + + IRegistryHub const& getRegistryHub() { + return RegistryHubSingleton::get(); + } + IMutableRegistryHub& getMutableRegistryHub() { + return RegistryHubSingleton::getMutable(); + } + void cleanUp() { + cleanupSingletons(); + cleanUpContext(); + } + std::string translateActiveException() { + return getRegistryHub().getExceptionTranslatorRegistry().translateActiveException(); + } + +} // end namespace Catch +// end catch_registry_hub.cpp +// start catch_reporter_registry.cpp + +namespace Catch { + + ReporterRegistry::~ReporterRegistry() = default; + + IStreamingReporterPtr ReporterRegistry::create( std::string const& name, IConfigPtr const& config ) const { + auto it = m_factories.find( name ); + if( it == m_factories.end() ) + return nullptr; + return it->second->create( ReporterConfig( config ) ); + } + + void ReporterRegistry::registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) { + m_factories.emplace(name, factory); + } + void ReporterRegistry::registerListener( IReporterFactoryPtr const& factory ) { + m_listeners.push_back( factory ); + } + + IReporterRegistry::FactoryMap const& ReporterRegistry::getFactories() const { + return m_factories; + } + IReporterRegistry::Listeners const& ReporterRegistry::getListeners() const { + return m_listeners; + } + +} +// end catch_reporter_registry.cpp +// start catch_result_type.cpp + +namespace Catch { + + bool isOk( ResultWas::OfType resultType ) { + return ( resultType & ResultWas::FailureBit ) == 0; + } + bool isJustInfo( int flags ) { + return flags == ResultWas::Info; + } + + ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ) { + return static_cast( static_cast( lhs ) | static_cast( rhs ) ); + } + + bool shouldContinueOnFailure( int flags ) { return ( flags & ResultDisposition::ContinueOnFailure ) != 0; } + bool shouldSuppressFailure( int flags ) { return ( flags & ResultDisposition::SuppressFail ) != 0; } + +} // end namespace Catch +// end catch_result_type.cpp +// start catch_run_context.cpp + +#include +#include +#include + +namespace Catch { + + namespace Generators { + struct GeneratorTracker : TestCaseTracking::TrackerBase, IGeneratorTracker { + GeneratorBasePtr m_generator; + + GeneratorTracker( TestCaseTracking::NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : TrackerBase( nameAndLocation, ctx, parent ) + {} + ~GeneratorTracker(); + + static GeneratorTracker& acquire( TrackerContext& ctx, TestCaseTracking::NameAndLocation const& nameAndLocation ) { + std::shared_ptr tracker; + + ITracker& currentTracker = ctx.currentTracker(); + if( TestCaseTracking::ITrackerPtr childTracker = currentTracker.findChild( nameAndLocation ) ) { + assert( childTracker ); + assert( childTracker->isGeneratorTracker() ); + tracker = std::static_pointer_cast( childTracker ); + } + else { + tracker = std::make_shared( nameAndLocation, ctx, ¤tTracker ); + currentTracker.addChild( tracker ); + } + + if( !ctx.completedCycle() && !tracker->isComplete() ) { + tracker->open(); + } + + return *tracker; + } + + // TrackerBase interface + bool isGeneratorTracker() const override { return true; } + auto hasGenerator() const -> bool override { + return !!m_generator; + } + void close() override { + TrackerBase::close(); + // Generator interface only finds out if it has another item on atual move + if (m_runState == CompletedSuccessfully && m_generator->next()) { + m_children.clear(); + m_runState = Executing; + } + } + + // IGeneratorTracker interface + auto getGenerator() const -> GeneratorBasePtr const& override { + return m_generator; + } + void setGenerator( GeneratorBasePtr&& generator ) override { + m_generator = std::move( generator ); + } + }; + GeneratorTracker::~GeneratorTracker() {} + } + + RunContext::RunContext(IConfigPtr const& _config, IStreamingReporterPtr&& reporter) + : m_runInfo(_config->name()), + m_context(getCurrentMutableContext()), + m_config(_config), + m_reporter(std::move(reporter)), + m_lastAssertionInfo{ StringRef(), SourceLineInfo("",0), StringRef(), ResultDisposition::Normal }, + m_includeSuccessfulResults( m_config->includeSuccessfulResults() || m_reporter->getPreferences().shouldReportAllAssertions ) + { + m_context.setRunner(this); + m_context.setConfig(m_config); + m_context.setResultCapture(this); + m_reporter->testRunStarting(m_runInfo); + } + + RunContext::~RunContext() { + m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, aborting())); + } + + void RunContext::testGroupStarting(std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount) { + m_reporter->testGroupStarting(GroupInfo(testSpec, groupIndex, groupsCount)); + } + + void RunContext::testGroupEnded(std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount) { + m_reporter->testGroupEnded(TestGroupStats(GroupInfo(testSpec, groupIndex, groupsCount), totals, aborting())); + } + + Totals RunContext::runTest(TestCase const& testCase) { + Totals prevTotals = m_totals; + + std::string redirectedCout; + std::string redirectedCerr; + + auto const& testInfo = testCase.getTestCaseInfo(); + + m_reporter->testCaseStarting(testInfo); + + m_activeTestCase = &testCase; + + ITracker& rootTracker = m_trackerContext.startRun(); + assert(rootTracker.isSectionTracker()); + static_cast(rootTracker).addInitialFilters(m_config->getSectionsToRun()); + do { + m_trackerContext.startCycle(); + m_testCaseTracker = &SectionTracker::acquire(m_trackerContext, TestCaseTracking::NameAndLocation(testInfo.name, testInfo.lineInfo)); + runCurrentTest(redirectedCout, redirectedCerr); + } while (!m_testCaseTracker->isSuccessfullyCompleted() && !aborting()); + + Totals deltaTotals = m_totals.delta(prevTotals); + if (testInfo.expectedToFail() && deltaTotals.testCases.passed > 0) { + deltaTotals.assertions.failed++; + deltaTotals.testCases.passed--; + deltaTotals.testCases.failed++; + } + m_totals.testCases += deltaTotals.testCases; + m_reporter->testCaseEnded(TestCaseStats(testInfo, + deltaTotals, + redirectedCout, + redirectedCerr, + aborting())); + + m_activeTestCase = nullptr; + m_testCaseTracker = nullptr; + + return deltaTotals; + } + + IConfigPtr RunContext::config() const { + return m_config; + } + + IStreamingReporter& RunContext::reporter() const { + return *m_reporter; + } + + void RunContext::assertionEnded(AssertionResult const & result) { + if (result.getResultType() == ResultWas::Ok) { + m_totals.assertions.passed++; + m_lastAssertionPassed = true; + } else if (!result.isOk()) { + m_lastAssertionPassed = false; + if( m_activeTestCase->getTestCaseInfo().okToFail() ) + m_totals.assertions.failedButOk++; + else + m_totals.assertions.failed++; + } + else { + m_lastAssertionPassed = true; + } + + // We have no use for the return value (whether messages should be cleared), because messages were made scoped + // and should be let to clear themselves out. + static_cast(m_reporter->assertionEnded(AssertionStats(result, m_messages, m_totals))); + + if (result.getResultType() != ResultWas::Warning) + m_messageScopes.clear(); + + // Reset working state + resetAssertionInfo(); + m_lastResult = result; + } + void RunContext::resetAssertionInfo() { + m_lastAssertionInfo.macroName = StringRef(); + m_lastAssertionInfo.capturedExpression = "{Unknown expression after the reported line}"_sr; + } + + bool RunContext::sectionStarted(SectionInfo const & sectionInfo, Counts & assertions) { + ITracker& sectionTracker = SectionTracker::acquire(m_trackerContext, TestCaseTracking::NameAndLocation(sectionInfo.name, sectionInfo.lineInfo)); + if (!sectionTracker.isOpen()) + return false; + m_activeSections.push_back(§ionTracker); + + m_lastAssertionInfo.lineInfo = sectionInfo.lineInfo; + + m_reporter->sectionStarting(sectionInfo); + + assertions = m_totals.assertions; + + return true; + } + auto RunContext::acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& { + using namespace Generators; + GeneratorTracker& tracker = GeneratorTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( "generator", lineInfo ) ); + assert( tracker.isOpen() ); + m_lastAssertionInfo.lineInfo = lineInfo; + return tracker; + } + + bool RunContext::testForMissingAssertions(Counts& assertions) { + if (assertions.total() != 0) + return false; + if (!m_config->warnAboutMissingAssertions()) + return false; + if (m_trackerContext.currentTracker().hasChildren()) + return false; + m_totals.assertions.failed++; + assertions.failed++; + return true; + } + + void RunContext::sectionEnded(SectionEndInfo const & endInfo) { + Counts assertions = m_totals.assertions - endInfo.prevAssertions; + bool missingAssertions = testForMissingAssertions(assertions); + + if (!m_activeSections.empty()) { + m_activeSections.back()->close(); + m_activeSections.pop_back(); + } + + m_reporter->sectionEnded(SectionStats(endInfo.sectionInfo, assertions, endInfo.durationInSeconds, missingAssertions)); + m_messages.clear(); + m_messageScopes.clear(); + } + + void RunContext::sectionEndedEarly(SectionEndInfo const & endInfo) { + if (m_unfinishedSections.empty()) + m_activeSections.back()->fail(); + else + m_activeSections.back()->close(); + m_activeSections.pop_back(); + + m_unfinishedSections.push_back(endInfo); + } + void RunContext::benchmarkStarting( BenchmarkInfo const& info ) { + m_reporter->benchmarkStarting( info ); + } + void RunContext::benchmarkEnded( BenchmarkStats const& stats ) { + m_reporter->benchmarkEnded( stats ); + } + + void RunContext::pushScopedMessage(MessageInfo const & message) { + m_messages.push_back(message); + } + + void RunContext::popScopedMessage(MessageInfo const & message) { + m_messages.erase(std::remove(m_messages.begin(), m_messages.end(), message), m_messages.end()); + } + + void RunContext::emplaceUnscopedMessage( MessageBuilder const& builder ) { + m_messageScopes.emplace_back( builder ); + } + + std::string RunContext::getCurrentTestName() const { + return m_activeTestCase + ? m_activeTestCase->getTestCaseInfo().name + : std::string(); + } + + const AssertionResult * RunContext::getLastResult() const { + return &(*m_lastResult); + } + + void RunContext::exceptionEarlyReported() { + m_shouldReportUnexpected = false; + } + + void RunContext::handleFatalErrorCondition( StringRef message ) { + // First notify reporter that bad things happened + m_reporter->fatalErrorEncountered(message); + + // Don't rebuild the result -- the stringification itself can cause more fatal errors + // Instead, fake a result data. + AssertionResultData tempResult( ResultWas::FatalErrorCondition, { false } ); + tempResult.message = message; + AssertionResult result(m_lastAssertionInfo, tempResult); + + assertionEnded(result); + + handleUnfinishedSections(); + + // Recreate section for test case (as we will lose the one that was in scope) + auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name); + + Counts assertions; + assertions.failed = 1; + SectionStats testCaseSectionStats(testCaseSection, assertions, 0, false); + m_reporter->sectionEnded(testCaseSectionStats); + + auto const& testInfo = m_activeTestCase->getTestCaseInfo(); + + Totals deltaTotals; + deltaTotals.testCases.failed = 1; + deltaTotals.assertions.failed = 1; + m_reporter->testCaseEnded(TestCaseStats(testInfo, + deltaTotals, + std::string(), + std::string(), + false)); + m_totals.testCases.failed++; + testGroupEnded(std::string(), m_totals, 1, 1); + m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, false)); + } + + bool RunContext::lastAssertionPassed() { + return m_lastAssertionPassed; + } + + void RunContext::assertionPassed() { + m_lastAssertionPassed = true; + ++m_totals.assertions.passed; + resetAssertionInfo(); + m_messageScopes.clear(); + } + + bool RunContext::aborting() const { + return m_totals.assertions.failed >= static_cast(m_config->abortAfter()); + } + + void RunContext::runCurrentTest(std::string & redirectedCout, std::string & redirectedCerr) { + auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo(); + SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name); + m_reporter->sectionStarting(testCaseSection); + Counts prevAssertions = m_totals.assertions; + double duration = 0; + m_shouldReportUnexpected = true; + m_lastAssertionInfo = { "TEST_CASE"_sr, testCaseInfo.lineInfo, StringRef(), ResultDisposition::Normal }; + + seedRng(*m_config); + + Timer timer; + CATCH_TRY { + if (m_reporter->getPreferences().shouldRedirectStdOut) { +#if !defined(CATCH_CONFIG_EXPERIMENTAL_REDIRECT) + RedirectedStreams redirectedStreams(redirectedCout, redirectedCerr); + + timer.start(); + invokeActiveTestCase(); +#else + OutputRedirect r(redirectedCout, redirectedCerr); + timer.start(); + invokeActiveTestCase(); +#endif + } else { + timer.start(); + invokeActiveTestCase(); + } + duration = timer.getElapsedSeconds(); + } CATCH_CATCH_ANON (TestFailureException&) { + // This just means the test was aborted due to failure + } CATCH_CATCH_ALL { + // Under CATCH_CONFIG_FAST_COMPILE, unexpected exceptions under REQUIRE assertions + // are reported without translation at the point of origin. + if( m_shouldReportUnexpected ) { + AssertionReaction dummyReaction; + handleUnexpectedInflightException( m_lastAssertionInfo, translateActiveException(), dummyReaction ); + } + } + Counts assertions = m_totals.assertions - prevAssertions; + bool missingAssertions = testForMissingAssertions(assertions); + + m_testCaseTracker->close(); + handleUnfinishedSections(); + m_messages.clear(); + m_messageScopes.clear(); + + SectionStats testCaseSectionStats(testCaseSection, assertions, duration, missingAssertions); + m_reporter->sectionEnded(testCaseSectionStats); + } + + void RunContext::invokeActiveTestCase() { + FatalConditionHandler fatalConditionHandler; // Handle signals + m_activeTestCase->invoke(); + fatalConditionHandler.reset(); + } + + void RunContext::handleUnfinishedSections() { + // If sections ended prematurely due to an exception we stored their + // infos here so we can tear them down outside the unwind process. + for (auto it = m_unfinishedSections.rbegin(), + itEnd = m_unfinishedSections.rend(); + it != itEnd; + ++it) + sectionEnded(*it); + m_unfinishedSections.clear(); + } + + void RunContext::handleExpr( + AssertionInfo const& info, + ITransientExpression const& expr, + AssertionReaction& reaction + ) { + m_reporter->assertionStarting( info ); + + bool negated = isFalseTest( info.resultDisposition ); + bool result = expr.getResult() != negated; + + if( result ) { + if (!m_includeSuccessfulResults) { + assertionPassed(); + } + else { + reportExpr(info, ResultWas::Ok, &expr, negated); + } + } + else { + reportExpr(info, ResultWas::ExpressionFailed, &expr, negated ); + populateReaction( reaction ); + } + } + void RunContext::reportExpr( + AssertionInfo const &info, + ResultWas::OfType resultType, + ITransientExpression const *expr, + bool negated ) { + + m_lastAssertionInfo = info; + AssertionResultData data( resultType, LazyExpression( negated ) ); + + AssertionResult assertionResult{ info, data }; + assertionResult.m_resultData.lazyExpression.m_transientExpression = expr; + + assertionEnded( assertionResult ); + } + + void RunContext::handleMessage( + AssertionInfo const& info, + ResultWas::OfType resultType, + StringRef const& message, + AssertionReaction& reaction + ) { + m_reporter->assertionStarting( info ); + + m_lastAssertionInfo = info; + + AssertionResultData data( resultType, LazyExpression( false ) ); + data.message = message; + AssertionResult assertionResult{ m_lastAssertionInfo, data }; + assertionEnded( assertionResult ); + if( !assertionResult.isOk() ) + populateReaction( reaction ); + } + void RunContext::handleUnexpectedExceptionNotThrown( + AssertionInfo const& info, + AssertionReaction& reaction + ) { + handleNonExpr(info, Catch::ResultWas::DidntThrowException, reaction); + } + + void RunContext::handleUnexpectedInflightException( + AssertionInfo const& info, + std::string const& message, + AssertionReaction& reaction + ) { + m_lastAssertionInfo = info; + + AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) ); + data.message = message; + AssertionResult assertionResult{ info, data }; + assertionEnded( assertionResult ); + populateReaction( reaction ); + } + + void RunContext::populateReaction( AssertionReaction& reaction ) { + reaction.shouldDebugBreak = m_config->shouldDebugBreak(); + reaction.shouldThrow = aborting() || (m_lastAssertionInfo.resultDisposition & ResultDisposition::Normal); + } + + void RunContext::handleIncomplete( + AssertionInfo const& info + ) { + m_lastAssertionInfo = info; + + AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) ); + data.message = "Exception translation was disabled by CATCH_CONFIG_FAST_COMPILE"; + AssertionResult assertionResult{ info, data }; + assertionEnded( assertionResult ); + } + void RunContext::handleNonExpr( + AssertionInfo const &info, + ResultWas::OfType resultType, + AssertionReaction &reaction + ) { + m_lastAssertionInfo = info; + + AssertionResultData data( resultType, LazyExpression( false ) ); + AssertionResult assertionResult{ info, data }; + assertionEnded( assertionResult ); + + if( !assertionResult.isOk() ) + populateReaction( reaction ); + } + + IResultCapture& getResultCapture() { + if (auto* capture = getCurrentContext().getResultCapture()) + return *capture; + else + CATCH_INTERNAL_ERROR("No result capture instance"); + } +} +// end catch_run_context.cpp +// start catch_section.cpp + +namespace Catch { + + Section::Section( SectionInfo const& info ) + : m_info( info ), + m_sectionIncluded( getResultCapture().sectionStarted( m_info, m_assertions ) ) + { + m_timer.start(); + } + + Section::~Section() { + if( m_sectionIncluded ) { + SectionEndInfo endInfo{ m_info, m_assertions, m_timer.getElapsedSeconds() }; + if( uncaught_exceptions() ) + getResultCapture().sectionEndedEarly( endInfo ); + else + getResultCapture().sectionEnded( endInfo ); + } + } + + // This indicates whether the section should be executed or not + Section::operator bool() const { + return m_sectionIncluded; + } + +} // end namespace Catch +// end catch_section.cpp +// start catch_section_info.cpp + +namespace Catch { + + SectionInfo::SectionInfo + ( SourceLineInfo const& _lineInfo, + std::string const& _name ) + : name( _name ), + lineInfo( _lineInfo ) + {} + +} // end namespace Catch +// end catch_section_info.cpp +// start catch_session.cpp + +// start catch_session.h + +#include + +namespace Catch { + + class Session : NonCopyable { + public: + + Session(); + ~Session() override; + + void showHelp() const; + void libIdentify(); + + int applyCommandLine( int argc, char const * const * argv ); + #if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(UNICODE) + int applyCommandLine( int argc, wchar_t const * const * argv ); + #endif + + void useConfigData( ConfigData const& configData ); + + template + int run(int argc, CharT const * const argv[]) { + if (m_startupExceptions) + return 1; + int returnCode = applyCommandLine(argc, argv); + if (returnCode == 0) + returnCode = run(); + return returnCode; + } + + int run(); + + clara::Parser const& cli() const; + void cli( clara::Parser const& newParser ); + ConfigData& configData(); + Config& config(); + private: + int runInternal(); + + clara::Parser m_cli; + ConfigData m_configData; + std::shared_ptr m_config; + bool m_startupExceptions = false; + }; + +} // end namespace Catch + +// end catch_session.h +// start catch_version.h + +#include + +namespace Catch { + + // Versioning information + struct Version { + Version( Version const& ) = delete; + Version& operator=( Version const& ) = delete; + Version( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _patchNumber, + char const * const _branchName, + unsigned int _buildNumber ); + + unsigned int const majorVersion; + unsigned int const minorVersion; + unsigned int const patchNumber; + + // buildNumber is only used if branchName is not null + char const * const branchName; + unsigned int const buildNumber; + + friend std::ostream& operator << ( std::ostream& os, Version const& version ); + }; + + Version const& libraryVersion(); +} + +// end catch_version.h +#include +#include + +namespace Catch { + + namespace { + const int MaxExitCode = 255; + + IStreamingReporterPtr createReporter(std::string const& reporterName, IConfigPtr const& config) { + auto reporter = Catch::getRegistryHub().getReporterRegistry().create(reporterName, config); + CATCH_ENFORCE(reporter, "No reporter registered with name: '" << reporterName << "'"); + + return reporter; + } + + IStreamingReporterPtr makeReporter(std::shared_ptr const& config) { + if (Catch::getRegistryHub().getReporterRegistry().getListeners().empty()) { + return createReporter(config->getReporterName(), config); + } + + // On older platforms, returning std::unique_ptr + // when the return type is std::unique_ptr + // doesn't compile without a std::move call. However, this causes + // a warning on newer platforms. Thus, we have to work around + // it a bit and downcast the pointer manually. + auto ret = std::unique_ptr(new ListeningReporter); + auto& multi = static_cast(*ret); + auto const& listeners = Catch::getRegistryHub().getReporterRegistry().getListeners(); + for (auto const& listener : listeners) { + multi.addListener(listener->create(Catch::ReporterConfig(config))); + } + multi.addReporter(createReporter(config->getReporterName(), config)); + return ret; + } + + Catch::Totals runTests(std::shared_ptr const& config) { + auto reporter = makeReporter(config); + + RunContext context(config, std::move(reporter)); + + Totals totals; + + context.testGroupStarting(config->name(), 1, 1); + + TestSpec testSpec = config->testSpec(); + + auto const& allTestCases = getAllTestCasesSorted(*config); + for (auto const& testCase : allTestCases) { + if (!context.aborting() && matchTest(testCase, testSpec, *config)) + totals += context.runTest(testCase); + else + context.reporter().skipTest(testCase); + } + + if (config->warnAboutNoTests() && totals.testCases.total() == 0) { + ReusableStringStream testConfig; + + bool first = true; + for (const auto& input : config->getTestsOrTags()) { + if (!first) { testConfig << ' '; } + first = false; + testConfig << input; + } + + context.reporter().noMatchingTestCases(testConfig.str()); + totals.error = -1; + } + + context.testGroupEnded(config->name(), totals, 1, 1); + return totals; + } + + void applyFilenamesAsTags(Catch::IConfig const& config) { + auto& tests = const_cast&>(getAllTestCasesSorted(config)); + for (auto& testCase : tests) { + auto tags = testCase.tags; + + std::string filename = testCase.lineInfo.file; + auto lastSlash = filename.find_last_of("\\/"); + if (lastSlash != std::string::npos) { + filename.erase(0, lastSlash); + filename[0] = '#'; + } + + auto lastDot = filename.find_last_of('.'); + if (lastDot != std::string::npos) { + filename.erase(lastDot); + } + + tags.push_back(std::move(filename)); + setTags(testCase, tags); + } + } + + } // anon namespace + + Session::Session() { + static bool alreadyInstantiated = false; + if( alreadyInstantiated ) { + CATCH_TRY { CATCH_INTERNAL_ERROR( "Only one instance of Catch::Session can ever be used" ); } + CATCH_CATCH_ALL { getMutableRegistryHub().registerStartupException(); } + } + + // There cannot be exceptions at startup in no-exception mode. +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + const auto& exceptions = getRegistryHub().getStartupExceptionRegistry().getExceptions(); + if ( !exceptions.empty() ) { + m_startupExceptions = true; + Colour colourGuard( Colour::Red ); + Catch::cerr() << "Errors occurred during startup!" << '\n'; + // iterate over all exceptions and notify user + for ( const auto& ex_ptr : exceptions ) { + try { + std::rethrow_exception(ex_ptr); + } catch ( std::exception const& ex ) { + Catch::cerr() << Column( ex.what() ).indent(2) << '\n'; + } + } + } +#endif + + alreadyInstantiated = true; + m_cli = makeCommandLineParser( m_configData ); + } + Session::~Session() { + Catch::cleanUp(); + } + + void Session::showHelp() const { + Catch::cout() + << "\nCatch v" << libraryVersion() << "\n" + << m_cli << std::endl + << "For more detailed usage please see the project docs\n" << std::endl; + } + void Session::libIdentify() { + Catch::cout() + << std::left << std::setw(16) << "description: " << "A Catch test executable\n" + << std::left << std::setw(16) << "category: " << "testframework\n" + << std::left << std::setw(16) << "framework: " << "Catch Test\n" + << std::left << std::setw(16) << "version: " << libraryVersion() << std::endl; + } + + int Session::applyCommandLine( int argc, char const * const * argv ) { + if( m_startupExceptions ) + return 1; + + auto result = m_cli.parse( clara::Args( argc, argv ) ); + if( !result ) { + config(); + getCurrentMutableContext().setConfig(m_config); + Catch::cerr() + << Colour( Colour::Red ) + << "\nError(s) in input:\n" + << Column( result.errorMessage() ).indent( 2 ) + << "\n\n"; + Catch::cerr() << "Run with -? for usage\n" << std::endl; + return MaxExitCode; + } + + if( m_configData.showHelp ) + showHelp(); + if( m_configData.libIdentify ) + libIdentify(); + m_config.reset(); + return 0; + } + +#if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(UNICODE) + int Session::applyCommandLine( int argc, wchar_t const * const * argv ) { + + char **utf8Argv = new char *[ argc ]; + + for ( int i = 0; i < argc; ++i ) { + int bufSize = WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, NULL, 0, NULL, NULL ); + + utf8Argv[ i ] = new char[ bufSize ]; + + WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, utf8Argv[i], bufSize, NULL, NULL ); + } + + int returnCode = applyCommandLine( argc, utf8Argv ); + + for ( int i = 0; i < argc; ++i ) + delete [] utf8Argv[ i ]; + + delete [] utf8Argv; + + return returnCode; + } +#endif + + void Session::useConfigData( ConfigData const& configData ) { + m_configData = configData; + m_config.reset(); + } + + int Session::run() { + if( ( m_configData.waitForKeypress & WaitForKeypress::BeforeStart ) != 0 ) { + Catch::cout() << "...waiting for enter/ return before starting" << std::endl; + static_cast(std::getchar()); + } + int exitCode = runInternal(); + if( ( m_configData.waitForKeypress & WaitForKeypress::BeforeExit ) != 0 ) { + Catch::cout() << "...waiting for enter/ return before exiting, with code: " << exitCode << std::endl; + static_cast(std::getchar()); + } + return exitCode; + } + + clara::Parser const& Session::cli() const { + return m_cli; + } + void Session::cli( clara::Parser const& newParser ) { + m_cli = newParser; + } + ConfigData& Session::configData() { + return m_configData; + } + Config& Session::config() { + if( !m_config ) + m_config = std::make_shared( m_configData ); + return *m_config; + } + + int Session::runInternal() { + if( m_startupExceptions ) + return 1; + + if (m_configData.showHelp || m_configData.libIdentify) { + return 0; + } + + CATCH_TRY { + config(); // Force config to be constructed + + seedRng( *m_config ); + + if( m_configData.filenamesAsTags ) + applyFilenamesAsTags( *m_config ); + + // Handle list request + if( Option listed = list( m_config ) ) + return static_cast( *listed ); + + auto totals = runTests( m_config ); + // Note that on unices only the lower 8 bits are usually used, clamping + // the return value to 255 prevents false negative when some multiple + // of 256 tests has failed + return (std::min) (MaxExitCode, (std::max) (totals.error, static_cast(totals.assertions.failed))); + } +#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) + catch( std::exception& ex ) { + Catch::cerr() << ex.what() << std::endl; + return MaxExitCode; + } +#endif + } + +} // end namespace Catch +// end catch_session.cpp +// start catch_singletons.cpp + +#include + +namespace Catch { + + namespace { + static auto getSingletons() -> std::vector*& { + static std::vector* g_singletons = nullptr; + if( !g_singletons ) + g_singletons = new std::vector(); + return g_singletons; + } + } + + ISingleton::~ISingleton() {} + + void addSingleton(ISingleton* singleton ) { + getSingletons()->push_back( singleton ); + } + void cleanupSingletons() { + auto& singletons = getSingletons(); + for( auto singleton : *singletons ) + delete singleton; + delete singletons; + singletons = nullptr; + } + +} // namespace Catch +// end catch_singletons.cpp +// start catch_startup_exception_registry.cpp + +namespace Catch { +void StartupExceptionRegistry::add( std::exception_ptr const& exception ) noexcept { + CATCH_TRY { + m_exceptions.push_back(exception); + } CATCH_CATCH_ALL { + // If we run out of memory during start-up there's really not a lot more we can do about it + std::terminate(); + } + } + + std::vector const& StartupExceptionRegistry::getExceptions() const noexcept { + return m_exceptions; + } + +} // end namespace Catch +// end catch_startup_exception_registry.cpp +// start catch_stream.cpp + +#include +#include +#include +#include +#include +#include + +namespace Catch { + + Catch::IStream::~IStream() = default; + + namespace detail { namespace { + template + class StreamBufImpl : public std::streambuf { + char data[bufferSize]; + WriterF m_writer; + + public: + StreamBufImpl() { + setp( data, data + sizeof(data) ); + } + + ~StreamBufImpl() noexcept { + StreamBufImpl::sync(); + } + + private: + int overflow( int c ) override { + sync(); + + if( c != EOF ) { + if( pbase() == epptr() ) + m_writer( std::string( 1, static_cast( c ) ) ); + else + sputc( static_cast( c ) ); + } + return 0; + } + + int sync() override { + if( pbase() != pptr() ) { + m_writer( std::string( pbase(), static_cast( pptr() - pbase() ) ) ); + setp( pbase(), epptr() ); + } + return 0; + } + }; + + /////////////////////////////////////////////////////////////////////////// + + struct OutputDebugWriter { + + void operator()( std::string const&str ) { + writeToDebugConsole( str ); + } + }; + + /////////////////////////////////////////////////////////////////////////// + + class FileStream : public IStream { + mutable std::ofstream m_ofs; + public: + FileStream( StringRef filename ) { + m_ofs.open( filename.c_str() ); + CATCH_ENFORCE( !m_ofs.fail(), "Unable to open file: '" << filename << "'" ); + } + ~FileStream() override = default; + public: // IStream + std::ostream& stream() const override { + return m_ofs; + } + }; + + /////////////////////////////////////////////////////////////////////////// + + class CoutStream : public IStream { + mutable std::ostream m_os; + public: + // Store the streambuf from cout up-front because + // cout may get redirected when running tests + CoutStream() : m_os( Catch::cout().rdbuf() ) {} + ~CoutStream() override = default; + + public: // IStream + std::ostream& stream() const override { return m_os; } + }; + + /////////////////////////////////////////////////////////////////////////// + + class DebugOutStream : public IStream { + std::unique_ptr> m_streamBuf; + mutable std::ostream m_os; + public: + DebugOutStream() + : m_streamBuf( new StreamBufImpl() ), + m_os( m_streamBuf.get() ) + {} + + ~DebugOutStream() override = default; + + public: // IStream + std::ostream& stream() const override { return m_os; } + }; + + }} // namespace anon::detail + + /////////////////////////////////////////////////////////////////////////// + + auto makeStream( StringRef const &filename ) -> IStream const* { + if( filename.empty() ) + return new detail::CoutStream(); + else if( filename[0] == '%' ) { + if( filename == "%debug" ) + return new detail::DebugOutStream(); + else + CATCH_ERROR( "Unrecognised stream: '" << filename << "'" ); + } + else + return new detail::FileStream( filename ); + } + + // This class encapsulates the idea of a pool of ostringstreams that can be reused. + struct StringStreams { + std::vector> m_streams; + std::vector m_unused; + std::ostringstream m_referenceStream; // Used for copy state/ flags from + + auto add() -> std::size_t { + if( m_unused.empty() ) { + m_streams.push_back( std::unique_ptr( new std::ostringstream ) ); + return m_streams.size()-1; + } + else { + auto index = m_unused.back(); + m_unused.pop_back(); + return index; + } + } + + void release( std::size_t index ) { + m_streams[index]->copyfmt( m_referenceStream ); // Restore initial flags and other state + m_unused.push_back(index); + } + }; + + ReusableStringStream::ReusableStringStream() + : m_index( Singleton::getMutable().add() ), + m_oss( Singleton::getMutable().m_streams[m_index].get() ) + {} + + ReusableStringStream::~ReusableStringStream() { + static_cast( m_oss )->str(""); + m_oss->clear(); + Singleton::getMutable().release( m_index ); + } + + auto ReusableStringStream::str() const -> std::string { + return static_cast( m_oss )->str(); + } + + /////////////////////////////////////////////////////////////////////////// + +#ifndef CATCH_CONFIG_NOSTDOUT // If you #define this you must implement these functions + std::ostream& cout() { return std::cout; } + std::ostream& cerr() { return std::cerr; } + std::ostream& clog() { return std::clog; } +#endif +} +// end catch_stream.cpp +// start catch_string_manip.cpp + +#include +#include +#include +#include + +namespace Catch { + + namespace { + char toLowerCh(char c) { + return static_cast( std::tolower( c ) ); + } + } + + bool startsWith( std::string const& s, std::string const& prefix ) { + return s.size() >= prefix.size() && std::equal(prefix.begin(), prefix.end(), s.begin()); + } + bool startsWith( std::string const& s, char prefix ) { + return !s.empty() && s[0] == prefix; + } + bool endsWith( std::string const& s, std::string const& suffix ) { + return s.size() >= suffix.size() && std::equal(suffix.rbegin(), suffix.rend(), s.rbegin()); + } + bool endsWith( std::string const& s, char suffix ) { + return !s.empty() && s[s.size()-1] == suffix; + } + bool contains( std::string const& s, std::string const& infix ) { + return s.find( infix ) != std::string::npos; + } + void toLowerInPlace( std::string& s ) { + std::transform( s.begin(), s.end(), s.begin(), toLowerCh ); + } + std::string toLower( std::string const& s ) { + std::string lc = s; + toLowerInPlace( lc ); + return lc; + } + std::string trim( std::string const& str ) { + static char const* whitespaceChars = "\n\r\t "; + std::string::size_type start = str.find_first_not_of( whitespaceChars ); + std::string::size_type end = str.find_last_not_of( whitespaceChars ); + + return start != std::string::npos ? str.substr( start, 1+end-start ) : std::string(); + } + + bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ) { + bool replaced = false; + std::size_t i = str.find( replaceThis ); + while( i != std::string::npos ) { + replaced = true; + str = str.substr( 0, i ) + withThis + str.substr( i+replaceThis.size() ); + if( i < str.size()-withThis.size() ) + i = str.find( replaceThis, i+withThis.size() ); + else + i = std::string::npos; + } + return replaced; + } + + pluralise::pluralise( std::size_t count, std::string const& label ) + : m_count( count ), + m_label( label ) + {} + + std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser ) { + os << pluraliser.m_count << ' ' << pluraliser.m_label; + if( pluraliser.m_count != 1 ) + os << 's'; + return os; + } + +} +// end catch_string_manip.cpp +// start catch_stringref.cpp + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +#endif + +#include +#include +#include + +namespace { + const uint32_t byte_2_lead = 0xC0; + const uint32_t byte_3_lead = 0xE0; + const uint32_t byte_4_lead = 0xF0; +} + +namespace Catch { + StringRef::StringRef( char const* rawChars ) noexcept + : StringRef( rawChars, static_cast(std::strlen(rawChars) ) ) + {} + + StringRef::operator std::string() const { + return std::string( m_start, m_size ); + } + + void StringRef::swap( StringRef& other ) noexcept { + std::swap( m_start, other.m_start ); + std::swap( m_size, other.m_size ); + std::swap( m_data, other.m_data ); + } + + auto StringRef::c_str() const -> char const* { + if( isSubstring() ) + const_cast( this )->takeOwnership(); + return m_start; + } + auto StringRef::currentData() const noexcept -> char const* { + return m_start; + } + + auto StringRef::isOwned() const noexcept -> bool { + return m_data != nullptr; + } + auto StringRef::isSubstring() const noexcept -> bool { + return m_start[m_size] != '\0'; + } + + void StringRef::takeOwnership() { + if( !isOwned() ) { + m_data = new char[m_size+1]; + memcpy( m_data, m_start, m_size ); + m_data[m_size] = '\0'; + m_start = m_data; + } + } + auto StringRef::substr( size_type start, size_type size ) const noexcept -> StringRef { + if( start < m_size ) + return StringRef( m_start+start, size ); + else + return StringRef(); + } + auto StringRef::operator == ( StringRef const& other ) const noexcept -> bool { + return + size() == other.size() && + (std::strncmp( m_start, other.m_start, size() ) == 0); + } + auto StringRef::operator != ( StringRef const& other ) const noexcept -> bool { + return !operator==( other ); + } + + auto StringRef::operator[](size_type index) const noexcept -> char { + return m_start[index]; + } + + auto StringRef::numberOfCharacters() const noexcept -> size_type { + size_type noChars = m_size; + // Make adjustments for uft encodings + for( size_type i=0; i < m_size; ++i ) { + char c = m_start[i]; + if( ( c & byte_2_lead ) == byte_2_lead ) { + noChars--; + if (( c & byte_3_lead ) == byte_3_lead ) + noChars--; + if( ( c & byte_4_lead ) == byte_4_lead ) + noChars--; + } + } + return noChars; + } + + auto operator + ( StringRef const& lhs, StringRef const& rhs ) -> std::string { + std::string str; + str.reserve( lhs.size() + rhs.size() ); + str += lhs; + str += rhs; + return str; + } + auto operator + ( StringRef const& lhs, const char* rhs ) -> std::string { + return std::string( lhs ) + std::string( rhs ); + } + auto operator + ( char const* lhs, StringRef const& rhs ) -> std::string { + return std::string( lhs ) + std::string( rhs ); + } + + auto operator << ( std::ostream& os, StringRef const& str ) -> std::ostream& { + return os.write(str.currentData(), str.size()); + } + + auto operator+=( std::string& lhs, StringRef const& rhs ) -> std::string& { + lhs.append(rhs.currentData(), rhs.size()); + return lhs; + } + +} // namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif +// end catch_stringref.cpp +// start catch_tag_alias.cpp + +namespace Catch { + TagAlias::TagAlias(std::string const & _tag, SourceLineInfo _lineInfo): tag(_tag), lineInfo(_lineInfo) {} +} +// end catch_tag_alias.cpp +// start catch_tag_alias_autoregistrar.cpp + +namespace Catch { + + RegistrarForTagAliases::RegistrarForTagAliases(char const* alias, char const* tag, SourceLineInfo const& lineInfo) { + CATCH_TRY { + getMutableRegistryHub().registerTagAlias(alias, tag, lineInfo); + } CATCH_CATCH_ALL { + // Do not throw when constructing global objects, instead register the exception to be processed later + getMutableRegistryHub().registerStartupException(); + } + } + +} +// end catch_tag_alias_autoregistrar.cpp +// start catch_tag_alias_registry.cpp + +#include + +namespace Catch { + + TagAliasRegistry::~TagAliasRegistry() {} + + TagAlias const* TagAliasRegistry::find( std::string const& alias ) const { + auto it = m_registry.find( alias ); + if( it != m_registry.end() ) + return &(it->second); + else + return nullptr; + } + + std::string TagAliasRegistry::expandAliases( std::string const& unexpandedTestSpec ) const { + std::string expandedTestSpec = unexpandedTestSpec; + for( auto const& registryKvp : m_registry ) { + std::size_t pos = expandedTestSpec.find( registryKvp.first ); + if( pos != std::string::npos ) { + expandedTestSpec = expandedTestSpec.substr( 0, pos ) + + registryKvp.second.tag + + expandedTestSpec.substr( pos + registryKvp.first.size() ); + } + } + return expandedTestSpec; + } + + void TagAliasRegistry::add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) { + CATCH_ENFORCE( startsWith(alias, "[@") && endsWith(alias, ']'), + "error: tag alias, '" << alias << "' is not of the form [@alias name].\n" << lineInfo ); + + CATCH_ENFORCE( m_registry.insert(std::make_pair(alias, TagAlias(tag, lineInfo))).second, + "error: tag alias, '" << alias << "' already registered.\n" + << "\tFirst seen at: " << find(alias)->lineInfo << "\n" + << "\tRedefined at: " << lineInfo ); + } + + ITagAliasRegistry::~ITagAliasRegistry() {} + + ITagAliasRegistry const& ITagAliasRegistry::get() { + return getRegistryHub().getTagAliasRegistry(); + } + +} // end namespace Catch +// end catch_tag_alias_registry.cpp +// start catch_test_case_info.cpp + +#include +#include +#include +#include + +namespace Catch { + + namespace { + TestCaseInfo::SpecialProperties parseSpecialTag( std::string const& tag ) { + if( startsWith( tag, '.' ) || + tag == "!hide" ) + return TestCaseInfo::IsHidden; + else if( tag == "!throws" ) + return TestCaseInfo::Throws; + else if( tag == "!shouldfail" ) + return TestCaseInfo::ShouldFail; + else if( tag == "!mayfail" ) + return TestCaseInfo::MayFail; + else if( tag == "!nonportable" ) + return TestCaseInfo::NonPortable; + else if( tag == "!benchmark" ) + return static_cast( TestCaseInfo::Benchmark | TestCaseInfo::IsHidden ); + else + return TestCaseInfo::None; + } + bool isReservedTag( std::string const& tag ) { + return parseSpecialTag( tag ) == TestCaseInfo::None && tag.size() > 0 && !std::isalnum( static_cast(tag[0]) ); + } + void enforceNotReservedTag( std::string const& tag, SourceLineInfo const& _lineInfo ) { + CATCH_ENFORCE( !isReservedTag(tag), + "Tag name: [" << tag << "] is not allowed.\n" + << "Tag names starting with non alpha-numeric characters are reserved\n" + << _lineInfo ); + } + } + + TestCase makeTestCase( ITestInvoker* _testCase, + std::string const& _className, + NameAndTags const& nameAndTags, + SourceLineInfo const& _lineInfo ) + { + bool isHidden = false; + + // Parse out tags + std::vector tags; + std::string desc, tag; + bool inTag = false; + std::string _descOrTags = nameAndTags.tags; + for (char c : _descOrTags) { + if( !inTag ) { + if( c == '[' ) + inTag = true; + else + desc += c; + } + else { + if( c == ']' ) { + TestCaseInfo::SpecialProperties prop = parseSpecialTag( tag ); + if( ( prop & TestCaseInfo::IsHidden ) != 0 ) + isHidden = true; + else if( prop == TestCaseInfo::None ) + enforceNotReservedTag( tag, _lineInfo ); + + tags.push_back( tag ); + tag.clear(); + inTag = false; + } + else + tag += c; + } + } + if( isHidden ) { + tags.push_back( "." ); + } + + TestCaseInfo info( nameAndTags.name, _className, desc, tags, _lineInfo ); + return TestCase( _testCase, std::move(info) ); + } + + void setTags( TestCaseInfo& testCaseInfo, std::vector tags ) { + std::sort(begin(tags), end(tags)); + tags.erase(std::unique(begin(tags), end(tags)), end(tags)); + testCaseInfo.lcaseTags.clear(); + + for( auto const& tag : tags ) { + std::string lcaseTag = toLower( tag ); + testCaseInfo.properties = static_cast( testCaseInfo.properties | parseSpecialTag( lcaseTag ) ); + testCaseInfo.lcaseTags.push_back( lcaseTag ); + } + testCaseInfo.tags = std::move(tags); + } + + TestCaseInfo::TestCaseInfo( std::string const& _name, + std::string const& _className, + std::string const& _description, + std::vector const& _tags, + SourceLineInfo const& _lineInfo ) + : name( _name ), + className( _className ), + description( _description ), + lineInfo( _lineInfo ), + properties( None ) + { + setTags( *this, _tags ); + } + + bool TestCaseInfo::isHidden() const { + return ( properties & IsHidden ) != 0; + } + bool TestCaseInfo::throws() const { + return ( properties & Throws ) != 0; + } + bool TestCaseInfo::okToFail() const { + return ( properties & (ShouldFail | MayFail ) ) != 0; + } + bool TestCaseInfo::expectedToFail() const { + return ( properties & (ShouldFail ) ) != 0; + } + + std::string TestCaseInfo::tagsAsString() const { + std::string ret; + // '[' and ']' per tag + std::size_t full_size = 2 * tags.size(); + for (const auto& tag : tags) { + full_size += tag.size(); + } + ret.reserve(full_size); + for (const auto& tag : tags) { + ret.push_back('['); + ret.append(tag); + ret.push_back(']'); + } + + return ret; + } + + TestCase::TestCase( ITestInvoker* testCase, TestCaseInfo&& info ) : TestCaseInfo( std::move(info) ), test( testCase ) {} + + TestCase TestCase::withName( std::string const& _newName ) const { + TestCase other( *this ); + other.name = _newName; + return other; + } + + void TestCase::invoke() const { + test->invoke(); + } + + bool TestCase::operator == ( TestCase const& other ) const { + return test.get() == other.test.get() && + name == other.name && + className == other.className; + } + + bool TestCase::operator < ( TestCase const& other ) const { + return name < other.name; + } + + TestCaseInfo const& TestCase::getTestCaseInfo() const + { + return *this; + } + +} // end namespace Catch +// end catch_test_case_info.cpp +// start catch_test_case_registry_impl.cpp + +#include + +namespace Catch { + + std::vector sortTests( IConfig const& config, std::vector const& unsortedTestCases ) { + + std::vector sorted = unsortedTestCases; + + switch( config.runOrder() ) { + case RunTests::InLexicographicalOrder: + std::sort( sorted.begin(), sorted.end() ); + break; + case RunTests::InRandomOrder: + seedRng( config ); + std::shuffle( sorted.begin(), sorted.end(), rng() ); + break; + case RunTests::InDeclarationOrder: + // already in declaration order + break; + } + return sorted; + } + bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ) { + return testSpec.matches( testCase ) && ( config.allowThrows() || !testCase.throws() ); + } + + void enforceNoDuplicateTestCases( std::vector const& functions ) { + std::set seenFunctions; + for( auto const& function : functions ) { + auto prev = seenFunctions.insert( function ); + CATCH_ENFORCE( prev.second, + "error: TEST_CASE( \"" << function.name << "\" ) already defined.\n" + << "\tFirst seen at " << prev.first->getTestCaseInfo().lineInfo << "\n" + << "\tRedefined at " << function.getTestCaseInfo().lineInfo ); + } + } + + std::vector filterTests( std::vector const& testCases, TestSpec const& testSpec, IConfig const& config ) { + std::vector filtered; + filtered.reserve( testCases.size() ); + for( auto const& testCase : testCases ) + if( matchTest( testCase, testSpec, config ) ) + filtered.push_back( testCase ); + return filtered; + } + std::vector const& getAllTestCasesSorted( IConfig const& config ) { + return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config ); + } + + void TestRegistry::registerTest( TestCase const& testCase ) { + std::string name = testCase.getTestCaseInfo().name; + if( name.empty() ) { + ReusableStringStream rss; + rss << "Anonymous test case " << ++m_unnamedCount; + return registerTest( testCase.withName( rss.str() ) ); + } + m_functions.push_back( testCase ); + } + + std::vector const& TestRegistry::getAllTests() const { + return m_functions; + } + std::vector const& TestRegistry::getAllTestsSorted( IConfig const& config ) const { + if( m_sortedFunctions.empty() ) + enforceNoDuplicateTestCases( m_functions ); + + if( m_currentSortOrder != config.runOrder() || m_sortedFunctions.empty() ) { + m_sortedFunctions = sortTests( config, m_functions ); + m_currentSortOrder = config.runOrder(); + } + return m_sortedFunctions; + } + + /////////////////////////////////////////////////////////////////////////// + TestInvokerAsFunction::TestInvokerAsFunction( void(*testAsFunction)() ) noexcept : m_testAsFunction( testAsFunction ) {} + + void TestInvokerAsFunction::invoke() const { + m_testAsFunction(); + } + + std::string extractClassName( StringRef const& classOrQualifiedMethodName ) { + std::string className = classOrQualifiedMethodName; + if( startsWith( className, '&' ) ) + { + std::size_t lastColons = className.rfind( "::" ); + std::size_t penultimateColons = className.rfind( "::", lastColons-1 ); + if( penultimateColons == std::string::npos ) + penultimateColons = 1; + className = className.substr( penultimateColons, lastColons-penultimateColons ); + } + return className; + } + +} // end namespace Catch +// end catch_test_case_registry_impl.cpp +// start catch_test_case_tracker.cpp + +#include +#include +#include +#include +#include + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +#endif + +namespace Catch { +namespace TestCaseTracking { + + NameAndLocation::NameAndLocation( std::string const& _name, SourceLineInfo const& _location ) + : name( _name ), + location( _location ) + {} + + ITracker::~ITracker() = default; + + TrackerContext& TrackerContext::instance() { + static TrackerContext s_instance; + return s_instance; + } + + ITracker& TrackerContext::startRun() { + m_rootTracker = std::make_shared( NameAndLocation( "{root}", CATCH_INTERNAL_LINEINFO ), *this, nullptr ); + m_currentTracker = nullptr; + m_runState = Executing; + return *m_rootTracker; + } + + void TrackerContext::endRun() { + m_rootTracker.reset(); + m_currentTracker = nullptr; + m_runState = NotStarted; + } + + void TrackerContext::startCycle() { + m_currentTracker = m_rootTracker.get(); + m_runState = Executing; + } + void TrackerContext::completeCycle() { + m_runState = CompletedCycle; + } + + bool TrackerContext::completedCycle() const { + return m_runState == CompletedCycle; + } + ITracker& TrackerContext::currentTracker() { + return *m_currentTracker; + } + void TrackerContext::setCurrentTracker( ITracker* tracker ) { + m_currentTracker = tracker; + } + + TrackerBase::TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : m_nameAndLocation( nameAndLocation ), + m_ctx( ctx ), + m_parent( parent ) + {} + + NameAndLocation const& TrackerBase::nameAndLocation() const { + return m_nameAndLocation; + } + bool TrackerBase::isComplete() const { + return m_runState == CompletedSuccessfully || m_runState == Failed; + } + bool TrackerBase::isSuccessfullyCompleted() const { + return m_runState == CompletedSuccessfully; + } + bool TrackerBase::isOpen() const { + return m_runState != NotStarted && !isComplete(); + } + bool TrackerBase::hasChildren() const { + return !m_children.empty(); + } + + void TrackerBase::addChild( ITrackerPtr const& child ) { + m_children.push_back( child ); + } + + ITrackerPtr TrackerBase::findChild( NameAndLocation const& nameAndLocation ) { + auto it = std::find_if( m_children.begin(), m_children.end(), + [&nameAndLocation]( ITrackerPtr const& tracker ){ + return + tracker->nameAndLocation().location == nameAndLocation.location && + tracker->nameAndLocation().name == nameAndLocation.name; + } ); + return( it != m_children.end() ) + ? *it + : nullptr; + } + ITracker& TrackerBase::parent() { + assert( m_parent ); // Should always be non-null except for root + return *m_parent; + } + + void TrackerBase::openChild() { + if( m_runState != ExecutingChildren ) { + m_runState = ExecutingChildren; + if( m_parent ) + m_parent->openChild(); + } + } + + bool TrackerBase::isSectionTracker() const { return false; } + bool TrackerBase::isGeneratorTracker() const { return false; } + + void TrackerBase::open() { + m_runState = Executing; + moveToThis(); + if( m_parent ) + m_parent->openChild(); + } + + void TrackerBase::close() { + + // Close any still open children (e.g. generators) + while( &m_ctx.currentTracker() != this ) + m_ctx.currentTracker().close(); + + switch( m_runState ) { + case NeedsAnotherRun: + break; + + case Executing: + m_runState = CompletedSuccessfully; + break; + case ExecutingChildren: + if( m_children.empty() || m_children.back()->isComplete() ) + m_runState = CompletedSuccessfully; + break; + + case NotStarted: + case CompletedSuccessfully: + case Failed: + CATCH_INTERNAL_ERROR( "Illogical state: " << m_runState ); + + default: + CATCH_INTERNAL_ERROR( "Unknown state: " << m_runState ); + } + moveToParent(); + m_ctx.completeCycle(); + } + void TrackerBase::fail() { + m_runState = Failed; + if( m_parent ) + m_parent->markAsNeedingAnotherRun(); + moveToParent(); + m_ctx.completeCycle(); + } + void TrackerBase::markAsNeedingAnotherRun() { + m_runState = NeedsAnotherRun; + } + + void TrackerBase::moveToParent() { + assert( m_parent ); + m_ctx.setCurrentTracker( m_parent ); + } + void TrackerBase::moveToThis() { + m_ctx.setCurrentTracker( this ); + } + + SectionTracker::SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent ) + : TrackerBase( nameAndLocation, ctx, parent ) + { + if( parent ) { + while( !parent->isSectionTracker() ) + parent = &parent->parent(); + + SectionTracker& parentSection = static_cast( *parent ); + addNextFilters( parentSection.m_filters ); + } + } + + bool SectionTracker::isComplete() const { + bool complete = true; + + if ((m_filters.empty() || m_filters[0] == "") || + std::find(m_filters.begin(), m_filters.end(), + m_nameAndLocation.name) != m_filters.end()) + complete = TrackerBase::isComplete(); + return complete; + + } + + bool SectionTracker::isSectionTracker() const { return true; } + + SectionTracker& SectionTracker::acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ) { + std::shared_ptr section; + + ITracker& currentTracker = ctx.currentTracker(); + if( ITrackerPtr childTracker = currentTracker.findChild( nameAndLocation ) ) { + assert( childTracker ); + assert( childTracker->isSectionTracker() ); + section = std::static_pointer_cast( childTracker ); + } + else { + section = std::make_shared( nameAndLocation, ctx, ¤tTracker ); + currentTracker.addChild( section ); + } + if( !ctx.completedCycle() ) + section->tryOpen(); + return *section; + } + + void SectionTracker::tryOpen() { + if( !isComplete() && (m_filters.empty() || m_filters[0].empty() || m_filters[0] == m_nameAndLocation.name ) ) + open(); + } + + void SectionTracker::addInitialFilters( std::vector const& filters ) { + if( !filters.empty() ) { + m_filters.push_back(""); // Root - should never be consulted + m_filters.push_back(""); // Test Case - not a section filter + m_filters.insert( m_filters.end(), filters.begin(), filters.end() ); + } + } + void SectionTracker::addNextFilters( std::vector const& filters ) { + if( filters.size() > 1 ) + m_filters.insert( m_filters.end(), ++filters.begin(), filters.end() ); + } + +} // namespace TestCaseTracking + +using TestCaseTracking::ITracker; +using TestCaseTracking::TrackerContext; +using TestCaseTracking::SectionTracker; + +} // namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif +// end catch_test_case_tracker.cpp +// start catch_test_registry.cpp + +namespace Catch { + + auto makeTestInvoker( void(*testAsFunction)() ) noexcept -> ITestInvoker* { + return new(std::nothrow) TestInvokerAsFunction( testAsFunction ); + } + + NameAndTags::NameAndTags( StringRef const& name_ , StringRef const& tags_ ) noexcept : name( name_ ), tags( tags_ ) {} + + AutoReg::AutoReg( ITestInvoker* invoker, SourceLineInfo const& lineInfo, StringRef const& classOrMethod, NameAndTags const& nameAndTags ) noexcept { + CATCH_TRY { + getMutableRegistryHub() + .registerTest( + makeTestCase( + invoker, + extractClassName( classOrMethod ), + nameAndTags, + lineInfo)); + } CATCH_CATCH_ALL { + // Do not throw when constructing global objects, instead register the exception to be processed later + getMutableRegistryHub().registerStartupException(); + } + } + + AutoReg::~AutoReg() = default; +} +// end catch_test_registry.cpp +// start catch_test_spec.cpp + +#include +#include +#include +#include + +namespace Catch { + + TestSpec::Pattern::~Pattern() = default; + TestSpec::NamePattern::~NamePattern() = default; + TestSpec::TagPattern::~TagPattern() = default; + TestSpec::ExcludedPattern::~ExcludedPattern() = default; + + TestSpec::NamePattern::NamePattern( std::string const& name ) + : m_wildcardPattern( toLower( name ), CaseSensitive::No ) + {} + bool TestSpec::NamePattern::matches( TestCaseInfo const& testCase ) const { + return m_wildcardPattern.matches( toLower( testCase.name ) ); + } + + TestSpec::TagPattern::TagPattern( std::string const& tag ) : m_tag( toLower( tag ) ) {} + bool TestSpec::TagPattern::matches( TestCaseInfo const& testCase ) const { + return std::find(begin(testCase.lcaseTags), + end(testCase.lcaseTags), + m_tag) != end(testCase.lcaseTags); + } + + TestSpec::ExcludedPattern::ExcludedPattern( PatternPtr const& underlyingPattern ) : m_underlyingPattern( underlyingPattern ) {} + bool TestSpec::ExcludedPattern::matches( TestCaseInfo const& testCase ) const { return !m_underlyingPattern->matches( testCase ); } + + bool TestSpec::Filter::matches( TestCaseInfo const& testCase ) const { + // All patterns in a filter must match for the filter to be a match + for( auto const& pattern : m_patterns ) { + if( !pattern->matches( testCase ) ) + return false; + } + return true; + } + + bool TestSpec::hasFilters() const { + return !m_filters.empty(); + } + bool TestSpec::matches( TestCaseInfo const& testCase ) const { + // A TestSpec matches if any filter matches + for( auto const& filter : m_filters ) + if( filter.matches( testCase ) ) + return true; + return false; + } +} +// end catch_test_spec.cpp +// start catch_test_spec_parser.cpp + +namespace Catch { + + TestSpecParser::TestSpecParser( ITagAliasRegistry const& tagAliases ) : m_tagAliases( &tagAliases ) {} + + TestSpecParser& TestSpecParser::parse( std::string const& arg ) { + m_mode = None; + m_exclusion = false; + m_start = std::string::npos; + m_arg = m_tagAliases->expandAliases( arg ); + m_escapeChars.clear(); + for( m_pos = 0; m_pos < m_arg.size(); ++m_pos ) + visitChar( m_arg[m_pos] ); + if( m_mode == Name ) + addPattern(); + return *this; + } + TestSpec TestSpecParser::testSpec() { + addFilter(); + return m_testSpec; + } + + void TestSpecParser::visitChar( char c ) { + if( m_mode == None ) { + switch( c ) { + case ' ': return; + case '~': m_exclusion = true; return; + case '[': return startNewMode( Tag, ++m_pos ); + case '"': return startNewMode( QuotedName, ++m_pos ); + case '\\': return escape(); + default: startNewMode( Name, m_pos ); break; + } + } + if( m_mode == Name ) { + if( c == ',' ) { + addPattern(); + addFilter(); + } + else if( c == '[' ) { + if( subString() == "exclude:" ) + m_exclusion = true; + else + addPattern(); + startNewMode( Tag, ++m_pos ); + } + else if( c == '\\' ) + escape(); + } + else if( m_mode == EscapedName ) + m_mode = Name; + else if( m_mode == QuotedName && c == '"' ) + addPattern(); + else if( m_mode == Tag && c == ']' ) + addPattern(); + } + void TestSpecParser::startNewMode( Mode mode, std::size_t start ) { + m_mode = mode; + m_start = start; + } + void TestSpecParser::escape() { + if( m_mode == None ) + m_start = m_pos; + m_mode = EscapedName; + m_escapeChars.push_back( m_pos ); + } + std::string TestSpecParser::subString() const { return m_arg.substr( m_start, m_pos - m_start ); } + + void TestSpecParser::addFilter() { + if( !m_currentFilter.m_patterns.empty() ) { + m_testSpec.m_filters.push_back( m_currentFilter ); + m_currentFilter = TestSpec::Filter(); + } + } + + TestSpec parseTestSpec( std::string const& arg ) { + return TestSpecParser( ITagAliasRegistry::get() ).parse( arg ).testSpec(); + } + +} // namespace Catch +// end catch_test_spec_parser.cpp +// start catch_timer.cpp + +#include + +static const uint64_t nanosecondsInSecond = 1000000000; + +namespace Catch { + + auto getCurrentNanosecondsSinceEpoch() -> uint64_t { + return std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch() ).count(); + } + + namespace { + auto estimateClockResolution() -> uint64_t { + uint64_t sum = 0; + static const uint64_t iterations = 1000000; + + auto startTime = getCurrentNanosecondsSinceEpoch(); + + for( std::size_t i = 0; i < iterations; ++i ) { + + uint64_t ticks; + uint64_t baseTicks = getCurrentNanosecondsSinceEpoch(); + do { + ticks = getCurrentNanosecondsSinceEpoch(); + } while( ticks == baseTicks ); + + auto delta = ticks - baseTicks; + sum += delta; + + // If we have been calibrating for over 3 seconds -- the clock + // is terrible and we should move on. + // TBD: How to signal that the measured resolution is probably wrong? + if (ticks > startTime + 3 * nanosecondsInSecond) { + return sum / ( i + 1u ); + } + } + + // We're just taking the mean, here. To do better we could take the std. dev and exclude outliers + // - and potentially do more iterations if there's a high variance. + return sum/iterations; + } + } + auto getEstimatedClockResolution() -> uint64_t { + static auto s_resolution = estimateClockResolution(); + return s_resolution; + } + + void Timer::start() { + m_nanoseconds = getCurrentNanosecondsSinceEpoch(); + } + auto Timer::getElapsedNanoseconds() const -> uint64_t { + return getCurrentNanosecondsSinceEpoch() - m_nanoseconds; + } + auto Timer::getElapsedMicroseconds() const -> uint64_t { + return getElapsedNanoseconds()/1000; + } + auto Timer::getElapsedMilliseconds() const -> unsigned int { + return static_cast(getElapsedMicroseconds()/1000); + } + auto Timer::getElapsedSeconds() const -> double { + return getElapsedMicroseconds()/1000000.0; + } + +} // namespace Catch +// end catch_timer.cpp +// start catch_tostring.cpp + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +# pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +// Enable specific decls locally +#if !defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER) +#define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER +#endif + +#include +#include + +namespace Catch { + +namespace Detail { + + const std::string unprintableString = "{?}"; + + namespace { + const int hexThreshold = 255; + + struct Endianness { + enum Arch { Big, Little }; + + static Arch which() { + union _{ + int asInt; + char asChar[sizeof (int)]; + } u; + + u.asInt = 1; + return ( u.asChar[sizeof(int)-1] == 1 ) ? Big : Little; + } + }; + } + + std::string rawMemoryToString( const void *object, std::size_t size ) { + // Reverse order for little endian architectures + int i = 0, end = static_cast( size ), inc = 1; + if( Endianness::which() == Endianness::Little ) { + i = end-1; + end = inc = -1; + } + + unsigned char const *bytes = static_cast(object); + ReusableStringStream rss; + rss << "0x" << std::setfill('0') << std::hex; + for( ; i != end; i += inc ) + rss << std::setw(2) << static_cast(bytes[i]); + return rss.str(); + } +} + +template +std::string fpToString( T value, int precision ) { + if (Catch::isnan(value)) { + return "nan"; + } + + ReusableStringStream rss; + rss << std::setprecision( precision ) + << std::fixed + << value; + std::string d = rss.str(); + std::size_t i = d.find_last_not_of( '0' ); + if( i != std::string::npos && i != d.size()-1 ) { + if( d[i] == '.' ) + i++; + d = d.substr( 0, i+1 ); + } + return d; +} + +//// ======================================================= //// +// +// Out-of-line defs for full specialization of StringMaker +// +//// ======================================================= //// + +std::string StringMaker::convert(const std::string& str) { + if (!getCurrentContext().getConfig()->showInvisibles()) { + return '"' + str + '"'; + } + + std::string s("\""); + for (char c : str) { + switch (c) { + case '\n': + s.append("\\n"); + break; + case '\t': + s.append("\\t"); + break; + default: + s.push_back(c); + break; + } + } + s.append("\""); + return s; +} + +#ifdef CATCH_CONFIG_CPP17_STRING_VIEW +std::string StringMaker::convert(std::string_view str) { + return ::Catch::Detail::stringify(std::string{ str }); +} +#endif + +std::string StringMaker::convert(char const* str) { + if (str) { + return ::Catch::Detail::stringify(std::string{ str }); + } else { + return{ "{null string}" }; + } +} +std::string StringMaker::convert(char* str) { + if (str) { + return ::Catch::Detail::stringify(std::string{ str }); + } else { + return{ "{null string}" }; + } +} + +#ifdef CATCH_CONFIG_WCHAR +std::string StringMaker::convert(const std::wstring& wstr) { + std::string s; + s.reserve(wstr.size()); + for (auto c : wstr) { + s += (c <= 0xff) ? static_cast(c) : '?'; + } + return ::Catch::Detail::stringify(s); +} + +# ifdef CATCH_CONFIG_CPP17_STRING_VIEW +std::string StringMaker::convert(std::wstring_view str) { + return StringMaker::convert(std::wstring(str)); +} +# endif + +std::string StringMaker::convert(wchar_t const * str) { + if (str) { + return ::Catch::Detail::stringify(std::wstring{ str }); + } else { + return{ "{null string}" }; + } +} +std::string StringMaker::convert(wchar_t * str) { + if (str) { + return ::Catch::Detail::stringify(std::wstring{ str }); + } else { + return{ "{null string}" }; + } +} +#endif + +std::string StringMaker::convert(int value) { + return ::Catch::Detail::stringify(static_cast(value)); +} +std::string StringMaker::convert(long value) { + return ::Catch::Detail::stringify(static_cast(value)); +} +std::string StringMaker::convert(long long value) { + ReusableStringStream rss; + rss << value; + if (value > Detail::hexThreshold) { + rss << " (0x" << std::hex << value << ')'; + } + return rss.str(); +} + +std::string StringMaker::convert(unsigned int value) { + return ::Catch::Detail::stringify(static_cast(value)); +} +std::string StringMaker::convert(unsigned long value) { + return ::Catch::Detail::stringify(static_cast(value)); +} +std::string StringMaker::convert(unsigned long long value) { + ReusableStringStream rss; + rss << value; + if (value > Detail::hexThreshold) { + rss << " (0x" << std::hex << value << ')'; + } + return rss.str(); +} + +std::string StringMaker::convert(bool b) { + return b ? "true" : "false"; +} + +std::string StringMaker::convert(signed char value) { + if (value == '\r') { + return "'\\r'"; + } else if (value == '\f') { + return "'\\f'"; + } else if (value == '\n') { + return "'\\n'"; + } else if (value == '\t') { + return "'\\t'"; + } else if ('\0' <= value && value < ' ') { + return ::Catch::Detail::stringify(static_cast(value)); + } else { + char chstr[] = "' '"; + chstr[1] = value; + return chstr; + } +} +std::string StringMaker::convert(char c) { + return ::Catch::Detail::stringify(static_cast(c)); +} +std::string StringMaker::convert(unsigned char c) { + return ::Catch::Detail::stringify(static_cast(c)); +} + +std::string StringMaker::convert(std::nullptr_t) { + return "nullptr"; +} + +std::string StringMaker::convert(float value) { + return fpToString(value, 5) + 'f'; +} +std::string StringMaker::convert(double value) { + return fpToString(value, 10); +} + +std::string ratio_string::symbol() { return "a"; } +std::string ratio_string::symbol() { return "f"; } +std::string ratio_string::symbol() { return "p"; } +std::string ratio_string::symbol() { return "n"; } +std::string ratio_string::symbol() { return "u"; } +std::string ratio_string::symbol() { return "m"; } + +} // end namespace Catch + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + +// end catch_tostring.cpp +// start catch_totals.cpp + +namespace Catch { + + Counts Counts::operator - ( Counts const& other ) const { + Counts diff; + diff.passed = passed - other.passed; + diff.failed = failed - other.failed; + diff.failedButOk = failedButOk - other.failedButOk; + return diff; + } + + Counts& Counts::operator += ( Counts const& other ) { + passed += other.passed; + failed += other.failed; + failedButOk += other.failedButOk; + return *this; + } + + std::size_t Counts::total() const { + return passed + failed + failedButOk; + } + bool Counts::allPassed() const { + return failed == 0 && failedButOk == 0; + } + bool Counts::allOk() const { + return failed == 0; + } + + Totals Totals::operator - ( Totals const& other ) const { + Totals diff; + diff.assertions = assertions - other.assertions; + diff.testCases = testCases - other.testCases; + return diff; + } + + Totals& Totals::operator += ( Totals const& other ) { + assertions += other.assertions; + testCases += other.testCases; + return *this; + } + + Totals Totals::delta( Totals const& prevTotals ) const { + Totals diff = *this - prevTotals; + if( diff.assertions.failed > 0 ) + ++diff.testCases.failed; + else if( diff.assertions.failedButOk > 0 ) + ++diff.testCases.failedButOk; + else + ++diff.testCases.passed; + return diff; + } + +} +// end catch_totals.cpp +// start catch_uncaught_exceptions.cpp + +#include + +namespace Catch { + bool uncaught_exceptions() { +#if defined(CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) + return std::uncaught_exceptions() > 0; +#else + return std::uncaught_exception(); +#endif + } +} // end namespace Catch +// end catch_uncaught_exceptions.cpp +// start catch_version.cpp + +#include + +namespace Catch { + + Version::Version + ( unsigned int _majorVersion, + unsigned int _minorVersion, + unsigned int _patchNumber, + char const * const _branchName, + unsigned int _buildNumber ) + : majorVersion( _majorVersion ), + minorVersion( _minorVersion ), + patchNumber( _patchNumber ), + branchName( _branchName ), + buildNumber( _buildNumber ) + {} + + std::ostream& operator << ( std::ostream& os, Version const& version ) { + os << version.majorVersion << '.' + << version.minorVersion << '.' + << version.patchNumber; + // branchName is never null -> 0th char is \0 if it is empty + if (version.branchName[0]) { + os << '-' << version.branchName + << '.' << version.buildNumber; + } + return os; + } + + Version const& libraryVersion() { + static Version version( 2, 7, 0, "", 0 ); + return version; + } + +} +// end catch_version.cpp +// start catch_wildcard_pattern.cpp + +#include + +namespace Catch { + + WildcardPattern::WildcardPattern( std::string const& pattern, + CaseSensitive::Choice caseSensitivity ) + : m_caseSensitivity( caseSensitivity ), + m_pattern( adjustCase( pattern ) ) + { + if( startsWith( m_pattern, '*' ) ) { + m_pattern = m_pattern.substr( 1 ); + m_wildcard = WildcardAtStart; + } + if( endsWith( m_pattern, '*' ) ) { + m_pattern = m_pattern.substr( 0, m_pattern.size()-1 ); + m_wildcard = static_cast( m_wildcard | WildcardAtEnd ); + } + } + + bool WildcardPattern::matches( std::string const& str ) const { + switch( m_wildcard ) { + case NoWildcard: + return m_pattern == adjustCase( str ); + case WildcardAtStart: + return endsWith( adjustCase( str ), m_pattern ); + case WildcardAtEnd: + return startsWith( adjustCase( str ), m_pattern ); + case WildcardAtBothEnds: + return contains( adjustCase( str ), m_pattern ); + default: + CATCH_INTERNAL_ERROR( "Unknown enum" ); + } + } + + std::string WildcardPattern::adjustCase( std::string const& str ) const { + return m_caseSensitivity == CaseSensitive::No ? toLower( str ) : str; + } +} +// end catch_wildcard_pattern.cpp +// start catch_xmlwriter.cpp + +#include + +using uchar = unsigned char; + +namespace Catch { + +namespace { + + size_t trailingBytes(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return 2; + } + if ((c & 0xF0) == 0xE0) { + return 3; + } + if ((c & 0xF8) == 0xF0) { + return 4; + } + CATCH_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + uint32_t headerValue(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return c & 0x1F; + } + if ((c & 0xF0) == 0xE0) { + return c & 0x0F; + } + if ((c & 0xF8) == 0xF0) { + return c & 0x07; + } + CATCH_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + void hexEscapeChar(std::ostream& os, unsigned char c) { + std::ios_base::fmtflags f(os.flags()); + os << "\\x" + << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast(c); + os.flags(f); + } + +} // anonymous namespace + + XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void XmlEncode::encodeTo( std::ostream& os ) const { + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: http://www.w3.org/TR/xml/#syntax) + + for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { + uchar c = m_str[idx]; + switch (c) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: http://www.w3.org/TR/xml/#syntax + if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') + os << ">"; + else + os << c; + break; + + case '\"': + if (m_forWhat == ForAttributes) + os << """; + else + os << c; + break; + + default: + // Check for control characters and invalid utf-8 + + // Escape control characters in standard ascii + // see http://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { + hexEscapeChar(os, c); + break; + } + + // Plain ASCII: Write it to stream + if (c < 0x7F) { + os << c; + break; + } + + // UTF-8 territory + // Check if the encoding is valid and if it is not, hex escape bytes. + // Important: We do not check the exact decoded values for validity, only the encoding format + // First check that this bytes is a valid lead byte: + // This means that it is not encoded as 1111 1XXX + // Or as 10XX XXXX + if (c < 0xC0 || + c >= 0xF8) { + hexEscapeChar(os, c); + break; + } + + auto encBytes = trailingBytes(c); + // Are there enough bytes left to avoid accessing out-of-bounds memory? + if (idx + encBytes - 1 >= m_str.size()) { + hexEscapeChar(os, c); + break; + } + // The header is valid, check data + // The next encBytes bytes must together be a valid utf-8 + // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) + bool valid = true; + uint32_t value = headerValue(c); + for (std::size_t n = 1; n < encBytes; ++n) { + uchar nc = m_str[idx + n]; + valid &= ((nc & 0xC0) == 0x80); + value = (value << 6) | (nc & 0x3F); + } + + if ( + // Wrong bit pattern of following bytes + (!valid) || + // Overlong encodings + (value < 0x80) || + (0x80 <= value && value < 0x800 && encBytes > 2) || + (0x800 < value && value < 0x10000 && encBytes > 3) || + // Encoded value out of range + (value >= 0x110000) + ) { + hexEscapeChar(os, c); + break; + } + + // If we got here, this is in fact a valid(ish) utf-8 sequence + for (std::size_t n = 0; n < encBytes; ++n) { + os << m_str[idx + n]; + } + idx += encBytes - 1; + break; + } + } + } + + std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) noexcept + : m_writer( other.m_writer ){ + other.m_writer = nullptr; + } + XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) noexcept { + if ( m_writer ) { + m_writer->endElement(); + } + m_writer = other.m_writer; + other.m_writer = nullptr; + return *this; + } + + XmlWriter::ScopedElement::~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { + m_writer->writeText( text, indent ); + return *this; + } + + XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) + { + writeDeclaration(); + } + + XmlWriter::~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& XmlWriter::startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& XmlWriter::endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << ""; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + XmlWriter& XmlWriter::writeComment( std::string const& text ) { + ensureTagClosed(); + m_os << m_indent << ""; + m_needsNewline = true; + return *this; + } + + void XmlWriter::writeStylesheetRef( std::string const& url ) { + m_os << "\n"; + } + + XmlWriter& XmlWriter::writeBlankLine() { + ensureTagClosed(); + m_os << '\n'; + return *this; + } + + void XmlWriter::ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + void XmlWriter::writeDeclaration() { + m_os << "\n"; + } + + void XmlWriter::newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } +} +// end catch_xmlwriter.cpp +// start catch_reporter_bases.cpp + +#include +#include +#include +#include +#include + +namespace Catch { + void prepareExpandedExpression(AssertionResult& result) { + result.getExpandedExpression(); + } + + // Because formatting using c++ streams is stateful, drop down to C is required + // Alternatively we could use stringstream, but its performance is... not good. + std::string getFormattedDuration( double duration ) { + // Max exponent + 1 is required to represent the whole part + // + 1 for decimal point + // + 3 for the 3 decimal places + // + 1 for null terminator + const std::size_t maxDoubleSize = DBL_MAX_10_EXP + 1 + 1 + 3 + 1; + char buffer[maxDoubleSize]; + + // Save previous errno, to prevent sprintf from overwriting it + ErrnoGuard guard; +#ifdef _MSC_VER + sprintf_s(buffer, "%.3f", duration); +#else + std::sprintf(buffer, "%.3f", duration); +#endif + return std::string(buffer); + } + + TestEventListenerBase::TestEventListenerBase(ReporterConfig const & _config) + :StreamingReporterBase(_config) {} + + std::set TestEventListenerBase::getSupportedVerbosities() { + return { Verbosity::Quiet, Verbosity::Normal, Verbosity::High }; + } + + void TestEventListenerBase::assertionStarting(AssertionInfo const &) {} + + bool TestEventListenerBase::assertionEnded(AssertionStats const &) { + return false; + } + +} // end namespace Catch +// end catch_reporter_bases.cpp +// start catch_reporter_compact.cpp + +namespace { + +#ifdef CATCH_PLATFORM_MAC + const char* failedString() { return "FAILED"; } + const char* passedString() { return "PASSED"; } +#else + const char* failedString() { return "failed"; } + const char* passedString() { return "passed"; } +#endif + + // Colour::LightGrey + Catch::Colour::Code dimColour() { return Catch::Colour::FileName; } + + std::string bothOrAll( std::size_t count ) { + return count == 1 ? std::string() : + count == 2 ? "both " : "all " ; + } + +} // anon namespace + +namespace Catch { +namespace { +// Colour, message variants: +// - white: No tests ran. +// - red: Failed [both/all] N test cases, failed [both/all] M assertions. +// - white: Passed [both/all] N test cases (no assertions). +// - red: Failed N tests cases, failed M assertions. +// - green: Passed [both/all] N tests cases with M assertions. +void printTotals(std::ostream& out, const Totals& totals) { + if (totals.testCases.total() == 0) { + out << "No tests ran."; + } else if (totals.testCases.failed == totals.testCases.total()) { + Colour colour(Colour::ResultError); + const std::string qualify_assertions_failed = + totals.assertions.failed == totals.assertions.total() ? + bothOrAll(totals.assertions.failed) : std::string(); + out << + "Failed " << bothOrAll(totals.testCases.failed) + << pluralise(totals.testCases.failed, "test case") << ", " + "failed " << qualify_assertions_failed << + pluralise(totals.assertions.failed, "assertion") << '.'; + } else if (totals.assertions.total() == 0) { + out << + "Passed " << bothOrAll(totals.testCases.total()) + << pluralise(totals.testCases.total(), "test case") + << " (no assertions)."; + } else if (totals.assertions.failed) { + Colour colour(Colour::ResultError); + out << + "Failed " << pluralise(totals.testCases.failed, "test case") << ", " + "failed " << pluralise(totals.assertions.failed, "assertion") << '.'; + } else { + Colour colour(Colour::ResultSuccess); + out << + "Passed " << bothOrAll(totals.testCases.passed) + << pluralise(totals.testCases.passed, "test case") << + " with " << pluralise(totals.assertions.passed, "assertion") << '.'; + } +} + +// Implementation of CompactReporter formatting +class AssertionPrinter { +public: + AssertionPrinter& operator= (AssertionPrinter const&) = delete; + AssertionPrinter(AssertionPrinter const&) = delete; + AssertionPrinter(std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages) + : stream(_stream) + , result(_stats.assertionResult) + , messages(_stats.infoMessages) + , itMessage(_stats.infoMessages.begin()) + , printInfoMessages(_printInfoMessages) {} + + void print() { + printSourceInfo(); + + itMessage = messages.begin(); + + switch (result.getResultType()) { + case ResultWas::Ok: + printResultType(Colour::ResultSuccess, passedString()); + printOriginalExpression(); + printReconstructedExpression(); + if (!result.hasExpression()) + printRemainingMessages(Colour::None); + else + printRemainingMessages(); + break; + case ResultWas::ExpressionFailed: + if (result.isOk()) + printResultType(Colour::ResultSuccess, failedString() + std::string(" - but was ok")); + else + printResultType(Colour::Error, failedString()); + printOriginalExpression(); + printReconstructedExpression(); + printRemainingMessages(); + break; + case ResultWas::ThrewException: + printResultType(Colour::Error, failedString()); + printIssue("unexpected exception with message:"); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::FatalErrorCondition: + printResultType(Colour::Error, failedString()); + printIssue("fatal error condition with message:"); + printMessage(); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::DidntThrowException: + printResultType(Colour::Error, failedString()); + printIssue("expected exception, got none"); + printExpressionWas(); + printRemainingMessages(); + break; + case ResultWas::Info: + printResultType(Colour::None, "info"); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::Warning: + printResultType(Colour::None, "warning"); + printMessage(); + printRemainingMessages(); + break; + case ResultWas::ExplicitFailure: + printResultType(Colour::Error, failedString()); + printIssue("explicitly"); + printRemainingMessages(Colour::None); + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + printResultType(Colour::Error, "** internal error **"); + break; + } + } + +private: + void printSourceInfo() const { + Colour colourGuard(Colour::FileName); + stream << result.getSourceInfo() << ':'; + } + + void printResultType(Colour::Code colour, std::string const& passOrFail) const { + if (!passOrFail.empty()) { + { + Colour colourGuard(colour); + stream << ' ' << passOrFail; + } + stream << ':'; + } + } + + void printIssue(std::string const& issue) const { + stream << ' ' << issue; + } + + void printExpressionWas() { + if (result.hasExpression()) { + stream << ';'; + { + Colour colour(dimColour()); + stream << " expression was:"; + } + printOriginalExpression(); + } + } + + void printOriginalExpression() const { + if (result.hasExpression()) { + stream << ' ' << result.getExpression(); + } + } + + void printReconstructedExpression() const { + if (result.hasExpandedExpression()) { + { + Colour colour(dimColour()); + stream << " for: "; + } + stream << result.getExpandedExpression(); + } + } + + void printMessage() { + if (itMessage != messages.end()) { + stream << " '" << itMessage->message << '\''; + ++itMessage; + } + } + + void printRemainingMessages(Colour::Code colour = dimColour()) { + if (itMessage == messages.end()) + return; + + // using messages.end() directly yields (or auto) compilation error: + std::vector::const_iterator itEnd = messages.end(); + const std::size_t N = static_cast(std::distance(itMessage, itEnd)); + + { + Colour colourGuard(colour); + stream << " with " << pluralise(N, "message") << ':'; + } + + for (; itMessage != itEnd; ) { + // If this assertion is a warning ignore any INFO messages + if (printInfoMessages || itMessage->type != ResultWas::Info) { + stream << " '" << itMessage->message << '\''; + if (++itMessage != itEnd) { + Colour colourGuard(dimColour()); + stream << " and"; + } + } + } + } + +private: + std::ostream& stream; + AssertionResult const& result; + std::vector messages; + std::vector::const_iterator itMessage; + bool printInfoMessages; +}; + +} // anon namespace + + std::string CompactReporter::getDescription() { + return "Reports test results on a single line, suitable for IDEs"; + } + + ReporterPreferences CompactReporter::getPreferences() const { + return m_reporterPrefs; + } + + void CompactReporter::noMatchingTestCases( std::string const& spec ) { + stream << "No test cases matched '" << spec << '\'' << std::endl; + } + + void CompactReporter::assertionStarting( AssertionInfo const& ) {} + + bool CompactReporter::assertionEnded( AssertionStats const& _assertionStats ) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool printInfoMessages = true; + + // Drop out if result was successful and we're not printing those + if( !m_config->includeSuccessfulResults() && result.isOk() ) { + if( result.getResultType() != ResultWas::Warning ) + return false; + printInfoMessages = false; + } + + AssertionPrinter printer( stream, _assertionStats, printInfoMessages ); + printer.print(); + + stream << std::endl; + return true; + } + + void CompactReporter::sectionEnded(SectionStats const& _sectionStats) { + if (m_config->showDurations() == ShowDurations::Always) { + stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl; + } + } + + void CompactReporter::testRunEnded( TestRunStats const& _testRunStats ) { + printTotals( stream, _testRunStats.totals ); + stream << '\n' << std::endl; + StreamingReporterBase::testRunEnded( _testRunStats ); + } + + CompactReporter::~CompactReporter() {} + + CATCH_REGISTER_REPORTER( "compact", CompactReporter ) + +} // end namespace Catch +// end catch_reporter_compact.cpp +// start catch_reporter_console.cpp + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch + // Note that 4062 (not all labels are handled + // and default is missing) is enabled +#endif + +namespace Catch { + +namespace { + +// Formatter impl for ConsoleReporter +class ConsoleAssertionPrinter { +public: + ConsoleAssertionPrinter& operator= (ConsoleAssertionPrinter const&) = delete; + ConsoleAssertionPrinter(ConsoleAssertionPrinter const&) = delete; + ConsoleAssertionPrinter(std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages) + : stream(_stream), + stats(_stats), + result(_stats.assertionResult), + colour(Colour::None), + message(result.getMessage()), + messages(_stats.infoMessages), + printInfoMessages(_printInfoMessages) { + switch (result.getResultType()) { + case ResultWas::Ok: + colour = Colour::Success; + passOrFail = "PASSED"; + //if( result.hasMessage() ) + if (_stats.infoMessages.size() == 1) + messageLabel = "with message"; + if (_stats.infoMessages.size() > 1) + messageLabel = "with messages"; + break; + case ResultWas::ExpressionFailed: + if (result.isOk()) { + colour = Colour::Success; + passOrFail = "FAILED - but was ok"; + } else { + colour = Colour::Error; + passOrFail = "FAILED"; + } + if (_stats.infoMessages.size() == 1) + messageLabel = "with message"; + if (_stats.infoMessages.size() > 1) + messageLabel = "with messages"; + break; + case ResultWas::ThrewException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to unexpected exception with "; + if (_stats.infoMessages.size() == 1) + messageLabel += "message"; + if (_stats.infoMessages.size() > 1) + messageLabel += "messages"; + break; + case ResultWas::FatalErrorCondition: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "due to a fatal error condition"; + break; + case ResultWas::DidntThrowException: + colour = Colour::Error; + passOrFail = "FAILED"; + messageLabel = "because no exception was thrown where one was expected"; + break; + case ResultWas::Info: + messageLabel = "info"; + break; + case ResultWas::Warning: + messageLabel = "warning"; + break; + case ResultWas::ExplicitFailure: + passOrFail = "FAILED"; + colour = Colour::Error; + if (_stats.infoMessages.size() == 1) + messageLabel = "explicitly with message"; + if (_stats.infoMessages.size() > 1) + messageLabel = "explicitly with messages"; + break; + // These cases are here to prevent compiler warnings + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + passOrFail = "** internal error **"; + colour = Colour::Error; + break; + } + } + + void print() const { + printSourceInfo(); + if (stats.totals.assertions.total() > 0) { + printResultType(); + printOriginalExpression(); + printReconstructedExpression(); + } else { + stream << '\n'; + } + printMessage(); + } + +private: + void printResultType() const { + if (!passOrFail.empty()) { + Colour colourGuard(colour); + stream << passOrFail << ":\n"; + } + } + void printOriginalExpression() const { + if (result.hasExpression()) { + Colour colourGuard(Colour::OriginalExpression); + stream << " "; + stream << result.getExpressionInMacro(); + stream << '\n'; + } + } + void printReconstructedExpression() const { + if (result.hasExpandedExpression()) { + stream << "with expansion:\n"; + Colour colourGuard(Colour::ReconstructedExpression); + stream << Column(result.getExpandedExpression()).indent(2) << '\n'; + } + } + void printMessage() const { + if (!messageLabel.empty()) + stream << messageLabel << ':' << '\n'; + for (auto const& msg : messages) { + // If this assertion is a warning ignore any INFO messages + if (printInfoMessages || msg.type != ResultWas::Info) + stream << Column(msg.message).indent(2) << '\n'; + } + } + void printSourceInfo() const { + Colour colourGuard(Colour::FileName); + stream << result.getSourceInfo() << ": "; + } + + std::ostream& stream; + AssertionStats const& stats; + AssertionResult const& result; + Colour::Code colour; + std::string passOrFail; + std::string messageLabel; + std::string message; + std::vector messages; + bool printInfoMessages; +}; + +std::size_t makeRatio(std::size_t number, std::size_t total) { + std::size_t ratio = total > 0 ? CATCH_CONFIG_CONSOLE_WIDTH * number / total : 0; + return (ratio == 0 && number > 0) ? 1 : ratio; +} + +std::size_t& findMax(std::size_t& i, std::size_t& j, std::size_t& k) { + if (i > j && i > k) + return i; + else if (j > k) + return j; + else + return k; +} + +struct ColumnInfo { + enum Justification { Left, Right }; + std::string name; + int width; + Justification justification; +}; +struct ColumnBreak {}; +struct RowBreak {}; + +class Duration { + enum class Unit { + Auto, + Nanoseconds, + Microseconds, + Milliseconds, + Seconds, + Minutes + }; + static const uint64_t s_nanosecondsInAMicrosecond = 1000; + static const uint64_t s_nanosecondsInAMillisecond = 1000 * s_nanosecondsInAMicrosecond; + static const uint64_t s_nanosecondsInASecond = 1000 * s_nanosecondsInAMillisecond; + static const uint64_t s_nanosecondsInAMinute = 60 * s_nanosecondsInASecond; + + uint64_t m_inNanoseconds; + Unit m_units; + +public: + explicit Duration(uint64_t inNanoseconds, Unit units = Unit::Auto) + : m_inNanoseconds(inNanoseconds), + m_units(units) { + if (m_units == Unit::Auto) { + if (m_inNanoseconds < s_nanosecondsInAMicrosecond) + m_units = Unit::Nanoseconds; + else if (m_inNanoseconds < s_nanosecondsInAMillisecond) + m_units = Unit::Microseconds; + else if (m_inNanoseconds < s_nanosecondsInASecond) + m_units = Unit::Milliseconds; + else if (m_inNanoseconds < s_nanosecondsInAMinute) + m_units = Unit::Seconds; + else + m_units = Unit::Minutes; + } + + } + + auto value() const -> double { + switch (m_units) { + case Unit::Microseconds: + return m_inNanoseconds / static_cast(s_nanosecondsInAMicrosecond); + case Unit::Milliseconds: + return m_inNanoseconds / static_cast(s_nanosecondsInAMillisecond); + case Unit::Seconds: + return m_inNanoseconds / static_cast(s_nanosecondsInASecond); + case Unit::Minutes: + return m_inNanoseconds / static_cast(s_nanosecondsInAMinute); + default: + return static_cast(m_inNanoseconds); + } + } + auto unitsAsString() const -> std::string { + switch (m_units) { + case Unit::Nanoseconds: + return "ns"; + case Unit::Microseconds: + return "us"; + case Unit::Milliseconds: + return "ms"; + case Unit::Seconds: + return "s"; + case Unit::Minutes: + return "m"; + default: + return "** internal error **"; + } + + } + friend auto operator << (std::ostream& os, Duration const& duration) -> std::ostream& { + return os << duration.value() << " " << duration.unitsAsString(); + } +}; +} // end anon namespace + +class TablePrinter { + std::ostream& m_os; + std::vector m_columnInfos; + std::ostringstream m_oss; + int m_currentColumn = -1; + bool m_isOpen = false; + +public: + TablePrinter( std::ostream& os, std::vector columnInfos ) + : m_os( os ), + m_columnInfos( std::move( columnInfos ) ) {} + + auto columnInfos() const -> std::vector const& { + return m_columnInfos; + } + + void open() { + if (!m_isOpen) { + m_isOpen = true; + *this << RowBreak(); + for (auto const& info : m_columnInfos) + *this << info.name << ColumnBreak(); + *this << RowBreak(); + m_os << Catch::getLineOfChars<'-'>() << "\n"; + } + } + void close() { + if (m_isOpen) { + *this << RowBreak(); + m_os << std::endl; + m_isOpen = false; + } + } + + template + friend TablePrinter& operator << (TablePrinter& tp, T const& value) { + tp.m_oss << value; + return tp; + } + + friend TablePrinter& operator << (TablePrinter& tp, ColumnBreak) { + auto colStr = tp.m_oss.str(); + // This takes account of utf8 encodings + auto strSize = Catch::StringRef(colStr).numberOfCharacters(); + tp.m_oss.str(""); + tp.open(); + if (tp.m_currentColumn == static_cast(tp.m_columnInfos.size() - 1)) { + tp.m_currentColumn = -1; + tp.m_os << "\n"; + } + tp.m_currentColumn++; + + auto colInfo = tp.m_columnInfos[tp.m_currentColumn]; + auto padding = (strSize + 2 < static_cast(colInfo.width)) + ? std::string(colInfo.width - (strSize + 2), ' ') + : std::string(); + if (colInfo.justification == ColumnInfo::Left) + tp.m_os << colStr << padding << " "; + else + tp.m_os << padding << colStr << " "; + return tp; + } + + friend TablePrinter& operator << (TablePrinter& tp, RowBreak) { + if (tp.m_currentColumn > 0) { + tp.m_os << "\n"; + tp.m_currentColumn = -1; + } + return tp; + } +}; + +ConsoleReporter::ConsoleReporter(ReporterConfig const& config) + : StreamingReporterBase(config), + m_tablePrinter(new TablePrinter(config.stream(), + { + { "benchmark name", CATCH_CONFIG_CONSOLE_WIDTH - 32, ColumnInfo::Left }, + { "iters", 8, ColumnInfo::Right }, + { "elapsed ns", 14, ColumnInfo::Right }, + { "average", 14, ColumnInfo::Right } + })) {} +ConsoleReporter::~ConsoleReporter() = default; + +std::string ConsoleReporter::getDescription() { + return "Reports test results as plain lines of text"; +} + +void ConsoleReporter::noMatchingTestCases(std::string const& spec) { + stream << "No test cases matched '" << spec << '\'' << std::endl; +} + +void ConsoleReporter::assertionStarting(AssertionInfo const&) {} + +bool ConsoleReporter::assertionEnded(AssertionStats const& _assertionStats) { + AssertionResult const& result = _assertionStats.assertionResult; + + bool includeResults = m_config->includeSuccessfulResults() || !result.isOk(); + + // Drop out if result was successful but we're not printing them. + if (!includeResults && result.getResultType() != ResultWas::Warning) + return false; + + lazyPrint(); + + ConsoleAssertionPrinter printer(stream, _assertionStats, includeResults); + printer.print(); + stream << std::endl; + return true; +} + +void ConsoleReporter::sectionStarting(SectionInfo const& _sectionInfo) { + m_headerPrinted = false; + StreamingReporterBase::sectionStarting(_sectionInfo); +} +void ConsoleReporter::sectionEnded(SectionStats const& _sectionStats) { + m_tablePrinter->close(); + if (_sectionStats.missingAssertions) { + lazyPrint(); + Colour colour(Colour::ResultError); + if (m_sectionStack.size() > 1) + stream << "\nNo assertions in section"; + else + stream << "\nNo assertions in test case"; + stream << " '" << _sectionStats.sectionInfo.name << "'\n" << std::endl; + } + if (m_config->showDurations() == ShowDurations::Always) { + stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl; + } + if (m_headerPrinted) { + m_headerPrinted = false; + } + StreamingReporterBase::sectionEnded(_sectionStats); +} + +void ConsoleReporter::benchmarkStarting(BenchmarkInfo const& info) { + lazyPrintWithoutClosingBenchmarkTable(); + + auto nameCol = Column( info.name ).width( static_cast( m_tablePrinter->columnInfos()[0].width - 2 ) ); + + bool firstLine = true; + for (auto line : nameCol) { + if (!firstLine) + (*m_tablePrinter) << ColumnBreak() << ColumnBreak() << ColumnBreak(); + else + firstLine = false; + + (*m_tablePrinter) << line << ColumnBreak(); + } +} +void ConsoleReporter::benchmarkEnded(BenchmarkStats const& stats) { + Duration average(stats.elapsedTimeInNanoseconds / stats.iterations); + (*m_tablePrinter) + << stats.iterations << ColumnBreak() + << stats.elapsedTimeInNanoseconds << ColumnBreak() + << average << ColumnBreak(); +} + +void ConsoleReporter::testCaseEnded(TestCaseStats const& _testCaseStats) { + m_tablePrinter->close(); + StreamingReporterBase::testCaseEnded(_testCaseStats); + m_headerPrinted = false; +} +void ConsoleReporter::testGroupEnded(TestGroupStats const& _testGroupStats) { + if (currentGroupInfo.used) { + printSummaryDivider(); + stream << "Summary for group '" << _testGroupStats.groupInfo.name << "':\n"; + printTotals(_testGroupStats.totals); + stream << '\n' << std::endl; + } + StreamingReporterBase::testGroupEnded(_testGroupStats); +} +void ConsoleReporter::testRunEnded(TestRunStats const& _testRunStats) { + printTotalsDivider(_testRunStats.totals); + printTotals(_testRunStats.totals); + stream << std::endl; + StreamingReporterBase::testRunEnded(_testRunStats); +} + +void ConsoleReporter::lazyPrint() { + + m_tablePrinter->close(); + lazyPrintWithoutClosingBenchmarkTable(); +} + +void ConsoleReporter::lazyPrintWithoutClosingBenchmarkTable() { + + if (!currentTestRunInfo.used) + lazyPrintRunInfo(); + if (!currentGroupInfo.used) + lazyPrintGroupInfo(); + + if (!m_headerPrinted) { + printTestCaseAndSectionHeader(); + m_headerPrinted = true; + } +} +void ConsoleReporter::lazyPrintRunInfo() { + stream << '\n' << getLineOfChars<'~'>() << '\n'; + Colour colour(Colour::SecondaryText); + stream << currentTestRunInfo->name + << " is a Catch v" << libraryVersion() << " host application.\n" + << "Run with -? for options\n\n"; + + if (m_config->rngSeed() != 0) + stream << "Randomness seeded to: " << m_config->rngSeed() << "\n\n"; + + currentTestRunInfo.used = true; +} +void ConsoleReporter::lazyPrintGroupInfo() { + if (!currentGroupInfo->name.empty() && currentGroupInfo->groupsCounts > 1) { + printClosedHeader("Group: " + currentGroupInfo->name); + currentGroupInfo.used = true; + } +} +void ConsoleReporter::printTestCaseAndSectionHeader() { + assert(!m_sectionStack.empty()); + printOpenHeader(currentTestCaseInfo->name); + + if (m_sectionStack.size() > 1) { + Colour colourGuard(Colour::Headers); + + auto + it = m_sectionStack.begin() + 1, // Skip first section (test case) + itEnd = m_sectionStack.end(); + for (; it != itEnd; ++it) + printHeaderString(it->name, 2); + } + + SourceLineInfo lineInfo = m_sectionStack.back().lineInfo; + + if (!lineInfo.empty()) { + stream << getLineOfChars<'-'>() << '\n'; + Colour colourGuard(Colour::FileName); + stream << lineInfo << '\n'; + } + stream << getLineOfChars<'.'>() << '\n' << std::endl; +} + +void ConsoleReporter::printClosedHeader(std::string const& _name) { + printOpenHeader(_name); + stream << getLineOfChars<'.'>() << '\n'; +} +void ConsoleReporter::printOpenHeader(std::string const& _name) { + stream << getLineOfChars<'-'>() << '\n'; + { + Colour colourGuard(Colour::Headers); + printHeaderString(_name); + } +} + +// if string has a : in first line will set indent to follow it on +// subsequent lines +void ConsoleReporter::printHeaderString(std::string const& _string, std::size_t indent) { + std::size_t i = _string.find(": "); + if (i != std::string::npos) + i += 2; + else + i = 0; + stream << Column(_string).indent(indent + i).initialIndent(indent) << '\n'; +} + +struct SummaryColumn { + + SummaryColumn( std::string _label, Colour::Code _colour ) + : label( std::move( _label ) ), + colour( _colour ) {} + SummaryColumn addRow( std::size_t count ) { + ReusableStringStream rss; + rss << count; + std::string row = rss.str(); + for (auto& oldRow : rows) { + while (oldRow.size() < row.size()) + oldRow = ' ' + oldRow; + while (oldRow.size() > row.size()) + row = ' ' + row; + } + rows.push_back(row); + return *this; + } + + std::string label; + Colour::Code colour; + std::vector rows; + +}; + +void ConsoleReporter::printTotals( Totals const& totals ) { + if (totals.testCases.total() == 0) { + stream << Colour(Colour::Warning) << "No tests ran\n"; + } else if (totals.assertions.total() > 0 && totals.testCases.allPassed()) { + stream << Colour(Colour::ResultSuccess) << "All tests passed"; + stream << " (" + << pluralise(totals.assertions.passed, "assertion") << " in " + << pluralise(totals.testCases.passed, "test case") << ')' + << '\n'; + } else { + + std::vector columns; + columns.push_back(SummaryColumn("", Colour::None) + .addRow(totals.testCases.total()) + .addRow(totals.assertions.total())); + columns.push_back(SummaryColumn("passed", Colour::Success) + .addRow(totals.testCases.passed) + .addRow(totals.assertions.passed)); + columns.push_back(SummaryColumn("failed", Colour::ResultError) + .addRow(totals.testCases.failed) + .addRow(totals.assertions.failed)); + columns.push_back(SummaryColumn("failed as expected", Colour::ResultExpectedFailure) + .addRow(totals.testCases.failedButOk) + .addRow(totals.assertions.failedButOk)); + + printSummaryRow("test cases", columns, 0); + printSummaryRow("assertions", columns, 1); + } +} +void ConsoleReporter::printSummaryRow(std::string const& label, std::vector const& cols, std::size_t row) { + for (auto col : cols) { + std::string value = col.rows[row]; + if (col.label.empty()) { + stream << label << ": "; + if (value != "0") + stream << value; + else + stream << Colour(Colour::Warning) << "- none -"; + } else if (value != "0") { + stream << Colour(Colour::LightGrey) << " | "; + stream << Colour(col.colour) + << value << ' ' << col.label; + } + } + stream << '\n'; +} + +void ConsoleReporter::printTotalsDivider(Totals const& totals) { + if (totals.testCases.total() > 0) { + std::size_t failedRatio = makeRatio(totals.testCases.failed, totals.testCases.total()); + std::size_t failedButOkRatio = makeRatio(totals.testCases.failedButOk, totals.testCases.total()); + std::size_t passedRatio = makeRatio(totals.testCases.passed, totals.testCases.total()); + while (failedRatio + failedButOkRatio + passedRatio < CATCH_CONFIG_CONSOLE_WIDTH - 1) + findMax(failedRatio, failedButOkRatio, passedRatio)++; + while (failedRatio + failedButOkRatio + passedRatio > CATCH_CONFIG_CONSOLE_WIDTH - 1) + findMax(failedRatio, failedButOkRatio, passedRatio)--; + + stream << Colour(Colour::Error) << std::string(failedRatio, '='); + stream << Colour(Colour::ResultExpectedFailure) << std::string(failedButOkRatio, '='); + if (totals.testCases.allPassed()) + stream << Colour(Colour::ResultSuccess) << std::string(passedRatio, '='); + else + stream << Colour(Colour::Success) << std::string(passedRatio, '='); + } else { + stream << Colour(Colour::Warning) << std::string(CATCH_CONFIG_CONSOLE_WIDTH - 1, '='); + } + stream << '\n'; +} +void ConsoleReporter::printSummaryDivider() { + stream << getLineOfChars<'-'>() << '\n'; +} + +CATCH_REGISTER_REPORTER("console", ConsoleReporter) + +} // end namespace Catch + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +// end catch_reporter_console.cpp +// start catch_reporter_junit.cpp + +#include +#include +#include +#include + +namespace Catch { + + namespace { + std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + +#ifdef _MSC_VER + std::tm timeInfo = {}; + gmtime_s(&timeInfo, &rawtime); +#else + std::tm* timeInfo; + timeInfo = std::gmtime(&rawtime); +#endif + + char timeStamp[timeStampSize]; + const char * const fmt = "%Y-%m-%dT%H:%M:%SZ"; + +#ifdef _MSC_VER + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); +#else + std::strftime(timeStamp, timeStampSize, fmt, timeInfo); +#endif + return std::string(timeStamp); + } + + std::string fileNameTag(const std::vector &tags) { + auto it = std::find_if(begin(tags), + end(tags), + [] (std::string const& tag) {return tag.front() == '#'; }); + if (it != tags.end()) + return it->substr(1); + return std::string(); + } + } // anonymous namespace + + JunitReporter::JunitReporter( ReporterConfig const& _config ) + : CumulativeReporterBase( _config ), + xml( _config.stream() ) + { + m_reporterPrefs.shouldRedirectStdOut = true; + m_reporterPrefs.shouldReportAllAssertions = true; + } + + JunitReporter::~JunitReporter() {} + + std::string JunitReporter::getDescription() { + return "Reports test results in an XML format that looks like Ant's junitreport target"; + } + + void JunitReporter::noMatchingTestCases( std::string const& /*spec*/ ) {} + + void JunitReporter::testRunStarting( TestRunInfo const& runInfo ) { + CumulativeReporterBase::testRunStarting( runInfo ); + xml.startElement( "testsuites" ); + if( m_config->rngSeed() != 0 ) { + xml.startElement( "properties" ); + xml.scopedElement( "property" ) + .writeAttribute( "name", "random-seed" ) + .writeAttribute( "value", m_config->rngSeed() ); + xml.endElement(); + } + } + + void JunitReporter::testGroupStarting( GroupInfo const& groupInfo ) { + suiteTimer.start(); + stdOutForSuite.clear(); + stdErrForSuite.clear(); + unexpectedExceptions = 0; + CumulativeReporterBase::testGroupStarting( groupInfo ); + } + + void JunitReporter::testCaseStarting( TestCaseInfo const& testCaseInfo ) { + m_okToFail = testCaseInfo.okToFail(); + } + + bool JunitReporter::assertionEnded( AssertionStats const& assertionStats ) { + if( assertionStats.assertionResult.getResultType() == ResultWas::ThrewException && !m_okToFail ) + unexpectedExceptions++; + return CumulativeReporterBase::assertionEnded( assertionStats ); + } + + void JunitReporter::testCaseEnded( TestCaseStats const& testCaseStats ) { + stdOutForSuite += testCaseStats.stdOut; + stdErrForSuite += testCaseStats.stdErr; + CumulativeReporterBase::testCaseEnded( testCaseStats ); + } + + void JunitReporter::testGroupEnded( TestGroupStats const& testGroupStats ) { + double suiteTime = suiteTimer.getElapsedSeconds(); + CumulativeReporterBase::testGroupEnded( testGroupStats ); + writeGroup( *m_testGroups.back(), suiteTime ); + } + + void JunitReporter::testRunEndedCumulative() { + xml.endElement(); + } + + void JunitReporter::writeGroup( TestGroupNode const& groupNode, double suiteTime ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testsuite" ); + TestGroupStats const& stats = groupNode.value; + xml.writeAttribute( "name", stats.groupInfo.name ); + xml.writeAttribute( "errors", unexpectedExceptions ); + xml.writeAttribute( "failures", stats.totals.assertions.failed-unexpectedExceptions ); + xml.writeAttribute( "tests", stats.totals.assertions.total() ); + xml.writeAttribute( "hostname", "tbd" ); // !TBD + if( m_config->showDurations() == ShowDurations::Never ) + xml.writeAttribute( "time", "" ); + else + xml.writeAttribute( "time", suiteTime ); + xml.writeAttribute( "timestamp", getCurrentTimestamp() ); + + // Write test cases + for( auto const& child : groupNode.children ) + writeTestCase( *child ); + + xml.scopedElement( "system-out" ).writeText( trim( stdOutForSuite ), false ); + xml.scopedElement( "system-err" ).writeText( trim( stdErrForSuite ), false ); + } + + void JunitReporter::writeTestCase( TestCaseNode const& testCaseNode ) { + TestCaseStats const& stats = testCaseNode.value; + + // All test cases have exactly one section - which represents the + // test case itself. That section may have 0-n nested sections + assert( testCaseNode.children.size() == 1 ); + SectionNode const& rootSection = *testCaseNode.children.front(); + + std::string className = stats.testInfo.className; + + if( className.empty() ) { + className = fileNameTag(stats.testInfo.tags); + if ( className.empty() ) + className = "global"; + } + + if ( !m_config->name().empty() ) + className = m_config->name() + "." + className; + + writeSection( className, "", rootSection ); + } + + void JunitReporter::writeSection( std::string const& className, + std::string const& rootName, + SectionNode const& sectionNode ) { + std::string name = trim( sectionNode.stats.sectionInfo.name ); + if( !rootName.empty() ) + name = rootName + '/' + name; + + if( !sectionNode.assertions.empty() || + !sectionNode.stdOut.empty() || + !sectionNode.stdErr.empty() ) { + XmlWriter::ScopedElement e = xml.scopedElement( "testcase" ); + if( className.empty() ) { + xml.writeAttribute( "classname", name ); + xml.writeAttribute( "name", "root" ); + } + else { + xml.writeAttribute( "classname", className ); + xml.writeAttribute( "name", name ); + } + xml.writeAttribute( "time", ::Catch::Detail::stringify( sectionNode.stats.durationInSeconds ) ); + + writeAssertions( sectionNode ); + + if( !sectionNode.stdOut.empty() ) + xml.scopedElement( "system-out" ).writeText( trim( sectionNode.stdOut ), false ); + if( !sectionNode.stdErr.empty() ) + xml.scopedElement( "system-err" ).writeText( trim( sectionNode.stdErr ), false ); + } + for( auto const& childNode : sectionNode.childSections ) + if( className.empty() ) + writeSection( name, "", *childNode ); + else + writeSection( className, name, *childNode ); + } + + void JunitReporter::writeAssertions( SectionNode const& sectionNode ) { + for( auto const& assertion : sectionNode.assertions ) + writeAssertion( assertion ); + } + + void JunitReporter::writeAssertion( AssertionStats const& stats ) { + AssertionResult const& result = stats.assertionResult; + if( !result.isOk() ) { + std::string elementName; + switch( result.getResultType() ) { + case ResultWas::ThrewException: + case ResultWas::FatalErrorCondition: + elementName = "error"; + break; + case ResultWas::ExplicitFailure: + elementName = "failure"; + break; + case ResultWas::ExpressionFailed: + elementName = "failure"; + break; + case ResultWas::DidntThrowException: + elementName = "failure"; + break; + + // We should never see these here: + case ResultWas::Info: + case ResultWas::Warning: + case ResultWas::Ok: + case ResultWas::Unknown: + case ResultWas::FailureBit: + case ResultWas::Exception: + elementName = "internalError"; + break; + } + + XmlWriter::ScopedElement e = xml.scopedElement( elementName ); + + xml.writeAttribute( "message", result.getExpandedExpression() ); + xml.writeAttribute( "type", result.getTestMacroName() ); + + ReusableStringStream rss; + if( !result.getMessage().empty() ) + rss << result.getMessage() << '\n'; + for( auto const& msg : stats.infoMessages ) + if( msg.type == ResultWas::Info ) + rss << msg.message << '\n'; + + rss << "at " << result.getSourceInfo(); + xml.writeText( rss.str(), false ); + } + } + + CATCH_REGISTER_REPORTER( "junit", JunitReporter ) + +} // end namespace Catch +// end catch_reporter_junit.cpp +// start catch_reporter_listening.cpp + +#include + +namespace Catch { + + ListeningReporter::ListeningReporter() { + // We will assume that listeners will always want all assertions + m_preferences.shouldReportAllAssertions = true; + } + + void ListeningReporter::addListener( IStreamingReporterPtr&& listener ) { + m_listeners.push_back( std::move( listener ) ); + } + + void ListeningReporter::addReporter(IStreamingReporterPtr&& reporter) { + assert(!m_reporter && "Listening reporter can wrap only 1 real reporter"); + m_reporter = std::move( reporter ); + m_preferences.shouldRedirectStdOut = m_reporter->getPreferences().shouldRedirectStdOut; + } + + ReporterPreferences ListeningReporter::getPreferences() const { + return m_preferences; + } + + std::set ListeningReporter::getSupportedVerbosities() { + return std::set{ }; + } + + void ListeningReporter::noMatchingTestCases( std::string const& spec ) { + for ( auto const& listener : m_listeners ) { + listener->noMatchingTestCases( spec ); + } + m_reporter->noMatchingTestCases( spec ); + } + + void ListeningReporter::benchmarkStarting( BenchmarkInfo const& benchmarkInfo ) { + for ( auto const& listener : m_listeners ) { + listener->benchmarkStarting( benchmarkInfo ); + } + m_reporter->benchmarkStarting( benchmarkInfo ); + } + void ListeningReporter::benchmarkEnded( BenchmarkStats const& benchmarkStats ) { + for ( auto const& listener : m_listeners ) { + listener->benchmarkEnded( benchmarkStats ); + } + m_reporter->benchmarkEnded( benchmarkStats ); + } + + void ListeningReporter::testRunStarting( TestRunInfo const& testRunInfo ) { + for ( auto const& listener : m_listeners ) { + listener->testRunStarting( testRunInfo ); + } + m_reporter->testRunStarting( testRunInfo ); + } + + void ListeningReporter::testGroupStarting( GroupInfo const& groupInfo ) { + for ( auto const& listener : m_listeners ) { + listener->testGroupStarting( groupInfo ); + } + m_reporter->testGroupStarting( groupInfo ); + } + + void ListeningReporter::testCaseStarting( TestCaseInfo const& testInfo ) { + for ( auto const& listener : m_listeners ) { + listener->testCaseStarting( testInfo ); + } + m_reporter->testCaseStarting( testInfo ); + } + + void ListeningReporter::sectionStarting( SectionInfo const& sectionInfo ) { + for ( auto const& listener : m_listeners ) { + listener->sectionStarting( sectionInfo ); + } + m_reporter->sectionStarting( sectionInfo ); + } + + void ListeningReporter::assertionStarting( AssertionInfo const& assertionInfo ) { + for ( auto const& listener : m_listeners ) { + listener->assertionStarting( assertionInfo ); + } + m_reporter->assertionStarting( assertionInfo ); + } + + // The return value indicates if the messages buffer should be cleared: + bool ListeningReporter::assertionEnded( AssertionStats const& assertionStats ) { + for( auto const& listener : m_listeners ) { + static_cast( listener->assertionEnded( assertionStats ) ); + } + return m_reporter->assertionEnded( assertionStats ); + } + + void ListeningReporter::sectionEnded( SectionStats const& sectionStats ) { + for ( auto const& listener : m_listeners ) { + listener->sectionEnded( sectionStats ); + } + m_reporter->sectionEnded( sectionStats ); + } + + void ListeningReporter::testCaseEnded( TestCaseStats const& testCaseStats ) { + for ( auto const& listener : m_listeners ) { + listener->testCaseEnded( testCaseStats ); + } + m_reporter->testCaseEnded( testCaseStats ); + } + + void ListeningReporter::testGroupEnded( TestGroupStats const& testGroupStats ) { + for ( auto const& listener : m_listeners ) { + listener->testGroupEnded( testGroupStats ); + } + m_reporter->testGroupEnded( testGroupStats ); + } + + void ListeningReporter::testRunEnded( TestRunStats const& testRunStats ) { + for ( auto const& listener : m_listeners ) { + listener->testRunEnded( testRunStats ); + } + m_reporter->testRunEnded( testRunStats ); + } + + void ListeningReporter::skipTest( TestCaseInfo const& testInfo ) { + for ( auto const& listener : m_listeners ) { + listener->skipTest( testInfo ); + } + m_reporter->skipTest( testInfo ); + } + + bool ListeningReporter::isMulti() const { + return true; + } + +} // end namespace Catch +// end catch_reporter_listening.cpp +// start catch_reporter_xml.cpp + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch + // Note that 4062 (not all labels are handled + // and default is missing) is enabled +#endif + +namespace Catch { + XmlReporter::XmlReporter( ReporterConfig const& _config ) + : StreamingReporterBase( _config ), + m_xml(_config.stream()) + { + m_reporterPrefs.shouldRedirectStdOut = true; + m_reporterPrefs.shouldReportAllAssertions = true; + } + + XmlReporter::~XmlReporter() = default; + + std::string XmlReporter::getDescription() { + return "Reports test results as an XML document"; + } + + std::string XmlReporter::getStylesheetRef() const { + return std::string(); + } + + void XmlReporter::writeSourceInfo( SourceLineInfo const& sourceInfo ) { + m_xml + .writeAttribute( "filename", sourceInfo.file ) + .writeAttribute( "line", sourceInfo.line ); + } + + void XmlReporter::noMatchingTestCases( std::string const& s ) { + StreamingReporterBase::noMatchingTestCases( s ); + } + + void XmlReporter::testRunStarting( TestRunInfo const& testInfo ) { + StreamingReporterBase::testRunStarting( testInfo ); + std::string stylesheetRef = getStylesheetRef(); + if( !stylesheetRef.empty() ) + m_xml.writeStylesheetRef( stylesheetRef ); + m_xml.startElement( "Catch" ); + if( !m_config->name().empty() ) + m_xml.writeAttribute( "name", m_config->name() ); + if( m_config->rngSeed() != 0 ) + m_xml.scopedElement( "Randomness" ) + .writeAttribute( "seed", m_config->rngSeed() ); + } + + void XmlReporter::testGroupStarting( GroupInfo const& groupInfo ) { + StreamingReporterBase::testGroupStarting( groupInfo ); + m_xml.startElement( "Group" ) + .writeAttribute( "name", groupInfo.name ); + } + + void XmlReporter::testCaseStarting( TestCaseInfo const& testInfo ) { + StreamingReporterBase::testCaseStarting(testInfo); + m_xml.startElement( "TestCase" ) + .writeAttribute( "name", trim( testInfo.name ) ) + .writeAttribute( "description", testInfo.description ) + .writeAttribute( "tags", testInfo.tagsAsString() ); + + writeSourceInfo( testInfo.lineInfo ); + + if ( m_config->showDurations() == ShowDurations::Always ) + m_testCaseTimer.start(); + m_xml.ensureTagClosed(); + } + + void XmlReporter::sectionStarting( SectionInfo const& sectionInfo ) { + StreamingReporterBase::sectionStarting( sectionInfo ); + if( m_sectionDepth++ > 0 ) { + m_xml.startElement( "Section" ) + .writeAttribute( "name", trim( sectionInfo.name ) ); + writeSourceInfo( sectionInfo.lineInfo ); + m_xml.ensureTagClosed(); + } + } + + void XmlReporter::assertionStarting( AssertionInfo const& ) { } + + bool XmlReporter::assertionEnded( AssertionStats const& assertionStats ) { + + AssertionResult const& result = assertionStats.assertionResult; + + bool includeResults = m_config->includeSuccessfulResults() || !result.isOk(); + + if( includeResults || result.getResultType() == ResultWas::Warning ) { + // Print any info messages in tags. + for( auto const& msg : assertionStats.infoMessages ) { + if( msg.type == ResultWas::Info && includeResults ) { + m_xml.scopedElement( "Info" ) + .writeText( msg.message ); + } else if ( msg.type == ResultWas::Warning ) { + m_xml.scopedElement( "Warning" ) + .writeText( msg.message ); + } + } + } + + // Drop out if result was successful but we're not printing them. + if( !includeResults && result.getResultType() != ResultWas::Warning ) + return true; + + // Print the expression if there is one. + if( result.hasExpression() ) { + m_xml.startElement( "Expression" ) + .writeAttribute( "success", result.succeeded() ) + .writeAttribute( "type", result.getTestMacroName() ); + + writeSourceInfo( result.getSourceInfo() ); + + m_xml.scopedElement( "Original" ) + .writeText( result.getExpression() ); + m_xml.scopedElement( "Expanded" ) + .writeText( result.getExpandedExpression() ); + } + + // And... Print a result applicable to each result type. + switch( result.getResultType() ) { + case ResultWas::ThrewException: + m_xml.startElement( "Exception" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + case ResultWas::FatalErrorCondition: + m_xml.startElement( "FatalErrorCondition" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + case ResultWas::Info: + m_xml.scopedElement( "Info" ) + .writeText( result.getMessage() ); + break; + case ResultWas::Warning: + // Warning will already have been written + break; + case ResultWas::ExplicitFailure: + m_xml.startElement( "Failure" ); + writeSourceInfo( result.getSourceInfo() ); + m_xml.writeText( result.getMessage() ); + m_xml.endElement(); + break; + default: + break; + } + + if( result.hasExpression() ) + m_xml.endElement(); + + return true; + } + + void XmlReporter::sectionEnded( SectionStats const& sectionStats ) { + StreamingReporterBase::sectionEnded( sectionStats ); + if( --m_sectionDepth > 0 ) { + XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResults" ); + e.writeAttribute( "successes", sectionStats.assertions.passed ); + e.writeAttribute( "failures", sectionStats.assertions.failed ); + e.writeAttribute( "expectedFailures", sectionStats.assertions.failedButOk ); + + if ( m_config->showDurations() == ShowDurations::Always ) + e.writeAttribute( "durationInSeconds", sectionStats.durationInSeconds ); + + m_xml.endElement(); + } + } + + void XmlReporter::testCaseEnded( TestCaseStats const& testCaseStats ) { + StreamingReporterBase::testCaseEnded( testCaseStats ); + XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResult" ); + e.writeAttribute( "success", testCaseStats.totals.assertions.allOk() ); + + if ( m_config->showDurations() == ShowDurations::Always ) + e.writeAttribute( "durationInSeconds", m_testCaseTimer.getElapsedSeconds() ); + + if( !testCaseStats.stdOut.empty() ) + m_xml.scopedElement( "StdOut" ).writeText( trim( testCaseStats.stdOut ), false ); + if( !testCaseStats.stdErr.empty() ) + m_xml.scopedElement( "StdErr" ).writeText( trim( testCaseStats.stdErr ), false ); + + m_xml.endElement(); + } + + void XmlReporter::testGroupEnded( TestGroupStats const& testGroupStats ) { + StreamingReporterBase::testGroupEnded( testGroupStats ); + // TODO: Check testGroupStats.aborting and act accordingly. + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", testGroupStats.totals.assertions.passed ) + .writeAttribute( "failures", testGroupStats.totals.assertions.failed ) + .writeAttribute( "expectedFailures", testGroupStats.totals.assertions.failedButOk ); + m_xml.endElement(); + } + + void XmlReporter::testRunEnded( TestRunStats const& testRunStats ) { + StreamingReporterBase::testRunEnded( testRunStats ); + m_xml.scopedElement( "OverallResults" ) + .writeAttribute( "successes", testRunStats.totals.assertions.passed ) + .writeAttribute( "failures", testRunStats.totals.assertions.failed ) + .writeAttribute( "expectedFailures", testRunStats.totals.assertions.failedButOk ); + m_xml.endElement(); + } + + CATCH_REGISTER_REPORTER( "xml", XmlReporter ) + +} // end namespace Catch + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +// end catch_reporter_xml.cpp + +namespace Catch { + LeakDetector leakDetector; +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +// end catch_impl.hpp +#endif + +#ifdef CATCH_CONFIG_MAIN +// start catch_default_main.hpp + +#ifndef __OBJC__ + +#if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(_UNICODE) && !defined(DO_NOT_USE_WMAIN) +// Standard C/C++ Win32 Unicode wmain entry point +extern "C" int wmain (int argc, wchar_t * argv[], wchar_t * []) { +#else +// Standard C/C++ main entry point +int main (int argc, char * argv[]) { +#endif + + return Catch::Session().run( argc, argv ); +} + +#else // __OBJC__ + +// Objective-C entry point +int main (int argc, char * const argv[]) { +#if !CATCH_ARC_ENABLED + NSAutoreleasePool * pool = [[NSAutoreleasePool alloc] init]; +#endif + + Catch::registerTestMethods(); + int result = Catch::Session().run( argc, (char**)argv ); + +#if !CATCH_ARC_ENABLED + [pool drain]; +#endif + + return result; +} + +#endif // __OBJC__ + +// end catch_default_main.hpp +#endif + +#if !defined(CATCH_CONFIG_IMPL_ONLY) + +#ifdef CLARA_CONFIG_MAIN_NOT_DEFINED +# undef CLARA_CONFIG_MAIN +#endif + +#if !defined(CATCH_CONFIG_DISABLE) +////// +// If this config identifier is defined then all CATCH macros are prefixed with CATCH_ +#ifdef CATCH_CONFIG_PREFIX_ALL + +#define CATCH_REQUIRE( ... ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define CATCH_REQUIRE_FALSE( ... ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) + +#define CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr ) +#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CATCH_REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CATCH_REQUIRE_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::Normal, matcher, expr ) +#endif// CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_REQUIRE_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CATCH_REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, __VA_ARGS__ ) + +#define CATCH_CHECK( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECK_FALSE( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) +#define CATCH_CHECKED_IF( ... ) INTERNAL_CATCH_IF( "CATCH_CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECKED_ELSE( ... ) INTERNAL_CATCH_ELSE( "CATCH_CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECK_NOFAIL( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, __VA_ARGS__ ) + +#define CATCH_CHECK_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CATCH_CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CATCH_CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CATCH_CHECK_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_CHECK_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CATCH_CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg ) + +#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define CATCH_INFO( msg ) INTERNAL_CATCH_INFO( "CATCH_INFO", msg ) +#define CATCH_WARN( msg ) INTERNAL_CATCH_MSG( "CATCH_WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg ) +#define CATCH_CAPTURE( ... ) INTERNAL_CATCH_CAPTURE( INTERNAL_CATCH_UNIQUE_NAME(capturer), "CATCH_CAPTURE",__VA_ARGS__ ) + +#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) +#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define CATCH_METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) +#define CATCH_REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ ) +#define CATCH_SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) +#define CATCH_DYNAMIC_SECTION( ... ) INTERNAL_CATCH_DYNAMIC_SECTION( __VA_ARGS__ ) +#define CATCH_FAIL( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define CATCH_FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CATCH_SUCCEED( ... ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#endif + +#if !defined(CATCH_CONFIG_RUNTIME_STATIC_REQUIRE) +#define CATCH_STATIC_REQUIRE( ... ) static_assert( __VA_ARGS__ , #__VA_ARGS__ ); CATCH_SUCCEED( #__VA_ARGS__ ) +#define CATCH_STATIC_REQUIRE_FALSE( ... ) static_assert( !(__VA_ARGS__), "!(" #__VA_ARGS__ ")" ); CATCH_SUCCEED( #__VA_ARGS__ ) +#else +#define CATCH_STATIC_REQUIRE( ... ) CATCH_REQUIRE( __VA_ARGS__ ) +#define CATCH_STATIC_REQUIRE_FALSE( ... ) CATCH_REQUIRE_FALSE( __VA_ARGS__ ) +#endif + +// "BDD-style" convenience wrappers +#define CATCH_SCENARIO( ... ) CATCH_TEST_CASE( "Scenario: " __VA_ARGS__ ) +#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ ) +#define CATCH_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Given: " << desc ) +#define CATCH_AND_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( "And given: " << desc ) +#define CATCH_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " When: " << desc ) +#define CATCH_AND_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And when: " << desc ) +#define CATCH_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Then: " << desc ) +#define CATCH_AND_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And: " << desc ) + +// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required +#else + +#define REQUIRE( ... ) INTERNAL_CATCH_TEST( "REQUIRE", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define REQUIRE_FALSE( ... ) INTERNAL_CATCH_TEST( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) + +#define REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr ) +#define REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "REQUIRE_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::Normal, matcher, expr ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define REQUIRE_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, __VA_ARGS__ ) + +#define CHECK( ... ) INTERNAL_CATCH_TEST( "CHECK", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECK_FALSE( ... ) INTERNAL_CATCH_TEST( "CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, __VA_ARGS__ ) +#define CHECKED_IF( ... ) INTERNAL_CATCH_IF( "CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECKED_ELSE( ... ) INTERNAL_CATCH_ELSE( "CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECK_NOFAIL( ... ) INTERNAL_CATCH_TEST( "CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, __VA_ARGS__ ) + +#define CHECK_THROWS( ... ) INTERNAL_CATCH_THROWS( "CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr ) +#define CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CHECK_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::ContinueOnFailure, matcher, expr ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CHECK_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg ) + +#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg ) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg ) +#define UNSCOPED_INFO( msg ) INTERNAL_CATCH_UNSCOPED_INFO( "UNSCOPED_INFO", msg ) +#define WARN( msg ) INTERNAL_CATCH_MSG( "WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg ) +#define CAPTURE( ... ) INTERNAL_CATCH_CAPTURE( INTERNAL_CATCH_UNIQUE_NAME(capturer), "CAPTURE",__VA_ARGS__ ) + +#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ ) +#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ ) +#define REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ ) +#define SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ ) +#define DYNAMIC_SECTION( ... ) INTERNAL_CATCH_DYNAMIC_SECTION( __VA_ARGS__ ) +#define FAIL( ... ) INTERNAL_CATCH_MSG( "FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ ) +#define FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define SUCCEED( ... ) INTERNAL_CATCH_MSG( "SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ ) +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE() + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) ) +#endif + +#if !defined(CATCH_CONFIG_RUNTIME_STATIC_REQUIRE) +#define STATIC_REQUIRE( ... ) static_assert( __VA_ARGS__, #__VA_ARGS__ ); SUCCEED( #__VA_ARGS__ ) +#define STATIC_REQUIRE_FALSE( ... ) static_assert( !(__VA_ARGS__), "!(" #__VA_ARGS__ ")" ); SUCCEED( "!(" #__VA_ARGS__ ")" ) +#else +#define STATIC_REQUIRE( ... ) REQUIRE( __VA_ARGS__ ) +#define STATIC_REQUIRE_FALSE( ... ) REQUIRE_FALSE( __VA_ARGS__ ) +#endif + +#endif + +#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) + +// "BDD-style" convenience wrappers +#define SCENARIO( ... ) TEST_CASE( "Scenario: " __VA_ARGS__ ) +#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ ) + +#define GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Given: " << desc ) +#define AND_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( "And given: " << desc ) +#define WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " When: " << desc ) +#define AND_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And when: " << desc ) +#define THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Then: " << desc ) +#define AND_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And: " << desc ) + +using Catch::Detail::Approx; + +#else // CATCH_CONFIG_DISABLE + +////// +// If this config identifier is defined then all CATCH macros are prefixed with CATCH_ +#ifdef CATCH_CONFIG_PREFIX_ALL + +#define CATCH_REQUIRE( ... ) (void)(0) +#define CATCH_REQUIRE_FALSE( ... ) (void)(0) + +#define CATCH_REQUIRE_THROWS( ... ) (void)(0) +#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) (void)(0) +#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif// CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_REQUIRE_NOTHROW( ... ) (void)(0) + +#define CATCH_CHECK( ... ) (void)(0) +#define CATCH_CHECK_FALSE( ... ) (void)(0) +#define CATCH_CHECKED_IF( ... ) if (__VA_ARGS__) +#define CATCH_CHECKED_ELSE( ... ) if (!(__VA_ARGS__)) +#define CATCH_CHECK_NOFAIL( ... ) (void)(0) + +#define CATCH_CHECK_THROWS( ... ) (void)(0) +#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) (void)(0) +#define CATCH_CHECK_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CATCH_CHECK_NOTHROW( ... ) (void)(0) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CATCH_CHECK_THAT( arg, matcher ) (void)(0) + +#define CATCH_REQUIRE_THAT( arg, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define CATCH_INFO( msg ) (void)(0) +#define CATCH_WARN( msg ) (void)(0) +#define CATCH_CAPTURE( msg ) (void)(0) + +#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_METHOD_AS_TEST_CASE( method, ... ) +#define CATCH_REGISTER_TEST_CASE( Function, ... ) (void)(0) +#define CATCH_SECTION( ... ) +#define CATCH_DYNAMIC_SECTION( ... ) +#define CATCH_FAIL( ... ) (void)(0) +#define CATCH_FAIL_CHECK( ... ) (void)(0) +#define CATCH_SUCCEED( ... ) (void)(0) + +#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) ) +#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#endif + +// "BDD-style" convenience wrappers +#define CATCH_SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className ) +#define CATCH_GIVEN( desc ) +#define CATCH_AND_GIVEN( desc ) +#define CATCH_WHEN( desc ) +#define CATCH_AND_WHEN( desc ) +#define CATCH_THEN( desc ) +#define CATCH_AND_THEN( desc ) + +#define CATCH_STATIC_REQUIRE( ... ) (void)(0) +#define CATCH_STATIC_REQUIRE_FALSE( ... ) (void)(0) + +// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required +#else + +#define REQUIRE( ... ) (void)(0) +#define REQUIRE_FALSE( ... ) (void)(0) + +#define REQUIRE_THROWS( ... ) (void)(0) +#define REQUIRE_THROWS_AS( expr, exceptionType ) (void)(0) +#define REQUIRE_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define REQUIRE_NOTHROW( ... ) (void)(0) + +#define CHECK( ... ) (void)(0) +#define CHECK_FALSE( ... ) (void)(0) +#define CHECKED_IF( ... ) if (__VA_ARGS__) +#define CHECKED_ELSE( ... ) if (!(__VA_ARGS__)) +#define CHECK_NOFAIL( ... ) (void)(0) + +#define CHECK_THROWS( ... ) (void)(0) +#define CHECK_THROWS_AS( expr, exceptionType ) (void)(0) +#define CHECK_THROWS_WITH( expr, matcher ) (void)(0) +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS +#define CHECK_NOTHROW( ... ) (void)(0) + +#if !defined(CATCH_CONFIG_DISABLE_MATCHERS) +#define CHECK_THAT( arg, matcher ) (void)(0) + +#define REQUIRE_THAT( arg, matcher ) (void)(0) +#endif // CATCH_CONFIG_DISABLE_MATCHERS + +#define INFO( msg ) (void)(0) +#define WARN( msg ) (void)(0) +#define CAPTURE( msg ) (void)(0) + +#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) +#define METHOD_AS_TEST_CASE( method, ... ) +#define REGISTER_TEST_CASE( Function, ... ) (void)(0) +#define SECTION( ... ) +#define DYNAMIC_SECTION( ... ) +#define FAIL( ... ) (void)(0) +#define FAIL_CHECK( ... ) (void)(0) +#define SUCCEED( ... ) (void)(0) +#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ )) + +#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#else +#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) ) ) +#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), className ) ) +#define TEMPLATE_PRODUCT_TEST_CASE( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ ) +#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) +#endif + +#define STATIC_REQUIRE( ... ) (void)(0) +#define STATIC_REQUIRE_FALSE( ... ) (void)(0) + +#endif + +#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature ) + +// "BDD-style" convenience wrappers +#define SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ) ) +#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className ) + +#define GIVEN( desc ) +#define AND_GIVEN( desc ) +#define WHEN( desc ) +#define AND_WHEN( desc ) +#define THEN( desc ) +#define AND_THEN( desc ) + +using Catch::Detail::Approx; + +#endif + +#endif // ! CATCH_CONFIG_IMPL_ONLY + +// start catch_reenable_warnings.h + + +#ifdef __clang__ +# ifdef __ICC // icpc defines the __clang__ macro +# pragma warning(pop) +# else +# pragma clang diagnostic pop +# endif +#elif defined __GNUC__ +# pragma GCC diagnostic pop +#endif + +// end catch_reenable_warnings.h +// end catch.hpp +#endif // TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED + diff --git a/tests/common.h b/tests/common.h new file mode 100644 index 0000000..59f9b87 --- /dev/null +++ b/tests/common.h @@ -0,0 +1,507 @@ +#pragma once + +#include +#include + +#include "catch.hpp" + +#include +#include + +#define REPORTS_ERROR(expr) witest::ReportsError(wistd::is_same{}, [&]() { return expr; }) +#define REQUIRE_ERROR(expr) REQUIRE(REPORTS_ERROR(expr)) +#define REQUIRE_NOERROR(expr) REQUIRE_FALSE(REPORTS_ERROR(expr)) + +#define CRASHES(expr) witest::DoesCodeCrash([&]() { return expr; }) +#define REQUIRE_CRASH(expr) REQUIRE(CRASHES(expr)) +#define REQUIRE_NOCRASH(expr) REQUIRE_FALSE(CRASHES(expr)) + +// NOTE: SUCCEEDED/FAILED macros not used here since Catch2 can give us better diagnostics if it knows the HRESULT value +#define REQUIRE_SUCCEEDED(expr) REQUIRE((HRESULT)(expr) >= 0) +#define REQUIRE_FAILED(expr) REQUIRE((HRESULT)(expr) < 0) + +// MACRO double evaluation check. +// The following macro illustrates a common problem with writing macros: +// #define MY_MAX(a, b) (((a) > (b)) ? (a) : (b)) +// The issue is that whatever code is being used for both a and b is being executed twice. +// This isn't harmful when thinking of constant numerics, but consider this example: +// MY_MAX(4, InterlockedIncrement(&cCount)) +// This evaluates the (B) parameter twice and results in incrementing the counter twice. +// We use MDEC in unit tests to verify that this kind of pattern is not present. A test +// of this kind: +// MY_MAX(MDEC(4), MDEC(InterlockedIncrement(&cCount)) +// will verify that the parameters are not evaluated more than once. +#define MDEC(PARAM) (witest::details::MacroDoubleEvaluationCheck(__LINE__, #PARAM), PARAM) + +// There's some functionality that we need for testing that's not available for the app partition. Since those tests are +// primarily compilation tests, declare what's needed here +extern "C" { + +#if !WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM | WINAPI_PARTITION_GAMES) +WINBASEAPI _Ret_maybenull_ +PVOID WINAPI AddVectoredExceptionHandler(_In_ ULONG First, _In_ PVECTORED_EXCEPTION_HANDLER Handler); + +WINBASEAPI +ULONG WINAPI RemoveVectoredExceptionHandler(_In_ PVOID Handle); +#endif + +} + +#pragma warning(push) +#pragma warning(disable: 4702) // Unreachable code + +namespace witest +{ + namespace details + { + inline void MacroDoubleEvaluationCheck(size_t uLine, _In_ const char* pszCode) + { + struct SEval + { + size_t uLine; + const char* pszCode; + }; + + static SEval rgEval[15] = {}; + static size_t nOffset = 0; + + for (auto& eval : rgEval) + { + if ((eval.uLine == uLine) && (eval.pszCode != nullptr) && (0 == strcmp(pszCode, eval.pszCode))) + { + // This verification indicates that macro-double-evaluation check is firing for a particular usage of MDEC(). + FAIL("Expression '" << pszCode << "' double evaluated in macro on line " << uLine); + } + } + + rgEval[nOffset].uLine = uLine; + rgEval[nOffset].pszCode = pszCode; + nOffset = (nOffset + 1) % ARRAYSIZE(rgEval); + } + + template + class AssignTemporaryValueCleanup + { + public: + AssignTemporaryValueCleanup(_In_ AssignTemporaryValueCleanup const &) = delete; + AssignTemporaryValueCleanup & operator=(_In_ AssignTemporaryValueCleanup const &) = delete; + + explicit AssignTemporaryValueCleanup(_Inout_ T *pVal, T val) WI_NOEXCEPT : + m_pVal(pVal), + m_valOld(*pVal) + { + *pVal = val; + } + + AssignTemporaryValueCleanup(_Inout_ AssignTemporaryValueCleanup && other) WI_NOEXCEPT : + m_pVal(other.m_pVal), + m_valOld(other.m_valOld) + { + other.m_pVal = nullptr; + } + + ~AssignTemporaryValueCleanup() WI_NOEXCEPT + { + operator()(); + } + + void operator()() WI_NOEXCEPT + { + if (m_pVal != nullptr) + { + *m_pVal = m_valOld; + m_pVal = nullptr; + } + } + + void Dismiss() WI_NOEXCEPT + { + m_pVal = nullptr; + } + + private: + T *m_pVal; + T m_valOld; + }; + } + + // Use the following routine to allow for a variable to be swapped with another and automatically revert the + // assignment at the end of the scope. + // Example: + // int nFoo = 10 + // { + // auto revert = witest::AssignTemporaryValue(&nFoo, 12); + // // nFoo will now be 12 within this scope... + // } + // // and nFoo is back to 10 within the outer scope + template + inline witest::details::AssignTemporaryValueCleanup AssignTemporaryValue(_Inout_ T *pVal, T val) WI_NOEXCEPT + { + return witest::details::AssignTemporaryValueCleanup(pVal, val); + } + + //! Global class which tracks objects that derive from @ref AllocatedObject. + //! Use `witest::g_objectCount.Leaked()` to determine if an object deriving from `AllocatedObject` has been leaked. + class GlobalCount + { + public: + int m_count = 0; + + //! Returns `true` if there are any objects that derive from @ref AllocatedObject still in memory. + bool Leaked() const + { + return (m_count != 0); + } + + ~GlobalCount() + { + if (Leaked()) + { + // NOTE: This runs when no test is active, but will still cause an assert failure to notify + FAIL("GlobalCount is non-zero; there is a leak somewhere"); + } + } + }; + __declspec(selectany) GlobalCount g_objectCount; + + //! Derive an allocated test object from witest::AllocatedObject to ensure that those objects aren't leaked in the test. + //! Note that you can call g_objectCount.Leaked() at any point to determine if a leak has already occurred (assuming that + //! all objects should have been destroyed at that point. + class AllocatedObject + { + public: + AllocatedObject() { g_objectCount.m_count++; } + ~AllocatedObject() { g_objectCount.m_count--; } + }; + + template + bool DoesCodeThrow(Lambda&& callOp) + { +#ifdef WIL_ENABLE_EXCEPTIONS + try +#endif + { + callOp(); + } +#ifdef WIL_ENABLE_EXCEPTIONS + catch (...) + { + return true; + } +#endif + + return false; + } + + [[noreturn]] + inline void __stdcall TranslateFailFastException(PEXCEPTION_RECORD rec, PCONTEXT, DWORD) + { + // RaiseFailFastException cannot be continued or handled. By instead calling RaiseException, it allows us to + // handle exceptions + ::RaiseException(rec->ExceptionCode, rec->ExceptionFlags, rec->NumberParameters, rec->ExceptionInformation); + } + + constexpr DWORD msvc_exception_code = 0xE06D7363; + + // This is a MAJOR hack. Catch2 registers a vectored exception handler - which gets run before our handler below - + // that interprets a set of exception codes as fatal. We don't want this behavior since we may be expecting such + // crashes, so instead translate all exception codes to something not fatal + inline LONG WINAPI TranslateExceptionCodeHandler(PEXCEPTION_POINTERS info) + { + if (info->ExceptionRecord->ExceptionCode != witest::msvc_exception_code) + { + info->ExceptionRecord->ExceptionCode = STATUS_STACK_BUFFER_OVERRUN; + } + + return EXCEPTION_CONTINUE_SEARCH; + } + + namespace details + { + inline bool DoesCodeCrash(wistd::function& callOp) + { + bool result = false; + __try + { + callOp(); + } + // Let C++ exceptions pass through + __except ((::GetExceptionCode() != msvc_exception_code) ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH) + { + result = true; + } + return result; + } + } + + inline bool DoesCodeCrash(wistd::function callOp) + { + // See above; we don't want to actually fail fast, so make sure we raise a different exception instead + auto restoreHandler = AssignTemporaryValue(&wil::details::g_pfnRaiseFailFastException, TranslateFailFastException); + + auto handler = AddVectoredExceptionHandler(1, TranslateExceptionCodeHandler); + auto removeVectoredHandler = wil::scope_exit([&] { RemoveVectoredExceptionHandler(handler); }); + + return details::DoesCodeCrash(callOp); + } + + template + bool ReportsError(wistd::false_type, Lambda&& callOp) + { + bool doesThrow = false; + bool doesCrash = DoesCodeCrash([&]() + { + doesThrow = DoesCodeThrow(callOp); + }); + + return doesThrow || doesCrash; + } + + template + bool ReportsError(wistd::true_type, Lambda&& callOp) + { + return FAILED(callOp()); + } + +#ifdef WIL_ENABLE_EXCEPTIONS + class TestFailureCache final : + public wil::details::IFailureCallback + { + public: + TestFailureCache() : + m_callbackHolder(this) + { + } + + void clear() + { + m_failures.clear(); + } + + size_t size() const + { + return m_failures.size(); + } + + bool empty() const + { + return m_failures.empty(); + } + + const wil::FailureInfo& operator[](size_t pos) const + { + return m_failures.at(pos).GetFailureInfo(); + } + + // IFailureCallback + bool NotifyFailure(wil::FailureInfo const & failure) WI_NOEXCEPT override + { + m_failures.emplace_back(failure); + return false; + } + + private: + std::vector m_failures; + wil::details::ThreadFailureCallbackHolder m_callbackHolder; + }; +#endif + + inline HRESULT GetTempFileName(wchar_t (&result)[MAX_PATH]) + { + wchar_t dir[MAX_PATH]; + RETURN_LAST_ERROR_IF(::GetTempPathW(MAX_PATH, dir) == 0); + RETURN_LAST_ERROR_IF(::GetTempFileNameW(dir, L"wil", 0, result) == 0); + return S_OK; + } + + inline HRESULT CreateUniqueFolderPath(wchar_t (&buffer)[MAX_PATH], PCWSTR root = nullptr) + { + if (root) + { + RETURN_LAST_ERROR_IF(::GetTempFileNameW(root, L"wil", 0, buffer) == 0); + } + else + { + wchar_t tempPath[MAX_PATH]; + RETURN_LAST_ERROR_IF(::GetTempPathW(ARRAYSIZE(tempPath), tempPath) == 0); + RETURN_LAST_ERROR_IF(::GetLongPathNameW(tempPath, tempPath, ARRAYSIZE(tempPath)) == 0); + RETURN_LAST_ERROR_IF(::GetTempFileNameW(tempPath, L"wil", 0, buffer) == 0); + } + RETURN_IF_WIN32_BOOL_FALSE(DeleteFileW(buffer)); + PathCchRemoveExtension(buffer, ARRAYSIZE(buffer)); + return S_OK; + } + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + + struct TestFolder + { + TestFolder() + { + if (SUCCEEDED(CreateUniqueFolderPath(m_path)) && SUCCEEDED(wil::CreateDirectoryDeepNoThrow(m_path))) + { + m_valid = true; + } + } + + TestFolder(PCWSTR path) + { + if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), path)) && SUCCEEDED(wil::CreateDirectoryDeepNoThrow(m_path))) + { + m_valid = true; + } + } + + TestFolder(const TestFolder&) = delete; + TestFolder& operator=(const TestFolder&) = delete; + + TestFolder(TestFolder&& other) + { + if (other.m_valid) + { + m_valid = true; + other.m_valid = false; + wcscpy_s(m_path, other.m_path); + } + } + + ~TestFolder() + { + if (m_valid) + { + wil::RemoveDirectoryRecursiveNoThrow(m_path); + } + } + + operator bool() const + { + return m_valid; + } + + operator PCWSTR() const + { + return m_path; + } + + PCWSTR Path() const + { + return m_path; + } + + private: + + bool m_valid = false; + wchar_t m_path[MAX_PATH] = L""; + }; + + struct TestFile + { + TestFile(PCWSTR path) + { + if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), path))) + { + Create(); + } + } + + TestFile(PCWSTR dirPath, PCWSTR fileName) + { + if (SUCCEEDED(StringCchCopyW(m_path, ARRAYSIZE(m_path), dirPath)) && SUCCEEDED(PathCchAppend(m_path, ARRAYSIZE(m_path), fileName))) + { + Create(); + } + } + + TestFile(const TestFile&) = delete; + TestFile& operator=(const TestFile&) = delete; + + TestFile(TestFile&& other) + { + if (other.m_valid) + { + m_valid = true; + m_deleteDir = other.m_deleteDir; + other.m_valid = other.m_deleteDir = false; + wcscpy_s(m_path, other.m_path); + } + } + + ~TestFile() + { + // Best effort on all of these + if (m_valid) + { + ::DeleteFileW(m_path); + } + if (m_deleteDir) + { + size_t parentLength; + if (wil::try_get_parent_path_range(m_path, &parentLength)) + { + m_path[parentLength] = L'\0'; + ::RemoveDirectoryW(m_path); + m_path[parentLength] = L'\\'; + } + } + } + + operator bool() const + { + return m_valid; + } + + operator PCWSTR() const + { + return m_path; + } + + PCWSTR Path() const + { + return m_path; + } + + private: + + HRESULT Create() + { + WI_ASSERT(!m_valid && !m_deleteDir); + wil::unique_hfile fileHandle(::CreateFileW(m_path, + FILE_WRITE_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, + CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr)); + if (!fileHandle) + { + auto err = ::GetLastError(); + size_t parentLength; + if ((err == ERROR_PATH_NOT_FOUND) && wil::try_get_parent_path_range(m_path, &parentLength)) + { + m_path[parentLength] = L'\0'; + RETURN_IF_FAILED(wil::CreateDirectoryDeepNoThrow(m_path)); + m_deleteDir = true; + + m_path[parentLength] = L'\\'; + fileHandle.reset(::CreateFileW(m_path, + FILE_WRITE_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, + CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, nullptr)); + RETURN_LAST_ERROR_IF(!fileHandle); + } + else + { + RETURN_WIN32(err); + } + } + + m_valid = true; + return S_OK; + } + + bool m_valid = false; + bool m_deleteDir = false; + wchar_t m_path[MAX_PATH] = L""; + }; + +#endif +} + +#pragma warning(pop) diff --git a/tests/cpplatest/CMakeLists.txt b/tests/cpplatest/CMakeLists.txt new file mode 100644 index 0000000..989f3dd --- /dev/null +++ b/tests/cpplatest/CMakeLists.txt @@ -0,0 +1,25 @@ + +# Compilers often don't use the latest C++ standard as the default. Periodically update this value (possibly conditioned +# on compiler) as new standards are ratified/support is available +set(CMAKE_CXX_STANDARD 17) + +project(witest.cpplatest) +add_executable(witest.cpplatest) + +target_sources(witest.cpplatest PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../main.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../CppWinRTTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../CommonTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ComTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../FileSystemTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResourceTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResultTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../SafeCastTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../StlTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../TokenHelpersTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../UniqueWinRTEventTokenTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WatcherTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WinRTTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WistdTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../wiTest.cpp + ) diff --git a/tests/main.cpp b/tests/main.cpp new file mode 100644 index 0000000..4b05624 --- /dev/null +++ b/tests/main.cpp @@ -0,0 +1,7 @@ + +#pragma comment(lib, "Pathcch.lib") +#pragma comment(lib, "RuntimeObject.lib") +#pragma comment(lib, "Synchronization.lib") + +#define CATCH_CONFIG_MAIN +#include "catch.hpp" diff --git a/tests/noexcept/CMakeLists.txt b/tests/noexcept/CMakeLists.txt new file mode 100644 index 0000000..d0fd972 --- /dev/null +++ b/tests/noexcept/CMakeLists.txt @@ -0,0 +1,26 @@ + +project(witest.noexcept) +add_executable(witest.noexcept) + +# Turn off exceptions for this test +replace_cxx_flag("/EHsc" "") +add_definitions(-DCATCH_CONFIG_DISABLE_EXCEPTIONS) + +# Catch2 has a no exceptions mode (configured above), however still includes STL headers which contain try...catch +# statements... Thankfully MSVC just gives us a warning that we can disable +append_cxx_flag("/wd4530") + +target_sources(witest.noexcept PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../main.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../CommonTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ComTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../FileSystemTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResourceTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResultTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../SafeCastTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../TokenHelpersTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../UniqueWinRTEventTokenTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WatcherTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WinRTTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WistdTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../wiTest.cpp) diff --git a/tests/normal/CMakeLists.txt b/tests/normal/CMakeLists.txt new file mode 100644 index 0000000..37714cf --- /dev/null +++ b/tests/normal/CMakeLists.txt @@ -0,0 +1,20 @@ + +project(witest) +add_executable(witest) + +target_sources(witest PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../main.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../CommonTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ComTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../FileSystemTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResourceTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../ResultTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../SafeCastTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../StlTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../TokenHelpersTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../UniqueWinRTEventTokenTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WatcherTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WinRTTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../WistdTests.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/../wiTest.cpp + ) diff --git a/tests/test_objects.h b/tests/test_objects.h new file mode 100644 index 0000000..7b04627 --- /dev/null +++ b/tests/test_objects.h @@ -0,0 +1,107 @@ +#pragma once + +#include "catch.hpp" + +// Useful for validating that the copy constructor is never called (e.g. to validate perfect forwarding). Note that +// the copy constructor/assignment operator are not deleted since we want to be able to validate in scenarios that +// require CopyConstructible (e.g. for wistd::function) +struct fail_on_copy +{ + fail_on_copy() = default; + + fail_on_copy(const fail_on_copy&) + { + FAIL("Copy constructor invoked for fail_on_copy type"); + } + + fail_on_copy(fail_on_copy&&) = default; + + fail_on_copy& operator=(const fail_on_copy&) + { + FAIL("Copy assignment operator invoked for fail_on_copy type"); + return *this; + } + + fail_on_copy& operator=(fail_on_copy&&) = default; +}; + +// Useful for validating that objects get copied e.g. as opposed to capturing a reference +struct value_holder +{ + int value = 0xbadf00d; + + ~value_holder() + { + value = 0xbadf00d; + } +}; + +// Useful for validating that functions, etc. are callable with move-only types +// Example real type that is move only is Microsoft::WRL::Wrappers::HString +struct cannot_copy +{ + cannot_copy(const cannot_copy&) = delete; + cannot_copy& operator=(const cannot_copy&) = delete; + + cannot_copy(cannot_copy&&) = default; + cannot_copy& operator=(cannot_copy&&) = default; +}; + +// State for object_counter type. This has the unfortunate side effect that the object_counter type cannot be used in +// contexts that require a default constructible type, but has the nice property that it allows for tests to run +// concurrently +struct object_counter_state +{ + volatile LONG constructed_count = 0; + volatile LONG destructed_count = 0; + volatile LONG copy_count = 0; + volatile LONG move_count = 0; + + LONG instance_count() + { + return constructed_count - destructed_count; + } +}; + +struct object_counter +{ + object_counter_state* state; + + object_counter(object_counter_state& s) : + state(&s) + { + ::InterlockedIncrement(&state->constructed_count); + } + + object_counter(const object_counter& other) : + state(other.state) + { + ::InterlockedIncrement(&state->constructed_count); + ::InterlockedIncrement(&state->copy_count); + } + + object_counter(object_counter&& other) : + state(other.state) + { + ::InterlockedIncrement(&state->constructed_count); + ::InterlockedIncrement(&state->move_count); + } + + ~object_counter() + { + ::InterlockedIncrement(&state->destructed_count); + state = nullptr; + } + + object_counter& operator=(const object_counter&) + { + ::InterlockedIncrement(&state->copy_count); + return *this; + } + + object_counter& operator=(object_counter&&) + { + ::InterlockedIncrement(&state->move_count); + return *this; + } +}; diff --git a/tests/wiTest.cpp b/tests/wiTest.cpp new file mode 100644 index 0000000..aea96a8 --- /dev/null +++ b/tests/wiTest.cpp @@ -0,0 +1,3448 @@ + +#include + +#ifdef WIL_ENABLE_EXCEPTIONS +#include +#include +#include +#endif + +#include +#include +#include +#include +#include +#include + +// Do not include most headers until after the WIL headers to ensure that we're not inadvertently adding any unnecessary +// dependencies to STL, WRL, or indirectly retrieved headers + +#ifndef __cplusplus_winrt +#include +#include +#endif + +// Include Resource.h a second time after including other headers +#include + +#include "common.h" +#include "MallocSpy.h" +#include "test_objects.h" + +#pragma warning(push) +#pragma warning(disable: 4702) // Unreachable code + +TEST_CASE("WindowsInternalTests::CommonHelpers", "[resource]") +{ + { + wil::unique_handle spHandle; + REQUIRE(spHandle == nullptr); + REQUIRE(nullptr == spHandle); + REQUIRE_FALSE(spHandle != nullptr); + REQUIRE_FALSE(nullptr != spHandle); + + //equivalence check will static_assert because spMutex does not allow pointer access + wil::mutex_release_scope_exit spMutex; + //REQUIRE(spMutex == nullptr); + //REQUIRE(nullptr == spMutex); + + //equivalence check will static_assert because spFile does not use nullptr_t as a invalid value + wil::unique_hfile spFile; + //REQUIRE(spFile == nullptr); + } +#ifdef __WIL_WINBASE_STL + { + wil::shared_handle spHandle; + REQUIRE(spHandle == nullptr); + REQUIRE(nullptr == spHandle); + REQUIRE_FALSE(spHandle != nullptr); + REQUIRE_FALSE(nullptr != spHandle); + } +#endif +} + +TEST_CASE("WindowsInternalTests::AssertMacros", "[result_macros]") +{ + //WI_ASSERT macros are all no-ops if in retail +#ifndef RESULT_DEBUG + WI_ASSERT(false); + WI_ASSERT_MSG(false, "WI_ASSERT_MSG"); + WI_ASSERT_NOASSUME(false); + WI_ASSERT_MSG_NOASSUME(false, "WI_ASSERT_MSG_NOASSUME"); + WI_VERIFY(false); + WI_VERIFY_MSG(false, "WI_VERIFY_MSG"); +#endif + + WI_ASSERT(true); + WI_ASSERT_MSG(true, "WI_ASSERT_MSG"); + WI_ASSERT_NOASSUME(true); + WI_ASSERT_MSG_NOASSUME(true, "WI_ASSERT_MSG_NOASSUME"); + WI_VERIFY(true); + WI_VERIFY_MSG(true, "WI_VERIFY_MSG"); +} + +void __stdcall EmptyResultMacrosLoggingCallback(wil::FailureInfo*, PWSTR, size_t) WI_NOEXCEPT +{ +} + +#ifdef WIL_ENABLE_EXCEPTIONS +// Test Result Macros +void TestErrorCallbacks() +{ + { + size_t callbackCount = 0; + auto monitor = wil::ThreadFailureCallback([&](wil::FailureInfo const &failure) -> bool + { + REQUIRE(failure.hr == E_ACCESSDENIED); + callbackCount++; + return false; + }); + + size_t const depthCount = 10; + for (size_t index = 0; index < depthCount; index++) + { + LOG_HR(E_ACCESSDENIED); + } + REQUIRE(callbackCount == depthCount); + } + { + wil::ThreadFailureCache cache; + + LOG_HR(E_ACCESSDENIED); + REQUIRE(cache.GetFailure() != nullptr); + REQUIRE(cache.GetFailure()->hr == E_ACCESSDENIED); + + wil::ThreadFailureCache cacheNested; + + LOG_HR(E_FAIL); unsigned short errorLine = __LINE__; + LOG_HR(E_FAIL); + LOG_HR(E_FAIL); + REQUIRE(cache.GetFailure()->hr == E_FAIL); + REQUIRE(cache.GetFailure()->uLineNumber == errorLine); + REQUIRE(cacheNested.GetFailure()->hr == E_FAIL); + REQUIRE(cacheNested.GetFailure()->uLineNumber == errorLine); + } +} + +DWORD WINAPI ErrorCallbackThreadTest(_In_ LPVOID lpParameter) +{ + try + { + HANDLE hEvent = reinterpret_cast(lpParameter); + + for (size_t stress = 0; stress < 200; stress++) + { + Sleep(1); // allow the threadpool to saturate the thread count... + TestErrorCallbacks(); + } + THROW_IF_WIN32_BOOL_FALSE(::SetEvent(hEvent)); + } + catch (...) + { + FAIL(); + } + return 1; +} + +void StressErrorCallbacks() +{ + auto restore = witest::AssignTemporaryValue(&wil::g_fResultOutputDebugString, false); + + size_t const threadCount = 20; + wil::unique_event eventArray[threadCount]; + + for (size_t index = 0; index < threadCount; index++) + { + eventArray[index].create(); +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + THROW_IF_WIN32_BOOL_FALSE(::QueueUserWorkItem(ErrorCallbackThreadTest, eventArray[index].get(), 0)); +#else + ErrorCallbackThreadTest(eventArray[index].get()); +#endif /* WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) */ + } + for (size_t index = 0; index < threadCount; index++) + { + eventArray[index].wait(); + } +} + +TEST_CASE("WindowsInternalTests::ResultMacrosStress", "[!hide][result_macros][stress]") +{ + auto restore = witest::AssignTemporaryValue(&wil::g_pfnResultLoggingCallback, EmptyResultMacrosLoggingCallback); + StressErrorCallbacks(); +} +#endif + +#define E_AD HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED) +void SetAD() +{ + ::SetLastError(ERROR_ACCESS_DENIED); +} + +class AlternateAccessDeniedException +{ +}; + +#ifdef WIL_ENABLE_EXCEPTIONS +class DerivedAccessDeniedException : public wil::ResultException +{ +public: + DerivedAccessDeniedException() : ResultException(E_AD) {} +}; + +HRESULT __stdcall TestResultCaughtFromException() WI_NOEXCEPT +{ + try + { + throw; + } + catch (AlternateAccessDeniedException) + { + return E_AD; + } + catch (...) + { + } + return S_OK; +} +#endif + +HANDLE hValid = reinterpret_cast(1); +HANDLE& hValidRef() { return hValid; } +HANDLE hNull = NULL; +HANDLE hInvalid = INVALID_HANDLE_VALUE; +void* pValid = reinterpret_cast(1); +void*& pValidRef() { return pValid; } +void* pNull = nullptr; +void*& pNullRef() { return pNull; } +bool fTrue = true; +bool& fTrueRef() { return fTrue; } +bool fFalse = false; +bool& fFalseRef() { return fFalse; } +BOOL fTRUE = TRUE; +BOOL& fTRUERef() { return fTRUE; } +BOOL fFALSE = FALSE; +DWORD errSuccess = ERROR_SUCCESS; +DWORD& errSuccessRef() { return errSuccess; } +HRESULT hrOK = S_OK; +HRESULT& hrOKRef() { return hrOK; } +HRESULT hrFAIL = E_FAIL; +HRESULT& hrFAILRef() { return hrFAIL; } +const HRESULT E_hrOutOfPaper = HRESULT_FROM_WIN32(ERROR_OUT_OF_PAPER); +NTSTATUS ntOK = STATUS_SUCCESS; +NTSTATUS& ntOKRef() { return ntOK; } +NTSTATUS ntFAIL = STATUS_NO_MEMORY; +NTSTATUS& ntFAILRef() { return ntFAIL; } +const HRESULT S_hrNtOkay = wil::details::NtStatusToHr(STATUS_SUCCESS); +const HRESULT E_hrNtAssertionFailure = wil::details::NtStatusToHr(STATUS_ASSERTION_FAILURE); + +wil::StoredFailureInfo g_log; + +void __stdcall ResultMacrosLoggingCallback(wil::FailureInfo *pFailure, PWSTR, size_t) WI_NOEXCEPT +{ + g_log = *pFailure; +} + +enum class EType +{ + None = 0x00, + Expected = 0x02, + Msg = 0x04, + FailFast = 0x08, // overall fail fast (throw exception on successful result code, for example) + FailFastMacro = 0x10, // explicit use of fast fail fast (FAIL_FAST_IF...) + NoContext = 0x20 // file and line info can be wrong (throw does not happen in context to code) +}; +DEFINE_ENUM_FLAG_OPERATORS(EType); + +template +bool VerifyResult(unsigned int lineNumber, EType type, HRESULT hr, TLambda&& lambda) +{ + bool succeeded = true; +#ifdef WIL_ENABLE_EXCEPTIONS + try + { +#endif + HRESULT lambdaResult = E_FAIL; + bool didFailFast = true; + { + didFailFast = witest::DoesCodeCrash([&]() + { + lambdaResult = lambda(); + }); + } + if (WI_IsFlagSet(type, EType::FailFast)) + { + REQUIRE(didFailFast); + } + else + { + if (WI_IsFlagClear(type, EType::Expected)) + { + if (SUCCEEDED(hr)) + { + REQUIRE(hr == lambdaResult); + REQUIRE(lineNumber != g_log.GetFailureInfo().uLineNumber); + REQUIRE(!didFailFast); + } + else + { + REQUIRE((WI_IsFlagSet(type, EType::NoContext) || (g_log.GetFailureInfo().uLineNumber == lineNumber))); + REQUIRE(g_log.GetFailureInfo().hr == hr); + REQUIRE((WI_IsFlagClear(type, EType::Msg) || (nullptr != wcsstr(g_log.GetFailureInfo().pszMessage, L"msg")))); + REQUIRE((WI_IsFlagClear(type, EType::FailFastMacro) || (didFailFast))); + REQUIRE((WI_IsFlagSet(type, EType::FailFastMacro) || (!didFailFast))); + } + } + } +#ifdef WIL_ENABLE_EXCEPTIONS + } + catch (...) + { + succeeded = false; + } +#endif + + // Ensure we come out clean... + ::SetLastError(ERROR_SUCCESS); + return succeeded; +} + +#ifdef WIL_ENABLE_EXCEPTIONS +template +HRESULT TranslateException(TLambda&& lambda) +{ + try + { + lambda(); + } + catch (wil::ResultException &re) + { + return re.GetErrorCode(); + } +#ifdef __cplusplus_winrt + catch (Platform::Exception ^pe) + { + return wil::details::GetErrorCode(pe); + } +#endif + catch (...) + { + FAIL(); + } + return S_OK; +} +#endif + +#define REQUIRE_RETURNS(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::None, hr, lambda)) +#define REQUIRE_RETURNS_MSG(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::Msg, hr, lambda)) +#define REQUIRE_RETURNS_EXPECTED(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::Expected, hr, lambda)) + +#ifdef WIL_ENABLE_EXCEPTIONS +#define REQUIRE_THROWS_RESULT(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::None, hr, [&] { return TranslateException(lambda); })) +#define REQUIRE_THROWS_MSG(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::Msg, hr, [&] { return TranslateException(lambda); })) +#else +#define REQUIRE_THROWS_RESULT(hr, lambda) +#define REQUIRE_THROWS_MSG(hr, lambda) +#endif + +#define REQUIRE_LOG(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::None, hr, [&] { auto fn = (lambda); fn(); return hr; })) +#define REQUIRE_LOG_MSG(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::Msg, hr, [&] { auto fn = (lambda); fn(); return hr; })) + +#define REQUIRE_FAILFAST(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::FailFastMacro, hr, [&] { auto fn = (lambda); fn(); return hr; })) +#define REQUIRE_FAILFAST_MSG(hr, lambda) REQUIRE(VerifyResult(__LINE__, EType::FailFastMacro | EType::Msg, hr, [&] { auto fn = (lambda); fn(); return hr; })) +#define REQUIRE_FAILFAST_UNSPECIFIED(lambda) REQUIRE(VerifyResult(__LINE__, EType::FailFast, S_OK, [&] { auto fn = (lambda); fn(); return S_OK; })) + +TEST_CASE("WindowsInternalTests::ResultMacros", "[result_macros]") +{ + auto restoreLoggingCallback = witest::AssignTemporaryValue(&wil::g_pfnResultLoggingCallback, ResultMacrosLoggingCallback); +#ifdef WIL_ENABLE_EXCEPTIONS + auto restoreExceptionCallback = witest::AssignTemporaryValue(&wil::g_pfnResultFromCaughtException, TestResultCaughtFromException); +#endif + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR(MDEC(hrOKRef())); }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR(MDEC(hrOKRef())); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(E_FAIL, [] { RETURN_HR(E_FAIL); }); + REQUIRE_RETURNS_MSG(E_FAIL, [] { RETURN_HR_MSG(E_FAIL, "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_FAIL, [] { THROW_HR(E_FAIL); }); + REQUIRE_THROWS_MSG(E_FAIL, [] { THROW_HR_MSG(E_FAIL, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_FAIL, [] { LOG_HR(E_FAIL); }); + REQUIRE_LOG_MSG(E_FAIL, [] { LOG_HR_MSG(E_FAIL, "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_FAIL, [] { FAIL_FAST_HR(E_FAIL); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { FAIL_FAST_HR_MSG(E_FAIL, "msg: %d", __LINE__); }); + + REQUIRE_FAILFAST_UNSPECIFIED([] { ::SetLastError(0); FAIL_FAST_LAST_ERROR(); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { ::SetLastError(0); FAIL_FAST_LAST_ERROR_MSG("msg: %d", __LINE__); }); + + REQUIRE_RETURNS(E_AD, [] { SetAD(); RETURN_LAST_ERROR(); }); + REQUIRE_RETURNS_MSG(E_AD, [] { SetAD(); RETURN_LAST_ERROR_MSG("msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_AD, [] { SetAD(); THROW_LAST_ERROR(); }); + REQUIRE_THROWS_MSG(E_AD, [] { SetAD(); THROW_LAST_ERROR_MSG("msg: %d", __LINE__); }); + REQUIRE_LOG(E_AD, [] { SetAD(); LOG_LAST_ERROR(); }); + REQUIRE_LOG_MSG(E_AD, [] { SetAD(); LOG_LAST_ERROR_MSG("msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_AD, [] { SetAD(); FAIL_FAST_LAST_ERROR(); }); + REQUIRE_FAILFAST_MSG(E_AD, [] { SetAD(); FAIL_FAST_LAST_ERROR_MSG("msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_WIN32(MDEC(errSuccessRef())); }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_WIN32_MSG(MDEC(errSuccessRef()), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_WIN32(MDEC(errSuccessRef())); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_WIN32_MSG(MDEC(errSuccessRef()), "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(E_AD, [] { RETURN_WIN32(ERROR_ACCESS_DENIED); }); + REQUIRE_RETURNS_MSG(E_AD, [] { RETURN_WIN32_MSG(ERROR_ACCESS_DENIED, "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_AD, [] { THROW_WIN32(ERROR_ACCESS_DENIED); }); + REQUIRE_THROWS_MSG(E_AD, [] { THROW_WIN32_MSG(ERROR_ACCESS_DENIED, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_AD, [] { LOG_WIN32(ERROR_ACCESS_DENIED); }); + REQUIRE_LOG_MSG(E_AD, [] { LOG_WIN32_MSG(ERROR_ACCESS_DENIED, "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_AD, [] { FAIL_FAST_WIN32(ERROR_ACCESS_DENIED); }); + REQUIRE_FAILFAST_MSG(E_AD, [] { FAIL_FAST_WIN32_MSG(ERROR_ACCESS_DENIED, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_IF_FAILED(MDEC(hrOKRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_IF_FAILED_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_IF_FAILED_EXPECTED(MDEC(hrOKRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(S_OK == THROW_IF_FAILED(MDEC(hrOKRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(S_OK == THROW_IF_FAILED_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(S_OK == LOG_IF_FAILED(MDEC(hrOKRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(S_OK == LOG_IF_FAILED_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(S_OK == FAIL_FAST_IF_FAILED(MDEC(hrOKRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(S_OK == FAIL_FAST_IF_FAILED_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(E_FAIL, [] { RETURN_IF_FAILED(E_FAIL); return S_OK; }); + REQUIRE_RETURNS_MSG(E_FAIL, [] { RETURN_IF_FAILED_MSG(E_FAIL, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_IF_FAILED_EXPECTED(E_FAIL); return S_OK; }); + REQUIRE_THROWS_RESULT(E_FAIL, [] { THROW_IF_FAILED(E_FAIL); }); + REQUIRE_THROWS_MSG(E_FAIL, [] { THROW_IF_FAILED_MSG(E_FAIL, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(E_FAIL == LOG_IF_FAILED(E_FAIL)); }); + REQUIRE_LOG_MSG(E_FAIL, [] { REQUIRE(E_FAIL == LOG_IF_FAILED_MSG(E_FAIL, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_FAIL, [] { FAIL_FAST_IF_FAILED(E_FAIL); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { FAIL_FAST_IF_FAILED_MSG(E_FAIL, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_IF_WIN32_BOOL_FALSE(MDEC(fTRUERef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_IF_WIN32_BOOL_FALSE_MSG(MDEC(fTRUERef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(MDEC(fTRUERef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(fTRUE == THROW_IF_WIN32_BOOL_FALSE(MDEC(fTRUERef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(fTRUE == THROW_IF_WIN32_BOOL_FALSE_MSG(MDEC(fTRUERef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(fTRUE == LOG_IF_WIN32_BOOL_FALSE(MDEC(fTRUERef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(fTRUE == LOG_IF_WIN32_BOOL_FALSE_MSG(MDEC(fTRUERef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fTRUE == FAIL_FAST_IF_WIN32_BOOL_FALSE(MDEC(fTRUERef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(fTRUE == FAIL_FAST_IF_WIN32_BOOL_FALSE_MSG(MDEC(fTRUERef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(E_AD, [] { SetAD(); RETURN_IF_WIN32_BOOL_FALSE(fFALSE); return S_OK; }); + REQUIRE_RETURNS_MSG(E_AD, [] { SetAD(); RETURN_IF_WIN32_BOOL_FALSE_MSG(fFALSE, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_AD, [] { SetAD(); RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(fFALSE); return S_OK; }); + REQUIRE_THROWS_RESULT(E_AD, [] { SetAD(); THROW_IF_WIN32_BOOL_FALSE(fFALSE); }); + REQUIRE_THROWS_MSG(E_AD, [] { SetAD(); THROW_IF_WIN32_BOOL_FALSE_MSG(fFALSE, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_AD, [] { SetAD(); REQUIRE(fFALSE == LOG_IF_WIN32_BOOL_FALSE(fFALSE)); }); + REQUIRE_LOG_MSG(E_AD, [] { SetAD(); REQUIRE(fFALSE == LOG_IF_WIN32_BOOL_FALSE_MSG(fFALSE, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_AD, [] { SetAD(); FAIL_FAST_IF_WIN32_BOOL_FALSE(fFALSE); }); + REQUIRE_FAILFAST_MSG(E_AD, [] { SetAD(); FAIL_FAST_IF_WIN32_BOOL_FALSE_MSG(fFALSE, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_IF_WIN32_ERROR(MDEC(hrOKRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_IF_WIN32_ERROR_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_IF_WIN32_ERROR_EXPECTED(MDEC(hrOKRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(S_OK == THROW_IF_WIN32_ERROR(MDEC(hrOKRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(S_OK == THROW_IF_WIN32_ERROR_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(S_OK == LOG_IF_WIN32_ERROR(MDEC(hrOKRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(S_OK == LOG_IF_WIN32_ERROR_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(S_OK == FAIL_FAST_IF_WIN32_ERROR(MDEC(hrOKRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(S_OK == FAIL_FAST_IF_WIN32_ERROR_MSG(MDEC(hrOKRef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(E_hrOutOfPaper, [] { RETURN_IF_WIN32_ERROR(ERROR_OUT_OF_PAPER); return S_OK; }); + REQUIRE_RETURNS_MSG(E_hrOutOfPaper, [] { RETURN_IF_WIN32_ERROR_MSG(ERROR_OUT_OF_PAPER, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_hrOutOfPaper, [] { RETURN_IF_WIN32_ERROR_EXPECTED(ERROR_OUT_OF_PAPER); return S_OK; }); + REQUIRE_THROWS_RESULT(E_hrOutOfPaper, [] { THROW_IF_WIN32_ERROR(ERROR_OUT_OF_PAPER); }); + REQUIRE_THROWS_MSG(E_hrOutOfPaper, [] { THROW_IF_WIN32_ERROR_MSG(ERROR_OUT_OF_PAPER, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_hrOutOfPaper, [] { REQUIRE(ERROR_OUT_OF_PAPER == LOG_IF_WIN32_ERROR(ERROR_OUT_OF_PAPER)); }); + REQUIRE_LOG_MSG(E_hrOutOfPaper, [] { REQUIRE(ERROR_OUT_OF_PAPER == LOG_IF_WIN32_ERROR_MSG(ERROR_OUT_OF_PAPER, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_hrOutOfPaper, [] { FAIL_FAST_IF_WIN32_ERROR(ERROR_OUT_OF_PAPER); }); + REQUIRE_FAILFAST_MSG(E_hrOutOfPaper, [] { FAIL_FAST_IF_WIN32_ERROR_MSG(ERROR_OUT_OF_PAPER, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_hrNtOkay, [] { RETURN_NTSTATUS(MDEC(ntOKRef())); }); + REQUIRE_RETURNS_MSG(S_hrNtOkay, [] { RETURN_NTSTATUS_MSG(MDEC(ntOKRef()), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_NTSTATUS(MDEC(ntOKRef())); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_NTSTATUS_MSG(MDEC(ntOKRef()), "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(E_hrNtAssertionFailure, [] { RETURN_NTSTATUS(STATUS_ASSERTION_FAILURE); }); + REQUIRE_RETURNS_MSG(E_hrNtAssertionFailure, [] { RETURN_NTSTATUS_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_hrNtAssertionFailure, [] { THROW_NTSTATUS(STATUS_ASSERTION_FAILURE); }); + REQUIRE_THROWS_MSG(E_hrNtAssertionFailure, [] { THROW_NTSTATUS_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_hrNtAssertionFailure, [] { LOG_NTSTATUS(STATUS_ASSERTION_FAILURE); }); + REQUIRE_LOG_MSG(E_hrNtAssertionFailure, [] { LOG_NTSTATUS_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_hrNtAssertionFailure, [] { FAIL_FAST_NTSTATUS(STATUS_ASSERTION_FAILURE); }); + REQUIRE_FAILFAST_MSG(E_hrNtAssertionFailure, [] { FAIL_FAST_NTSTATUS_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_IF_NTSTATUS_FAILED_MSG(MDEC(ntOKRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_IF_NTSTATUS_FAILED_EXPECTED(MDEC(ntOKRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(STATUS_WAIT_0 == THROW_IF_NTSTATUS_FAILED(MDEC(ntOKRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(STATUS_WAIT_0 == THROW_IF_NTSTATUS_FAILED_MSG(MDEC(ntOKRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(STATUS_WAIT_0 == LOG_IF_NTSTATUS_FAILED(MDEC(ntOKRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(STATUS_WAIT_0 == LOG_IF_NTSTATUS_FAILED_MSG(MDEC(ntOKRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(STATUS_WAIT_0 == FAIL_FAST_IF_NTSTATUS_FAILED(MDEC(ntOKRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(STATUS_WAIT_0 == FAIL_FAST_IF_NTSTATUS_FAILED_MSG(MDEC(ntOKRef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(E_hrNtAssertionFailure, [] { RETURN_IF_NTSTATUS_FAILED(STATUS_ASSERTION_FAILURE); return S_OK; }); + REQUIRE_RETURNS_MSG(E_hrNtAssertionFailure, [] { RETURN_IF_NTSTATUS_FAILED_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_hrNtAssertionFailure, [] { RETURN_IF_NTSTATUS_FAILED_EXPECTED(STATUS_ASSERTION_FAILURE); return S_OK; }); + REQUIRE_THROWS_RESULT(E_hrNtAssertionFailure, [] { THROW_IF_NTSTATUS_FAILED(STATUS_ASSERTION_FAILURE); }); + REQUIRE_THROWS_MSG(E_hrNtAssertionFailure, [] { THROW_IF_NTSTATUS_FAILED_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_hrNtAssertionFailure, [] { REQUIRE(STATUS_ASSERTION_FAILURE == LOG_IF_NTSTATUS_FAILED(STATUS_ASSERTION_FAILURE)); }); + REQUIRE_LOG_MSG(E_hrNtAssertionFailure, [] { REQUIRE(STATUS_ASSERTION_FAILURE == LOG_IF_NTSTATUS_FAILED_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_hrNtAssertionFailure, [] { FAIL_FAST_IF_NTSTATUS_FAILED(STATUS_ASSERTION_FAILURE); }); + REQUIRE_FAILFAST_MSG(E_hrNtAssertionFailure, [] { FAIL_FAST_IF_NTSTATUS_FAILED_MSG(STATUS_ASSERTION_FAILURE, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { RETURN_IF_NTSTATUS_FAILED(STATUS_NO_MEMORY); return S_OK; }); + REQUIRE_RETURNS_MSG(E_OUTOFMEMORY, [] { RETURN_IF_NTSTATUS_FAILED_MSG(STATUS_NO_MEMORY, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_OUTOFMEMORY, [] { RETURN_IF_NTSTATUS_FAILED_EXPECTED(STATUS_NO_MEMORY); return S_OK; }); + REQUIRE_THROWS_RESULT(E_OUTOFMEMORY, [] { THROW_IF_NTSTATUS_FAILED(STATUS_NO_MEMORY); }); + REQUIRE_THROWS_MSG(E_OUTOFMEMORY, [] { THROW_IF_NTSTATUS_FAILED_MSG(STATUS_NO_MEMORY, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_OUTOFMEMORY, [] { REQUIRE(STATUS_NO_MEMORY == LOG_IF_NTSTATUS_FAILED(STATUS_NO_MEMORY)); }); + REQUIRE_LOG_MSG(E_OUTOFMEMORY, [] { REQUIRE(STATUS_NO_MEMORY == LOG_IF_NTSTATUS_FAILED_MSG(STATUS_NO_MEMORY, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_OUTOFMEMORY, [] { FAIL_FAST_IF_NTSTATUS_FAILED(STATUS_NO_MEMORY); }); + REQUIRE_FAILFAST_MSG(E_OUTOFMEMORY, [] { FAIL_FAST_IF_NTSTATUS_FAILED_MSG(STATUS_NO_MEMORY, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_IF_NULL_ALLOC(MDEC(pValidRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_IF_NULL_ALLOC_MSG(MDEC(pValidRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_IF_NULL_ALLOC_EXPECTED(MDEC(pValidRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(pValid == THROW_IF_NULL_ALLOC(MDEC(pValidRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(pValid == THROW_IF_NULL_ALLOC_MSG(MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(pValid == LOG_IF_NULL_ALLOC(MDEC(pValidRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(pValid == LOG_IF_NULL_ALLOC_MSG(MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(pValid == FAIL_FAST_IF_NULL_ALLOC(MDEC(pValidRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(pValid == FAIL_FAST_IF_NULL_ALLOC_MSG(MDEC(pValidRef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { RETURN_IF_NULL_ALLOC(pNull); return S_OK; }); + REQUIRE_RETURNS_MSG(E_OUTOFMEMORY, [] { RETURN_IF_NULL_ALLOC_MSG(pNull, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_OUTOFMEMORY, [] { RETURN_IF_NULL_ALLOC_EXPECTED(pNull); return S_OK; }); + REQUIRE_THROWS_RESULT(E_OUTOFMEMORY, [] { THROW_IF_NULL_ALLOC(pNull); }); + REQUIRE_THROWS_MSG(E_OUTOFMEMORY, [] { THROW_IF_NULL_ALLOC_MSG(pNull, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_OUTOFMEMORY, [] { REQUIRE(pNull == LOG_IF_NULL_ALLOC(pNull)); }); + REQUIRE_LOG_MSG(E_OUTOFMEMORY, [] { REQUIRE(pNull == LOG_IF_NULL_ALLOC_MSG(pNull, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_OUTOFMEMORY, [] { FAIL_FAST_IF_NULL_ALLOC(pNull); }); + REQUIRE_FAILFAST_MSG(E_OUTOFMEMORY, [] { FAIL_FAST_IF_NULL_ALLOC_MSG(pNull, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF(MDEC(S_OK), MDEC(fTrueRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_MSG(MDEC(S_OK), MDEC(fTrueRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_EXPECTED(MDEC(S_OK), MDEC(fTrueRef())); return S_OK; }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_IF(MDEC(S_OK), MDEC(fTrueRef())); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_IF_MSG(MDEC(S_OK), MDEC(fTrueRef()), "msg: %d", __LINE__); }); + REQUIRE_RETURNS(E_FAIL, [] { RETURN_HR_IF(E_FAIL, fTrue); return S_OK; }); + REQUIRE_RETURNS_MSG(E_FAIL, [] { RETURN_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_HR_IF_EXPECTED(E_FAIL, fTrue); return S_OK; }); + REQUIRE_THROWS_RESULT(E_FAIL, [] { THROW_HR_IF(E_FAIL, fTrue); }); + REQUIRE_THROWS_MSG(E_FAIL, [] { THROW_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(fTrue == LOG_HR_IF(E_FAIL, fTrue)); }); + REQUIRE_LOG_MSG(E_FAIL, [] { REQUIRE(fTrue == LOG_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_FAIL, [] { FAIL_FAST_HR_IF(E_FAIL, fTrue); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { FAIL_FAST_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF(MDEC(S_OK), MDEC(fTrueRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_MSG(MDEC(S_OK), MDEC(fTrueRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_EXPECTED(MDEC(S_OK), MDEC(fTrueRef())); return S_OK; }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_IF(MDEC(S_OK), MDEC(fTrueRef())); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_IF_MSG(MDEC(S_OK), MDEC(fTrueRef()), "msg: %d", __LINE__); }); + REQUIRE_RETURNS(E_FAIL, [] { RETURN_HR_IF(E_FAIL, fTrue); return S_OK; }); + REQUIRE_RETURNS_MSG(E_FAIL, [] { RETURN_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_HR_IF_EXPECTED(E_FAIL, fTrue); return S_OK; }); + REQUIRE_THROWS_RESULT(E_FAIL, [] { THROW_HR_IF(E_FAIL, fTrue); }); + REQUIRE_THROWS_MSG(E_FAIL, [] { THROW_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(fTrue == LOG_HR_IF(E_FAIL, fTrue)); }); + REQUIRE_LOG_MSG(E_FAIL, [] { REQUIRE(fTrue == LOG_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_FAIL, [] { FAIL_FAST_HR_IF(E_FAIL, fTrue); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { FAIL_FAST_HR_IF_MSG(E_FAIL, fTrue, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF(MDEC(S_OK), MDEC(fFalseRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_EXPECTED(MDEC(S_OK), MDEC(fFalseRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF(MDEC(S_OK), MDEC(fFalseRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF(MDEC(S_OK), MDEC(fFalseRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF(MDEC(S_OK), MDEC(fFalseRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF(E_FAIL, MDEC(fFalseRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_EXPECTED(E_FAIL, MDEC(fFalseRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF(E_FAIL, MDEC(fFalseRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF(E_FAIL, MDEC(fFalseRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF(E_FAIL, MDEC(fFalseRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF(MDEC(S_OK), MDEC(fFalseRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_EXPECTED(MDEC(S_OK), MDEC(fFalseRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF(MDEC(S_OK), MDEC(fFalseRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF(MDEC(S_OK), MDEC(fFalseRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF(MDEC(S_OK), MDEC(fFalseRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF_MSG(MDEC(S_OK), MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF(E_FAIL, MDEC(fFalseRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_EXPECTED(E_FAIL, MDEC(fFalseRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF(E_FAIL, MDEC(fFalseRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(fFalse == THROW_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF(E_FAIL, MDEC(fFalseRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(fFalse == LOG_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF(E_FAIL, MDEC(fFalseRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_HR_IF_MSG(E_FAIL, MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF_NULL(S_OK, pNull); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_NULL_MSG(S_OK, pNull, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_NULL_EXPECTED(S_OK, pNull); return S_OK; }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_IF_NULL(S_OK, pNull); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_HR_IF_NULL_MSG(S_OK, pNull, "msg: %d", __LINE__); }); + REQUIRE_RETURNS(E_FAIL, [] { RETURN_HR_IF_NULL(E_FAIL, pNull); return S_OK; }); + REQUIRE_RETURNS_MSG(E_FAIL, [] { RETURN_HR_IF_NULL_MSG(E_FAIL, pNull, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_HR_IF_NULL_EXPECTED(E_FAIL, pNull); return S_OK; }); + REQUIRE_THROWS_RESULT(E_FAIL, [] { THROW_HR_IF_NULL(E_FAIL, pNull); }); + REQUIRE_THROWS_MSG(E_FAIL, [] { THROW_HR_IF_NULL_MSG(E_FAIL, pNull, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(pNull == LOG_HR_IF_NULL(E_FAIL, pNull)); }); + REQUIRE_LOG_MSG(E_FAIL, [] { REQUIRE(pNull == LOG_HR_IF_NULL_MSG(E_FAIL, pNull, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_FAIL, [] { FAIL_FAST_HR_IF_NULL(E_FAIL, pNull); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { FAIL_FAST_HR_IF_NULL_MSG(E_FAIL, pNull, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF_NULL(MDEC(S_OK), MDEC(pValidRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_NULL_MSG(MDEC(S_OK), MDEC(pValidRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_NULL_EXPECTED(MDEC(S_OK), MDEC(pValidRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(pValid == THROW_HR_IF_NULL(MDEC(S_OK), MDEC(pValidRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(pValid == THROW_HR_IF_NULL_MSG(MDEC(S_OK), MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(pValid == LOG_HR_IF_NULL(MDEC(S_OK), MDEC(pValidRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(pValid == LOG_HR_IF_NULL_MSG(MDEC(S_OK), MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(pValid == FAIL_FAST_HR_IF_NULL(MDEC(S_OK), MDEC(pValidRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(pValid == FAIL_FAST_HR_IF_NULL_MSG(MDEC(S_OK), MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_RETURNS(S_OK, [] { RETURN_HR_IF_NULL(E_FAIL, MDEC(pValidRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_HR_IF_NULL_MSG(E_FAIL, MDEC(pValidRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_HR_IF_NULL_EXPECTED(E_FAIL, MDEC(pValidRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(pValid == THROW_HR_IF_NULL(E_FAIL, MDEC(pValidRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(pValid == THROW_HR_IF_NULL_MSG(E_FAIL, MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(pValid == LOG_HR_IF_NULL(E_FAIL, MDEC(pValidRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(pValid == LOG_HR_IF_NULL_MSG(E_FAIL, MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(pValid == FAIL_FAST_HR_IF_NULL(E_FAIL, MDEC(pValidRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(pValid == FAIL_FAST_HR_IF_NULL_MSG(E_FAIL, MDEC(pValidRef()), "msg: %d", __LINE__)); }); + + REQUIRE_FAILFAST_UNSPECIFIED([] { ::SetLastError(0); FAIL_FAST_LAST_ERROR_IF(fTrue); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { ::SetLastError(0); FAIL_FAST_LAST_ERROR_IF_MSG(fTrue, "msg: %d", __LINE__); }); + REQUIRE_RETURNS(E_AD, [] { SetAD(); RETURN_LAST_ERROR_IF(fTrue); return S_OK; }); + REQUIRE_RETURNS_MSG(E_AD, [] { SetAD(); RETURN_LAST_ERROR_IF_MSG(fTrue, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_AD, [] { SetAD(); RETURN_LAST_ERROR_IF_EXPECTED(fTrue); return S_OK; }); + REQUIRE_THROWS_RESULT(E_AD, [] { SetAD(); THROW_LAST_ERROR_IF(fTrue); }); + REQUIRE_THROWS_MSG(E_AD, [] { SetAD(); THROW_LAST_ERROR_IF_MSG(fTrue, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_AD, [] { SetAD(); REQUIRE(fTrue == LOG_LAST_ERROR_IF(fTrue)); }); + REQUIRE_LOG_MSG(E_AD, [] { SetAD(); REQUIRE(fTrue == LOG_LAST_ERROR_IF_MSG(fTrue, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_AD, [] { SetAD(); FAIL_FAST_LAST_ERROR_IF(fTrue); }); + REQUIRE_FAILFAST_MSG(E_AD, [] { SetAD(); FAIL_FAST_LAST_ERROR_IF_MSG(fTrue, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_LAST_ERROR_IF(MDEC(fFalseRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_LAST_ERROR_IF_MSG(MDEC(fFalseRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_LAST_ERROR_IF_EXPECTED(MDEC(fFalseRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(fFalse == THROW_LAST_ERROR_IF(MDEC(fFalseRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(fFalse == THROW_LAST_ERROR_IF_MSG(MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(fFalse == LOG_LAST_ERROR_IF(MDEC(fFalseRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(fFalse == LOG_LAST_ERROR_IF_MSG(MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_LAST_ERROR_IF(MDEC(fFalseRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_LAST_ERROR_IF_MSG(MDEC(fFalseRef()), "msg: %d", __LINE__)); }); + + REQUIRE_FAILFAST_UNSPECIFIED([] { ::SetLastError(0); FAIL_FAST_LAST_ERROR_IF_NULL(pNull); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { ::SetLastError(0); FAIL_FAST_LAST_ERROR_IF_NULL_MSG(pNull, "msg: %d", __LINE__); }); + REQUIRE_RETURNS(E_AD, [] { SetAD(); RETURN_LAST_ERROR_IF_NULL(pNull); return S_OK; }); + REQUIRE_RETURNS_MSG(E_AD, [] { SetAD(); RETURN_LAST_ERROR_IF_NULL_MSG(pNull, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_AD, [] { SetAD(); RETURN_LAST_ERROR_IF_NULL_EXPECTED(pNull); return S_OK; }); + REQUIRE_THROWS_RESULT(E_AD, [] { SetAD(); THROW_LAST_ERROR_IF_NULL(pNull); }); + REQUIRE_THROWS_MSG(E_AD, [] { SetAD(); THROW_LAST_ERROR_IF_NULL_MSG(pNull, "msg: %d", __LINE__); }); + REQUIRE_LOG(E_AD, [] { SetAD(); REQUIRE(pNull == LOG_LAST_ERROR_IF_NULL(pNull)); }); + REQUIRE_LOG_MSG(E_AD, [] { SetAD(); REQUIRE(pNull == LOG_LAST_ERROR_IF_NULL_MSG(pNull, "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(E_AD, [] { SetAD(); FAIL_FAST_LAST_ERROR_IF_NULL(pNull); }); + REQUIRE_FAILFAST_MSG(E_AD, [] { SetAD(); FAIL_FAST_LAST_ERROR_IF_NULL_MSG(pNull, "msg: %d", __LINE__); }); + + REQUIRE_RETURNS(S_OK, [] { RETURN_LAST_ERROR_IF_NULL(MDEC(pValidRef())); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { RETURN_LAST_ERROR_IF_NULL_MSG(MDEC(pValidRef()), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_LAST_ERROR_IF_NULL_EXPECTED(MDEC(pValidRef())); return S_OK; }); + REQUIRE_THROWS_RESULT(S_OK, [] { REQUIRE(pNull != THROW_LAST_ERROR_IF_NULL(MDEC(pValidRef()))); }); + REQUIRE_THROWS_MSG(S_OK, [] { REQUIRE(pNull != THROW_LAST_ERROR_IF_NULL_MSG(MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(pNull != LOG_LAST_ERROR_IF_NULL(MDEC(pValidRef()))); }); + REQUIRE_LOG_MSG(S_OK, [] { REQUIRE(pNull != LOG_LAST_ERROR_IF_NULL_MSG(MDEC(pValidRef()), "msg: %d", __LINE__)); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(pNull != FAIL_FAST_LAST_ERROR_IF_NULL(MDEC(pValidRef()))); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { REQUIRE(pNull != FAIL_FAST_LAST_ERROR_IF_NULL_MSG(MDEC(pValidRef()), "msg: %d", __LINE__)); }); + + REQUIRE_LOG(S_OK, [] { REQUIRE(true == SUCCEEDED_LOG(MDEC(S_OK))); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(false == SUCCEEDED_LOG(E_FAIL)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(false == FAILED_LOG(MDEC(S_OK))); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(true == FAILED_LOG(E_FAIL)); }); + + REQUIRE_LOG(ERROR_SUCCESS, [] { REQUIRE(true == SUCCEEDED_WIN32_LOG(MDEC(ERROR_SUCCESS))); }); + REQUIRE_LOG(HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED), [] { REQUIRE(false == SUCCEEDED_WIN32_LOG(ERROR_ACCESS_DENIED)); }); + REQUIRE_LOG(ERROR_SUCCESS, [] { REQUIRE(false == FAILED_WIN32_LOG(MDEC(ERROR_SUCCESS))); }); + REQUIRE_LOG(HRESULT_FROM_WIN32(ERROR_ACCESS_DENIED), [] { REQUIRE(true == FAILED_WIN32_LOG(ERROR_ACCESS_DENIED)); }); + + REQUIRE_LOG(ntOK, [] { REQUIRE(true == SUCCEEDED_NTSTATUS_LOG(MDEC(ntOK))); }); + REQUIRE_LOG(wil::details::NtStatusToHr(ntFAIL), [] { REQUIRE(false == SUCCEEDED_NTSTATUS_LOG(ntFAIL)); }); + REQUIRE_LOG(ntOK, [] { REQUIRE(false == FAILED_NTSTATUS_LOG(MDEC(ntOK))); }); + REQUIRE_LOG(wil::details::NtStatusToHr(ntFAIL), [] { REQUIRE(true == FAILED_NTSTATUS_LOG(ntFAIL)); }); + + // FAIL_FAST_IMMEDIATE* directly invokes __fastfail, which we can't catch, so disabled for now + // REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_IMMEDIATE(); }); + // REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_IMMEDIATE_IF_FAILED(E_FAIL); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(S_OK == FAIL_FAST_IMMEDIATE_IF_FAILED(MDEC(S_OK))); }); + // REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_IMMEDIATE_IF(fTrue); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(fFalse == FAIL_FAST_IMMEDIATE_IF(MDEC(fFalseRef()))); }); + // REQUIRE_FAILFAST_UNSPECIFIED([] { FAIL_FAST_IMMEDIATE_IF_NULL(pNull); }); + REQUIRE_FAILFAST(S_OK, [] { REQUIRE(pValid == FAIL_FAST_IMMEDIATE_IF_NULL(MDEC(pValidRef()))); }); + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_RETURNS(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_RETURN(); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_RETURN_MSG("msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_RETURN_EXPECTED(); return S_OK; }); + REQUIRE_LOG(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_LOG(); }); + REQUIRE_LOG_MSG(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_LOG_MSG("msg: %d", __LINE__); }); + REQUIRE_FAILFAST(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_FAIL_FAST(); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_FAIL_FAST_MSG("msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_THROW_NORMALIZED(); }); + REQUIRE_THROWS_MSG(S_OK, [] { try { THROW_IF_FAILED(hrOK); } CATCH_THROW_NORMALIZED_MSG("msg: %d", __LINE__); }); + + REQUIRE_RETURNS(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_RETURN(); return S_OK; }); + REQUIRE_RETURNS_MSG(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_RETURN_MSG("msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_RETURN_EXPECTED(); return S_OK; }); + REQUIRE_LOG(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_LOG(); }); + REQUIRE_LOG_MSG(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_LOG_MSG("msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_FAIL_FAST(); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_FAIL_FAST_MSG("msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_THROW_NORMALIZED(); }); + REQUIRE_THROWS_MSG(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_THROW_NORMALIZED_MSG("msg: %d", __LINE__); }); + + REQUIRE_FAILFAST_UNSPECIFIED([] { try { if (FAILED(hrFAIL)) { throw E_FAIL; } } CATCH_FAIL_FAST(); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { try { if (FAILED(hrFAIL)) { throw E_FAIL; } } CATCH_FAIL_FAST_MSG("msg: %d", __LINE__); }); + + REQUIRE_THROWS_RESULT(E_AD, [] { THROW_EXCEPTION(MDEC(DerivedAccessDeniedException())); }); + REQUIRE_THROWS_MSG(E_AD, [] { THROW_EXCEPTION_MSG(MDEC(DerivedAccessDeniedException()), "msg: %d", __LINE__); }); + + REQUIRE_LOG(E_AD, [] { try { throw AlternateAccessDeniedException(); } CATCH_LOG(); }); + REQUIRE_THROWS_RESULT(E_AD, [] { try { throw AlternateAccessDeniedException(); } CATCH_THROW_NORMALIZED(); }); + + REQUIRE_RETURNS(S_OK, [] { return wil::ResultFromException([] { THROW_IF_FAILED(hrOK); }); }); + REQUIRE_RETURNS(E_FAIL, [] { return wil::ResultFromException([] { THROW_IF_FAILED(hrFAIL); }); }); + REQUIRE(E_AD == wil::ResultFromException([] { throw AlternateAccessDeniedException(); })); + + try { THROW_HR(E_FAIL); } + catch (...) { REQUIRE(E_FAIL == wil::ResultFromCaughtException()); }; +#endif + +#ifdef WIL_ENABLE_EXCEPTIONS + REQUIRE_LOG(E_FAIL, [] { try { THROW_IF_FAILED(hrFAIL); } CATCH_LOG(); }); +#endif + + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { std::unique_ptr pInt; RETURN_IF_NULL_ALLOC(MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; RETURN_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_OUTOFMEMORY, [] { std::unique_ptr pInt; RETURN_IF_NULL_ALLOC_EXPECTED(MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS(S_OK, [] { std::unique_ptr pInt(new int(5)); RETURN_IF_NULL_ALLOC(MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); RETURN_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { std::unique_ptr pInt(new int(5)); RETURN_IF_NULL_ALLOC_EXPECTED(MDEC(pInt)); return S_OK; }); + + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { std::unique_ptr pInt; RETURN_HR_IF_NULL(E_OUTOFMEMORY, MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; RETURN_HR_IF_NULL_MSG(E_OUTOFMEMORY, pInt, "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_OUTOFMEMORY, [] { std::unique_ptr pInt; RETURN_HR_IF_NULL_EXPECTED(E_OUTOFMEMORY, MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS(S_OK, [] { std::unique_ptr pInt(new int(5)); RETURN_HR_IF_NULL(E_OUTOFMEMORY, MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); RETURN_HR_IF_NULL_MSG(E_OUTOFMEMORY, MDEC(pInt), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { std::unique_ptr pInt(new int(5)); RETURN_HR_IF_NULL_EXPECTED(E_OUTOFMEMORY, MDEC(pInt)); return S_OK; }); + + REQUIRE_RETURNS(E_AD, [] { std::unique_ptr pInt; SetAD(); RETURN_LAST_ERROR_IF_NULL(MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS_MSG(E_AD, [] { std::unique_ptr pInt; SetAD(); RETURN_LAST_ERROR_IF_NULL_MSG(MDEC(pInt), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_AD, [] { std::unique_ptr pInt; SetAD(); RETURN_LAST_ERROR_IF_NULL_EXPECTED(MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS(S_OK, [] { std::unique_ptr pInt(new int(5)); SetAD(); RETURN_LAST_ERROR_IF_NULL(MDEC(pInt)); return S_OK; }); + REQUIRE_RETURNS_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); SetAD(); RETURN_LAST_ERROR_IF_NULL_MSG(MDEC(pInt), "msg: %d", __LINE__); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(S_OK, [] { std::unique_ptr pInt(new int(5)); SetAD(); RETURN_LAST_ERROR_IF_NULL_EXPECTED(MDEC(pInt)); return S_OK; }); + + REQUIRE_THROWS_RESULT(E_OUTOFMEMORY, [] { std::unique_ptr pInt; THROW_IF_NULL_ALLOC(MDEC(pInt)); }); + REQUIRE_THROWS_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; THROW_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_LOG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; LOG_IF_NULL_ALLOC(MDEC(pInt)); }); + REQUIRE_LOG_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; LOG_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_OUTOFMEMORY, [] { std::unique_ptr pInt; FAIL_FAST_IF_NULL_ALLOC(MDEC(pInt)); }); + REQUIRE_FAILFAST_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; FAIL_FAST_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(S_OK, [] { std::unique_ptr pInt(new int(5)); THROW_IF_NULL_ALLOC(MDEC(pInt)); }); + REQUIRE_THROWS_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); THROW_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_LOG(S_OK, [] { std::unique_ptr pInt(new int(5)); LOG_IF_NULL_ALLOC(MDEC(pInt)); }); + REQUIRE_LOG_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); LOG_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_IF_NULL_ALLOC(MDEC(pInt)); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_IF_NULL_ALLOC_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + + REQUIRE_LOG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(pInt)); }); + REQUIRE_LOG_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; LOG_HR_IF_NULL_MSG(MDEC(E_OUTOFMEMORY), MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_FAIL, [] { std::unique_ptr pInt; FAIL_FAST_HR_IF_NULL(MDEC(E_FAIL), MDEC(pInt)); }); + REQUIRE_FAILFAST_MSG(E_FAIL, [] { std::unique_ptr pInt; FAIL_FAST_HR_IF_NULL_MSG(MDEC(E_FAIL), MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_OUTOFMEMORY, [] { std::unique_ptr pInt; THROW_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(pInt)); }); + REQUIRE_THROWS_MSG(E_OUTOFMEMORY, [] { std::unique_ptr pInt; THROW_HR_IF_NULL_MSG(MDEC(E_OUTOFMEMORY), MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_LOG(S_OK, [] { std::unique_ptr pInt(new int(5)); LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(pInt)); }); + REQUIRE_LOG_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); LOG_HR_IF_NULL_MSG(MDEC(E_OUTOFMEMORY), MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_HR_IF_NULL(MDEC(E_FAIL), MDEC(pInt)); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_HR_IF_NULL_MSG(MDEC(E_FAIL), MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(S_OK, [] { std::unique_ptr pInt(new int(5)); THROW_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(pInt)); }); + REQUIRE_THROWS_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); THROW_HR_IF_NULL_MSG(MDEC(E_OUTOFMEMORY), MDEC(pInt), "msg: %d", __LINE__); }); + + REQUIRE_LOG(E_AD, [] { std::unique_ptr pInt; SetAD(); LOG_LAST_ERROR_IF_NULL(MDEC(pInt)); }); + REQUIRE_LOG_MSG(E_AD, [] { std::unique_ptr pInt; SetAD(); LOG_LAST_ERROR_IF_NULL_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(E_AD, [] { std::unique_ptr pInt; SetAD(); FAIL_FAST_LAST_ERROR_IF_NULL(pInt); }); + REQUIRE_FAILFAST_MSG(E_AD, [] { std::unique_ptr pInt; SetAD(); FAIL_FAST_LAST_ERROR_IF_NULL_MSG(pInt, "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(E_AD, [] { std::unique_ptr pInt; SetAD(); THROW_LAST_ERROR_IF_NULL(MDEC(pInt)); }); + REQUIRE_THROWS_MSG(E_AD, [] { std::unique_ptr pInt; SetAD(); THROW_LAST_ERROR_IF_NULL_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_LOG(S_OK, [] { std::unique_ptr pInt(new int(5)); LOG_LAST_ERROR_IF_NULL(MDEC(pInt)); }); + REQUIRE_LOG_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); LOG_LAST_ERROR_IF_NULL_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + REQUIRE_FAILFAST(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_LAST_ERROR_IF_NULL(pInt); }); + REQUIRE_FAILFAST_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_LAST_ERROR_IF_NULL_MSG(pInt, "msg: %d", __LINE__); }); + REQUIRE_THROWS_RESULT(S_OK, [] { std::unique_ptr pInt(new int(5)); THROW_LAST_ERROR_IF_NULL(MDEC(pInt)); }); + REQUIRE_THROWS_MSG(S_OK, [] { std::unique_ptr pInt(new int(5)); THROW_LAST_ERROR_IF_NULL_MSG(MDEC(pInt), "msg: %d", __LINE__); }); + + // REQUIRE_FAILFAST_UNSPECIFIED([] { std::unique_ptr pInt; FAIL_FAST_IMMEDIATE_IF_NULL(pNull); }); + REQUIRE_FAILFAST(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_IMMEDIATE_IF_NULL(MDEC(pValidRef())); }); + REQUIRE_FAILFAST_UNSPECIFIED([] { std::unique_ptr pInt; FAIL_FAST_IF_NULL(pNull); }); + REQUIRE_FAILFAST(S_OK, [] { std::unique_ptr pInt(new int(5)); FAIL_FAST_IF_NULL(MDEC(pInt)); }); + + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { Microsoft::WRL::ComPtr ptr; RETURN_IF_NULL_ALLOC(MDEC(ptr)); return S_OK; }); + REQUIRE_LOG(E_OUTOFMEMORY, [] { Microsoft::WRL::ComPtr ptr; LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(ptr)); }); + + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { std::shared_ptr ptr; RETURN_IF_NULL_ALLOC(MDEC(ptr)); return S_OK; }); + REQUIRE_LOG(E_OUTOFMEMORY, [] { std::shared_ptr ptr; LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(ptr)); }); + REQUIRE_RETURNS(S_OK, [] { std::shared_ptr ptr(new int(5)); RETURN_IF_NULL_ALLOC(MDEC(ptr)); return S_OK; }); + REQUIRE_LOG(S_OK, [] { std::shared_ptr ptr(new int(5)); LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(ptr)); }); + +#ifdef __cplusplus_winrt + REQUIRE_RETURNS(E_OUTOFMEMORY, [] { Platform::String^ str(nullptr); RETURN_IF_NULL_ALLOC(MDEC(str)); return S_OK; }); + REQUIRE_LOG(E_OUTOFMEMORY, [] { Platform::String^ str(nullptr); LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(str)); }); + REQUIRE_RETURNS(S_OK, [] { Platform::String^ str(L"a"); RETURN_IF_NULL_ALLOC(MDEC(str)); return S_OK; }); + REQUIRE_LOG(S_OK, [] { Platform::String^ str(L"a"); LOG_HR_IF_NULL(MDEC(E_OUTOFMEMORY), MDEC(str)); }); +#endif +} + +#define WRAP_LAMBDA(code) [&] {code;}; + +//these macros should all have compile errors due to use of an invalid type +void InvalidTypeChecks() +{ + std::unique_ptr boolCastClass; + std::vector noBoolCastClass; + + //WRAP_LAMBDA(RETURN_IF_FAILED(fTrue)); + //WRAP_LAMBDA(RETURN_IF_FAILED(fTRUE)); + //WRAP_LAMBDA(RETURN_IF_FAILED(boolCastClass)); + //WRAP_LAMBDA(RETURN_IF_FAILED(noBoolCastClass)); + //WRAP_LAMBDA(RETURN_IF_FAILED(errSuccess)); + + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE(fTrue)); + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE(noBoolCastClass)); + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE(hrOK)); + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE(errSuccess)); + + //WRAP_LAMBDA(RETURN_HR_IF(errSuccess, false)); + //WRAP_LAMBDA(RETURN_HR_IF(errSuccess, true)); + //WRAP_LAMBDA(RETURN_HR_IF(hrOK, noBoolCastClass)); + //WRAP_LAMBDA(RETURN_HR_IF(hrOK, hrOK)); + //WRAP_LAMBDA(RETURN_HR_IF(hrOK, errSuccess)); + + //WRAP_LAMBDA(RETURN_HR_IF_NULL(errSuccess, nullptr)); + //WRAP_LAMBDA(RETURN_HR_IF_NULL(errSuccess, pValid)); + + //WRAP_LAMBDA(RETURN_LAST_ERROR_IF(noBoolCastClass)); + //WRAP_LAMBDA(RETURN_LAST_ERROR_IF(errSuccess)); + //WRAP_LAMBDA(RETURN_LAST_ERROR_IF(hrOK)); + + //WRAP_LAMBDA(RETURN_IF_FAILED_EXPECTED(fTrue)); + //WRAP_LAMBDA(RETURN_IF_FAILED_EXPECTED(fTRUE)); + //WRAP_LAMBDA(RETURN_IF_FAILED_EXPECTED(boolCastClass)); + //WRAP_LAMBDA(RETURN_IF_FAILED_EXPECTED(noBoolCastClass)); + //WRAP_LAMBDA(RETURN_IF_FAILED_EXPECTED(errSuccess)); + + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(fTrue)); + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(noBoolCastClass)); + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(hrOK)); + //WRAP_LAMBDA(RETURN_IF_WIN32_BOOL_FALSE_EXPECTED(errSuccess)); + + //LOG_IF_FAILED(fTrue); + //LOG_IF_FAILED(fTRUE); + //LOG_IF_FAILED(boolCastClass); + //LOG_IF_FAILED(noBoolCastClass); + //LOG_IF_FAILED(errSuccess); + + //LOG_IF_WIN32_BOOL_FALSE(fTrue); + //LOG_IF_WIN32_BOOL_FALSE(noBoolCastClass); + //LOG_IF_WIN32_BOOL_FALSE(hrOK); + //LOG_IF_WIN32_BOOL_FALSE(errSuccess); + + //LOG_HR_IF(errSuccess, false); + //LOG_HR_IF(errSuccess, true); + //LOG_HR_IF(hrOK, noBoolCastClass); + //LOG_HR_IF(hrOK, hrOK); + //LOG_HR_IF(hrOK, errSuccess); + + //FAIL_FAST_IF_FAILED(fTrue); + //FAIL_FAST_IF_FAILED(fTRUE); + //FAIL_FAST_IF_FAILED(boolCastClass); + //FAIL_FAST_IF_FAILED(noBoolCastClass); + //FAIL_FAST_IF_FAILED(errSuccess); + + //FAIL_FAST_IF_WIN32_BOOL_FALSE(fTrue); + //FAIL_FAST_IF_WIN32_BOOL_FALSE(noBoolCastClass); + //FAIL_FAST_IF_WIN32_BOOL_FALSE(hrOK); + //FAIL_FAST_IF_WIN32_BOOL_FALSE(errSuccess); + + //FAIL_FAST_HR_IF(errSuccess, false); + //FAIL_FAST_HR_IF(errSuccess, true); + //FAIL_FAST_HR_IF(hrOK, noBoolCastClass); + //FAIL_FAST_HR_IF(hrOK, hrOK); + //FAIL_FAST_HR_IF(hrOK, errSuccess); + + //THROW_IF_FAILED(fTrue); + //THROW_IF_FAILED(fTRUE); + //THROW_IF_FAILED(boolCastClass); + //THROW_IF_FAILED(noBoolCastClass); + //THROW_IF_FAILED(errSuccess); + + //THROW_IF_WIN32_BOOL_FALSE(fTrue); + //THROW_IF_WIN32_BOOL_FALSE(noBoolCastClass); + //THROW_IF_WIN32_BOOL_FALSE(hrOK); + //THROW_IF_WIN32_BOOL_FALSE(errSuccess); + + //THROW_HR_IF(errSuccess, false); + //THROW_HR_IF(errSuccess, true); + //THROW_HR_IF(hrOK, noBoolCastClass); + //THROW_HR_IF(hrOK, hrOK); + //THROW_HR_IF(hrOK, errSuccess); + + //FAIL_FAST_IF(noBoolCastClass); + //FAIL_FAST_IF(hrOK); + //FAIL_FAST_IF(errSuccess); + + //FAIL_FAST_IMMEDIATE_IF_FAILED(fTrue); + //FAIL_FAST_IMMEDIATE_IF_FAILED(fTRUE); + //FAIL_FAST_IMMEDIATE_IF_FAILED(boolCastClass); + //FAIL_FAST_IMMEDIATE_IF_FAILED(noBoolCastClass); + //FAIL_FAST_IMMEDIATE_IF_FAILED(errSuccess); + + //FAIL_FAST_IMMEDIATE_IF(noBoolCastClass); + //FAIL_FAST_IMMEDIATE_IF(hrOK); + //FAIL_FAST_IMMEDIATE_IF(errSuccess); +} + +TEST_CASE("WindowsInternalTests::UniqueHandle", "[resource][unique_any]") +{ + { + // default construction test + wil::unique_handle spHandle; + REQUIRE(spHandle.get() == nullptr); + + // null ptr assignment creation + wil::unique_handle spNullHandle = nullptr; + REQUIRE(spNullHandle.get() == nullptr); + + // explicit construction from the invalid value + wil::unique_handle spInvalidHandle(nullptr); + REQUIRE(spInvalidHandle.get() == nullptr); + + // valid handle creation + wil::unique_handle spValidHandle(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + REQUIRE(spValidHandle.get() != nullptr); + auto const handleValue = spValidHandle.get(); + + // r-value construction + wil::unique_handle spMoveHandle = wistd::move(spValidHandle); + REQUIRE(spValidHandle.get() == nullptr); + REQUIRE(spMoveHandle.get() == handleValue); + + // nullptr-assignment + spNullHandle = nullptr; + REQUIRE(spNullHandle.get() == nullptr); + + // r-value assignment + spValidHandle = wistd::move(spMoveHandle); + REQUIRE(spValidHandle.get() == handleValue); + REQUIRE(spMoveHandle.get() == nullptr); + + // swap + spValidHandle.swap(spMoveHandle); + REQUIRE(spValidHandle.get() == nullptr); + REQUIRE(spMoveHandle.get() == handleValue); + + // operator bool + REQUIRE_FALSE(spValidHandle); + REQUIRE(spMoveHandle); + + // release + auto ptrValidHandle = spValidHandle.release(); + auto ptrMoveHandle = spMoveHandle.release(); + REQUIRE(ptrValidHandle == nullptr); + REQUIRE(ptrMoveHandle == handleValue); + REQUIRE(spValidHandle.get() == nullptr); + REQUIRE(spMoveHandle.get() == nullptr); + + // reset + spValidHandle.reset(); + spMoveHandle.reset(); + REQUIRE(spValidHandle.get() == nullptr); + REQUIRE(spMoveHandle.get() == nullptr); + spValidHandle.reset(ptrValidHandle); + spMoveHandle.reset(ptrMoveHandle); + REQUIRE(spValidHandle.get() == nullptr); + REQUIRE(spMoveHandle.get() == handleValue); + spNullHandle.reset(nullptr); + REQUIRE(spNullHandle.get() == nullptr); + + // address + REQUIRE(*spMoveHandle.addressof() == handleValue); + REQUIRE(*(&spMoveHandle) == nullptr); + *(&spMoveHandle) = ::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0); + REQUIRE(spMoveHandle); + } + { + // default construction test + wil::unique_hfile spHandle; + REQUIRE(spHandle.get() == INVALID_HANDLE_VALUE); + + // implicit construction from the invalid value + wil::unique_hfile spNullHandle; // = nullptr; // method explicitly disabled as nullptr isn't the invalid value + REQUIRE(spNullHandle.get() == INVALID_HANDLE_VALUE); + + // assignment from the invalid value + // spNullHandle = nullptr; // method explicitly disabled as nullptr isn't the invalid value + REQUIRE(spNullHandle.get() == INVALID_HANDLE_VALUE); + + // explicit construction from the invalid value + wil::unique_hfile spInvalidHandle(INVALID_HANDLE_VALUE); + REQUIRE(spInvalidHandle.get() == INVALID_HANDLE_VALUE); + + // valid handle creation + wchar_t tempFileName[MAX_PATH]; + REQUIRE_SUCCEEDED(witest::GetTempFileName(tempFileName)); + + CREATEFILE2_EXTENDED_PARAMETERS params = { sizeof(params) }; + params.dwFileAttributes = FILE_ATTRIBUTE_TEMPORARY; + wil::unique_hfile spValidHandle(::CreateFile2(tempFileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_DELETE, CREATE_ALWAYS, ¶ms)); + + ::DeleteFileW(tempFileName); + REQUIRE(spValidHandle.get() != INVALID_HANDLE_VALUE); + auto const handleValue = spValidHandle.get(); + + // r-value construction + wil::unique_hfile spMoveHandle = wistd::move(spValidHandle); + REQUIRE(spValidHandle.get() == INVALID_HANDLE_VALUE); + REQUIRE(spMoveHandle.get() == handleValue); + + // nullptr-assignment -- uncomment to check intentional compilation error + // spNullHandle = nullptr; + + // r-value assignment + spValidHandle = wistd::move(spMoveHandle); + REQUIRE(spValidHandle.get() == handleValue); + REQUIRE(spMoveHandle.get() == INVALID_HANDLE_VALUE); + + // swap + spValidHandle.swap(spMoveHandle); + REQUIRE(spValidHandle.get() == INVALID_HANDLE_VALUE); + REQUIRE(spMoveHandle.get() == handleValue); + + // operator bool + REQUIRE_FALSE(spValidHandle); + REQUIRE(spMoveHandle); + + // release + auto ptrValidHandle = spValidHandle.release(); + auto ptrMoveHandle = spMoveHandle.release(); + REQUIRE(ptrValidHandle == INVALID_HANDLE_VALUE); + REQUIRE(ptrMoveHandle == handleValue); + REQUIRE(spValidHandle.get() == INVALID_HANDLE_VALUE); + REQUIRE(spMoveHandle.get() == INVALID_HANDLE_VALUE); + + // reset + spValidHandle.reset(); + spMoveHandle.reset(); + REQUIRE(spValidHandle.get() == INVALID_HANDLE_VALUE); + REQUIRE(spMoveHandle.get() == INVALID_HANDLE_VALUE); + spValidHandle.reset(ptrValidHandle); + spMoveHandle.reset(ptrMoveHandle); + REQUIRE(spValidHandle.get() == INVALID_HANDLE_VALUE); + REQUIRE(spMoveHandle.get() == handleValue); + // uncomment to test intentional compilation error due to conflict with INVALID_HANDLE_VALUE + // spNullHandle.reset(nullptr); + + // address + REQUIRE(*spMoveHandle.addressof() == handleValue); + REQUIRE(*(&spMoveHandle) == INVALID_HANDLE_VALUE); + + wchar_t tempFileName2[MAX_PATH]; + REQUIRE_SUCCEEDED(witest::GetTempFileName(tempFileName2)); + + CREATEFILE2_EXTENDED_PARAMETERS params2 = { sizeof(params2) }; + params2.dwFileAttributes = FILE_ATTRIBUTE_TEMPORARY; + *(&spMoveHandle) = ::CreateFile2(tempFileName2, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_DELETE, CREATE_ALWAYS, ¶ms2); + + ::DeleteFileW(tempFileName2); + REQUIRE(spMoveHandle); + + // ensure that mistaken nullptr usage is not valid... + spMoveHandle.reset(); + *(&spMoveHandle) = nullptr; + REQUIRE_FALSE(spMoveHandle); + } + + auto hFirst = ::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0); + auto hSecond= ::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0); + + wil::unique_handle spLeft(hFirst); + wil::unique_handle spRight(hSecond); + + REQUIRE(spRight.get() == hSecond); + REQUIRE(spLeft.get() == hFirst); + swap(spLeft, spRight); + REQUIRE(spLeft.get() == hSecond); + REQUIRE(spRight.get() == hFirst); + swap(spLeft, spRight); + + REQUIRE((spLeft.get() == spRight.get()) == (spLeft == spRight)); + REQUIRE((spLeft.get() != spRight.get()) == (spLeft != spRight)); + REQUIRE((spLeft.get() < spRight.get()) == (spLeft < spRight)); + REQUIRE((spLeft.get() <= spRight.get()) == (spLeft <= spRight)); + REQUIRE((spLeft.get() >= spRight.get()) == (spLeft >= spRight)); + REQUIRE((spLeft.get() > spRight.get()) == (spLeft > spRight)); + + // test stl container use (hash & std::less) +#ifdef WIL_ENABLE_EXCEPTIONS + std::unordered_set hashSet; + hashSet.insert(std::move(spLeft)); + hashSet.insert(std::move(spRight)); + std::multiset set; + set.insert(std::move(spLeft)); + set.insert(std::move(spRight)); +#endif +} + +#ifdef WIL_ENABLE_EXCEPTIONS +TEST_CASE("WindowsInternalTests::SharedHandle", "[resource][shared_any]") +{ + // default construction + wil::shared_handle spHandle; + REQUIRE(spHandle.get() == nullptr); + + // pointer construction + wil::shared_handle spValid(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + auto ptr = spValid.get(); + REQUIRE(spValid.get() != nullptr); + + // null construction + wil::shared_handle spNull = nullptr; + REQUIRE(spNull.get() == nullptr); + + // Present to verify that it doesn't compile (disabled) + // wil::shared_hfile spFile = nullptr; + + // copy construction + wil::shared_handle spCopy = spValid; + REQUIRE(spCopy.get() == ptr); + + // r-value construction + wil::shared_handle spMove = wistd::move(spCopy); + REQUIRE(spMove.get() == ptr); + REQUIRE(spCopy.get() == nullptr); + + // unique handle construction + wil::shared_handle spFromUnique = wil::unique_handle(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + REQUIRE(spFromUnique.get() != nullptr); + + // direct assignment + wil::shared_handle spAssign; + spAssign = spValid; + REQUIRE(spAssign.get() == ptr); + + // empty reset + spFromUnique.reset(); + REQUIRE(spFromUnique.get() == nullptr); + + // reset against unique ptr + spFromUnique.reset(wil::unique_handle(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0))); + REQUIRE(spFromUnique.get() != nullptr); + + // reset against raw pointer + spAssign.reset(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + REQUIRE(spAssign.get() != nullptr); + REQUIRE(spAssign.get() != ptr); + + // ref-count checks + REQUIRE(spAssign.use_count() == 1); + + // bool operator + REQUIRE(spAssign); + spAssign.reset(); + REQUIRE_FALSE(spAssign); + + // swap and compare + wil::shared_handle sp1(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + wil::shared_handle sp2(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + auto ptr1 = sp1.get(); + auto ptr2 = sp2.get(); + sp1.swap(sp2); + REQUIRE(sp1.get() == ptr2); + REQUIRE(sp2.get() == ptr1); + swap(sp1, sp2); + REQUIRE(sp1.get() == ptr1); + REQUIRE(sp2.get() == ptr2); + REQUIRE((ptr1 == ptr2) == (sp1 == sp2)); + REQUIRE((ptr1 != ptr2) == (sp1 != sp2)); + REQUIRE((ptr1 < ptr2) == (sp1 < sp2)); + REQUIRE((ptr1 <= ptr2) == (sp1 <= sp2)); + REQUIRE((ptr1 > ptr2) == (sp1 > sp2)); + REQUIRE((ptr1 >= ptr2) == (sp1 >= sp2)); + + // construction + wil::weak_handle wh; + REQUIRE_FALSE(wh.lock()); + wil::weak_handle wh1 = sp1; + REQUIRE(wh1.lock()); + REQUIRE(wh1.lock().get() == ptr1); + wil::weak_handle wh1copy = wh1; + REQUIRE(wh1copy.lock()); + + // assignment + wh = wh1; + REQUIRE(wh.lock().get() == ptr1); + wh = sp2; + REQUIRE(wh.lock().get() == ptr2); + + // reset + wh.reset(); + REQUIRE_FALSE(wh.lock()); + + // expiration + wh = sp1; + sp1.reset(); + REQUIRE(wh.expired()); + REQUIRE_FALSE(wh.lock()); + + // swap + wh1 = sp1; + wil::weak_handle wh2 = sp2; + ptr1 = sp1.get(); + ptr2 = sp2.get(); + REQUIRE(wh1.lock().get() == ptr1); + REQUIRE(wh2.lock().get() == ptr2); + wh1.swap(wh2); + REQUIRE(wh1.lock().get() == ptr2); + REQUIRE(wh2.lock().get() == ptr1); + swap(wh1, wh2); + REQUIRE(wh1.lock().get() == ptr1); + REQUIRE(wh2.lock().get() == ptr2); + + // address + sp1.reset(::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0)); + REQUIRE(sp1); + &sp1; // frees the pointer... + REQUIRE_FALSE(sp1); + sp2 = sp1; + REQUIRE_FALSE(sp2); + *(&sp1) = ::CreateEventEx(nullptr, nullptr, CREATE_EVENT_INITIAL_SET, 0); + REQUIRE(sp1); + REQUIRE_FALSE(sp2); + + // test stl container use (hash & std::less) + std::unordered_set hashSet; + hashSet.insert(sp1); + hashSet.insert(sp2); + std::set set; + set.insert(sp1); + set.insert(sp2); +} +#endif + +template +void EventTestCommon() +{ + // Constructor tests... + event_t e1; + REQUIRE_FALSE(e1); + event_t e2(::CreateEventEx(nullptr, nullptr, 0, 0)); + REQUIRE(e2); + wil::unique_handle h1(::CreateEventEx(nullptr, nullptr, 0, 0)); + REQUIRE(h1); + event_t e3(h1.release()); + REQUIRE(e3); + REQUIRE_FALSE(h1); + event_t e4(std::move(e2)); + REQUIRE(e4); + REQUIRE_FALSE(e2); + + // inherited address tests... + REQUIRE(e4); + &e4; + REQUIRE_FALSE(e4); + auto hFill = ::CreateEventEx(nullptr, nullptr, 0, 0); + *(&e4) = hFill; + REQUIRE(e4); + REQUIRE(*e4.addressof() == hFill); + REQUIRE(e4); + + // assignment... + event_t e5; + e5 = std::move(e4); + REQUIRE(e5); + REQUIRE_FALSE(e4); + + // various event-based tests + event_t eManual; + eManual.create(wil::EventOptions::ManualReset); + REQUIRE_FALSE(eManual.is_signaled()); + eManual.SetEvent(); + REQUIRE(eManual.is_signaled()); + eManual.ResetEvent(); + REQUIRE_FALSE(eManual.is_signaled()); + { + auto exit = eManual.SetEvent_scope_exit(); + REQUIRE_FALSE(eManual.is_signaled()); + } + REQUIRE(eManual.is_signaled()); + { + auto exit = eManual.ResetEvent_scope_exit(); + REQUIRE(eManual.is_signaled()); + } + REQUIRE_FALSE(eManual.is_signaled()); + REQUIRE_FALSE(eManual.wait(50)); + REQUIRE_FALSE(wil::handle_wait(eManual.get(), 50)); + eManual.SetEvent(); + REQUIRE(eManual.wait(50)); + REQUIRE(wil::handle_wait(eManual.get(), 50)); + + REQUIRE(eManual.wait(50)); + + REQUIRE(eManual.try_create(wil::EventOptions::ManualReset, L"IExist")); + REQUIRE_FALSE(eManual.try_open(L"IDontExist")); +} + +template +void MutexTestCommon() +{ + // Constructor tests... + mutex_t m1; + REQUIRE_FALSE(m1); + mutex_t m2(::CreateMutexEx(nullptr, nullptr, 0, 0)); + REQUIRE(m2); + wil::unique_handle h1(::CreateMutexEx(nullptr, nullptr, 0, 0)); + REQUIRE(h1); + mutex_t m3(h1.release()); + REQUIRE(m3); + REQUIRE_FALSE(h1); + mutex_t m4(std::move(m2)); + REQUIRE(m4); + REQUIRE_FALSE(m2); + + // inherited address tests... + REQUIRE(m4); + &m4; + REQUIRE_FALSE(m4); + auto hFill = ::CreateMutexEx(nullptr, nullptr, 0, 0); + *(&m4) = hFill; + REQUIRE(m4); + REQUIRE(*m4.addressof() == hFill); + REQUIRE(m4); + + // assignment... + mutex_t m5; + m5 = std::move(m4); + REQUIRE(m5); + REQUIRE_FALSE(m4); + + // various mutex-based tests + mutex_t eManual; + eManual.create(nullptr, CREATE_MUTEX_INITIAL_OWNER); + eManual.ReleaseMutex(); + eManual.create(nullptr, CREATE_MUTEX_INITIAL_OWNER); + { + auto release = eManual.ReleaseMutex_scope_exit(); + } + { + DWORD dwStatus; + auto release = eManual.acquire(&dwStatus); + REQUIRE(release); + REQUIRE(dwStatus == WAIT_OBJECT_0); + } + + // pass-through methods -- test compilation; + REQUIRE(eManual.try_create(L"FOO-TEST")); + REQUIRE(eManual.try_open(L"FOO-TEST")); +} + +template +void SemaphoreTestCommon() +{ + // Constructor tests... + semaphore_t m1; + REQUIRE_FALSE(m1); + semaphore_t m2(::CreateSemaphoreEx(nullptr, 1, 1, nullptr, 0, 0)); + REQUIRE(m2); + wil::unique_handle h1(::CreateSemaphoreEx(nullptr, 1, 1, nullptr, 0, 0)); + REQUIRE(h1); + semaphore_t m3(h1.release()); + REQUIRE(m3); + REQUIRE_FALSE(h1); + semaphore_t m4(std::move(m2)); + REQUIRE(m4); + REQUIRE_FALSE(m2); + + // inherited address tests... + REQUIRE(m4); + &m4; + REQUIRE_FALSE(m4); + auto hFill = ::CreateSemaphoreEx(nullptr, 1, 1, nullptr, 0, 0); + *(&m4) = hFill; + REQUIRE(m4); + REQUIRE(*m4.addressof() == hFill); + REQUIRE(m4); + + // assignment... + semaphore_t m5; + m5 = std::move(m4); + REQUIRE(m5); + REQUIRE_FALSE(m4); + + // various semaphore-based tests + semaphore_t eManual; + eManual.create(1, 1); + WaitForSingleObjectEx(eManual.get(), INFINITE, true); + eManual.ReleaseSemaphore(); + eManual.create(1, 1); + WaitForSingleObjectEx(eManual.get(), INFINITE, true); + { + auto release = eManual.ReleaseSemaphore_scope_exit(); + } + { + DWORD dwStatus; + auto release = eManual.acquire(&dwStatus); + REQUIRE(release); + REQUIRE(dwStatus == WAIT_OBJECT_0); + } + + // pass-through methods -- test compilation; + REQUIRE(eManual.try_create(1, 1, L"BAR-TEST")); + REQUIRE(eManual.try_open(L"BAR-TEST")); +} + +TEST_CASE("WindowsInternalTests::HandleWrappers", "[resource][unique_any]") +{ + EventTestCommon(); + EventTestCommon(); + + // intentionally disabled in the non-exception version... + // wil::unique_event_nothrow testEvent2(wil::EventOptions::ManualReset); + wil::unique_event_failfast testEvent3(wil::EventOptions::ManualReset); +#ifdef WIL_ENABLE_EXCEPTIONS + EventTestCommon(); + + wil::unique_event testEvent(wil::EventOptions::ManualReset); + { + REQUIRE_FALSE(wil::event_is_signaled(testEvent.get())); + auto eventSet = wil::SetEvent_scope_exit(testEvent.get()); + REQUIRE_FALSE(wil::event_is_signaled(testEvent.get())); + } + { + REQUIRE(wil::event_is_signaled(testEvent.get())); + auto eventSet = wil::ResetEvent_scope_exit(testEvent.get()); + REQUIRE(wil::event_is_signaled(testEvent.get())); + } + REQUIRE_FALSE(wil::event_is_signaled(testEvent.get())); + REQUIRE_FALSE(wil::handle_wait(testEvent.get(), 0)); + + // Exception-based - no return + testEvent.create(wil::EventOptions::ManualReset); +#endif + + // Error-code based -- returns HR + wil::unique_event_nothrow testEventNoExcept; + REQUIRE(SUCCEEDED(testEventNoExcept.create(wil::EventOptions::ManualReset))); + + + MutexTestCommon(); + MutexTestCommon(); + + // intentionally disabled in the non-exception version... + // wil::unique_mutex_nothrow testMutex2(L"FOO-TEST-2"); + wil::unique_mutex_failfast testMutex3(L"FOO-TEST-3"); +#ifdef WIL_ENABLE_EXCEPTIONS + MutexTestCommon(); + + wil::unique_mutex testMutex(L"FOO-TEST"); + WaitForSingleObjectEx(testMutex.get(), INFINITE, TRUE); + { + auto release = wil::ReleaseMutex_scope_exit(testMutex.get()); + } + + // Exception-based - no return + testMutex.create(nullptr); +#endif + + // Error-code based -- returns HR + wil::unique_mutex_nothrow testMutexNoExcept; + REQUIRE(SUCCEEDED(testMutexNoExcept.create(nullptr))); + + + SemaphoreTestCommon(); + SemaphoreTestCommon(); + + // intentionally disabled in the non-exception version... + // wil::unique_semaphore_nothrow testSemaphore2(1, 1); + wil::unique_semaphore_failfast testSemaphore3(1, 1); +#ifdef WIL_ENABLE_EXCEPTIONS + SemaphoreTestCommon(); + + wil::unique_semaphore testSemaphore(1, 1); + WaitForSingleObjectEx(testSemaphore.get(), INFINITE, true); + { + auto release = wil::ReleaseSemaphore_scope_exit(testSemaphore.get()); + } + + // Exception-based - no return + testSemaphore.create(1, 1); +#endif + + // Error-code based -- returns HR + wil::unique_semaphore_nothrow testSemaphoreNoExcept; + REQUIRE(SUCCEEDED(testSemaphoreNoExcept.create(1, 1))); + + auto unique_cotaskmem_string_failfast1 = wil::make_cotaskmem_string_failfast(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_cotaskmem_string_failfast1.get()) == 0); + + auto unique_cotaskmem_string_nothrow1 = wil::make_cotaskmem_string_nothrow(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_cotaskmem_string_nothrow1.get()) == 0); + + auto unique_cotaskmem_string_nothrow2 = wil::make_cotaskmem_string_nothrow(L""); + REQUIRE(wcscmp(L"", unique_cotaskmem_string_nothrow2.get()) == 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto unique_cotaskmem_string_te1 = wil::make_cotaskmem_string(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_cotaskmem_string_te1.get()) == 0); + + auto unique_cotaskmem_string_te2 = wil::make_cotaskmem_string(L""); + REQUIRE(wcscmp(L"", unique_cotaskmem_string_te2.get()) == 0); + + auto unique_cotaskmem_string_range1 = wil::make_cotaskmem_string(L"Foo", 2); + REQUIRE(wcscmp(L"Fo", unique_cotaskmem_string_range1.get()) == 0); + + auto unique_cotaskmem_string_range2 = wil::make_cotaskmem_string(nullptr, 2); + unique_cotaskmem_string_range2.get()[0] = L'F'; + unique_cotaskmem_string_range2.get()[1] = L'o'; + REQUIRE(wcscmp(L"Fo", unique_cotaskmem_string_range2.get()) == 0); + + auto unique_cotaskmem_string_range3 = wil::make_cotaskmem_string(nullptr, 0); + REQUIRE(wcscmp(L"", unique_cotaskmem_string_range3.get()) == 0); +#endif + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + { + auto verify = MakeSecureDeleterMallocSpy(); + REQUIRE_SUCCEEDED(::CoRegisterMallocSpy(verify.Get())); + auto removeSpy = wil::scope_exit([&] { ::CoRevokeMallocSpy(); }); + + auto unique_cotaskmem_string_secure_failfast1 = wil::make_cotaskmem_string_secure_failfast(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_cotaskmem_string_secure_failfast1.get()) == 0); + + auto unique_cotaskmem_string_secure_nothrow1 = wil::make_cotaskmem_string_secure_nothrow(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_cotaskmem_string_secure_nothrow1.get()) == 0); + + auto unique_cotaskmem_string_secure_nothrow2 = wil::make_cotaskmem_string_secure_nothrow(L""); + REQUIRE(wcscmp(L"", unique_cotaskmem_string_secure_nothrow2.get()) == 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto unique_cotaskmem_string_secure_te1 = wil::make_cotaskmem_string_secure(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_cotaskmem_string_secure_te1.get()) == 0); + + auto unique_cotaskmem_string_secure_te2 = wil::make_cotaskmem_string_secure(L""); + REQUIRE(wcscmp(L"", unique_cotaskmem_string_secure_te2.get()) == 0); +#endif + } + + auto unique_hlocal_string_failfast1 = wil::make_hlocal_string_failfast(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_hlocal_string_failfast1.get()) == 0); + + auto unique_hlocal_string_nothrow1 = wil::make_hlocal_string_nothrow(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_hlocal_string_nothrow1.get()) == 0); + + auto unique_hlocal_string_nothrow2 = wil::make_hlocal_string_nothrow(L""); + REQUIRE(wcscmp(L"", unique_hlocal_string_nothrow2.get()) == 0); + + auto unique_hlocal_ansistring_failfast1 = wil::make_hlocal_ansistring_failfast("Foo"); + REQUIRE(strcmp("Foo", unique_hlocal_ansistring_failfast1.get()) == 0); + + auto unique_hlocal_ansistring_nothrow1 = wil::make_hlocal_ansistring_nothrow("Foo"); + REQUIRE(strcmp("Foo", unique_hlocal_ansistring_nothrow1.get()) == 0); + + auto unique_hlocal_ansistring_nothrow2 = wil::make_hlocal_ansistring_nothrow(""); + REQUIRE(strcmp("", unique_hlocal_ansistring_nothrow2.get()) == 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto unique_hlocal_string_te1 = wil::make_hlocal_string(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_hlocal_string_te1.get()) == 0); + + auto unique_hlocal_string_te2 = wil::make_hlocal_string(L""); + REQUIRE(wcscmp(L"", unique_hlocal_string_te2.get()) == 0); + + auto unique_hlocal_string_range1 = wil::make_hlocal_string(L"Foo", 2); + REQUIRE(wcscmp(L"Fo", unique_hlocal_string_range1.get()) == 0); + + auto unique_hlocal_string_range2 = wil::make_hlocal_string(nullptr, 2); + unique_hlocal_string_range2.get()[0] = L'F'; + unique_hlocal_string_range2.get()[1] = L'o'; + REQUIRE(wcscmp(L"Fo", unique_hlocal_string_range2.get()) == 0); + + auto unique_hlocal_string_range3 = wil::make_hlocal_string(nullptr, 0); + REQUIRE(wcscmp(L"", unique_hlocal_string_range3.get()) == 0); + + auto unique_hlocal_ansistring_te1 = wil::make_hlocal_ansistring("Foo"); + REQUIRE(strcmp("Foo", unique_hlocal_ansistring_te1.get()) == 0); + + auto unique_hlocal_ansistring_te2 = wil::make_hlocal_ansistring(""); + REQUIRE(strcmp("", unique_hlocal_ansistring_te2.get()) == 0); + + auto unique_hlocal_ansistring_range1 = wil::make_hlocal_ansistring("Foo", 2); + REQUIRE(strcmp("Fo", unique_hlocal_ansistring_range1.get()) == 0); + + auto unique_hlocal_ansistring_range2 = wil::make_hlocal_ansistring(nullptr, 2); + unique_hlocal_ansistring_range2.get()[0] = L'F'; + unique_hlocal_ansistring_range2.get()[1] = L'o'; + REQUIRE(strcmp("Fo", unique_hlocal_ansistring_range2.get()) == 0); + + auto unique_hlocal_ansistring_range3 = wil::make_hlocal_ansistring(nullptr, 0); + REQUIRE(strcmp("", unique_hlocal_ansistring_range3.get()) == 0); +#endif + + { + auto verify = MakeSecureDeleterMallocSpy(); + REQUIRE_SUCCEEDED(::CoRegisterMallocSpy(verify.Get())); + auto removeSpy = wil::scope_exit([&] { ::CoRevokeMallocSpy(); }); + + auto unique_hlocal_string_secure_failfast1 = wil::make_hlocal_string_secure_failfast(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_hlocal_string_secure_failfast1.get()) == 0); + + auto unique_hlocal_string_secure_nothrow1 = wil::make_hlocal_string_secure_nothrow(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_hlocal_string_secure_nothrow1.get()) == 0); + + auto unique_hlocal_string_secure_nothrow2 = wil::make_hlocal_string_secure_nothrow(L""); + REQUIRE(wcscmp(L"", unique_hlocal_string_secure_nothrow2.get()) == 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto unique_hlocal_string_secure_te1 = wil::make_hlocal_string_secure(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_hlocal_string_secure_te1.get()) == 0); + + auto unique_hlocal_string_secure_te2 = wil::make_hlocal_string_secure(L""); + REQUIRE(wcscmp(L"", unique_hlocal_string_secure_te2.get()) == 0); +#endif + } + + auto unique_process_heap_string_failfast1 = wil::make_process_heap_string_failfast(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_process_heap_string_failfast1.get()) == 0); + + auto unique_process_heap_string_nothrow1 = wil::make_process_heap_string_nothrow(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_process_heap_string_nothrow1.get()) == 0); + + auto unique_process_heap_string_nothrow2 = wil::make_process_heap_string_nothrow(L""); + REQUIRE(wcscmp(L"", unique_process_heap_string_nothrow2.get()) == 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto unique_process_heap_string_te1 = wil::make_process_heap_string(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_process_heap_string_te1.get()) == 0); + + auto unique_process_heap_string_te2 = wil::make_process_heap_string(L""); + REQUIRE(wcscmp(L"", unique_process_heap_string_te2.get()) == 0); + + auto unique_process_heap_string_range1 = wil::make_process_heap_string(L"Foo", 2); + REQUIRE(wcscmp(L"Fo", unique_process_heap_string_range1.get()) == 0); + + auto unique_process_heap_string_range2 = wil::make_process_heap_string(nullptr, 2); + unique_process_heap_string_range2.get()[0] = L'F'; + unique_process_heap_string_range2.get()[1] = L'o'; + REQUIRE(wcscmp(L"Fo", unique_process_heap_string_range2.get()) == 0); + + auto unique_process_heap_string_range3 = wil::make_process_heap_string(nullptr, 0); + REQUIRE(wcscmp(L"", unique_process_heap_string_range3.get()) == 0); +#endif + +#endif /* WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) */ + + auto unique_bstr_failfast1 = wil::make_bstr_failfast(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_bstr_failfast1.get()) == 0); + + auto unique_bstr_nothrow1 = wil::make_bstr_nothrow(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_bstr_nothrow1.get()) == 0); + + auto unique_bstr_nothrow2 = wil::make_bstr_nothrow(L""); + REQUIRE(wcscmp(L"", unique_bstr_nothrow2.get()) == 0); + +#ifdef WIL_ENABLE_EXCEPTIONS + auto unique_bstr_te1 = wil::make_bstr(L"Foo"); + REQUIRE(wcscmp(L"Foo", unique_bstr_te1.get()) == 0); + + auto unique_bstr_te2 = wil::make_bstr(L""); + REQUIRE(wcscmp(L"", unique_bstr_te2.get()) == 0); + + + auto testString = wil::make_cotaskmem_string(L"Foo"); + { + auto cleanupMemory = wil::SecureZeroMemory_scope_exit(testString.get()); + } + REQUIRE(0 == testString.get()[0]); + + auto testString2 = wil::make_cotaskmem_string(L"Bar"); + { + auto cleanupMemory = wil::SecureZeroMemory_scope_exit(testString2.get(), wcslen(testString2.get()) * sizeof(testString2.get()[0])); + } + REQUIRE(0 == testString2.get()[0]); +#endif +} + +TEST_CASE("WindowsInternalTests::Locking", "[resource]") +{ + { + SRWLOCK rwlock = SRWLOCK_INIT; + { + auto lock = wil::AcquireSRWLockExclusive(&rwlock); + REQUIRE(lock); + + auto lockRecursive = wil::TryAcquireSRWLockExclusive(&rwlock); + REQUIRE_FALSE(lockRecursive); + + auto lockRecursiveShared = wil::TryAcquireSRWLockShared(&rwlock); + REQUIRE_FALSE(lockRecursiveShared); + } + { + auto lock = wil::AcquireSRWLockShared(&rwlock); + REQUIRE(lock); + + auto lockRecursive = wil::TryAcquireSRWLockShared(&rwlock); + REQUIRE(lockRecursive); + + auto lockRecursiveExclusive = wil::TryAcquireSRWLockExclusive(&rwlock); + REQUIRE_FALSE(lockRecursiveExclusive); + } + { + auto lock = wil::TryAcquireSRWLockExclusive(&rwlock); + REQUIRE(lock); + } + { + auto lock = wil::TryAcquireSRWLockShared(&rwlock); + REQUIRE(lock); + } + } + + { + wil::srwlock rwlock; + { + auto lock = rwlock.lock_exclusive(); + REQUIRE(lock); + + auto lockRecursive = rwlock.try_lock_exclusive(); + REQUIRE_FALSE(lockRecursive); + + auto lockRecursiveShared = rwlock.try_lock_shared(); + REQUIRE_FALSE(lockRecursiveShared); + } + { + auto lock = rwlock.lock_shared(); + REQUIRE(lock); + + auto lockRecursive = rwlock.try_lock_shared(); + REQUIRE(lockRecursive); + + auto lockRecursiveExclusive = rwlock.try_lock_exclusive(); + REQUIRE_FALSE(lockRecursiveExclusive); + } + { + auto lock = rwlock.try_lock_exclusive(); + REQUIRE(lock); + } + { + auto lock = rwlock.try_lock_shared(); + REQUIRE(lock); + } + } + + { + CRITICAL_SECTION cs; + ::InitializeCriticalSectionEx(&cs, 0, 0); + auto lock = wil::EnterCriticalSection(&cs); + REQUIRE(lock); + auto tryLock = wil::TryEnterCriticalSection(&cs); + REQUIRE(tryLock); + } + { + wil::critical_section cs; + auto lock = cs.lock(); + REQUIRE(lock); + auto tryLock = cs.try_lock(); + REQUIRE(tryLock); + } +} + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +TEST_CASE("WindowsInternalTests::GDIWrappers", "[resource]") +{ + { + auto dc = wil::GetDC(::GetDesktopWindow()); + } + { + auto dc = wil::GetWindowDC(::GetDesktopWindow()); + } + { + auto dc = wil::BeginPaint(::GetDesktopWindow()); + wil::unique_hbrush brush(::CreateSolidBrush(0xffffff)); + auto select = wil::SelectObject(dc.get(), brush.get()); + } +} +#endif /* WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) */ + +void TestOutHandle(_Out_ HANDLE *pHandle) +{ + *pHandle = nullptr; +} + +void TestOutAlloc(_Out_ int **ppInt) +{ + *ppInt = new int(5); +} + +void TestCoTask(_Outptr_result_buffer_(*charCount) PWSTR *ppsz, size_t *charCount) +{ + *charCount = 0; + PWSTR psz = static_cast(::CoTaskMemAlloc(10)); + if (psz != nullptr) + { + *charCount = 5; + *psz = L'\0'; + } + *ppsz = psz; +} + +void TestVoid(_Out_ void **ppv) +{ + *ppv = nullptr; +} + +void TestByte(_Out_ BYTE **ppByte) +{ + *ppByte = nullptr; +} + +struct my_deleter +{ + template + void operator()(T* p) const + { + delete p; + } +}; + +TEST_CASE("WindowsInternalTests::WistdTests", "[resource][wistd]") +{ + wil::unique_handle spHandle; + TestOutHandle(wil::out_param(spHandle)); + + wistd::unique_ptr spInt; + TestOutAlloc(wil::out_param(spInt)); + + std::unique_ptr spIntStd; + TestOutAlloc(wil::out_param(spIntStd)); + + wil::unique_cotaskmem_string spsz0; + size_t count; + TestCoTask(wil::out_param(spsz0), &count); + + std::unique_ptr spsz1; + TestCoTask(wil::out_param(spsz1), &count); + + wistd::unique_ptr spsz2; + TestCoTask(wil::out_param(spsz2), &count); + + wil::unique_cotaskmem_ptr spsz3; + TestCoTask(wil::out_param(spsz3), &count); + + wil::unique_cotaskmem_ptr spv; + TestVoid(wil::out_param(spv)); + + std::unique_ptr spIntStd2; + TestByte(wil::out_param_ptr(spIntStd2)); + + struct Nothing + { + int n; + Nothing(int param) : n(param) {} + void Method() {} + }; + + auto spff = wil::make_unique_failfast(3); + auto sp = wil::make_unique_nothrow(3); + REQUIRE(sp); +#ifdef WIL_ENABLE_EXCEPTIONS + THROW_IF_NULL_ALLOC(sp.get()); + THROW_IF_NULL_ALLOC(sp); +#endif + sp->Method(); + decltype(sp) sp2; + sp2 = wistd::move(sp); + sp2.get(); + + wistd::unique_ptr spConstruct; + wistd::unique_ptr spConstruct2 = nullptr; + spConstruct = nullptr; + wistd::unique_ptr spConstruct3(new int(3)); + my_deleter d; + wistd::unique_ptr spConstruct4(new int(4), d); + wistd::unique_ptr spConstruct5(new int(5), my_deleter()); + wistd::unique_ptr spConstruct6(wistd::unique_ptr(new int(6))); + spConstruct = std::move(spConstruct2); + spConstruct.swap(spConstruct2); + REQUIRE(*spConstruct4 == 4); + spConstruct4.get(); + if (spConstruct4) + { + } + spConstruct.reset(); + spConstruct.release(); + + auto spTooBig = wil::make_unique_nothrow(static_cast(-1)); + REQUIRE_FALSE(spTooBig); + // REQUIRE_FAILFAST_UNSPECIFIED([]{ auto spTooBigFF = wil::make_unique_failfast(static_cast(-1)); }); + + object_counter_state state; + count = 0; + { + object_counter c{ state }; + REQUIRE(state.instance_count() == 1); + + wistd::function fn = [&count, c](int param) + { + count += param; + }; + REQUIRE(state.instance_count() == 2); + + fn(3); + REQUIRE(count == 3); + } + REQUIRE(state.instance_count() == 0); + + count = 0; + { + wistd::function fn; + { + object_counter c{ state }; + REQUIRE(state.instance_count() == 1); + fn = [&count, c](int param) + { + count += param; + }; + REQUIRE(state.instance_count() == 2); + } + REQUIRE(state.instance_count() == 1); + fn(3); + REQUIRE(count == 3); + } + + { + // Size Check -- the current implementation allows for 10 pointers to be passed through the lambda + int a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12; + (void)a11; (void)a12; + + wistd::function fn = [&a1, &a2, &a3, &a4, &a5, &a6, &a7, &a8, &a9, &a10]() + { + (void)a1; (void)a2; (void)a3; (void)a4; (void)a5; (void)a6; (void)a7; (void)a8; (void)a9; (void)a10; + }; + auto fnCopy = fn; + + // Uncomment to double-check static assert. Reports: + // "The sizeof(wistd::function) has grown too large for the reserved buffer (10 pointers). Refactor to reduce size of the capture." + // wistd::function fn2 = [&a1, &a2, &a3, &a4, &a5, &a6, &a7, &a8, &a9, &a10, &a11]() + // { + // a1; a2; a3; a4; a5; a6; a7; a8; a9; a10; a11; + // }; + } +} + +template +void NullptrRaiiTests(lambda_t const &fnCreate) +{ + // nullptr_t construct + test_t var1 = nullptr; // implicit + REQUIRE_FALSE(var1); + test_t var2(nullptr); // explicit + REQUIRE_FALSE(var2); + + // nullptr_t assingment + var1.reset(fnCreate()); + REQUIRE(var1); + var1 = nullptr; + REQUIRE_FALSE(var1); + + // nullptr_t reset + var1.reset(fnCreate()); + REQUIRE(var1); + var1.reset(nullptr); + REQUIRE_FALSE(var1); +} + +template +void ReleaseRaiiTests(lambda_t const &fnCreate) +{ + test_t var1(fnCreate()); + REQUIRE(var1); + auto ptr = var1.release(); + REQUIRE_FALSE(var1); + REQUIRE(ptr != test_t::policy::invalid_value()); + REQUIRE(var1.get() == test_t::policy::invalid_value()); + + var1.reset(ptr); +} + +template +void GetRaiiTests(lambda_t const &fnCreate) +{ + test_t var1; + REQUIRE_FALSE(var1); + REQUIRE(var1.get() == test_t::policy::invalid_value()); + + var1.reset(fnCreate()); + REQUIRE(var1); + REQUIRE(var1.get() != test_t::policy::invalid_value()); +} + +template +void SharedRaiiTests(lambda_t const &fnCreate) +{ + // copy construction + test_t var1(fnCreate()); + REQUIRE(var1); + test_t var2 = var1; // implicit + REQUIRE(var1); + REQUIRE(var2); + test_t var3(var1); // explicit + + // copy assignment + test_t var4(fnCreate()); + test_t var5; + var5 = var4; + REQUIRE(var5); + REQUIRE(var4); + + // r-value construction from unique_ptr + typename test_t::unique_t unique1(fnCreate()); + test_t var7(std::move(unique1)); // explicit + REQUIRE(var7); + REQUIRE_FALSE(unique1); + typename test_t::unique_t unique2(fnCreate()); + test_t var8 = std::move(unique2); // implicit + REQUIRE(var8); + REQUIRE_FALSE(unique2); + + // r-value assignment from unique_ptr + var8.reset(); + REQUIRE_FALSE(var8); + unique2.reset(fnCreate()); + var8 = std::move(unique2); + REQUIRE(var8); + REQUIRE_FALSE(unique2); + + // use_count() + REQUIRE(var8.use_count() == 1); + auto var9 = var8; + REQUIRE(var8.use_count() == 2); +} + +template +void WeakRaiiTests(lambda_t const &fnCreate) +{ + typedef typename test_t::shared_t shared_type; + + // base constructor + test_t weak1; + + // construct from shared + shared_type shared1(fnCreate()); + test_t weak2 = shared1; // implicit + test_t weak3(shared1); // explicit + + // construct from weak + test_t weak4 = weak2; // implicit + test_t weak5(weak2); // explicit + + // assign from weak + weak2 = weak5; + + // assign from shared + weak2 = shared1; + + // reset + weak2.reset(); + REQUIRE_FALSE(weak2.lock()); + + // swap + test_t swap1 = shared1; + test_t swap2; + REQUIRE(swap1.lock()); + REQUIRE_FALSE(swap2.lock()); + swap1.swap(swap2); + REQUIRE_FALSE(swap1.lock()); + REQUIRE(swap2.lock()); + + // expired + REQUIRE_FALSE(swap2.expired()); + shared1.reset(); + REQUIRE(swap2.expired()); + + // lock + shared1.reset(fnCreate()); + weak1 = shared1; + auto shared2 = weak1.lock(); + REQUIRE(shared2); + shared2.reset(); + REQUIRE(weak1.lock()); + shared1.reset(); + shared2 = weak1.lock(); + REQUIRE_FALSE(shared2); +} + +template +void AddressRaiiTests(lambda_t const &fnCreate) +{ + test_t var1(fnCreate()); + REQUIRE(var1); + + &var1; + REQUIRE_FALSE(var1); // the address operator does an auto-release + + *(&var1) = fnCreate(); + REQUIRE(var1); + + REQUIRE(var1.addressof() != nullptr); + REQUIRE(var1); // verify that 'addressof()' does not auto-release +} + +template +void BasicRaiiTests(lambda_t const &fnCreate) +{ + auto invalidHandle = test_t::policy::invalid_value(); + + // no-constructor construction + test_t var1; + REQUIRE_FALSE(var1); + + // construct from a given resource + test_t var2(fnCreate()); // r-value + REQUIRE(var2); + test_t var3(invalidHandle); // l-value + REQUIRE_FALSE(var3); + + // r-value construct from the same type + test_t var4(std::move(var2)); // explicit + REQUIRE(var4); + REQUIRE_FALSE(var2); + test_t varMove(fnCreate()); + test_t var4implicit = std::move(varMove); // implicit + REQUIRE(var4implicit); + + // move assignment + var2 = std::move(var4); + REQUIRE(var2); + REQUIRE_FALSE(var4); + + // swap + var2.swap(var4); + REQUIRE(var4); + REQUIRE_FALSE(var2); + + // explicit bool cast + REQUIRE(static_cast(var4)); + REQUIRE_FALSE(static_cast(var2)); + + // reset + var4.reset(); + REQUIRE_FALSE(var4); + var4.reset(fnCreate()); // r-value + REQUIRE(var4); + var4.reset(invalidHandle); // l-value + REQUIRE_FALSE(var4); +} + +template +void EventRaiiTests() +{ + test_t var1; + var1.create(wil::EventOptions::ManualReset); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + + // SetEvent/ResetEvent + var1.SetEvent(); + REQUIRE(wil::event_is_signaled(var1.get())); + var1.ResetEvent(); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + + // SetEvent/ResetEvent scope_exit + { + auto exit = var1.SetEvent_scope_exit(); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + } + REQUIRE(wil::event_is_signaled(var1.get())); + { + auto exit = var1.ResetEvent_scope_exit(); + REQUIRE(wil::event_is_signaled(var1.get())); + } + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + + // is_signaled + REQUIRE_FALSE(var1.is_signaled()); + + // wait + REQUIRE_FALSE(var1.wait(50)); + + // try_create + bool exists = false; + REQUIRE(var1.try_create(wil::EventOptions::ManualReset, L"wiltestevent", nullptr, &exists)); + REQUIRE_FALSE(exists); + test_t var2; + REQUIRE(var2.try_create(wil::EventOptions::ManualReset, L"wiltestevent", nullptr, &exists)); + REQUIRE(exists); + test_t var3; + REQUIRE_FALSE(var3.try_create(wil::EventOptions::ManualReset, L"\\illegal\\chars\\too\\\\many\\\\namespaces", nullptr, &exists)); + REQUIRE(::GetLastError() != ERROR_SUCCESS); + + // try_open + test_t var4; + REQUIRE_FALSE(var4.try_open(L"\\illegal\\chars\\too\\\\many\\\\namespaces")); + REQUIRE(::GetLastError() != ERROR_SUCCESS); + REQUIRE(var4.try_open(L"wiltestevent")); +} + +void EventTests() +{ + static_assert(sizeof(wil::unique_event_nothrow) == sizeof(HANDLE), "event_t should be sizeof(HANDLE) to allow for raw array utilization"); + + auto fnCreate = []() { return CreateEventEx(nullptr, nullptr, CREATE_EVENT_MANUAL_RESET, 0); }; + + BasicRaiiTests(fnCreate); + NullptrRaiiTests(fnCreate); + GetRaiiTests(fnCreate); + ReleaseRaiiTests(fnCreate); + AddressRaiiTests(fnCreate); + EventRaiiTests(); + + BasicRaiiTests(fnCreate); + NullptrRaiiTests(fnCreate); + GetRaiiTests(fnCreate); + ReleaseRaiiTests(fnCreate); + AddressRaiiTests(fnCreate); + EventRaiiTests(); + + wil::unique_event_nothrow event4; + REQUIRE(S_OK == event4.create(wil::EventOptions::ManualReset)); + REQUIRE(FAILED(event4.create(wil::EventOptions::ManualReset, L"\\illegal\\chars\\too\\\\many\\\\namespaces"))); + +#ifdef WIL_ENABLE_EXCEPTIONS + static_assert(sizeof(wil::unique_event) == sizeof(HANDLE), "event_t should be sizeof(HANDLE) to allow for raw array utilization"); + + BasicRaiiTests(fnCreate); + NullptrRaiiTests(fnCreate); + GetRaiiTests(fnCreate); + ReleaseRaiiTests(fnCreate); + AddressRaiiTests(fnCreate); + EventRaiiTests(); + + BasicRaiiTests(fnCreate); + NullptrRaiiTests(fnCreate); + GetRaiiTests(fnCreate); + AddressRaiiTests(fnCreate); + SharedRaiiTests(fnCreate); + EventRaiiTests(); + + WeakRaiiTests(fnCreate); + + // explicitly disabled + // wil::unique_event_nothrow event1(wil::EventOptions::ManualReset); + wil::unique_event event2(wil::EventOptions::ManualReset); + wil::shared_event event3(wil::EventOptions::ManualReset); + + event2.create(wil::EventOptions::ManualReset); + REQUIRE(event2); + event3.create(wil::EventOptions::ManualReset); + REQUIRE(event3); + REQUIRE_THROWS(event2.create(wil::EventOptions::ManualReset, L"\\illegal\\chars\\too\\\\many\\\\namespaces") ); + REQUIRE_THROWS(event3.create(wil::EventOptions::ManualReset, L"\\illegal\\chars\\too\\\\many\\\\namespaces") ); + + wil::unique_event var1(wil::EventOptions::ManualReset); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + { + auto autoset = wil::SetEvent_scope_exit(var1.get()); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + REQUIRE(autoset.get() == var1.get()); + // &autoset; // verified disabled + // autoset.addressof(); // verified disabled + } + REQUIRE(wil::event_is_signaled(var1.get())); + { + auto autoreset = wil::ResetEvent_scope_exit(var1.get()); + REQUIRE(wil::event_is_signaled(var1.get())); + autoreset.reset(); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + } + { + auto autoset = wil::SetEvent_scope_exit(var1.get()); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + autoset.release(); + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); + } + REQUIRE_FALSE(wil::event_is_signaled(var1.get())); +#endif +} + +typedef wil::unique_struct unique_prop_variant_no_init; + +void SetPropVariantValue(_In_ int intVal, _Out_ PROPVARIANT* ppropvar) +{ + ppropvar->intVal = intVal; + ppropvar->vt = VT_INT; +} + +template +void TestUniquePropVariant() +{ + { + wil::unique_prop_variant spPropVariant; + REQUIRE(spPropVariant.vt == VT_EMPTY); + } + + // constructor test + { + PROPVARIANT propVariant; + SetPropVariantValue(12, &propVariant); + T spPropVariant(propVariant); + REQUIRE(((spPropVariant.intVal == 12) && (spPropVariant.vt == VT_INT))); + + T spPropVariant2(wistd::move(propVariant)); + REQUIRE(((spPropVariant2.intVal == 12) && (spPropVariant2.vt == VT_INT))); + + //spPropVariant = propVariant; // deleted function + //spPropVariant = wistd::move(propVariant); // deleted function + //spPropVariant.swap(propVariant); //deleted function + } + + // move constructor + { + T spPropVariant; + SetPropVariantValue(12, &spPropVariant); + REQUIRE(((spPropVariant.intVal == 12) && (spPropVariant.vt == VT_INT))); + + T spPropVariant2(wistd::move(spPropVariant)); + REQUIRE(spPropVariant.vt == VT_EMPTY); + REQUIRE(((spPropVariant2.intVal == 12) && (spPropVariant2.vt == VT_INT))); + + //T spPropVariant3(spPropVariant); // deleted function + //spPropVariant2 = spPropVariant; // deleted function + } + + // move operator + { + T spPropVariant; + SetPropVariantValue(12, &spPropVariant); + T spPropVariant2 = wistd::move(spPropVariant); + REQUIRE(spPropVariant.vt == VT_EMPTY); + REQUIRE(((spPropVariant2.intVal == 12) && (spPropVariant2.vt == VT_INT))); + } + + // reset + { + PROPVARIANT propVariant; + SetPropVariantValue(22, &propVariant); + T spPropVariant; + SetPropVariantValue(12, &spPropVariant); + T spPropVariant2; + + //spPropVariant2.reset(spPropVariant); // deleted function + spPropVariant.reset(propVariant); + REQUIRE(spPropVariant.intVal == 22); + REQUIRE(propVariant.intVal == 22); + + spPropVariant.reset(); + REQUIRE(spPropVariant.vt == VT_EMPTY); + } + + // swap + { + T spPropVariant; + SetPropVariantValue(12, &spPropVariant); + T spPropVariant2; + SetPropVariantValue(22, &spPropVariant2); + + spPropVariant.swap(spPropVariant2); + REQUIRE(spPropVariant.intVal == 22); + REQUIRE(spPropVariant2.intVal == 12); + } + + // release, addressof, reset_and_addressof + { + T spPropVariant; + SetPropVariantValue(12, &spPropVariant); + + [](PROPVARIANT* propVariant) + { + REQUIRE(propVariant->vt == VT_EMPTY); + }(spPropVariant.reset_and_addressof()); + + SetPropVariantValue(12, &spPropVariant); + PROPVARIANT* pPropVariant = spPropVariant.addressof(); + REQUIRE(pPropVariant->intVal == 12); + REQUIRE(spPropVariant.intVal == 12); + + PROPVARIANT propVariant = spPropVariant.release(); + REQUIRE(propVariant.intVal == 12); + REQUIRE(spPropVariant.vt == VT_EMPTY); + } +} + +TEST_CASE("WindowsInternalTests::ResourceTemplateTests", "[resource]") +{ + EventTests(); + TestUniquePropVariant(); + TestUniquePropVariant(); +} + +inline unsigned long long ToInt64(const FILETIME &ft) +{ + return (static_cast(ft.dwHighDateTime) << 32) + ft.dwLowDateTime; +} + +inline FILETIME FromInt64(unsigned long long i64) +{ + FILETIME ft = { static_cast(i64), static_cast(i64 >> 32) }; + return ft; +} + +TEST_CASE("WindowsInternalTests::Win32HelperTests", "[win32_helpers]") +{ + auto systemTime = wil::filetime::get_system_time(); + REQUIRE(ToInt64(systemTime) == wil::filetime::to_int64(systemTime)); + auto systemTime64 = wil::filetime::to_int64(systemTime); +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + auto ft1 = FromInt64(systemTime64); + auto ft2 = wil::filetime::from_int64(systemTime64); + REQUIRE(CompareFileTime(&ft1, &ft2) == 0); +#endif /* WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) */ + + REQUIRE(systemTime64 == wil::filetime::to_int64(wil::filetime::from_int64(systemTime64))); + REQUIRE((systemTime64 + wil::filetime_duration::one_hour) == (systemTime64 + (wil::filetime_duration::one_minute * 60))); + auto systemTimePlusOneHour = wil::filetime::add(systemTime, wil::filetime_duration::one_hour); + auto systemTimePlusOneHour64 = wil::filetime::to_int64(systemTimePlusOneHour); + REQUIRE(systemTimePlusOneHour64 == (systemTime64 + wil::filetime_duration::one_hour)); +} + +TEST_CASE("WindowsInternalTests::InitOnceNonTests") +{ + bool called = false; + bool winner = false; + INIT_ONCE init{}; + REQUIRE_FALSE(wil::init_once_initialized(init)); + + // Call, but fail. Should transport the HRESULT back, but mark us as not the winner + called = false; + winner = false; + REQUIRE(E_FAIL == wil::init_once_nothrow(init, [&] { called = true; return E_FAIL; }, &winner)); + REQUIRE_FALSE(wil::init_once_initialized(init)); + REQUIRE(called); + REQUIRE_FALSE(winner); + + // Call, succeed. Should mark us as the winner. + called = false; + winner = false; + REQUIRE_SUCCEEDED(wil::init_once_nothrow(init, [&] { called = true; return S_OK; }, &winner)); + REQUIRE(wil::init_once_initialized(init)); + REQUIRE(called); + REQUIRE(winner); + + // Call again. Should not actually be invoked and should not be the winner + called = false; + winner = false; + REQUIRE_SUCCEEDED(wil::init_once_nothrow(init, [&] { called = false; return S_OK; }, &winner)); + REQUIRE(wil::init_once_initialized(init)); + REQUIRE_FALSE(called); + REQUIRE_FALSE(winner); + + // Call again. Still not invoked, but we don't care if we're the winner + called = false; + REQUIRE_SUCCEEDED(wil::init_once_nothrow(init, [&] { called = false; return S_OK; })); + REQUIRE(wil::init_once_initialized(init)); + REQUIRE_FALSE(called); + +#ifdef WIL_ENABLE_EXCEPTIONS + called = false; + winner = false; + init = {}; + + // A thrown exception leaves the object un-initialized + REQUIRE_THROWS_AS(winner = wil::init_once(init, [&] { called = true; throw wil::ResultException(E_FAIL); }), wil::ResultException); + REQUIRE_FALSE(wil::init_once_initialized(init)); + REQUIRE(called); + REQUIRE_FALSE(winner); + + // Success! + called = false; + winner = false; + REQUIRE_NOTHROW(winner = wil::init_once(init, [&] { called = true; })); + REQUIRE(wil::init_once_initialized(init)); + REQUIRE(called); + REQUIRE(winner); + + // No-op success! + called = false; + winner = false; + REQUIRE_NOTHROW(winner = wil::init_once(init, [&] { called = true; })); + REQUIRE(wil::init_once_initialized(init)); + REQUIRE_FALSE(called); + REQUIRE_FALSE(winner); +#endif // WIL_ENABLE_EXCEPTIONS +} + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +TEST_CASE("WindowsInternalTests::TestUniquePointerCases", "[resource][unique_any]") +{ + // wil::unique_process_heap_ptr tests + { + wil::unique_process_heap_ptr empty; // null case + } + { + wil::unique_process_heap_ptr heapMemory(::HeapAlloc(::GetProcessHeap(), 0, 100)); + REQUIRE(static_cast(heapMemory)); + } + + // wil::unique_cotaskmem_ptr tests + { + wil::unique_cotaskmem_ptr empty; // null case + } + { + wil::unique_cotaskmem_ptr cotaskmemMemory(CoTaskMemAlloc(100)); + REQUIRE(static_cast(cotaskmemMemory)); + } + { + auto cotaskmemMemory = wil::make_unique_cotaskmem_nothrow(42); + REQUIRE(static_cast(cotaskmemMemory)); + REQUIRE(*cotaskmemMemory == static_cast(42)); + } + { + struct S { size_t s; S() : s(42) {} }; + auto cotaskmemMemory = wil::make_unique_cotaskmem_nothrow(); + REQUIRE(static_cast(cotaskmemMemory)); + REQUIRE(cotaskmemMemory->s == static_cast(42)); + } + { + auto cotaskmemArrayMemory = wil::make_unique_cotaskmem_nothrow(12); + REQUIRE(static_cast(cotaskmemArrayMemory)); + } + { + struct S { size_t s; S() : s(42) {} }; + const size_t size = 12; + auto cotaskmemArrayMemory = wil::make_unique_cotaskmem_nothrow(size); + REQUIRE(static_cast(cotaskmemArrayMemory)); + bool verified = true; + for (auto& elem : wil::make_range(cotaskmemArrayMemory.get(), size)) if (elem.s != 42) verified = false; + REQUIRE(verified); + } + + // wil::unique_cotaskmem_secure_ptr tests + { + wil::unique_cotaskmem_secure_ptr empty; // null case + } + { + wil::unique_cotaskmem_secure_ptr cotaskmemMemory(CoTaskMemAlloc(100)); + REQUIRE(static_cast(cotaskmemMemory)); + } + { + auto cotaskmemMemory = wil::make_unique_cotaskmem_secure_nothrow(42); + REQUIRE(static_cast(cotaskmemMemory)); + REQUIRE(*cotaskmemMemory == static_cast(42)); + } + { + struct S { size_t s; S() : s(42) {} }; + auto cotaskmemMemory = wil::make_unique_cotaskmem_secure_nothrow(); + REQUIRE(static_cast(cotaskmemMemory)); + REQUIRE(cotaskmemMemory->s == static_cast(42)); + } + { + auto cotaskmemArrayMemory = wil::make_unique_cotaskmem_secure_nothrow(12); + REQUIRE(static_cast(cotaskmemArrayMemory)); + } + { + struct S { size_t s; S() : s(42) {} }; + const size_t size = 12; + auto cotaskmemArrayMemory = wil::make_unique_cotaskmem_secure_nothrow(size); + REQUIRE(static_cast(cotaskmemArrayMemory)); + bool verified = true; + for (auto& elem : wil::make_range(cotaskmemArrayMemory.get(), size)) if (elem.s != 42) verified = false; + REQUIRE(verified); + } + + // wil::unique_hlocal_ptr tests + { + wil::unique_hlocal_ptr empty; // null case + } + { + wil::unique_hlocal_ptr localMemory(LocalAlloc(LPTR, 100)); + REQUIRE(static_cast(localMemory)); + } + { + auto localMemory = wil::make_unique_hlocal_nothrow(42); + REQUIRE(static_cast(localMemory)); + REQUIRE(*localMemory == static_cast(42)); + } + { + struct S { size_t s; S() : s(42) {} }; + auto localMemory = wil::make_unique_hlocal_nothrow(); + REQUIRE(static_cast(localMemory)); + REQUIRE(localMemory->s == static_cast(42)); + } + { + auto localArrayMemory = wil::make_unique_hlocal_nothrow(12); + REQUIRE(static_cast(localArrayMemory)); + } + { + struct S { size_t s; S() : s(42) {} }; + const size_t size = 12; + auto localArrayMemory = wil::make_unique_hlocal_nothrow(size); + REQUIRE(static_cast(localArrayMemory)); + bool verified = true; + for (auto& elem : wil::make_range(localArrayMemory.get(), size)) if (elem.s != 42) verified = false; + REQUIRE(verified); + } + + // wil::unique_hlocal_secure_ptr tests + { + wil::unique_hlocal_secure_ptr empty; // null case + } + { + wil::unique_hlocal_secure_ptr localMemory(LocalAlloc(LPTR, 100)); + REQUIRE(static_cast(localMemory)); + } + { + auto localMemory = wil::make_unique_hlocal_secure_nothrow(42); + REQUIRE(static_cast(localMemory)); + REQUIRE(*localMemory == static_cast(42)); + } + { + struct S { size_t s; S() : s(42) {} }; + auto localMemory = wil::make_unique_hlocal_secure_nothrow(); + REQUIRE(static_cast(localMemory)); + REQUIRE(localMemory->s == static_cast(42)); + } + { + auto localArrayMemory = wil::make_unique_hlocal_secure_nothrow(12); + REQUIRE(static_cast(localArrayMemory)); + } + { + struct S { size_t s; S() : s(42) {} }; + const size_t size = 12; + auto localArrayMemory = wil::make_unique_hlocal_secure_nothrow(size); + REQUIRE(static_cast(localArrayMemory)); + bool verified = true; + for (auto& elem : wil::make_range(localArrayMemory.get(), size)) if (elem.s != 42) verified = false; + REQUIRE(verified); + } + + // wil::unique_hglobal_ptr tests + { + wil::unique_hglobal_ptr empty; // null case + } + { + wil::unique_hglobal_ptr globalMemory(GlobalAlloc(GPTR, 100)); + REQUIRE(static_cast(globalMemory)); + } + + { + // The following uses are blocked due to a static assert failure + + //struct S { ~S() {} }; + + //auto cotaskmemMemory = wil::make_unique_cotaskmem_nothrow(); + //auto cotaskmemArrayMemory = wil::make_unique_cotaskmem_nothrow(1); + //auto cotaskmemMemory2 = wil::make_unique_cotaskmem_secure_nothrow(); + //auto cotaskmemArrayMemory2 = wil::make_unique_cotaskmem_secure_nothrow(1); + + //auto localMemory = wil::make_unique_hlocal_nothrow(); + //auto localArrayMemory = wil::make_unique_hlocal_nothrow(1); + //auto localMemory2 = wil::make_unique_hlocal_secure_nothrow(); + //auto localArrayMemory2 = wil::make_unique_hlocal_secure_nothrow(1); + } +} +#endif + +void GetDWORDArray(_Out_ size_t* count, _Outptr_result_buffer_(*count) DWORD** numbers) +{ + const size_t size = 5; + auto ptr = static_cast(::CoTaskMemAlloc(sizeof(DWORD) * size)); + REQUIRE(ptr); + ::ZeroMemory(ptr, sizeof(DWORD) * size); + *numbers = ptr; + *count = size; +} + +void GetHSTRINGArray(_Out_ ULONG* count, _Outptr_result_buffer_(*count) HSTRING** strings) +{ + const size_t size = 5; + auto ptr = static_cast(::CoTaskMemAlloc(sizeof(HSTRING) * size)); + REQUIRE(ptr); + for (UINT i = 0; i < size; ++i) + { + REQUIRE_SUCCEEDED(WindowsCreateString(L"test", static_cast(wcslen(L"test")), &ptr[i])); + } + *strings = ptr; + *count = static_cast(size); +} + +void GetPOINTArray(_Out_ UINT32* count, _Outptr_result_buffer_(*count) POINT** points) +{ + const size_t size = 5; + auto ptr = static_cast(::CoTaskMemAlloc(sizeof(POINT) * size)); + REQUIRE(ptr); + for (UINT i = 0; i < size; ++i) + { + ptr[i].x = ptr[i].y = i; + } + *points = ptr; + *count = static_cast(size); +} + +#ifdef WIL_ENABLE_EXCEPTIONS +void GetHANDLEArray(_Out_ size_t* count, _Outptr_result_buffer_(*count) HANDLE** events) +{ + const size_t size = 5; + HANDLE* ptr = reinterpret_cast(::CoTaskMemAlloc(sizeof(HANDLE) * size)); + for (auto& val : wil::make_range(ptr, size)) + { + val = wil::unique_event(wil::EventOptions::ManualReset).release(); + } + *events = ptr; + *count = size; +} +#endif + +interface __declspec(uuid("EDCA4ADC-DF46-442A-A69D-FDFD8BC37B31")) IFakeObject : public IUnknown +{ + STDMETHOD_(void, DoStuff)() = 0; +}; + +class ArrayTestObject : witest::AllocatedObject, + public Microsoft::WRL::RuntimeClass, IFakeObject> +{ +public: + HRESULT RuntimeClassInitialize(UINT n) { m_number = n; return S_OK; }; + STDMETHOD_(void, DoStuff)() {} +private: + UINT m_number; +}; + +void GetUnknownArray(_Out_ size_t* count, _Outptr_result_buffer_(*count) IFakeObject*** objects) +{ + const size_t size = 5; + auto ptr = reinterpret_cast(::CoTaskMemAlloc(sizeof(IFakeObject*) * size)); + REQUIRE(ptr); + for (UINT i = 0; i < size; ++i) + { + Microsoft::WRL::ComPtr obj; + REQUIRE_SUCCEEDED(Microsoft::WRL::MakeAndInitialize(&obj, i)); + ptr[i] = obj.Detach(); + } + *objects = ptr; + *count = size; +} + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +TEST_CASE("WindowsInternalTests::TestUniqueArrayCases", "[resource]") +{ + // wil::unique_cotaskmem_array_ptr tests + { + wil::unique_cotaskmem_array_ptr values; + GetDWORDArray(values.size_address(), &values); + } + { + wil::unique_cotaskmem_array_ptr strings; + GetHSTRINGArray(strings.size_address(), &strings); + for (ULONG i = 0; i < strings.size(); ++i) + { + REQUIRE(WindowsGetStringLen(strings[i]) == wcslen(L"test")); + } + } + { + wil::unique_cotaskmem_array_ptr points; + GetPOINTArray(points.size_address(), &points); + for (ULONG i = 0; i < points.size(); ++i) + { + REQUIRE((ULONG)points[i].x == i); + } + } +#ifdef WIL_ENABLE_EXCEPTIONS + { + wil::unique_cotaskmem_array_ptr events; + GetHANDLEArray(events.size_address(), &events); + } + { + wil::unique_cotaskmem_array_ptr> objects; + GetUnknownArray(objects.size_address(), &objects); + for (ULONG i = 0; i < objects.size(); ++i) + { + objects[i]->DoStuff(); + } + } +#endif + { + wil::unique_cotaskmem_array_ptr values = nullptr; + REQUIRE(!values); + REQUIRE(values.size() == 0); + + // move onto self + values = wistd::move(values); + REQUIRE(!values); + + // fetch + GetDWORDArray(values.size_address(), &values); + REQUIRE(!!values); + REQUIRE(values.size() > 0); + REQUIRE(!values.empty()); + + // move onto self + values = wistd::move(values); + REQUIRE(!!values); + + decltype(values) values2(wistd::move(values)); + REQUIRE(!values); + REQUIRE(!!values2); + REQUIRE(values2.size() > 0); + + values = wistd::move(values2); + REQUIRE(!!values); + REQUIRE(!values2); + + values = nullptr; + REQUIRE(!values); + GetDWORDArray(values.size_address(), &values); + REQUIRE(!!values); + + auto size = values.size(); + auto ptr = values.release(); + + REQUIRE(!values); + REQUIRE(values.empty()); + + decltype(values) values3(ptr, size); + REQUIRE(!!values3); + REQUIRE(values3.size() == size); + + values3.swap(values); + REQUIRE(!!values); + REQUIRE(!values.empty()); + REQUIRE(!values3); + REQUIRE(values3.empty()); + + REQUIRE(!values.empty()); + size_t count = 0; + for (auto it = values.begin(); it != values.end(); ++it) + { + ++count; + } + REQUIRE(count == values.size()); + + count = 0; + for (auto it = values.cbegin(); it != values.cend(); ++it) + { + ++count; + } + REQUIRE(count == values.size()); + + for (size_t index = 0; index < values.size(); index++) + { + auto& val = values[index]; + REQUIRE(val == 0); + } + + auto& front = values.front(); + REQUIRE(front == 0); + auto& back = values.back(); + REQUIRE(back == 0); + + [](const wil::unique_cotaskmem_array_ptr& cvalues) + { + size_t count = 0; + for (auto it = cvalues.begin(); it != cvalues.end(); ++it) + { + ++count; + } + REQUIRE(count == cvalues.size()); + for (size_t index = 0; index < cvalues.size(); index++) + { + auto& val = cvalues[index]; + REQUIRE(val == 0); + } + + auto& front = cvalues.front(); + REQUIRE(front == 0); + auto& back = cvalues.back(); + REQUIRE(back == 0); + + REQUIRE(cvalues.data() != nullptr); + }(values); + + auto data1 = values.data(); + auto data2 = values.get(); + REQUIRE((data1 && (data1 == data2))); + + values.reset(); + REQUIRE(!values); + REQUIRE(values.empty()); + + GetDWORDArray(values2.size_address(), &values2); + size = values2.size(); + ptr = values2.release(); + + values.reset(ptr, size); + REQUIRE(!!values); + REQUIRE(!values.empty()); + + REQUIRE(&values2 == values2.addressof()); + } +} +#endif + +#ifndef __cplusplus_winrt +TEST_CASE("WindowsInternalTests::VerifyMakeAgileCallback", "[wrl]") +{ + using namespace ABI::Windows::Foundation; + + class CallbackClient + { + public: + HRESULT On(IMemoryBufferReference*, IInspectable*) + { + return S_OK; + } + }; + CallbackClient callbackClient; + +#ifdef WIL_ENABLE_EXCEPTIONS + auto cbAgile = wil::MakeAgileCallback>([](IMemoryBufferReference*, IInspectable*) -> HRESULT + { + return S_OK; + }); + REQUIRE(wil::is_agile(cbAgile)); + + auto cbAgileWithMember = wil::MakeAgileCallback>(&callbackClient, &CallbackClient::On); + REQUIRE(wil::is_agile(cbAgileWithMember)); +#endif + auto cbAgileNoThrow = wil::MakeAgileCallbackNoThrow>([](IMemoryBufferReference*, IInspectable*) -> HRESULT + { + return S_OK; + }); + REQUIRE(wil::is_agile(cbAgileNoThrow)); + + auto cbAgileWithMemberNoThrow = wil::MakeAgileCallbackNoThrow>(&callbackClient, &CallbackClient::On); + REQUIRE(wil::is_agile(cbAgileWithMemberNoThrow)); +} +#endif + +TEST_CASE("WindowsInternalTests::Ranges", "[common]") +{ + { + int things[10]{}; + unsigned int count = 0; + for (auto& m : wil::make_range(things, ARRAYSIZE(things))) + { + ++count; + m = 1; + } + REQUIRE(ARRAYSIZE(things) == count); + REQUIRE(1 == things[1]); + } + + { + int things[10]{}; + unsigned int count = 0; + for (auto m : wil::make_range(things, ARRAYSIZE(things))) + { + ++count; + m = 1; + } + REQUIRE(ARRAYSIZE(things) == count); + REQUIRE(0 == things[0]); + } + + { + int things[10]{}; + unsigned int count = 0; + auto range = wil::make_range(things, ARRAYSIZE(things)); + for (auto m : range) + { + (void)m; + ++count; + } + REQUIRE(ARRAYSIZE(things) == count); + } + + { + int things[10]{}; + unsigned int count = 0; + const auto range = wil::make_range(things, ARRAYSIZE(things)); + for (auto m : range) + { + (void)m; + ++count; + } + REQUIRE(ARRAYSIZE(things) == count); + } +} + +TEST_CASE("WindowsInternalTests::HStringTests", "[resource][unique_any]") +{ + const wchar_t kittens[] = L"kittens"; + + { + wchar_t* bufferStorage = nullptr; + wil::unique_hstring_buffer theBuffer; + REQUIRE_SUCCEEDED(::WindowsPreallocateStringBuffer(ARRAYSIZE(kittens), &bufferStorage, &theBuffer)); + REQUIRE_SUCCEEDED(StringCchCopyW(bufferStorage, ARRAYSIZE(kittens), kittens)); + + // Promote sets the promoted-to value but resets theBuffer + wil::unique_hstring promoted; + REQUIRE_SUCCEEDED(wil::make_hstring_from_buffer_nothrow(wistd::move(theBuffer), &promoted)); + REQUIRE(static_cast(promoted)); + REQUIRE_FALSE(static_cast(theBuffer)); + } + + { + wchar_t* bufferStorage = nullptr; + wil::unique_hstring_buffer theBuffer; + REQUIRE_SUCCEEDED(::WindowsPreallocateStringBuffer(ARRAYSIZE(kittens), &bufferStorage, &theBuffer)); + REQUIRE_SUCCEEDED(StringCchCopyW(bufferStorage, ARRAYSIZE(kittens), kittens)); + + // Failure to promote retains the buffer state + REQUIRE_FAILED(wil::make_hstring_from_buffer_nothrow(wistd::move(theBuffer), nullptr)); + REQUIRE(static_cast(theBuffer)); + } + +#ifdef WIL_ENABLE_EXCEPTIONS + { + wchar_t* bufferStorage = nullptr; + wil::unique_hstring_buffer theBuffer; + THROW_IF_FAILED(::WindowsPreallocateStringBuffer(ARRAYSIZE(kittens), &bufferStorage, &theBuffer)); + THROW_IF_FAILED(StringCchCopyW(bufferStorage, ARRAYSIZE(kittens), kittens)); + + wil::unique_hstring promoted; + REQUIRE_NOTHROW(promoted = wil::make_hstring_from_buffer(wistd::move(theBuffer))); + REQUIRE(static_cast(promoted)); + REQUIRE_FALSE(static_cast(theBuffer)); + } +#endif +} + +struct ThreadPoolWaitTestContext +{ + volatile LONG Counter = 0; + wil::unique_event_nothrow Event; +}; + +static void __stdcall ThreadPoolWaitTestCallback( + _Inout_ PTP_CALLBACK_INSTANCE /*instance*/, + _Inout_opt_ void* context, + _Inout_ PTP_WAIT wait, + _In_ TP_WAIT_RESULT /*waitResult*/) +{ + ThreadPoolWaitTestContext& myContext = *reinterpret_cast(context); + SetThreadpoolWait(wait, myContext.Event.get(), nullptr); + ::InterlockedIncrement(&myContext.Counter); +} + +template +void ThreadPoolWaitTestHelper(bool requireExactCallbackCount) +{ + ThreadPoolWaitTestContext myContext; + REQUIRE_SUCCEEDED(myContext.Event.create()); + + WaitResourceT wait; + wait.reset(CreateThreadpoolWait(ThreadPoolWaitTestCallback, &myContext, NULL)); + REQUIRE(wait); + + SetThreadpoolWait(wait.get(), myContext.Event.get(), nullptr); + + const int loopCount = 5; + for (int currCallbackCount = 0; currCallbackCount != loopCount; ++currCallbackCount) + { + // Signal event. + myContext.Event.SetEvent(); + + // Wait until 'myContext.Counter' increments by 1. + for (int itr = 0; itr != 50 && currCallbackCount == myContext.Counter; ++itr) + { + Sleep(10); + } + + // Ensure we didn't timeout + REQUIRE(currCallbackCount + 1 == myContext.Counter); + } + + // Signal one last event. + myContext.Event.SetEvent(); + + // Close thread-pool wait. + wait.reset(); + myContext.Event.reset(); + + // Verify counter. + if (requireExactCallbackCount) + { + REQUIRE(loopCount + 1 == myContext.Counter); + } + else + { + REQUIRE((loopCount + 1 == myContext.Counter || loopCount == myContext.Counter)); + } +} + +TEST_CASE("WindowsInternalTests::ThreadPoolWaitTest", "[resource][unique_threadpool_wait]") +{ + ThreadPoolWaitTestHelper(false); + ThreadPoolWaitTestHelper(true); +} + +struct ThreadPoolWaitWorkContext +{ + volatile LONG Counter = 0; +}; + +static void __stdcall ThreadPoolWaitWorkCallback( + _Inout_ PTP_CALLBACK_INSTANCE /*instance*/, + _Inout_opt_ void* context, + _Inout_ PTP_WORK /*work*/) +{ + ThreadPoolWaitWorkContext& myContext = *reinterpret_cast(context); + ::InterlockedIncrement(&myContext.Counter); +} + +template +void ThreadPoolWaitWorkHelper(bool requireExactCallbackCount) +{ + ThreadPoolWaitWorkContext myContext; + + WaitResourceT work; + work.reset(CreateThreadpoolWork(ThreadPoolWaitWorkCallback, &myContext, NULL)); + REQUIRE(work); + + const int loopCount = 5; + for (int itr = 0; itr != loopCount; ++itr) + { + SubmitThreadpoolWork(work.get()); + } + + work.reset(); + + if (requireExactCallbackCount) + { + REQUIRE(loopCount == myContext.Counter); + } + else + { + REQUIRE(loopCount >= myContext.Counter); + } +} + +TEST_CASE("WindowsInternalTests::ThreadPoolWorkTest", "[resource][unique_threadpool_work]") +{ + ThreadPoolWaitWorkHelper(false); + ThreadPoolWaitWorkHelper(true); +} + +struct ThreadPoolTimerWorkContext +{ + volatile LONG Counter = 0; + wil::unique_event_nothrow Event; +}; + +static void __stdcall ThreadPoolTimerWorkCallback( + _Inout_ PTP_CALLBACK_INSTANCE /*instance*/, + _Inout_opt_ void* context, + _Inout_ PTP_TIMER /*timer*/) +{ + ThreadPoolTimerWorkContext& myContext = *reinterpret_cast(context); + myContext.Event.SetEvent(); + ::InterlockedIncrement(&myContext.Counter); +} + +template +void ThreadPoolTimerWorkHelper(SetThreadpoolTimerT const &setThreadpoolTimerFn, bool requireExactCallbackCount) +{ + ThreadPoolTimerWorkContext myContext; + REQUIRE_SUCCEEDED(myContext.Event.create()); + + TimerResourceT timer; + timer.reset(CreateThreadpoolTimer(ThreadPoolTimerWorkCallback, &myContext, nullptr)); + REQUIRE(timer); + + const int loopCount = 5; + for (int currCallbackCount = 0; currCallbackCount != loopCount; ++currCallbackCount) + { + // Schedule timer + myContext.Event.ResetEvent(); + const auto allowedWindow = 0; + LONGLONG dueTime = -5 * 10000I64; // 5ms + setThreadpoolTimerFn(timer.get(), reinterpret_cast(&dueTime), 0, allowedWindow); + + // Wait until 'myContext.Counter' increments by 1. + REQUIRE(myContext.Event.wait(500)); + for (int itr = 0; itr != 50 && currCallbackCount == myContext.Counter; ++itr) + { + Sleep(10); + } + + // Ensure we didn't timeout + REQUIRE(currCallbackCount + 1 == myContext.Counter); + } + + // Schedule one last timer. + myContext.Event.ResetEvent(); + const auto allowedWindow = 0; + LONGLONG dueTime = -5 * 10000I64; // 5ms + setThreadpoolTimerFn(timer.get(), reinterpret_cast(&dueTime), 0, allowedWindow); + + if (requireExactCallbackCount) + { + // Wait for the event to be set + REQUIRE(myContext.Event.wait(500)); + } + + // Close timer. + timer.reset(); + myContext.Event.reset(); + + // Verify counter. + if (requireExactCallbackCount) + { + REQUIRE(loopCount + 1 == myContext.Counter); + } + else + { + REQUIRE((loopCount + 1 == myContext.Counter || loopCount == myContext.Counter)); + } +} + +TEST_CASE("WindowsInternalTests::ThreadPoolTimerTest", "[resource][unique_threadpool_timer]") +{ + static_assert(sizeof(FILETIME) == sizeof(LONGLONG), "FILETIME and LONGLONG must be same size"); + ThreadPoolTimerWorkHelper(SetThreadpoolTimer, false); + ThreadPoolTimerWorkHelper(SetThreadpoolTimer, true); +} + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +static void __stdcall SlimEventTrollCallback( + _Inout_ PTP_CALLBACK_INSTANCE /*instance*/, + _Inout_opt_ void* context, + _Inout_ PTP_TIMER /*timer*/) +{ + auto event = reinterpret_cast(context); + + // Wake up the thread without setting the event. + // Note: This relies on the fact that the 'wil::slim_event' class only has a single member variable. + WakeByAddressAll(event); +} + +static void __stdcall SlimEventFriendlyCallback( + _Inout_ PTP_CALLBACK_INSTANCE /*instance*/, + _Inout_opt_ void* context, + _Inout_ PTP_TIMER /*timer*/) +{ + auto event = reinterpret_cast(context); + event->SetEvent(); +} + +TEST_CASE("WindowsInternalTests::SlimEventTests", "[resource][slim_event]") +{ + { + wil::slim_event event; + + // Verify simple timeouts work on an auto-reset event. + REQUIRE_FALSE(event.wait(/*timeout(ms)*/ 0)); + REQUIRE_FALSE(event.wait(/*timeout(ms)*/ 10)); + + wil::unique_threadpool_timer trollTimer(CreateThreadpoolTimer(SlimEventTrollCallback, &event, nullptr)); + REQUIRE(trollTimer); + + FILETIME trollDueTime = wil::filetime::from_int64(0); + SetThreadpoolTimer(trollTimer.get(), &trollDueTime, /*period(ms)*/ 5, /*window(ms)*/ 0); + + // Ensure we timeout in spite of being constantly woken up unnecessarily. + REQUIRE_FALSE(event.wait(/*timeout(ms)*/ 100)); + + wil::unique_threadpool_timer friendlyTimer(CreateThreadpoolTimer(SlimEventFriendlyCallback, &event, nullptr)); + REQUIRE(friendlyTimer); + + FILETIME friendlyDueTime = wil::filetime::from_int64(UINT64(-100 * wil::filetime_duration::one_millisecond)); // 100ms (relative to now) + SetThreadpoolTimer(friendlyTimer.get(), &friendlyDueTime, /*period(ms)*/ 0, /*window(ms)*/ 0); + + // Now that the 'friendlyTimer' is queued, we should succeed. + REQUIRE(event.wait(INFINITE)); + + // Ensure event is auto-reset. + REQUIRE_FALSE(event.wait(/*timeout(ms)*/ 100)); + } + + { + wil::slim_event_manual_reset manualResetEvent; + + // Verify simple timeouts work on a manual-reset event. + REQUIRE_FALSE(manualResetEvent.wait(/*timeout(ms)*/ 0)); + REQUIRE_FALSE(manualResetEvent.wait(/*timeout(ms)*/ 10)); + + // Ensure multiple waits can occur on a manual-reset event. + manualResetEvent.SetEvent(); + REQUIRE(manualResetEvent.wait()); + REQUIRE(manualResetEvent.wait(/*timeout(ms)*/ 100)); + REQUIRE(manualResetEvent.wait(INFINITE)); + + // Verify 'ResetEvent' works. + manualResetEvent.ResetEvent(); + REQUIRE_FALSE(manualResetEvent.wait(/*timeout(ms)*/ 10)); + } + +} +#endif // WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + +struct ConditionVariableCSCallbackContext +{ + wil::condition_variable event; + wil::critical_section lock; + auto acquire() { return lock.lock(); } +}; + +struct ConditionVariableSRWCallbackContext +{ + wil::condition_variable event; + wil::srwlock lock; + auto acquire() { return lock.lock_exclusive(); } +}; + +template +static void __stdcall ConditionVariableCallback( + _Inout_ PTP_CALLBACK_INSTANCE /*Instance*/, + _Inout_opt_ void* Context) +{ + auto callbackContext = reinterpret_cast(Context); + + // Acquire the lock to ensure we don't notify the condition variable before the other thread has + // gone to sleep. + auto gate = callbackContext->acquire(); + + // Signal the condition variable. + callbackContext->event.notify_all(); +} + +// A quick sanity check of the 'wil::condition_variable' type. +TEST_CASE("WindowsInternalTests::ConditionVariableTests", "[resource][condition_variable]") +{ + SECTION("Test 'wil::condition_variable' with 'wil::critical_section'") + { + ConditionVariableCSCallbackContext callbackContext; + auto gate = callbackContext.lock.lock(); + + // Schedule the thread that will wake up this thread. + REQUIRE(TrySubmitThreadpoolCallback(ConditionVariableCallback, &callbackContext, nullptr)); + + // Wait on the condition variable. + REQUIRE(callbackContext.event.wait_for(gate, /*timeout(ms)*/ 500)); + } + + SECTION("Test 'wil::condition_variable' with 'wil::srwlock'") + { + ConditionVariableSRWCallbackContext callbackContext; + + // Test exclusive lock. + { + auto gate = callbackContext.lock.lock_exclusive(); + + // Schedule the thread that will wake up this thread. + REQUIRE(TrySubmitThreadpoolCallback(ConditionVariableCallback, &callbackContext, nullptr)); + + // Wait on the condition variable. + REQUIRE(callbackContext.event.wait_for(gate, /*timeout(ms)*/ 500)); + } + + // Test shared lock. + { + auto gate = callbackContext.lock.lock_shared(); + + // Schedule the thread that will wake up this thread. + REQUIRE(TrySubmitThreadpoolCallback(ConditionVariableCallback, &callbackContext, nullptr)); + + // Wait on the condition variable. + REQUIRE(callbackContext.event.wait_for(gate, /*timeout(ms)*/ 500)); + } + } +} + +TEST_CASE("WindowsInternalTests::ReturnWithExpectedTests", "[result_macros]") +{ + wil::g_pfnResultLoggingCallback = ResultMacrosLoggingCallback; + + // Succeeded + REQUIRE_RETURNS_EXPECTED(S_OK, [] { RETURN_IF_FAILED_WITH_EXPECTED(MDEC(hrOKRef()), E_UNEXPECTED); return S_OK; }); + + // Expected + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_IF_FAILED_WITH_EXPECTED(E_FAIL, E_FAIL); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_UNEXPECTED, [] { RETURN_IF_FAILED_WITH_EXPECTED(E_UNEXPECTED, E_FAIL, E_UNEXPECTED, E_POINTER, E_INVALIDARG); return S_OK; }); + + // Unexpected + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_IF_FAILED_WITH_EXPECTED(E_FAIL, E_UNEXPECTED); return S_OK; }); + REQUIRE_RETURNS_EXPECTED(E_FAIL, [] { RETURN_IF_FAILED_WITH_EXPECTED(E_FAIL, E_UNEXPECTED, E_POINTER, E_INVALIDARG); return S_OK; }); +} + +TEST_CASE("WindowsInternalTests::LogWithExpectedTests", "[result_macros]") +{ + wil::g_pfnResultLoggingCallback = ResultMacrosLoggingCallback; + + // Succeeded + REQUIRE_LOG(S_OK, [] { REQUIRE(S_OK == LOG_IF_FAILED_WITH_EXPECTED(MDEC(hrOKRef()), E_FAIL, E_INVALIDARG)); }); + + // Expected + REQUIRE_LOG(S_OK, [] { REQUIRE(E_UNEXPECTED == LOG_IF_FAILED_WITH_EXPECTED(E_UNEXPECTED, E_UNEXPECTED, E_INVALIDARG)); }); + REQUIRE_LOG(S_OK, [] { REQUIRE(E_UNEXPECTED == LOG_IF_FAILED_WITH_EXPECTED(E_UNEXPECTED, E_FAIL, E_UNEXPECTED, E_POINTER, E_INVALIDARG)); }); + + // Unexpected + REQUIRE_LOG(E_FAIL, [] { REQUIRE(E_FAIL == LOG_IF_FAILED_WITH_EXPECTED(E_FAIL, E_UNEXPECTED)); }); + REQUIRE_LOG(E_FAIL, [] { REQUIRE(E_FAIL == LOG_IF_FAILED_WITH_EXPECTED(E_FAIL, E_UNEXPECTED, E_POINTER, E_INVALIDARG)); }); +} + +// Verifies that the shutdown-aware objects respect the alignment +// of the wrapped object. +template class Wrapper> +void VerifyAlignment() +{ + // Some of the wrappers require a method called ProcessShutdown(), so we'll give it one. + struct alignment_sensitive_struct + { + // Use SLIST_HEADER as our poster child alignment-sensitive data type. + SLIST_HEADER value; + void ProcessShutdown() { } + }; + static_assert(alignof(alignment_sensitive_struct) != alignof(char), "Need to choose a better alignment-sensitive type"); + + // Create a custom structure that tries to force misalignment. + struct attempted_misalignment + { + char c; + Wrapper wrapper; + } possibly_misaligned; + + static_assert(alignof(attempted_misalignment) == alignof(alignment_sensitive_struct), "Wrapper type does not respect alignment"); + + // Verify that the wrapper type placed the inner object at proper alignment. + // Note: use std::addressof in case the alignment_sensitive_struct overrides the & operator. + REQUIRE(reinterpret_cast(std::addressof(possibly_misaligned.wrapper.get())) % alignof(alignment_sensitive_struct) == 0); +} + +TEST_CASE("WindowsInternalTests::ShutdownAwareObjectAlignmentTests", "[result_macros]") +{ + VerifyAlignment(); + VerifyAlignment(); + VerifyAlignment(); +} + +#pragma warning(pop) diff --git a/tests/workarounds/readme.md b/tests/workarounds/readme.md new file mode 100644 index 0000000..d2f248f --- /dev/null +++ b/tests/workarounds/readme.md @@ -0,0 +1,2 @@ + +We try and be as conformant as possible, but sometimes dependencies make that difficult. For example, WRL has had a number of conformance issues that keep getting uncovered. The files here are fixed up copies of those files and the include path is modified such that these directories appear first. diff --git a/tests/workarounds/wrl/wrl/async.h b/tests/workarounds/wrl/wrl/async.h new file mode 100644 index 0000000..ac8296a --- /dev/null +++ b/tests/workarounds/wrl/wrl/async.h @@ -0,0 +1,1356 @@ +// +// Copyright (C) Microsoft Corporation +// All rights reserved. +// +// Code in details namespace is for internal usage within the library code +// +#ifndef _WRL_ASYNC_H_ +#define _WRL_ASYNC_H_ + +#ifdef _MSC_VER +#pragma once +#endif // _MSC_VER + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpragma-pack" +#pragma clang diagnostic ignored "-Wignored-qualifiers" +#pragma clang diagnostic ignored "-Wextra-tokens" +#pragma clang diagnostic ignored "-Wreorder" +#endif + +#include +#include +#include + +#include +#include +#if !defined(MIDL_NS_PREFIX) && !defined(____x_ABI_CWindows_CFoundation_CDiagnostics_CITracingStatusChangedEventArgs_FWD_DEFINED__) +namespace ABI { +namespace Windows { +namespace Foundation { +typedef ::Windows::Foundation::IAsyncActionCompletedHandler IAsyncActionCompletedHandler; +namespace Diagnostics { +typedef ::Windows::Foundation::Diagnostics::CausalitySource CausalitySource; +typedef ::Windows::Foundation::Diagnostics::IAsyncCausalityTracerStatics IAsyncCausalityTracerStatics; +typedef ::Windows::Foundation::Diagnostics::TracingStatusChangedEventArgs TracingStatusChangedEventArgs; +typedef ::Windows::Foundation::Diagnostics::ITracingStatusChangedEventArgs ITracingStatusChangedEventArgs; + +typedef ::Windows::Foundation::Diagnostics::CausalityTraceLevel CausalityTraceLevel; +const ::Windows::Foundation::Diagnostics::CausalityTraceLevel CausalityTraceLevel_Verbose = ::Windows::Foundation::Diagnostics::CausalityTraceLevel_Verbose; +const ::Windows::Foundation::Diagnostics::CausalityTraceLevel CausalityTraceLevel_Important = ::Windows::Foundation::Diagnostics::CausalityTraceLevel_Important; +const ::Windows::Foundation::Diagnostics::CausalityTraceLevel CausalityTraceLevel_Required = ::Windows::Foundation::Diagnostics::CausalityTraceLevel_Required; + +const ::Windows::Foundation::Diagnostics::CausalityRelation CausalityRelation_Join = ::Windows::Foundation::Diagnostics::CausalityRelation_Join; +const ::Windows::Foundation::Diagnostics::CausalityRelation CausalityRelation_Choice = ::Windows::Foundation::Diagnostics::CausalityRelation_Choice; +const ::Windows::Foundation::Diagnostics::CausalityRelation CausalityRelation_Error = ::Windows::Foundation::Diagnostics::CausalityRelation_Error; +const ::Windows::Foundation::Diagnostics::CausalityRelation CausalityRelation_Cancel = ::Windows::Foundation::Diagnostics::CausalityRelation_Cancel; +const ::Windows::Foundation::Diagnostics::CausalityRelation CausalityRelation_AssignDelegate = ::Windows::Foundation::Diagnostics::CausalityRelation_AssignDelegate ; + +const ::Windows::Foundation::Diagnostics::CausalitySynchronousWork CausalitySynchronousWork_CompletionNotification = ::Windows::Foundation::Diagnostics::CausalitySynchronousWork_CompletionNotification; +const ::Windows::Foundation::Diagnostics::CausalitySynchronousWork CausalitySynchronousWork_ProgressNotification = ::Windows::Foundation::Diagnostics::CausalitySynchronousWork_ProgressNotification; +const ::Windows::Foundation::Diagnostics::CausalitySynchronousWork CausalitySynchronousWork_Execution = ::Windows::Foundation::Diagnostics::CausalitySynchronousWork_Execution; +} +} +} +} +#endif + +#include +#include +#include +#include +#include + +#include + +// Set packing +#include + +#pragma warning(push) +// nonstandard extension used: override specifier 'override' +#pragma warning(disable: 4481) + +// GUID identifying the the Windows Platform for logging purposes +// {C54C95D9-5B6E-41E9-A28D-2DD68F94B500} +extern _declspec(selectany) const GUID GUID_CAUSALITY_WINDOWS_PLATFORM_ID = +{ 0xc54c95d9, 0x5b6e, 0x41e9, { 0xa2, 0x8d, 0x2d, 0xd6, 0x8f, 0x94, 0xb5, 0x0 } }; + +namespace Microsoft { +namespace WRL { + +// Designates the error propagation policy used by FireProgress and FireComplete. If PropagateDelegateError is the mode +// then failures returned from the async completion and progress delegates are propagated. If IgnoreDelegateError +// is the mode, then failures returned from the async completion and progress delegates are converted to successes and +// the errors are swallowed. +enum ErrorPropagationPolicy +{ + PropagateDelegateError = 1, + IgnoreDelegateError = 2 +}; + +namespace Details +{ + // contains states indicating existance or lack of options + struct AsyncOptionsBase + { + static const bool hasCausalityOptions = false; + static const bool hasErrorPropagationPolicy = false; + static const bool hasCausalityOperationName = false; + static const bool isCausalityEnabled = true; + }; + + template < PCWSTR OpName > + struct IsOperationName + { + static const bool Value = true; + }; + + template < > + struct IsOperationName< nullptr > + { + static const bool Value = false; + }; +} + +// Options for error propagation and the defaults are set here. +#if defined(BUILD_WINDOWS) && (NTDDI_VERSION >= NTDDI_WINBLUE) +template < ErrorPropagationPolicy errorPropagationPolicy = PropagateErrorWithWin8Quirk> +#else +template < ErrorPropagationPolicy errorPropagationPolicy = Microsoft::WRL::ErrorPropagationPolicy::IgnoreDelegateError> +#endif +struct ErrorPropagationOptions : public Microsoft::WRL::Details::AsyncOptionsBase +{ + static const ErrorPropagationPolicy PropagationPolicy = errorPropagationPolicy; + static const bool hasErrorPropagationPolicy = true; +}; + +#ifndef _WRL_DISABLE_CAUSALITY_ + +// Options for causality tracing and the needed defaults are set here. The following class may be used as +// a reference to add more options to AsyncBase +#ifdef BUILD_WINDOWS +#define WRL_DEFAULT_CAUSALITY_GUID GUID_CAUSALITY_WINDOWS_PLATFORM_ID +#define WRL_DEFAULT_CAUSALITY_SOURCE ::ABI::Windows::Foundation::Diagnostics::CausalitySource::CausalitySource_System +#else +#define WRL_DEFAULT_CAUSALITY_GUID GUID_NULL +#define WRL_DEFAULT_CAUSALITY_SOURCE ::ABI::Windows::Foundation::Diagnostics::CausalitySource::CausalitySource_Application +#endif //BUILD_WINDOWS + +template < + PCWSTR OpName = nullptr, + const GUID &PlatformId = WRL_DEFAULT_CAUSALITY_GUID, + ::ABI::Windows::Foundation::Diagnostics::CausalitySource CausalitySource = WRL_DEFAULT_CAUSALITY_SOURCE +> +struct AsyncCausalityOptions : public Microsoft::WRL::Details::AsyncOptionsBase +{ + static PCWSTR GetAsyncOperationName() + { + return OpName; + } + + static const GUID GetPlatformId() + { + return PlatformId; + } + + static ::ABI::Windows::Foundation::Diagnostics::CausalitySource GetCausalitySource() + { + return CausalitySource; + } + + static const bool hasCausalityOptions = true; + static const bool hasCausalityOperationName = Microsoft::WRL::Details::IsOperationName::Value; +}; + +// This option type for causality tracing disables just the tracing part. +extern __declspec(selectany) const WCHAR DisableCausalityAsyncOperationName[] = L"Disabled"; +struct DisableCausality : public AsyncCausalityOptions< DisableCausalityAsyncOperationName > +{ + static const bool isCausalityEnabled = false; +}; + +#endif // _WRL_DISABLE_CAUSALITY_ + +namespace Details +{ +// maps internal definitions for AsyncStatus and defines states that are not client visible +enum AsyncStatusInternal +{ + // non-client visible internal states + _Undefined = -2, + _Created = -1, + + // client visible states (must match AsyncStatus exactly) + _Started = static_cast(::ABI::Windows::Foundation::AsyncStatus::Started), + _Completed = static_cast(::ABI::Windows::Foundation::AsyncStatus::Completed), + _Canceled = static_cast(::ABI::Windows::Foundation::AsyncStatus::Canceled), + _Error = static_cast(::ABI::Windows::Foundation::AsyncStatus::Error), + + // non-client visible internal states + _Closed +}; + +template < typename T > +struct DerefHelper; + +template < typename T > +struct DerefHelper +{ + typedef T DerefType; +}; + +#pragma region AsyncOptionsHelper + +#ifndef _WRL_DISABLE_CAUSALITY_ + +// Provides the name for a async operation/action for logging purposes +template < typename TComplete, bool hasName, typename TOptions > +struct CausalityNameHelper +{ + // Provides a default string for the Async Operation/Action name if a name is not provided + // for logging purposes + static PCWSTR GetName() + { + return TOptions::GetAsyncOperationName(); + } +}; + +// Specialization to handle the logging for those classes that implement z_get_rc_name_impl +template < typename TComplete, typename TOptions > +struct CausalityNameHelper< TComplete, false, TOptions > +{ + static PCWSTR GetName() + { + return TComplete::z_get_rc_name_impl(); + } +}; + +// Specialization to handle the logging for IAsyncAction +template +struct CausalityNameHelper< ::ABI::Windows::Foundation::IAsyncActionCompletedHandler, false, TOptions > +{ + // IAsyncActionCompletedHandler is not templatized so it does not implement z_get_rc_name_impl + // hence this specialization + static PCWSTR GetName() + { + return L"Windows.Foundation.IAsyncActionCompletedHandler"; + } +}; + +#endif // _WRL_DISABLE_CAUSALITY_ + +// helper class to switch between default or given options for error +// propagation +template < bool hasValue, typename TOptions > +struct ErrorPropagationOptionsHelper; + +// provides the given options for error propagation +template < typename TOptions > +struct ErrorPropagationOptionsHelper< true , TOptions > +{ + static const Microsoft::WRL::ErrorPropagationPolicy PropagationPolicy = TOptions::PropagationPolicy; +}; + +// provides default options for error propagation +template < typename TOptions > +struct ErrorPropagationOptionsHelper< false , TOptions > +{ + static const Microsoft::WRL::ErrorPropagationPolicy PropagationPolicy = Microsoft::WRL::ErrorPropagationOptions<>::PropagationPolicy; +}; + +#ifndef _WRL_DISABLE_CAUSALITY_ + +// helper class to switch between default or given options for +// Async Causality Options +template < bool hasCausalityOptions, typename TComplete, typename TOptions > +struct AsyncCausalityOptionsHelper; + +// provides the given options for causality tracing +template < typename TComplete, typename TOptions > +struct AsyncCausalityOptionsHelper < true, TComplete, TOptions > +{ +#ifdef BUILD_WINDOWS + static_assert(!(__is_base_of(::ABI::Windows::Foundation::IAsyncActionCompletedHandler, TComplete) && !TOptions::hasCausalityOperationName),"Please add name to Asynchronous Operations for Better Diagnostics: http://winri/BreakingChanges/BreakingChangeForm/Index/1992"); +#endif + static PCWSTR GetAsyncOperationName() + { + return CausalityNameHelper< TComplete, TOptions::hasCausalityOperationName, TOptions >::GetName(); + } + + static const GUID GetPlatformId() + { + return TOptions::GetPlatformId(); + } + + static const ::ABI::Windows::Foundation::Diagnostics::CausalitySource GetCausalitySource() + { + return TOptions::GetCausalitySource(); + } + + static const bool CausalityEnabled = TOptions::isCausalityEnabled; +}; + +// provides the default options for causality tracing +template < typename TComplete, typename TOptions > +struct AsyncCausalityOptionsHelper < false, TComplete, TOptions > +{ +#ifdef BUILD_WINDOWS + static_assert(!(__is_base_of(::ABI::Windows::Foundation::IAsyncActionCompletedHandler, TComplete) && !TOptions::hasCausalityOperationName),"Please add name to Asynchronous Operations for Better Diagnostics: http://winri/BreakingChanges/BreakingChangeForm/Index/1992"); +#endif + static PCWSTR GetAsyncOperationName() + { + return CausalityNameHelper< TComplete, TOptions::hasCausalityOperationName, TOptions >::GetName(); + } + + static const GUID GetPlatformId() + { + return Microsoft::WRL::AsyncCausalityOptions<>::GetPlatformId(); + } + + static ::ABI::Windows::Foundation::Diagnostics::CausalitySource GetCausalitySource() + { + return Microsoft::WRL::AsyncCausalityOptions<>::GetCausalitySource(); + } + + static const bool CausalityEnabled = TOptions::isCausalityEnabled; +}; + +#endif // _WRL_DISABLE_CAUSALITY_ + +// helper calls to accumulate all options +// Future options have to be added here +template < typename TComplete, typename TOptions > +struct AsyncOptionsHelper : +#ifndef _WRL_DISABLE_CAUSALITY_ + public AsyncCausalityOptionsHelper < TOptions::hasCausalityOptions, TComplete, TOptions >, +#endif // _WRL_DISABLE_CAUSALITY_ + public ErrorPropagationOptionsHelper < TOptions::hasErrorPropagationPolicy, TOptions > +{ +}; + +#pragma endregion +// End of AsyncOptionsHelper + +} // Details + +// designates whether the "GetResults" method returns a single result (after complete fires) or multiple results +// (which are progressively consumable between Start state and before Close is called) +enum AsyncResultType +{ + SingleResult = 0x0001, + MultipleResults = 0x0002 +}; + +// indicates how an attempt to transition to a terminal state of Completed or Error should behave with respect to +// the (client-requested) Canceled state. +enum CancelTransitionPolicy +{ + // If the async operation is presently in a (client-requested) Canceled state, this indicates that + // it will stay in the Canceled state as opposed to transitioning to a terminal Completed or Error + // state. + RemainCanceled = 0, + + // If the async operation is presently in a (client-requested) Canceled state, this indicates that + // state should transition from that Canceled state to the terminal state of Completed or Error as + // determined by the call utilizing this flag. + TransitionFromCanceled +}; + +#pragma region AsyncOptions + +template < ErrorPropagationPolicy errorPropagationPolicy > +struct ErrorPropagationPolicyTraits; + +// This error propagation policy passes through all errors +template <> +struct ErrorPropagationPolicyTraits< PropagateDelegateError > +{ + static HRESULT FireCompletionErrorPropagationPolicyFilter(HRESULT hrIn, IUnknown *, void * = nullptr) + { + // Ignore errors if the error is caused by a disconnected object + if (hrIn == RPC_E_DISCONNECTED || hrIn == HRESULT_FROM_WIN32(RPC_S_SERVER_UNAVAILABLE) || hrIn == JSCRIPT_E_CANTEXECUTE) + { + ::RoTransformError(hrIn, S_OK, nullptr); + hrIn = S_OK; + } + return hrIn; + } + + static HRESULT FireProgressErrorPropagationPolicyFilter(HRESULT hrIn, IUnknown *, void * = nullptr) + { + // Ignore errors if the error is caused by a disconnected object + if (hrIn == RPC_E_DISCONNECTED || hrIn == HRESULT_FROM_WIN32(RPC_S_SERVER_UNAVAILABLE) || hrIn == JSCRIPT_E_CANTEXECUTE) + { + ::RoTransformError(hrIn, S_OK, nullptr); + hrIn = S_OK; + } + return hrIn; + } +}; + +// This error propagation policy ignores all errors and converts them to S_OK +template <> +struct ErrorPropagationPolicyTraits< IgnoreDelegateError > +{ + static HRESULT FireCompletionErrorPropagationPolicyFilter(HRESULT hrIn, IUnknown *, void * = nullptr) + { + if (FAILED(hrIn)) + { + ::RoTransformError(hrIn, S_OK, nullptr); + hrIn = S_OK; + } + return hrIn; + } + + static HRESULT FireProgressErrorPropagationPolicyFilter(HRESULT hrIn, IUnknown *, void * = nullptr) + { + if (FAILED(hrIn)) + { + ::RoTransformError(hrIn, S_OK, nullptr); + hrIn = S_OK; + } + return hrIn; + } +}; + +// All options for the AsyncBase class are accumulated here. This class may be expanded to include +// new options as needed. +template < +#if defined(BUILD_WINDOWS) && (NTDDI_VERSION >= NTDDI_WINBLUE) + ErrorPropagationPolicy errorPropagationPolicy = PropagateErrorWithWin8Quirk, +#else + ErrorPropagationPolicy errorPropagationPolicy = ErrorPropagationPolicy::IgnoreDelegateError, +#endif + PCWSTR OpName = nullptr, +#ifdef BUILD_WINDOWS + const GUID &PlatformId = GUID_CAUSALITY_WINDOWS_PLATFORM_ID, + ::ABI::Windows::Foundation::Diagnostics::CausalitySource CausalitySource = ::ABI::Windows::Foundation::Diagnostics::CausalitySource::CausalitySource_System +#else + const GUID &PlatformId = GUID_NULL, + ::ABI::Windows::Foundation::Diagnostics::CausalitySource CausalitySource = ::ABI::Windows::Foundation::Diagnostics::CausalitySource::CausalitySource_Application +#endif //BUILD_WINDOWS +> +struct AsyncOptions : +#ifndef _WRL_DISABLE_CAUSALITY_ + public AsyncCausalityOptions, +#endif // _WRL_DISABLE_CAUSALITY_ + public ErrorPropagationOptions +{ + static const bool hasCausalityOptions = true; + static const bool hasErrorPropagationPolicy = true; + static const bool hasCausalityOperationName = Microsoft::WRL::Details::IsOperationName::Value; + static const bool isCausalityEnabled = true; +}; + +#pragma endregion +// End of AsyncOptions region + +#ifndef _WRL_DISABLE_CAUSALITY_ + _declspec(selectany) INIT_ONCE gCausalityInitOnce = INIT_ONCE_STATIC_INIT; + _declspec(selectany) ::ABI::Windows::Foundation::Diagnostics::IAsyncCausalityTracerStatics* gCausality; +#endif // _WRL_DISABLE_CAUSALITY_ + +// AsyncBase - base class that implements the WinRT Async state machine +// this base class is designed to be used with WRL to implement an async worker object +template < + typename TComplete, + typename TProgress = Details::Nil, + AsyncResultType resultType = SingleResult, + typename TAsyncBaseOptions = AsyncOptions<> +> +class AsyncBase : public AsyncBase< TComplete, Details::Nil, resultType, TAsyncBaseOptions > +{ + typedef typename Details::ArgTraitsHelper< TProgress >::Traits ProgressTraits; + typedef Microsoft::WRL::Details::AsyncOptionsHelper< TComplete, TAsyncBaseOptions > AllOptions; + friend class AsyncBase< TComplete, Details::Nil, resultType, TAsyncBaseOptions >; + +public: + + // since this is designed to be used inside of an RuntimeClass<> template, we can + // only have a default constructor + AsyncBase() : + progressDelegate_(nullptr), + progressDelegateBucketAssist_(nullptr) + { + } + + // Delegate Helpers + STDMETHOD(PutOnProgress)(TProgress* progressHandler) + { + HRESULT hr = this->CheckValidStateForDelegateCall(); + if (SUCCEEDED(hr)) + { + progressDelegate_ = progressHandler; + + if (progressDelegate_ != nullptr) + { + progressDelegateBucketAssist_ = Microsoft::WRL::Details::GetDelegateBucketAssist(progressDelegate_.Get()); + } + + this->TraceDelegateAssigned(); + } + return hr; + } + + STDMETHOD(GetOnProgress)(TProgress** progressHandler) + { + *progressHandler = nullptr; + HRESULT hr = this->CheckValidStateForDelegateCall(); + if (SUCCEEDED(hr)) + { + progressDelegate_.CopyTo(progressHandler); + } + return hr; + } + + HRESULT FireProgress(const typename ProgressTraits::Arg2Type arg) + { + HRESULT hr = S_OK; + ComPtr< ::ABI::Windows::Foundation::IAsyncInfo > asyncInfo = this; + ComPtr::DerefType> operationInterface; + if (progressDelegate_) + { + hr = asyncInfo.As(&operationInterface); + if (SUCCEEDED(hr)) + { + this->TraceProgressNotificationStart(); + + hr = progressDelegate_->Invoke(operationInterface.Get(), arg); + + this->TraceProgressNotificationComplete(); + } + } + + // filter the errors per the Error Propagation Policy + hr = ErrorPropagationPolicyTraits< AllOptions::PropagationPolicy >::FireProgressErrorPropagationPolicyFilter(hr, progressDelegate_.Get(), progressDelegateBucketAssist_); + + return hr; + } + + HRESULT FireCompletion(void) override + { + // "this" may be deleted during the completion call. Remove progress prior to firing completion. + progressDelegate_.Reset(); + return AsyncBase< TComplete, Details::Nil, resultType, TAsyncBaseOptions >::FireCompletion(); + } + +private: + ::Microsoft::WRL::ComPtr progressDelegate_; + void *progressDelegateBucketAssist_; +}; + +template < typename TComplete, AsyncResultType resultType, typename TAsyncBaseOptions > +class AsyncBase< TComplete, Details::Nil, resultType, TAsyncBaseOptions > : public ::Microsoft::WRL::Implements< ::ABI::Windows::Foundation::IAsyncInfo > +{ + typedef typename Details::ArgTraitsHelper< TComplete >::Traits CompleteTraits; + typedef Microsoft::WRL::Details::AsyncOptionsHelper AllOptions; +public: + // since this is designed to be used inside of a RuntimeClass<> template, we can + // only have a default constructor + AsyncBase() : + currentStatus_(Details::AsyncStatusInternal::_Created), + id_(1), + errorCode_(S_OK), + completeDelegate_(nullptr), + completeDelegateBucketAssist_(nullptr), + asyncOperationBucketAssist_(nullptr), + cCompleteDelegateAssigned_(0), + cCallbackMade_(0) + { + } + + // The TraceCompletion in logged if the FireCompletion occurs and the completion call back is assigned + // if the callback was not made then the async operation completion is logged here + virtual ~AsyncBase() + { + if (!cCallbackMade_) + { + TraceOperationComplete(); + } + } + + // IAsyncInfo::put_Id + STDMETHOD(put_Id)(const unsigned int id) + { + if (id == 0) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + WCHAR const pszParamName[] = L"id"; + ::RoOriginateErrorW(E_INVALIDARG, ARRAYSIZE(pszParamName) - 1, pszParamName); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_INVALIDARG; + } + id_ = id; + + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + + if (current != Details::_Created) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_METHOD_CALL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_ILLEGAL_METHOD_CALL; + } + + return S_OK; + } + + // IAsyncInfo::get_Id + STDMETHOD(get_Id)(unsigned int *id) override + { + *id = id_; + return CheckValidStateForAsyncInfoCall(); + } + + // IAsyncInfo::get_Status + STDMETHOD(get_Status)(::ABI::Windows::Foundation::AsyncStatus *status) override + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + *status = static_cast< ::ABI::Windows::Foundation::AsyncStatus >(current); + return CheckValidStateForAsyncInfoCall(); + } + + // IAsyncInfo::get_ErrorCode + STDMETHOD(get_ErrorCode)(HRESULT* errorCode) override + { + HRESULT hr = CheckValidStateForAsyncInfoCall(); + if (SUCCEEDED(hr)) + { + ErrorCode(errorCode); + } + else + { + // Do not propagate the error and error info associated with the async error if this call generated + // a specific error itself. + *errorCode = hr; + } + return hr; + } + +protected: + // Start - this is not externally visible since async operations "hot start" before returning to the caller + STDMETHOD(Start)(void) + { + HRESULT hr = S_OK; + if (TransitionToState(Details::_Started)) + { + hr = OnStart(); + +#ifndef _WRL_DISABLE_CAUSALITY_ + if (SUCCEEDED(hr) && + ::InitOnceExecuteOnce(&gCausalityInitOnce, InitCausality, NULL, NULL)) + { + TraceOperationStart(); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + else + { + hr = E_ILLEGAL_STATE_CHANGE; +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_STATE_CHANGE, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + } + return hr; + } + +public: + // IAsyncInfo::Cancel + STDMETHOD(Cancel)(void) + { + if (TransitionToState(Details::_Canceled)) + { + OnCancel(); + + TraceCancellation(); + } + return S_OK; + } + + // IAsyncInfo::Close + STDMETHOD(Close)(void) override + { + HRESULT hr = S_OK; + if (TransitionToState(Details::_Closed)) + { + OnClose(); + } + else // illegal state change + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + + if (current == Details::_Closed) + { + hr = S_OK; // Closed => Closed transition is just ignored + } + else + { + hr = E_ILLEGAL_STATE_CHANGE; +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_STATE_CHANGE, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + } + } + return hr; + } + + // Delegate helpers + STDMETHOD(PutOnComplete)(TComplete* completeHandler) + { + HRESULT hr = CheckValidStateForDelegateCall(); + if (SUCCEEDED(hr)) + { + // this delegate property is "write once" + if (InterlockedIncrement(&cCompleteDelegateAssigned_) == 1) + { + if (completeHandler != nullptr) + { + completeDelegateBucketAssist_ = Microsoft::WRL::Details::GetDelegateBucketAssist(completeHandler); + } + + completeDelegate_ = completeHandler; + + // Guarantee that the write of completeDelegate_ is ordered with respect to the read of state below + // as perceived from FireCompletion on another thread. + MemoryBarrier(); + + this->TraceDelegateAssigned(); + + // in the "hot start" case, put_Completed could have been called after the async operation has hit + // a terminal state. If so, fire the completion immediately. + if (IsTerminalState()) + { + FireCompletion(); + } + } + else + { + hr = E_ILLEGAL_DELEGATE_ASSIGNMENT; +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_DELEGATE_ASSIGNMENT, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + } + } + return hr; + } + + STDMETHOD(GetOnComplete)(TComplete** completeHandler) + { + *completeHandler = nullptr; + HRESULT hr = CheckValidStateForDelegateCall(); + if (SUCCEEDED(hr)) + { + completeDelegate_.CopyTo(completeHandler); + } + return hr; + } + + virtual HRESULT FireCompletion() + { + HRESULT hr = S_OK; + // must do this *before* the InterlockedIncrement! + TryTransitionToCompleted(); + + __WRL_ASSERT__(IsTerminalState() && "Must only call FireCompletion when operation is in terminal state"); + + // we guarantee that completion can only ever be fired once + if (completeDelegate_ != nullptr && InterlockedIncrement(&cCallbackMade_) == 1) + { + ComPtr< ::ABI::Windows::Foundation::IAsyncInfo> asyncInfo = this; + ComPtr::DerefType> operationInterface; + + TraceOperationComplete(); + + if (SUCCEEDED(asyncInfo.As(&operationInterface))) + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + + TraceCompletionNotificationStart(); + + hr = completeDelegate_->Invoke(operationInterface.Get(), static_cast<::ABI::Windows::Foundation::AsyncStatus>(current)); + // Filter the errors as per the Error Propagation Policy + hr = ErrorPropagationPolicyTraits< AllOptions::PropagationPolicy >::FireCompletionErrorPropagationPolicyFilter(hr, completeDelegate_.Get(), completeDelegateBucketAssist_); + completeDelegate_ = nullptr; + + + TraceCompletionNotificationComplete(); + } + } + + return hr; + } + +protected: + + inline void CurrentStatus(Details::AsyncStatusInternal *status) + { + ::_InterlockedCompareExchange(reinterpret_cast(status), currentStatus_, static_cast(*status)); + __WRL_ASSERT__(*status != Details::_Undefined); + } + + // This method returns the error code stored as a result of a transition into the error state. + // In addition, if there is any restricted error information associated with the error that was captured at the time + // of the error transition, it will be associated with the calling thread via a SetRestrictedErrorInfo call. + inline void ErrorCode(HRESULT *error) + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + + // Do not allow visibility of the error until such point as we have had a successful state transition into the error state. + // The error + information is not a single atomic quantity. It is not considered "published" until we are actively in the error state. + if (current != Details::_Error) + { + *error = S_OK; + } + else + { + ::_InterlockedCompareExchange(reinterpret_cast(error), errorCode_, static_cast(*error)); + if (errorInfo_ != nullptr) + { + SetRestrictedErrorInfo(errorInfo_.Get()); + } + } + } + + bool TryTransitionToCompleted(CancelTransitionPolicy cancelBehavior = CancelTransitionPolicy::RemainCanceled) + { + bool bTransition = TransitionToState(Details::AsyncStatusInternal::_Completed); + if (!bTransition && cancelBehavior == CancelTransitionPolicy::TransitionFromCanceled) + { + bTransition = TransitionCanceledToCompleted(); + } + return bTransition; + } + + bool TryTransitionToError(const HRESULT error, CancelTransitionPolicy cancelBehavior = CancelTransitionPolicy::RemainCanceled, _In_opt_ void * bucketAssist = nullptr) + { + // In addition to the result being transitioned to, there might be restricted error information associated with the error. It + // is assumed that such is on the calling thread. If we successfully transition to the error state with "error" as the code, + // we must also capture the error info and funnel it over to callers of GetResults / ErrorCode. Our call to + // GetRestrictedErrorInfo below will capture the error info after which, it is owned by this async operation. + // + // Since there are multiple pieces of information and the capturing of these are not atomic, no one from the outside is allowed + // to view these until the state transition to error is complete. This happens in two parts: + // + // - A successful CAS from S_OK to error (meaning that this error is the one being captures) + // - A successful state change into the error state (via the Transition* call below) + + if (bucketAssist != nullptr) + { + asyncOperationBucketAssist_ = bucketAssist; + } + bool bTransition = false; + if (::_InterlockedCompareExchange(reinterpret_cast(&errorCode_), error, S_OK) == S_OK) + { + (void)GetRestrictedErrorInfo(&errorInfo_); + + // This thread is the "owner" of the rights to transition to the error state. + bTransition = TransitionToState(Details::AsyncStatusInternal::_Error); + if (!bTransition && cancelBehavior == CancelTransitionPolicy::TransitionFromCanceled) + { + bTransition = TransitionCanceledToError(); + } + } + + if (bTransition) + { + TraceError(); + } + + // if we return true, then we did a valid state transition + // queue firing of completed event (cannot be done from this call frame) + // otherwise we are already in a terminal state: error, canceled, completed, or closed + // and we ignore the transition request to the Error state + return bTransition; + } + + // This method checks to see if the delegate properties can be + // modified in the current state and generates the appropriate + // error hr in the case of violation. + inline HRESULT CheckValidStateForDelegateCall() + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + if (current == Details::_Closed) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_METHOD_CALL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_ILLEGAL_METHOD_CALL; + } + return S_OK; + } + + // This method checks to see if results can be collected in the + // current state and generates the appropriate error hr in + // the case of a violation. + inline HRESULT CheckValidStateForResultsCall() + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + + if (current == Details::_Error) + { + HRESULT hr; + + // Make sure to propagate any restricted error info associated with the asynchronous failure. + ErrorCode(&hr); + return hr; + } +#pragma warning(push) +#pragma warning(disable: 4127) // Conditional expression is constant + // single result only legal in Completed state + if (resultType == SingleResult) +#pragma warning(pop) + { + if (current != Details::_Completed) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_METHOD_CALL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_ILLEGAL_METHOD_CALL; + } + } + // multiple results can be called after async operation is running (started) and before/after Completed + else if (current != Details::_Started && + current != Details::_Canceled && + current != Details::_Completed) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_METHOD_CALL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_ILLEGAL_METHOD_CALL; + } + return S_OK; + } + + // This method can be called by derived classes periodically to determine + // whether the asynchronous operation should continue processing or should + // be halted. + inline bool ContinueAsyncOperation() + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + return (current == Details::_Started); + } + + // These methods are used to allow the async worker implementation do work on + // state transitions. No real "work" should be done in these methods. In other words + // they should not block for a long time on UI timescales. + virtual HRESULT OnStart(void) = 0; + virtual void OnClose(void) = 0; + virtual void OnCancel(void) = 0; + +private: +#ifndef _WRL_DISABLE_CAUSALITY_ + // This method is used to initialize the Causality tracking + static BOOL WINAPI InitCausality( + _Inout_opt_ PINIT_ONCE InitOnce, + _Inout_opt_ PVOID Parameter, + _Out_opt_ PVOID* Context ) + { + UNREFERENCED_PARAMETER(InitOnce); + UNREFERENCED_PARAMETER(Parameter); + UNREFERENCED_PARAMETER(Context); + +#ifdef _NO_CAUSALITY_DOWNLEVEL_ + // do not attempt to trace causality on OS versions less than 6.2 (Windows 8) + OSVERSIONINFOEX osvi; + DWORDLONG dwlConditionMask = 0; + ZeroMemory(&osvi, sizeof(OSVERSIONINFOEX)); + osvi.dwOSVersionInfoSize = sizeof(OSVERSIONINFOEX); + osvi.dwMajorVersion = 6; + osvi.dwMinorVersion = 2; + + VER_SET_CONDITION(dwlConditionMask, VER_MAJORVERSION, VER_GREATER_EQUAL); + VER_SET_CONDITION(dwlConditionMask, VER_MINORVERSION, VER_GREATER_EQUAL); + + // if running on Windows 8 or greater + if (VerifyVersionInfo(&osvi, VER_MAJORVERSION | VER_MINORVERSION, dwlConditionMask)) + { +#endif + Microsoft::WRL::Wrappers::HStringReference hstrCausalityTraceName(RuntimeClass_Windows_Foundation_Diagnostics_AsyncCausalityTracer); + if (FAILED(::Windows::Foundation::GetActivationFactory(hstrCausalityTraceName.Get(), &gCausality))) + { + return FALSE; + } + return TRUE; +#ifdef _NO_CAUSALITY_DOWNLEVEL_ + } + else + { + gCausality = nullptr; + return FALSE; + } +#endif + } +#endif _WRL_DISABLE_CAUSALITY_ + + // This method is used to check if calls to the AsyncInfo properties + // (id, status, error code) are legal in the current state. It also + // generates the appropriate error hr to return in the case of an + // illegal call. + inline HRESULT CheckValidStateForAsyncInfoCall() + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + if (current == Details::_Closed) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(E_ILLEGAL_METHOD_CALL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_ILLEGAL_METHOD_CALL; + } + else if (current == Details::_Created) // error in async ::ABI object - returned to caller not started + { + // No RoOriginateError needed since this can hit multiple times in expected scenarios + + return E_ASYNC_OPERATION_NOT_STARTED; + } + + return S_OK; + } + + inline bool TransitionToState(const Details::AsyncStatusInternal newState) + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + + // This enforces the valid state transitions of the asynchronous worker object + // state machine. + switch(newState) + { + case Details::_Started: + if (current != Details::_Created) + { + return false; + } + break; + case Details::_Completed: + if (current != Details::_Started) + { + return false; + } + break; + case Details::_Canceled: + if (current != Details::_Started) + { + return false; + } + break; + case Details::_Error: + if (current != Details::_Started) + { + return false; + } + break; + case Details::_Closed: + if (!IsTerminalState(current)) + { + return false; + } + break; + default: + return false; + break; + } + // attempt the transition to the new state + // Note: if currentStatus_ == current, then there was no intervening write + // by the async work object and the swap succeeded. + Details::AsyncStatusInternal retState = static_cast( + ::_InterlockedCompareExchange(reinterpret_cast(¤tStatus_), + newState, + static_cast(current))); + + // ICE returns the former state, if the returned state and the + // state we captured at the beginning of this method are the same, + // the swap succeeded. + return (retState == current); + } + +protected: + + // It is legal for an async operation object to transition from (client-requested) Canceled + // state to Completed if, for example, the operation completed near the time of the cancellation request. + // An operation which is no longer responsive to client requests to cancel and intends to complete + // successfully despite any new incoming requests to cancel should call TryTransitionToCompleted and + // pass TransitionFromCanceled instead of using this method. + inline bool TransitionCanceledToCompleted() + { + // this is somewhat overly pessimistic since the client cannot possibly transition + // the operation out of the canceled state (only the async operation itself can call + // this method) + Details::AsyncStatusInternal retState = static_cast( + ::_InterlockedCompareExchange(reinterpret_cast(¤tStatus_), + Details::AsyncStatusInternal::_Completed, + Details::AsyncStatusInternal::_Canceled)); + return (retState == Details::AsyncStatusInternal::_Canceled); + } + + // It is legal for an async operation object to transition from (client-requested) Canceled + // state to the error state if, for example, the operation encountered an error near the time + // of the cancellation request. An operation which is no longer responsive to client requests to cancel + // and intends to complete with an error despite any new incoming requests to cancel should call + // TryTransitionToError and pass TransitionFromCanceled instead of using this method. + inline bool TransitionCanceledToError() + { + Details::AsyncStatusInternal retState = static_cast( + ::_InterlockedCompareExchange(reinterpret_cast(¤tStatus_), + Details::AsyncStatusInternal::_Error, + Details::AsyncStatusInternal::_Canceled)); + return (retState == Details::AsyncStatusInternal::_Canceled); + } + + inline bool IsTerminalState() + { + Details::AsyncStatusInternal current = Details::_Undefined; + CurrentStatus(¤t); + return IsTerminalState(current); + } + + inline bool IsTerminalState(Details::AsyncStatusInternal status) + { + return (status == Details::_Error || + status == Details::_Canceled || + status == Details::_Completed || + status == Details::_Closed); + } + + long cCallbackMade_; + long cCompleteDelegateAssigned_; + +#pragma region TracingMethods + + void TraceOperationStart() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceOperationCreation( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Required, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + Microsoft::WRL::Wrappers::HStringReference(AllOptions::GetAsyncOperationName()).Get(), + id_); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceOperationComplete() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + Details::AsyncStatusInternal status; + CurrentStatus(&status); + gCausality->TraceOperationCompletion( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Required, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + static_cast< ::ABI::Windows::Foundation::AsyncStatus >(status)); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceProgressNotificationStart() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceSynchronousWorkStart( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Important, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + ::ABI::Windows::Foundation::Diagnostics::CausalitySynchronousWork_ProgressNotification); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceProgressNotificationComplete() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceSynchronousWorkCompletion( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Important, + AllOptions::GetCausalitySource(), + ::ABI::Windows::Foundation::Diagnostics::CausalitySynchronousWork_ProgressNotification); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceCompletionNotificationStart() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceSynchronousWorkStart( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Required, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + ::ABI::Windows::Foundation::Diagnostics::CausalitySynchronousWork_CompletionNotification); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceCompletionNotificationComplete() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceSynchronousWorkCompletion( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Required, + AllOptions::GetCausalitySource(), + ::ABI::Windows::Foundation::Diagnostics::CausalitySynchronousWork_CompletionNotification); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceExecutionStart(::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel traceLevel) + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceSynchronousWorkStart( + traceLevel, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + ::ABI::Windows::Foundation::Diagnostics::CausalitySynchronousWork_Execution); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceExecutionComplete(::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel traceLevel) + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceSynchronousWorkCompletion( + traceLevel, + AllOptions::GetCausalitySource(), + ::ABI::Windows::Foundation::Diagnostics::CausalitySynchronousWork_Execution); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceDelegateAssigned() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceOperationRelation( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Verbose, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + ::ABI::Windows::Foundation::Diagnostics::CausalityRelation_AssignDelegate); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceError() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceOperationRelation( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Verbose, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + ::ABI::Windows::Foundation::Diagnostics::CausalityRelation_Error); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + + void TraceCancellation() + { +#ifndef _WRL_DISABLE_CAUSALITY_ + if (gCausality && AllOptions::CausalityEnabled) + { + // Ignoring HRESULT intentionally. Tracking failure should not change the + // normal flow of AsyncOperations + gCausality->TraceOperationRelation( + ::ABI::Windows::Foundation::Diagnostics::CausalityTraceLevel_Important, + AllOptions::GetCausalitySource(), + AllOptions::GetPlatformId(), + reinterpret_cast< UINT64 >(this), + ::ABI::Windows::Foundation::Diagnostics::CausalityRelation_Cancel); + } +#endif // _WRL_DISABLE_CAUSALITY_ + } + +#pragma endregion + +private: + ::Microsoft::WRL::ComPtr completeDelegate_; + void *completeDelegateBucketAssist_; + + ::Microsoft::WRL::ComPtr errorInfo_; + Details::AsyncStatusInternal volatile currentStatus_; + HRESULT volatile errorCode_; + unsigned int id_; + +protected: + void *asyncOperationBucketAssist_; +}; + +}} // namespace Microsoft::WRL + +#pragma warning(pop) + +#ifndef _WRL_DISABLE_CAUSALITY +#define CAUSALITY_OPTIONS(OpName) \ + Microsoft::WRL::AsyncCausalityOptions< OpName > +#define ASYNCBASE_CAUSALITY_OPTIONS(TComplete, OpName) \ + Microsoft::WRL::AsyncBase< TComplete, Microsoft::WRL::Details::Nil, Microsoft::WRL::AsyncResultType::SingleResult, CAUSALITY_OPTIONS(OpName)> +#define ASYNCBASE_WITH_PROGRESS_CAUSALITY_OPTIONS(TComplete, TProgress, OpName) \ + Microsoft::WRL::AsyncBase< TComplete, TProgress, Microsoft::WRL::AsyncResultType::SingleResult, CAUSALITY_OPTIONS(OpName)> +#define ASYNCBASE_DISABLE_CAUSALITY(TComplete) \ + Microsoft::WRL::AsyncBase< TComplete, Microsoft::WRL::Details::Nil, Microsoft::WRL::AsyncResultType::SingleResult, Microsoft::WRL::DisableCausality> +#define ASYNCBASE_WITH_PROGRESS_DISABLE_CAUSALITY(TComplete, TProgress) \ + Microsoft::WRL::AsyncBase< TComplete, TProgress, Microsoft::WRL::AsyncResultType::SingleResult, Microsoft::WRL::DisableCausality> +#endif // _WRL_DISABLE_CAUSALITY + +// Restore packing +#include + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif // _WRL_ASYNC_H_ + +#ifdef BUILD_WINDOWS +#include +#endif diff --git a/tests/workarounds/wrl/wrl/implements.h b/tests/workarounds/wrl/wrl/implements.h new file mode 100644 index 0000000..c4c2489 --- /dev/null +++ b/tests/workarounds/wrl/wrl/implements.h @@ -0,0 +1,2655 @@ +// +// Copyright (C) Microsoft Corporation +// All rights reserved. +// +// Code in Details namespace is for internal usage within the library code +// + +#ifndef _WRL_IMPLEMENTS_H_ +#define _WRL_IMPLEMENTS_H_ + +#ifdef _MSC_VER +#pragma once +#endif // _MSC_VER + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpragma-pack" +#pragma clang diagnostic ignored "-Wunused-value" +#pragma clang diagnostic ignored "-Wmicrosoft-sealed" +#pragma clang diagnostic ignored "-Winaccessible-base" +#endif + +#pragma region includes + +#include +#include +#ifdef BUILD_WINDOWS +#include +#endif +#include +#include + +#include +#include +#include // IMarshal +#include // CLSID_StdGlobalInterfaceTable +#include + +#include +#include + +#if (NTDDI_VERSION >= NTDDI_WINBLUE) +#include "roerrorapi.h" +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + +// Set packing +#include + +#pragma endregion + +#ifndef __WRL_NO_DEFAULT_LIB__ +#pragma comment(lib, "ole32.lib") // For CoTaskMemAlloc +#endif + +#pragma region disable warnings + +#pragma warning(push) +#pragma warning(disable: 4584) // 'class1' : base-class 'class2' is already a base-class of 'class3' +#pragma warning(disable: 4481) // nonstandard extension used: override specifier 'override' + +#pragma endregion // disable warnings + +namespace Microsoft { +namespace WRL { + +// Indicator for RuntimeClass,Implements and ChainInterfaces that T interface +// will be not accessible on IID list +// Example: +// struct MyRuntimeClass : RuntimeClass> {} +template +struct CloakedIid : T +{ +}; + +enum RuntimeClassType +{ + WinRt = 0x0001, + ClassicCom = 0x0002, + WinRtClassicComMix = WinRt | ClassicCom, + InhibitWeakReference = 0x0004, + Delegate = ClassicCom, + InhibitFtmBase = 0x0008, + InhibitRoOriginateError = 0x0010 +}; + +template +struct RuntimeClassFlags +{ + static const unsigned int value = flags; +}; + +namespace Details +{ +// Empty struct used for validating template parameter types in Implements +struct ImplementsBase +{ +}; + +} // namespace Details + +// MixIn modifier allows to combine QI from +// a class that doesn't have default constructor on it +template +struct MixIn +{ +}; + +// ComposableBase template to allow deriving from a RuntimeClass +// Optionally allows specifying the base factory and statics interface +template +class ComposableBase +{ +}; +// Back-compat indicator for RuntimeClass to not support IWeakReferenceSource +typedef RuntimeClassFlags InhibitWeakReferencePolicy; + +template +struct ErrorHelper +{ + static void OriginateError(HRESULT hr, HSTRING message) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ::RoOriginateError(hr, message); +#else + UNREFERENCED_PARAMETER(hr); + UNREFERENCED_PARAMETER(message); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + } +}; + +template<> +struct ErrorHelper +{ + static void OriginateError(HRESULT hr, HSTRING message) + { + UNREFERENCED_PARAMETER(hr); + UNREFERENCED_PARAMETER(message); + // No-Op + } +}; + +namespace Details +{ + +//Forward declaration +struct CreatorMap; + +// Sections automatically generate a list of pointers to CreatorMap through the linker +// Sections a and z are used as a terminators +#pragma section("minATL$__a", read) +// Section f is used to put com objects to creator map +#pragma section("minATL$__f", read) +// Section m divides COM entries from WinRT entries +#pragma section("minATL$__m", read) +// Section r is used to put WinRT objects to creator map +#pragma section("minATL$__r", read) +#pragma section("minATL$__z", read) + +extern "C" +{ +// Location of the first and last entries for the linker generated list of pointers to CreatorMapEntry +__declspec(selectany) __declspec(allocate("minATL$__a")) const CreatorMap* __pobjectentryfirst = nullptr; +// Section m divides COM objects from WinRT objects +// - sections between a and m we store COM object info +// - sections between m+1 and z we store WinRT object info +__declspec(selectany) __declspec(allocate("minATL$__m")) const CreatorMap* __pobjectentrymid = nullptr; +__declspec(selectany) __declspec(allocate("minATL$__z")) const CreatorMap* __pobjectentrylast = nullptr; +} + +// Base class used by all module classes. +class __declspec(novtable) ModuleBase +{ +private: + // Lock that synchronize access and termination of factories + static void* moduleLock_; + + static_assert(sizeof(moduleLock_) == sizeof(SRWLOCK), "cacheLock must have the same size as SRWLOCK"); +protected: + static volatile unsigned long objectCount_; +public: + static ModuleBase *module_; + + ModuleBase() throw() + { +#ifdef _DEBUG + // WRLs support for activatable classes requires there is only one instance of Module<>, this assert + // ensures there is only one. Since Module<> is templatized, using different template parameters will + // result in multiple instances, avoid this by making sure all code in a component uses the same parameters. + // Note that the C++ CX runtime creates an instance; Module, + // so mixing it with non CX code can result in this assert. + // WRL supports static and dynamically allocated Module<>, choose dynamic by defining __WRL_DISABLE_STATIC_INITIALIZE__ + // and allocate that instance with new but only once, for example in the main() entry point of an application. + __WRL_ASSERT__(::InterlockedCompareExchangePointer(reinterpret_cast(&module_), this, nullptr) == nullptr && + "The module was already instantiated"); + + SRWLOCK initSRWLOCK = SRWLOCK_INIT; + __WRL_ASSERT__(reinterpret_cast(&moduleLock_)->Ptr == initSRWLOCK.Ptr && "Different value for moduleLock_ than SRWLOCK_INIT"); + (initSRWLOCK); +#else + module_ = this; +#endif + } + + ModuleBase(const ModuleBase&) = delete; + ModuleBase& operator=(const ModuleBase&) = delete; + + virtual ~ModuleBase() throw() + { +#ifdef _DEBUG + __WRL_ASSERT__(::InterlockedCompareExchangePointer(reinterpret_cast(&module_), nullptr, this) == this && + "The module was already instantiated"); +#else + module_ = nullptr; +#endif + } + + // Number of active objects in the module + STDMETHOD_(unsigned long, IncrementObjectCount)() = 0; + STDMETHOD_(unsigned long, DecrementObjectCount)() = 0; + + STDMETHOD_(unsigned long, GetObjectCount)() const + { + return objectCount_; + } + + STDMETHOD_(const CreatorMap**, GetFirstEntryPointer)() const + { + return &__pobjectentryfirst; + } + + STDMETHOD_(const CreatorMap**, GetMidEntryPointer)() const + { + return &__pobjectentrymid; + } + + STDMETHOD_(const CreatorMap**, GetLastEntryPointer)() const + { + return &__pobjectentrylast; + } + + STDMETHOD_(SRWLOCK*, GetLock)() const + { + return reinterpret_cast(&moduleLock_); + } + + STDMETHOD(RegisterWinRTObject)(_In_opt_z_ const wchar_t*, _In_z_ const wchar_t** activatableClassIds, _Inout_ RO_REGISTRATION_COOKIE* cookie, unsigned int) = 0; + STDMETHOD(UnregisterWinRTObject)(_In_opt_z_ const wchar_t*, _In_ RO_REGISTRATION_COOKIE) = 0; + STDMETHOD(RegisterCOMObject)(_In_opt_z_ const wchar_t*, _In_ IID*, _In_ IClassFactory**, _Inout_ DWORD*, unsigned int) = 0; + STDMETHOD(UnregisterCOMObject)(_In_opt_z_ const wchar_t*, _Inout_ DWORD*, unsigned int) = 0; +}; + +__declspec(selectany) volatile unsigned long ModuleBase::objectCount_ = 0; +// moduleLock_ value must be equal SRWLOCK_INIT which is nullptr +__declspec(selectany) void* ModuleBase::moduleLock_ = nullptr; +__declspec(selectany) ModuleBase *ModuleBase::module_ = nullptr; + +#pragma region helper types +// Empty struct used as default template parameter +class Nil +{ +}; + +// Used on RuntimeClass to protect it from being constructed with new +class DontUseNewUseMake +{ +private: + void* operator new(size_t) throw() + { + __WRL_ASSERT__(false); + return 0; + } + +public: + void* operator new(size_t, _In_ void* placement) throw() + { + return placement; + } +}; + +// RuntimeClassBase is used for detection of RuntimeClass in Make method +class RuntimeClassBase +{ +}; + +// RuntimeClassBaseT provides helper methods for QI and getting IIDs +template +class RuntimeClassBaseT : private RuntimeClassBase +{ +protected: + template + static HRESULT AsIID(_In_ T* implements, REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) throw() + { + *ppvObject = nullptr; +#pragma warning(push) +// Conditional expression is constant +#pragma warning(disable: 4127) +// Potential comparison of a constant with another constant +#pragma warning(disable: 6326) +// Conditional check using template parameter is constant and can be used to optimize the code + bool isRefDelegated = false; + // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. + if (InlineIsEqualGUID(riid, __uuidof(IUnknown)) || ((RuntimeClassTypeT & WinRt) != 0 && InlineIsEqualGUID(riid, __uuidof(IInspectable)))) +#pragma warning(pop) + { + *ppvObject = implements->CastToUnknown(); + static_cast(*ppvObject)->AddRef(); + return S_OK; + } + + HRESULT hr = implements->CanCastTo(riid, ppvObject, &isRefDelegated); + if (SUCCEEDED(hr) && !isRefDelegated) + { + static_cast(*ppvObject)->AddRef(); + } + +#pragma warning(suppress: 6102) // '*ppvObject' is used but may not be initialized + _Analysis_assume_(SUCCEEDED(hr) || (*ppvObject == nullptr)); + + return hr; + } + template + static HRESULT GetImplementedIIDS( + _In_ T* implements, + _Out_ ULONG *iidCount, + _When_(*iidCount == 0, _At_(*iids, _Post_null_)) + _When_(*iidCount > 0, _At_(*iids, _Post_notnull_)) + _Result_nullonfailure_ IID **iids) throw() + { + *iids = nullptr; + *iidCount = 0; + unsigned long count = implements->GetIidCount(); + + // If there is no iids the CoTaskMemAlloc don't have to be called + if (count == 0) + { + return S_OK; + } + + IID* iidArray = reinterpret_cast(::CoTaskMemAlloc(sizeof(IID) * count)); + if (iidArray == nullptr) + { + return E_OUTOFMEMORY; + } + + unsigned long index = 0; + + // assign the IIDs to the array + implements->FillArrayWithIid(&index, iidArray); + __WRL_ASSERT__(index == count); + + // and return it + *iidCount = count; + *iids = iidArray; + return S_OK; + } + +public: + HRESULT RuntimeClassInitialize() throw() + { + return S_OK; + } +}; + +// Base class required to mark FtmBase +class FtmBaseMarker +{ +}; + +// Verifies that I is derived from specified base +template +struct VerifyInterfaceHelper; + +// Specialization for ClassicCom interface +template +struct VerifyInterfaceHelper +{ + static void Verify() throw() + { +#ifdef __WRL_STRICT__ + // Make sure that your interfaces inherit from IUnknown and are not IUnknown and/or IInspectable based + // The IUnknown is allowed only on RuntimeClass as first template parameter + static_assert(__is_base_of(IUnknown, I) && !__is_base_of(IInspectable, I) && !(doStrictCheck && IsSame::value), + "'I' has to derive from 'IUnknown' and not from 'IInspectable'. 'I' must not be IUnknown."); +#else + static_assert(__is_base_of(IUnknown, I), "'I' has to derive from 'IUnknown'."); +#endif + } +}; + +// Specialization for WinRtClassicComMix interface +template +struct VerifyInterfaceHelper +{ + static void Verify() throw() + { +#ifdef __WRL_STRICT__ + // Make sure that your interfaces inherit from IUnknown and are not IUnknown and/or IInspectable + // except when IInspectable is the first template parameter + static_assert(__is_base_of(IUnknown, I) && + (doStrictCheck ? !(IsSame::value || IsSame::value) : __is_base_of(IInspectable, I)), + "'I' has to derive from 'IUnknown' and must not be IUnknown and/or IInspectable."); +#else + static_assert(__is_base_of(IUnknown, I), "'I' has to derive from 'IUnknown'."); +#endif + } +}; + +// Specialization for WinRt interface +template +struct VerifyInterfaceHelper +{ + static void Verify() throw() + { +#ifdef __WRL_STRICT__ + // IWeakReferenceSource is exception for WinRt and can be used however it cannot be first templated interface + // Make sure that your interfaces inherit from IInspectable and are not IInspectable + // The IInspectable is allowed only on RuntimeClass as first template parameter + static_assert((__is_base_of(IWeakReferenceSource, I) && doStrictCheck) || + (__is_base_of(IInspectable, I) && !(doStrictCheck && IsSame::value)), + "'I' has to derive from 'IWeakReferenceSource' or 'IInspectable' and must not be IInspectable"); +#else + // IWeakReference and IWeakReferneceSource are exceptions for WinRT + static_assert(__is_base_of(IWeakReference, I) || + __is_base_of(IWeakReferenceSource, I) || + __is_base_of(IInspectable, I), "'I' has to derive from 'IWeakReference', 'IWeakReferenceSource' or 'IInspectable'"); +#endif + } +}; + +// Specialization for Implements passed as template parameter +template +struct VerifyInterfaceHelper +{ + static void Verify() throw() + { +#ifdef __WRL_STRICT__ + // Verifies if Implements has correct RuntimeClassFlags setting + // Allow using FtmBase on classes configured with RuntimeClassFlags (Default configuration) + static_assert(I::ClassFlags::value == type || + type == WinRtClassicComMix || + __is_base_of(::Microsoft::WRL::Details::FtmBaseMarker, I), + "Implements class must have the same and/or compatibile flags configuration"); +#endif + } +}; + +// Specialization for Implements passed as first template parameter +template +struct VerifyInterfaceHelper +{ + static void Verify() throw() + { +#ifdef __WRL_STRICT__ + // Verifies if Implements has correct RuntimeClassFlags setting + static_assert(I::ClassFlags::value == type || type == WinRtClassicComMix, + "Implements class must have the same and/or compatible flags configuration." + "If you use WRL::FtmBase it cannot be specified as first template parameter on RuntimeClass"); + + // Besides make sure that the first interface on Implements meet flags requirement + VerifyInterfaceHelper::Verify(); +#endif + } +}; + +// Interface traits provides casting and filling iids methods helpers +template +struct __declspec(novtable) InterfaceTraits +{ + typedef I0 Base; + static const unsigned long IidCount = 1; + + template + static void Verify() throw() + { + VerifyInterfaceHelper::Verify(); + } + + template + static Base* CastToBase(_In_ T* ptr) throw() + { + return static_cast(ptr); + } + + template + static IUnknown* CastToUnknown(_In_ T* ptr) throw() + { + return static_cast(static_cast(ptr)); + } + + template + _Success_(return == true) + static bool CanCastTo(_In_ T* ptr, REFIID riid, _Outptr_ void **ppv) throw() + { + // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. + if (InlineIsEqualGUID(riid, __uuidof(Base))) + { + *ppv = static_cast(ptr); + return true; + } + + return false; + } + + static void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + *(iids + *index) = __uuidof(Base); + (*index)++; + } +}; + +// Specialization of traits for cloaked interface +template +struct __declspec(novtable) InterfaceTraits> +{ + typedef CloakedType Base; + static const unsigned long IidCount = 0; + + template + static void Verify() throw() + { + VerifyInterfaceHelper::Verify(); + } + + template + static Base* CastToBase(_In_ T* ptr) throw() + { + return static_cast(ptr); + } + + template + static IUnknown* CastToUnknown(_In_ T* ptr) throw() + { + return static_cast(static_cast(ptr)); + } + + template + _Success_(return == true) + static bool CanCastTo(_In_ T* ptr, REFIID riid, _Outptr_ void **ppv) throw() + { + // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. + if (InlineIsEqualGUID(riid, __uuidof(Base))) + { + *ppv = static_cast(ptr); + return true; + } + + return false; + } + + // Cloaked specialization makes it always IID list empty + static void FillArrayWithIid(_Inout_ unsigned long*, _Inout_ IID*) throw() + { + } +}; + +// Specialization for Nil parameter +template<> +struct __declspec(novtable) InterfaceTraits +{ + typedef Nil Base; + static const unsigned long IidCount = 0; + + template + static void Verify() throw() + { + } + + static void FillArrayWithIid(_Inout_ unsigned long *, _Inout_ IID*) throw() + { + } + + template + _Success_(return == true) + static bool CanCastTo(_In_ T*, REFIID, _Outptr_ void **) throw() + { + return false; + } +}; + +// Verify inheritance +template +struct VerifyInheritanceHelper +{ + static void Verify() throw() + { + static_assert(Details::IsBaseOfStrict::Base, typename InterfaceTraits::Base>::value, "'I' needs to inherit from 'Base'."); + } +}; + +template +struct VerifyInheritanceHelper +{ + static void Verify() throw() + { + } +}; + +#pragma endregion // helper types + +} // namespace Details + +inline Details::ModuleBase* GetModuleBase() throw() +{ + return Details::ModuleBase::module_; +} + +// ChainInterfaces - template allows specifying a derived COM interface along with its class hierarchy to allow QI for the base interfaces +template +struct ChainInterfaces : I0 +{ +protected: + template + static void Verify() throw() + { + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + } + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) throw() + { + typename Details::InterfaceTraits::Base* ptr = Details::InterfaceTraits::CastToBase(this); + + return (Details::InterfaceTraits::CanCastTo(this, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv)) ? S_OK : E_NOINTERFACE; + } + + IUnknown* CastToUnknown() throw() + { + return Details::InterfaceTraits::CastToUnknown(this); + } + + static const unsigned long IidCount = + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount; + + static void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + } +}; + +template +struct ChainInterfaces, I1, I2, I3, I4, I5, I6, I7, I8, I9> +{ + static_assert(!hasImplements, "Cannot use ChainInterfaces> to Mix a class implementing interfaces using \"Implements\""); + +protected: + template + static void Verify() throw() + { + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + Details::InterfaceTraits::template Verify(); + + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + Details::VerifyInheritanceHelper::Verify(); + } + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) throw() + { + BaseType* ptr = static_cast(static_cast(this)); + + return ( + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv) || + Details::InterfaceTraits::CanCastTo(ptr, riid, ppv)) ? S_OK : E_NOINTERFACE; + } + + // It's not possible to cast to IUnknown when Base interface inherit more interfaces + // The RuntimeClass is taking always the first interface as IUnknown thus it's required to + // list IInspectable or IUnknown class before MixIn parameter, such as: + // struct MyRuntimeClass : RuntimeClass, IFoo, IBar>, MyIndependentImplementation {} + IUnknown* CastToUnknown() throw() = delete; + + static const unsigned long IidCount = + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount + + Details::InterfaceTraits::IidCount; + + static void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + Details::InterfaceTraits::FillArrayWithIid(index, iids); + } +}; + +namespace Details +{ + +#pragma region Implements helper templates + +// Helper template used by Implements. This template traverses a list of interfaces and adds them as base class and information +// to enable QI. doStrictCheck is typically false only for the first interface, allowing IInspectable to be explicitly specified +// only as the first interface. +template +struct __declspec(novtable) ImplementsHelper; + +template +struct __declspec(novtable) ImplementsMarker +{}; + +template +struct __declspec(novtable) MarkImplements; + +template +struct __declspec(novtable) MarkImplements +{ + typedef I0 Type; +}; + +template +struct __declspec(novtable) MarkImplements +{ + typedef ImplementsMarker Type; +}; + +template +struct __declspec(novtable) MarkImplements, true> +{ + // Cloaked Implements type will be handled in the nested processing. + // Applying the ImplementsMarker too early will bypass Cloaked behavior. + typedef CloakedIid Type; +}; + +template +struct __declspec(novtable) MarkImplements, true> +{ + // Implements type in mix-ins will be handled in the nested processing. + typedef MixIn Type; +}; + +// AdjustImplements pre-processes the type list for more efficient builds. +template +struct __declspec(novtable) AdjustImplements; + +template +struct __declspec(novtable) AdjustImplements +{ + typedef ImplementsHelper::Type, Bases...> Type; +}; + +// Use AdjustImplements to remove instances of "Details::Nil" from the type list. +template +struct __declspec(novtable) AdjustImplements +{ + typedef typename AdjustImplements::Type Type; +}; + + +template +struct __declspec(novtable) AdjustImplements +{ + typedef ImplementsHelper Type; +}; + + +// Specialization handles unadorned interfaces +template +struct __declspec(novtable) ImplementsHelper : + I0, + AdjustImplements::Type +{ + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + VerifyInterfaceHelper::Verify(); + // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. + if (InlineIsEqualGUID(riid, __uuidof(I0))) + { + *ppv = reinterpret_cast(reinterpret_cast(this)); + return S_OK; + } + return AdjustImplements::Type::CanCastTo(riid, ppv, pRefDelegated); + } + + IUnknown* CastToUnknown() throw() + { + return reinterpret_cast(reinterpret_cast(this)); + } + + unsigned long GetIidCount() throw() + { + return 1 + AdjustImplements::Type::GetIidCount(); + } + + // FillArrayWithIid + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + *(iids + *index) = __uuidof(I0); + (*index)++; + AdjustImplements::Type::FillArrayWithIid(index, iids); + } +}; + + +// Selector is used to "tag" base interfaces to be used in casting, since a runtime class may indirectly derive from +// the same interface or Implements<> template multiple times +template +struct __declspec(novtable) Selector : public base +{ +}; + +// Specialization handles types that derive from ImplementsHelper (e.g. nested Implements). +template +struct __declspec(novtable) ImplementsHelper, TInterfaces...> : + Selector, TInterfaces...>>, + Selector::Type, ImplementsHelper, TInterfaces...>> +{ + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + typedef Selector, TInterfaces...>> CurrentType; + typedef Selector::Type, ImplementsHelper, TInterfaces...>> BaseType; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + VerifyInterfaceHelper::Verify(); + HRESULT hr = CurrentType::CanCastTo(riid, ppv); + if (hr == E_NOINTERFACE) + { + hr = BaseType::CanCastTo(riid, ppv, pRefDelegated); + } + return hr; + } + + IUnknown* CastToUnknown() throw() + { + // First in list wins. + return CurrentType::CastToUnknown(); + } + + unsigned long GetIidCount() throw() + { + return CurrentType::GetIidCount() + BaseType::GetIidCount(); + } + + // FillArrayWithIid + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + CurrentType::FillArrayWithIid(index, iids); + BaseType::FillArrayWithIid(index, iids); + } +}; + +// CloakedIid instance. Since the first "real" interface should be checked against doStrictCheck, +// pass this through unchanged. Two specializations for cloaked prevent the need to use the Selector +// used in the Implements<> case. The same can't be done there because some type ambiguities are unavoidable. +template +struct __declspec(novtable) ImplementsHelper, I1, TInterfaces...> : + AdjustImplements::Type, + AdjustImplements::Type +{ + template friend struct ImplementsHelper; + template friend class Details::RuntimeClassBaseT; + +protected: + + typedef typename AdjustImplements::Type CurrentType; + typedef typename AdjustImplements::Type BaseType; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + VerifyInterfaceHelper::Verify(); + + HRESULT hr = CurrentType::CanCastTo(riid, ppv, pRefDelegated); + if (SUCCEEDED(hr)) + { + return S_OK; + } + return BaseType::CanCastTo(riid, ppv, pRefDelegated); + } + + IUnknown* CastToUnknown() throw() + { + return CurrentType::CastToUnknown(); + } + + // Don't expose the cloaked IID(s), but continue processing the rest of the interfaces + unsigned long GetIidCount() throw() + { + return BaseType::GetIidCount(); + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + BaseType::FillArrayWithIid(index, iids); + } +}; + +template +struct __declspec(novtable) ImplementsHelper> : + AdjustImplements::Type +{ + template friend struct ImplementsHelper; + template friend class Details::RuntimeClassBaseT; + +protected: + + typedef typename AdjustImplements::Type CurrentType; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + VerifyInterfaceHelper::Verify(); + + return CurrentType::CanCastTo(riid, ppv, pRefDelegated); + } + + IUnknown* CastToUnknown() throw() + { + return CurrentType::CastToUnknown(); + } + + // Don't expose the cloaked IID(s), but continue processing the rest of the interfaces + unsigned long GetIidCount() throw() + { + return 0; + } + + void FillArrayWithIid(_Inout_ unsigned long * /*index*/, _Inout_ IID* /*iids*/) throw() + { + // no-op + } +}; + + +// terminal case specialization. +template +struct __declspec(novtable) ImplementsHelper +{ + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + template friend class Details::RuntimeClassBaseT; + + HRESULT CanCastTo(_In_ REFIID /*riid*/, _Outptr_ void ** /*ppv*/, bool * /*pRefDelegated*/ = nullptr) throw() + { + return E_NOINTERFACE; + } + + // IUnknown* CastToUnknown() throw(); // not defined for terminal case. + + unsigned long GetIidCount() throw() + { + return 0; + } + + void FillArrayWithIid(_Inout_ unsigned long * /*index*/, _Inout_ IID* /*iids*/) throw() + { + } +}; + +// Specialization handles chaining interfaces +template +struct __declspec(novtable) ImplementsHelper, TInterfaces...> : + ChainInterfaces, + AdjustImplements::Type +{ + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + template friend class Details::RuntimeClassBaseT; + typedef typename AdjustImplements::Type BaseType; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + ChainInterfaces::template Verify(); + + HRESULT hr = ChainInterfaces::CanCastTo(riid, ppv); + if (FAILED(hr)) + { + hr = BaseType::CanCastTo(riid, ppv, pRefDelegated); + } + + return hr; + } + + IUnknown* CastToUnknown() throw() + { + return ChainInterfaces::CastToUnknown(); + } + + unsigned long GetIidCount() throw() + { + return ChainInterfaces::IidCount + BaseType::GetIidCount(); + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + ChainInterfaces::FillArrayWithIid(index, iids); + BaseType::FillArrayWithIid(index, iids); + } +}; + + +// Mixin specialization +template +struct __declspec(novtable) ImplementsHelper, TInterfaces...> : + AdjustImplements::Type +{ + static_assert(hasImplements, "Cannot use MixIn to with a class not deriving from \"Implements\""); + + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + template friend class Details::RuntimeClassBaseT; + typedef typename AdjustImplements::Type BaseType; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + VerifyInterfaceHelper::Verify(); + + HRESULT hr = static_cast(static_cast(this))->CanCastTo(riid, ppv); + if (FAILED(hr)) + { + hr = BaseType::CanCastTo(riid, ppv, pRefDelegated); + } + + return hr; + } + + IUnknown* CastToUnknown() throw() + { + return static_cast(static_cast(this))->CastToUnknown(); + } + + unsigned long GetIidCount() throw() + { + return static_cast(static_cast(this))->GetIidCount() + + BaseType::GetIidCount(); + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + static_cast(static_cast(this))->FillArrayWithIid(index, iids); + BaseType::FillArrayWithIid(index, iids); + } +}; + +// Specialization handles inheriting COM objects. ComposableBase must be the last non-nil interface in the list. +// Trailing nil's are allowed for compatibility with some tools that pad out the list. +template +struct AreAllNil +{ + static const bool value = false; +}; + +template +struct AreAllNil +{ + static const bool value = AreAllNil::value; +}; + +template <> +struct AreAllNil +{ + static const bool value = true; +}; + +template +struct __declspec(novtable) ImplementsHelper, TInterfaces...> : + ImplementsHelper> +{ + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + template friend class Details::RuntimeClassBaseT; + + typedef ImplementsHelper> BaseType; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated = nullptr) throw() + { + static_assert(AreAllNil::value, "ComposableBase should be the last template parameter to RuntimeClass"); + return BaseType::CanCastTo(riid, ppv, pRefDelegated); + } + + IUnknown* CastToUnknown() throw() + { + static_assert(AreAllNil::value, "ComposableBase should be the last template parameter to RuntimeClass"); + return BaseType::CastToUnknown(); + } + + unsigned long GetIidCount() throw() + { + static_assert(AreAllNil::value, "ComposableBase should be the last template parameter to RuntimeClass"); + return BaseType::GetIidCount(); + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + static_assert(AreAllNil::value, "ComposableBase should be the last template parameter to RuntimeClass"); + BaseType::FillArrayWithIid(index, iids); + } +}; + +template +struct __declspec(novtable) ImplementsHelper> +{ + template friend struct ImplementsHelper; + template friend class RuntimeClassBaseT; + +protected: + template friend class Details::RuntimeClassBaseT; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv, bool *pRefDelegated) throw() + { + *pRefDelegated = true; + return composableBase_.CopyTo(riid, ppv); + } + + IUnknown* CastToUnknown() throw() + { + return nullptr; + } + + unsigned long GetIidCount() throw() + { + return iidCount_; + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + for(unsigned long i = 0; i < iidCount_; i++) + { + *(iids + *index) = *(iidsCached_ + i); + (*index)++; + } + } + + ImplementsHelper() throw() : iidsCached_(nullptr), iidCount_(0) + { + } + + ~ImplementsHelper() throw() + { + ::CoTaskMemFree(iidsCached_); + iidsCached_ = nullptr; + iidCount_ = 0; + } + +public: + HRESULT SetComposableBasePointers(_In_ IInspectable* base, _In_opt_ FactoryInterface* baseFactory = nullptr) throw() + { + if (composableBase_ != nullptr) + { +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ErrorHelper::OriginateError(E_UNEXPECTED, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_UNEXPECTED; + } + + HRESULT hr = base->GetIids(&iidCount_, &iidsCached_); + if (SUCCEEDED(hr)) + { + composableBase_ = base; + composableBaseFactory_ = baseFactory; + } + return hr; + } + + ComPtr GetComposableBase() throw() + { + return composableBase_; + } + + ComPtr GetComposableBaseFactory() throw() + { + return composableBaseFactory_; + } + +private: + ComPtr composableBase_; + ComPtr composableBaseFactory_; + IID *iidsCached_; + unsigned long iidCount_; +}; + +#pragma endregion // Implements helper templates + +} // namespace Details + +// Implements - template implementing QI using the information provided through its template parameters +// Each template parameter has to be one of the following: +// * COM Interface +// * A class that implements one or more COM interfaces +// * ChainInterfaces template +template +struct __declspec(novtable) Implements : + Details::AdjustImplements, true, I0, TInterfaces...>::Type, + Details::ImplementsBase +{ +public: + typedef RuntimeClassFlags ClassFlags; + typedef I0 FirstInterface; +protected: + typedef typename Details::AdjustImplements, true, I0, TInterfaces...>::Type BaseType; + template friend struct Details::ImplementsHelper; + template friend class Details::RuntimeClassBaseT; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) throw() + { + return BaseType::CanCastTo(riid, ppv); + } + + IUnknown* CastToUnknown() throw() + { + return BaseType::CastToUnknown(); + } + + unsigned long GetIidCount() throw() + { + return BaseType::GetIidCount(); + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + BaseType::FillArrayWithIid(index, iids); + } +}; + +template +struct __declspec(novtable) Implements, I0, TInterfaces...> : + Details::AdjustImplements, true, I0, TInterfaces...>::Type, + Details::ImplementsBase +{ +public: + typedef RuntimeClassFlags ClassFlags; + typedef I0 FirstInterface; +protected: + + typedef typename Details::AdjustImplements, true, I0, TInterfaces...>::Type BaseType; + template friend struct Details::ImplementsHelper; + template friend class Details::RuntimeClassBaseT; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) throw() + { + return BaseType::CanCastTo(riid, ppv); + } + + IUnknown* CastToUnknown() throw() + { + return BaseType::CastToUnknown(); + } + + unsigned long GetIidCount() throw() + { + return BaseType::GetIidCount(); + } + + void FillArrayWithIid(_Inout_ unsigned long *index, _Inout_ IID* iids) throw() + { + BaseType::FillArrayWithIid(index, iids); + } +}; + +class FtmBase : + public Implements< + ::Microsoft::WRL::RuntimeClassFlags, + ::Microsoft::WRL::CloakedIid< ::IMarshal> >, + // Inheriting from FtmBaseMarker allows using FtmBase on classes configured with RuntimeClassFlags (Default configuration) + private ::Microsoft::WRL::Details::FtmBaseMarker +{ + // defining type 'Super' for other compilers since '__super' is a VC++-specific language extension + using Super = Implements< + ::Microsoft::WRL::RuntimeClassFlags, + ::Microsoft::WRL::CloakedIid< ::IMarshal> >; +protected: + template friend struct Details::ImplementsHelper; + + HRESULT CanCastTo(REFIID riid, _Outptr_ void **ppv) throw() + { + // Prefer InlineIsEqualGUID over other forms since it's better perf on 4-byte aligned data, which is almost always the case. + if (InlineIsEqualGUID(riid, __uuidof(::IAgileObject))) + { + + *ppv = Super::CastToUnknown(); + return S_OK; + } + + return Super::CanCastTo(riid, ppv); + } + +public: + FtmBase() throw() + { + ComPtr unknown; + if (SUCCEEDED(::CoCreateFreeThreadedMarshaler(nullptr, &unknown))) + { + unknown.As(&marshaller_); + } + } + + // IMarshal Methods +#pragma warning(suppress: 6101) // PREFast cannot see through the smart-pointer invocation + STDMETHOD(GetUnmarshalClass)(_In_ REFIID riid, + _In_opt_ void *pv, + _In_ DWORD dwDestContext, + _Reserved_ void *pvDestContext, + _In_ DWORD mshlflags, + _Out_ CLSID *pCid) override + { + if (marshaller_) + { + return marshaller_->GetUnmarshalClass(riid, pv, dwDestContext, pvDestContext, mshlflags, pCid); + } + return E_OUTOFMEMORY; + } + +#pragma warning(suppress: 6101) // PREFast cannot see through the smart-pointer invocation + STDMETHOD(GetMarshalSizeMax)(_In_ REFIID riid, _In_opt_ void *pv, _In_ DWORD dwDestContext, + _Reserved_ void *pvDestContext, _In_ DWORD mshlflags, _Out_ DWORD *pSize) override + { + if (marshaller_) + { + return marshaller_->GetMarshalSizeMax(riid, pv, dwDestContext, pvDestContext, mshlflags, pSize); + } + return E_OUTOFMEMORY; + } + + STDMETHOD(MarshalInterface)(_In_ IStream *pStm, _In_ REFIID riid, _In_opt_ void *pv, _In_ DWORD dwDestContext, + _Reserved_ void *pvDestContext, _In_ DWORD mshlflags) override + { + if (marshaller_) + { + return marshaller_->MarshalInterface(pStm, riid, pv, dwDestContext, pvDestContext, mshlflags); + } + return E_OUTOFMEMORY; + } + +#pragma warning(suppress: 6101) // PREFast cannot see through the smart-pointer invocation + STDMETHOD(UnmarshalInterface)(_In_ IStream *pStm, _In_ REFIID riid, _Outptr_ void **ppv) override + { + if (marshaller_) + { + return marshaller_->UnmarshalInterface(pStm, riid, ppv); + } + return E_OUTOFMEMORY; + } + + STDMETHOD(ReleaseMarshalData)(_In_ IStream *pStm) override + { + if (marshaller_) + { + return marshaller_->ReleaseMarshalData(pStm); + } + return E_OUTOFMEMORY; + } + + STDMETHOD(DisconnectObject)(_In_ DWORD dwReserved) override + { + if (marshaller_) + { + return marshaller_->DisconnectObject(dwReserved); + } + return E_OUTOFMEMORY; + } + + static HRESULT CreateGlobalInterfaceTable(_Out_ IGlobalInterfaceTable **git) throw() + { + *git = nullptr; + return ::CoCreateInstance(CLSID_StdGlobalInterfaceTable, + nullptr, + CLSCTX_INPROC_SERVER, + __uuidof(IGlobalInterfaceTable), + reinterpret_cast(git)); + } + + ::Microsoft::WRL::ComPtr marshaller_; // Holds a reference to the free threaded marshaler +}; + +namespace Details +{ + +#ifdef _PERF_COUNTERS +class __declspec(novtable) PerfCountersBase +{ +public: + ULONG GetAddRefCount() throw() + { + return addRefCount_; + } + + ULONG GetReleaseCount() throw() + { + return releaseCount_; + } + + ULONG GetQueryInterfaceCount() throw() + { + return queryInterfaceCount_; + } + + void ResetPerfCounters() throw() + { + addRefCount_ = 0; + releaseCount_ = 0; + queryInterfaceCount_ = 0; + } + +protected: + PerfCountersBase() throw() : + addRefCount_(0), + releaseCount_(0), + queryInterfaceCount_(0) + { + } + + void IncrementAddRefCount() throw() + { + InterlockedIncrement(&addRefCount_); + } + + void IncrementReleaseCount() throw() + { + InterlockedIncrement(&releaseCount_); + } + + void IncrementQueryInterfaceCount() throw() + { + InterlockedIncrement(&queryInterfaceCount_); + } + +private: + volatile unsigned long addRefCount_; + volatile unsigned long releaseCount_; + volatile unsigned long queryInterfaceCount_; +}; +#endif + +#if defined(_X86_) || defined(_AMD64_) + +#define UnknownIncrementReference InterlockedIncrement +#define UnknownDecrementReference InterlockedDecrement +#define UnknownBarrierAfterInterlock() +#define UnknownInterlockedCompareExchangePointer InterlockedCompareExchangePointer +#define UnknownInterlockedCompareExchangePointerForIncrement InterlockedCompareExchangePointer +#define UnknownInterlockedCompareExchangePointerForRelease InterlockedCompareExchangePointer + +#elif defined(_ARM_) + +#define UnknownIncrementReference InterlockedIncrementNoFence +#define UnknownDecrementReference InterlockedDecrementRelease +#define UnknownBarrierAfterInterlock() __dmb(_ARM_BARRIER_ISH) +#define UnknownInterlockedCompareExchangePointer InterlockedCompareExchangePointer +#define UnknownInterlockedCompareExchangePointerForIncrement InterlockedCompareExchangePointerNoFence +#define UnknownInterlockedCompareExchangePointerForRelease InterlockedCompareExchangePointerRelease + +#elif defined(_ARM64_) + +#define UnknownIncrementReference InterlockedIncrementNoFence +#define UnknownDecrementReference InterlockedDecrementRelease +#define UnknownBarrierAfterInterlock() __dmb(_ARM64_BARRIER_ISH) +#define UnknownInterlockedCompareExchangePointer InterlockedCompareExchangePointer +#define UnknownInterlockedCompareExchangePointerForIncrement InterlockedCompareExchangePointerNoFence +#define UnknownInterlockedCompareExchangePointerForRelease InterlockedCompareExchangePointerRelease + +#else + +#error Unsupported architecture. + +#endif + +// Since variadic templates can't have a parameter pack after default arguments, provide a convenient helper for defaults. +#define DETAILS_RTCLASS_FLAGS_ARGUMENTS(RuntimeClassFlagsT) \ + RuntimeClassFlagsT, \ + (RuntimeClassFlagsT::value & InhibitWeakReference) == 0, \ + (RuntimeClassFlagsT::value & WinRt) == WinRt, \ + __WRL_IMPLEMENTS_FTM_BASE__(RuntimeClassFlagsT::value) \ + +template +class __declspec(novtable) RuntimeClassImpl; + +#pragma warning(push) +// PREFast cannot see through template instantiation for AsIID() +#pragma warning(disable: 6388) + +template +class __declspec(novtable) RuntimeClassImpl : + public Details::AdjustImplements::Type, + public RuntimeClassBaseT, + protected RuntimeClassFlags, + public DontUseNewUseMake +#ifdef _PERF_COUNTERS + , public PerfCountersBase +#endif +{ +public: + typedef RuntimeClassFlagsT ClassFlags; + + STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) + { +#ifdef _PERF_COUNTERS + IncrementQueryInterfaceCount(); +#endif + return Super::AsIID(this, riid, ppvObject); + } + + STDMETHOD_(ULONG, AddRef)() + { + return InternalAddRef(); + } + + STDMETHOD_(ULONG, Release)() + { + ULONG ref = InternalRelease(); + if (ref == 0) + { + delete this; + + auto modulePtr = ::Microsoft::WRL::GetModuleBase(); + if (modulePtr != nullptr) + { + modulePtr->DecrementObjectCount(); + } + } + + return ref; + } + +protected: + using Super = RuntimeClassBaseT; + + RuntimeClassImpl() throw() : refcount_(1) + { + } + + virtual ~RuntimeClassImpl() throw() + { + // Set refcount_ to -(LONG_MAX/2) to protect destruction and + // also catch mismatched Release in debug builds + refcount_ = -(LONG_MAX/2); + } + + unsigned long InternalAddRef() throw() + { +#ifdef _PERF_COUNTERS + IncrementAddRefCount(); +#endif + return UnknownIncrementReference(&refcount_); + } + + unsigned long InternalRelease() throw() + { +#ifdef _PERF_COUNTERS + IncrementReleaseCount(); +#endif + // A release fence is required to ensure all guarded memory accesses are + // complete before any thread can begin destroying the object. + unsigned long newValue = UnknownDecrementReference(&refcount_); + if (newValue == 0) + { + // An acquire fence is required before object destruction to ensure + // that the destructor cannot observe values changing on other threads. + UnknownBarrierAfterInterlock(); + } + return newValue; + } + + unsigned long GetRefCount() const throw() + { + return refcount_; + } + + friend class WeakReferenceImpl; + +private: + volatile long refcount_; +}; + +template +struct HasIInspectable; + +template +struct HasIInspectable +{ + static const bool isIInspectable = __is_base_of(IInspectable, I); +}; + +template +struct HasIInspectable +{ + static const bool isIInspectable = HasIInspectable::isIInspectable; +}; + +#ifdef __WRL_STRICT__ +template +#else +template::isIInspectable> +#endif +struct IInspectableInjector; + +template +struct IInspectableInjector +{ + typedef Details::Nil InspectableIfNeeded; +}; + +template +struct IInspectableInjector +{ + typedef IInspectable InspectableIfNeeded; +}; + +// Implements IInspectable in ILst +template +class __declspec(novtable) RuntimeClassImpl : + public Details::AdjustImplements::InspectableIfNeeded, I0, TInterfaces...>::Type, + public RuntimeClassBaseT, + protected RuntimeClassFlags, + public DontUseNewUseMake +#ifdef _PERF_COUNTERS + , public PerfCountersBase +#endif +{ +public: + typedef RuntimeClassFlagsT ClassFlags; + + STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) + { +#ifdef _PERF_COUNTERS + IncrementQueryInterfaceCount(); +#endif + return Super::AsIID(this, riid, ppvObject); + } + + STDMETHOD_(ULONG, AddRef)() + { + return InternalAddRef(); + } + + STDMETHOD_(ULONG, Release)() + { + ULONG ref = InternalRelease(); + if (ref == 0) + { + delete this; + + auto modulePtr = ::Microsoft::WRL::GetModuleBase(); + if (modulePtr != nullptr) + { + modulePtr->DecrementObjectCount(); + } + } + + return ref; + } + + // IInspectable methods + STDMETHOD(GetIids)( + _Out_ ULONG *iidCount, + _When_(*iidCount == 0, _At_(*iids, _Post_null_)) + _When_(*iidCount > 0, _At_(*iids, _Post_notnull_)) + _Result_nullonfailure_ IID **iids) + { + return Super::GetImplementedIIDS(this, iidCount, iids); + } + +#if !defined(__WRL_STRICT__) || !defined(__WRL_FORCE_INSPECTABLE_CLASS_MACRO__) + STDMETHOD(GetRuntimeClassName)(_Out_ HSTRING* runtimeClassName) + { + *runtimeClassName = nullptr; + + __WRL_ASSERT__(false && "Use InspectableClass macro to set runtime class name and trust level."); + +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ErrorHelper::OriginateError(E_NOTIMPL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_NOTIMPL; + } + + STDMETHOD(GetTrustLevel)(_Out_ ::TrustLevel*) + { + __WRL_ASSERT__(false && "Use InspectableClass macro to set runtime class name and trust level."); + +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ErrorHelper::OriginateError(E_NOTIMPL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_NOTIMPL; + } +#endif // !defined(__WRL_STRICT__) || !defined(__WRL_FORCE_INSPECTABLE_CLASS_MACRO__) + +protected: + using Super = RuntimeClassBaseT; + + RuntimeClassImpl() throw() : refcount_(1) + { + } + + virtual ~RuntimeClassImpl() throw() + { + // Set refcount_ to -(LONG_MAX/2) to protect destruction and + // also catch mismatched Release in debug builds + refcount_ = -(LONG_MAX/2); + } + + unsigned long InternalAddRef() throw() + { +#ifdef _PERF_COUNTERS + IncrementAddRefCount(); +#endif + return UnknownIncrementReference(&refcount_); + } + + unsigned long InternalRelease() throw() + { +#ifdef _PERF_COUNTERS + IncrementReleaseCount(); +#endif + // A release fence is required to ensure all guarded memory accesses are + // complete before any thread can begin destroying the object. + unsigned long newValue = UnknownDecrementReference(&refcount_); + if (newValue == 0) + { + // An acquire fence is required before object destruction to ensure + // that the destructor cannot observe values changing on other threads. + UnknownBarrierAfterInterlock(); + } + return newValue; + } + + unsigned long GetRefCount() const throw() + { + return refcount_; + } +private: + volatile long refcount_; +}; + +class StrongReference +{ +public: + StrongReference(long refCount = 1) throw() : strongRefCount_(refCount) {} + + ~StrongReference() throw() + { + // Set refcount_ to -(LONG_MAX/2) to protect destruction and + // also catch mismatched Release in debug builds + strongRefCount_ = -(LONG_MAX / 2); + } + + unsigned long IncrementStrongReference() throw() + { + return UnknownIncrementReference(&strongRefCount_); + } + + unsigned long DecrementStrongReference() throw() + { + // A release fence is required to ensure all guarded memory accesses are + // complete before any thread can begin destroying the object. + unsigned long newValue = UnknownDecrementReference(&strongRefCount_); + if (newValue == 0) + { + // An acquire fence is required before object destruction to ensure + // that the destructor cannot observe values changing on other threads. + UnknownBarrierAfterInterlock(); + } + return newValue; + } + + unsigned long GetStrongReferenceCount() throw() + { + return strongRefCount_; + } + + void SetStrongReference(unsigned long value) throw() + { + strongRefCount_ = value; + } + + long strongRefCount_; +}; + +// To support storing, encoding and decoding reference-count/pointers regardless of the target platform. +// In a RuntimeClass, the refCount_ member can mean either +// 1. actual reference count +// 2. pointer to the weak reference object which holds the strong count +// The member +// 1. If it is a count, the most significant bit will be OFF +// 2. If it is an encoded pointer to the weak reference, the most significant bit will be turned ON +// To test which mode it is +// 1. Test for negative +// 2. If it is, it is an encoded pointer to the weak reference +// 3. If it is not, it is the actual reference count +// To yield the encoded pointer +// 1. Test the value for negative +// 2. If it is, shift the value to the left and cast it to a WeakReferenceImpl* +// +const UINT_PTR EncodeWeakReferencePointerFlag = static_cast(1) << ((sizeof(UINT_PTR) * 8) - 1); + +union ReferenceCountOrWeakReferencePointer +{ + // Represents the count when it is a count (in practice only the least significant 4 bytes) + UINT_PTR refCount; + // Pointer size, *signed* to help with ease of casting and such + INT_PTR rawValue; + // The hint that this could also be a pointer + void* ifHighBitIsSetThenShiftLeftToYieldPointerToWeakReference; +}; + +// Helper methods to test, decode and decode the different representations of ReferenceCountOrWeakReferencePointer +inline bool IsValueAPointerToWeakReference(INT_PTR value) +{ + return value < 0; +} + +// Forward declaration +class WeakReferenceImpl; +inline INT_PTR EncodeWeakReferencePointer(Microsoft::WRL::Details::WeakReferenceImpl* value); +inline Microsoft::WRL::Details::WeakReferenceImpl* DecodeWeakReferencePointer(INT_PTR value); + +// Helper functions, originally from winnt.h, needed to get the semantics right when fetching values +// in multi-threaded scenarios. This is in order to guarantee the compiler emits exactly one read +// (compiler may decide to re-fetch the value later on, which in some cases can break code) +#if defined(_ARM_) + +FORCEINLINE +LONG +ReadULongPtrNoFence ( + _In_ _Interlocked_operand_ DWORD const volatile *Source + ) + +{ + LONG Value; + + Value = __iso_volatile_load32((int *)Source); + return Value; +} + +#elif defined(_ARM64_) + +FORCEINLINE +LONG64 +ReadULongPtrNoFence ( + _In_ _Interlocked_operand_ DWORD64 const volatile *Source + ) + +{ + LONG64 Value; + + Value = __iso_volatile_load64((__int64 *)Source); + return Value; +} + +#elif defined(_X86_) + +FORCEINLINE +LONG +ReadULongPtrNoFence ( + _In_ _Interlocked_operand_ DWORD const volatile *Source + ) + +{ + LONG Value; + + Value = *Source; + return Value; +} + +#elif defined(_AMD64_) + +FORCEINLINE +LONG64 +ReadULongPtrNoFence ( + _In_ _Interlocked_operand_ DWORD64 const volatile *Source + ) + +{ + LONG64 Value; + + Value = *Source; + return Value; +} + +#else + +#error Unsupported architecture. + +#endif + +template +inline T ReadValueFromPointerNoFence(_In_ const volatile T* value) +{ + ULONG_PTR currentValue = ReadULongPtrNoFence(reinterpret_cast(value)); + const T* currentPointerToValue = reinterpret_cast(¤tValue); + return *currentPointerToValue; +} + +inline WeakReferenceImpl* CreateWeakReference(_In_ IUnknown*); + +// Implementation of activatable class that implements IWeakReferenceSource +// and delegates reference counting to WeakReferenceImpl object +template +class __declspec(novtable) RuntimeClassImpl : + public Details::AdjustImplements::InspectableIfNeeded, I0, IWeakReferenceSource, TInterfaces...>::Type, + public RuntimeClassBaseT, + public DontUseNewUseMake +#ifdef _PERF_COUNTERS + , public PerfCountersBase +#endif +{ +public: + typedef RuntimeClassFlagsT ClassFlags; + + RuntimeClassImpl() throw() + { + refCount_.rawValue = 1; + } + + STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) + { +#ifdef _PERF_COUNTERS + IncrementQueryInterfaceCount(); +#endif + return Super::AsIID(this, riid, ppvObject); + } + + STDMETHOD_(ULONG, AddRef)() + { + return InternalAddRef(); + } + + STDMETHOD_(ULONG, Release)() + { + ULONG ref = InternalRelease(); + if (ref == 0) + { + delete this; + + auto modulePtr = ::Microsoft::WRL::GetModuleBase(); + if (modulePtr != nullptr) + { + modulePtr->DecrementObjectCount(); + } + } + + return ref; + } + + // IInspectable methods + STDMETHOD(GetIids)( + _Out_ ULONG *iidCount, + _When_(*iidCount == 0, _At_(*iids, _Post_null_)) + _When_(*iidCount > 0, _At_(*iids, _Post_notnull_)) + _Result_nullonfailure_ IID **iids) + { + return Super::GetImplementedIIDS(this, iidCount, iids); + } + +#if !defined(__WRL_STRICT__) || !defined(__WRL_FORCE_INSPECTABLE_CLASS_MACRO__) + STDMETHOD(GetRuntimeClassName)(_Out_ HSTRING* runtimeClassName) + { + *runtimeClassName = nullptr; + + __WRL_ASSERT__(false && "Use InspectableClass macro to set runtime class name and trust level."); + +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ErrorHelper::OriginateError(E_NOTIMPL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_NOTIMPL; + } + + STDMETHOD(GetTrustLevel)(_Out_ ::TrustLevel*) + { + __WRL_ASSERT__(false && "Use InspectableClass macro to set runtime class name and trust level."); +#if (NTDDI_VERSION >= NTDDI_WINBLUE) + ErrorHelper::OriginateError(E_NOTIMPL, nullptr); +#endif // (NTDDI_VERSION >= NTDDI_WINBLUE) + return E_NOTIMPL; + } +#endif // !defined(__WRL_STRICT__) || !defined(__WRL_FORCE_INSPECTABLE_CLASS_MACRO__) + + STDMETHOD(GetWeakReference)(_Outptr_ IWeakReference **weakReference); + + virtual ~RuntimeClassImpl() throw(); + +protected: + template friend class Details::RuntimeClassBaseT; + using ImplementsHelper = typename Details::AdjustImplements::InspectableIfNeeded, I0, IWeakReferenceSource, TInterfaces...>::Type; + using Super = RuntimeClassBaseT; + + unsigned long InternalAddRef() throw(); + + unsigned long InternalRelease() throw(); + + unsigned long GetRefCount() const throw(); + + friend class WeakReferenceImpl; + +#ifdef __WRL_UNITTEST__ +protected: +#else +private: +#endif + ReferenceCountOrWeakReferencePointer refCount_; +}; + +inline INT_PTR EncodeWeakReferencePointer(Microsoft::WRL::Details::WeakReferenceImpl* value) +{ + return ((reinterpret_cast(value) >> 1) | EncodeWeakReferencePointerFlag); +} + +inline Microsoft::WRL::Details::WeakReferenceImpl* DecodeWeakReferencePointer(INT_PTR value) +{ + return reinterpret_cast(value << 1); +} + +#pragma warning(pop) // C6388 + +template +class __declspec(novtable) RuntimeClassImpl : + public RuntimeClassImpl +{ +}; + +template +class __declspec(novtable) RuntimeClassImpl : + public RuntimeClassImpl +{ +}; + +// To minimize breaks with code written against WRL before variadic support was added, this form is maintained. +template +struct InterfaceListHelper +{ + typedef InterfaceListHelper TypeT; +}; + +template < + typename ILst, + class RuntimeClassFlagsT, + bool implementsWeakReferenceSource = (RuntimeClassFlagsT::value & InhibitWeakReference) == 0, + bool implementsInspectable = (RuntimeClassFlagsT::value & WinRt) == WinRt, + bool implementsFtmBase = __WRL_IMPLEMENTS_FTM_BASE__(RuntimeClassFlagsT::value) +> +class RuntimeClass; + + +template < + typename RuntimeClassFlagsT, + bool implementsWeakReferenceSource, + bool implementsInspectable, + bool implementsFtmBase, + typename ...TInterfaces +> +class RuntimeClass, RuntimeClassFlagsT, implementsWeakReferenceSource, implementsInspectable, implementsFtmBase> : + public RuntimeClassImpl +{ +protected: + HRESULT CustomQueryInterface(REFIID /*riid*/, _Outptr_result_nullonfailure_ void** /*ppvObject*/, _Out_ bool *handled) + { + *handled = false; + return S_OK; + } +}; + +} // namespace Details + +// The RuntimeClass IUnknown methods +// It inherits from Details::RuntimeClass that provides helper methods for reference counting and +// collecting IIDs +template +class RuntimeClass : + public Details::RuntimeClassImpl), TInterfaces...> +{ + RuntimeClass(const RuntimeClass&); + RuntimeClass& operator=(const RuntimeClass&); +protected: + HRESULT CustomQueryInterface(REFIID /*riid*/, _Outptr_result_nullonfailure_ void** /*ppvObject*/, _Out_ bool *handled) + { + *handled = false; + return S_OK; + } +public: + RuntimeClass() throw() + { + auto modulePtr = ::Microsoft::WRL::GetModuleBase(); + if (modulePtr != nullptr) + { + modulePtr->IncrementObjectCount(); + } + } + typedef RuntimeClass RuntimeClassT; +}; + +template +class RuntimeClass, TInterfaces...> : + public Details::RuntimeClassImpl), TInterfaces...> +{ + RuntimeClass(const RuntimeClass&); + RuntimeClass& operator=(const RuntimeClass&); +protected: + HRESULT CustomQueryInterface(REFIID /*riid*/, _Outptr_result_nullonfailure_ void** /*ppvObject*/, _Out_ bool *handled) + { + *handled = false; + return S_OK; + } +public: + RuntimeClass() throw() + { + auto modulePtr = ::Microsoft::WRL::GetModuleBase(); + if (modulePtr != nullptr) + { + modulePtr->IncrementObjectCount(); + } + } + typedef RuntimeClass RuntimeClassT; +}; + +namespace Details +{ +//Weak reference implementation + class WeakReferenceImpl sealed: + public ::Microsoft::WRL::RuntimeClass, IWeakReference>, + public StrongReference + { + public: + WeakReferenceImpl(_In_ IUnknown* unk) throw() : StrongReference(LONG_MAX / 2), unknown_(unk) + { + // Set ref count to 2 to avoid unnecessary interlocked increment operation while returning + // WeakReferenceImpl from GetWeakReference method. One reference is hold by the object the second is hold + // by the caller of GetWeakReference method. + refcount_ = 2; + } + + virtual ~WeakReferenceImpl() throw() + { + } + + STDMETHOD(Resolve)(REFIID riid, _Outptr_result_maybenull_ _Result_nullonfailure_ IInspectable **ppvObject) + { + *ppvObject = nullptr; + + for(;;) + { + long ref = this->strongRefCount_; + if (ref == 0) + { + return S_OK; + } + + // InterlockedCompareExchange calls _InterlockedCompareExchange intrinsic thus we call directly _InterlockedCompareExchange to save the call + if (::_InterlockedCompareExchange(&this->strongRefCount_, ref + 1, ref) == ref) + { +#ifdef _PERF_COUNTERS + // This artificially manipulates the strong ref count via AddRef to account for the resolve + // interlocked operation above when tallying reference counting operations. + unknown_->AddRef(); + ::_InterlockedDecrement(&this->strongRefCount_); +#endif + break; + } + } + + HRESULT hr = unknown_->QueryInterface(riid, reinterpret_cast(ppvObject)); + unknown_->Release(); + return hr; + } + + +private: + IUnknown *unknown_; +}; + +template +RuntimeClassImpl::~RuntimeClassImpl() throw() +{ + if (IsValueAPointerToWeakReference(refCount_.rawValue)) + { + WeakReferenceImpl* weakRef = DecodeWeakReferencePointer(refCount_.rawValue); + weakRef->Release(); + weakRef = nullptr; + } +} + +template +unsigned long RuntimeClassImpl::GetRefCount() const throw() +{ + ReferenceCountOrWeakReferencePointer currentValue = ReadValueFromPointerNoFence(&refCount_); + + if (IsValueAPointerToWeakReference(currentValue.rawValue)) + { + WeakReferenceImpl* weakRef = DecodeWeakReferencePointer(currentValue.rawValue); + return weakRef->GetStrongReferenceCount(); + } + else + { + return static_cast(currentValue.refCount); + } +} + +template +unsigned long RuntimeClassImpl::InternalAddRef() throw() +{ +#ifdef _PERF_COUNTERS + IncrementAddRefCount(); +#endif + + ReferenceCountOrWeakReferencePointer currentValue = ReadValueFromPointerNoFence(&refCount_); + + for (;;) + { + if (!IsValueAPointerToWeakReference(currentValue.rawValue)) + { + UINT_PTR updateValue = currentValue.refCount + 1; + +#ifdef __WRL_UNITTEST__ + OnBeforeInternalAddRefIncrement(); +#endif + + INT_PTR previousValue = reinterpret_cast(UnknownInterlockedCompareExchangePointerForIncrement(reinterpret_cast(&(refCount_.refCount)), reinterpret_cast(updateValue), reinterpret_cast(currentValue.refCount))); + if (previousValue == currentValue.rawValue) + { + return static_cast(updateValue); + } + + currentValue.rawValue = previousValue; + } + else + { + WeakReferenceImpl* weakRef = DecodeWeakReferencePointer(currentValue.rawValue); + return weakRef->IncrementStrongReference(); + } + } +} + +template +unsigned long RuntimeClassImpl::InternalRelease() throw() +{ +#ifdef _PERF_COUNTERS + IncrementReleaseCount(); +#endif + + ReferenceCountOrWeakReferencePointer currentValue = ReadValueFromPointerNoFence(&refCount_); + + for (;;) + { + if (!IsValueAPointerToWeakReference(currentValue.rawValue)) + { + UINT_PTR updateValue = currentValue.refCount - 1; + +#ifdef __WRL_UNITTEST__ + OnBeforeInternalReleaseDecrement(); +#endif + + INT_PTR previousValue = reinterpret_cast(UnknownInterlockedCompareExchangePointerForIncrement(reinterpret_cast(&(refCount_.refCount)), reinterpret_cast(updateValue), reinterpret_cast(currentValue.refCount))); + if (previousValue == currentValue.rawValue) + { + return static_cast(updateValue); + } + + currentValue.rawValue = previousValue; + } + else + { + WeakReferenceImpl* weakRef = DecodeWeakReferencePointer(currentValue.rawValue); + return weakRef->DecrementStrongReference(); + } + } +} + +template +HRESULT RuntimeClassImpl::GetWeakReference(_Outptr_ IWeakReference **weakReference) +{ + WeakReferenceImpl* weakRef = nullptr; + INT_PTR encodedWeakRef = 0; + ReferenceCountOrWeakReferencePointer currentValue = ReadValueFromPointerNoFence(&refCount_); + + *weakReference = nullptr; + + if (IsValueAPointerToWeakReference(currentValue.rawValue)) + { + weakRef = DecodeWeakReferencePointer(currentValue.rawValue); + + weakRef->AddRef(); + *weakReference = weakRef; + return S_OK; + } + + // WeakReferenceImpl is created with ref count 2 to avoid interlocked increment + weakRef = CreateWeakReference(ImplementsHelper::CastToUnknown()); + if (weakRef == nullptr) + { + return E_OUTOFMEMORY; + } + + encodedWeakRef = EncodeWeakReferencePointer(weakRef); + + for (;;) + { + INT_PTR previousValue = 0; + + weakRef->SetStrongReference(static_cast(currentValue.refCount)); + +#ifdef __WRL_UNITTEST__ + OnBeforeGetWeakReferenceSwap(); +#endif + + previousValue = reinterpret_cast(UnknownInterlockedCompareExchangePointer(reinterpret_cast(&(this->refCount_.ifHighBitIsSetThenShiftLeftToYieldPointerToWeakReference)), reinterpret_cast(encodedWeakRef), currentValue.ifHighBitIsSetThenShiftLeftToYieldPointerToWeakReference)); + if (previousValue == currentValue.rawValue) + { + // No need to call AddRef in this case, WeakReferenceImpl is created with ref count 2 to avoid interlocked increment + *weakReference = weakRef; + return S_OK; + } + else if (IsValueAPointerToWeakReference(previousValue)) + { + // Another thread beat this call to create the weak reference. + + delete weakRef; + + weakRef = DecodeWeakReferencePointer(previousValue); + weakRef->AddRef(); + *weakReference = weakRef; + return S_OK; + } + + // Another thread won via an AddRef or Release. + // Let's try again + currentValue.rawValue = previousValue; + } +} + +// Memory allocation for object that doesn't support weak references +// It only allocates memory +template +class MakeAllocator +{ +public: + MakeAllocator() throw() : buffer_(nullptr) + { + } + + ~MakeAllocator() throw() + { + if (buffer_ != nullptr) + { + delete buffer_; + } + } + + void* Allocate() throw() + { + __WRL_ASSERT__(buffer_ == nullptr); + // Allocate memory with operator new(size, nothrow) only + // This will allow developer to override one operator only + // to enable different memory allocation model + buffer_ = (char*) (operator new (sizeof(T), std::nothrow)); + return buffer_; + } + + void Detach() throw() + { + buffer_ = nullptr; + } +private: + char* buffer_; +}; + +} //Details + +#pragma region make overloads + +namespace Details { + +// Make and MakeAndInitialize functions must not be marked as throw() as the constructor is allowed to throw exceptions. +template +ComPtr Make(TArgs&&... args) +{ + static_assert(__is_base_of(Details::RuntimeClassBase, T), "Make can only instantiate types that derive from RuntimeClass"); + ComPtr object; + Details::MakeAllocator allocator; + void *buffer = allocator.Allocate(); + if (buffer != nullptr) + { + auto ptr = new (buffer)T(Details::Forward(args)...); + object.Attach(ptr); + allocator.Detach(); + } + return object; +} + +#pragma warning(push) +#pragma warning(disable:6387 6388 28196) // PREFast does not understand call to ComPtr::CopyTo() is safe here + +template +HRESULT MakeAndInitialize(_Outptr_result_nullonfailure_ I** result, TArgs&&... args) +{ + static_assert(__is_base_of(Details::RuntimeClassBase, T), "Make can only instantiate types that derive from RuntimeClass"); + static_assert(__is_base_of(I, T), "The 'T' runtime class doesn't implement 'I' interface"); + *result = nullptr; + Details::MakeAllocator allocator; + void *buffer = allocator.Allocate(); + if (buffer == nullptr) { return E_OUTOFMEMORY; } + auto ptr = new (buffer)T; + ComPtr object; + object.Attach(ptr); + allocator.Detach(); + HRESULT hr = object->RuntimeClassInitialize(Details::Forward(args)...); + if (FAILED(hr)) { return hr; } + return object.CopyTo(result); +} + +#pragma warning(pop) // C6387 C6388 C28196 + +template +HRESULT MakeAndInitialize(_Inout_ ComPtrRef> ppvObject, TArgs&&... args) +{ + return MakeAndInitialize(ppvObject.ReleaseAndGetAddressOf(), Details::Forward(args)...); +} + +} //end of Details + +using Details::MakeAndInitialize; +using Details::Make; + +#pragma endregion // make overloads + +namespace Details +{ + inline WeakReferenceImpl* CreateWeakReference(_In_ IUnknown* unk) + { + return Make(unk).Detach(); + } +} + +#define InspectableClass(runtimeClassName, trustLevel) \ + public: \ + static _Null_terminated_ const wchar_t* STDMETHODCALLTYPE InternalGetRuntimeClassName() throw() \ + { \ + static_assert((RuntimeClassT::ClassFlags::value & ::Microsoft::WRL::WinRtClassicComMix) == ::Microsoft::WRL::WinRt || \ + (RuntimeClassT::ClassFlags::value & ::Microsoft::WRL::WinRtClassicComMix) == ::Microsoft::WRL::WinRtClassicComMix, \ + "'InspectableClass' macro must not be used with ClassicCom clasess."); \ + static_assert(__is_base_of(::Microsoft::WRL::Details::RuntimeClassBase, RuntimeClassT), "'InspectableClass' macro can only be used with ::Windows::WRL::RuntimeClass types"); \ + static_assert(!__is_base_of(IActivationFactory, RuntimeClassT), "Incorrect usage of IActivationFactory interface. Make sure that your RuntimeClass doesn't implement IActivationFactory interface use ::Windows::WRL::ActivationFactory instead or 'InspectableClass' macro is not used on ::Windows::WRL::ActivationFactory"); \ + return runtimeClassName; \ + } \ + static ::TrustLevel STDMETHODCALLTYPE InternalGetTrustLevel() throw() \ + { \ + return trustLevel; \ + } \ + STDMETHOD(GetRuntimeClassName)(_Out_ HSTRING* runtimeName) \ + { \ + *runtimeName = nullptr; \ + HRESULT hr = S_OK; \ + auto name = InternalGetRuntimeClassName(); \ + if (name != nullptr) \ + { \ + hr = ::WindowsCreateString(name, static_cast(::wcslen(name)), runtimeName); \ + } \ + return hr; \ + } \ + STDMETHOD(GetTrustLevel)(_Out_ ::TrustLevel* trustLvl) \ + { \ + *trustLvl = trustLevel; \ + return S_OK; \ + } \ + STDMETHOD(GetIids)(_Out_ ULONG *iidCount, \ + _When_(*iidCount == 0, _At_(*iids, _Post_null_)) \ + _When_(*iidCount > 0, _At_(*iids, _Post_notnull_)) \ + _Result_nullonfailure_ IID **iids) \ + { \ + return RuntimeClassT::GetIids(iidCount, iids); \ + } \ + STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) \ + { \ + bool handled = false; \ + HRESULT hr = this->CustomQueryInterface(riid, ppvObject, &handled); \ + if (FAILED(hr) || handled) return hr; \ + return RuntimeClassT::QueryInterface(riid, ppvObject); \ + } \ + STDMETHOD_(ULONG, Release)() \ + { \ + return RuntimeClassT::Release(); \ + } \ + STDMETHOD_(ULONG, AddRef)() \ + { \ + return RuntimeClassT::AddRef(); \ + } \ + private: + +#define MixInHelper() \ + public: \ + STDMETHOD(QueryInterface)(REFIID riid, _Outptr_result_nullonfailure_ void **ppvObject) \ + { \ + static_assert((RuntimeClassT::ClassFlags::value & ::Microsoft::WRL::WinRt) == 0, "'MixInClass' macro must not be used with WinRt clasess."); \ + static_assert(__is_base_of(::Microsoft::WRL::Details::RuntimeClassBase, RuntimeClassT), "'MixInHelper' macro can only be used with ::Windows::WRL::RuntimeClass types"); \ + static_assert(!__is_base_of(IClassFactory, RuntimeClassT), "Incorrect usage of IClassFactory interface. Make sure that your RuntimeClass doesn't implement IClassFactory interface use ::Windows::WRL::ClassFactory instead or 'MixInHelper' macro is not used on ::Windows::WRL::ClassFactory"); \ + return RuntimeClassT::QueryInterface(riid, ppvObject); \ + } \ + STDMETHOD_(ULONG, Release)() \ + { \ + return RuntimeClassT::Release(); \ + } \ + STDMETHOD_(ULONG, AddRef)() \ + { \ + return RuntimeClassT::AddRef(); \ + } \ + private: + +// Please make sure that those macros are in sync with those ones from 'wrl/module.h' +#ifndef WrlCreatorMapIncludePragmaEx +#define WrlCreatorMapIncludePragmaEx(className, group) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'WrlCreatorMapIncludePragmaEx' macro"); +#endif + +#ifndef WrlCreatorMapIncludePragma +#define WrlCreatorMapIncludePragma(className) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'WrlCreatorMapIncludePragma' macro"); +#endif + +#ifndef ActivatableClassWithFactoryEx +#define ActivatableClassWithFactoryEx(className, factory, groupId) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'ActivatableClassWithFactoryEx' macro"); +#endif + +#ifndef ActivatableClassWithFactory +#define ActivatableClassWithFactory(className, factory) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'ActivatableClassWithFactory' macro"); +#endif + +#ifndef ActivatableClass +#define ActivatableClass(className) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'ActivatableClass' macro"); +#endif + +#ifndef ActivatableStaticOnlyFactoryEx +#define ActivatableStaticOnlyFactoryEx(factory, serverName) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'ActivatableStaticOnlyFactoryEx' macro"); +#endif + +#ifndef ActivatableStaticOnlyFactory +#define ActivatableStaticOnlyFactory(factory) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'ActivatableStaticOnlyFactory' macro"); +#endif + +#ifndef CoCreatableClassWithFactoryEx +#define CoCreatableClassWithFactoryEx(className, factory, groupId) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'CoCreatableClassWithFactory' macro"); +#endif + +#ifndef CoCreatableClassWithFactory +#define CoCreatableClassWithFactory(className, factory) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'CoCreatableClassWithFactory' macro"); +#endif + +#ifndef CoCreatableClass +#define CoCreatableClass(className) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'CoCreatableClass' macro"); +#endif + +#ifndef CoCreatableClassWrlCreatorMapInclude +#define CoCreatableClassWrlCreatorMapInclude(className) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'CoCreatableClassWrlCreatorMapInclude' macro"); +#endif + +#ifndef CoCreatableClassWrlCreatorMapIncludeEx +#define CoCreatableClassWrlCreatorMapIncludeEx(className, groupId) static_assert(false, "It's required to include 'wrl/module.h' to be able to use 'CoCreatableClassWrlCreatorMapInclude' macro"); +#endif + +#undef UnknownIncrementReference +#undef UnknownDecrementReference +#undef UnknownBarrierAfterInterlock +#undef UnknownInterlockedCompareExchangePointer +#undef UnknownInterlockedCompareExchangePointerForIncrement +#undef UnknownInterlockedCompareExchangePointerForRelease + +}} // namespace Microsoft::WRL + +#pragma warning(pop) + +// Restore packing +#include + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif // _WRL_IMPLEMENTS_H_