Fixing more unnecessary interlocked issues

This commit is contained in:
Keith Horton 2024-05-26 22:54:30 -07:00
Родитель 1a2b1de73c
Коммит f1235bb3d1
14 изменённых файлов: 351 добавлений и 342 удалений

Просмотреть файл

@ -248,7 +248,7 @@ void ctsSocket::SetSocket(SOCKET _s) noexcept
m_socket.reset(_s);
}
void ctsSocket::CompleteState(DWORD) noexcept
void ctsSocket::CompleteState(DWORD) const noexcept
{
SetEvent(g_RemovedSocketEvent);
}

Двоичные данные
ctsPerf/ctsPerf.aps

Двоичный файл не отображается.

Двоичные данные
ctsPerf/ctsPerf.rc

Двоичный файл не отображается.

Двоичные данные
ctsTraffic/Resource.rc

Двоичный файл не отображается.

Просмотреть файл

@ -510,7 +510,7 @@ namespace ctsTraffic
{
ctsConfig::g_configSettings->TcpStatusDetails.m_bytesSent.Add(currentTransfer);
}
else
else if (ctsTaskAction::Recv == originalTask.m_ioAction)
{
ctsConfig::g_configSettings->TcpStatusDetails.m_bytesRecv.Add(currentTransfer);
}
@ -834,14 +834,16 @@ namespace ctsTraffic
ctsIoPatternError ctsIoPatternPull::CompleteTaskBackToPattern(const ctsTask& task, uint32_t completedBytes) noexcept
{
// because this is called while holding the parent lock (which is the pattern lock)
// -- we don't need to use Interlocked* to write to m_statistics (no m_statistics requires Interlocked*)
if (ctsTaskAction::Send == task.m_ioAction)
{
m_statistics.m_bytesSent.Add(completedBytes);
m_statistics.m_bytesSent.AddNoLock(completedBytes);
m_sendBytesInflight -= completedBytes;
}
else
else if (ctsTaskAction::Recv == task.m_ioAction)
{
m_statistics.m_bytesRecv.Add(completedBytes);
m_statistics.m_bytesRecv.AddNoLock(completedBytes);
++m_recvNeeded;
}
@ -896,14 +898,16 @@ namespace ctsTraffic
ctsIoPatternError ctsIoPatternPush::CompleteTaskBackToPattern(const ctsTask& task, uint32_t completedBytes) noexcept
{
// because this is called while holding the parent lock (which is the pattern lock)
// -- we don't need to use Interlocked* to write to m_statistics (no m_statistics requires Interlocked*)
if (ctsTaskAction::Send == task.m_ioAction)
{
m_statistics.m_bytesSent.Add(completedBytes);
m_statistics.m_bytesSent.AddNoLock(completedBytes);
m_sendBytesInflight -= completedBytes;
}
else
else if (ctsTaskAction::Recv == task.m_ioAction)
{
m_statistics.m_bytesRecv.Add(completedBytes);
m_statistics.m_bytesRecv.AddNoLock(completedBytes);
++m_recvNeeded;
}
@ -981,13 +985,15 @@ namespace ctsTraffic
ctsIoPatternError ctsIoPatternPushPull::CompleteTaskBackToPattern(const ctsTask& task, uint32_t currentTransfer) noexcept
{
// because this is called while holding the parent lock (which is the pattern lock)
// -- we don't need to use Interlocked* to write to m_statistics (no m_statistics requires Interlocked*)
if (ctsTaskAction::Send == task.m_ioAction)
{
m_statistics.m_bytesSent.Add(currentTransfer);
m_statistics.m_bytesSent.AddNoLock(currentTransfer);
}
else
else if (ctsTaskAction::Recv == task.m_ioAction)
{
m_statistics.m_bytesRecv.Add(currentTransfer);
m_statistics.m_bytesRecv.AddNoLock(currentTransfer);
}
m_ioNeeded = true;
@ -1099,11 +1105,13 @@ namespace ctsTraffic
ctsIoPatternError ctsIoPatternDuplex::CompleteTaskBackToPattern(const ctsTask& task, uint32_t completedBytes) noexcept
{
// because this is called while holding the parent lock (which is the pattern lock)
// -- we don't need to use Interlocked* to write to m_statistics (no m_statistics requires Interlocked*)
switch (task.m_ioAction)
{
case ctsTaskAction::Send:
{
m_statistics.m_bytesSent.Add(completedBytes);
m_statistics.m_bytesSent.AddNoLock(completedBytes);
m_sendBytesInflight -= completedBytes;
// first, we need to adjust the total back from our over-subscription guard when this task was created
@ -1114,7 +1122,7 @@ namespace ctsTraffic
}
case ctsTaskAction::Recv:
{
m_statistics.m_bytesRecv.Add(completedBytes);
m_statistics.m_bytesRecv.AddNoLock(completedBytes);
++m_recvNeeded;
// first, we need to adjust the total back from our over-subscription guard when this task was created

Просмотреть файл

@ -477,7 +477,9 @@ public:
void PrintStatistics(const ctl::ctSockaddr& localAddr, const ctl::ctSockaddr& remoteAddr) noexcept override
{
// before printing the final results, make sure the timers are stopped
if (0 == GetLastPatternError() && 0 == m_statistics.GetBytesTransferred())
// because this is called while holding the parent lock (which is the pattern lock)
// -- we don't need to use Interlocked* to read m_statistics (no m_statistics requires Interlocked*)
if (0 == GetLastPatternError() && 0 == m_statistics.GetBytesTransferredNoLock())
{
PRINT_DEBUG_INFO(L"\t\tctsIOPattern::PrintStatistics : reporting a successful IO completion but transfered zero bytes\n");
UpdateLastPatternError(ctsIoPatternError::TooFewBytes);

Просмотреть файл

@ -203,11 +203,13 @@ ctsIoPatternError ctsIoPatternMediaStreamClient::CompleteTaskBackToPattern(const
// track the # of *bits* received
ctsConfig::g_configSettings->UdpStatusDetails.m_bitsReceived.Add(static_cast<int64_t>(completedBytes) * 8LL);
// local stats are called under lock - no need for Interlocked* calls
// because this is called while holding the parent lock (which is the pattern lock)
// -- we don't need to use Interlocked* to write to m_statistics (no m_statistics requires Interlocked*)
m_statistics.m_bitsReceived.AddNoLock(static_cast<int64_t>(completedBytes) * 8LL);
const auto receivedsequenceNumber = ctsMediaStreamMessage::GetSequenceNumberFromTask(task);
if (receivedsequenceNumber > m_finalFrame)
const auto receivedSequenceNumber = ctsMediaStreamMessage::GetSequenceNumberFromTask(task);
if (receivedSequenceNumber > m_finalFrame)
{
ctsConfig::g_configSettings->UdpStatusDetails.m_errorFrames.Increment();
// local stats are called under lock - no need for Interlocked* calls
@ -215,7 +217,7 @@ ctsIoPatternError ctsIoPatternMediaStreamClient::CompleteTaskBackToPattern(const
PRINT_DEBUG_INFO(
L"\t\tctsIOPatternMediaStreamClient recevieved **an unknown** seq number (%lld) (outside the final frame %lu)\n",
receivedsequenceNumber,
receivedSequenceNumber,
m_finalFrame);
}
else
@ -224,7 +226,7 @@ ctsIoPatternError ctsIoPatternMediaStreamClient::CompleteTaskBackToPattern(const
// search our circular queue (starting at the head_entry)
// for the seq number we just received, and if found, tag as received
//
const auto foundSlot = FindSequenceNumber(receivedsequenceNumber);
const auto foundSlot = FindSequenceNumber(receivedSequenceNumber);
if (foundSlot != m_frameEntries.end())
{
const auto bufferedQpc = *reinterpret_cast<int64_t*>(task.m_buffer + 8);
@ -246,7 +248,7 @@ ctsIoPatternError ctsIoPatternMediaStreamClient::CompleteTaskBackToPattern(const
// stop the timer once we receive the last frame
// - it's not perfect (e.g. might have received them out of order)
// - but it will be very close for tracking the total bits/sec
if (static_cast<uint32_t>(receivedsequenceNumber) == m_finalFrame)
if (static_cast<uint32_t>(receivedSequenceNumber) == m_finalFrame)
{
EndStatistics();
}
@ -258,18 +260,18 @@ ctsIoPatternError ctsIoPatternMediaStreamClient::CompleteTaskBackToPattern(const
// local stats are called under lock - no need for Interlocked* calls
m_statistics.m_errorFrames.IncrementNoLock();
if (receivedsequenceNumber < m_headEntry->m_sequenceNumber)
if (receivedSequenceNumber < m_headEntry->m_sequenceNumber)
{
PRINT_DEBUG_INFO(
L"\t\tctsIOPatternMediaStreamClient received **a stale** seq number (%lld) - current seq number (%lld)\n",
receivedsequenceNumber,
receivedSequenceNumber,
m_headEntry->m_sequenceNumber);
}
else
{
PRINT_DEBUG_INFO(
L"\t\tctsIOPatternMediaStreamClient recevieved **a future** seq number (%lld) - head of queue (%lld) tail of queue (%llu)\n",
receivedsequenceNumber,
receivedSequenceNumber,
m_headEntry->m_sequenceNumber,
m_headEntry->m_sequenceNumber + m_frameEntries.size() - 1);
}

Просмотреть файл

@ -117,7 +117,6 @@ namespace ctsTraffic
//
// increment IO count while issuing this Impl, so we hold a ref-count during this out of band callback
// TODO: socket is locked - no need for interlocked
if (lambdaSharedSocket->IncrementIo() > 1)
{
// only running this one task in the OOB callback
@ -140,7 +139,6 @@ namespace ctsTraffic
});
// increment IO count while issuing this Impl, so we hold a ref-count during this out of band callback
// TODO: socket is locked - no need for interlocked
sharedSocket->IncrementIo();
IoImplStatus status = ctsMediaStreamClientIoImpl(
sharedSocket,
@ -247,7 +245,6 @@ namespace ctsTraffic
case ctsTaskAction::Recv:
{
// add-ref the IO about to start
// TODO: the socket locked when ctsMediaStreamClientIoImpl is called
sharedSocket->IncrementIo();
auto callback = [weak_reference = std::weak_ptr(sharedSocket), task](OVERLAPPED* ov) noexcept
{

Просмотреть файл

@ -176,7 +176,6 @@ void ctsReadWriteIocp(const std::weak_ptr<ctsSocket>& weakSocket) noexcept
// else we need to initiate another IO
// add-ref the IO about to start
// TODO: socket is locked - no need for interlocked
ioCount = sharedSocket->IncrementIo();
std::shared_ptr<ctl::ctThreadIocp> ioThreadPool;

Просмотреть файл

@ -730,7 +730,6 @@ namespace ctsTraffic { namespace Rioiocp
// if we're here, we're attempting IO
// pre-incremenet IO tracking on the socket before issuing the IO
// TODO: socket is locked - no need for interlocked
ioRefcount = sharedSocket->IncrementIo();
// must ensure we have room in the RQ & CQ before initiating the IO

Просмотреть файл

@ -305,7 +305,6 @@ namespace ctsTraffic
// where it's handled appropriately
// increment IO for this IO request
// TODO: socket is locked - no need for interlocked
sharedSocket->IncrementIo();
// run the ctsIOTask (next_io) that was scheduled through the TP timer
@ -363,7 +362,6 @@ namespace ctsTraffic
// The IO ref-count must be incremented here to hold an IO count on the socket
// - so that we won't inadvertently call complete_state() while IO is still being scheduled
//
// TODO: socket is locked - no need for interlocked
sharedSocket->IncrementIo();
ctsSendRecvStatus status{};
@ -377,7 +375,6 @@ namespace ctsTraffic
}
// increment IO for each individual request
// TODO: socket is locked - no need for interlocked
sharedSocket->IncrementIo();
if (nextIo.m_timeOffsetMilliseconds > 0)

Просмотреть файл

@ -15,8 +15,6 @@ See the Apache Version 2.0 License for specific language governing permissions a
#include "ctsSocket.h"
// OS headers
#include <Windows.h>
// ctl headers
#include <ctMemoryGuard.hpp>
// project headers
#include "ctsConfig.h"
#include "ctsSocketState.h"
@ -26,335 +24,342 @@ See the Apache Version 2.0 License for specific language governing permissions a
namespace ctsTraffic
{
using namespace ctl;
using namespace std;
using namespace ctl;
using namespace std;
// default values are assigned in the class declaration
ctsSocket::ctsSocket(weak_ptr<ctsSocketState> parent) noexcept :
m_parent(std::move(parent))
{
}
_No_competing_thread_ ctsSocket::~ctsSocket() noexcept
{
// shutdown() tears down the socket object
Shutdown();
/*
// wait for all IO to be completed before deleting the pattern if this is RIO
// as the IO requests are still using the RIO Buffers we gave it
if (WI_IsFlagSet(ctsConfig::g_configSettings->SocketFlags, WSA_FLAG_REGISTERED_IO))
// default values are assigned in the class declaration
ctsSocket::ctsSocket(weak_ptr<ctsSocketState> parent) noexcept :
m_parent(std::move(parent))
{
while (GetPendedIoCount() != 0)
{
YieldProcessor();
}
}
*/
// if the IO pattern is still alive, must delete it once in the destructor before this object goes away
// - can't reset this in ctsSocket::shutdown since ctsSocket::shutdown can be called from the parent ctsSocketState
// and there may be callbacks still running holding onto a reference to this ctsSocket object
// which causes the potential to AV in the io_pattern
// (a race-condition touching the io_pattern with deleting the io_pattern)
m_pattern.reset();
}
[[nodiscard]] ctsSocket::SocketReference ctsSocket::AcquireSocketLock() const noexcept
{
auto lock = m_lock.lock();
const SOCKET lockedSocketValue = m_socket.get();
return {std::move(lock), lockedSocketValue, m_pattern};
}
void ctsSocket::SetSocket(SOCKET socket) noexcept
{
const auto lock = m_lock.lock();
FAIL_FAST_IF_MSG(
!!m_socket,
"ctsSocket::set_socket trying to set a SOCKET (%Iu) when it has already been set in this object (%Iu)",
socket, m_socket.get());
m_socket.reset(socket);
}
int ctsSocket::CloseSocket(uint32_t errorCode) noexcept
{
const auto lock = m_lock.lock();
auto error = 0l;
if (m_socket)
_No_competing_thread_ ctsSocket::~ctsSocket() noexcept
{
if (errorCode != 0)
{
// always try to RST if we are closing due to an error
// to best-effort notify the opposite endpoint
const wsIOResult result = ctsSetLingerToResetSocket(m_socket.get());
error = static_cast<int>(result.m_errorCode);
}
// shutdown() tears down the socket object
Shutdown();
/*
// wait for all IO to be completed before deleting the pattern if this is RIO
// as the IO requests are still using the RIO Buffers we gave it
if (WI_IsFlagSet(ctsConfig::g_configSettings->SocketFlags, WSA_FLAG_REGISTERED_IO))
{
while (GetPendedIoCount() != 0)
{
YieldProcessor();
}
}
*/
// if the IO pattern is still alive, must delete it once in the destructor before this object goes away
// - can't reset this in ctsSocket::shutdown since ctsSocket::shutdown can be called from the parent ctsSocketState
// and there may be callbacks still running holding onto a reference to this ctsSocket object
// which causes the potential to AV in the io_pattern
// (a race-condition touching the io_pattern with deleting the io_pattern)
m_pattern.reset();
}
[[nodiscard]] ctsSocket::SocketReference ctsSocket::AcquireSocketLock() const noexcept
{
auto lock = m_lock.lock();
const SOCKET lockedSocketValue = m_socket.get();
return {std::move(lock), lockedSocketValue, m_pattern};
}
void ctsSocket::SetSocket(SOCKET socket) noexcept
{
const auto lock = m_lock.lock();
FAIL_FAST_IF_MSG(
!!m_socket,
"ctsSocket::set_socket trying to set a SOCKET (%Iu) when it has already been set in this object (%Iu)",
socket, m_socket.get());
m_socket.reset(socket);
}
int ctsSocket::CloseSocket(uint32_t errorCode) noexcept
{
const auto lock = m_lock.lock();
auto error = 0l;
if (m_socket)
{
if (errorCode != 0)
{
// always try to RST if we are closing due to an error
// to best-effort notify the opposite endpoint
const wsIOResult result = ctsSetLingerToResetSocket(m_socket.get());
error = static_cast<int>(result.m_errorCode);
}
if (m_pattern)
{
// if the user asked for TCP details, capture them before we close the socket
m_pattern->PrintTcpInfo(m_localSockaddr, m_targetSockaddr, m_socket.get());
}
m_socket.reset();
}
return error;
}
const shared_ptr<ctThreadIocp>& ctsSocket::GetIocpThreadpool()
{
// use the SOCKET cs to also guard creation of this TP object
const auto lock = m_lock.lock();
// must verify a valid socket first to avoid racing destroying the iocp shared_ptr as we try to create it here
if (m_socket && !m_tpIocp)
{
// can throw
m_tpIocp = make_shared<ctThreadIocp>(m_socket.get(), ctsConfig::g_configSettings->pTpEnvironment);
}
return m_tpIocp;
}
void ctsSocket::PrintPatternResults(uint32_t lastError) const noexcept
{
if (m_pattern)
{
// if the user asked for TCP details, capture them before we close the socket
m_pattern->PrintTcpInfo(m_localSockaddr, m_targetSockaddr, m_socket.get());
const auto lock = m_lock.lock();
m_pattern->PrintStatistics(
GetLocalSockaddr(),
GetRemoteSockaddr());
}
else
{
// failed during socket creation, bind, or connect
ctsConfig::PrintConnectionResults(lastError);
}
}
void ctsSocket::CompleteState(DWORD errorCode) const noexcept
{
DWORD recordedError = errorCode;
{
const auto lock = m_lock.lock();
FAIL_FAST_IF_MSG(
m_ioCount != 0,
"ctsSocket::complete_state is called with outstanding IO (%d)", m_ioCount);
if (m_pattern)
{
// get the pattern's last_error
recordedError = m_pattern->GetLastPatternError();
// no longer allow any more callbacks
m_pattern->RegisterCallback(nullptr);
}
}
m_socket.reset();
}
return error;
}
const shared_ptr<ctThreadIocp>& ctsSocket::GetIocpThreadpool()
{
// use the SOCKET cs to also guard creation of this TP object
const auto lock = m_lock.lock();
// must verify a valid socket first to avoid racing destroying the iocp shared_ptr as we try to create it here
if (m_socket && !m_tpIocp)
{
m_tpIocp = make_shared<ctThreadIocp>(m_socket.get(), ctsConfig::g_configSettings->pTpEnvironment); // can throw
}
return m_tpIocp;
}
void ctsSocket::PrintPatternResults(uint32_t lastError) const noexcept
{
if (m_pattern)
{
const auto lock = m_lock.lock();
m_pattern->PrintStatistics(
GetLocalSockaddr(),
GetRemoteSockaddr());
}
else
{
// failed during socket creation, bind, or connect
ctsConfig::PrintConnectionResults(lastError);
}
}
void ctsSocket::CompleteState(DWORD errorCode) noexcept
{
const auto currentIoCount = ctMemoryGuardRead(&m_ioCount);
FAIL_FAST_IF_MSG(
currentIoCount != 0,
"ctsSocket::complete_state is called with outstanding IO (%d)", currentIoCount);
DWORD recordedError = errorCode;
if (m_pattern)
{
const auto lock = m_lock.lock();
// get the pattern's last_error
recordedError = m_pattern->GetLastPatternError();
// no longer allow any more callbacks
m_pattern->RegisterCallback(nullptr);
}
if (const auto refParent = m_parent.lock())
{
refParent->CompleteState(recordedError);
}
}
const ctSockaddr& ctsSocket::GetLocalSockaddr() const noexcept
{
return m_localSockaddr;
}
void ctsSocket::SetLocalSockaddr(const ctSockaddr& localAddress) noexcept
{
m_localSockaddr = localAddress;
}
const ctSockaddr& ctsSocket::GetRemoteSockaddr() const noexcept
{
return m_targetSockaddr;
}
void ctsSocket::SetRemoteSockaddr(const ctSockaddr& targetAddress) noexcept
{
m_targetSockaddr = targetAddress;
}
void ctsSocket::SetIoPattern()
{
m_pattern = ctsIoPattern::MakeIoPattern();
if (!m_pattern)
{
// in test scenarios
return;
}
m_pattern->SetParent(shared_from_this());
if (ctsConfig::g_configSettings->PrePostSends == 0)
{
// user didn't specify a specific # of sends to pend
// start ISB notifications (best effort)
InitiateIsbNotification();
}
}
void ctsSocket::InitiateIsbNotification() noexcept try
{
const auto sharedThis = shared_from_this();
const auto lockedSocket(sharedThis->AcquireSocketLock());
const auto& sharedIocp = GetIocpThreadpool();
OVERLAPPED* ov = sharedIocp->new_request([weak_this_ptr = std::weak_ptr(shared_from_this())](OVERLAPPED* pOverlapped) noexcept {
const auto lambdaSharedThis = weak_this_ptr.lock();
if (!lambdaSharedThis)
// don't hold any locks when calling back into the parent
if (const auto refParent = m_parent.lock())
{
refParent->CompleteState(recordedError);
}
}
const ctSockaddr& ctsSocket::GetLocalSockaddr() const noexcept
{
return m_localSockaddr;
}
void ctsSocket::SetLocalSockaddr(const ctSockaddr& localAddress) noexcept
{
m_localSockaddr = localAddress;
}
const ctSockaddr& ctsSocket::GetRemoteSockaddr() const noexcept
{
return m_targetSockaddr;
}
void ctsSocket::SetRemoteSockaddr(const ctSockaddr& targetAddress) noexcept
{
m_targetSockaddr = targetAddress;
}
void ctsSocket::SetIoPattern()
{
m_pattern = ctsIoPattern::MakeIoPattern();
if (!m_pattern)
{
// in test scenarios
return;
}
DWORD gle = NO_ERROR;
const auto lambdaLockedSocket(lambdaSharedThis->AcquireSocketLock());
const auto lambdaSocket = lambdaLockedSocket.GetSocket();
if (lambdaSocket != INVALID_SOCKET)
m_pattern->SetParent(shared_from_this());
if (ctsConfig::g_configSettings->PrePostSends == 0)
{
DWORD transferred; // unneeded
DWORD flags; // unneeded
if (!WSAGetOverlappedResult(lambdaSocket, pOverlapped, &transferred, FALSE, &flags))
// user didn't specify a specific # of sends to pend
// start ISB notifications (best effort)
InitiateIsbNotification();
}
}
void ctsSocket::InitiateIsbNotification() noexcept try
{
const auto sharedThis = shared_from_this();
const auto lockedSocket(sharedThis->AcquireSocketLock());
const auto& sharedIocp = GetIocpThreadpool();
OVERLAPPED* ov = sharedIocp->new_request(
[weak_this_ptr = std::weak_ptr(shared_from_this())](OVERLAPPED* pOverlapped) noexcept
{
gle = WSAGetLastError();
if (gle != ERROR_OPERATION_ABORTED && gle != WSAEINTR)
const auto lambdaSharedThis = weak_this_ptr.lock();
if (!lambdaSharedThis)
{
// aborted is expected whenever the socket is closed
ctsConfig::PrintErrorIfFailed("WSAIoctl(SIO_IDEAL_SEND_BACKLOG_CHANGE)", gle);
return;
}
DWORD gle = NO_ERROR;
const auto lambdaLockedSocket(lambdaSharedThis->AcquireSocketLock());
const auto lambdaSocket = lambdaLockedSocket.GetSocket();
if (lambdaSocket != INVALID_SOCKET)
{
DWORD transferred; // unneeded
DWORD flags; // unneeded
if (!WSAGetOverlappedResult(lambdaSocket, pOverlapped, &transferred, FALSE, &flags))
{
gle = WSAGetLastError();
if (gle != ERROR_OPERATION_ABORTED && gle != WSAEINTR)
{
// aborted is expected whenever the socket is closed
ctsConfig::PrintErrorIfFailed("WSAIoctl(SIO_IDEAL_SEND_BACKLOG_CHANGE)", gle);
}
}
}
else
{
gle = WSAECANCELLED;
}
if (gle == NO_ERROR)
{
// if the request succeeded, handle the ISB change
// and issue the next
ULONG isb;
if (0 == idealsendbacklogquery(lambdaSocket, &isb))
{
const auto lock = lambdaSharedThis->m_lock.lock();
PRINT_DEBUG_INFO(L"\t\tctsSocket::process_isb_notification : setting ISB to %u bytes\n", isb);
lambdaSharedThis->m_pattern->SetIdealSendBacklog(isb);
}
else
{
gle = WSAGetLastError();
if (gle != ERROR_OPERATION_ABORTED && gle != WSAEINTR)
{
ctsConfig::PrintErrorIfFailed("WSAIoctl(SIO_IDEAL_SEND_BACKLOG_QUERY)", gle);
}
}
lambdaSharedThis->InitiateIsbNotification();
}
}); // lambda for new_request
const auto localSocket = lockedSocket.GetSocket();
if (localSocket != INVALID_SOCKET)
{
if (SOCKET_ERROR == idealsendbacklognotify(localSocket, ov, nullptr))
{
const auto gle = WSAGetLastError();
// expect this to be pending
if (gle != WSA_IO_PENDING)
{
// if the ISB notification failed, tell the TP to no longer track that IO
sharedIocp->cancel_request(ov);
if (gle != ERROR_OPERATION_ABORTED && gle != WSAEINTR)
{
ctsConfig::PrintErrorIfFailed("WSAIoctl(SIO_IDEAL_SEND_BACKLOG_CHANGE)", gle);
}
}
}
}
else
{
gle = WSAECANCELLED;
// there wasn't a SOCKET to initiate the ISB notification, tell the TP to no longer track that IO
sharedIocp->cancel_request(ov);
}
}
catch (...)
{
ctsConfig::PrintThrownException();
}
if (gle == NO_ERROR)
// AcquireSocketLock() must have been called when calling this
int32_t ctsSocket::IncrementIo() noexcept
{
++m_ioCount;
return m_ioCount;
}
// AcquireSocketLock() must have been called when calling this
int32_t ctsSocket::DecrementIo() noexcept
{
--m_ioCount;
FAIL_FAST_IF_MSG(
m_ioCount < 0,
"ctsSocket: io count fell below zero (%d)\n", m_ioCount);
return m_ioCount;
}
// AcquireSocketLock() must have been called when calling this
int32_t ctsSocket::GetPendedIoCount() const noexcept
{
return m_ioCount;
}
void ctsSocket::Shutdown() noexcept
{
// close the socket to trigger IO to complete/shutdown
CloseSocket();
// Must destroy these threadpool objects outside the CS to prevent a deadlock
// - from when worker threads attempt to callback this ctsSocket object when IO completes
// Must wait for the threadpool from this method when ctsSocketState calls ctsSocket::shutdown
// - instead of calling this from the destructor of ctsSocket, as the final reference
// to this ctsSocket might be from a TP thread - in which case this destructor will deadlock
// (it will wait for all TP threads to exit, but it is using/blocking on of those TP threads)
m_tpIocp.reset();
m_tpTimer.reset();
}
///
/// SetTimer schedules the callback function to be invoked with the given ctsSocket and ctsIOTask
/// - note that the timer
/// - can throw under low resource conditions
///
void ctsSocket::SetTimer(const ctsTask& task, function<void(weak_ptr<ctsSocket>, const ctsTask&)>&& func)
{
const auto lock = m_lock.lock();
m_timerTask = task;
m_timerCallback = std::move(func);
if (!m_tpTimer)
{
// if the request succeeded, handle the ISB change
// and issue the next
ULONG isb;
if (0 == idealsendbacklogquery(lambdaSocket, &isb))
{
const auto lock = lambdaSharedThis->m_lock.lock();
PRINT_DEBUG_INFO(L"\t\tctsSocket::process_isb_notification : setting ISB to %u bytes\n", isb);
lambdaSharedThis->m_pattern->SetIdealSendBacklog(isb);
}
else
{
gle = WSAGetLastError();
if (gle != ERROR_OPERATION_ABORTED && gle != WSAEINTR)
{
ctsConfig::PrintErrorIfFailed("WSAIoctl(SIO_IDEAL_SEND_BACKLOG_QUERY)", gle);
}
}
lambdaSharedThis->InitiateIsbNotification();
m_tpTimer.reset(
CreateThreadpoolTimer(ThreadPoolTimerCallback, this, ctsConfig::g_configSettings->pTpEnvironment));
THROW_LAST_ERROR_IF(!m_tpTimer);
}
}); // lambda for new_request
const auto localSocket = lockedSocket.GetSocket();
if (localSocket != INVALID_SOCKET)
FILETIME relativeTimeout = wil::filetime::from_int64(
-1 * wil::filetime_duration::one_millisecond * task.m_timeOffsetMilliseconds);
SetThreadpoolTimer(m_tpTimer.get(), &relativeTimeout, 0, 0);
}
void NTAPI ctsSocket::ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE, PVOID pContext, PTP_TIMER)
{
if (SOCKET_ERROR == idealsendbacklognotify(localSocket, ov, nullptr))
auto* pThis = static_cast<ctsSocket*>(pContext);
ctsTask task;
function<void(weak_ptr<ctsSocket>, const ctsTask&)> callback;
{
const auto gle = WSAGetLastError();
// expect this to be pending
if (gle != WSA_IO_PENDING)
{
// if the ISB notification failed, tell the TP to no longer track that IO
sharedIocp->cancel_request(ov);
if (gle != ERROR_OPERATION_ABORTED && gle != WSAEINTR)
{
ctsConfig::PrintErrorIfFailed("WSAIoctl(SIO_IDEAL_SEND_BACKLOG_CHANGE)", gle);
}
}
const auto lock = pThis->m_lock.lock();
task = pThis->m_timerTask;
callback = std::move(pThis->m_timerCallback);
}
// invoke the callback outside the lock
callback(pThis->weak_from_this(), task);
}
else
{
// there wasn't a SOCKET to initiate the ISB notification, tell the TP to no longer track that IO
sharedIocp->cancel_request(ov);
}
}
catch (...)
{
ctsConfig::PrintThrownException();
}
int32_t ctsSocket::IncrementIo() noexcept
{
// not apply memory guards - we are holding the socket lock
++m_ioCount;
return m_ioCount;
}
int32_t ctsSocket::DecrementIo() noexcept
{
// not apply memory guards - we are holding the socket lock
--m_ioCount;
FAIL_FAST_IF_MSG(
m_ioCount < 0,
"ctsSocket: io count fell below zero (%d)\n", m_ioCount);
return m_ioCount;
}
int32_t ctsSocket::GetPendedIoCount() noexcept
{
// using Interlocked* to read to force CPU synchronization
// even though writes do not need to be synchronized (we are OK having a small drift with the actual count)
return ctMemoryGuardRead(&m_ioCount);
}
void ctsSocket::Shutdown() noexcept
{
// close the socket to trigger IO to complete/shutdown
CloseSocket();
// Must destroy these threadpool objects outside the CS to prevent a deadlock
// - from when worker threads attempt to callback this ctsSocket object when IO completes
// Must wait for the threadpool from this method when ctsSocketState calls ctsSocket::shutdown
// - instead of calling this from the destructor of ctsSocket, as the final reference
// to this ctsSocket might be from a TP thread - in which case this destructor will deadlock
// (it will wait for all TP threads to exit, but it is using/blocking on of those TP threads)
m_tpIocp.reset();
m_tpTimer.reset();
}
///
/// SetTimer schedules the callback function to be invoked with the given ctsSocket and ctsIOTask
/// - note that the timer
/// - can throw under low resource conditions
///
void ctsSocket::SetTimer(const ctsTask& task, function<void(weak_ptr<ctsSocket>, const ctsTask&)>&& func)
{
const auto lock = m_lock.lock();
m_timerTask = task;
m_timerCallback = std::move(func);
if (!m_tpTimer)
{
m_tpTimer.reset(CreateThreadpoolTimer(ThreadPoolTimerCallback, this, ctsConfig::g_configSettings->pTpEnvironment));
THROW_LAST_ERROR_IF(!m_tpTimer);
}
FILETIME relativeTimeout = wil::filetime::from_int64(-1 * wil::filetime_duration::one_millisecond * task.m_timeOffsetMilliseconds);
SetThreadpoolTimer(m_tpTimer.get(), &relativeTimeout, 0, 0);
}
void NTAPI ctsSocket::ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE, PVOID pContext, PTP_TIMER)
{
auto* pThis = static_cast<ctsSocket*>(pContext);
ctsTask task{};
function<void(weak_ptr<ctsSocket>, const ctsTask&)> callback;
{
const auto lock = pThis->m_lock.lock();
task = pThis->m_timerTask;
callback = std::move(pThis->m_timerCallback);
}
// invoke the callback outside the lock
callback(pThis->weak_from_this(), task);
}
} // namespace

Просмотреть файл

@ -123,7 +123,7 @@ public:
// The only successful DWORD value is NO_ERROR (0)
// Any other DWORD indicates error
//
void CompleteState(DWORD errorCode) noexcept;
void CompleteState(DWORD errorCode) const noexcept;
//
// Gets/Sets the local address of the SOCKET
@ -147,7 +147,7 @@ public:
//
int32_t IncrementIo() noexcept;
int32_t DecrementIo() noexcept;
int32_t GetPendedIoCount() noexcept;
int32_t GetPendedIoCount() const noexcept;
//
// method for the parent to instruct the ctsSocket to print the connection data
@ -186,11 +186,9 @@ private:
void InitiateIsbNotification() noexcept;
// private members for this socket instance
// mutable is required to EnterCS/LeaveCS in const methods
// mutable is required to EnterCriticalSection/LeaveCriticalSection in const methods
mutable wil::critical_section m_lock{ctsConfig::ctsConfigSettings::c_CriticalSectionSpinlock};
_Guarded_by_(m_lock) wil::unique_socket m_socket;
_Interlocked_ long m_ioCount = 0L;
// maintain a weak-reference to the parent and child
std::weak_ptr<ctsSocketState> m_parent;
@ -206,6 +204,8 @@ private:
ctl::ctSockaddr m_localSockaddr;
ctl::ctSockaddr m_targetSockaddr;
long m_ioCount = 0L;
static void NTAPI ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE, PVOID pContext, PTP_TIMER);
};
} // namespace

Просмотреть файл

@ -305,9 +305,9 @@ namespace ctsTraffic { namespace ctsStatistics
ctsUdpStatistics& operator=(ctsUdpStatistics&&) = delete;
// currently only called by the UDP client - only tracking the receives
[[nodiscard]] int64_t GetBytesTransferred() const noexcept
[[nodiscard]] int64_t GetBytesTransferredNoLock() const noexcept
{
return m_bitsReceived.GetValue() / 8;
return m_bitsReceived.GetValueNoLock() / 8;
}
//
@ -371,9 +371,9 @@ namespace ctsTraffic { namespace ctsStatistics
ctsTcpStatistics operator=(const ctsTcpStatistics&) = delete;
ctsTcpStatistics operator=(ctsTcpStatistics&&) = delete;
[[nodiscard]] int64_t GetBytesTransferred() const noexcept
[[nodiscard]] int64_t GetBytesTransferredNoLock() const noexcept
{
return m_bytesRecv.GetValue() + m_bytesSent.GetValue();
return m_bytesRecv.GetValueNoLock() + m_bytesSent.GetValueNoLock();
}
//