Add support for certification revocation list files

Signed-off-by: Hormoz Kheradmand <hormoz.kheradmand@shopify.com>
This commit is contained in:
Hormoz Kheradmand 2021-08-30 20:56:13 +00:00
Родитель 6b31715e81
Коммит 8e06dc7f59
28 изменённых файлов: 538 добавлений и 65 удалений

Просмотреть файл

@ -3,6 +3,7 @@
"LdapCert": "path/to/ldap-client-cert.pem",
"LdapKey": "path/to/ldap-client-key.pem",
"LdapCA": "path/to/ldap-client-ca.pem",
"LdapCRL": "path/to/ldap-client-crl.pem",
"User": "uid=vitessROuser,ou=users,ou=people,dc=example,dc=com",
"Password": "sUpErSeCuRe1",
"GroupQuery": "ou=groups,ou=people,dc=example,dc=com",

Просмотреть файл

@ -48,7 +48,9 @@ var cmdMap map[string]cmdFunc
func init() {
cmdMap = map[string]cmdFunc{
"CreateCA": cmdCreateCA,
"CreateCRL": cmdCreateCRL,
"CreateSignedCert": cmdCreateSignedCert,
"RevokeCert": cmdRevokeCert,
}
}
@ -65,6 +67,28 @@ func cmdCreateCA(subFlags *flag.FlagSet, args []string) {
tlstest.CreateCA(*root)
}
func cmdCreateCRL(subFlags *flag.FlagSet, args []string) {
subFlags.Parse(args)
if subFlags.NArg() != 1 {
log.Fatalf("CreateCRL command takes a single CA name as a parameter")
}
ca := subFlags.Arg(0)
tlstest.CreateCRL(*root, ca)
}
func cmdRevokeCert(subFlags *flag.FlagSet, args []string) {
parent := subFlags.String("parent", "ca", "Parent cert name to use. Use 'ca' for the toplevel CA.")
subFlags.Parse(args)
if subFlags.NArg() != 1 {
log.Fatalf("RevokeCert command takes a single name as a parameter")
}
name := subFlags.Arg(0)
tlstest.RevokeCertAndRegenerateCRL(*root, *parent, name)
}
func cmdCreateSignedCert(subFlags *flag.FlagSet, args []string) {
parent := subFlags.String("parent", "ca", "Parent cert name to use. Use 'ca' for the toplevel CA.")
serial := subFlags.String("serial", "01", "Serial number for the certificate to create. Should be different for two certificates with the same parent.")
@ -74,11 +98,13 @@ func cmdCreateSignedCert(subFlags *flag.FlagSet, args []string) {
if subFlags.NArg() != 1 {
log.Fatalf("CreateSignedCert command takes a single name as a parameter")
}
name := subFlags.Arg(0)
if *commonName == "" {
*commonName = subFlags.Arg(0)
*commonName = name
}
tlstest.CreateSignedCert(*root, *parent, *serial, subFlags.Arg(0), *commonName)
tlstest.CreateSignedCert(*root, *parent, *serial, name, *commonName)
}
func main() {

Просмотреть файл

@ -54,12 +54,14 @@ func TestValidCert(t *testing.T) {
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername)
tlstest.CreateCRL(root, tlstest.CA)
// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
if err != nil {
@ -136,12 +138,14 @@ func TestNoCert(t *testing.T) {
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateCRL(root, tlstest.CA)
// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
if err != nil {

Просмотреть файл

@ -282,7 +282,7 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
}
// Build the TLS config.
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, serverName, tlsVersion)
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, params.SslCrl, serverName, tlsVersion)
if err != nil {
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error loading client cert and ca: %v", err)
}

Просмотреть файл

@ -187,6 +187,7 @@ func TestTLSClientDisabled(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
@ -260,6 +261,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
@ -381,6 +383,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
@ -465,6 +468,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
@ -511,4 +515,12 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
if conn != nil {
conn.Close()
}
// Now revoke the server certificate and make sure we can't connect
tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "server")
params.SslCrl = path.Join(root, "ca-crl.pem")
_, err = Connect(context.Background(), params)
require.Error(t, err)
require.Contains(t, err.Error(), "Certificate revoked: CommonName=server.example.com")
}

Просмотреть файл

@ -39,6 +39,7 @@ type ConnParams struct {
SslCa string `json:"ssl_ca"`
SslCaPath string `json:"ssl_ca_path"`
SslCert string `json:"ssl_cert"`
SslCrl string `json:"ssl_crl"`
SslKey string `json:"ssl_key"`
TLSMinVersion string `json:"tls_min_version"`
ServerName string `json:"server_name"`

Просмотреть файл

@ -126,6 +126,7 @@ func TestSSLConnection(t *testing.T) {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)

Просмотреть файл

@ -215,6 +215,7 @@ type ServerConfig struct {
LdapCert string
LdapKey string
LdapCA string
LdapCRL string
LdapTLSMinVersion string
}
@ -250,7 +251,7 @@ func (lci *ClientImpl) Connect(network string, config *ServerConfig) error {
return err
}
tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, serverName, tlsVersion)
tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, config.LdapCRL, serverName, tlsVersion)
if err != nil {
return err
}

