зеркало из https://github.com/mozilla/gecko-dev.git
912 строки
29 KiB
C++
912 строки
29 KiB
C++
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
|
|
/* vim: set ts=2 et sw=2 tw=80: */
|
|
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
* You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
#include "tls_agent.h"
|
|
#include "databuffer.h"
|
|
#include "keyhi.h"
|
|
#include "pk11func.h"
|
|
#include "ssl.h"
|
|
#include "sslerr.h"
|
|
#include "sslproto.h"
|
|
#include "tls_parser.h"
|
|
|
|
extern "C" {
|
|
// This is not something that should make you happy.
|
|
#include "libssl_internals.h"
|
|
}
|
|
|
|
#define GTEST_HAS_RTTI 0
|
|
#include "gtest/gtest.h"
|
|
#include "gtest_utils.h"
|
|
#include "scoped_ptrs.h"
|
|
|
|
extern std::string g_working_dir_path;
|
|
|
|
namespace nss_test {
|
|
|
|
const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
|
|
|
|
const std::string TlsAgent::kClient = "client"; // both sign and encrypt
|
|
const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
|
|
const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
|
|
const std::string TlsAgent::kServerRsaSign = "rsa_sign";
|
|
const std::string TlsAgent::kServerRsaPss = "rsa_pss";
|
|
const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
|
|
const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
|
|
const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
|
|
const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
|
|
const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
|
|
const std::string TlsAgent::kServerDsa = "dsa";
|
|
|
|
TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
|
|
: name_(name),
|
|
mode_(mode),
|
|
server_key_bits_(0),
|
|
pr_fd_(nullptr),
|
|
adapter_(nullptr),
|
|
ssl_fd_(nullptr),
|
|
role_(role),
|
|
state_(STATE_INIT),
|
|
timer_handle_(nullptr),
|
|
falsestart_enabled_(false),
|
|
expected_version_(0),
|
|
expected_cipher_suite_(0),
|
|
expect_resumption_(false),
|
|
expect_client_auth_(false),
|
|
can_falsestart_hook_called_(false),
|
|
sni_hook_called_(false),
|
|
auth_certificate_hook_called_(false),
|
|
handshake_callback_called_(false),
|
|
error_code_(0),
|
|
send_ctr_(0),
|
|
recv_ctr_(0),
|
|
expect_readwrite_error_(false),
|
|
handshake_callback_(),
|
|
auth_certificate_callback_(),
|
|
sni_callback_() {
|
|
memset(&info_, 0, sizeof(info_));
|
|
memset(&csinfo_, 0, sizeof(csinfo_));
|
|
SECStatus rv = SSL_VersionRangeGetDefault(
|
|
mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
TlsAgent::~TlsAgent() {
|
|
if (adapter_) {
|
|
Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
|
|
// The adapter is closed when the FD closes.
|
|
}
|
|
if (timer_handle_) {
|
|
timer_handle_->Cancel();
|
|
}
|
|
|
|
if (pr_fd_) {
|
|
PR_Close(pr_fd_);
|
|
}
|
|
|
|
if (ssl_fd_) {
|
|
PR_Close(ssl_fd_);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::SetState(State state) {
|
|
if (state_ == state) return;
|
|
|
|
LOG("Changing state from " << state_ << " to " << state);
|
|
state_ = state;
|
|
}
|
|
|
|
bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits,
|
|
const SSLExtraServerCertData* serverCertData) {
|
|
ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr));
|
|
EXPECT_NE(nullptr, cert.get());
|
|
if (!cert.get()) return false;
|
|
|
|
ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
|
|
EXPECT_NE(nullptr, pub.get());
|
|
if (!pub.get()) return false;
|
|
if (updateKeyBits) {
|
|
server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
|
|
}
|
|
|
|
ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr));
|
|
EXPECT_NE(nullptr, priv.get());
|
|
if (!priv.get()) return false;
|
|
|
|
SECStatus rv =
|
|
SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null);
|
|
EXPECT_EQ(SECFailure, rv);
|
|
rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData,
|
|
serverCertData ? sizeof(*serverCertData) : 0);
|
|
return rv == SECSuccess;
|
|
}
|
|
|
|
bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
|
|
// Don't set up twice
|
|
if (ssl_fd_) return true;
|
|
|
|
if (adapter_->mode() == STREAM) {
|
|
ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
|
|
} else {
|
|
ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
|
|
}
|
|
|
|
EXPECT_NE(nullptr, ssl_fd_);
|
|
if (!ssl_fd_) return false;
|
|
pr_fd_ = nullptr;
|
|
|
|
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
|
|
if (role_ == SERVER) {
|
|
EXPECT_TRUE(ConfigServerCert(name_, true));
|
|
|
|
rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
} else {
|
|
rv = SSL_SetURL(ssl_fd_, "server");
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
}
|
|
|
|
rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
|
|
rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
if (rv != SECSuccess) return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
void TlsAgent::SetupClientAuth() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
ASSERT_EQ(CLIENT, role_);
|
|
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
|
|
reinterpret_cast<void*>(this)));
|
|
}
|
|
|
|
bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert,
|
|
SECKEYPrivateKey** priv) const {
|
|
*cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
|
|
EXPECT_NE(nullptr, *cert);
|
|
if (!*cert) return false;
|
|
|
|
*priv = PK11_FindKeyByAnyCert(*cert, nullptr);
|
|
EXPECT_NE(nullptr, *priv);
|
|
if (!*priv) return false; // Leak cert.
|
|
|
|
return true;
|
|
}
|
|
|
|
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
|
|
CERTDistNames* caNames,
|
|
CERTCertificate** cert,
|
|
SECKEYPrivateKey** privKey) {
|
|
TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
|
|
ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
|
|
EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
|
|
if (agent->GetClientAuthCredentials(cert, privKey)) {
|
|
return SECSuccess;
|
|
}
|
|
return SECFailure;
|
|
}
|
|
|
|
void TlsAgent::RequestClientAuth(bool requireAuth) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
ASSERT_EQ(SERVER, role_);
|
|
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE));
|
|
EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE,
|
|
requireAuth ? PR_TRUE : PR_FALSE));
|
|
|
|
EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
|
|
ssl_fd_, &TlsAgent::ClientAuthenticated, this));
|
|
expect_client_auth_ = true;
|
|
}
|
|
|
|
void TlsAgent::StartConnect(PRFileDesc* model) {
|
|
EXPECT_TRUE(EnsureTlsSetup(model));
|
|
|
|
SECStatus rv;
|
|
rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
SetState(STATE_CONNECTING);
|
|
}
|
|
|
|
void TlsAgent::DisableAllCiphers() {
|
|
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
|
|
SECStatus rv =
|
|
SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
}
|
|
|
|
// Not actually all groups, just the onece that we are actually willing
|
|
// to use.
|
|
const std::vector<SSLNamedGroup> kAllDHEGroups = {
|
|
ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1,
|
|
ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
|
|
ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
|
|
|
|
const std::vector<SSLNamedGroup> kECDHEGroups = {
|
|
ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1};
|
|
|
|
const std::vector<SSLNamedGroup> kFFDHEGroups = {
|
|
ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
|
|
ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
|
|
|
|
// Defined because the big DHE groups are ridiculously slow.
|
|
const std::vector<SSLNamedGroup> kFasterDHEGroups = {
|
|
ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1,
|
|
ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072};
|
|
|
|
void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
|
|
SSLCipherSuiteInfo csinfo;
|
|
|
|
SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
|
|
sizeof(csinfo));
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
EXPECT_EQ(sizeof(csinfo), csinfo.length);
|
|
|
|
if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
|
|
rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
|
|
switch (kea) {
|
|
case ssl_kea_dh:
|
|
ConfigNamedGroups(kFFDHEGroups);
|
|
break;
|
|
case ssl_kea_ecdh:
|
|
ConfigNamedGroups(kECDHEGroups);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
|
|
if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
|
|
authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
|
|
ConfigNamedGroups(kECDHEGroups);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
|
|
SSLCipherSuiteInfo csinfo;
|
|
|
|
SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
|
|
sizeof(csinfo));
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
|
|
if ((csinfo.authType == authType) ||
|
|
(csinfo.keaType == ssl_kea_tls13_any)) {
|
|
rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TlsAgent::EnableSingleCipher(uint16_t cipher) {
|
|
DisableAllCiphers();
|
|
SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size());
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SetSessionTicketsEnabled(bool en) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
|
|
en ? PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SetSessionCacheEnabled(bool en) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::Set0RttEnabled(bool en) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv =
|
|
SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
|
|
vrange_.min = minver;
|
|
vrange_.max = maxver;
|
|
|
|
if (ssl_fd_) {
|
|
SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
|
|
*minver = vrange_.min;
|
|
*maxver = vrange_.max;
|
|
}
|
|
|
|
void TlsAgent::SetExpectedVersion(uint16_t version) {
|
|
expected_version_ = version;
|
|
}
|
|
|
|
void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
|
|
|
|
void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
|
|
|
|
void TlsAgent::SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms,
|
|
size_t count) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
EXPECT_LE(count, SSL_SignatureMaxCount());
|
|
EXPECT_EQ(SECSuccess, SSL_SignaturePrefSet(ssl_fd_, algorithms,
|
|
static_cast<unsigned int>(count)));
|
|
EXPECT_EQ(SECFailure, SSL_SignaturePrefSet(ssl_fd_, algorithms, 0))
|
|
<< "setting no algorithms should fail and do nothing";
|
|
|
|
std::vector<SSLSignatureAndHashAlg> configuredAlgorithms(count);
|
|
unsigned int configuredCount;
|
|
EXPECT_EQ(SECFailure,
|
|
SSL_SignaturePrefGet(ssl_fd_, nullptr, &configuredCount, 1))
|
|
<< "get algorithms, algorithms is nullptr";
|
|
EXPECT_EQ(SECFailure, SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0],
|
|
&configuredCount, 0))
|
|
<< "get algorithms, too little space";
|
|
EXPECT_EQ(SECFailure,
|
|
SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], nullptr,
|
|
configuredAlgorithms.size()))
|
|
<< "get algorithms, algCountOut is nullptr";
|
|
|
|
EXPECT_EQ(SECSuccess, SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0],
|
|
&configuredCount,
|
|
configuredAlgorithms.size()));
|
|
// SignaturePrefSet drops unsupported algorithms silently, so the number that
|
|
// are configured might be fewer.
|
|
EXPECT_LE(configuredCount, count);
|
|
unsigned int i = 0;
|
|
for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
|
|
if (i < configuredCount &&
|
|
algorithms[j].hashAlg == configuredAlgorithms[i].hashAlg &&
|
|
algorithms[j].sigAlg == configuredAlgorithms[i].sigAlg) {
|
|
++i;
|
|
}
|
|
}
|
|
EXPECT_EQ(i, configuredCount) << "algorithms in use were all set";
|
|
}
|
|
|
|
void TlsAgent::CheckKEA(SSLKEAType type, size_t kea_size) const {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
EXPECT_EQ(type, info_.keaType);
|
|
EXPECT_EQ(kea_size, info_.keaKeyBits);
|
|
}
|
|
|
|
void TlsAgent::CheckKEA(SSLKEAType type) const {
|
|
PRUint32 ecKEAKeyBits = SSLInt_DetermineKEABits(
|
|
server_key_bits_, info_.authType, csinfo_.symKeyBits);
|
|
switch (type) {
|
|
case ssl_kea_ecdh:
|
|
CheckKEA(type, ecKEAKeyBits);
|
|
break;
|
|
case ssl_kea_dh:
|
|
CheckKEA(type, 2048);
|
|
break;
|
|
case ssl_kea_rsa:
|
|
CheckKEA(type, server_key_bits_);
|
|
break;
|
|
default:
|
|
EXPECT_TRUE(false) << "Unknown KEA type";
|
|
break;
|
|
}
|
|
}
|
|
|
|
void TlsAgent::CheckAuthType(SSLAuthType type) const {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
EXPECT_EQ(type, info_.authType);
|
|
EXPECT_EQ(server_key_bits_, info_.authKeyBits);
|
|
|
|
if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
|
|
return;
|
|
}
|
|
|
|
// Check authAlgorithm, which is the old value for authType. This is a second
|
|
// switch
|
|
// statement because default label is different.
|
|
switch (type) {
|
|
case ssl_auth_rsa_sign:
|
|
EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
|
|
<< "authAlgorithm for RSA is always decrypt";
|
|
break;
|
|
case ssl_auth_ecdh_rsa:
|
|
EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
|
|
<< "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
|
|
break;
|
|
case ssl_auth_ecdh_ecdsa:
|
|
EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
|
|
<< "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
|
|
break;
|
|
default:
|
|
EXPECT_EQ(type, csinfo_.authAlgorithm)
|
|
<< "authAlgorithm is (usually) the same as authType";
|
|
break;
|
|
}
|
|
}
|
|
|
|
void TlsAgent::EnableFalseStart() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
falsestart_enabled_ = true;
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this));
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE));
|
|
}
|
|
|
|
void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
|
|
|
|
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
|
|
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
|
|
}
|
|
|
|
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
|
|
const std::string& expected) const {
|
|
SSLNextProtoState state;
|
|
char chosen[10];
|
|
unsigned int chosen_len;
|
|
SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
|
|
reinterpret_cast<unsigned char*>(chosen),
|
|
&chosen_len, sizeof(chosen));
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
EXPECT_EQ(expected_state, state);
|
|
if (state == SSL_NEXT_PROTO_NO_SUPPORT) {
|
|
EXPECT_EQ("", expected);
|
|
} else {
|
|
EXPECT_NE("", expected);
|
|
EXPECT_EQ(expected, std::string(chosen, chosen_len));
|
|
}
|
|
}
|
|
|
|
void TlsAgent::EnableSrtp() {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
|
|
SRTP_AES128_CM_HMAC_SHA1_32};
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers)));
|
|
}
|
|
|
|
void TlsAgent::CheckSrtp() const {
|
|
uint16_t actual;
|
|
EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
|
|
EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
|
|
}
|
|
|
|
void TlsAgent::CheckErrorCode(int32_t expected) const {
|
|
EXPECT_EQ(STATE_ERROR, state_);
|
|
EXPECT_EQ(expected, error_code_)
|
|
<< "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
|
|
<< PORT_ErrorToName(expected) << std::endl;
|
|
}
|
|
|
|
void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
|
|
ASSERT_EQ(0, error_code_);
|
|
WAIT_(error_code_ != 0, delay);
|
|
EXPECT_EQ(expected, error_code_)
|
|
<< "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
|
|
<< PORT_ErrorToName(expected) << std::endl;
|
|
}
|
|
|
|
void TlsAgent::CheckPreliminaryInfo() {
|
|
SSLPreliminaryChannelInfo info;
|
|
EXPECT_EQ(SECSuccess,
|
|
SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info)));
|
|
EXPECT_EQ(sizeof(info), info.length);
|
|
EXPECT_TRUE(info.valuesSet & ssl_preinfo_version);
|
|
EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite);
|
|
|
|
// A version of 0 is invalid and indicates no expectation. This value is
|
|
// initialized to 0 so that tests that don't explicitly set an expected
|
|
// version can negotiate a version.
|
|
if (!expected_version_) {
|
|
expected_version_ = info.protocolVersion;
|
|
}
|
|
EXPECT_EQ(expected_version_, info.protocolVersion);
|
|
|
|
// As with the version; 0 is the null cipher suite (and also invalid).
|
|
if (!expected_cipher_suite_) {
|
|
expected_cipher_suite_ = info.cipherSuite;
|
|
}
|
|
EXPECT_EQ(expected_cipher_suite_, info.cipherSuite);
|
|
}
|
|
|
|
// Check that all the expected callbacks have been called.
|
|
void TlsAgent::CheckCallbacks() const {
|
|
// If false start happens, the handshake is reported as being complete at the
|
|
// point that false start happens.
|
|
if (expect_resumption_ || !falsestart_enabled_) {
|
|
EXPECT_TRUE(handshake_callback_called_);
|
|
}
|
|
|
|
// These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
|
|
if (role_ == SERVER) {
|
|
PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn);
|
|
EXPECT_EQ(((!expect_resumption_ && have_sni) ||
|
|
expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
|
|
sni_hook_called_);
|
|
} else {
|
|
EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_);
|
|
// Note that this isn't unconditionally called, even with false start on.
|
|
// But the callback is only skipped if a cipher that is ridiculously weak
|
|
// (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
|
|
EXPECT_EQ(falsestart_enabled_ && !expect_resumption_,
|
|
can_falsestart_hook_called_);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::ResetPreliminaryInfo() {
|
|
expected_version_ = 0;
|
|
expected_cipher_suite_ = 0;
|
|
}
|
|
|
|
void TlsAgent::Connected() {
|
|
LOG("Handshake success");
|
|
CheckPreliminaryInfo();
|
|
CheckCallbacks();
|
|
|
|
SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
EXPECT_EQ(sizeof(info_), info_.length);
|
|
|
|
// Preliminary values are exposed through callbacks during the handshake.
|
|
// If either expected values were set or the callbacks were called, check
|
|
// that the final values are correct.
|
|
EXPECT_EQ(expected_version_, info_.protocolVersion);
|
|
EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
|
|
|
|
rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
|
|
|
|
if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
|
|
PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_);
|
|
// We use one ciphersuite in each direction, plus one that's kept around
|
|
// by DTLS for retransmission.
|
|
EXPECT_EQ(((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2, cipherSuites);
|
|
}
|
|
|
|
SetState(STATE_CONNECTED);
|
|
}
|
|
|
|
void TlsAgent::EnableExtendedMasterSecret() {
|
|
ASSERT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv =
|
|
SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
|
|
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::CheckExtendedMasterSecret(bool expected) {
|
|
if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
|
|
expected = PR_TRUE;
|
|
}
|
|
ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
|
|
<< "unexpected extended master secret state for " << name_;
|
|
}
|
|
|
|
void TlsAgent::CheckEarlyDataAccepted(bool expected) {
|
|
if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
|
|
expected = false;
|
|
}
|
|
ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
|
|
<< "unexpected early data state for " << name_;
|
|
}
|
|
|
|
void TlsAgent::CheckSecretsDestroyed() {
|
|
ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_));
|
|
}
|
|
|
|
void TlsAgent::DisableRollbackDetection() {
|
|
ASSERT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE);
|
|
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::EnableCompression() {
|
|
ASSERT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE);
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
|
|
ASSERT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version);
|
|
ASSERT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::Handshake() {
|
|
LOGV("Handshake");
|
|
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
|
|
if (rv == SECSuccess) {
|
|
Connected();
|
|
|
|
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
|
|
&TlsAgent::ReadableCallback);
|
|
return;
|
|
}
|
|
|
|
int32_t err = PR_GetError();
|
|
if (err == PR_WOULD_BLOCK_ERROR) {
|
|
LOGV("Would have blocked");
|
|
if (mode_ == DGRAM) {
|
|
if (timer_handle_) {
|
|
timer_handle_->Cancel();
|
|
timer_handle_ = nullptr;
|
|
}
|
|
|
|
PRIntervalTime timeout;
|
|
rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
|
|
if (rv == SECSuccess) {
|
|
Poller::Instance()->SetTimer(
|
|
timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
|
|
}
|
|
}
|
|
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
|
|
&TlsAgent::ReadableCallback);
|
|
return;
|
|
}
|
|
|
|
LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
|
|
<< PORT_ErrorToString(err));
|
|
error_code_ = err;
|
|
SetState(STATE_ERROR);
|
|
}
|
|
|
|
void TlsAgent::PrepareForRenegotiate() {
|
|
EXPECT_EQ(STATE_CONNECTED, state_);
|
|
|
|
SetState(STATE_CONNECTING);
|
|
}
|
|
|
|
void TlsAgent::StartRenegotiate() {
|
|
PrepareForRenegotiate();
|
|
|
|
SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
void TlsAgent::SendDirect(const DataBuffer& buf) {
|
|
LOG("Send Direct " << buf);
|
|
adapter_->peer()->PacketReceived(buf);
|
|
}
|
|
|
|
static bool ErrorIsNonFatal(PRErrorCode code) {
|
|
return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ;
|
|
}
|
|
|
|
void TlsAgent::SendData(size_t bytes, size_t blocksize) {
|
|
uint8_t block[4096];
|
|
|
|
ASSERT_LT(blocksize, sizeof(block));
|
|
|
|
while (bytes) {
|
|
size_t tosend = std::min(blocksize, bytes);
|
|
|
|
for (size_t i = 0; i < tosend; ++i) {
|
|
block[i] = 0xff & send_ctr_;
|
|
++send_ctr_;
|
|
}
|
|
|
|
LOGV("Writing " << tosend << " bytes");
|
|
int32_t rv = PR_Write(ssl_fd_, block, tosend);
|
|
if (expect_readwrite_error_) {
|
|
EXPECT_GT(0, rv);
|
|
EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
|
|
error_code_ = PR_GetError();
|
|
expect_readwrite_error_ = false;
|
|
} else {
|
|
ASSERT_EQ(tosend, static_cast<size_t>(rv));
|
|
}
|
|
|
|
bytes -= tosend;
|
|
}
|
|
}
|
|
|
|
void TlsAgent::ReadBytes() {
|
|
uint8_t block[1024];
|
|
|
|
int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
|
|
LOGV("ReadBytes " << rv);
|
|
int32_t err;
|
|
|
|
if (rv >= 0) {
|
|
size_t count = static_cast<size_t>(rv);
|
|
for (size_t i = 0; i < count; ++i) {
|
|
ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
|
|
recv_ctr_++;
|
|
}
|
|
} else {
|
|
err = PR_GetError();
|
|
LOG("Read error " << PORT_ErrorToName(err) << ": "
|
|
<< PORT_ErrorToString(err));
|
|
if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
|
|
error_code_ = err;
|
|
expect_readwrite_error_ = false;
|
|
}
|
|
}
|
|
|
|
// If closed, then don't bother waiting around.
|
|
if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) {
|
|
LOGV("Re-arming");
|
|
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
|
|
&TlsAgent::ReadableCallback);
|
|
}
|
|
}
|
|
|
|
void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
|
|
|
|
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
|
|
EXPECT_TRUE(EnsureTlsSetup());
|
|
|
|
SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
|
|
mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
|
|
rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
|
|
mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
|
|
EXPECT_EQ(SECSuccess, rv);
|
|
}
|
|
|
|
static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
|
|
::testing::internal::ParamGenerator<std::string>
|
|
TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
|
|
|
|
void TlsAgentTestBase::SetUp() {
|
|
SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
|
|
}
|
|
|
|
void TlsAgentTestBase::TearDown() {
|
|
delete agent_;
|
|
SSL_ClearSessionCache();
|
|
SSL_ShutdownServerSessionIDCache();
|
|
}
|
|
|
|
void TlsAgentTestBase::Reset(const std::string& server_name) {
|
|
delete agent_;
|
|
Init(server_name);
|
|
}
|
|
|
|
void TlsAgentTestBase::Init(const std::string& server_name) {
|
|
agent_ =
|
|
new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
|
|
role_, mode_);
|
|
agent_->Init();
|
|
fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
|
|
agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
|
|
agent_->StartConnect();
|
|
}
|
|
|
|
void TlsAgentTestBase::EnsureInit() {
|
|
if (!agent_) {
|
|
Init();
|
|
}
|
|
}
|
|
|
|
void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
|
|
TlsAgent::State expected_state,
|
|
int32_t error_code) {
|
|
std::cerr << "Process message: " << buffer << std::endl;
|
|
EnsureInit();
|
|
agent_->adapter()->PacketReceived(buffer);
|
|
agent_->Handshake();
|
|
|
|
ASSERT_EQ(expected_state, agent_->state());
|
|
|
|
if (expected_state == TlsAgent::STATE_ERROR) {
|
|
ASSERT_EQ(error_code, agent_->error_code());
|
|
}
|
|
}
|
|
|
|
void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
|
|
const uint8_t* buf, size_t len,
|
|
DataBuffer* out, uint64_t seq_num) {
|
|
size_t index = 0;
|
|
index = out->Write(index, type, 1);
|
|
index = out->Write(
|
|
index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2);
|
|
if (mode == DGRAM) {
|
|
index = out->Write(index, seq_num >> 32, 4);
|
|
index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
|
|
}
|
|
index = out->Write(index, len, 2);
|
|
out->Write(index, buf, len);
|
|
}
|
|
|
|
void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
|
|
const uint8_t* buf, size_t len,
|
|
DataBuffer* out, uint64_t seq_num) const {
|
|
MakeRecord(mode_, type, version, buf, len, out, seq_num);
|
|
}
|
|
|
|
void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
|
|
const uint8_t* data, size_t hs_len,
|
|
DataBuffer* out,
|
|
uint64_t seq_num) const {
|
|
return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
|
|
0);
|
|
}
|
|
|
|
void TlsAgentTestBase::MakeHandshakeMessageFragment(
|
|
uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
|
|
uint64_t seq_num, uint32_t fragment_offset,
|
|
uint32_t fragment_length) const {
|
|
size_t index = 0;
|
|
if (!fragment_length) fragment_length = hs_len;
|
|
index = out->Write(index, hs_type, 1); // Handshake record type.
|
|
index = out->Write(index, hs_len, 3); // Handshake length
|
|
if (mode_ == DGRAM) {
|
|
index = out->Write(index, seq_num, 2);
|
|
index = out->Write(index, fragment_offset, 3);
|
|
index = out->Write(index, fragment_length, 3);
|
|
}
|
|
if (data) {
|
|
index = out->Write(index, data, fragment_length);
|
|
} else {
|
|
for (size_t i = 0; i < fragment_length; ++i) {
|
|
index = out->Write(index, 1, 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
|
|
size_t hs_len,
|
|
DataBuffer* out) {
|
|
size_t index = 0;
|
|
index = out->Write(index, kTlsHandshakeType, 1); // Content Type
|
|
index = out->Write(index, 3, 1); // Version high
|
|
index = out->Write(index, 1, 1); // Version low
|
|
index = out->Write(index, 4 + hs_len, 2); // Length
|
|
|
|
index = out->Write(index, hs_type, 1); // Handshake record type.
|
|
index = out->Write(index, hs_len, 3); // Handshake length
|
|
for (size_t i = 0; i < hs_len; ++i) {
|
|
index = out->Write(index, 1, 1);
|
|
}
|
|
}
|
|
|
|
} // namespace nss_test
|