ebpf-for-windows/netebpfext/user/nmr_impl.cpp

354 строки
12 KiB
C++
Исходник Обычный вид История

Refactor netebpfext in preparation for fuzz testing (#1464) * Refactor netebpfext (#1351) * refactor netebpfext to allow unit and fuzzing Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Normalize parameter casing Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Normalize types Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Fix release build Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Fix CMake build Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Fix NetEbpfExt_headers CMake target Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Rename unit to user Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Switch netebpfext_unit -> netebpfext_user Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Flesh out kernel thunk Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * More canonical naming and W4 clean Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Fix prefast issue * Fix CMake Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Fix code analysis Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Co-authored-by: Alan Jowett <alan.jowett@microsoft.com> * Clean up fw_thunk.h (#1369) * Clean up fw_thunk.h Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Fix SAL annotations Signed-off-by: Dave Thaler <dthaler@microsoft.com> Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Netebpf ext unit test (#1365) * Add netebfpext start/stop unit tests Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Update NMR thunk and expand netebpf_ext helper Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Fix code-analysis issues Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Co-authored-by: Alan Jowett <alan.jowett@microsoft.com> * Remove duplicate NMR includes (#1370) * Remove duplicate NMR includes Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback and more header cleanup Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Co-authored-by: Alan Jowett <alan.jowett@microsoft.com> * Refactor NMR user mode to be thread safe (#1371) * Refactor NMR user mode to be thread safe Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Co-authored-by: Alan Jowett <alan.jowett@microsoft.com> * Add a unit test that tests classify_packet in user mode (#1379) * Add packet test scaffolding Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Add test Signed-off-by: Dave Thaler <dthaler@microsoft.com> * More of test Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Add dummy packet to test Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Add dummy program Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Cleanup Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Fix mock NMR bug Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Temporarily add printfs Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Incorporate fix for 1378 Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Fix test bug Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Add a packet-drop case to the test Signed-off-by: Dave Thaler <dthaler@microsoft.com> Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Try fixing include path Signed-off-by: Dave Thaler <dthaler@microsoft.com> * Fix double free of workitem Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * Revert unrelated changes in ebpfcore and ebpfapi Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> * PR feedback Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Signed-off-by: Alan Jowett <alan.jowett@microsoft.com> Signed-off-by: Dave Thaler <dthaler@microsoft.com> Co-authored-by: Alan Jowett <alanjo@microsoft.com> Co-authored-by: Alan Jowett <alan.jowett@microsoft.com>
2022-10-18 22:21:54 +03:00
// Copyright (c) Microsoft Corporation
// SPDX-License-Identifier: MIT
#include "nmr_impl.h"
_nmr::nmr_provider_handle
_nmr::register_provider(_In_ const NPI_PROVIDER_CHARACTERISTICS& characteristics, _In_opt_ const void* context)
{
// Add the provider to the list of providers.
nmr_provider_handle provider_handle = add(providers, characteristics, context);
// Notify existing clients about the new provider.
perform_bind(providers, provider_handle, clients);
return provider_handle;
}
bool
_nmr::deregister_provider(_In_ nmr_provider_handle provider_handle)
{
// Remove the provider from the list of providers.
deactivate(providers, provider_handle);
// If the unbind returned pending, then the caller needs to wait for the unbind to complete.
if (perform_unbind(providers, provider_handle)) {
// Pending unbind.
return true;
}
// Unbind is complete.
remove(providers, provider_handle);
return false;
}
void
_nmr::wait_for_deregister_provider(_In_ nmr_provider_handle provider_handle)
{
// Wait for the unbind to complete.
remove(providers, provider_handle);
}
_nmr::nmr_client_handle
_nmr::register_client(_In_ const NPI_CLIENT_CHARACTERISTICS& characteristics, _In_opt_ const void* context)
{
// Add the client to the list of clients.
nmr_client_handle client_handle = add(clients, characteristics, context);
// Notify existing providers about the new client.
perform_bind(clients, client_handle, providers);
return client_handle;
}
bool
_nmr::deregister_client(_In_ nmr_client_handle client_handle)
{
// Remove the client from the list of clients.
deactivate(clients, client_handle);
// If the unbind returned pending, then the caller needs to wait for the unbind to complete.
if (perform_unbind(clients, client_handle)) {
// Pending unbind.
return true;
}
// Unbind is complete.
remove(clients, client_handle);
return false;
}
void
_nmr::wait_for_deregister_client(_In_ nmr_client_handle client_handle)
{
// Wait for the unbind to complete.
remove(clients, client_handle);
}
void
_nmr::binding_detach_client_complete(_In_ nmr_binding_handle binding_handle)
{
std::unique_lock l(lock);
auto it = bindings.find(binding_handle);
if (it == bindings.end()) {
throw std::runtime_error("invalid handle");
}
_nmr::binding& binding = it->second;
ASSERT(binding.client_binding_status == binding_status::UnbindPending);
binding.client_binding_status = UnbindComplete;
bool complete = (binding.provider_binding_status == binding_status::UnbindComplete);
l.unlock();
if (complete) {
// Signal the detach complete.
unbind_complete(binding_handle);
}
}
void
_nmr::binding_detach_provider_complete(_In_ nmr_binding_handle binding_handle)
{
std::unique_lock l(lock);
auto it = bindings.find(binding_handle);
if (it == bindings.end()) {
throw std::runtime_error("invalid handle");
}
_nmr::binding& binding = it->second;
ASSERT(binding.provider_binding_status == binding_status::UnbindPending);
binding.provider_binding_status = UnbindComplete;
bool complete = (binding.client_binding_status == binding_status::UnbindComplete);
l.unlock();
if (complete) {
// Signal the detach complete.
unbind_complete(binding_handle);
}
}
NTSTATUS
_nmr::client_attach_provider(
_In_ nmr_binding_handle binding_handle,
_In_ __drv_aliasesMem const void* client_binding_context,
_In_ const void* client_dispatch,
_Outptr_ const void** provider_binding_context,
_Outptr_ const void** provider_dispatch)
{
std::unique_lock l(lock);
// Resolve the binding_handle to the binding.
auto it = bindings.find(binding_handle);
if (it == bindings.end()) {
throw std::runtime_error("invalid handle");
}
auto& binding = it->second;
// Save the client's per binding context and dispatch table.
binding.client_binding_context = client_binding_context;
binding.client_dispatch = client_dispatch;
l.unlock();
// Call the provider's attach client.
NTSTATUS status = binding.provider.characteristics.ProviderAttachClient(
const_cast<void*>(binding_handle),
const_cast<void*>(binding.provider.context),
&binding.client.characteristics.ClientRegistrationInstance,
const_cast<void*>(client_binding_context),
client_dispatch,
const_cast<void**>(&binding.provider_binding_context),
&binding.provider_dispatch);
// If successful, save the provider's per binding context and dispatch table.
if (NT_SUCCESS(status)) {
*provider_binding_context = binding.provider_binding_context;
*provider_dispatch = binding.provider_dispatch;
}
return status;
}
std::optional<_nmr::pending_action_t>
_nmr::bind(_Inout_ client_registration& client, _Inout_ provider_registration& provider)
{
// Match on NPI ID.
if (*client.characteristics.ClientRegistrationInstance.NpiId !=
*provider.characteristics.ProviderRegistrationInstance.NpiId) {
return std::nullopt;
}
// Skip if client or provider are deregestering.
if (client.deregistering || provider.deregistering) {
return std::nullopt;
}
// Acquire references on both client and provider to prevent them from unloading.
_InterlockedIncrement64(&client.bindings);
_InterlockedIncrement64(&provider.bindings);
nmr_binding_handle h = reinterpret_cast<nmr_binding_handle>(next_handle++);
bindings.insert({h, {provider, client}});
return {[&client, &provider, h, this]() {
NTSTATUS status = client.characteristics.ClientAttachProvider(
const_cast<HANDLE>(h),
const_cast<void*>(client.context),
&provider.characteristics.ProviderRegistrationInstance);
if (!NT_SUCCESS(status)) {
unbind_complete(h);
}
}};
}
void
_nmr::unbind_complete(_In_ nmr_binding_handle binding_handle)
{
std::unique_lock l(lock);
auto it = bindings.find(binding_handle);
if (it == bindings.end()) {
throw std::runtime_error("invalid handle");
}
auto& binding = it->second;
// Notify the client that that the binding context can be freed.
binding.client.characteristics.ClientCleanupBindingContext(const_cast<void*>(binding.client_binding_context));
// Notify the provider that that the binding context can be freed.
binding.provider.characteristics.ProviderCleanupBindingContext(const_cast<void*>(binding.provider_binding_context));
_InterlockedDecrement64(&binding.provider.bindings);
_InterlockedDecrement64(&binding.client.bindings);
bindings.erase(it);
// Notify the client or provider to check if they have any pending bindings.
bindings_changed.notify_all();
}
bool
_nmr::begin_unbind(_In_ nmr_binding_handle binding_handle)
{
std::unique_lock l(lock);
auto it = bindings.find(binding_handle);
if (it == bindings.end()) {
throw std::runtime_error("invalid handle");
}
auto& binding = it->second;
if (binding.client_binding_status != Ready || binding.provider_binding_status != Ready) {
return true;
}
NTSTATUS client_detach_provider_status =
binding.client.characteristics.ClientDetachProvider(const_cast<void*>(binding.client_binding_context));
NTSTATUS provider_detach_client_status =
binding.provider.characteristics.ProviderDetachClient(const_cast<void*>(binding.provider_binding_context));
binding.provider_binding_status = (client_detach_provider_status == STATUS_PENDING)
? binding_status::UnbindPending
: binding_status::UnbindComplete;
binding.client_binding_status = (provider_detach_client_status == STATUS_PENDING) ? binding_status::UnbindPending
: binding_status::UnbindComplete;
bool complete =
((binding.client_binding_status == binding_status::UnbindComplete) &&
(binding.provider_binding_status == binding_status::UnbindComplete));
l.unlock();
if (complete) {
unbind_complete(binding_handle);
return false;
}
return true;
}
template <typename collection_t, typename characteristics_t>
collection_t::value_type::first_type
_nmr::add(_Inout_ collection_t& collection, _In_ const characteristics_t& characteristics, _In_ const void* context)
{
std::unique_lock l(lock);
auto handle = reinterpret_cast<collection_t::value_type::first_type>(next_handle++);
collection.insert({handle, {characteristics, context}});
return handle;
}
template <typename collection_t>
void
_nmr::deactivate(_Inout_ collection_t& collection, _Inout_ collection_t::value_type::first_type handle)
{
std::unique_lock l(lock);
auto it = collection.find(handle);
if (it == collection.end()) {
throw std::runtime_error("invalid handle");
}
// Block new bindings.
it->second.deregistering = true;
}
template <typename collection_t>
void
_nmr::remove(_Inout_ collection_t& collection, _In_ collection_t::value_type::first_type handle)
{
std::unique_lock l(lock);
auto it = collection.find(handle);
if (it == collection.end()) {
throw std::runtime_error("invalid handle");
}
// Wait for bindings to reach zero if requested.
if (it->second.bindings > 0) {
bindings_changed.wait(l, [&]() { return it->second.bindings == 0; });
}
collection.erase(it);
}
template <typename initiator_collection_t, typename target_collection_t>
void
_nmr::perform_bind(
_Inout_ initiator_collection_t& initiator_collection,
_In_ initiator_collection_t::value_type::first_type initiator_handle,
_Inout_ target_collection_t& target_collection)
{
// Queue up the bind for each target to performed outside the lock.
std::vector<pending_action_t> pending_actions;
std::unique_lock l(lock);
auto it = initiator_collection.find(initiator_handle);
if (it == initiator_collection.end()) {
throw std::runtime_error("invalid handle");
}
auto& initiator = it->second;
for (auto& [target_handle, target] : target_collection) {
// If the initiator is a client, then the target must be a provider.
if constexpr (std::is_same<initiator_collection_t::value_type::second_type, client_registration>::value) {
auto result = bind(initiator, target);
if (result.has_value()) {
pending_actions.push_back(result.value());
}
}
// If the initiator is a provider, then the target must be a client.
if constexpr (std::is_same<initiator_collection_t::value_type::second_type, provider_registration>::value) {
auto result = bind(target, initiator);
if (result.has_value()) {
pending_actions.push_back(result.value());
}
}
}
l.unlock();
for (auto& action : pending_actions) {
action();
}
}
template <typename initiator_collection_t>
bool
_nmr::perform_unbind(
_Inout_ initiator_collection_t& initiator_collection,
_In_ initiator_collection_t::value_type::first_type initiator_handle)
{
bool pending = false;
std::vector<nmr_binding_handle> handles_to_unbind;
std::unique_lock l(lock);
auto it = initiator_collection.find(initiator_handle);
if (it == initiator_collection.end()) {
throw std::runtime_error("invalid handle");
}
auto& initiator = it->second;
// Find all the bindings that have the initiator as the client or provider.
for (auto& [binding_handle, binding] : bindings) {
// If the initiator is a client, then the target must be a provider.
if constexpr (std::is_same<initiator_collection_t::value_type::second_type, client_registration>::value) {
if (&binding.client == &initiator) {
handles_to_unbind.push_back(binding_handle);
}
}
// If the initiator is a provider, then the target must be a client.
if constexpr (std::is_same<initiator_collection_t::value_type::second_type, provider_registration>::value) {
if (&binding.provider == &initiator) {
handles_to_unbind.push_back(binding_handle);
}
}
}
l.unlock();
for (auto& binding_handle : handles_to_unbind) {
pending |= begin_unbind(binding_handle);
}
return pending;
}