Fix QUIC_TLS_SECRETS on Server and Client. (#3539)

* Fix QUIC_TLS_SECRETS on Server.

* Fix logging 0-RTT secrets, and add tests.

* Fix client not emitting ClientRandom with 0-RTT. 

* Create new function to parse ClientRandom only. Fixes memory leak.
This commit is contained in:
Anthony Rossi 2023-04-04 17:46:53 -07:00 коммит произвёл GitHub
Родитель d34b09585f
Коммит 43de82a81d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 218 добавлений и 37 удалений

Просмотреть файл

@ -1533,14 +1533,13 @@ QuicCryptoProcessTlsCompletion(
// //
if (Connection->TlsSecrets != NULL && if (Connection->TlsSecrets != NULL &&
QuicConnIsClient(Connection) && QuicConnIsClient(Connection) &&
Crypto->TlsState.WriteKey == QUIC_PACKET_KEY_INITIAL && (Crypto->TlsState.WriteKey == QUIC_PACKET_KEY_INITIAL ||
Crypto->TlsState.WriteKey == QUIC_PACKET_KEY_0_RTT) &&
Crypto->TlsState.BufferLength > 0) { Crypto->TlsState.BufferLength > 0) {
QUIC_NEW_CONNECTION_INFO Info = { 0 };
QuicCryptoTlsReadInitial( QuicCryptoTlsReadClientRandom(
Connection,
Crypto->TlsState.Buffer, Crypto->TlsState.Buffer,
Crypto->TlsState.BufferLength, Crypto->TlsState.BufferLength,
&Info,
Connection->TlsSecrets); Connection->TlsSecrets);
// //
// Connection is done with TlsSecrets, clean up. // Connection is done with TlsSecrets, clean up.
@ -1837,14 +1836,7 @@ QuicCryptoProcessData(
Connection, Connection,
Buffer.Buffer, Buffer.Buffer,
Buffer.Length, Buffer.Length,
&Info, &Info);
//
// On server, TLS is initialized before the listener
// is told about the connection, so TlsSecrets is still
// NULL.
//
NULL
);
if (QUIC_FAILED(Status)) { if (QUIC_FAILED(Status)) {
QuicConnTransportError( QuicConnTransportError(
Connection, Connection,
@ -1880,6 +1872,19 @@ QuicCryptoProcessData(
Connection->Paths[0].Binding, Connection->Paths[0].Binding,
Connection, Connection,
&Info); &Info);
if (Connection->TlsSecrets != NULL &&
!Connection->State.HandleClosed &&
Connection->State.ExternalOwner) {
//
// At this point, the connection was accepted by the listener,
// so now the ClientRandom can be copied.
//
QuicCryptoTlsReadClientRandom(
Buffer.Buffer,
Buffer.Length,
Connection->TlsSecrets);
}
return Status; return Status;
} }
} }

Просмотреть файл

@ -309,8 +309,20 @@ QuicCryptoTlsReadInitial(
_In_reads_(BufferLength) _In_reads_(BufferLength)
const uint8_t* Buffer, const uint8_t* Buffer,
_In_ uint32_t BufferLength, _In_ uint32_t BufferLength,
_Inout_ QUIC_NEW_CONNECTION_INFO* Info, _Inout_ QUIC_NEW_CONNECTION_INFO* Info
_Inout_opt_ QUIC_TLS_SECRETS* TlsSecrets );
//
// Reads only the ClientRandom out of the initial CRYPTO data.
// MUST ONLY BE CALLED AFTER QuicCryptoTlsReadInitial!!
//
_IRQL_requires_max_(PASSIVE_LEVEL)
QUIC_STATUS
QuicCryptoTlsReadClientRandom(
_In_reads_(BufferLength)
const uint8_t* Buffer,
_In_ uint32_t BufferLength,
_Inout_ QUIC_TLS_SECRETS* TlsSecrets
); );
// //

Просмотреть файл

@ -440,8 +440,7 @@ QuicCryptoTlsReadClientHello(
_In_reads_(BufferLength) _In_reads_(BufferLength)
const uint8_t* Buffer, const uint8_t* Buffer,
_In_ uint32_t BufferLength, _In_ uint32_t BufferLength,
_Inout_ QUIC_NEW_CONNECTION_INFO* Info, _Inout_ QUIC_NEW_CONNECTION_INFO* Info
_Inout_opt_ QUIC_TLS_SECRETS* TlsSecrets
) )
{ {
/* /*
@ -486,10 +485,6 @@ QuicCryptoTlsReadClientHello(
"Parse error. ReadTlsClientHello #2"); "Parse error. ReadTlsClientHello #2");
return QUIC_STATUS_INVALID_PARAMETER; return QUIC_STATUS_INVALID_PARAMETER;
} }
if (TlsSecrets != NULL) {
memcpy(TlsSecrets->ClientRandom, Buffer, TLS_RANDOM_LENGTH);
TlsSecrets->IsSet.ClientRandom = TRUE;
}
BufferLength -= TLS_RANDOM_LENGTH; BufferLength -= TLS_RANDOM_LENGTH;
Buffer += TLS_RANDOM_LENGTH; Buffer += TLS_RANDOM_LENGTH;
@ -605,8 +600,7 @@ QuicCryptoTlsReadInitial(
_In_reads_(BufferLength) _In_reads_(BufferLength)
const uint8_t* Buffer, const uint8_t* Buffer,
_In_ uint32_t BufferLength, _In_ uint32_t BufferLength,
_Inout_ QUIC_NEW_CONNECTION_INFO* Info, _Inout_ QUIC_NEW_CONNECTION_INFO* Info
_Inout_opt_ QUIC_TLS_SECRETS* TlsSecrets
) )
{ {
do { do {
@ -633,9 +627,7 @@ QuicCryptoTlsReadInitial(
Connection, Connection,
Buffer + TLS_MESSAGE_HEADER_LENGTH, Buffer + TLS_MESSAGE_HEADER_LENGTH,
MessageLength, MessageLength,
Info, Info);
TlsSecrets
);
if (QUIC_FAILED(Status)) { if (QUIC_FAILED(Status)) {
return Status; return Status;
} }
@ -664,6 +656,27 @@ QuicCryptoTlsReadInitial(
return QUIC_STATUS_SUCCESS; return QUIC_STATUS_SUCCESS;
} }
_IRQL_requires_max_(PASSIVE_LEVEL)
QUIC_STATUS
QuicCryptoTlsReadClientRandom(
_In_reads_(BufferLength)
const uint8_t* Buffer,
_In_ uint32_t BufferLength,
_Inout_ QUIC_TLS_SECRETS* TlsSecrets
)
{
UNREFERENCED_PARAMETER(BufferLength);
CXPLAT_DBG_ASSERT(
BufferLength >=
TLS_MESSAGE_HEADER_LENGTH + sizeof(uint16_t) + TLS_RANDOM_LENGTH);
Buffer += TLS_MESSAGE_HEADER_LENGTH + sizeof(uint16_t);
memcpy(TlsSecrets->ClientRandom, Buffer, TLS_RANDOM_LENGTH);
TlsSecrets->IsSet.ClientRandom = TRUE;
return QUIC_STATUS_SUCCESS;
}
_IRQL_requires_max_(DISPATCH_LEVEL) _IRQL_requires_max_(DISPATCH_LEVEL)
_Success_(return != NULL) _Success_(return != NULL)
const uint8_t* const uint8_t*

Просмотреть файл

@ -487,12 +487,6 @@ CxPlatTlsSetEncryptionSecretsCallback(
} }
if (TlsContext->TlsSecrets != NULL) { if (TlsContext->TlsSecrets != NULL) {
if (!TlsContext->TlsSecrets->IsSet.ClientRandom) {
if (SSL_get_client_random(Ssl, TlsContext->TlsSecrets->ClientRandom, sizeof(TlsContext->TlsSecrets->ClientRandom)) > 0) {
TlsContext->TlsSecrets->IsSet.ClientRandom = TRUE;
}
}
TlsContext->TlsSecrets->SecretLength = (uint8_t)SecretLen; TlsContext->TlsSecrets->SecretLength = (uint8_t)SecretLen;
switch (KeyType) { switch (KeyType) {
case QUIC_PACKET_KEY_HANDSHAKE: case QUIC_PACKET_KEY_HANDSHAKE:
@ -524,7 +518,11 @@ CxPlatTlsSetEncryptionSecretsCallback(
TlsContext->TlsSecrets = NULL; TlsContext->TlsSecrets = NULL;
break; break;
case QUIC_PACKET_KEY_0_RTT: case QUIC_PACKET_KEY_0_RTT:
if (!TlsContext->IsServer) { if (TlsContext->IsServer) {
CXPLAT_FRE_ASSERT(ReadSecret != NULL);
memcpy(TlsContext->TlsSecrets->ClientEarlyTrafficSecret, ReadSecret, SecretLen);
TlsContext->TlsSecrets->IsSet.ClientEarlyTrafficSecret = TRUE;
} else {
CXPLAT_FRE_ASSERT(WriteSecret != NULL); CXPLAT_FRE_ASSERT(WriteSecret != NULL);
memcpy(TlsContext->TlsSecrets->ClientEarlyTrafficSecret, WriteSecret, SecretLen); memcpy(TlsContext->TlsSecrets->ClientEarlyTrafficSecret, WriteSecret, SecretLen);
TlsContext->TlsSecrets->IsSet.ClientEarlyTrafficSecret = TRUE; TlsContext->TlsSecrets->IsSet.ClientEarlyTrafficSecret = TRUE;

Просмотреть файл

@ -2398,6 +2398,15 @@ CxPlatTlsWriteDataToSchannel(
Result |= CXPLAT_TLS_RESULT_READ_KEY_UPDATED; Result |= CXPLAT_TLS_RESULT_READ_KEY_UPDATED;
if (NewPeerTrafficSecrets[i]->TrafficSecretType == SecTrafficSecret_ClientEarlyData) { if (NewPeerTrafficSecrets[i]->TrafficSecretType == SecTrafficSecret_ClientEarlyData) {
CXPLAT_FRE_ASSERT(FALSE); // TODO - Finish the 0-RTT logic. CXPLAT_FRE_ASSERT(FALSE); // TODO - Finish the 0-RTT logic.
CXPLAT_FRE_ASSERT(TlsContext->IsServer);
if (TlsContext->TlsSecrets != NULL) {
TlsContext->TlsSecrets->SecretLength = (uint8_t)NewPeerTrafficSecrets[i]->TrafficSecretSize;
memcpy(
TlsContext->TlsSecrets->ClientEarlyTrafficSecret,
NewPeerTrafficSecrets[i]->TrafficSecret,
NewPeerTrafficSecrets[i]->TrafficSecretSize);
TlsContext->TlsSecrets->IsSet.ClientEarlyTrafficSecret = TRUE;
}
} else { } else {
if (State->ReadKey == QUIC_PACKET_KEY_INITIAL) { if (State->ReadKey == QUIC_PACKET_KEY_INITIAL) {
if (!QuicPacketKeyCreate( if (!QuicPacketKeyCreate(

Просмотреть файл

@ -53,11 +53,14 @@ struct PingStats
const QUIC_STATUS ExpectedCloseStatus; const QUIC_STATUS ExpectedCloseStatus;
volatile long ConnectionsComplete; volatile long ConnectionsComplete;
volatile long SecretsIndex;
CXPLAT_EVENT CompletionEvent; CXPLAT_EVENT CompletionEvent;
QUIC_BUFFER* ResumptionTicket {nullptr}; QUIC_BUFFER* ResumptionTicket {nullptr};
QUIC_TLS_SECRETS* TlsSecrets {nullptr};
PingStats( PingStats(
uint64_t _PayloadLength, uint64_t _PayloadLength,
uint32_t _ConnectionCount, uint32_t _ConnectionCount,
@ -80,7 +83,8 @@ struct PingStats
AllowDataIncomplete(_AllowDataIncomplete), AllowDataIncomplete(_AllowDataIncomplete),
ServerKeyUpdate(_ServerKeyUpdate), ServerKeyUpdate(_ServerKeyUpdate),
ExpectedCloseStatus(_ExpectedCloseStatus), ExpectedCloseStatus(_ExpectedCloseStatus),
ConnectionsComplete(0) ConnectionsComplete(0),
SecretsIndex(0)
{ {
CxPlatEventInitialize(&CompletionEvent, FALSE, FALSE); CxPlatEventInitialize(&CompletionEvent, FALSE, FALSE);
} }
@ -274,6 +278,15 @@ ListenerAcceptPingConnection(
} }
} }
if (Stats->TlsSecrets) {
auto Status = Connection->SetTlsSecrets(
&(Stats->TlsSecrets[InterlockedIncrement(&Stats->SecretsIndex) - 1]));
if (QUIC_FAILED(Status)) {
TEST_FAILURE("SetParam(QUIC_TLS_SECRETS) failed with 0x%x", Status);
return false;
}
}
Connection->SetPriorityScheme( Connection->SetPriorityScheme(
Stats->FifoScheduling ? Stats->FifoScheduling ?
QUIC_STREAM_SCHEDULING_SCHEME_FIFO : QUIC_STREAM_SCHEDULING_SCHEME_FIFO :
@ -369,6 +382,19 @@ QuicTestConnectAndPing(
// //
} }
UniquePtr<QUIC_TLS_SECRETS[]> ClientSecrets;
UniquePtr<QUIC_TLS_SECRETS[]> ServerSecrets;
if (ClientZeroRtt && !ServerRejectZeroRtt) {
ClientSecrets.reset(
new(std::nothrow) QUIC_TLS_SECRETS[ConnectionCount]);
ServerSecrets.reset(
new(std::nothrow) QUIC_TLS_SECRETS[ConnectionCount]);
if (ClientSecrets == nullptr || ServerSecrets == nullptr) {
return;
}
ServerStats.TlsSecrets = ServerSecrets.get();
}
MsQuicRegistration Registration(NULL, QUIC_EXECUTION_PROFILE_TYPE_MAX_THROUGHPUT, true); MsQuicRegistration Registration(NULL, QUIC_EXECUTION_PROFILE_TYPE_MAX_THROUGHPUT, true);
TEST_TRUE(Registration.IsValid()); TEST_TRUE(Registration.IsValid());
@ -451,6 +477,10 @@ QuicTestConnectAndPing(
if (Connections.get()[i] == nullptr) { if (Connections.get()[i] == nullptr) {
return; return;
} }
if (ClientSecrets) {
TEST_QUIC_SUCCEEDED(
Connections.get()[i]->SetTlsSecrets(&ClientSecrets[i]));
}
} }
QuicAddr LocalAddr; QuicAddr LocalAddr;
@ -512,6 +542,42 @@ QuicTestConnectAndPing(
TEST_FAILURE("Wait for server to complete timed out after %u ms.", TimeoutMs); TEST_FAILURE("Wait for server to complete timed out after %u ms.", TimeoutMs);
return; return;
} }
if (ClientSecrets) {
for (auto i = 0u; i < ConnectionCount; i++) {
auto ServerSecret = &ServerSecrets[i];
bool Match = false;
for (auto j = 0u; j < ConnectionCount; j++) {
auto ClientSecret = &ClientSecrets[j];
if (!memcmp(
ServerSecret->ClientRandom,
ClientSecret->ClientRandom,
sizeof(ClientSecret->ClientRandom))) {
if (Match) {
TEST_FAILURE("Multiple clients with the same ClientRandom?!");
return;
}
TEST_EQUAL(
ClientSecret->IsSet.ClientEarlyTrafficSecret,
ServerSecret->IsSet.ClientEarlyTrafficSecret);
TEST_EQUAL(
ClientSecret->SecretLength,
ServerSecret->SecretLength);
TEST_TRUE(
!memcmp(
ClientSecret->ClientEarlyTrafficSecret,
ServerSecret->ClientEarlyTrafficSecret,
ClientSecret->SecretLength));
Match = true;
}
}
if (!Match) {
TEST_FAILURE("Failed to match Server Secrets to any Client Secrets!");
return;
}
}
}
} }
} }

Просмотреть файл

@ -138,6 +138,13 @@ ListenerAcceptConnection(
(*AcceptContext->NewConnection)->SetPeerCertEventReturnStatus( (*AcceptContext->NewConnection)->SetPeerCertEventReturnStatus(
AcceptContext->PeerCertEventReturnStatus); AcceptContext->PeerCertEventReturnStatus);
} }
if (AcceptContext->TlsSecrets != NULL) {
auto Status = (*AcceptContext->NewConnection)->SetTlsSecrets(AcceptContext->TlsSecrets);
if (QUIC_FAILED(Status)) {
TEST_FAILURE("SetParam(QUIC_PARAM_CONN_TLS_SECRETS) returned 0x%x", Status);
return false;
}
}
CxPlatEventSet(AcceptContext->NewConnectionReady); CxPlatEventSet(AcceptContext->NewConnectionReady);
return true; return true;
} }
@ -164,6 +171,8 @@ QuicTestConnect(
MsQuicAlpn Alpn1("MsQuicTest"); MsQuicAlpn Alpn1("MsQuicTest");
MsQuicAlpn Alpn2("MsQuicTest2", "MsQuicTest"); MsQuicAlpn Alpn2("MsQuicTest2", "MsQuicTest");
QUIC_TLS_SECRETS ClientSecrets{}, ServerSecrets{};
MsQuicSettings Settings; MsQuicSettings Settings;
Settings.SetPeerBidiStreamCount(4); Settings.SetPeerBidiStreamCount(4);
Settings.SetGreaseQuicBitEnabled(GreaseQuicBitEnabled); Settings.SetGreaseQuicBitEnabled(GreaseQuicBitEnabled);
@ -242,6 +251,7 @@ QuicTestConnect(
ServerAcceptCtx.ExpectedCustomTicketValidationResult = QUIC_STATUS_INTERNAL_ERROR; ServerAcceptCtx.ExpectedCustomTicketValidationResult = QUIC_STATUS_INTERNAL_ERROR;
} }
} }
ServerAcceptCtx.TlsSecrets = &ServerSecrets;
Listener.Context = &ServerAcceptCtx; Listener.Context = &ServerAcceptCtx;
@ -249,6 +259,7 @@ QuicTestConnect(
TestConnection Client(Registration); TestConnection Client(Registration);
TEST_TRUE(Client.IsValid()); TEST_TRUE(Client.IsValid());
Client.SetHasRandomLoss(RandomLossPercentage != 0); Client.SetHasRandomLoss(RandomLossPercentage != 0);
TEST_QUIC_SUCCEEDED(Client.SetTlsSecrets(&ClientSecrets));
if (ClientUsesOldVersion) { if (ClientUsesOldVersion) {
TEST_QUIC_SUCCEEDED( TEST_QUIC_SUCCEEDED(
@ -311,6 +322,54 @@ QuicTestConnect(
} }
TEST_TRUE(Server->GetIsConnected()); TEST_TRUE(Server->GetIsConnected());
TEST_EQUAL(
ServerSecrets.IsSet.ClientRandom,
ClientSecrets.IsSet.ClientRandom);
TEST_TRUE(
!memcmp(
ServerSecrets.ClientRandom,
ClientSecrets.ClientRandom,
sizeof(ServerSecrets.ClientRandom)));
TEST_EQUAL(ServerSecrets.SecretLength, ClientSecrets.SecretLength);
TEST_TRUE(ServerSecrets.SecretLength <= QUIC_TLS_SECRETS_MAX_SECRET_LEN);
TEST_EQUAL(
ServerSecrets.IsSet.ClientHandshakeTrafficSecret,
ClientSecrets.IsSet.ClientHandshakeTrafficSecret);
TEST_TRUE(
!memcmp(
ServerSecrets.ClientHandshakeTrafficSecret,
ClientSecrets.ClientHandshakeTrafficSecret,
ServerSecrets.SecretLength));
TEST_EQUAL(
ServerSecrets.IsSet.ServerHandshakeTrafficSecret,
ClientSecrets.IsSet.ServerHandshakeTrafficSecret);
TEST_TRUE(
!memcmp(
ServerSecrets.ServerHandshakeTrafficSecret,
ClientSecrets.ServerHandshakeTrafficSecret,
ServerSecrets.SecretLength));
TEST_EQUAL(
ServerSecrets.IsSet.ClientTrafficSecret0,
ClientSecrets.IsSet.ClientTrafficSecret0);
TEST_TRUE(
!memcmp(
ServerSecrets.ClientTrafficSecret0,
ClientSecrets.ClientTrafficSecret0,
ServerSecrets.SecretLength));
TEST_EQUAL(
ServerSecrets.IsSet.ServerTrafficSecret0,
ClientSecrets.IsSet.ServerTrafficSecret0);
TEST_TRUE(
!memcmp(
ServerSecrets.ServerTrafficSecret0,
ClientSecrets.ServerTrafficSecret0,
ServerSecrets.SecretLength));
if (ClientUsesOldVersion) { if (ClientUsesOldVersion) {
TEST_EQUAL(Server->GetQuicVersion(), OLD_SUPPORTED_VERSION); TEST_EQUAL(Server->GetQuicVersion(), OLD_SUPPORTED_VERSION);
} else { } else {

Просмотреть файл

@ -957,7 +957,8 @@ TestConnection::HandleConnectionEvent(
} }
uint32_t uint32_t
TestConnection::GetDestCidUpdateCount() { TestConnection::GetDestCidUpdateCount()
{
QUIC_STATISTICS_V2 Stats; QUIC_STATISTICS_V2 Stats;
uint32_t StatsSize = sizeof(Stats); uint32_t StatsSize = sizeof(Stats);
QUIC_STATUS Status = QUIC_STATUS Status =
@ -974,11 +975,26 @@ TestConnection::GetDestCidUpdateCount() {
} }
const uint8_t* const uint8_t*
TestConnection::GetNegotiatedAlpn() const { TestConnection::GetNegotiatedAlpn() const
{
return NegotiatedAlpn; return NegotiatedAlpn;
} }
uint8_t uint8_t
TestConnection::GetNegotiatedAlpnLength() const { TestConnection::GetNegotiatedAlpnLength() const
{
return NegotiatedAlpnLength; return NegotiatedAlpnLength;
} }
QUIC_STATUS
TestConnection::SetTlsSecrets(
QUIC_TLS_SECRETS* Secrets
)
{
return
MsQuic->SetParam(
QuicConnection,
QUIC_PARAM_CONN_TLS_SECRETS,
sizeof(*Secrets),
Secrets);
}

Просмотреть файл

@ -305,4 +305,6 @@ public:
const uint8_t* GetNegotiatedAlpn() const; const uint8_t* GetNegotiatedAlpn() const;
uint8_t GetNegotiatedAlpnLength() const; uint8_t GetNegotiatedAlpnLength() const;
QUIC_STATUS SetTlsSecrets(QUIC_TLS_SECRETS* Secrets);
}; };

Просмотреть файл

@ -80,6 +80,7 @@ struct ServerAcceptContext {
CXPLAT_EVENT NewConnectionReady; CXPLAT_EVENT NewConnectionReady;
TestConnection** NewConnection; TestConnection** NewConnection;
void* NewStreamHandler{nullptr}; void* NewStreamHandler{nullptr};
QUIC_TLS_SECRETS* TlsSecrets{nullptr};
QUIC_STATUS ExpectedTransportCloseStatus{QUIC_STATUS_SUCCESS}; QUIC_STATUS ExpectedTransportCloseStatus{QUIC_STATUS_SUCCESS};
QUIC_STATUS ExpectedClientCertValidationResult[2]{}; QUIC_STATUS ExpectedClientCertValidationResult[2]{};
uint32_t ExpectedClientCertValidationResultCount{0}; uint32_t ExpectedClientCertValidationResultCount{0};