diff --git a/src/core/library.c b/src/core/library.c index c66bc5876..253ab494a 100644 --- a/src/core/library.c +++ b/src/core/library.c @@ -794,7 +794,8 @@ QuicLibrarySetGlobalParam( Status = QUIC_STATUS_INVALID_PARAMETER; break; } - int32_t Value = *(int32_t*)Buffer; + int32_t Value; + CxPlatCopyMemory(&Value, Buffer, sizeof(Value)); if (Value < 0) { Status = QUIC_STATUS_INVALID_PARAMETER; break; @@ -809,7 +810,8 @@ QuicLibrarySetGlobalParam( Status = QUIC_STATUS_INVALID_PARAMETER; break; } - int32_t Value = *(int32_t*)Buffer; + int32_t Value; + CxPlatCopyMemory(&Value, Buffer, sizeof(Value)); if (Value < 0) { Status = QUIC_STATUS_INVALID_PARAMETER; break; diff --git a/src/core/send.c b/src/core/send.c index 396560cdc..29ff63108 100644 --- a/src/core/send.c +++ b/src/core/send.c @@ -806,7 +806,8 @@ Exit: // The only valid reason to not have framed anything is that there was too // little room left in the packet to fit anything more. // - CXPLAT_DBG_ASSERT(Builder->Metadata->FrameCount > PrevFrameCount || RanOutOfRoom); + CXPLAT_DBG_ASSERT(Builder->Metadata->FrameCount > PrevFrameCount || RanOutOfRoom || + CxPlatIsRandomMemoryFailureEnabled()); UNREFERENCED_PARAMETER(RanOutOfRoom); return Builder->Metadata->FrameCount > PrevFrameCount; diff --git a/src/inc/msquic.hpp b/src/inc/msquic.hpp index 751dc892d..c64d42baa 100644 --- a/src/inc/msquic.hpp +++ b/src/inc/msquic.hpp @@ -718,6 +718,10 @@ private: if (Connection) { Status = Connection->SetConfiguration(pThis->Configuration); if (QUIC_FAILED(Status)) { + // + // The connection is being rejected. Let MsQuic free the handle. + // + Connection->Handle = nullptr; delete Connection; } } diff --git a/src/platform/platform_posix.c b/src/platform/platform_posix.c index a069e752c..6a1ab0695 100644 --- a/src/platform/platform_posix.c +++ b/src/platform/platform_posix.c @@ -513,6 +513,7 @@ CxPlatSetAllocFailDenominator( ) { CxPlatform.AllocFailDenominator = Value; + CxPlatform.AllocCounter = 0; } int32_t diff --git a/src/platform/platform_winkernel.c b/src/platform/platform_winkernel.c index ef89d708a..4f796597f 100644 --- a/src/platform/platform_winkernel.c +++ b/src/platform/platform_winkernel.c @@ -242,6 +242,7 @@ CxPlatSetAllocFailDenominator( ) { CxPlatform.AllocFailDenominator = Value; + CxPlatform.AllocCounter = 0; } int32_t diff --git a/src/platform/platform_winuser.c b/src/platform/platform_winuser.c index 57d65778c..69be7c615 100644 --- a/src/platform/platform_winuser.c +++ b/src/platform/platform_winuser.c @@ -601,6 +601,7 @@ CxPlatSetAllocFailDenominator( ) { CxPlatform.AllocFailDenominator = Value; + CxPlatform.AllocCounter = 0; } int32_t CxPlatGetAllocFailDenominator( diff --git a/src/test/MsQuicTests.h b/src/test/MsQuicTests.h index 7ff136a92..54360c595 100644 --- a/src/test/MsQuicTests.h +++ b/src/test/MsQuicTests.h @@ -355,6 +355,10 @@ void QuicTestSlowReceive( ); +void +QuicTestNthAllocFail( + ); + // // QuicDrill tests // @@ -812,4 +816,7 @@ typedef struct { #define IOCTL_QUIC_RUN_SLOW_RECEIVE \ QUIC_CTL_CODE(65, METHOD_BUFFERED, FILE_WRITE_DATA) -#define QUIC_MAX_IOCTL_FUNC_CODE 65 +#define IOCTL_QUIC_RUN_NTH_ALLOC_FAIL \ + QUIC_CTL_CODE(66, METHOD_BUFFERED, FILE_WRITE_DATA) + +#define QUIC_MAX_IOCTL_FUNC_CODE 66 diff --git a/src/test/bin/quic_gtest.cpp b/src/test/bin/quic_gtest.cpp index 4e9b47ed4..7c219184b 100644 --- a/src/test/bin/quic_gtest.cpp +++ b/src/test/bin/quic_gtest.cpp @@ -1327,6 +1327,15 @@ TEST(Misc, SlowReceive) { } } +TEST(Misc, NthAllocFail) { + TestLogger Logger("NthAllocFail"); + if (TestingKernelMode) { + ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_NTH_ALLOC_FAIL)); + } else { + QuicTestNthAllocFail(); + } +} + TEST(Drill, VarIntEncoder) { TestLogger Logger("QuicDrillTestVarIntEncoder"); if (TestingKernelMode) { diff --git a/src/test/bin/winkernel/control.cpp b/src/test/bin/winkernel/control.cpp index 82c8c011a..a71a15c4e 100644 --- a/src/test/bin/winkernel/control.cpp +++ b/src/test/bin/winkernel/control.cpp @@ -434,6 +434,7 @@ size_t QUIC_IOCTL_BUFFER_SIZES[] = sizeof(QUIC_RUN_CRED_VALIDATION), sizeof(QUIC_ABORT_RECEIVE_TYPE), sizeof(QUIC_RUN_KEY_UPDATE_RANDOM_LOSS_PARAMS), + 0, 0 }; @@ -1035,6 +1036,10 @@ QuicTestCtlEvtIoDeviceControl( QuicTestCtlRun(QuicTestSlowReceive()); break; + case IOCTL_QUIC_RUN_NTH_ALLOC_FAIL: + QuicTestCtlRun(QuicTestNthAllocFail()); + break; + default: Status = STATUS_NOT_IMPLEMENTED; break; diff --git a/src/test/lib/DataTest.cpp b/src/test/lib/DataTest.cpp index a0ac61886..52f220eae 100644 --- a/src/test/lib/DataTest.cpp +++ b/src/test/lib/DataTest.cpp @@ -2239,3 +2239,95 @@ QuicTestSlowReceive( TEST_TRUE(Context.ServerStreamShutdown.WaitTimeout(TestWaitTimeout)); TEST_TRUE(Context.ServerStreamHasShutdown); } + +struct NthAllocFailTestContext { + CxPlatEvent ServerStreamRecv; + CxPlatEvent ServerStreamShutdown; + MsQuicStream* ServerStream {nullptr}; + bool ServerStreamHasShutdown {false}; + + static QUIC_STATUS StreamCallback(_In_ MsQuicStream* Stream, _In_opt_ void* Context, _Inout_ QUIC_STREAM_EVENT* Event) { + auto TestContext = (NthAllocFailTestContext*)Context; + if (Event->Type == QUIC_STREAM_EVENT_RECEIVE) { + TestContext->ServerStreamRecv.Set(); + } else if (Event->Type == QUIC_STREAM_EVENT_SHUTDOWN_COMPLETE) { + TestContext->ServerStreamHasShutdown = true; + TestContext->ServerStreamShutdown.Set(); + Stream->ConnectionShutdown(1); + } + return QUIC_STATUS_SUCCESS; + } + + static QUIC_STATUS ConnCallback(_In_ MsQuicConnection*, _In_opt_ void* Context, _Inout_ QUIC_CONNECTION_EVENT* Event) { + auto TestContext = (NthAllocFailTestContext*)Context; + if (Event->Type == QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED) { + TestContext->ServerStream = new MsQuicStream(Event->PEER_STREAM_STARTED.Stream, CleanUpAutoDelete, StreamCallback, Context); + } + return QUIC_STATUS_SUCCESS; + } +}; + +struct AllocFailScope { + ~AllocFailScope() { + int32_t Zero = 0; + MsQuic->SetParam( + nullptr, + QUIC_PARAM_LEVEL_GLOBAL, + QUIC_PARAM_GLOBAL_ALLOC_FAIL_CYCLE, + sizeof(Zero), + &Zero); + } +}; + +#define CONTINUE_ON_FAIL(__condition) { \ + QUIC_STATUS __status = __condition; \ + if (QUIC_FAILED(__status)) { \ + continue; \ + } \ +} + +void +QuicTestNthAllocFail( + ) +{ + AllocFailScope Scope{}; + + for (uint32_t i = 100; i > 1; i--) { + TEST_QUIC_SUCCEEDED(MsQuic->SetParam( + nullptr, + QUIC_PARAM_LEVEL_GLOBAL, + QUIC_PARAM_GLOBAL_ALLOC_FAIL_CYCLE, + sizeof(i), + &i)); + + MsQuicRegistration Registration; + CONTINUE_ON_FAIL(Registration.GetInitStatus()); + + MsQuicConfiguration ServerConfiguration(Registration, "MsQuicTest", MsQuicSettings().SetPeerUnidiStreamCount(1), ServerSelfSignedCredConfig); + CONTINUE_ON_FAIL(ServerConfiguration.GetInitStatus()); + + MsQuicConfiguration ClientConfiguration(Registration, "MsQuicTest", MsQuicCredentialConfig()); + CONTINUE_ON_FAIL(ClientConfiguration.GetInitStatus()); + + NthAllocFailTestContext RecvContext {}; + MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, NthAllocFailTestContext::ConnCallback, &RecvContext); + CONTINUE_ON_FAIL(Listener.GetInitStatus()); + CONTINUE_ON_FAIL(Listener.Start("MsQuicTest")); + QuicAddr ServerLocalAddr; + CONTINUE_ON_FAIL(Listener.GetLocalAddr(ServerLocalAddr)); + + MsQuicConnection Connection(Registration); + CONTINUE_ON_FAIL(Connection.GetInitStatus()); + CONTINUE_ON_FAIL(Connection.StartLocalhost(ClientConfiguration, ServerLocalAddr)); + + MsQuicStream Stream(Connection, QUIC_STREAM_OPEN_FLAG_UNIDIRECTIONAL); + CONTINUE_ON_FAIL(Stream.GetInitStatus()); + + uint8_t RawBuffer[100]; + QUIC_BUFFER Buffer { sizeof(RawBuffer), RawBuffer }; + CONTINUE_ON_FAIL(Stream.Send(&Buffer, 1, QUIC_SEND_FLAG_START | QUIC_SEND_FLAG_FIN)); + + RecvContext.ServerStreamRecv.WaitTimeout(100); + RecvContext.ServerStreamShutdown.WaitTimeout(100); + } +}