diff --git a/external/Catch2 b/external/Catch2 index b5be64204..04382af4c 160000 --- a/external/Catch2 +++ b/external/Catch2 @@ -1 +1 @@ -Subproject commit b5be6420427331c2bf1160fa7eedf0f5f1145b6b +Subproject commit 04382af4c640d801d332191528d11bd7ae5819f3 diff --git a/libs/platform/user/kernel_um.cpp b/libs/platform/user/kernel_um.cpp index 7f788e758..3b116369c 100644 --- a/libs/platform/user/kernel_um.cpp +++ b/libs/platform/user/kernel_um.cpp @@ -2,18 +2,147 @@ // SPDX-License-Identifier: MIT #include +#include #include +#include #include "ebpf_platform.h" #include "kernel_um.h" -typedef struct _mock_rundown_ref +/*** + * @brief This following class implements a mock of the Windows Kernel's rundown reference implementation. + * 1) It uses a map to track the number of references to a given EX_RUNDOWN_REF structure. + * 2) The address of the EX_RUNDOWN_REF structure is used as the key to the map. + * 3) A single condition variable is used to wait for the ref count of any EX_RUNDOWN_REF structure to reach 0. + * 4) The class is a singleton and is created during static initialization and destroyed during static destruction. + */ +typedef class _rundown_ref_table { - std::mutex lock; - std::condition_variable cv; - size_t count = 0; - bool rundown_in_progress = false; -} mock_rundown_ref; + private: + static std::unique_ptr<_rundown_ref_table> _instance; + + public: + // The constructor and destructor should be private to ensure that the class is a singleton, but that is not + // possible because the singleton instance is stored in a unique_ptr which requires the constructor to be public. + // The instance of the class can be accessed using the instance() method. + _rundown_ref_table() = default; + ~_rundown_ref_table() = default; + + /** + * @brief Get the singleton instance of the rundown ref table. + * + * @return The singleton instance of the rundown ref table. + */ + static _rundown_ref_table& + instance() + { + return *_instance; + } + + /** + * @brief Initialize the rundown ref table entry for the given context. + * + * @param[in] context The address of a EX_RUNDOWN_REF structure. + */ + void + initialize_rundown_ref(_In_ const void* context) + { + std::unique_lock lock(_lock); + + // Re-initialize the entry if it already exists. + if (_rundown_ref_counts.find((uint64_t)context) != _rundown_ref_counts.end()) { + _rundown_ref_counts.erase((uint64_t)context); + } + + _rundown_ref_counts[(uint64_t)context] = {false, 0}; + } + + /** + * @brief Acquire a rundown ref for the given context. + * + * @param[in] context The address of a EX_RUNDOWN_REF structure. + * @retval true Rundown has not started. + * @retval false Rundown has started. + */ + bool + acquire_rundown_ref(_In_ const void* context) + { + std::unique_lock lock(_lock); + + // Fail if the entry is not initialized. + if (_rundown_ref_counts.find((uint64_t)context) == _rundown_ref_counts.end()) { + throw std::runtime_error("rundown ref table not initialized"); + } + + // Check if the entry is already rundown. + if (std::get<0>(_rundown_ref_counts[(uint64_t)context])) { + return false; + } + + // Increment the ref count if the entry is not rundown. + std::get<1>(_rundown_ref_counts[(uint64_t)context])++; + + return true; + } + + /** + * @brief Release a rundown ref for the given context. + * + * @param[in] context The address of a EX_RUNDOWN_REF structure. + */ + void + release_rundown_ref(_In_ const void* context) + { + std::unique_lock lock(_lock); + + // Fail if the entry is not initialized. + if (_rundown_ref_counts.find((uint64_t)context) == _rundown_ref_counts.end()) { + throw std::runtime_error("rundown ref table not initialized"); + } + + if (std::get<1>(_rundown_ref_counts[(uint64_t)context]) == 0) { + throw std::runtime_error("rundown ref table already released"); + } + + std::get<1>(_rundown_ref_counts[(uint64_t)context])--; + + if (std::get<1>(_rundown_ref_counts[(uint64_t)context]) == 0) { + _rundown_ref_cv.notify_all(); + } + } + + /** + * @brief Wait for the rundown ref count to reach 0 for the given context. + * + * @param[in] context The address of a EX_RUNDOWN_REF structure. + */ + void + wait_for_rundown_ref(_In_ const void* context) + { + std::unique_lock lock(_lock); + + // Fail if the entry is not initialized. + if (_rundown_ref_counts.find((uint64_t)context) == _rundown_ref_counts.end()) { + throw std::runtime_error("rundown ref table not initialized"); + } + + auto& [rundown, ref_count] = _rundown_ref_counts[(uint64_t)context]; + rundown = true; + // Wait for the ref count to reach 0. + _rundown_ref_cv.wait(lock, [&ref_count] { return ref_count == 0; }); + } + + private: + std::mutex _lock; + std::map> _rundown_ref_counts; + std::condition_variable _rundown_ref_cv; +} rundown_ref_table_t; + +/** + * @brief The singleton instance of the rundown ref table. Created during static initialization and destroyed during + * static destruction. + */ +std::unique_ptr<_rundown_ref_table> rundown_ref_table_t::_instance = std::make_unique(); typedef struct _IO_WORKITEM { @@ -66,39 +195,33 @@ unsigned long __cdecl DbgPrintEx( void ExInitializeRundownProtection(_Out_ EX_RUNDOWN_REF* rundown_ref) { - rundown_ref->inner = new mock_rundown_ref(); +#pragma warning(push) +#pragma warning(suppress : 6001) // Uninitialized memory. The rundown_ref is used as a key in a map and is not + // dereferenced. + rundown_ref_table_t::instance().initialize_rundown_ref(rundown_ref); +#pragma warning(pop) } void ExWaitForRundownProtectionRelease(_Inout_ EX_RUNDOWN_REF* rundown_ref) { - auto& rundown = *rundown_ref->inner; - std::unique_lock l(rundown.lock); - rundown.rundown_in_progress = true; - rundown_ref->inner->cv.wait(l, [&] { return rundown.count == 0; }); + rundown_ref_table_t::instance().wait_for_rundown_ref(rundown_ref); } BOOLEAN ExAcquireRundownProtection(_Inout_ EX_RUNDOWN_REF* rundown_ref) { - auto& rundown = *rundown_ref->inner; - std::unique_lock l(rundown.lock); - if (rundown.rundown_in_progress) { - return FALSE; - } else { - rundown.count++; + if (rundown_ref_table_t::instance().acquire_rundown_ref(rundown_ref)) { return TRUE; + } else { + return FALSE; } } void ExReleaseRundownProtection(_Inout_ EX_RUNDOWN_REF* rundown_ref) { - auto& rundown = *rundown_ref->inner; - std::unique_lock l(rundown.lock); - assert(rundown.count > 0); - rundown.count--; - rundown.cv.notify_all(); + rundown_ref_table_t::instance().release_rundown_ref(rundown_ref); } _Acquires_exclusive_lock_(push_lock->lock) void ExAcquirePushLockExclusiveEx( diff --git a/libs/platform/user/kernel_um.h b/libs/platform/user/kernel_um.h index cb3f8bf37..cf132253b 100644 --- a/libs/platform/user/kernel_um.h +++ b/libs/platform/user/kernel_um.h @@ -53,7 +53,7 @@ extern "C" } EX_SPIN_LOCK; typedef struct _EX_RUNDOWN_REF { - struct _mock_rundown_ref* inner; + void* reserved; } EX_RUNDOWN_REF; typedef struct _IO_WORKITEM IO_WORKITEM, *PIO_WORKITEM; typedef void