зеркало из https://github.com/microsoft/cppwinrt.git
Add `resume_agile` to allow coroutine to resume in any apartment (#1356)
This commit is contained in:
Родитель
23c4ced66a
Коммит
fac72c82b0
|
@ -99,44 +99,49 @@ namespace winrt::impl
|
|||
return async.GetResults();
|
||||
}
|
||||
|
||||
template<typename Awaiter>
|
||||
struct disconnect_aware_handler
|
||||
struct ignore_apartment_context {};
|
||||
|
||||
template<bool preserve_context, typename Awaiter>
|
||||
struct disconnect_aware_handler : private std::conditional_t<preserve_context, resume_apartment_context, ignore_apartment_context>
|
||||
{
|
||||
disconnect_aware_handler(Awaiter* awaiter, coroutine_handle<> handle) noexcept
|
||||
: m_awaiter(awaiter), m_handle(handle) { }
|
||||
|
||||
disconnect_aware_handler(disconnect_aware_handler&& other) noexcept
|
||||
: m_context(std::move(other.m_context))
|
||||
, m_awaiter(std::exchange(other.m_awaiter, {}))
|
||||
, m_handle(std::exchange(other.m_handle, {})) { }
|
||||
disconnect_aware_handler(disconnect_aware_handler&& other) = default;
|
||||
|
||||
~disconnect_aware_handler()
|
||||
{
|
||||
if (m_handle) Complete();
|
||||
if (m_handle.value) Complete();
|
||||
}
|
||||
|
||||
template<typename Async>
|
||||
void operator()(Async&&, Windows::Foundation::AsyncStatus status)
|
||||
{
|
||||
m_awaiter->status = status;
|
||||
m_awaiter.value->status = status;
|
||||
Complete();
|
||||
}
|
||||
|
||||
private:
|
||||
resume_apartment_context m_context;
|
||||
Awaiter* m_awaiter;
|
||||
coroutine_handle<> m_handle;
|
||||
movable_primitive<Awaiter*> m_awaiter;
|
||||
movable_primitive<coroutine_handle<>, nullptr> m_handle;
|
||||
|
||||
void Complete()
|
||||
{
|
||||
if (m_awaiter->suspending.exchange(false, std::memory_order_release))
|
||||
if (m_awaiter.value->suspending.exchange(false, std::memory_order_release))
|
||||
{
|
||||
m_handle = nullptr; // resumption deferred to await_suspend
|
||||
m_handle.value = nullptr; // resumption deferred to await_suspend
|
||||
}
|
||||
else
|
||||
{
|
||||
auto handle = std::exchange(m_handle, {});
|
||||
if (!resume_apartment(m_context, handle, &m_awaiter->failure))
|
||||
auto handle = m_handle.detach();
|
||||
if constexpr (preserve_context)
|
||||
{
|
||||
if (!resume_apartment(*this, handle, &m_awaiter.value->failure))
|
||||
{
|
||||
handle.resume();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
handle.resume();
|
||||
}
|
||||
|
@ -145,7 +150,7 @@ namespace winrt::impl
|
|||
};
|
||||
|
||||
#ifdef WINRT_IMPL_COROUTINES
|
||||
template <typename Async>
|
||||
template <typename Async, bool preserve_context = true>
|
||||
struct await_adapter : cancellable_awaiter<await_adapter<Async>>
|
||||
{
|
||||
await_adapter(Async const& async) : async(async) { }
|
||||
|
@ -185,7 +190,7 @@ namespace winrt::impl
|
|||
private:
|
||||
bool register_completed_callback(coroutine_handle<> handle)
|
||||
{
|
||||
async.Completed(disconnect_aware_handler(this, handle));
|
||||
async.Completed(disconnect_aware_handler<preserve_context, await_adapter>(this, handle));
|
||||
return suspending.exchange(false, std::memory_order_acquire);
|
||||
}
|
||||
|
||||
|
@ -249,6 +254,15 @@ namespace winrt::impl
|
|||
}
|
||||
|
||||
#ifdef WINRT_IMPL_COROUTINES
|
||||
WINRT_EXPORT namespace winrt
|
||||
{
|
||||
template<typename Async, typename = std::enable_if_t<std::is_convertible_v<Async, winrt::Windows::Foundation::IAsyncInfo>>>
|
||||
inline impl::await_adapter<Async, false> resume_agile(Async const& async)
|
||||
{
|
||||
return { async };
|
||||
};
|
||||
}
|
||||
|
||||
WINRT_EXPORT namespace winrt::Windows::Foundation
|
||||
{
|
||||
inline impl::await_adapter<IAsyncAction> operator co_await(IAsyncAction const& async)
|
||||
|
|
|
@ -52,23 +52,14 @@ namespace winrt::impl
|
|||
{
|
||||
resume_apartment_context() = default;
|
||||
resume_apartment_context(std::nullptr_t) : m_context(nullptr), m_context_type(-1) {}
|
||||
resume_apartment_context(resume_apartment_context const&) = default;
|
||||
resume_apartment_context(resume_apartment_context&& other) noexcept :
|
||||
m_context(std::move(other.m_context)), m_context_type(std::exchange(other.m_context_type, -1)) {}
|
||||
resume_apartment_context& operator=(resume_apartment_context const&) = default;
|
||||
resume_apartment_context& operator=(resume_apartment_context&& other) noexcept
|
||||
{
|
||||
m_context = std::move(other.m_context);
|
||||
m_context_type = std::exchange(other.m_context_type, -1);
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool valid() const noexcept
|
||||
{
|
||||
return m_context_type >= 0;
|
||||
return m_context_type.value >= 0;
|
||||
}
|
||||
|
||||
com_ptr<IContextCallback> m_context = try_capture<IContextCallback>(WINRT_IMPL_CoGetObjectContext);
|
||||
int32_t m_context_type = get_apartment_type().first;
|
||||
movable_primitive<int32_t, -1> m_context_type = get_apartment_type().first;
|
||||
};
|
||||
|
||||
inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept
|
||||
|
@ -124,7 +115,7 @@ namespace winrt::impl
|
|||
{
|
||||
return false;
|
||||
}
|
||||
else if (context.m_context_type == 1 /* APTTYPE_MTA */)
|
||||
else if (context.m_context_type.value == 1 /* APTTYPE_MTA */)
|
||||
{
|
||||
resume_background(handle);
|
||||
return true;
|
||||
|
|
|
@ -193,6 +193,25 @@ namespace winrt::impl
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, auto empty_value = T{}>
|
||||
struct movable_primitive
|
||||
{
|
||||
T value = empty_value;
|
||||
movable_primitive() = default;
|
||||
movable_primitive(T const& init) : value(init) {}
|
||||
movable_primitive(movable_primitive const&) = default;
|
||||
movable_primitive(movable_primitive&& other) :
|
||||
value(other.detach()) {}
|
||||
movable_primitive& operator=(movable_primitive const&) = default;
|
||||
movable_primitive& operator=(movable_primitive&& other)
|
||||
{
|
||||
value = other.detach();
|
||||
return *this;
|
||||
}
|
||||
|
||||
T detach() { return std::exchange(value, empty_value); }
|
||||
};
|
||||
|
||||
template <typename T, typename Enable = void>
|
||||
struct arg
|
||||
{
|
||||
|
|
|
@ -18,13 +18,9 @@ namespace
|
|||
|
||||
static handle signal{ CreateEventW(nullptr, false, false, nullptr) };
|
||||
|
||||
IAsyncAction OtherForegroundAsync()
|
||||
IAsyncAction OtherForegroundAsync(DispatcherQueue dispatcher)
|
||||
{
|
||||
// Simple coroutine that completes on a unique STA thread.
|
||||
|
||||
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
|
||||
auto dispatcher = controller.DispatcherQueue();
|
||||
|
||||
// Simple coroutine that completes on the specified STA thread.
|
||||
co_await resume_foreground(dispatcher);
|
||||
}
|
||||
|
||||
|
@ -35,37 +31,37 @@ namespace
|
|||
co_await resume_background();
|
||||
}
|
||||
|
||||
IAsyncAction ForegroundAsync(DispatcherQueue dispatcher)
|
||||
// Coroutine that completes on dispatcher1, while potentially blocking dispatcher2.
|
||||
IAsyncAction ForegroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2)
|
||||
{
|
||||
REQUIRE(!is_sta());
|
||||
co_await resume_foreground(dispatcher);
|
||||
co_await resume_foreground(dispatcher1);
|
||||
REQUIRE(is_sta());
|
||||
|
||||
// This exercises one STA thread waiting on another thus one context callback
|
||||
// completing on another.
|
||||
uint32_t id = GetCurrentThreadId();
|
||||
co_await OtherForegroundAsync();
|
||||
co_await OtherForegroundAsync(dispatcher2);
|
||||
REQUIRE(id == GetCurrentThreadId());
|
||||
|
||||
// This just avoids the ForegroundAsync coroutine completing before
|
||||
// BackgroundAsync waits on the result, forcing the Completed handler
|
||||
// to be called on the foreground thread. This just makes the test
|
||||
// success/failure more predictable.
|
||||
// This Sleep() makes it more likely that the caller will actually suspend in await_suspend,
|
||||
// so that the Completed handler triggers a resumption from the dispatcher1 thread.
|
||||
Sleep(100);
|
||||
}
|
||||
|
||||
fire_and_forget SignalFromForeground(DispatcherQueue dispatcher)
|
||||
fire_and_forget SignalFromForeground(DispatcherQueue dispatcher1)
|
||||
{
|
||||
REQUIRE(!is_sta());
|
||||
co_await resume_foreground(dispatcher);
|
||||
co_await resume_foreground(dispatcher1);
|
||||
REQUIRE(is_sta());
|
||||
|
||||
// Previously, this signal was never raised because the foreground thread
|
||||
// was always blocked waiting for ContextCallback to return.
|
||||
// Previously, we never got here because of a deadlock:
|
||||
// The dispatcher1 thread was blocked waiting for ContextCallback to return,
|
||||
// but the ContextCallback is waiting for this event to get signaled.
|
||||
REQUIRE(SetEvent(signal.get()));
|
||||
}
|
||||
|
||||
IAsyncAction BackgroundAsync(DispatcherQueue dispatcher)
|
||||
IAsyncAction BackgroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2)
|
||||
{
|
||||
// Switch to a background (MTA) thread.
|
||||
co_await resume_background();
|
||||
|
@ -76,19 +72,19 @@ namespace
|
|||
co_await OtherBackgroundAsync();
|
||||
REQUIRE(!is_sta());
|
||||
|
||||
// Wait for a coroutine that completes on a foreground (STA) thread.
|
||||
co_await ForegroundAsync(dispatcher);
|
||||
// Wait for a coroutine that completes on a the dispatcher1 thread (STA).
|
||||
co_await ForegroundAsync(dispatcher1, dispatcher2);
|
||||
|
||||
// Resumption should automatically switch to a background (MTA) thread
|
||||
// without blocking the Completed handler (which would in turn block the foreground thread).
|
||||
// without blocking the Completed handler (which would in turn block the dispatcher1 thread).
|
||||
REQUIRE(!is_sta());
|
||||
|
||||
// Attempt to signal from the foreground thread under the assumption
|
||||
// that the foreground thread is not blocked.
|
||||
SignalFromForeground(dispatcher);
|
||||
// Attempt to signal from the dispatcher1 thread under the assumption
|
||||
// that the dispatcher1 thread is not blocked.
|
||||
SignalFromForeground(dispatcher1);
|
||||
|
||||
// Block the background (MTA) thread indefinitely until the signal is raied.
|
||||
// Previously this would deadlock.
|
||||
// Block the background (MTA) thread indefinitely until the signal is raised.
|
||||
// Previously this would hang because the signal never got raised.
|
||||
REQUIRE(WAIT_OBJECT_0 == WaitForSingleObject(signal.get(), INFINITE));
|
||||
}
|
||||
}
|
||||
|
@ -99,9 +95,44 @@ TEST_CASE("await_adapter", "[.clang-crash]")
|
|||
#else
|
||||
TEST_CASE("await_adapter")
|
||||
#endif
|
||||
{
|
||||
auto controller1 = DispatcherQueueController::CreateOnDedicatedThread();
|
||||
auto controller2 = DispatcherQueueController::CreateOnDedicatedThread();
|
||||
|
||||
BackgroundAsync(controller1.DispatcherQueue(), controller2.DispatcherQueue()).get();
|
||||
controller1.ShutdownQueueAsync().get();
|
||||
controller2.ShutdownQueueAsync().get();
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
IAsyncAction OtherBackgroundDelayAsync()
|
||||
{
|
||||
// Simple coroutine that completes on some MTA thread after a brief delay
|
||||
// to ensure that the caller has suspended.
|
||||
|
||||
co_await resume_after(100ms);
|
||||
}
|
||||
|
||||
IAsyncAction AgileAsync(DispatcherQueue dispatcher)
|
||||
{
|
||||
// Switch to the STA.
|
||||
co_await resume_foreground(dispatcher);
|
||||
REQUIRE(is_sta());
|
||||
|
||||
// Ask for agile resumption of a coroutine that finishes on a background thread.
|
||||
// Add a 100ms delay to ensure we suspend.
|
||||
co_await resume_agile(OtherBackgroundDelayAsync());
|
||||
// We should be on the background thread now.
|
||||
REQUIRE(!is_sta());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("await_adapter_agile")
|
||||
{
|
||||
auto controller = DispatcherQueueController::CreateOnDedicatedThread();
|
||||
auto dispatcher = controller.DispatcherQueue();
|
||||
|
||||
BackgroundAsync(dispatcher).get();
|
||||
AgileAsync(dispatcher).get();
|
||||
controller.ShutdownQueueAsync().get();
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче