wil/tests/MockingTests.cpp

410 строки
14 KiB
C++

// Relatively simple tests as a sanity check to verify that our function mocking & use of detours is working correctly
#include "pch.h"
#include "common.h"
#ifdef WIL_ENABLE_EXCEPTIONS
#include <thread>
#endif
DWORD __stdcall InvertFileAttributes(PCWSTR path)
{
thread_local bool recursive = false;
if (recursive)
return INVALID_FILE_ATTRIBUTES;
recursive = true;
auto result = ::GetFileAttributesW(path);
recursive = false;
return ~result;
}
TEST_CASE("MockingTests::ThreadDetourWithFunctionPointer", "[mocking]")
{
wchar_t buffer[MAX_PATH];
REQUIRE(::GetSystemDirectoryW(buffer, ARRAYSIZE(buffer)) != 0);
auto realAttr = ::GetFileAttributesW(buffer);
REQUIRE(realAttr != INVALID_FILE_ATTRIBUTES);
REQUIRE(realAttr != 0);
{
witest::detoured_thread_function<&::GetFileAttributesW> detour;
REQUIRE_SUCCEEDED(detour.reset(InvertFileAttributes));
auto inverseAttr = ::GetFileAttributesW(buffer);
REQUIRE(inverseAttr == ~realAttr);
}
REQUIRE(::GetFileAttributesW(buffer) == realAttr);
}
TEST_CASE("MockingTests::GlobalDetourWithFunctionPointer", "[mocking]")
{
wchar_t buffer[MAX_PATH];
REQUIRE(::GetSystemDirectoryW(buffer, ARRAYSIZE(buffer)) != 0);
auto realAttr = ::GetFileAttributesW(buffer);
REQUIRE(realAttr != INVALID_FILE_ATTRIBUTES);
REQUIRE(realAttr != 0);
{
witest::detoured_global_function<&::GetFileAttributesW> detour;
REQUIRE_SUCCEEDED(detour.reset(InvertFileAttributes));
auto inverseAttr = ::GetFileAttributesW(buffer);
REQUIRE(inverseAttr == ~realAttr);
}
REQUIRE(::GetFileAttributesW(buffer) == realAttr);
}
TEST_CASE("MockingTests::ThreadDetourWithLambda", "[mocking]")
{
PCWSTR path = L"$*&><"; // Purposefully nonesense/invalid to test the mocking functionality
{
DWORD expectedAttr = 0;
witest::detoured_thread_function<&::GetFileAttributesW> detour;
REQUIRE_SUCCEEDED(detour.reset([&](PCWSTR) -> DWORD {
return expectedAttr;
}));
REQUIRE(::GetFileAttributesW(path) == expectedAttr);
expectedAttr = 0xc0ffee;
REQUIRE(::GetFileAttributesW(path) == expectedAttr);
}
REQUIRE(::GetFileAttributesW(path) == INVALID_FILE_ATTRIBUTES);
}
TEST_CASE("MockingTests::GlobalDetourWithLambda", "[mocking]")
{
PCWSTR path = L"$*&><"; // Purposefully nonesense/invalid to test the mocking functionality
{
DWORD expectedAttr = 0;
witest::detoured_global_function<&::GetFileAttributesW> detour;
REQUIRE_SUCCEEDED(detour.reset([&](PCWSTR) -> DWORD {
return expectedAttr;
}));
REQUIRE(::GetFileAttributesW(path) == expectedAttr);
expectedAttr = 0xc0ffee;
REQUIRE(::GetFileAttributesW(path) == expectedAttr);
}
REQUIRE(::GetFileAttributesW(path) == INVALID_FILE_ATTRIBUTES);
}
#pragma optimize("", off) // Don't evaluate at compile-time
__declspec(noinline) int __cdecl LocalAddFunction(int lhs, int rhs)
{
return lhs + rhs;
}
TEST_CASE("MockingTests::ThreadDetourLocalFunction", "[mocking]")
{
{
witest::detoured_thread_function<&LocalAddFunction> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) {
return lhs * rhs;
}));
REQUIRE(LocalAddFunction(2, 3) == 6);
}
REQUIRE(LocalAddFunction(2, 3) == 5);
}
TEST_CASE("MockingTests::GlobalDetourLocalFunction", "[mocking]")
{
{
witest::detoured_global_function<&LocalAddFunction> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) {
return lhs * rhs;
}));
REQUIRE(LocalAddFunction(2, 3) == 6);
}
REQUIRE(LocalAddFunction(2, 3) == 5);
}
__declspec(noinline) int __cdecl LocalAddFunctionNoexcept(int lhs, int rhs) noexcept
{
return lhs + rhs;
}
__declspec(noinline) int __stdcall LocalAddFunctionStdcallNoexcept(int lhs, int rhs) noexcept
{
return lhs + rhs;
}
TEST_CASE("MockingTests::ThreadDetourNoexceptFunction", "[mocking]")
{
{
witest::detoured_thread_function<&LocalAddFunctionNoexcept> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) noexcept {
return lhs * rhs;
}));
REQUIRE(LocalAddFunctionNoexcept(2, 3) == 6);
}
REQUIRE(LocalAddFunctionNoexcept(2, 3) == 5);
{
witest::detoured_thread_function<&LocalAddFunctionStdcallNoexcept> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) noexcept {
return lhs * rhs;
}));
REQUIRE(LocalAddFunctionStdcallNoexcept(2, 3) == 6);
}
REQUIRE(LocalAddFunctionStdcallNoexcept(2, 3) == 5);
}
TEST_CASE("MockingTests::GlobalDetourNoexceptFunction", "[mocking]")
{
{
witest::detoured_global_function<&LocalAddFunctionNoexcept> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) noexcept {
return lhs * rhs;
}));
REQUIRE(LocalAddFunctionNoexcept(2, 3) == 6);
}
REQUIRE(LocalAddFunctionNoexcept(2, 3) == 5);
{
witest::detoured_global_function<&LocalAddFunctionStdcallNoexcept> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) noexcept {
return lhs * rhs;
}));
REQUIRE(LocalAddFunctionStdcallNoexcept(2, 3) == 6);
}
REQUIRE(LocalAddFunctionStdcallNoexcept(2, 3) == 5);
}
TEST_CASE("MockingTests::RecursiveThreadDetouring", "[mocking]")
{
{
witest::detoured_thread_function<&LocalAddFunction> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) {
return lhs + rhs + LocalAddFunction(lhs * 2, rhs * 2);
}));
{
witest::detoured_thread_function<&LocalAddFunction> detour2;
REQUIRE_SUCCEEDED(detour2.reset([](int lhs, int rhs) {
return lhs + rhs + LocalAddFunction(lhs * 3, rhs * 3);
}));
// Last registration should be the first to execute
// 5 + 3 * (5 + 2 * 5)
REQUIRE(LocalAddFunction(2, 3) == 50);
}
// (2 + 3) + (4 + 6)
REQUIRE(LocalAddFunction(2, 3) == 15);
}
REQUIRE(LocalAddFunction(2, 3) == 5);
}
TEST_CASE("MockingTests::RecursiveGlobalDetouring", "[mocking]")
{
{
witest::detoured_global_function<&LocalAddFunction> detour;
REQUIRE_SUCCEEDED(detour.reset([](int lhs, int rhs) {
return lhs + rhs + LocalAddFunction(lhs * 2, rhs * 2);
}));
{
witest::detoured_global_function<&LocalAddFunction> detour2;
REQUIRE_SUCCEEDED(detour2.reset([](int lhs, int rhs) {
return lhs + rhs + LocalAddFunction(lhs * 3, rhs * 3);
}));
// Last registration should be the first to execute
// 5 + 3 * (5 + 2 * 5)
REQUIRE(LocalAddFunction(2, 3) == 50);
}
// (2 + 3) + (4 + 6)
REQUIRE(LocalAddFunction(2, 3) == 15);
}
REQUIRE(LocalAddFunction(2, 3) == 5);
}
TEST_CASE("MockingTests::ThreadDetourMoving", "[mocking]")
{
witest::detoured_thread_function<&LocalAddFunction> outer;
{
witest::detoured_thread_function<&LocalAddFunction> middle;
{
witest::detoured_thread_function<&LocalAddFunction> inner;
REQUIRE_SUCCEEDED(inner.reset([](int lhs, int rhs) {
return lhs * rhs;
}));
REQUIRE(LocalAddFunction(2, 3) == 6);
middle = std::move(inner);
}
REQUIRE(LocalAddFunction(2, 3) == 6);
outer = std::move(middle);
}
REQUIRE(LocalAddFunction(2, 3) == 6);
{
witest::detoured_thread_function other(std::move(outer));
REQUIRE(LocalAddFunction(2, 3) == 6);
}
REQUIRE(LocalAddFunction(2, 3) == 5); // Reverted back by now
}
TEST_CASE("MockingTests::ThreadDetourSwap", "[mocking]")
{
{
witest::detoured_thread_function<&LocalAddFunction> outer;
REQUIRE_SUCCEEDED(outer.reset([](int lhs, int rhs) {
return lhs * rhs;
}));
{
witest::detoured_thread_function<&LocalAddFunction> inner;
REQUIRE_SUCCEEDED(inner.reset([](int lhs, int rhs) {
return 2 * LocalAddFunction(lhs, rhs);
}));
REQUIRE(LocalAddFunction(2, 3) == 12); // 2 * (2 * 3)
inner.swap(outer);
REQUIRE(LocalAddFunction(2, 3) == 12); // Order of evaluation should stay the same
outer.swap(inner); // Swap the other way around
REQUIRE(LocalAddFunction(2, 3) == 12); // Still the same...
outer.swap(inner); // So that inner's lambda is copied into 'outer' when 'inner' goes out of scope
}
REQUIRE(LocalAddFunction(2, 3) == 10); // 2 * (2 + 3)
}
REQUIRE(LocalAddFunction(2, 3) == 5); // Reverted back by now
}
TEST_CASE("MockingTests::ThreadDetourDestructOutOfOrder", "[mocking]")
{
{
witest::detoured_thread_function<&LocalAddFunction> delayed;
{
// Under "normal" circumstances, registration behaves like a stack and we'll always remove from the head of the list.
// Here we force the first registration to fall out of scope to test handling removal of the non-head element.
witest::detoured_thread_function<&LocalAddFunction> first;
REQUIRE_SUCCEEDED(first.reset([](int lhs, int rhs) {
return lhs * rhs;
}));
REQUIRE_SUCCEEDED(delayed.reset([](int lhs, int rhs) {
return 2 * (lhs + rhs);
}));
REQUIRE(LocalAddFunction(2, 3) == 10); // Should execute 'delayed'
}
REQUIRE(LocalAddFunction(2, 3) == 10); // Should still execute 'delayed'
}
REQUIRE(LocalAddFunction(2, 3) == 5); // Reverted back by now
}
TEST_CASE("MockingTests::GlobalDetourDestructOutOfOrder", "[mocking]")
{
{
witest::detoured_global_function<&LocalAddFunction> delayed;
{
// Under "normal" circumstances, registration behaves like a stack and we'll always remove from the head of the list.
// Here we force the first registration to fall out of scope to test handling removal of the non-head element.
witest::detoured_global_function<&LocalAddFunction> first;
REQUIRE_SUCCEEDED(first.reset([](int lhs, int rhs) {
return lhs * rhs;
}));
REQUIRE_SUCCEEDED(delayed.reset([](int lhs, int rhs) {
return 2 * (lhs + rhs);
}));
REQUIRE(LocalAddFunction(2, 3) == 10); // Should execute 'delayed'
}
REQUIRE(LocalAddFunction(2, 3) == 10); // Should still execute 'delayed'
}
REQUIRE(LocalAddFunction(2, 3) == 5); // Reverted back by now
}
#ifdef WIL_ENABLE_EXCEPTIONS
TEST_CASE("MockingTests::ThreadDetourMultithreaded", "[mocking]")
{
witest::detoured_thread_function<&LocalAddFunction> detour([](int lhs, int rhs) {
return lhs * rhs;
});
int otherThreadResult = 0;
std::thread thread([&] {
otherThreadResult = LocalAddFunction(2, 3);
});
thread.join();
REQUIRE(otherThreadResult == 5);
}
TEST_CASE("MockingTests::GlobalDetourMultithreaded", "[mocking]")
{
witest::detoured_global_function<&LocalAddFunction> detour([](int lhs, int rhs) {
return lhs * rhs;
});
int otherThreadResult = 0;
std::thread thread([&] {
otherThreadResult = LocalAddFunction(2, 3);
});
thread.join();
REQUIRE(otherThreadResult == 6);
}
TEST_CASE("MockingTests::GlobalDetourDestructorRace", "[mocking]")
{
wil::unique_event detourRunningEvent(wil::EventOptions::None);
wil::unique_event nonDetourContinueEvent(wil::EventOptions::None);
wil::unique_event nonDetourCompleteEvent(wil::EventOptions::None);
witest::detoured_thread_function<&::SleepConditionVariableSRW> cvWaitDetour(
[&](PCONDITION_VARIABLE cv, PSRWLOCK lock, DWORD dwMilliseconds, ULONG flags) {
// This should be called during the call to 'reset' since there's an "active" call
nonDetourContinueEvent.SetEvent(); // Kick off a non-detoured call
return ::SleepConditionVariableSRW(cv, lock, dwMilliseconds, flags);
});
witest::detoured_global_function<&LocalAddFunction> detour([&](int lhs, int rhs) {
detourRunningEvent.SetEvent();
nonDetourCompleteEvent.wait(); // Wait until the non-detoured call is complete (implies we're in 'reset')
return lhs * rhs;
});
int detouredResult = 0;
std::thread detouredThread([&] {
detouredResult = LocalAddFunction(2, 3);
});
int nonDetouredResult = 0;
std::thread nonDetouredThread([&] {
nonDetourContinueEvent.wait(); // Wait until 'reset' is called
nonDetouredResult = LocalAddFunction(2, 3);
nonDetourCompleteEvent.SetEvent(); // Let the original call complete, which allows 'reset' to complete
});
detourRunningEvent.wait(); // Wait for 'detouredThread' to kick off & invoke the detoured function
detour.reset(); // Kick off everything to continue
// NOTE: While 'reset' will wait for all calls to complete, it will NOT wait for anything after that point in time to complete
// (e.g. assignment of the result to a variable), so we need to call 'join' first
detouredThread.join();
nonDetouredThread.join();
// By the time 'reset' completes, all calls should also have completed, hence the check before the calls to 'join' are fine
REQUIRE(detouredResult == 6);
REQUIRE(nonDetouredResult == 5);
}
#endif