credentials/xds: ServerHandshake() implementation (#4089)
This commit is contained in:
Родитель
03d4b8878b
Коммит
17e2cbe887
|
@ -33,6 +33,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/attributes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
@ -50,9 +51,9 @@ func init() {
|
|||
// credentials implementation.
|
||||
type ClientOptions struct {
|
||||
// FallbackCreds specifies the fallback credentials to be used when either
|
||||
// the `xds` scheme is not used in the user's dial target or when the xDS
|
||||
// server does not return any security configuration. Attempts to create
|
||||
// client credentials without a fallback credentials will fail.
|
||||
// the `xds` scheme is not used in the user's dial target or when the
|
||||
// management server does not return any security configuration. Attempts to
|
||||
// create client credentials without fallback credentials will fail.
|
||||
FallbackCreds credentials.TransportCredentials
|
||||
}
|
||||
|
||||
|
@ -68,6 +69,27 @@ func NewClientCredentials(opts ClientOptions) (credentials.TransportCredentials,
|
|||
}, nil
|
||||
}
|
||||
|
||||
// ServerOptions contains parameters to configure a new server-side xDS
|
||||
// credentials implementation.
|
||||
type ServerOptions struct {
|
||||
// FallbackCreds specifies the fallback credentials to be used when the
|
||||
// management server does not return any security configuration. Attempts to
|
||||
// create server credentials without fallback credentials will fail.
|
||||
FallbackCreds credentials.TransportCredentials
|
||||
}
|
||||
|
||||
// NewServerCredentials returns a new server-side transport credentials
|
||||
// implementation which uses xDS APIs to fetch its security configuration.
|
||||
func NewServerCredentials(opts ServerOptions) (credentials.TransportCredentials, error) {
|
||||
if opts.FallbackCreds == nil {
|
||||
return nil, errors.New("missing fallback credentials")
|
||||
}
|
||||
return &credsImpl{
|
||||
isClient: false,
|
||||
fallback: opts.FallbackCreds,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// credsImpl is an implementation of the credentials.TransportCredentials
|
||||
// interface which uses xDS APIs to fetch its security configuration.
|
||||
type credsImpl struct {
|
||||
|
@ -98,11 +120,15 @@ func getHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
|
|||
// responsible for populating these fields.
|
||||
//
|
||||
// Safe for concurrent access.
|
||||
//
|
||||
// TODO(easwars): Move this type and any other non-user functionality to an
|
||||
// internal package.
|
||||
type HandshakeInfo struct {
|
||||
mu sync.Mutex
|
||||
rootProvider certprovider.Provider
|
||||
identityProvider certprovider.Provider
|
||||
acceptedSANs map[string]bool // Only on the client side.
|
||||
mu sync.Mutex
|
||||
rootProvider certprovider.Provider
|
||||
identityProvider certprovider.Provider
|
||||
acceptedSANs map[string]bool // Only on the client side.
|
||||
requireClientCert bool // Only on server side.
|
||||
}
|
||||
|
||||
// SetRootCertProvider updates the root certificate provider.
|
||||
|
@ -129,6 +155,14 @@ func (hi *HandshakeInfo) SetAcceptedSANs(sans []string) {
|
|||
hi.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetRequireClientCert updates whether a client cert is required during the
|
||||
// ServerHandshake(). A value of true indicates that we are performing mTLS.
|
||||
func (hi *HandshakeInfo) SetRequireClientCert(require bool) {
|
||||
hi.mu.Lock()
|
||||
hi.requireClientCert = require
|
||||
hi.mu.Unlock()
|
||||
}
|
||||
|
||||
// UseFallbackCreds returns true when fallback credentials are to be used based
|
||||
// on the contents of the HandshakeInfo.
|
||||
func (hi *HandshakeInfo) UseFallbackCreds() bool {
|
||||
|
@ -141,27 +175,13 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool {
|
|||
return hi.identityProvider == nil && hi.rootProvider == nil
|
||||
}
|
||||
|
||||
func (hi *HandshakeInfo) validate(isClient bool) error {
|
||||
func (hi *HandshakeInfo) makeClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
|
||||
hi.mu.Lock()
|
||||
defer hi.mu.Unlock()
|
||||
|
||||
// On the client side, rootProvider is mandatory. IdentityProvider is
|
||||
// optional based on whether the client is doing TLS or mTLS.
|
||||
if isClient && hi.rootProvider == nil {
|
||||
return errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake. Please check configuration on the management server")
|
||||
if hi.rootProvider == nil {
|
||||
return nil, errors.New("xds: CertificateProvider to fetch trusted roots is missing, cannot perform TLS handshake. Please check configuration on the management server")
|
||||
}
|
||||
|
||||
// On the server side, identityProvider is mandatory. RootProvider is
|
||||
// optional based on whether the server is doing TLS or mTLS.
|
||||
if !isClient && hi.identityProvider == nil {
|
||||
return errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake. Please check configuration on the management server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error) {
|
||||
hi.mu.Lock()
|
||||
// Since the call to KeyMaterial() can block, we read the providers under
|
||||
// the lock but call the actual function after releasing the lock.
|
||||
rootProv, idProv := hi.rootProvider, hi.identityProvider
|
||||
|
@ -173,13 +193,13 @@ func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error)
|
|||
// includes hostname verification) or none. We are forced to go with the
|
||||
// latter and perform the normal cert validation ourselves.
|
||||
cfg := &tls.Config{InsecureSkipVerify: true}
|
||||
if rootProv != nil {
|
||||
km, err := rootProv.KeyMaterial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err)
|
||||
}
|
||||
cfg.RootCAs = km.Roots
|
||||
|
||||
km, err := rootProv.KeyMaterial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err)
|
||||
}
|
||||
cfg.RootCAs = km.Roots
|
||||
|
||||
if idProv != nil {
|
||||
km, err := idProv.KeyMaterial(ctx)
|
||||
if err != nil {
|
||||
|
@ -190,6 +210,39 @@ func (hi *HandshakeInfo) makeTLSConfig(ctx context.Context) (*tls.Config, error)
|
|||
return cfg, nil
|
||||
}
|
||||
|
||||
func (hi *HandshakeInfo) makeServerSideTLSConfig(ctx context.Context) (*tls.Config, error) {
|
||||
cfg := &tls.Config{ClientAuth: tls.NoClientCert}
|
||||
hi.mu.Lock()
|
||||
// On the server side, identityProvider is mandatory. RootProvider is
|
||||
// optional based on whether the server is doing TLS or mTLS.
|
||||
if hi.identityProvider == nil {
|
||||
return nil, errors.New("xds: CertificateProvider to fetch identity certificate is missing, cannot perform TLS handshake. Please check configuration on the management server")
|
||||
}
|
||||
// Since the call to KeyMaterial() can block, we read the providers under
|
||||
// the lock but call the actual function after releasing the lock.
|
||||
rootProv, idProv := hi.rootProvider, hi.identityProvider
|
||||
if hi.requireClientCert {
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
hi.mu.Unlock()
|
||||
|
||||
// identityProvider is mandatory on the server side.
|
||||
km, err := idProv.KeyMaterial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xds: fetching identity certificates from CertificateProvider failed: %v", err)
|
||||
}
|
||||
cfg.Certificates = km.Certs
|
||||
|
||||
if rootProv != nil {
|
||||
km, err := rootProv.KeyMaterial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xds: fetching trusted roots from CertificateProvider failed: %v", err)
|
||||
}
|
||||
cfg.ClientCAs = km.Roots
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (hi *HandshakeInfo) matchingSANExists(cert *x509.Certificate) bool {
|
||||
if len(hi.acceptedSANs) == 0 {
|
||||
// An empty list of acceptedSANs means "accept everything".
|
||||
|
@ -265,9 +318,6 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
|
|||
if hi.UseFallbackCreds() {
|
||||
return c.fallback.ClientHandshake(ctx, authority, rawConn)
|
||||
}
|
||||
if err := hi.validate(c.isClient); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// We build the tls.Config with the following values
|
||||
// 1. Root certificate as returned by the root provider.
|
||||
|
@ -281,7 +331,7 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
|
|||
// 4. Key usage to match whether client/server usage.
|
||||
// 5. A `VerifyPeerCertificate` function which performs normal peer
|
||||
// cert verification using configured roots, and the custom SAN checks.
|
||||
cfg, err := hi.makeTLSConfig(ctx)
|
||||
cfg, err := hi.makeClientSideTLSConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -349,12 +399,55 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
|
|||
}
|
||||
|
||||
// ServerHandshake performs the TLS handshake on the server-side.
|
||||
func (c *credsImpl) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
func (c *credsImpl) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
if c.isClient {
|
||||
return nil, nil, errors.New("ServerHandshake is not supported for client credentials")
|
||||
}
|
||||
// TODO(easwars): Implement along with server side xDS implementation.
|
||||
return nil, nil, errors.New("not implemented")
|
||||
|
||||
// An xds-enabled gRPC server wraps the underlying raw net.Conn in a type
|
||||
// that provides a way to retrieve `HandshakeInfo`, which contains the
|
||||
// certificate providers to be used during the handshake. If the net.Conn
|
||||
// passed to this function does not implement this interface, or if the
|
||||
// `HandshakeInfo` does not contain the information we are looking for, we
|
||||
// delegate the handshake to the fallback credentials.
|
||||
hiConn, ok := rawConn.(interface{ XDSHandshakeInfo() *HandshakeInfo })
|
||||
if !ok {
|
||||
return c.fallback.ServerHandshake(rawConn)
|
||||
}
|
||||
hi := hiConn.XDSHandshakeInfo()
|
||||
if hi.UseFallbackCreds() {
|
||||
return c.fallback.ServerHandshake(rawConn)
|
||||
}
|
||||
|
||||
// An xds-enabled gRPC server is expected to wrap the underlying raw
|
||||
// net.Conn in a type which provides a way to retrieve the deadline set on
|
||||
// it. If we cannot retrieve the deadline here, we fail (by setting deadline
|
||||
// to time.Now()), instead of using a default deadline and possibly taking
|
||||
// longer to eventually fail.
|
||||
deadline := time.Now()
|
||||
if dConn, ok := rawConn.(interface{ GetDeadline() time.Time }); ok {
|
||||
deadline = dConn.GetDeadline()
|
||||
}
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||
defer cancel()
|
||||
cfg, err := hi.makeServerSideTLSConfig(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conn := tls.Server(rawConn, cfg)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
conn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
info := credentials.TLSInfo{
|
||||
State: conn.ConnectionState(),
|
||||
CommonAuthInfo: credentials.CommonAuthInfo{
|
||||
SecurityLevel: credentials.PrivacyAndIntegrity,
|
||||
},
|
||||
}
|
||||
info.SPIFFEID = credinternal.SPIFFEIDFromState(conn.ConnectionState())
|
||||
return credinternal.WrapSyscallConn(rawConn, conn), info, nil
|
||||
}
|
||||
|
||||
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||
|
|
|
@ -40,9 +40,10 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
defaultTestTimeout = 1 * time.Second
|
||||
defaultTestCertSAN = "*.test.example.com"
|
||||
authority = "authority"
|
||||
defaultTestTimeout = 10 * time.Second
|
||||
defaultTestShortTimeout = 10 * time.Millisecond
|
||||
defaultTestCertSAN = "*.test.example.com"
|
||||
authority = "authority"
|
||||
)
|
||||
|
||||
type s struct {
|
||||
|
@ -133,17 +134,6 @@ func (ts *testServer) stop() {
|
|||
ts.lis.Close()
|
||||
}
|
||||
|
||||
// A handshake function which simulates a handshake timeout. Tests usually pass
|
||||
// `defaultTestTimeout` to the ClientHandshake() method. This function just
|
||||
// hangs around for twice that duration, thus making sure that the context
|
||||
// passes to the credentials code times out.
|
||||
func testServerTLSHandshakeTimeout(_ net.Conn) handshakeResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*defaultTestTimeout)
|
||||
<-ctx.Done()
|
||||
cancel()
|
||||
return handshakeResult{err: ctx.Err()}
|
||||
}
|
||||
|
||||
// A handshake function which simulates a successful handshake without client
|
||||
// authentication (server does not request for client certificate during the
|
||||
// handshake here).
|
||||
|
@ -239,7 +229,7 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
|
|||
|
||||
// compareAuthInfo compares the AuthInfo received on the client side after a
|
||||
// successful handshake with the authInfo available on the testServer.
|
||||
func compareAuthInfo(ts *testServer, ai credentials.AuthInfo) error {
|
||||
func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
|
||||
if ai.AuthType() != "tls" {
|
||||
return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
|
||||
}
|
||||
|
@ -251,8 +241,6 @@ func compareAuthInfo(ts *testServer, ai credentials.AuthInfo) error {
|
|||
|
||||
// Read the handshake result from the testServer which contains the TLS
|
||||
// connection state and compare it with the one received on the client-side.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
val, err := ts.hsResult.Receive(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("testServer failed to return handshake result: %v", err)
|
||||
|
@ -341,7 +329,7 @@ func (s) TestClientCredsProviderFailure(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider)
|
||||
if _, _, err := creds.ClientHandshake(ctx, authority, nil); !strings.Contains(err.Error(), test.wantErr) {
|
||||
if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
|
||||
t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
|
||||
}
|
||||
})
|
||||
|
@ -410,13 +398,59 @@ func (s) TestClientCredsSuccess(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("ClientHandshake() returned failed: %q", err)
|
||||
}
|
||||
if err := compareAuthInfo(ts, ai); err != nil {
|
||||
if err := compareAuthInfo(ctx, ts, ai); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
|
||||
clientDone := make(chan struct{})
|
||||
// A handshake function which simulates a handshake timeout from the
|
||||
// server-side by simply blocking on the client-side handshake to timeout
|
||||
// and not writing any handshake data.
|
||||
hErr := errors.New("server handshake error")
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
<-clientDone
|
||||
return handshakeResult{err: hErr}
|
||||
})
|
||||
defer ts.stop()
|
||||
|
||||
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
||||
creds, err := NewClientCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", ts.address)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
|
||||
defer sCancel()
|
||||
ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
|
||||
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
|
||||
t.Fatal("ClientHandshake() succeeded when expected to timeout")
|
||||
}
|
||||
close(clientDone)
|
||||
|
||||
// Read the handshake result from the testServer and make sure the expected
|
||||
// error is returned.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
val, err := ts.hsResult.Receive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("testServer failed to return handshake result: %v", err)
|
||||
}
|
||||
hsr := val.(handshakeResult)
|
||||
if hsr.err != hErr {
|
||||
t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientCredsHandshakeFailure verifies different handshake failure cases.
|
||||
func (s) TestClientCredsHandshakeFailure(t *testing.T) {
|
||||
tests := []struct {
|
||||
|
@ -433,13 +467,6 @@ func (s) TestClientCredsHandshakeFailure(t *testing.T) {
|
|||
san: defaultTestCertSAN,
|
||||
wantErr: "x509: certificate signed by unknown authority",
|
||||
},
|
||||
{
|
||||
desc: "handshake times out",
|
||||
handshakeFunc: testServerTLSHandshakeTimeout,
|
||||
rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"),
|
||||
san: defaultTestCertSAN,
|
||||
wantErr: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
desc: "SAN mismatch",
|
||||
handshakeFunc: testServerTLSHandshake,
|
||||
|
@ -534,13 +561,13 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("ClientHandshake() returned failed: %q", err)
|
||||
}
|
||||
if err := compareAuthInfo(ts, ai); err != nil {
|
||||
if err := compareAuthInfo(ctx, ts, ai); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClone verifies the Clone() method.
|
||||
func (s) TestClone(t *testing.T) {
|
||||
// TestClientClone verifies the Clone() method on client credentials.
|
||||
func (s) TestClientClone(t *testing.T) {
|
||||
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
||||
orig, err := NewClientCredentials(opts)
|
||||
if err != nil {
|
||||
|
@ -549,7 +576,7 @@ func (s) TestClone(t *testing.T) {
|
|||
|
||||
// The credsImpl does not have any exported fields, and it does not make
|
||||
// sense to use any cmp options to look deep into. So, all we make sure here
|
||||
// is that the cloned object points to a different locaiton in memory.
|
||||
// is that the cloned object points to a different location in memory.
|
||||
if clone := orig.Clone(); clone == orig {
|
||||
t.Fatal("return value from Clone() doesn't point to new credentials instance")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,492 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2020 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package xds
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
"google.golang.org/grpc/testdata"
|
||||
)
|
||||
|
||||
func makeClientTLSConfig(t *testing.T, mTLS bool) *tls.Config {
|
||||
t.Helper()
|
||||
|
||||
pemData, err := ioutil.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
roots.AppendCertsFromPEM(pemData)
|
||||
|
||||
var certs []tls.Certificate
|
||||
if mTLS {
|
||||
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/client1_cert.pem"), testdata.Path("x509/client1_key.pem"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
certs = append(certs, cert)
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: certs,
|
||||
RootCAs: roots,
|
||||
ServerName: "*.test.example.com",
|
||||
// Setting this to true completely turns off the certificate validation
|
||||
// on the client side. So, the client side handshake always seems to
|
||||
// succeed. But if we want to turn this ON, we will need to generate
|
||||
// certificates which work with localhost, or supply a custom
|
||||
// verification function. So, the server credentials tests will rely
|
||||
// solely on the success/failure of the server-side handshake.
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a real TLS server credentials which is used as
|
||||
// fallback credentials from multiple tests.
|
||||
func makeFallbackServerCreds(t *testing.T) credentials.TransportCredentials {
|
||||
t.Helper()
|
||||
|
||||
creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
type errorCreds struct {
|
||||
credentials.TransportCredentials
|
||||
}
|
||||
|
||||
// TestServerCredsWithoutFallback verifies that the call to
|
||||
// NewServerCredentials() fails when no fallback is specified.
|
||||
func (s) TestServerCredsWithoutFallback(t *testing.T) {
|
||||
if _, err := NewServerCredentials(ServerOptions{}); err == nil {
|
||||
t.Fatal("NewServerCredentials() succeeded without specifying fallback")
|
||||
}
|
||||
}
|
||||
|
||||
type wrapperConn struct {
|
||||
net.Conn
|
||||
xdsHI *HandshakeInfo
|
||||
deadline time.Time
|
||||
}
|
||||
|
||||
func (wc *wrapperConn) XDSHandshakeInfo() *HandshakeInfo {
|
||||
return wc.xdsHI
|
||||
}
|
||||
|
||||
func (wc *wrapperConn) GetDeadline() time.Time {
|
||||
return wc.deadline
|
||||
}
|
||||
|
||||
func newWrappedConn(conn net.Conn, xdsHI *HandshakeInfo, deadline time.Time) *wrapperConn {
|
||||
return &wrapperConn{Conn: conn, xdsHI: xdsHI, deadline: deadline}
|
||||
}
|
||||
|
||||
// TestServerCredsInvalidHandshakeInfo verifies scenarios where the passed in
|
||||
// HandshakeInfo is invalid because it does not contain the expected certificate
|
||||
// providers.
|
||||
func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) {
|
||||
opts := ServerOptions{FallbackCreds: &errorCreds{}}
|
||||
creds, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
info := NewHandshakeInfo(&fakeProvider{}, nil)
|
||||
conn := newWrappedConn(nil, info, time.Time{})
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil {
|
||||
t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerCredsProviderFailure verifies the cases where an expected
|
||||
// certificate provider is missing in the HandshakeInfo value in the context.
|
||||
func (s) TestServerCredsProviderFailure(t *testing.T) {
|
||||
opts := ServerOptions{FallbackCreds: &errorCreds{}}
|
||||
creds, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
rootProvider certprovider.Provider
|
||||
identityProvider certprovider.Provider
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
desc: "erroring identity provider",
|
||||
identityProvider: &fakeProvider{err: errors.New("identity provider error")},
|
||||
wantErr: "identity provider error",
|
||||
},
|
||||
{
|
||||
desc: "erroring root provider",
|
||||
identityProvider: &fakeProvider{km: &certprovider.KeyMaterial{}},
|
||||
rootProvider: &fakeProvider{err: errors.New("root provider error")},
|
||||
wantErr: "root provider error",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
info := NewHandshakeInfo(test.rootProvider, test.identityProvider)
|
||||
conn := newWrappedConn(nil, info, time.Time{})
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
|
||||
t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerCredsHandshakeTimeout verifies the case where the client does not
|
||||
// send required handshake data before the deadline set on the net.Conn passed
|
||||
// to ServerHandshake().
|
||||
func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
|
||||
opts := ServerOptions{FallbackCreds: &errorCreds{}}
|
||||
creds, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
// Create a test server which uses the xDS server credentials created above
|
||||
// to perform TLS handshake on incoming connections.
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
hi := NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo created
|
||||
// above with a very small deadline.
|
||||
d := time.Now().Add(defaultTestShortTimeout)
|
||||
rawConn.SetDeadline(d)
|
||||
conn := newWrappedConn(rawConn, hi, d)
|
||||
|
||||
// ServerHandshake() on the xDS credentials is expected to fail.
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil {
|
||||
return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to timeout")}
|
||||
}
|
||||
return handshakeResult{}
|
||||
})
|
||||
defer ts.stop()
|
||||
|
||||
// Dial the test server, but don't trigger the TLS handshake. This will
|
||||
// cause ServerHandshake() to fail.
|
||||
rawConn, err := net.Dial("tcp", ts.address)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
||||
}
|
||||
defer rawConn.Close()
|
||||
|
||||
// Read handshake result from the testServer and expect a failure result.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
val, err := ts.hsResult.Receive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("testServer failed to return handshake result: %v", err)
|
||||
}
|
||||
hsr := val.(handshakeResult)
|
||||
if hsr.err != nil {
|
||||
t.Fatalf("testServer handshake failure: %v", hsr.err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerCredsHandshakeFailure verifies the case where the server-side
|
||||
// credentials uses a root certificate which does not match the certificate
|
||||
// presented by the client, and hence the handshake must fail.
|
||||
func (s) TestServerCredsHandshakeFailure(t *testing.T) {
|
||||
opts := ServerOptions{FallbackCreds: &errorCreds{}}
|
||||
creds, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
// Create a test server which uses the xDS server credentials created above
|
||||
// to perform TLS handshake on incoming connections.
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
// Create a HandshakeInfo which has a root provider which does not match
|
||||
// the certificate sent by the client.
|
||||
hi := NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
// method.
|
||||
conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
|
||||
|
||||
// ServerHandshake() on the xDS credentials is expected to fail.
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil {
|
||||
return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")}
|
||||
}
|
||||
return handshakeResult{}
|
||||
})
|
||||
defer ts.stop()
|
||||
|
||||
// Dial the test server, and trigger the TLS handshake.
|
||||
rawConn, err := net.Dial("tcp", ts.address)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
||||
}
|
||||
defer rawConn.Close()
|
||||
tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true))
|
||||
tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout))
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read handshake result from the testServer which will return an error if
|
||||
// the handshake succeeded.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
val, err := ts.hsResult.Receive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("testServer failed to return handshake result: %v", err)
|
||||
}
|
||||
hsr := val.(handshakeResult)
|
||||
if hsr.err != nil {
|
||||
t.Fatalf("testServer handshake failure: %v", hsr.err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerCredsHandshakeSuccess verifies success handshake cases.
|
||||
func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
fallbackCreds credentials.TransportCredentials
|
||||
rootProvider certprovider.Provider
|
||||
identityProvider certprovider.Provider
|
||||
requireClientCert bool
|
||||
}{
|
||||
{
|
||||
desc: "fallback",
|
||||
fallbackCreds: makeFallbackServerCreds(t),
|
||||
},
|
||||
{
|
||||
desc: "TLS",
|
||||
fallbackCreds: &errorCreds{},
|
||||
identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"),
|
||||
},
|
||||
{
|
||||
desc: "mTLS",
|
||||
fallbackCreds: &errorCreds{},
|
||||
identityProvider: makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"),
|
||||
rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"),
|
||||
requireClientCert: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
// Create an xDS server credentials.
|
||||
opts := ServerOptions{FallbackCreds: test.fallbackCreds}
|
||||
creds, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
// Create a test server which uses the xDS server credentials
|
||||
// created above to perform TLS handshake on incoming connections.
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
// Create a HandshakeInfo with information from the test table.
|
||||
hi := NewHandshakeInfo(test.rootProvider, test.identityProvider)
|
||||
hi.SetRequireClientCert(test.requireClientCert)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
// method.
|
||||
conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
|
||||
|
||||
// Invoke the ServerHandshake() method on the xDS credentials
|
||||
// and make some sanity checks before pushing the result for
|
||||
// inspection by the main test body.
|
||||
_, ai, err := creds.ServerHandshake(conn)
|
||||
if err != nil {
|
||||
return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)}
|
||||
}
|
||||
if ai.AuthType() != "tls" {
|
||||
return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")}
|
||||
}
|
||||
info, ok := ai.(credentials.TLSInfo)
|
||||
if !ok {
|
||||
return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})}
|
||||
}
|
||||
return handshakeResult{connState: info.State}
|
||||
})
|
||||
defer ts.stop()
|
||||
|
||||
// Dial the test server, and trigger the TLS handshake.
|
||||
rawConn, err := net.Dial("tcp", ts.address)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
||||
}
|
||||
defer rawConn.Close()
|
||||
tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, test.requireClientCert))
|
||||
tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout))
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the handshake result from the testServer which contains the
|
||||
// TLS connection state on the server-side and compare it with the
|
||||
// one received on the client-side.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
val, err := ts.hsResult.Receive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("testServer failed to return handshake result: %v", err)
|
||||
}
|
||||
hsr := val.(handshakeResult)
|
||||
if hsr.err != nil {
|
||||
t.Fatalf("testServer handshake failure: %v", hsr.err)
|
||||
}
|
||||
|
||||
// AuthInfo contains a variety of information. We only verify a
|
||||
// subset here. This is the same subset which is verified in TLS
|
||||
// credentials tests.
|
||||
if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestServerCredsProviderSwitch(t *testing.T) {
|
||||
opts := ServerOptions{FallbackCreds: &errorCreds{}}
|
||||
creds, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
// The first time the handshake function is invoked, it returns a
|
||||
// HandshakeInfo which is expected to fail. Further invocations return a
|
||||
// HandshakeInfo which is expected to succeed.
|
||||
cnt := 0
|
||||
// Create a test server which uses the xDS server credentials created above
|
||||
// to perform TLS handshake on incoming connections.
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
cnt++
|
||||
var hi *HandshakeInfo
|
||||
if cnt == 1 {
|
||||
// Create a HandshakeInfo which has a root provider which does not match
|
||||
// the certificate sent by the client.
|
||||
hi = NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
// method.
|
||||
conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
|
||||
|
||||
// ServerHandshake() on the xDS credentials is expected to fail.
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil {
|
||||
return handshakeResult{err: errors.New("ServerHandshake() succeeded when expected to fail")}
|
||||
}
|
||||
return handshakeResult{}
|
||||
}
|
||||
|
||||
hi = NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
// method.
|
||||
conn := newWrappedConn(rawConn, hi, time.Now().Add(defaultTestTimeout))
|
||||
|
||||
// Invoke the ServerHandshake() method on the xDS credentials
|
||||
// and make some sanity checks before pushing the result for
|
||||
// inspection by the main test body.
|
||||
_, ai, err := creds.ServerHandshake(conn)
|
||||
if err != nil {
|
||||
return handshakeResult{err: fmt.Errorf("ServerHandshake() failed: %v", err)}
|
||||
}
|
||||
if ai.AuthType() != "tls" {
|
||||
return handshakeResult{err: fmt.Errorf("ServerHandshake returned authType %q, want %q", ai.AuthType(), "tls")}
|
||||
}
|
||||
info, ok := ai.(credentials.TLSInfo)
|
||||
if !ok {
|
||||
return handshakeResult{err: fmt.Errorf("ServerHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})}
|
||||
}
|
||||
return handshakeResult{connState: info.State}
|
||||
})
|
||||
defer ts.stop()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
// Dial the test server, and trigger the TLS handshake.
|
||||
rawConn, err := net.Dial("tcp", ts.address)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
||||
}
|
||||
defer rawConn.Close()
|
||||
tlsConn := tls.Client(rawConn, makeClientTLSConfig(t, true))
|
||||
tlsConn.SetDeadline(time.Now().Add(defaultTestTimeout))
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the handshake result from the testServer which contains the
|
||||
// TLS connection state on the server-side and compare it with the
|
||||
// one received on the client-side.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
val, err := ts.hsResult.Receive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("testServer failed to return handshake result: %v", err)
|
||||
}
|
||||
hsr := val.(handshakeResult)
|
||||
if hsr.err != nil {
|
||||
t.Fatalf("testServer handshake failure: %v", hsr.err)
|
||||
}
|
||||
if i == 0 {
|
||||
// We expect the first handshake to fail. So, we skip checks which
|
||||
// compare connection state.
|
||||
continue
|
||||
}
|
||||
// AuthInfo contains a variety of information. We only verify a
|
||||
// subset here. This is the same subset which is verified in TLS
|
||||
// credentials tests.
|
||||
if err := compareConnState(tlsConn.ConnectionState(), hsr.connState); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerClone verifies the Clone() method on client credentials.
|
||||
func (s) TestServerClone(t *testing.T) {
|
||||
opts := ServerOptions{FallbackCreds: makeFallbackServerCreds(t)}
|
||||
orig, err := NewServerCredentials(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
// The credsImpl does not have any exported fields, and it does not make
|
||||
// sense to use any cmp options to look deep into. So, all we make sure here
|
||||
// is that the cloned object points to a different location in memory.
|
||||
if clone := orig.Clone(); clone == orig {
|
||||
t.Fatal("return value from Clone() doesn't point to new credentials instance")
|
||||
}
|
||||
}
|
1
go.sum
1
go.sum
|
@ -43,6 +43,7 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:
|
|||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
|
|
Загрузка…
Ссылка в новой задаче