Просмотреть файл

@ -357,6 +357,7 @@ func FuzzTLSServer(data []byte) int {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
if err != nil {
return -1

Просмотреть файл

@ -833,6 +833,7 @@ func TestTLSServer(t *testing.T) {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
@ -924,12 +925,16 @@ func TestTLSRequired(t *testing.T) {
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert")
tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client")
// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
require.NoError(t, err)
@ -966,7 +971,6 @@ func TestTLSRequired(t *testing.T) {
}
// setup conn params with TLS
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
params.SslMode = vttls.VerifyIdentity
params.SslCa = path.Join(root, "ca-cert.pem")
params.SslCert = path.Join(root, "client-cert.pem")
@ -977,6 +981,16 @@ func TestTLSRequired(t *testing.T) {
if conn != nil {
conn.Close()
}
// setup conn params with TLS, but with a revoked client certificate
params.SslCert = path.Join(root, "revoked-client-cert.pem")
params.SslKey = path.Join(root, "revoked-client-key.pem")
conn, err = Connect(context.Background(), params)
require.NotNil(t, err)
require.Contains(t, err.Error(), "remote error: tls: bad certificate")
if conn != nil {
conn.Close()
}
}
func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
@ -1013,6 +1027,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)

Просмотреть файл

@ -374,7 +374,7 @@ func tabletConnExtraArgs(name string) []string {
}
func getVitessClient(addr string) (vtgateservicepb.VitessClient, error) {
opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, grpcName)
opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, "", grpcName)
if err != nil {
return nil, err
}

Просмотреть файл

@ -36,6 +36,7 @@ var (
cert = flag.String("binlog_player_grpc_cert", "", "the cert to use to connect")
key = flag.String("binlog_player_grpc_key", "", "the key to use to connect")
ca = flag.String("binlog_player_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("binlog_player_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("binlog_player_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -48,7 +49,7 @@ type client struct {
func (client *client) Dial(tablet *topodatapb.Tablet) error {
addr := netutil.JoinHostPort(tablet.Hostname, tablet.PortMap["grpc"])
var err error
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return err
}

Просмотреть файл

@ -125,7 +125,7 @@ func interceptors() []grpc.DialOption {
// SecureDialOption returns the gRPC dial option to use for the
// given client connection. It is either using TLS, or Insecure if
// nothing is set.
func SecureDialOption(cert, key, ca, name string) (grpc.DialOption, error) {
func SecureDialOption(cert, key, ca, crl, name string) (grpc.DialOption, error) {
// No security options set, just return.
if (cert == "" || key == "") && ca == "" {
return grpc.WithInsecure(), nil
@ -133,7 +133,7 @@ func SecureDialOption(cert, key, ca, name string) (grpc.DialOption, error) {
// Load the config. At this point we know
// we want a strict config with verify identity.
config, err := vttls.ClientConfig(vttls.VerifyIdentity, cert, key, ca, name, tls.VersionTLS12)
config, err := vttls.ClientConfig(vttls.VerifyIdentity, cert, key, ca, crl, name, tls.VersionTLS12)
if err != nil {
return nil, err
}

Просмотреть файл

@ -67,6 +67,9 @@ var (
// GRPCCA is the CA to use if TLS is enabled
GRPCCA = flag.String("grpc_ca", "", "server CA to use for gRPC connections, requires TLS, and enforces client certificate check")
// GRPCCRL is the CRL (Certificate Revocation List) to use if TLS is enabled
GRPCCRL = flag.String("grpc_crl", "", "path to a certificate revocation list in PEM format, client certificates will be further verified against this file during TLS handshake")
GRPCEnableOptionalTLS = flag.Bool("grpc_enable_optional_tls", false, "enable optional TLS mode when a server accepts both TLS and plain-text connections on the same port")
// GRPCServerCA if specified will combine server cert and server CA
@ -133,7 +136,7 @@ func createGRPCServer() {
var opts []grpc.ServerOption
if GRPCPort != nil && *GRPCCert != "" && *GRPCKey != "" {
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCServerCA, tls.VersionTLS12)
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCCRL, *GRPCServerCA, tls.VersionTLS12)
if err != nil {
log.Exitf("Failed to log gRPC cert/key/ca: %v", err)
}

Просмотреть файл

@ -36,6 +36,7 @@ var (
cert = flag.String("throttler_client_grpc_cert", "", "the cert to use to connect")
key = flag.String("throttler_client_grpc_key", "", "the key to use to connect")
ca = flag.String("throttler_client_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("throttler_client_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("throttler_client_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -45,7 +46,7 @@ type client struct {
}
func factory(addr string) (throttlerclient.Client, error) {
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, err
}

Просмотреть файл

@ -31,7 +31,15 @@ const (
// CA is the name of the CA toplevel cert.
CA = "ca"
caConfig = `
caConfigTemplate = `
[ ca ]
default_ca = default_ca
[ default_ca ]
database = %s
default_md = default
default_crl_days = 30
[ req ]
default_bits = 4096
default_keyfile = keyfile.pem
@ -90,10 +98,28 @@ func openssl(argv ...string) {
cmd := exec.Command("openssl", argv...)
output, err := cmd.CombinedOutput()
if err != nil {
if len(output) > 0 {
log.Errorf("openssl %v returned:\n%v", argv, string(output))
}
log.Fatalf("openssl %v failed: %v", argv, err)
}
if len(output) > 0 {
log.Infof("openssl %v returned:\n%v", argv, string(output))
}
// createKeyDBAndCAConfig creates a key database and ca config file
// for the passed in CA (possibly an intermediate CA)
func createKeyDBAndCAConfig(root, parent string) {
databasePath := path.Join(root, parent+"-keys.db")
if _, err := os.Stat(databasePath); os.IsNotExist(err) {
if err := os.WriteFile(databasePath, []byte{}, os.ModePerm); err != nil {
log.Fatalf("cannot write file %v: %v", databasePath, err)
}
}
config := path.Join(root, parent+"-ca.config")
if _, err := os.Stat(config); os.IsNotExist(err) {
if err := os.WriteFile(config, []byte(fmt.Sprintf(caConfigTemplate, databasePath)), os.ModePerm); err != nil {
log.Fatalf("cannot write file %v: %v", config, err)
}
}
}
@ -104,12 +130,11 @@ func CreateCA(root string) {
log.Infof("Creating test root CA in %v", root)
key := path.Join(root, "ca-key.pem")
cert := path.Join(root, "ca-cert.pem")
config := path.Join(root, "ca-ca.config")
createKeyDBAndCAConfig(root, "ca")
openssl("genrsa", "-out", key)
config := path.Join(root, "ca.config")
if err := os.WriteFile(config, []byte(caConfig), os.ModePerm); err != nil {
log.Fatalf("cannot write file %v: %v", config, err)
}
openssl("req", "-new", "-x509", "-nodes", "-days", "3600", "-batch",
"-config", config,
"-key", key,
@ -147,50 +172,139 @@ func CreateSignedCert(root, parent, serial, name, commonName string) {
"-out", cert)
}
// CreateCRL creates a new empty certificate revocation list
// for the provided parent
func CreateCRL(root, parent string) {
log.Infof("Creating CRL for root CA in %v", root)
caKey := path.Join(root, parent+"-key.pem")
caCert := path.Join(root, parent+"-cert.pem")
configPath := path.Join(root, parent+"-ca.config")
crlPath := path.Join(root, parent+"-crl.pem")
createKeyDBAndCAConfig(root, parent)
openssl("ca", "-gencrl",
"-keyfile", caKey,
"-cert", caCert,
"-config", configPath,
"-out",
crlPath,
)
}
// RevokeCertAndRegenerateCRL revokes a provided certificate under the
// provided parent CA and regenerates the CRL file for that parent
func RevokeCertAndRegenerateCRL(root, parent, name string) {
log.Infof("Revoking certificate %s", name)
caKey := path.Join(root, parent+"-key.pem")
caCert := path.Join(root, parent+"-cert.pem")
cert := path.Join(root, name+"-cert.pem")
configPath := path.Join(root, parent+"-ca.config")
createKeyDBAndCAConfig(root, parent)
openssl("ca", "-revoke", cert,
"-keyfile", caKey,
"-cert", caCert,
"-config", configPath,
)
CreateCRL(root, parent)
}
// ClientServerKeyPairs is used in tests
type ClientServerKeyPairs struct {
ServerCert string
ServerKey string
ServerCA string
ServerName string
ClientCert string
ClientKey string
ClientCA string
ServerCert string
ServerKey string
ServerCA string
ServerName string
ServerCRL string
RevokedServerCert string
RevokedServerKey string
RevokedServerName string
ClientCert string
ClientKey string
ClientCA string
ClientCRL string
RevokedClientCert string
RevokedClientKey string
RevokedClientName string
CombinedCRL string
}
var serialCounter = 0
// CreateClientServerCertPairs creates certificate pairs for use in test
func CreateClientServerCertPairs(root string) ClientServerKeyPairs {
// Create the certs and configs.
CreateCA(root)
serverSerial := fmt.Sprintf("%03d", serialCounter*2+1)
clientSerial := fmt.Sprintf("%03d", serialCounter*2+2)
serverCASerial := fmt.Sprintf("%03d", serialCounter*2+1)
serverSerial := fmt.Sprintf("%03d", serialCounter*2+3)
revokedServerSerial := fmt.Sprintf("%03d", serialCounter*2+5)
clientCASerial := fmt.Sprintf("%03d", serialCounter*2+2)
clientCertSerial := fmt.Sprintf("%03d", serialCounter*2+4)
revokedClientSerial := fmt.Sprintf("%03d", serialCounter*2+6)
serialCounter = serialCounter + 1
serialCounter = serialCounter + 3
serverName := fmt.Sprintf("server-%s", serverSerial)
serverCACommonName := fmt.Sprintf("Server %s CA", serverSerial)
serverCAName := fmt.Sprintf("servers-ca-%s", serverCASerial)
serverCACommonName := fmt.Sprintf("Servers %s CA", serverCASerial)
serverCertName := fmt.Sprintf("server-instance-%s", serverSerial)
serverCertCommonName := fmt.Sprintf("server%s.example.com", serverSerial)
revokedServerCertName := fmt.Sprintf("server-instance-%s", revokedServerSerial)
revokedServerCertCommonName := fmt.Sprintf("server%s.example.com", revokedServerSerial)
clientName := fmt.Sprintf("clients-%s", serverSerial)
clientCACommonName := fmt.Sprintf("Clients %s CA", serverSerial)
clientCertName := fmt.Sprintf("client-instance-%s", serverSerial)
clientCertCommonName := fmt.Sprintf("Client Instance %s", serverSerial)
clientCAName := fmt.Sprintf("clients-ca-%s", clientCASerial)
clientCACommonName := fmt.Sprintf("Clients %s CA", clientCASerial)
clientCertName := fmt.Sprintf("client-instance-%s", clientCertSerial)
clientCertCommonName := fmt.Sprintf("client%s.example.com", clientCertSerial)
revokedClientCertName := fmt.Sprintf("client-instance-%s", revokedClientSerial)
revokedClientCertCommonName := fmt.Sprintf("client%s.example.com", revokedClientSerial)
CreateSignedCert(root, CA, serverSerial, serverName, serverCACommonName)
CreateSignedCert(root, serverName, serverSerial, serverCertName, serverCertCommonName)
CreateSignedCert(root, CA, serverCASerial, serverCAName, serverCACommonName)
CreateSignedCert(root, serverCAName, serverSerial, serverCertName, serverCertCommonName)
CreateSignedCert(root, serverCAName, revokedServerSerial, revokedServerCertName, revokedServerCertCommonName)
RevokeCertAndRegenerateCRL(root, serverCAName, revokedServerCertName)
CreateSignedCert(root, CA, clientSerial, clientName, clientCACommonName)
CreateSignedCert(root, clientName, serverSerial, clientCertName, clientCertCommonName)
CreateSignedCert(root, CA, clientCASerial, clientCAName, clientCACommonName)
CreateSignedCert(root, clientCAName, clientCertSerial, clientCertName, clientCertCommonName)
CreateSignedCert(root, clientCAName, revokedClientSerial, revokedClientCertName, revokedClientCertCommonName)
RevokeCertAndRegenerateCRL(root, clientCAName, revokedClientCertName)
serverCRLPath := path.Join(root, fmt.Sprintf("%s-crl.pem", serverCAName))
clientCRLPath := path.Join(root, fmt.Sprintf("%s-crl.pem", clientCAName))
combinedCRLPath := path.Join(root, fmt.Sprintf("%s-%s-combined-crl.pem", serverCAName, clientCAName))
serverCRLBytes, err := ioutil.ReadFile(serverCRLPath)
if err != nil {
log.Fatalf("Could not read server CRL file")
}
clientCRLBytes, err := ioutil.ReadFile(clientCRLPath)
if err != nil {
log.Fatalf("Could not read client CRL file")
}
err = ioutil.WriteFile(combinedCRLPath, append(serverCRLBytes, clientCRLBytes...), 0777)
if err != nil {
log.Fatalf("Could not write combined CRL file")
}
return ClientServerKeyPairs{
ServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCertName)),
ServerKey: path.Join(root, fmt.Sprintf("%s-key.pem", serverCertName)),
ServerCA: path.Join(root, fmt.Sprintf("%s-cert.pem", serverName)),
ClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCertName)),
ClientKey: path.Join(root, fmt.Sprintf("%s-key.pem", clientCertName)),
ClientCA: path.Join(root, fmt.Sprintf("%s-cert.pem", clientName)),
ServerName: serverCertCommonName,
ServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCertName)),
ServerKey: path.Join(root, fmt.Sprintf("%s-key.pem", serverCertName)),
ServerCA: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCAName)),
ServerCRL: serverCRLPath,
RevokedServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", revokedServerCertName)),
RevokedServerKey: path.Join(root, fmt.Sprintf("%s-key.pem", revokedServerCertName)),
ClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCertName)),
ClientKey: path.Join(root, fmt.Sprintf("%s-key.pem", clientCertName)),
ClientCA: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCAName)),
ClientCRL: clientCRLPath,
RevokedClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", revokedClientCertName)),
RevokedClientKey: path.Join(root, fmt.Sprintf("%s-key.pem", revokedClientCertName)),
CombinedCRL: combinedCRLPath,
ServerName: serverCertCommonName,
RevokedServerName: revokedServerCertCommonName,
RevokedClientName: revokedClientCertCommonName,
}
}

Просмотреть файл

@ -65,6 +65,7 @@ func testClientServer(t *testing.T, combineCerts bool) {
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ClientCA,
clientServerKeyPairs.ClientCRL,
serverCA,
tls.VersionTLS12)
if err != nil {
@ -75,6 +76,7 @@ func testClientServer(t *testing.T, combineCerts bool) {
clientServerKeyPairs.ClientCert,
clientServerKeyPairs.ClientKey,
clientServerKeyPairs.ServerCA,
clientServerKeyPairs.ServerCRL,
clientServerKeyPairs.ServerName,
tls.VersionTLS12)
if err != nil {
@ -109,9 +111,6 @@ func testClientServer(t *testing.T, combineCerts bool) {
}()
serverConn, err := listener.Accept()
if clientErr != nil {
t.Fatalf("Dial failed: %v", clientErr)
}
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
@ -127,6 +126,10 @@ func testClientServer(t *testing.T, combineCerts bool) {
wg.Wait()
if clientErr != nil {
t.Fatalf("Dial failed: %v", clientErr)
}
//
// Negative case: connect a client with wrong cert (using the
// server cert on the client side).
@ -137,6 +140,7 @@ func testClientServer(t *testing.T, combineCerts bool) {
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ServerCA,
clientServerKeyPairs.ServerCRL,
clientServerKeyPairs.ServerName,
tls.VersionTLS12)
if err != nil {
@ -188,6 +192,7 @@ func getServerConfigWithoutCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Co
keypairs.ServerCert,
keypairs.ServerKey,
keypairs.ClientCA,
keypairs.ClientCRL,
"",
tls.VersionTLS12)
}
@ -197,6 +202,7 @@ func getServerConfigWithCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Confi
keypairs.ServerCert,
keypairs.ServerKey,
keypairs.ClientCA,
keypairs.ClientCRL,
keypairs.ServerCA,
tls.VersionTLS12)
}
@ -207,6 +213,7 @@ func getClientConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) {
keypairs.ClientCert,
keypairs.ClientKey,
keypairs.ServerCA,
keypairs.ServerCRL,
keypairs.ServerName,
tls.VersionTLS12)
}
@ -296,6 +303,7 @@ func testNumberOfCertsWithOrWithoutCombining(t *testing.T, numCertsExpected int,
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ClientCA,
clientServerKeyPairs.ClientCRL,
serverCA,
tls.VersionTLS12)
@ -312,3 +320,171 @@ func TestNumberOfCertsWithoutCombining(t *testing.T) {
func TestNumberOfCertsWithCombining(t *testing.T) {
testNumberOfCertsWithOrWithoutCombining(t, 2, true)
}
func assertTLSHandshakeFails(t *testing.T, serverConfig, clientConfig *tls.Config) {
// Create a TLS server listener.
listener, err := tls.Listen("tcp", ":0", serverConfig)
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
addr := listener.Addr().String()
defer listener.Close()
// create a dialer with timeout
dialer := new(net.Dialer)
dialer.Timeout = 10 * time.Second
wg := sync.WaitGroup{}
var clientErr error
wg.Add(1)
go func() {
defer wg.Done()
var clientConn *tls.Conn
clientConn, clientErr = tls.DialWithDialer(dialer, "tcp", addr, clientConfig)
if clientErr == nil {
clientConn.Close()
}
}()
serverConn, err := listener.Accept()
if err != nil {
// We should always be able to accept on the socket
t.Fatalf("Accept failed: %v", err)
}
err = serverConn.(*tls.Conn).Handshake()
if err != nil {
if !(strings.Contains(err.Error(), "Certificate revoked: CommonName=") ||
strings.Contains(err.Error(), "remote error: tls: bad certificate")) {
t.Fatalf("Wrong error returned: %v", err)
}
} else {
t.Fatal("Server should have failed the TLS handshake but it did not")
}
serverConn.Close()
wg.Wait()
}
func TestClientServerWithRevokedServerCert(t *testing.T) {
root, err := ioutil.TempDir("", "tlstest")
if err != nil {
t.Fatalf("TempDir failed: %v", err)
}
defer os.RemoveAll(root)
clientServerKeyPairs := CreateClientServerCertPairs(root)
serverConfig, err := vttls.ServerConfig(
clientServerKeyPairs.RevokedServerCert,
clientServerKeyPairs.RevokedServerKey,
clientServerKeyPairs.ClientCA,
clientServerKeyPairs.ClientCRL,
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
clientConfig, err := vttls.ClientConfig(
vttls.VerifyIdentity,
clientServerKeyPairs.ClientCert,
clientServerKeyPairs.ClientKey,
clientServerKeyPairs.ServerCA,
clientServerKeyPairs.ServerCRL,
clientServerKeyPairs.RevokedServerName,
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
assertTLSHandshakeFails(t, serverConfig, clientConfig)
serverConfig, err = vttls.ServerConfig(
clientServerKeyPairs.RevokedServerCert,
clientServerKeyPairs.RevokedServerKey,
clientServerKeyPairs.ClientCA,
clientServerKeyPairs.CombinedCRL,
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
clientConfig, err = vttls.ClientConfig(
vttls.VerifyIdentity,
clientServerKeyPairs.ClientCert,
clientServerKeyPairs.ClientKey,
clientServerKeyPairs.ServerCA,
clientServerKeyPairs.CombinedCRL,
clientServerKeyPairs.RevokedServerName,
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
assertTLSHandshakeFails(t, serverConfig, clientConfig)
}
func TestClientServerWithRevokedClientCert(t *testing.T) {
root, err := ioutil.TempDir("", "tlstest")
if err != nil {
t.Fatalf("TempDir failed: %v", err)
}
defer os.RemoveAll(root)
clientServerKeyPairs := CreateClientServerCertPairs(root)
// Single CRL
serverConfig, err := vttls.ServerConfig(
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ClientCA,
clientServerKeyPairs.ClientCRL,
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
clientConfig, err := vttls.ClientConfig(
vttls.VerifyIdentity,
clientServerKeyPairs.RevokedClientCert,
clientServerKeyPairs.RevokedClientKey,
clientServerKeyPairs.ServerCA,
clientServerKeyPairs.ServerCRL,
clientServerKeyPairs.ServerName,
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
assertTLSHandshakeFails(t, serverConfig, clientConfig)
// CombinedCRL
serverConfig, err = vttls.ServerConfig(
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ClientCA,
clientServerKeyPairs.CombinedCRL,
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
clientConfig, err = vttls.ClientConfig(
vttls.VerifyIdentity,
clientServerKeyPairs.RevokedClientCert,
clientServerKeyPairs.RevokedClientKey,
clientServerKeyPairs.ServerCA,
clientServerKeyPairs.CombinedCRL,
clientServerKeyPairs.ServerName,
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
assertTLSHandshakeFails(t, serverConfig, clientConfig)
}

Просмотреть файл

@ -30,6 +30,7 @@ var (
cert = flag.String("vtctld_grpc_cert", "", "the cert to use to connect")
key = flag.String("vtctld_grpc_key", "", "the key to use to connect")
ca = flag.String("vtctld_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("vtctld_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("vtctld_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -37,5 +38,5 @@ var (
// insecure if no flags were set) based on the vtctld_grpc_* flags declared by
// this package.
func SecureDialOption() (grpc.DialOption, error) {
return grpcclient.SecureDialOption(*cert, *key, *ca, *name)
return grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
}

Просмотреть файл

@ -41,6 +41,7 @@ var (
cert = flag.String("vtgate_grpc_cert", "", "the cert to use to connect")
key = flag.String("vtgate_grpc_key", "", "the key to use to connect")
ca = flag.String("vtgate_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("vtgate_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("vtgate_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -60,7 +61,7 @@ func dial(ctx context.Context, addr string) (vtgateconn.Impl, error) {
// DialWithOpts allows for custom dial options to be set on a vtgateConn.
func DialWithOpts(ctx context.Context, opts ...grpc.DialOption) vtgateconn.DialerFunc {
return func(ctx context.Context, address string) (vtgateconn.Impl, error) {
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, err
}

Просмотреть файл

@ -63,6 +63,7 @@ var (
mysqlSslCert = flag.String("mysql_server_ssl_cert", "", "Path to the ssl cert for mysql server plugin SSL")
mysqlSslKey = flag.String("mysql_server_ssl_key", "", "Path to ssl key for mysql server plugin SSL")
mysqlSslCa = flag.String("mysql_server_ssl_ca", "", "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.")
mysqlSslCrl = flag.String("mysql_server_ssl_crl", "", "Path to ssl CRL for mysql server plugin SSL")
mysqlTLSMinVersion = flag.String("mysql_server_tls_min_version", "", "Configures the minimal TLS version negotiated when SSL is enabled. Defaults to TLSv1.2. Options: TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3.")
@ -364,8 +365,8 @@ var sigChan chan os.Signal
var vtgateHandle *vtgateHandler
// initTLSConfig inits tls config for the given mysql listener
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA, mysqlMinTLSVersion)
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion)
if err != nil {
log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
return err
@ -376,7 +377,7 @@ func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mys
signal.Notify(sigChan, syscall.SIGHUP)
go func() {
for range sigChan {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA, mysqlMinTLSVersion)
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion)
if err != nil {
log.Errorf("grpcutils.TLSServerConfig failed: %v", err)
} else {
@ -437,7 +438,7 @@ func initMySQLProtocol() {
log.Exitf("mysql.NewListener failed: %v", err)
}
_ = initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlSslServerCA, *mysqlServerRequireSecureTransport, tlsVersion)
_ = initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlSslCrl, *mysqlSslServerCA, *mysqlServerRequireSecureTransport, tlsVersion)
}
mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS)
// Check for the connection threshold

Просмотреть файл

@ -264,6 +264,7 @@ func testInitTLSConfig(t *testing.T, serverCA bool) {
}
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateCRL(root, tlstest.CA)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
serverCACert := ""
@ -272,7 +273,7 @@ func testInitTLSConfig(t *testing.T, serverCA bool) {
}
listener := &mysql.Listener{}
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), serverCACert, true, tls.VersionTLS12); err != nil {
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), path.Join(root, "ca-crl.pem"), serverCACert, true, tls.VersionTLS12); err != nil {
t.Fatalf("init tls config failure due to: +%v", err)
}

Просмотреть файл

@ -105,9 +105,9 @@ func querylogzHandler(ch chan interface{}, w http.ResponseWriter, r *http.Reques
stats, ok := out.(*LogStats)
if !ok {
err := fmt.Errorf("unexpected value in %s: %#v (expecting value of type %T)", QueryLogger.Name(), out, &LogStats{})
io.WriteString(w, `<tr class="error">`)
io.WriteString(w, err.Error())
io.WriteString(w, "</tr>")
_, _ = io.WriteString(w, `<tr class="error">`)
_, _ = io.WriteString(w, err.Error())
_, _ = io.WriteString(w, "</tr>")
log.Error(err)
continue
}

Просмотреть файл

@ -44,6 +44,7 @@ var (
cert = flag.String("tablet_grpc_cert", "", "the cert to use to connect")
key = flag.String("tablet_grpc_key", "", "the key to use to connect")
ca = flag.String("tablet_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("tablet_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("tablet_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -73,7 +74,7 @@ func DialTablet(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (querys
} else {
addr = tablet.Hostname
}
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, err
}

Просмотреть файл

@ -230,7 +230,7 @@ func (dialer *cachedConnDialer) pollOnce(ctx context.Context, addr string) (clie
// It returns the three-tuple of client-interface, closer, and error that the
// main dial func returns.
func (dialer *cachedConnDialer) newdial(ctx context.Context, addr string) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) {
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
dialer.connWaitSema.Release()
return nil, nil, err

Просмотреть файл

@ -48,6 +48,7 @@ var (
cert = flag.String("tablet_manager_grpc_cert", "", "the cert to use to connect")
key = flag.String("tablet_manager_grpc_key", "", "the key to use to connect")
ca = flag.String("tablet_manager_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("tablet_manager_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("tablet_manager_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -111,7 +112,7 @@ func NewClient() *Client {
// dial returns a client to use
func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) {
addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"]))
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, nil, err
}
@ -125,7 +126,7 @@ func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) (
func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) {
addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"]))
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, err
}

93
go/vt/vttls/crl.go Normal file
Просмотреть файл

@ -0,0 +1,93 @@
/*
Copyright 2021 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package vttls
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io/ioutil"
"time"
"vitess.io/vitess/go/vt/log"
)
type verifyPeerCertificateFunc func([][]byte, [][]*x509.Certificate) error
func certIsRevoked(cert *x509.Certificate, crl *pkix.CertificateList) bool {
if crl.HasExpired(time.Now()) {
log.Warningf("The current Certificate Revocation List (CRL) is past expiry date and must be updated. Revoked certificates will still be rejected in this state.")
}
for _, revoked := range crl.TBSCertList.RevokedCertificates {
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 {
return true
}
}
return false
}
func verifyPeerCertificateAgainstCRL(crl string) (verifyPeerCertificateFunc, error) {
crlSet, err := loadCRLSet(crl)
if err != nil {
return nil, err
}
return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
for _, chain := range verifiedChains {
for i := 0; i < len(chain)-1; i++ {
cert := chain[i]
issuerCert := chain[i+1]
for _, crl := range crlSet {
if issuerCert.CheckCRLSignature(crl) == nil {
if certIsRevoked(cert, crl) {
return fmt.Errorf("Certificate revoked: CommonName=%v", cert.Subject.CommonName)
}
}
}
}
}
return nil
}, nil
}
func loadCRLSet(crl string) ([]*pkix.CertificateList, error) {
body, err := ioutil.ReadFile(crl)
if err != nil {
return nil, err
}
crlSet := make([]*pkix.CertificateList, 0)
for len(body) > 0 {
var block *pem.Block
block, body = pem.Decode(body)
if block == nil {
break
}
if block.Type != "X509 CRL" {
continue
}
parsedCRL, err := x509.ParseCRL(block.Bytes)
if err != nil {
return nil, err
}
crlSet = append(crlSet, parsedCRL)
}
return crlSet, nil
}

Просмотреть файл

@ -128,7 +128,7 @@ var onceByKeys = sync.Map{}
// ClientConfig returns the TLS config to use for a client to
// connect to a server with the provided parameters.
func ClientConfig(mode SslMode, cert, key, ca, name string, minTLSVersion uint16) (*tls.Config, error) {
func ClientConfig(mode SslMode, cert, key, ca, crl, name string, minTLSVersion uint16) (*tls.Config, error) {
config := newTLSConfig(minTLSVersion)
// Load the client-side cert & key if any.
@ -190,12 +190,20 @@ func ClientConfig(mode SslMode, cert, key, ca, name string, minTLSVersion uint16
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid mode: %s", mode)
}
if crl != "" {
crlFunc, err := verifyPeerCertificateAgainstCRL(crl)
if err != nil {
return nil, err
}
config.VerifyPeerCertificate = crlFunc
}
return config, nil
}
// ServerConfig returns the TLS config to use for a server to
// accept client connections.
func ServerConfig(cert, key, ca, serverCA string, minTLSVersion uint16) (*tls.Config, error) {
func ServerConfig(cert, key, ca, crl, serverCA string, minTLSVersion uint16) (*tls.Config, error) {
config := newTLSConfig(minTLSVersion)
var certificates *[]tls.Certificate
@ -225,6 +233,14 @@ func ServerConfig(cert, key, ca, serverCA string, minTLSVersion uint16) (*tls.Co
config.ClientAuth = tls.RequireAndVerifyClientCert
}
if crl != "" {
crlFunc, err := verifyPeerCertificateAgainstCRL(crl)
if err != nil {
return nil, err
}
config.VerifyPeerCertificate = crlFunc
}
return config, nil
}

Просмотреть файл

@ -39,6 +39,7 @@ var (
cert = flag.String("vtworker_client_grpc_cert", "", "the cert to use to connect")
key = flag.String("vtworker_client_grpc_key", "", "the key to use to connect")
ca = flag.String("vtworker_client_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("vtworker_client_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("vtworker_client_grpc_server_name", "", "the server name to use to validate server certificate")
)
@ -49,7 +50,7 @@ type gRPCVtworkerClient struct {
func gRPCVtworkerClientFactory(addr string) (vtworkerclient.Client, error) {
// create the RPC client
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, err
}