diff --git a/src/inc/msquic.hpp b/src/inc/msquic.hpp index 85966c4ac..63da51fb8 100644 --- a/src/inc/msquic.hpp +++ b/src/inc/msquic.hpp @@ -83,6 +83,16 @@ struct CxPlatLockDispatch { void Acquire() noexcept { CxPlatDispatchLockAcquire(&Handle); } void Release() noexcept { CxPlatDispatchLockRelease(&Handle); } }; + +struct CxPlatRwLockDispatch { + CXPLAT_DISPATCH_RW_LOCK Handle; + CxPlatRwLockDispatch() noexcept { CxPlatDispatchRwLockInitialize(&Handle); } + ~CxPlatRwLockDispatch() noexcept { CxPlatDispatchRwLockUninitialize(&Handle); } + void AcquireShared() noexcept { CxPlatDispatchRwLockAcquireShared(&Handle); } + void AcquireExclusive() noexcept { CxPlatDispatchRwLockAcquireExclusive(&Handle); } + void ReleaseShared() noexcept { CxPlatDispatchRwLockReleaseShared(&Handle); } + void ReleaseExclusive() noexcept { CxPlatDispatchRwLockReleaseExclusive(&Handle); } +}; #pragma warning(pop) struct CxPlatPool { @@ -134,11 +144,11 @@ public: #ifdef CXPLAT_HASH_MIN_SIZE -struct HashTable { +struct CxPlatHashTable { bool Initialized; CXPLAT_HASHTABLE Table; - HashTable() noexcept { Initialized = CxPlatHashtableInitializeEx(&Table, CXPLAT_HASH_MIN_SIZE); } - ~HashTable() noexcept { if (Initialized) { CxPlatHashtableUninitialize(&Table); } } + CxPlatHashTable() noexcept { Initialized = CxPlatHashtableInitializeEx(&Table, CXPLAT_HASH_MIN_SIZE); } + ~CxPlatHashTable() noexcept { if (Initialized) { CxPlatHashtableUninitialize(&Table); } } void Insert(CXPLAT_HASHTABLE_ENTRY* Entry) noexcept { CxPlatHashtableInsert(&Table, Entry, Entry->Signature, nullptr); } void Remove(CXPLAT_HASHTABLE_ENTRY* Entry) noexcept { CxPlatHashtableRemove(&Table, Entry, nullptr); } CXPLAT_HASHTABLE_ENTRY* Lookup(uint64_t Signature) noexcept { diff --git a/src/perf/lib/PerfClient.h b/src/perf/lib/PerfClient.h index 55906e6a3..8d6233ece 100644 --- a/src/perf/lib/PerfClient.h +++ b/src/perf/lib/PerfClient.h @@ -21,7 +21,7 @@ struct PerfClientConnection { HQUIC Handle {nullptr}; TcpConnection* TcpConn; }; - HashTable StreamTable; + CxPlatHashTable StreamTable; uint64_t StreamsCreated {0}; uint64_t StreamsActive {0}; bool WorkerConnComplete {false}; // Indicated completion to worker diff --git a/src/perf/lib/PerfServer.cpp b/src/perf/lib/PerfServer.cpp index a291f6dae..bccdefe08 100644 --- a/src/perf/lib/PerfServer.cpp +++ b/src/perf/lib/PerfServer.cpp @@ -102,8 +102,8 @@ QUIC_STATUS PerfServer::Start( _In_ CXPLAT_EVENT* _StopEvent ) { - if (!Server.Start(&LocalAddr)) { // TCP - //printf("TCP Server failed to start!\n"); + if (!Server.Start(&LocalAddr)) { + WriteOutput("Warning: TCP Server failed to start!\n"); } StopEvent = _StopEvent; @@ -158,17 +158,10 @@ PerfServer::ListenerCallback( switch (Event->Type) { case QUIC_LISTENER_EVENT_NEW_CONNECTION: { BOOLEAN value = TRUE; - MsQuic->SetParam( - Event->NEW_CONNECTION.Connection, - QUIC_PARAM_CONN_DISABLE_1RTT_ENCRYPTION, - sizeof(value), - &value); + MsQuic->SetParam(Event->NEW_CONNECTION.Connection, QUIC_PARAM_CONN_DISABLE_1RTT_ENCRYPTION, sizeof(value), &value); QUIC_CONNECTION_CALLBACK_HANDLER Handler = [](HQUIC Conn, void* Context, QUIC_CONNECTION_EVENT* Event) -> QUIC_STATUS { - return ((PerfServer*)Context)-> - ConnectionCallback( - Conn, - Event); + return ((PerfServer*)Context)->ConnectionCallback(Conn, Event); }; MsQuic->SetCallbackHandler(Event->NEW_CONNECTION.Connection, (void*)Handler, this); Status = MsQuic->ConnectionSetConfiguration(Event->NEW_CONNECTION.Connection, Configuration); @@ -324,7 +317,7 @@ PerfServer::SendTcpResponse( uint64_t BytesLeftToSend = Context->ResponseSize - Context->BytesSent; - auto SendData = new(std::nothrow) TcpSendData(); + auto SendData = TcpSendDataAllocator.Alloc(); SendData->StreamId = (uint32_t)Context->Entry.Signature; SendData->Open = Context->BytesSent == 0 ? 1 : 0; SendData->Buffer = DataBuffer->Buffer; @@ -352,7 +345,7 @@ PerfServer::TcpAcceptCallback( ) { auto This = (PerfServer*)Server->Context; - Connection->Context = This; + Connection->Context = This->TcpConnectionContextAllocator.Alloc(This); } _IRQL_requires_max_(DISPATCH_LEVEL) @@ -365,6 +358,9 @@ PerfServer::TcpConnectCallback( { if (!IsConnected) { Connection->Close(); + auto This = (TcpConnectionContext*)Connection->Context; + auto Server = This->Server; + Server->TcpConnectionContextAllocator.Free(This); } } @@ -381,10 +377,11 @@ PerfServer::TcpReceiveCallback( uint8_t* Buffer ) { - auto This = (PerfServer*)Connection->Context; + auto This = (TcpConnectionContext*)Connection->Context; + auto Server = This->Server; StreamContext* Stream; if (Open) { - if ((Stream = This->StreamContextAllocator.Alloc(This, false, false)) != nullptr) { + if ((Stream = Server->StreamContextAllocator.Alloc(Server, false, false)) != nullptr) { Stream->Entry.Signature = StreamID; Stream->IdealSendBuffer = 1; // TCP uses send buffering, so just set to 1. This->StreamTable.Insert(&Stream->Entry); @@ -402,30 +399,30 @@ PerfServer::TcpReceiveCallback( } if (Abort) { Stream->ResponseSize = 0; // Reset to make sure we stop sending more - auto SendData = new(std::nothrow) TcpSendData(); + auto SendData = Server->TcpSendDataAllocator.Alloc(); SendData->StreamId = StreamID; SendData->Open = Open ? TRUE : FALSE; SendData->Abort = TRUE; - SendData->Buffer = This->DataBuffer->Buffer; + SendData->Buffer = Server->DataBuffer->Buffer; SendData->Length = 0; Connection->Send(SendData); } else if (Fin) { if (Stream->ResponseSizeSet && Stream->ResponseSize != 0) { - This->SendTcpResponse(Stream, Connection); + Server->SendTcpResponse(Stream, Connection); } else { - auto SendData = new(std::nothrow) TcpSendData(); + auto SendData = Server->TcpSendDataAllocator.Alloc(); SendData->StreamId = StreamID; SendData->Open = TRUE; SendData->Fin = TRUE; - SendData->Buffer = This->DataBuffer->Buffer; + SendData->Buffer = Server->DataBuffer->Buffer; SendData->Length = 0; Connection->Send(SendData); } Stream->RecvShutdown = true; if (Stream->SendShutdown) { This->StreamTable.Remove(&Stream->Entry); - This->StreamContextAllocator.Free(Stream); + Server->StreamContextAllocator.Free(Stream); } } } @@ -438,23 +435,24 @@ PerfServer::TcpSendCompleteCallback( TcpSendData* SendDataChain ) { - auto This = (PerfServer*)Connection->Context; + auto This = (TcpConnectionContext*)Connection->Context; + auto Server = This->Server; while (SendDataChain) { auto Data = SendDataChain; auto Entry = This->StreamTable.Lookup(Data->StreamId); if (Entry) { auto Stream = CXPLAT_CONTAINING_RECORD(Entry, StreamContext, Entry); Stream->OutstandingBytes -= Data->Length; - This->SendTcpResponse(Stream, Connection); + Server->SendTcpResponse(Stream, Connection); if ((Data->Fin || Data->Abort) && !Stream->SendShutdown) { Stream->SendShutdown = true; if (Stream->RecvShutdown) { This->StreamTable.Remove(&Stream->Entry); - This->StreamContextAllocator.Free(Stream); + Server->StreamContextAllocator.Free(Stream); } } } SendDataChain = SendDataChain->Next; - delete Data; + Server->TcpSendDataAllocator.Free(Data); } } diff --git a/src/perf/lib/PerfServer.h b/src/perf/lib/PerfServer.h index a575690cf..3e7b35e5a 100644 --- a/src/perf/lib/PerfServer.h +++ b/src/perf/lib/PerfServer.h @@ -53,6 +53,14 @@ public: private: + struct TcpConnectionContext { + PerfServer* Server; + CxPlatHashTable StreamTable; + TcpConnectionContext(PerfServer* Server) : Server(Server) { } + }; + + CxPlatPoolT TcpConnectionContextAllocator; + struct StreamContext { StreamContext( PerfServer* Server, bool Unidirectional, bool BufferedIo) : @@ -76,6 +84,9 @@ private: QUIC_BUFFER LastBuffer; }; + CxPlatPoolT StreamContextAllocator; + CxPlatPoolT TcpSendDataAllocator; + QUIC_STATUS ListenerCallback( _In_ MsQuicListener* Listener, @@ -137,11 +148,9 @@ private: CXPLAT_EVENT* StopEvent {nullptr}; QUIC_BUFFER* DataBuffer {nullptr}; uint8_t PrintStats {FALSE}; - CxPlatPoolT StreamContextAllocator; TcpEngine Engine; TcpServer Server; - HashTable StreamTable; uint32_t CibirIdLength {0}; uint8_t CibirId[7]; // {offset, values} diff --git a/src/perf/lib/Tcp.cpp b/src/perf/lib/Tcp.cpp index f32d4cdc4..95ca5e23e 100644 --- a/src/perf/lib/Tcp.cpp +++ b/src/perf/lib/Tcp.cpp @@ -357,7 +357,7 @@ TcpConnection::TcpConnection( } } QuicAddrSetPort(&Route.RemoteAddress, ServerPort); - Engine->AddConnection(this, 0); // TODO - Correct index + Engine->AddConnection(this, (uint16_t)CxPlatProcCurrentNumber()); Initialized = true; if (QUIC_FAILED( CxPlatSocketCreateTcp( @@ -389,7 +389,7 @@ TcpConnection::TcpConnection( this); Initialized = true; IndicateAccept = true; - Engine->AddConnection(this, 0); // TODO - Correct index + Engine->AddConnection(this, (uint16_t)CxPlatProcCurrentNumber()); Queue(); }