зеркало из https://github.com/github/vitess-gh.git
Add support for certification revocation list files
Signed-off-by: Hormoz Kheradmand <hormoz.kheradmand@shopify.com>
This commit is contained in:
Родитель
6b31715e81
Коммит
8e06dc7f59
|
@ -3,6 +3,7 @@
|
|||
"LdapCert": "path/to/ldap-client-cert.pem",
|
||||
"LdapKey": "path/to/ldap-client-key.pem",
|
||||
"LdapCA": "path/to/ldap-client-ca.pem",
|
||||
"LdapCRL": "path/to/ldap-client-crl.pem",
|
||||
"User": "uid=vitessROuser,ou=users,ou=people,dc=example,dc=com",
|
||||
"Password": "sUpErSeCuRe1",
|
||||
"GroupQuery": "ou=groups,ou=people,dc=example,dc=com",
|
||||
|
|
|
@ -48,7 +48,9 @@ var cmdMap map[string]cmdFunc
|
|||
func init() {
|
||||
cmdMap = map[string]cmdFunc{
|
||||
"CreateCA": cmdCreateCA,
|
||||
"CreateCRL": cmdCreateCRL,
|
||||
"CreateSignedCert": cmdCreateSignedCert,
|
||||
"RevokeCert": cmdRevokeCert,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,6 +67,28 @@ func cmdCreateCA(subFlags *flag.FlagSet, args []string) {
|
|||
tlstest.CreateCA(*root)
|
||||
}
|
||||
|
||||
func cmdCreateCRL(subFlags *flag.FlagSet, args []string) {
|
||||
subFlags.Parse(args)
|
||||
if subFlags.NArg() != 1 {
|
||||
log.Fatalf("CreateCRL command takes a single CA name as a parameter")
|
||||
}
|
||||
|
||||
ca := subFlags.Arg(0)
|
||||
tlstest.CreateCRL(*root, ca)
|
||||
}
|
||||
|
||||
func cmdRevokeCert(subFlags *flag.FlagSet, args []string) {
|
||||
parent := subFlags.String("parent", "ca", "Parent cert name to use. Use 'ca' for the toplevel CA.")
|
||||
|
||||
subFlags.Parse(args)
|
||||
if subFlags.NArg() != 1 {
|
||||
log.Fatalf("RevokeCert command takes a single name as a parameter")
|
||||
}
|
||||
|
||||
name := subFlags.Arg(0)
|
||||
tlstest.RevokeCertAndRegenerateCRL(*root, *parent, name)
|
||||
}
|
||||
|
||||
func cmdCreateSignedCert(subFlags *flag.FlagSet, args []string) {
|
||||
parent := subFlags.String("parent", "ca", "Parent cert name to use. Use 'ca' for the toplevel CA.")
|
||||
serial := subFlags.String("serial", "01", "Serial number for the certificate to create. Should be different for two certificates with the same parent.")
|
||||
|
@ -74,11 +98,13 @@ func cmdCreateSignedCert(subFlags *flag.FlagSet, args []string) {
|
|||
if subFlags.NArg() != 1 {
|
||||
log.Fatalf("CreateSignedCert command takes a single name as a parameter")
|
||||
}
|
||||
|
||||
name := subFlags.Arg(0)
|
||||
if *commonName == "" {
|
||||
*commonName = subFlags.Arg(0)
|
||||
*commonName = name
|
||||
}
|
||||
|
||||
tlstest.CreateSignedCert(*root, *parent, *serial, subFlags.Arg(0), *commonName)
|
||||
tlstest.CreateSignedCert(*root, *parent, *serial, name, *commonName)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
|
|
@ -54,12 +54,14 @@ func TestValidCert(t *testing.T) {
|
|||
tlstest.CreateCA(root)
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername)
|
||||
tlstest.CreateCRL(root, tlstest.CA)
|
||||
|
||||
// Create the server with TLS config.
|
||||
serverConfig, err := vttls.ServerConfig(
|
||||
path.Join(root, "server-cert.pem"),
|
||||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
path.Join(root, "ca-crl.pem"),
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
|
@ -136,12 +138,14 @@ func TestNoCert(t *testing.T) {
|
|||
defer os.RemoveAll(root)
|
||||
tlstest.CreateCA(root)
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
|
||||
tlstest.CreateCRL(root, tlstest.CA)
|
||||
|
||||
// Create the server with TLS config.
|
||||
serverConfig, err := vttls.ServerConfig(
|
||||
path.Join(root, "server-cert.pem"),
|
||||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
path.Join(root, "ca-crl.pem"),
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
|
|
|
@ -282,7 +282,7 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
|
|||
}
|
||||
|
||||
// Build the TLS config.
|
||||
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, serverName, tlsVersion)
|
||||
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, params.SslCrl, serverName, tlsVersion)
|
||||
if err != nil {
|
||||
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error loading client cert and ca: %v", err)
|
||||
}
|
||||
|
|
|
@ -187,6 +187,7 @@ func TestTLSClientDisabled(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
require.NoError(t, err)
|
||||
l.TLSConfig.Store(serverConfig)
|
||||
|
@ -260,6 +261,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
require.NoError(t, err)
|
||||
l.TLSConfig.Store(serverConfig)
|
||||
|
@ -381,6 +383,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
require.NoError(t, err)
|
||||
l.TLSConfig.Store(serverConfig)
|
||||
|
@ -465,6 +468,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
require.NoError(t, err)
|
||||
l.TLSConfig.Store(serverConfig)
|
||||
|
@ -511,4 +515,12 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
|
|||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// Now revoke the server certificate and make sure we can't connect
|
||||
tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "server")
|
||||
|
||||
params.SslCrl = path.Join(root, "ca-crl.pem")
|
||||
_, err = Connect(context.Background(), params)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Certificate revoked: CommonName=server.example.com")
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ type ConnParams struct {
|
|||
SslCa string `json:"ssl_ca"`
|
||||
SslCaPath string `json:"ssl_ca_path"`
|
||||
SslCert string `json:"ssl_cert"`
|
||||
SslCrl string `json:"ssl_crl"`
|
||||
SslKey string `json:"ssl_key"`
|
||||
TLSMinVersion string `json:"tls_min_version"`
|
||||
ServerName string `json:"server_name"`
|
||||
|
|
|
@ -126,6 +126,7 @@ func TestSSLConnection(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSServerConfig failed: %v", err)
|
||||
|
|
|
@ -215,6 +215,7 @@ type ServerConfig struct {
|
|||
LdapCert string
|
||||
LdapKey string
|
||||
LdapCA string
|
||||
LdapCRL string
|
||||
LdapTLSMinVersion string
|
||||
}
|
||||
|
||||
|
@ -250,7 +251,7 @@ func (lci *ClientImpl) Connect(network string, config *ServerConfig) error {
|
|||
return err
|
||||
}
|
||||
|
||||
tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, serverName, tlsVersion)
|
||||
tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, config.LdapCRL, serverName, tlsVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -357,6 +357,7 @@ func FuzzTLSServer(data []byte) int {
|
|||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
return -1
|
||||
|
|
|
@ -833,6 +833,7 @@ func TestTLSServer(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
require.NoError(t, err)
|
||||
l.TLSConfig.Store(serverConfig)
|
||||
|
@ -924,12 +925,16 @@ func TestTLSRequired(t *testing.T) {
|
|||
defer os.RemoveAll(root)
|
||||
tlstest.CreateCA(root)
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert")
|
||||
tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client")
|
||||
|
||||
// Create the server with TLS config.
|
||||
serverConfig, err := vttls.ServerConfig(
|
||||
path.Join(root, "server-cert.pem"),
|
||||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
path.Join(root, "ca-crl.pem"),
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
require.NoError(t, err)
|
||||
|
@ -966,7 +971,6 @@ func TestTLSRequired(t *testing.T) {
|
|||
}
|
||||
|
||||
// setup conn params with TLS
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
|
||||
params.SslMode = vttls.VerifyIdentity
|
||||
params.SslCa = path.Join(root, "ca-cert.pem")
|
||||
params.SslCert = path.Join(root, "client-cert.pem")
|
||||
|
@ -977,6 +981,16 @@ func TestTLSRequired(t *testing.T) {
|
|||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// setup conn params with TLS, but with a revoked client certificate
|
||||
params.SslCert = path.Join(root, "revoked-client-cert.pem")
|
||||
params.SslKey = path.Join(root, "revoked-client-key.pem")
|
||||
conn, err = Connect(context.Background(), params)
|
||||
require.NotNil(t, err)
|
||||
require.Contains(t, err.Error(), "remote error: tls: bad certificate")
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
|
||||
|
@ -1013,6 +1027,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
|
|||
path.Join(root, "server-key.pem"),
|
||||
path.Join(root, "ca-cert.pem"),
|
||||
"",
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSServerConfig failed: %v", err)
|
||||
|
|
|
@ -374,7 +374,7 @@ func tabletConnExtraArgs(name string) []string {
|
|||
}
|
||||
|
||||
func getVitessClient(addr string) (vtgateservicepb.VitessClient, error) {
|
||||
opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, grpcName)
|
||||
opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, "", grpcName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ var (
|
|||
cert = flag.String("binlog_player_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("binlog_player_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("binlog_player_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("binlog_player_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("binlog_player_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -48,7 +49,7 @@ type client struct {
|
|||
func (client *client) Dial(tablet *topodatapb.Tablet) error {
|
||||
addr := netutil.JoinHostPort(tablet.Hostname, tablet.PortMap["grpc"])
|
||||
var err error
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ func interceptors() []grpc.DialOption {
|
|||
// SecureDialOption returns the gRPC dial option to use for the
|
||||
// given client connection. It is either using TLS, or Insecure if
|
||||
// nothing is set.
|
||||
func SecureDialOption(cert, key, ca, name string) (grpc.DialOption, error) {
|
||||
func SecureDialOption(cert, key, ca, crl, name string) (grpc.DialOption, error) {
|
||||
// No security options set, just return.
|
||||
if (cert == "" || key == "") && ca == "" {
|
||||
return grpc.WithInsecure(), nil
|
||||
|
@ -133,7 +133,7 @@ func SecureDialOption(cert, key, ca, name string) (grpc.DialOption, error) {
|
|||
|
||||
// Load the config. At this point we know
|
||||
// we want a strict config with verify identity.
|
||||
config, err := vttls.ClientConfig(vttls.VerifyIdentity, cert, key, ca, name, tls.VersionTLS12)
|
||||
config, err := vttls.ClientConfig(vttls.VerifyIdentity, cert, key, ca, crl, name, tls.VersionTLS12)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -67,6 +67,9 @@ var (
|
|||
// GRPCCA is the CA to use if TLS is enabled
|
||||
GRPCCA = flag.String("grpc_ca", "", "server CA to use for gRPC connections, requires TLS, and enforces client certificate check")
|
||||
|
||||
// GRPCCRL is the CRL (Certificate Revocation List) to use if TLS is enabled
|
||||
GRPCCRL = flag.String("grpc_crl", "", "path to a certificate revocation list in PEM format, client certificates will be further verified against this file during TLS handshake")
|
||||
|
||||
GRPCEnableOptionalTLS = flag.Bool("grpc_enable_optional_tls", false, "enable optional TLS mode when a server accepts both TLS and plain-text connections on the same port")
|
||||
|
||||
// GRPCServerCA if specified will combine server cert and server CA
|
||||
|
@ -133,7 +136,7 @@ func createGRPCServer() {
|
|||
|
||||
var opts []grpc.ServerOption
|
||||
if GRPCPort != nil && *GRPCCert != "" && *GRPCKey != "" {
|
||||
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCServerCA, tls.VersionTLS12)
|
||||
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCCRL, *GRPCServerCA, tls.VersionTLS12)
|
||||
if err != nil {
|
||||
log.Exitf("Failed to log gRPC cert/key/ca: %v", err)
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ var (
|
|||
cert = flag.String("throttler_client_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("throttler_client_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("throttler_client_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("throttler_client_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("throttler_client_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -45,7 +46,7 @@ type client struct {
|
|||
}
|
||||
|
||||
func factory(addr string) (throttlerclient.Client, error) {
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -31,7 +31,15 @@ const (
|
|||
// CA is the name of the CA toplevel cert.
|
||||
CA = "ca"
|
||||
|
||||
caConfig = `
|
||||
caConfigTemplate = `
|
||||
[ ca ]
|
||||
default_ca = default_ca
|
||||
|
||||
[ default_ca ]
|
||||
database = %s
|
||||
default_md = default
|
||||
default_crl_days = 30
|
||||
|
||||
[ req ]
|
||||
default_bits = 4096
|
||||
default_keyfile = keyfile.pem
|
||||
|
@ -90,10 +98,28 @@ func openssl(argv ...string) {
|
|||
cmd := exec.Command("openssl", argv...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
if len(output) > 0 {
|
||||
log.Errorf("openssl %v returned:\n%v", argv, string(output))
|
||||
}
|
||||
log.Fatalf("openssl %v failed: %v", argv, err)
|
||||
}
|
||||
if len(output) > 0 {
|
||||
log.Infof("openssl %v returned:\n%v", argv, string(output))
|
||||
}
|
||||
|
||||
// createKeyDBAndCAConfig creates a key database and ca config file
|
||||
// for the passed in CA (possibly an intermediate CA)
|
||||
func createKeyDBAndCAConfig(root, parent string) {
|
||||
databasePath := path.Join(root, parent+"-keys.db")
|
||||
if _, err := os.Stat(databasePath); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(databasePath, []byte{}, os.ModePerm); err != nil {
|
||||
log.Fatalf("cannot write file %v: %v", databasePath, err)
|
||||
}
|
||||
}
|
||||
|
||||
config := path.Join(root, parent+"-ca.config")
|
||||
if _, err := os.Stat(config); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(config, []byte(fmt.Sprintf(caConfigTemplate, databasePath)), os.ModePerm); err != nil {
|
||||
log.Fatalf("cannot write file %v: %v", config, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,12 +130,11 @@ func CreateCA(root string) {
|
|||
log.Infof("Creating test root CA in %v", root)
|
||||
key := path.Join(root, "ca-key.pem")
|
||||
cert := path.Join(root, "ca-cert.pem")
|
||||
config := path.Join(root, "ca-ca.config")
|
||||
createKeyDBAndCAConfig(root, "ca")
|
||||
|
||||
openssl("genrsa", "-out", key)
|
||||
|
||||
config := path.Join(root, "ca.config")
|
||||
if err := os.WriteFile(config, []byte(caConfig), os.ModePerm); err != nil {
|
||||
log.Fatalf("cannot write file %v: %v", config, err)
|
||||
}
|
||||
openssl("req", "-new", "-x509", "-nodes", "-days", "3600", "-batch",
|
||||
"-config", config,
|
||||
"-key", key,
|
||||
|
@ -147,50 +172,139 @@ func CreateSignedCert(root, parent, serial, name, commonName string) {
|
|||
"-out", cert)
|
||||
}
|
||||
|
||||
// CreateCRL creates a new empty certificate revocation list
|
||||
// for the provided parent
|
||||
func CreateCRL(root, parent string) {
|
||||
log.Infof("Creating CRL for root CA in %v", root)
|
||||
caKey := path.Join(root, parent+"-key.pem")
|
||||
caCert := path.Join(root, parent+"-cert.pem")
|
||||
configPath := path.Join(root, parent+"-ca.config")
|
||||
crlPath := path.Join(root, parent+"-crl.pem")
|
||||
createKeyDBAndCAConfig(root, parent)
|
||||
|
||||
openssl("ca", "-gencrl",
|
||||
"-keyfile", caKey,
|
||||
"-cert", caCert,
|
||||
"-config", configPath,
|
||||
"-out",
|
||||
crlPath,
|
||||
)
|
||||
}
|
||||
|
||||
// RevokeCertAndRegenerateCRL revokes a provided certificate under the
|
||||
// provided parent CA and regenerates the CRL file for that parent
|
||||
func RevokeCertAndRegenerateCRL(root, parent, name string) {
|
||||
log.Infof("Revoking certificate %s", name)
|
||||
caKey := path.Join(root, parent+"-key.pem")
|
||||
caCert := path.Join(root, parent+"-cert.pem")
|
||||
cert := path.Join(root, name+"-cert.pem")
|
||||
configPath := path.Join(root, parent+"-ca.config")
|
||||
createKeyDBAndCAConfig(root, parent)
|
||||
|
||||
openssl("ca", "-revoke", cert,
|
||||
"-keyfile", caKey,
|
||||
"-cert", caCert,
|
||||
"-config", configPath,
|
||||
)
|
||||
|
||||
CreateCRL(root, parent)
|
||||
}
|
||||
|
||||
// ClientServerKeyPairs is used in tests
|
||||
type ClientServerKeyPairs struct {
|
||||
ServerCert string
|
||||
ServerKey string
|
||||
ServerCA string
|
||||
ServerName string
|
||||
ClientCert string
|
||||
ClientKey string
|
||||
ClientCA string
|
||||
ServerCert string
|
||||
ServerKey string
|
||||
ServerCA string
|
||||
ServerName string
|
||||
ServerCRL string
|
||||
RevokedServerCert string
|
||||
RevokedServerKey string
|
||||
RevokedServerName string
|
||||
ClientCert string
|
||||
ClientKey string
|
||||
ClientCA string
|
||||
ClientCRL string
|
||||
RevokedClientCert string
|
||||
RevokedClientKey string
|
||||
RevokedClientName string
|
||||
CombinedCRL string
|
||||
}
|
||||
|
||||
var serialCounter = 0
|
||||
|
||||
// CreateClientServerCertPairs creates certificate pairs for use in test
|
||||
func CreateClientServerCertPairs(root string) ClientServerKeyPairs {
|
||||
// Create the certs and configs.
|
||||
CreateCA(root)
|
||||
|
||||
serverSerial := fmt.Sprintf("%03d", serialCounter*2+1)
|
||||
clientSerial := fmt.Sprintf("%03d", serialCounter*2+2)
|
||||
serverCASerial := fmt.Sprintf("%03d", serialCounter*2+1)
|
||||
serverSerial := fmt.Sprintf("%03d", serialCounter*2+3)
|
||||
revokedServerSerial := fmt.Sprintf("%03d", serialCounter*2+5)
|
||||
clientCASerial := fmt.Sprintf("%03d", serialCounter*2+2)
|
||||
clientCertSerial := fmt.Sprintf("%03d", serialCounter*2+4)
|
||||
revokedClientSerial := fmt.Sprintf("%03d", serialCounter*2+6)
|
||||
|
||||
serialCounter = serialCounter + 1
|
||||
serialCounter = serialCounter + 3
|
||||
|
||||
serverName := fmt.Sprintf("server-%s", serverSerial)
|
||||
serverCACommonName := fmt.Sprintf("Server %s CA", serverSerial)
|
||||
serverCAName := fmt.Sprintf("servers-ca-%s", serverCASerial)
|
||||
serverCACommonName := fmt.Sprintf("Servers %s CA", serverCASerial)
|
||||
serverCertName := fmt.Sprintf("server-instance-%s", serverSerial)
|
||||
serverCertCommonName := fmt.Sprintf("server%s.example.com", serverSerial)
|
||||
revokedServerCertName := fmt.Sprintf("server-instance-%s", revokedServerSerial)
|
||||
revokedServerCertCommonName := fmt.Sprintf("server%s.example.com", revokedServerSerial)
|
||||
|
||||
clientName := fmt.Sprintf("clients-%s", serverSerial)
|
||||
clientCACommonName := fmt.Sprintf("Clients %s CA", serverSerial)
|
||||
clientCertName := fmt.Sprintf("client-instance-%s", serverSerial)
|
||||
clientCertCommonName := fmt.Sprintf("Client Instance %s", serverSerial)
|
||||
clientCAName := fmt.Sprintf("clients-ca-%s", clientCASerial)
|
||||
clientCACommonName := fmt.Sprintf("Clients %s CA", clientCASerial)
|
||||
clientCertName := fmt.Sprintf("client-instance-%s", clientCertSerial)
|
||||
clientCertCommonName := fmt.Sprintf("client%s.example.com", clientCertSerial)
|
||||
revokedClientCertName := fmt.Sprintf("client-instance-%s", revokedClientSerial)
|
||||
revokedClientCertCommonName := fmt.Sprintf("client%s.example.com", revokedClientSerial)
|
||||
|
||||
CreateSignedCert(root, CA, serverSerial, serverName, serverCACommonName)
|
||||
CreateSignedCert(root, serverName, serverSerial, serverCertName, serverCertCommonName)
|
||||
CreateSignedCert(root, CA, serverCASerial, serverCAName, serverCACommonName)
|
||||
CreateSignedCert(root, serverCAName, serverSerial, serverCertName, serverCertCommonName)
|
||||
CreateSignedCert(root, serverCAName, revokedServerSerial, revokedServerCertName, revokedServerCertCommonName)
|
||||
RevokeCertAndRegenerateCRL(root, serverCAName, revokedServerCertName)
|
||||
|
||||
CreateSignedCert(root, CA, clientSerial, clientName, clientCACommonName)
|
||||
CreateSignedCert(root, clientName, serverSerial, clientCertName, clientCertCommonName)
|
||||
CreateSignedCert(root, CA, clientCASerial, clientCAName, clientCACommonName)
|
||||
CreateSignedCert(root, clientCAName, clientCertSerial, clientCertName, clientCertCommonName)
|
||||
CreateSignedCert(root, clientCAName, revokedClientSerial, revokedClientCertName, revokedClientCertCommonName)
|
||||
RevokeCertAndRegenerateCRL(root, clientCAName, revokedClientCertName)
|
||||
|
||||
serverCRLPath := path.Join(root, fmt.Sprintf("%s-crl.pem", serverCAName))
|
||||
clientCRLPath := path.Join(root, fmt.Sprintf("%s-crl.pem", clientCAName))
|
||||
combinedCRLPath := path.Join(root, fmt.Sprintf("%s-%s-combined-crl.pem", serverCAName, clientCAName))
|
||||
|
||||
serverCRLBytes, err := ioutil.ReadFile(serverCRLPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not read server CRL file")
|
||||
}
|
||||
|
||||
clientCRLBytes, err := ioutil.ReadFile(clientCRLPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not read client CRL file")
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(combinedCRLPath, append(serverCRLBytes, clientCRLBytes...), 0777)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not write combined CRL file")
|
||||
}
|
||||
|
||||
return ClientServerKeyPairs{
|
||||
ServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCertName)),
|
||||
ServerKey: path.Join(root, fmt.Sprintf("%s-key.pem", serverCertName)),
|
||||
ServerCA: path.Join(root, fmt.Sprintf("%s-cert.pem", serverName)),
|
||||
ClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCertName)),
|
||||
ClientKey: path.Join(root, fmt.Sprintf("%s-key.pem", clientCertName)),
|
||||
ClientCA: path.Join(root, fmt.Sprintf("%s-cert.pem", clientName)),
|
||||
ServerName: serverCertCommonName,
|
||||
ServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCertName)),
|
||||
ServerKey: path.Join(root, fmt.Sprintf("%s-key.pem", serverCertName)),
|
||||
ServerCA: path.Join(root, fmt.Sprintf("%s-cert.pem", serverCAName)),
|
||||
ServerCRL: serverCRLPath,
|
||||
RevokedServerCert: path.Join(root, fmt.Sprintf("%s-cert.pem", revokedServerCertName)),
|
||||
RevokedServerKey: path.Join(root, fmt.Sprintf("%s-key.pem", revokedServerCertName)),
|
||||
ClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCertName)),
|
||||
ClientKey: path.Join(root, fmt.Sprintf("%s-key.pem", clientCertName)),
|
||||
ClientCA: path.Join(root, fmt.Sprintf("%s-cert.pem", clientCAName)),
|
||||
ClientCRL: clientCRLPath,
|
||||
RevokedClientCert: path.Join(root, fmt.Sprintf("%s-cert.pem", revokedClientCertName)),
|
||||
RevokedClientKey: path.Join(root, fmt.Sprintf("%s-key.pem", revokedClientCertName)),
|
||||
CombinedCRL: combinedCRLPath,
|
||||
ServerName: serverCertCommonName,
|
||||
RevokedServerName: revokedServerCertCommonName,
|
||||
RevokedClientName: revokedClientCertCommonName,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,6 +65,7 @@ func testClientServer(t *testing.T, combineCerts bool) {
|
|||
clientServerKeyPairs.ServerCert,
|
||||
clientServerKeyPairs.ServerKey,
|
||||
clientServerKeyPairs.ClientCA,
|
||||
clientServerKeyPairs.ClientCRL,
|
||||
serverCA,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
|
@ -75,6 +76,7 @@ func testClientServer(t *testing.T, combineCerts bool) {
|
|||
clientServerKeyPairs.ClientCert,
|
||||
clientServerKeyPairs.ClientKey,
|
||||
clientServerKeyPairs.ServerCA,
|
||||
clientServerKeyPairs.ServerCRL,
|
||||
clientServerKeyPairs.ServerName,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
|
@ -109,9 +111,6 @@ func testClientServer(t *testing.T, combineCerts bool) {
|
|||
}()
|
||||
|
||||
serverConn, err := listener.Accept()
|
||||
if clientErr != nil {
|
||||
t.Fatalf("Dial failed: %v", clientErr)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Accept failed: %v", err)
|
||||
}
|
||||
|
@ -127,6 +126,10 @@ func testClientServer(t *testing.T, combineCerts bool) {
|
|||
|
||||
wg.Wait()
|
||||
|
||||
if clientErr != nil {
|
||||
t.Fatalf("Dial failed: %v", clientErr)
|
||||
}
|
||||
|
||||
//
|
||||
// Negative case: connect a client with wrong cert (using the
|
||||
// server cert on the client side).
|
||||
|
@ -137,6 +140,7 @@ func testClientServer(t *testing.T, combineCerts bool) {
|
|||
clientServerKeyPairs.ServerCert,
|
||||
clientServerKeyPairs.ServerKey,
|
||||
clientServerKeyPairs.ServerCA,
|
||||
clientServerKeyPairs.ServerCRL,
|
||||
clientServerKeyPairs.ServerName,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
|
@ -188,6 +192,7 @@ func getServerConfigWithoutCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Co
|
|||
keypairs.ServerCert,
|
||||
keypairs.ServerKey,
|
||||
keypairs.ClientCA,
|
||||
keypairs.ClientCRL,
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
}
|
||||
|
@ -197,6 +202,7 @@ func getServerConfigWithCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Confi
|
|||
keypairs.ServerCert,
|
||||
keypairs.ServerKey,
|
||||
keypairs.ClientCA,
|
||||
keypairs.ClientCRL,
|
||||
keypairs.ServerCA,
|
||||
tls.VersionTLS12)
|
||||
}
|
||||
|
@ -207,6 +213,7 @@ func getClientConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) {
|
|||
keypairs.ClientCert,
|
||||
keypairs.ClientKey,
|
||||
keypairs.ServerCA,
|
||||
keypairs.ServerCRL,
|
||||
keypairs.ServerName,
|
||||
tls.VersionTLS12)
|
||||
}
|
||||
|
@ -296,6 +303,7 @@ func testNumberOfCertsWithOrWithoutCombining(t *testing.T, numCertsExpected int,
|
|||
clientServerKeyPairs.ServerCert,
|
||||
clientServerKeyPairs.ServerKey,
|
||||
clientServerKeyPairs.ClientCA,
|
||||
clientServerKeyPairs.ClientCRL,
|
||||
serverCA,
|
||||
tls.VersionTLS12)
|
||||
|
||||
|
@ -312,3 +320,171 @@ func TestNumberOfCertsWithoutCombining(t *testing.T) {
|
|||
func TestNumberOfCertsWithCombining(t *testing.T) {
|
||||
testNumberOfCertsWithOrWithoutCombining(t, 2, true)
|
||||
}
|
||||
|
||||
func assertTLSHandshakeFails(t *testing.T, serverConfig, clientConfig *tls.Config) {
|
||||
// Create a TLS server listener.
|
||||
listener, err := tls.Listen("tcp", ":0", serverConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Listen failed: %v", err)
|
||||
}
|
||||
addr := listener.Addr().String()
|
||||
defer listener.Close()
|
||||
// create a dialer with timeout
|
||||
dialer := new(net.Dialer)
|
||||
dialer.Timeout = 10 * time.Second
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
var clientErr error
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var clientConn *tls.Conn
|
||||
clientConn, clientErr = tls.DialWithDialer(dialer, "tcp", addr, clientConfig)
|
||||
if clientErr == nil {
|
||||
clientConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
serverConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// We should always be able to accept on the socket
|
||||
t.Fatalf("Accept failed: %v", err)
|
||||
}
|
||||
|
||||
err = serverConn.(*tls.Conn).Handshake()
|
||||
if err != nil {
|
||||
if !(strings.Contains(err.Error(), "Certificate revoked: CommonName=") ||
|
||||
strings.Contains(err.Error(), "remote error: tls: bad certificate")) {
|
||||
t.Fatalf("Wrong error returned: %v", err)
|
||||
}
|
||||
} else {
|
||||
t.Fatal("Server should have failed the TLS handshake but it did not")
|
||||
}
|
||||
serverConn.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestClientServerWithRevokedServerCert(t *testing.T) {
|
||||
root, err := ioutil.TempDir("", "tlstest")
|
||||
if err != nil {
|
||||
t.Fatalf("TempDir failed: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(root)
|
||||
|
||||
clientServerKeyPairs := CreateClientServerCertPairs(root)
|
||||
|
||||
serverConfig, err := vttls.ServerConfig(
|
||||
clientServerKeyPairs.RevokedServerCert,
|
||||
clientServerKeyPairs.RevokedServerKey,
|
||||
clientServerKeyPairs.ClientCA,
|
||||
clientServerKeyPairs.ClientCRL,
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSServerConfig failed: %v", err)
|
||||
}
|
||||
|
||||
clientConfig, err := vttls.ClientConfig(
|
||||
vttls.VerifyIdentity,
|
||||
clientServerKeyPairs.ClientCert,
|
||||
clientServerKeyPairs.ClientKey,
|
||||
clientServerKeyPairs.ServerCA,
|
||||
clientServerKeyPairs.ServerCRL,
|
||||
clientServerKeyPairs.RevokedServerName,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSClientConfig failed: %v", err)
|
||||
}
|
||||
|
||||
assertTLSHandshakeFails(t, serverConfig, clientConfig)
|
||||
|
||||
serverConfig, err = vttls.ServerConfig(
|
||||
clientServerKeyPairs.RevokedServerCert,
|
||||
clientServerKeyPairs.RevokedServerKey,
|
||||
clientServerKeyPairs.ClientCA,
|
||||
clientServerKeyPairs.CombinedCRL,
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSServerConfig failed: %v", err)
|
||||
}
|
||||
|
||||
clientConfig, err = vttls.ClientConfig(
|
||||
vttls.VerifyIdentity,
|
||||
clientServerKeyPairs.ClientCert,
|
||||
clientServerKeyPairs.ClientKey,
|
||||
clientServerKeyPairs.ServerCA,
|
||||
clientServerKeyPairs.CombinedCRL,
|
||||
clientServerKeyPairs.RevokedServerName,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSClientConfig failed: %v", err)
|
||||
}
|
||||
|
||||
assertTLSHandshakeFails(t, serverConfig, clientConfig)
|
||||
}
|
||||
|
||||
func TestClientServerWithRevokedClientCert(t *testing.T) {
|
||||
root, err := ioutil.TempDir("", "tlstest")
|
||||
if err != nil {
|
||||
t.Fatalf("TempDir failed: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(root)
|
||||
|
||||
clientServerKeyPairs := CreateClientServerCertPairs(root)
|
||||
|
||||
// Single CRL
|
||||
|
||||
serverConfig, err := vttls.ServerConfig(
|
||||
clientServerKeyPairs.ServerCert,
|
||||
clientServerKeyPairs.ServerKey,
|
||||
clientServerKeyPairs.ClientCA,
|
||||
clientServerKeyPairs.ClientCRL,
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSServerConfig failed: %v", err)
|
||||
}
|
||||
|
||||
clientConfig, err := vttls.ClientConfig(
|
||||
vttls.VerifyIdentity,
|
||||
clientServerKeyPairs.RevokedClientCert,
|
||||
clientServerKeyPairs.RevokedClientKey,
|
||||
clientServerKeyPairs.ServerCA,
|
||||
clientServerKeyPairs.ServerCRL,
|
||||
clientServerKeyPairs.ServerName,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSClientConfig failed: %v", err)
|
||||
}
|
||||
|
||||
assertTLSHandshakeFails(t, serverConfig, clientConfig)
|
||||
|
||||
// CombinedCRL
|
||||
|
||||
serverConfig, err = vttls.ServerConfig(
|
||||
clientServerKeyPairs.ServerCert,
|
||||
clientServerKeyPairs.ServerKey,
|
||||
clientServerKeyPairs.ClientCA,
|
||||
clientServerKeyPairs.CombinedCRL,
|
||||
"",
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSServerConfig failed: %v", err)
|
||||
}
|
||||
|
||||
clientConfig, err = vttls.ClientConfig(
|
||||
vttls.VerifyIdentity,
|
||||
clientServerKeyPairs.RevokedClientCert,
|
||||
clientServerKeyPairs.RevokedClientKey,
|
||||
clientServerKeyPairs.ServerCA,
|
||||
clientServerKeyPairs.CombinedCRL,
|
||||
clientServerKeyPairs.ServerName,
|
||||
tls.VersionTLS12)
|
||||
if err != nil {
|
||||
t.Fatalf("TLSClientConfig failed: %v", err)
|
||||
}
|
||||
|
||||
assertTLSHandshakeFails(t, serverConfig, clientConfig)
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ var (
|
|||
cert = flag.String("vtctld_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("vtctld_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("vtctld_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("vtctld_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("vtctld_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -37,5 +38,5 @@ var (
|
|||
// insecure if no flags were set) based on the vtctld_grpc_* flags declared by
|
||||
// this package.
|
||||
func SecureDialOption() (grpc.DialOption, error) {
|
||||
return grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
return grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ var (
|
|||
cert = flag.String("vtgate_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("vtgate_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("vtgate_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("vtgate_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("vtgate_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -60,7 +61,7 @@ func dial(ctx context.Context, addr string) (vtgateconn.Impl, error) {
|
|||
// DialWithOpts allows for custom dial options to be set on a vtgateConn.
|
||||
func DialWithOpts(ctx context.Context, opts ...grpc.DialOption) vtgateconn.DialerFunc {
|
||||
return func(ctx context.Context, address string) (vtgateconn.Impl, error) {
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ var (
|
|||
mysqlSslCert = flag.String("mysql_server_ssl_cert", "", "Path to the ssl cert for mysql server plugin SSL")
|
||||
mysqlSslKey = flag.String("mysql_server_ssl_key", "", "Path to ssl key for mysql server plugin SSL")
|
||||
mysqlSslCa = flag.String("mysql_server_ssl_ca", "", "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.")
|
||||
mysqlSslCrl = flag.String("mysql_server_ssl_crl", "", "Path to ssl CRL for mysql server plugin SSL")
|
||||
|
||||
mysqlTLSMinVersion = flag.String("mysql_server_tls_min_version", "", "Configures the minimal TLS version negotiated when SSL is enabled. Defaults to TLSv1.2. Options: TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3.")
|
||||
|
||||
|
@ -364,8 +365,8 @@ var sigChan chan os.Signal
|
|||
var vtgateHandle *vtgateHandler
|
||||
|
||||
// initTLSConfig inits tls config for the given mysql listener
|
||||
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error {
|
||||
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA, mysqlMinTLSVersion)
|
||||
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool, mysqlMinTLSVersion uint16) error {
|
||||
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion)
|
||||
if err != nil {
|
||||
log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
|
||||
return err
|
||||
|
@ -376,7 +377,7 @@ func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mys
|
|||
signal.Notify(sigChan, syscall.SIGHUP)
|
||||
go func() {
|
||||
for range sigChan {
|
||||
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA, mysqlMinTLSVersion)
|
||||
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslCrl, mysqlSslServerCA, mysqlMinTLSVersion)
|
||||
if err != nil {
|
||||
log.Errorf("grpcutils.TLSServerConfig failed: %v", err)
|
||||
} else {
|
||||
|
@ -437,7 +438,7 @@ func initMySQLProtocol() {
|
|||
log.Exitf("mysql.NewListener failed: %v", err)
|
||||
}
|
||||
|
||||
_ = initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlSslServerCA, *mysqlServerRequireSecureTransport, tlsVersion)
|
||||
_ = initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlSslCrl, *mysqlSslServerCA, *mysqlServerRequireSecureTransport, tlsVersion)
|
||||
}
|
||||
mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS)
|
||||
// Check for the connection threshold
|
||||
|
|
|
@ -264,6 +264,7 @@ func testInitTLSConfig(t *testing.T, serverCA bool) {
|
|||
}
|
||||
defer os.RemoveAll(root)
|
||||
tlstest.CreateCA(root)
|
||||
tlstest.CreateCRL(root, tlstest.CA)
|
||||
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
|
||||
|
||||
serverCACert := ""
|
||||
|
@ -272,7 +273,7 @@ func testInitTLSConfig(t *testing.T, serverCA bool) {
|
|||
}
|
||||
|
||||
listener := &mysql.Listener{}
|
||||
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), serverCACert, true, tls.VersionTLS12); err != nil {
|
||||
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), path.Join(root, "ca-crl.pem"), serverCACert, true, tls.VersionTLS12); err != nil {
|
||||
t.Fatalf("init tls config failure due to: +%v", err)
|
||||
}
|
||||
|
||||
|
|
|
@ -105,9 +105,9 @@ func querylogzHandler(ch chan interface{}, w http.ResponseWriter, r *http.Reques
|
|||
stats, ok := out.(*LogStats)
|
||||
if !ok {
|
||||
err := fmt.Errorf("unexpected value in %s: %#v (expecting value of type %T)", QueryLogger.Name(), out, &LogStats{})
|
||||
io.WriteString(w, `<tr class="error">`)
|
||||
io.WriteString(w, err.Error())
|
||||
io.WriteString(w, "</tr>")
|
||||
_, _ = io.WriteString(w, `<tr class="error">`)
|
||||
_, _ = io.WriteString(w, err.Error())
|
||||
_, _ = io.WriteString(w, "</tr>")
|
||||
log.Error(err)
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ var (
|
|||
cert = flag.String("tablet_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("tablet_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("tablet_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("tablet_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("tablet_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -73,7 +74,7 @@ func DialTablet(tablet *topodatapb.Tablet, failFast grpcclient.FailFast) (querys
|
|||
} else {
|
||||
addr = tablet.Hostname
|
||||
}
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -230,7 +230,7 @@ func (dialer *cachedConnDialer) pollOnce(ctx context.Context, addr string) (clie
|
|||
// It returns the three-tuple of client-interface, closer, and error that the
|
||||
// main dial func returns.
|
||||
func (dialer *cachedConnDialer) newdial(ctx context.Context, addr string) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) {
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
dialer.connWaitSema.Release()
|
||||
return nil, nil, err
|
||||
|
|
|
@ -48,6 +48,7 @@ var (
|
|||
cert = flag.String("tablet_manager_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("tablet_manager_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("tablet_manager_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("tablet_manager_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("tablet_manager_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -111,7 +112,7 @@ func NewClient() *Client {
|
|||
// dial returns a client to use
|
||||
func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, io.Closer, error) {
|
||||
addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"]))
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -125,7 +126,7 @@ func (client *grpcClient) dial(ctx context.Context, tablet *topodatapb.Tablet) (
|
|||
|
||||
func (client *grpcClient) dialPool(ctx context.Context, tablet *topodatapb.Tablet) (tabletmanagerservicepb.TabletManagerClient, error) {
|
||||
addr := netutil.JoinHostPort(tablet.Hostname, int32(tablet.PortMap["grpc"]))
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
Copyright 2021 The Vitess Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package vttls
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
|
||||
"vitess.io/vitess/go/vt/log"
|
||||
)
|
||||
|
||||
type verifyPeerCertificateFunc func([][]byte, [][]*x509.Certificate) error
|
||||
|
||||
func certIsRevoked(cert *x509.Certificate, crl *pkix.CertificateList) bool {
|
||||
if crl.HasExpired(time.Now()) {
|
||||
log.Warningf("The current Certificate Revocation List (CRL) is past expiry date and must be updated. Revoked certificates will still be rejected in this state.")
|
||||
}
|
||||
|
||||
for _, revoked := range crl.TBSCertList.RevokedCertificates {
|
||||
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func verifyPeerCertificateAgainstCRL(crl string) (verifyPeerCertificateFunc, error) {
|
||||
crlSet, err := loadCRLSet(crl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
for _, chain := range verifiedChains {
|
||||
for i := 0; i < len(chain)-1; i++ {
|
||||
cert := chain[i]
|
||||
issuerCert := chain[i+1]
|
||||
for _, crl := range crlSet {
|
||||
if issuerCert.CheckCRLSignature(crl) == nil {
|
||||
if certIsRevoked(cert, crl) {
|
||||
return fmt.Errorf("Certificate revoked: CommonName=%v", cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadCRLSet(crl string) ([]*pkix.CertificateList, error) {
|
||||
body, err := ioutil.ReadFile(crl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crlSet := make([]*pkix.CertificateList, 0)
|
||||
for len(body) > 0 {
|
||||
var block *pem.Block
|
||||
block, body = pem.Decode(body)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
if block.Type != "X509 CRL" {
|
||||
continue
|
||||
}
|
||||
|
||||
parsedCRL, err := x509.ParseCRL(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
crlSet = append(crlSet, parsedCRL)
|
||||
}
|
||||
return crlSet, nil
|
||||
}
|
|
@ -128,7 +128,7 @@ var onceByKeys = sync.Map{}
|
|||
|
||||
// ClientConfig returns the TLS config to use for a client to
|
||||
// connect to a server with the provided parameters.
|
||||
func ClientConfig(mode SslMode, cert, key, ca, name string, minTLSVersion uint16) (*tls.Config, error) {
|
||||
func ClientConfig(mode SslMode, cert, key, ca, crl, name string, minTLSVersion uint16) (*tls.Config, error) {
|
||||
config := newTLSConfig(minTLSVersion)
|
||||
|
||||
// Load the client-side cert & key if any.
|
||||
|
@ -190,12 +190,20 @@ func ClientConfig(mode SslMode, cert, key, ca, name string, minTLSVersion uint16
|
|||
return nil, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "invalid mode: %s", mode)
|
||||
}
|
||||
|
||||
if crl != "" {
|
||||
crlFunc, err := verifyPeerCertificateAgainstCRL(crl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.VerifyPeerCertificate = crlFunc
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// ServerConfig returns the TLS config to use for a server to
|
||||
// accept client connections.
|
||||
func ServerConfig(cert, key, ca, serverCA string, minTLSVersion uint16) (*tls.Config, error) {
|
||||
func ServerConfig(cert, key, ca, crl, serverCA string, minTLSVersion uint16) (*tls.Config, error) {
|
||||
config := newTLSConfig(minTLSVersion)
|
||||
|
||||
var certificates *[]tls.Certificate
|
||||
|
@ -225,6 +233,14 @@ func ServerConfig(cert, key, ca, serverCA string, minTLSVersion uint16) (*tls.Co
|
|||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
if crl != "" {
|
||||
crlFunc, err := verifyPeerCertificateAgainstCRL(crl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.VerifyPeerCertificate = crlFunc
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ var (
|
|||
cert = flag.String("vtworker_client_grpc_cert", "", "the cert to use to connect")
|
||||
key = flag.String("vtworker_client_grpc_key", "", "the key to use to connect")
|
||||
ca = flag.String("vtworker_client_grpc_ca", "", "the server ca to use to validate servers when connecting")
|
||||
crl = flag.String("vtworker_client_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
|
||||
name = flag.String("vtworker_client_grpc_server_name", "", "the server name to use to validate server certificate")
|
||||
)
|
||||
|
||||
|
@ -49,7 +50,7 @@ type gRPCVtworkerClient struct {
|
|||
|
||||
func gRPCVtworkerClientFactory(addr string) (vtworkerclient.Client, error) {
|
||||
// create the RPC client
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
|
||||
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче