From 0df9a4434bae488bb23fb223b814ae9a2f41d84b Mon Sep 17 00:00:00 2001 From: Brian Pepin Date: Mon, 13 Apr 2020 10:46:53 -0700 Subject: [PATCH] Fix race condition in termination code (#512) * Fix race in termination * Task Queue should alert monitors during termination --- Source/Task/TaskQueue.cpp | 102 +++++++++++++---------- Tests/UnitTests/Tests/TaskQueueTests.cpp | 32 +++++++ 2 files changed, 88 insertions(+), 46 deletions(-) diff --git a/Source/Task/TaskQueue.cpp b/Source/Task/TaskQueue.cpp index b3f845f1..ed8b8a0c 100644 --- a/Source/Task/TaskQueue.cpp +++ b/Source/Task/TaskQueue.cpp @@ -512,15 +512,16 @@ void __stdcall TaskQueuePortImpl::Terminate( _In_ void* token) { TerminationEntry* term = static_cast(token); + referenced_ptr cxt(term->portContext); // Prevent anything else from coming into the queue - term->portContext->SetStatus(TaskQueuePortStatus::Terminated); + cxt->SetStatus(TaskQueuePortStatus::Terminated); - CancelPendingEntries(term->portContext, true); + CancelPendingEntries(cxt.get(), true); // Are there existing suspends? AddSuspend returns // true if this is the first suspend added. - if (term->portContext->AddSuspend()) + if (cxt->AddSuspend()) { ScheduleTermination(term); } @@ -531,47 +532,7 @@ void __stdcall TaskQueuePortImpl::Terminate( } // Balance our add - term->portContext->RemoveSuspend(); -} - -void TaskQueuePortImpl::ScheduleTermination( - _In_ TerminationEntry* term) -{ - // Insert the termination callback into the queue. Even if the - // main queue is empty, we still signal it and run through - // a cycle. This ensures we flush the queue out with no - // races and that the termination callback happens on the right - // thread. - - if (term->callback != nullptr) - { - // This never fails because we preallocate the - // list node. - m_terminationList->push_back(term, term->node); - term->node = 0; // now owned by the list - } - - // We will not signal until we are marked as terminated. The queue could - // still be moving while we are running this terminate call. - - SignalQueue(); - - // We must ensure we poke the queue threads in case there's - // nothing submitted - switch (m_dispatchMode) - { - case XTaskQueueDispatchMode::SerializedThreadPool: - case XTaskQueueDispatchMode::ThreadPool: - m_threadPool.Submit(); - break; - - case XTaskQueueDispatchMode::Immediate: - DrainOneItem(); - break; - - default: - break; - } + cxt->RemoveSuspend(); } HRESULT __stdcall TaskQueuePortImpl::Attach( @@ -1050,6 +1011,52 @@ void TaskQueuePortImpl::SignalTerminations() } } +void TaskQueuePortImpl::ScheduleTermination( + _In_ TerminationEntry* term) +{ + // Insert the termination callback into the queue. Even if the + // main queue is empty, we still signal it and run through + // a cycle. This ensures we flush the queue out with no + // races and that the termination callback happens on the right + // thread. + + if (term->callback != nullptr) + { + // This never fails because we preallocate the + // list node. + m_terminationList->push_back(term, term->node); + term->node = 0; // now owned by the list + } + + // We will not signal until we are marked as terminated. The queue could + // still be moving while we are running this terminate call. + + SignalQueue(); + + // We must ensure we poke the queue threads in case there's + // nothing submitted + switch (m_dispatchMode) + { + case XTaskQueueDispatchMode::SerializedThreadPool: + case XTaskQueueDispatchMode::ThreadPool: + m_threadPool.Submit(); + break; + + default: + break; + } + + m_attachedContexts.Visit([](ITaskQueuePortContext* portContext) + { + portContext->ItemQueued(); + }); + + if (m_dispatchMode == XTaskQueueDispatchMode::Immediate) + { + DrainOneItem(); + } +} + #ifdef _WIN32 void CALLBACK TaskQueuePortImpl::WaitCallback( _In_ PTP_CALLBACK_INSTANCE instance, @@ -1495,8 +1502,11 @@ void TaskQueueImpl::OnTerminationCallback(_In_ void* context) entry->callback(entry->context); } - entry->owner->m_termination.terminated = true; - entry->owner->m_termination.cv.notify_all(); + { + std::unique_lock lock(entry->owner->m_termination.lock); + entry->owner->m_termination.terminated = true; + entry->owner->m_termination.cv.notify_all(); + } entry->owner->Release(); delete entry; diff --git a/Tests/UnitTests/Tests/TaskQueueTests.cpp b/Tests/UnitTests/Tests/TaskQueueTests.cpp index bc437107..33ee4b27 100644 --- a/Tests/UnitTests/Tests/TaskQueueTests.cpp +++ b/Tests/UnitTests/Tests/TaskQueueTests.cpp @@ -963,6 +963,38 @@ public: VERIFY_IS_TRUE(data.terminationCalled); } + DEFINE_TEST_CASE(VerifyWaitTermination) + { + uint64_t start = GetTickCount64(); + do + { + XTaskQueueHandle queue; + HRESULT hr = XTaskQueueCreate(XTaskQueueDispatchMode::ThreadPool, XTaskQueueDispatchMode::ThreadPool, &queue); + if (FAILED(hr)) VERIFY_FAIL(); + + hr = XTaskQueueTerminate(queue, true, nullptr, nullptr); + if (FAILED(hr)) VERIFY_FAIL(); + + XTaskQueueCloseHandle(queue); + } while(GetTickCount64() - start < 5000); + } + + DEFINE_TEST_CASE(VerifyManualDispatchAtTermination) + { + AutoQueueHandle queue; + + VERIFY_SUCCEEDED(XTaskQueueCreate(XTaskQueueDispatchMode::Manual, XTaskQueueDispatchMode::Manual, &queue)); + + auto dispatcher = [](void*, XTaskQueueHandle queue, XTaskQueuePort port) + { + XTaskQueueDispatch(queue, port, 0); + }; + + XTaskQueueRegistrationToken token; + VERIFY_SUCCEEDED(XTaskQueueRegisterMonitor(queue, nullptr, dispatcher, &token)); + VERIFY_SUCCEEDED(XTaskQueueTerminate(queue, true, nullptr, nullptr)); + } + DEFINE_TEST_CASE(VerifyTerminationOfCompositeQueue) { struct TestData