From 4c3051cfbe678747a1c319d27b34268bde578213 Mon Sep 17 00:00:00 2001 From: Michael Friesen <3517159+mtfriesen@users.noreply.github.com> Date: Wed, 14 Dec 2022 15:55:28 -0500 Subject: [PATCH] Support multiple outstanding receives (#3292) * support multiple outstanding receives * fix leak and sqe initialization * fix indent --- src/platform/datapath_winuser.c | 206 ++++++++++++++++++-------------- 1 file changed, 115 insertions(+), 91 deletions(-) diff --git a/src/platform/datapath_winuser.c b/src/platform/datapath_winuser.c index b0a770138..70d6e8920 100644 --- a/src/platform/datapath_winuser.c +++ b/src/platform/datapath_winuser.c @@ -139,6 +139,11 @@ typedef struct CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT { // CXPLAT_POOL* OwningPool; + // + // The owning per-processor socket. + // + CXPLAT_SOCKET_PROC* SocketProc; + // // The reference count of the receive buffer. // @@ -149,6 +154,25 @@ typedef struct CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT { // CXPLAT_ROUTE Route; + // + // The receive SQE. + // + DATAPATH_IO_SQE Sqe; + + // + // Contains the control data resulting from the receive. + // + char ControlBuf[ + WSA_CMSG_SPACE(sizeof(IN6_PKTINFO)) + // IP_PKTINFO + WSA_CMSG_SPACE(sizeof(DWORD)) + // UDP_COALESCED_INFO + WSA_CMSG_SPACE(sizeof(INT)) // IP_ECN + ]; + + // + // Contains the input and output message data. + // + WSAMSG WsaMsgHdr; + } CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT; // @@ -287,19 +311,6 @@ typedef struct QUIC_CACHEALIGN CXPLAT_SOCKET_PROC { union { // - // Normal TCP/UDP socket data - // - struct { - WSABUF RecvWsaBuf; - char RecvWsaMsgControlBuf[ - WSA_CMSG_SPACE(sizeof(IN6_PKTINFO)) + // IP_PKTINFO - WSA_CMSG_SPACE(sizeof(DWORD)) + // UDP_COALESCED_INFO - WSA_CMSG_SPACE(sizeof(INT)) // IP_ECN - ]; - WSAMSG RecvWsaMsgHdr; - CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* CurrentRecvContext; - }; - // // TCP Listener socket data // struct { @@ -348,6 +359,11 @@ typedef struct CXPLAT_SOCKET { // uint16_t Mtu; + // + // The size of a receive buffer's payload. + // + uint32_t RecvBufLen; + // // Socket type. // @@ -609,7 +625,8 @@ QUIC_STATUS CxPlatSocketStartReceive( _In_ CXPLAT_SOCKET_PROC* SocketProc, _Out_opt_ ULONG* SyncIoResult, - _Out_opt_ uint16_t* SyncBytesReceived + _Out_opt_ uint16_t* SyncBytesReceived, + _Out_opt_ CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT** SyncRecvContext ); QUIC_STATUS @@ -1508,6 +1525,11 @@ CxPlatSocketCreateUdp( } CxPlatRefInitializeEx(&Socket->RefCount, SocketCount); + Socket->RecvBufLen = + (Datapath->Features & CXPLAT_DATAPATH_FEATURE_RECV_COALESCING) ? + MAX_URO_PAYLOAD_LENGTH : + Socket->Mtu - CXPLAT_MIN_IPV4_HEADER_SIZE - CXPLAT_UDP_HEADER_SIZE; + for (uint16_t i = 0; i < SocketCount; i++) { Socket->Processors[i].Parent = Socket; Socket->Processors[i].DatapathProc = NULL; @@ -1516,11 +1538,6 @@ CxPlatSocketCreateUdp( Socket->Processors[i].ShutdownSqe.CqeType = CXPLAT_CQE_TYPE_SOCKET_SHUTDOWN; CxPlatDatapathSqeInitialize( &Socket->Processors[i].IoSqe.DatapathSqe, CXPLAT_CQE_TYPE_SOCKET_IO); - - Socket->Processors[i].RecvWsaBuf.len = - (Datapath->Features & CXPLAT_DATAPATH_FEATURE_RECV_COALESCING) ? - MAX_URO_PAYLOAD_LENGTH : - Socket->Mtu - CXPLAT_MIN_IPV4_HEADER_SIZE - CXPLAT_UDP_HEADER_SIZE; CxPlatRundownInitialize(&Socket->Processors[i].UpcallRundown); } @@ -2001,7 +2018,7 @@ QUIC_DISABLED_BY_FUZZER_END; *NewSocket = Socket; for (uint16_t i = 0; i < SocketCount; i++) { - Status = CxPlatSocketStartReceive(&Socket->Processors[i], NULL, NULL); + Status = CxPlatSocketStartReceive(&Socket->Processors[i], NULL, NULL, NULL); if (QUIC_FAILED(Status)) { goto Error; } @@ -2072,6 +2089,7 @@ CxPlatSocketCreateTcpInternal( AffinitizedProcessor = RemoteAddress ? (((uint16_t)CxPlatProcCurrentNumber()) % Datapath->ProcCount) : 0; Socket->Mtu = CXPLAT_MAX_MTU; + Socket->RecvBufLen = MAX_URO_PAYLOAD_LENGTH; CxPlatRefInitializeEx(&Socket->RefCount, 1); SocketProc = &Socket->Processors[0]; @@ -2079,7 +2097,6 @@ CxPlatSocketCreateTcpInternal( SocketProc->Socket = INVALID_SOCKET; SocketProc->ShutdownSqe.CqeType = CXPLAT_CQE_TYPE_SOCKET_SHUTDOWN; CxPlatDatapathSqeInitialize(&SocketProc->IoSqe.DatapathSqe, CXPLAT_CQE_TYPE_SOCKET_IO); - SocketProc->RecvWsaBuf.len = MAX_URO_PAYLOAD_LENGTH; CxPlatRundownInitialize(&SocketProc->UpcallRundown); SocketProc->Socket = @@ -2660,12 +2677,6 @@ CxPlatSocketContextUninitializeComplete( CxPlatSocketDelete(SocketProc->AcceptSocket); SocketProc->AcceptSocket = NULL; } - - } else if (SocketProc->CurrentRecvContext != NULL) { - CxPlatPoolFree( - SocketProc->CurrentRecvContext->OwningPool, - SocketProc->CurrentRecvContext); - SocketProc->CurrentRecvContext = NULL; } CxPlatRundownUninitialize(&SocketProc->UpcallRundown); @@ -2715,20 +2726,33 @@ CxPlatSocketGetRemoteAddress( CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* CxPlatSocketAllocRecvContext( - _In_ CXPLAT_DATAPATH_PROC* DatapathProc + _In_ CXPLAT_SOCKET_PROC* SocketProc ) { + CXPLAT_DATAPATH_PROC* DatapathProc = SocketProc->DatapathProc; CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* RecvContext = CxPlatPoolAlloc(&DatapathProc->RecvDatagramPool); if (RecvContext != NULL) { RecvContext->OwningPool = &DatapathProc->RecvDatagramPool; + RecvContext->SocketProc = SocketProc; RecvContext->ReferenceCount = 0; +#if DEBUG + RecvContext->Sqe.IoType = 0; +#endif } return RecvContext; } +void +CxPlatSocketFreeRecvContext( + _In_ CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* RecvContext + ) +{ + CxPlatPoolFree(RecvContext->OwningPool, RecvContext); +} + QUIC_STATUS CxPlatSocketStartAccept( _In_ CXPLAT_SOCKET_PROC* ListenerSocketProc @@ -2896,7 +2920,7 @@ CxPlatDataPathSocketProcessAcceptCompletion( goto Error; } - if (QUIC_FAILED(CxPlatSocketStartReceive(AcceptSocketProc, NULL, NULL))) { + if (QUIC_FAILED(CxPlatSocketStartReceive(AcceptSocketProc, NULL, NULL, NULL))) { goto Error; } @@ -2970,7 +2994,7 @@ CxPlatDataPathSocketProcessConnectCompletion( // // Try to start a new receive. // - (void)CxPlatSocketStartReceive(SocketProc, NULL, NULL); + (void)CxPlatSocketStartReceive(SocketProc, NULL, NULL, NULL); } else { QuicTraceEvent( @@ -3021,7 +3045,8 @@ QUIC_STATUS CxPlatSocketStartReceive( _In_ CXPLAT_SOCKET_PROC* SocketProc, _Out_opt_ ULONG* SyncIoResult, - _Out_opt_ uint16_t* SyncBytesReceived + _Out_opt_ uint16_t* SyncBytesReceived, + _Out_opt_ CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT** SyncRecvContext ) { QUIC_STATUS Status = QUIC_STATUS_SUCCESS; @@ -3029,44 +3054,42 @@ CxPlatSocketStartReceive( CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* RecvContext; int Result; DWORD BytesRecv = 0; + WSABUF WsaBuf; CXPLAT_DBG_ASSERT((SyncIoResult != NULL) == (SyncBytesReceived != NULL)); + CXPLAT_DBG_ASSERT((SyncIoResult != NULL) == (SyncRecvContext != NULL)); CXPLAT_DBG_ASSERT(SocketProc->Parent->Type != CXPLAT_SOCKET_TCP_LISTENER); // // Get a receive buffer we can pass to WinSock. // - if (SocketProc->CurrentRecvContext == NULL) { - SocketProc->CurrentRecvContext = - CxPlatSocketAllocRecvContext(SocketProc->DatapathProc); - if (SocketProc->CurrentRecvContext == NULL) { - Status = QUIC_STATUS_OUT_OF_MEMORY; - QuicTraceEvent( - AllocFailure, - "Allocation of '%s' failed. (%llu bytes)", - "Socket Receive Buffer", - SocketProc->Parent->Datapath->RecvPayloadOffset + MAX_URO_PAYLOAD_LENGTH); - goto Error; - } + RecvContext = CxPlatSocketAllocRecvContext(SocketProc); + if (RecvContext == NULL) { + Status = QUIC_STATUS_OUT_OF_MEMORY; + QuicTraceEvent( + AllocFailure, + "Allocation of '%s' failed. (%llu bytes)", + "Socket Receive Buffer", + SocketProc->Parent->Datapath->RecvPayloadOffset + MAX_URO_PAYLOAD_LENGTH); + goto Error; } - RecvContext = SocketProc->CurrentRecvContext; - CxPlatStartDatapathIo(&SocketProc->IoSqe, DATAPATH_IO_RECV); + CxPlatDatapathSqeInitialize(&RecvContext->Sqe.DatapathSqe, CXPLAT_CQE_TYPE_SOCKET_IO); + CxPlatStartDatapathIo(&RecvContext->Sqe, DATAPATH_IO_RECV); - SocketProc->RecvWsaBuf.buf = ((CHAR*)RecvContext) + Datapath->RecvPayloadOffset; + WsaBuf.buf = ((CHAR*)RecvContext) + Datapath->RecvPayloadOffset; + WsaBuf.len = SocketProc->Parent->RecvBufLen; - RtlZeroMemory( - &SocketProc->RecvWsaMsgHdr, - sizeof(SocketProc->RecvWsaMsgHdr)); + RtlZeroMemory(&RecvContext->WsaMsgHdr, sizeof(RecvContext->WsaMsgHdr)); - SocketProc->RecvWsaMsgHdr.name = (PSOCKADDR)&RecvContext->Route.RemoteAddress; - SocketProc->RecvWsaMsgHdr.namelen = sizeof(RecvContext->Route.RemoteAddress); + RecvContext->WsaMsgHdr.name = (PSOCKADDR)&RecvContext->Route.RemoteAddress; + RecvContext->WsaMsgHdr.namelen = sizeof(RecvContext->Route.RemoteAddress); - SocketProc->RecvWsaMsgHdr.lpBuffers = &SocketProc->RecvWsaBuf; - SocketProc->RecvWsaMsgHdr.dwBufferCount = 1; + RecvContext->WsaMsgHdr.lpBuffers = &WsaBuf; + RecvContext->WsaMsgHdr.dwBufferCount = 1; - SocketProc->RecvWsaMsgHdr.Control.buf = SocketProc->RecvWsaMsgControlBuf; - SocketProc->RecvWsaMsgHdr.Control.len = sizeof(SocketProc->RecvWsaMsgControlBuf); + RecvContext->WsaMsgHdr.Control.buf = RecvContext->ControlBuf; + RecvContext->WsaMsgHdr.Control.len = sizeof(RecvContext->ControlBuf); Retry_recv: @@ -3074,20 +3097,20 @@ Retry_recv: Result = SocketProc->Parent->Datapath->WSARecvMsg( SocketProc->Socket, - &SocketProc->RecvWsaMsgHdr, + &RecvContext->WsaMsgHdr, &BytesRecv, - &SocketProc->IoSqe.DatapathSqe.Sqe.Overlapped, + &RecvContext->Sqe.DatapathSqe.Sqe.Overlapped, NULL); } else { DWORD Flags = 0; Result = WSARecv( SocketProc->Socket, - &SocketProc->RecvWsaBuf, + &WsaBuf, 1, &BytesRecv, &Flags, - &SocketProc->IoSqe.DatapathSqe.Sqe.Overlapped, + &RecvContext->Sqe.DatapathSqe.Sqe.Overlapped, NULL); } @@ -3109,8 +3132,9 @@ Retry_recv: if (SyncBytesReceived != NULL) { *SyncBytesReceived = 0; *SyncIoResult = WsaError; + *SyncRecvContext = RecvContext; } - CxPlatStopInlineDatapathIo(&SocketProc->IoSqe); + CxPlatStopInlineDatapathIo(&RecvContext->Sqe); goto Error; } } @@ -3122,9 +3146,9 @@ Retry_recv: // if (!CxPlatEventQEnqueueEx( SocketProc->DatapathProc->EventQ, - &SocketProc->IoSqe.DatapathSqe.Sqe, + &RecvContext->Sqe.DatapathSqe.Sqe, BytesRecv, - &SocketProc->IoSqe.DatapathSqe)) { + &RecvContext->Sqe.DatapathSqe)) { DWORD LastError = GetLastError(); QuicTraceEvent( DatapathErrorStatus, @@ -3140,7 +3164,8 @@ Retry_recv: CXPLAT_DBG_ASSERT(BytesRecv < UINT16_MAX); *SyncBytesReceived = (uint16_t)BytesRecv; *SyncIoResult = NO_ERROR; - CxPlatStopInlineDatapathIo(&SocketProc->IoSqe); + *SyncRecvContext = RecvContext; + CxPlatStopInlineDatapathIo(&RecvContext->Sqe); } Error: @@ -3176,7 +3201,7 @@ CxPlatDataPathUdpRecvComplete( } } else if (IoResult == ERROR_MORE_DATA || - (IoResult == NO_ERROR && SocketProc->RecvWsaBuf.len < NumberOfBytesTransferred)) { + (IoResult == NO_ERROR && SocketProc->Parent->RecvBufLen < NumberOfBytesTransferred)) { CxPlatConvertFromMappedV6(RemoteAddr, RemoteAddr); @@ -3207,9 +3232,9 @@ CxPlatDataPathUdpRecvComplete( BOOLEAN IsCoalesced = FALSE; INT ECN = 0; - for (WSACMSGHDR *CMsg = WSA_CMSG_FIRSTHDR(&SocketProc->RecvWsaMsgHdr); + for (WSACMSGHDR *CMsg = WSA_CMSG_FIRSTHDR(&RecvContext->WsaMsgHdr); CMsg != NULL; - CMsg = WSA_CMSG_NXTHDR(&SocketProc->RecvWsaMsgHdr, CMsg)) { + CMsg = WSA_CMSG_NXTHDR(&RecvContext->WsaMsgHdr, CMsg)) { if (CMsg->cmsg_level == IPPROTO_IPV6) { if (CMsg->cmsg_type == IPV6_PKTINFO) { @@ -3276,7 +3301,7 @@ CxPlatDataPathUdpRecvComplete( CASTED_CLOG_BYTEARRAY(sizeof(*LocalAddr), LocalAddr), CASTED_CLOG_BYTEARRAY(sizeof(*RemoteAddr), RemoteAddr)); - CXPLAT_DBG_ASSERT(NumberOfBytesTransferred <= SocketProc->RecvWsaBuf.len); + CXPLAT_DBG_ASSERT(NumberOfBytesTransferred <= SocketProc->Parent->RecvBufLen); Datagram = (CXPLAT_RECV_DATA*)(RecvContext + 1); @@ -3326,6 +3351,7 @@ CxPlatDataPathUdpRecvComplete( } } + RecvContext = NULL; CXPLAT_DBG_ASSERT(RecvDataChain); #ifdef QUIC_FUZZER @@ -3365,6 +3391,10 @@ CxPlatDataPathUdpRecvComplete( Drop: + if (RecvContext != NULL) { + CxPlatSocketFreeRecvContext(RecvContext); + } + return NeedReceive; } @@ -3372,7 +3402,8 @@ BOOLEAN CxPlatDataPathStartReceive( _In_ CXPLAT_SOCKET_PROC* SocketProc, _Out_opt_ ULONG* IoResult, - _Out_opt_ uint16_t* InlineBytesTransferred + _Out_opt_ uint16_t* InlineBytesTransferred, + _Out_opt_ CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT** RecvContext ) { // @@ -3387,7 +3418,8 @@ CxPlatDataPathStartReceive( CxPlatSocketStartReceive( SocketProc, IoResult, - InlineBytesTransferred); + InlineBytesTransferred, + RecvContext); } while (Status == QUIC_STATUS_OUT_OF_MEMORY && ++RetryCount < MAX_RECV_RETRIES); if (Status == QUIC_STATUS_OUT_OF_MEMORY) { @@ -3460,7 +3492,7 @@ CxPlatDataPathTcpRecvComplete( CASTED_CLOG_BYTEARRAY(sizeof(*LocalAddr), LocalAddr), CASTED_CLOG_BYTEARRAY(sizeof(*RemoteAddr), RemoteAddr)); - CXPLAT_DBG_ASSERT(NumberOfBytesTransferred <= SocketProc->RecvWsaBuf.len); + CXPLAT_DBG_ASSERT(NumberOfBytesTransferred <= SocketProc->Parent->RecvBufLen); CXPLAT_DATAPATH* Datapath = SocketProc->Parent->Datapath; CXPLAT_RECV_DATA* Data = (CXPLAT_RECV_DATA*)(RecvContext + 1); @@ -3478,6 +3510,7 @@ CxPlatDataPathTcpRecvComplete( Data->Allocated = TRUE; Data->QueuedOnConnection = FALSE; RecvContext->ReferenceCount++; + RecvContext = NULL; SocketProc->Parent->Datapath->TcpHandlers.Receive( SocketProc->Parent, @@ -3495,6 +3528,10 @@ CxPlatDataPathTcpRecvComplete( Drop: + if (RecvContext != NULL) { + CxPlatSocketFreeRecvContext(RecvContext); + } + return NeedReceive; } @@ -3527,9 +3564,7 @@ CxPlatRecvDataReturn( // // Clean up the data indication. // - CxPlatPoolFree( - BatchedInternalContext->OwningPool, - BatchedInternalContext); + CxPlatSocketFreeRecvContext(BatchedInternalContext); } BatchedInternalContext = InternalContext; @@ -3544,9 +3579,7 @@ CxPlatRecvDataReturn( // // Clean up the data indication. // - CxPlatPoolFree( - BatchedInternalContext->OwningPool, - BatchedInternalContext); + CxPlatSocketFreeRecvContext(BatchedInternalContext); } } @@ -3556,8 +3589,9 @@ CxPlatDataPathSocketProcessReceiveCompletion( _In_ CXPLAT_CQE* Cqe ) { - CXPLAT_SOCKET_PROC* SocketProc = CONTAINING_RECORD(Sqe, CXPLAT_SOCKET_PROC, IoSqe); - CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* RecvContext; + CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* RecvContext = + CONTAINING_RECORD(Sqe, CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT, Sqe); + CXPLAT_SOCKET_PROC* SocketProc = RecvContext->SocketProc; if (!CxPlatRundownAcquire(&SocketProc->UpcallRundown)) { return; @@ -3569,17 +3603,6 @@ CxPlatDataPathSocketProcessReceiveCompletion( for (ULONG InlineReceiveCount = 10; InlineReceiveCount > 0; InlineReceiveCount--) { BOOLEAN StartReceive; - // - // Copy the current receive buffer locally. On error cases, we leave the - // buffer set as the current receive buffer because we are only using it - // inline. Otherwise, we remove it as the current because we are giving - // it to the client. - // - CXPLAT_DBG_ASSERT(SocketProc->CurrentRecvContext != NULL); - RecvContext = SocketProc->CurrentRecvContext; - if (IoResult == NO_ERROR) { - SocketProc->CurrentRecvContext = NULL; - } if (SocketProc->Parent->Type == CXPLAT_SOCKET_UDP) { StartReceive = @@ -3601,7 +3624,8 @@ CxPlatDataPathSocketProcessReceiveCompletion( !CxPlatDataPathStartReceive( SocketProc, InlineReceiveCount > 1 ? &IoResult : NULL, - InlineReceiveCount > 1 ? &BytesTransferred : NULL)) { + InlineReceiveCount > 1 ? &BytesTransferred : NULL, + InlineReceiveCount > 1 ? &RecvContext : NULL)) { break; } } @@ -4243,7 +4267,7 @@ CxPlatFuzzerReceiveInject( } CXPLAT_DATAPATH_INTERNAL_RECV_CONTEXT* RecvContext = - CxPlatSocketAllocRecvContext(Socket->SocketProc->DatapathProc); + CxPlatSocketAllocRecvContext(Socket->SocketProc); if (!RecvContext) { return; }