585 строки
21 KiB
Go
585 строки
21 KiB
Go
/*
|
|
*
|
|
* Copyright 2020 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package xds
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/credentials/tls/certprovider"
|
|
"google.golang.org/grpc/internal"
|
|
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
|
|
"google.golang.org/grpc/internal/grpctest"
|
|
"google.golang.org/grpc/internal/testutils"
|
|
"google.golang.org/grpc/resolver"
|
|
"google.golang.org/grpc/testdata"
|
|
)
|
|
|
|
const (
|
|
defaultTestTimeout = 10 * time.Second
|
|
defaultTestShortTimeout = 10 * time.Millisecond
|
|
defaultTestCertSAN = "*.test.example.com"
|
|
authority = "authority"
|
|
)
|
|
|
|
type s struct {
|
|
grpctest.Tester
|
|
}
|
|
|
|
func Test(t *testing.T) {
|
|
grpctest.RunSubTests(t, s{})
|
|
}
|
|
|
|
// Helper function to create a real TLS client credentials which is used as
|
|
// fallback credentials from multiple tests.
|
|
func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
|
|
creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return creds
|
|
}
|
|
|
|
// testServer is a no-op server which listens on a local TCP port for incoming
|
|
// connections, and performs a manual TLS handshake on the received raw
|
|
// connection using a user specified handshake function. It then makes the
|
|
// result of the handshake operation available through a channel for tests to
|
|
// inspect. Tests should stop the testServer as part of their cleanup.
|
|
type testServer struct {
|
|
lis net.Listener
|
|
address string // Listening address of the test server.
|
|
handshakeFunc testHandshakeFunc // Test specified handshake function.
|
|
hsResult *testutils.Channel // Channel to deliver handshake results.
|
|
}
|
|
|
|
// handshakeResult wraps the result of the handshake operation on the test
|
|
// server. It consists of TLS connection state and an error, if the handshake
|
|
// failed. This result is delivered on the `hsResult` channel on the testServer.
|
|
type handshakeResult struct {
|
|
connState tls.ConnectionState
|
|
err error
|
|
}
|
|
|
|
// Configurable handshake function for the testServer. Tests can set this to
|
|
// simulate different conditions like handshake success, failure, timeout etc.
|
|
type testHandshakeFunc func(net.Conn) handshakeResult
|
|
|
|
// newTestServerWithHandshakeFunc starts a new testServer which listens for
|
|
// connections on a local TCP port, and uses the provided custom handshake
|
|
// function to perform TLS handshake.
|
|
func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
|
|
ts := &testServer{
|
|
handshakeFunc: f,
|
|
hsResult: testutils.NewChannel(),
|
|
}
|
|
ts.start()
|
|
return ts
|
|
}
|
|
|
|
// starts actually starts listening on a local TCP port, and spawns a goroutine
|
|
// to handle new connections.
|
|
func (ts *testServer) start() error {
|
|
lis, err := net.Listen("tcp", "localhost:0")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
ts.lis = lis
|
|
ts.address = lis.Addr().String()
|
|
go ts.handleConn()
|
|
return nil
|
|
}
|
|
|
|
// handleconn accepts a new raw connection, and invokes the test provided
|
|
// handshake function to perform TLS handshake, and returns the result on the
|
|
// `hsResult` channel.
|
|
func (ts *testServer) handleConn() {
|
|
for {
|
|
rawConn, err := ts.lis.Accept()
|
|
if err != nil {
|
|
// Once the listeners closed, Accept() will return with an error.
|
|
return
|
|
}
|
|
hsr := ts.handshakeFunc(rawConn)
|
|
ts.hsResult.Send(hsr)
|
|
}
|
|
}
|
|
|
|
// stop closes the associated listener which causes the connection handling
|
|
// goroutine to exit.
|
|
func (ts *testServer) stop() {
|
|
ts.lis.Close()
|
|
}
|
|
|
|
// A handshake function which simulates a successful handshake without client
|
|
// authentication (server does not request for client certificate during the
|
|
// handshake here).
|
|
func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
|
|
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
|
if err != nil {
|
|
return handshakeResult{err: err}
|
|
}
|
|
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
|
|
conn := tls.Server(rawConn, cfg)
|
|
if err := conn.Handshake(); err != nil {
|
|
return handshakeResult{err: err}
|
|
}
|
|
return handshakeResult{connState: conn.ConnectionState()}
|
|
}
|
|
|
|
// A handshake function which simulates a successful handshake with mutual
|
|
// authentication.
|
|
func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
|
|
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
|
if err != nil {
|
|
return handshakeResult{err: err}
|
|
}
|
|
pemData, err := ioutil.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
|
|
if err != nil {
|
|
return handshakeResult{err: err}
|
|
}
|
|
roots := x509.NewCertPool()
|
|
roots.AppendCertsFromPEM(pemData)
|
|
cfg := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientCAs: roots,
|
|
}
|
|
conn := tls.Server(rawConn, cfg)
|
|
if err := conn.Handshake(); err != nil {
|
|
return handshakeResult{err: err}
|
|
}
|
|
return handshakeResult{connState: conn.ConnectionState()}
|
|
}
|
|
|
|
// fakeProvider is an implementation of the certprovider.Provider interface
|
|
// which returns the configured key material and error in calls to
|
|
// KeyMaterial().
|
|
type fakeProvider struct {
|
|
km *certprovider.KeyMaterial
|
|
err error
|
|
}
|
|
|
|
func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
|
|
return f.km, f.err
|
|
}
|
|
|
|
func (f *fakeProvider) Close() {}
|
|
|
|
// makeIdentityProvider creates a new instance of the fakeProvider returning the
|
|
// identity key material specified in the provider file paths.
|
|
func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
|
|
t.Helper()
|
|
cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
|
|
}
|
|
|
|
// makeRootProvider creates a new instance of the fakeProvider returning the
|
|
// root key material specified in the provider file paths.
|
|
func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
|
|
pemData, err := ioutil.ReadFile(testdata.Path(caPath))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
roots := x509.NewCertPool()
|
|
roots.AppendCertsFromPEM(pemData)
|
|
return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
|
|
}
|
|
|
|
// newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
|
|
// context value added to it.
|
|
func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sans ...string) context.Context {
|
|
// Creating the HandshakeInfo and adding it to the attributes is very
|
|
// similar to what the CDS balancer would do when it intercepts calls to
|
|
// NewSubConn().
|
|
info := xdsinternal.NewHandshakeInfo(root, identity, sans...)
|
|
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
|
|
|
|
// Moving the attributes from the resolver.Address to the context passed to
|
|
// the handshaker is done in the transport layer. Since we directly call the
|
|
// handshaker in these tests, we need to do the same here.
|
|
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
|
|
return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
|
|
}
|
|
|
|
// compareAuthInfo compares the AuthInfo received on the client side after a
|
|
// successful handshake with the authInfo available on the testServer.
|
|
func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
|
|
if ai.AuthType() != "tls" {
|
|
return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
|
|
}
|
|
info, ok := ai.(credentials.TLSInfo)
|
|
if !ok {
|
|
return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
|
|
}
|
|
gotState := info.State
|
|
|
|
// Read the handshake result from the testServer which contains the TLS
|
|
// connection state and compare it with the one received on the client-side.
|
|
val, err := ts.hsResult.Receive(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("testServer failed to return handshake result: %v", err)
|
|
}
|
|
hsr := val.(handshakeResult)
|
|
if hsr.err != nil {
|
|
return fmt.Errorf("testServer handshake failure: %v", hsr.err)
|
|
}
|
|
// AuthInfo contains a variety of information. We only verify a subset here.
|
|
// This is the same subset which is verified in TLS credentials tests.
|
|
if err := compareConnState(gotState, hsr.connState); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func compareConnState(got, want tls.ConnectionState) error {
|
|
switch {
|
|
case got.Version != want.Version:
|
|
return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
|
|
case got.HandshakeComplete != want.HandshakeComplete:
|
|
return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
|
|
case got.CipherSuite != want.CipherSuite:
|
|
return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
|
|
case got.NegotiatedProtocol != want.NegotiatedProtocol:
|
|
return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TestClientCredsWithoutFallback verifies that the call to
|
|
// NewClientCredentials() fails when no fallback is specified.
|
|
func (s) TestClientCredsWithoutFallback(t *testing.T) {
|
|
if _, err := NewClientCredentials(ClientOptions{}); err == nil {
|
|
t.Fatal("NewClientCredentials() succeeded without specifying fallback")
|
|
}
|
|
}
|
|
|
|
// TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in
|
|
// HandshakeInfo is invalid because it does not contain the expected certificate
|
|
// providers.
|
|
func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
creds, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{})
|
|
if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
|
|
t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
|
|
}
|
|
}
|
|
|
|
// TestClientCredsProviderFailure verifies the cases where an expected
|
|
// certificate provider is missing in the HandshakeInfo value in the context.
|
|
func (s) TestClientCredsProviderFailure(t *testing.T) {
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
creds, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
tests := []struct {
|
|
desc string
|
|
rootProvider certprovider.Provider
|
|
identityProvider certprovider.Provider
|
|
wantErr string
|
|
}{
|
|
{
|
|
desc: "erroring root provider",
|
|
rootProvider: &fakeProvider{err: errors.New("root provider error")},
|
|
wantErr: "root provider error",
|
|
},
|
|
{
|
|
desc: "erroring identity provider",
|
|
rootProvider: &fakeProvider{km: &certprovider.KeyMaterial{}},
|
|
identityProvider: &fakeProvider{err: errors.New("identity provider error")},
|
|
wantErr: "identity provider error",
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.desc, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider)
|
|
if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
|
|
t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestClientCredsSuccess verifies successful client handshake cases.
|
|
func (s) TestClientCredsSuccess(t *testing.T) {
|
|
tests := []struct {
|
|
desc string
|
|
handshakeFunc testHandshakeFunc
|
|
handshakeInfoCtx func(ctx context.Context) context.Context
|
|
}{
|
|
{
|
|
desc: "fallback",
|
|
handshakeFunc: testServerTLSHandshake,
|
|
handshakeInfoCtx: func(ctx context.Context) context.Context {
|
|
// Since we don't add a HandshakeInfo to the context, the
|
|
// ClientHandshake() method will delegate to the fallback.
|
|
return ctx
|
|
},
|
|
},
|
|
{
|
|
desc: "TLS",
|
|
handshakeFunc: testServerTLSHandshake,
|
|
handshakeInfoCtx: func(ctx context.Context) context.Context {
|
|
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
|
|
},
|
|
},
|
|
{
|
|
desc: "mTLS",
|
|
handshakeFunc: testServerMutualTLSHandshake,
|
|
handshakeInfoCtx: func(ctx context.Context) context.Context {
|
|
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
|
|
},
|
|
},
|
|
{
|
|
desc: "mTLS with no acceptedSANs specified",
|
|
handshakeFunc: testServerMutualTLSHandshake,
|
|
handshakeInfoCtx: func(ctx context.Context) context.Context {
|
|
return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.desc, func(t *testing.T) {
|
|
ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
|
|
defer ts.stop()
|
|
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
creds, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
conn, err := net.Dial("tcp", ts.address)
|
|
if err != nil {
|
|
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
_, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
|
|
if err != nil {
|
|
t.Fatalf("ClientHandshake() returned failed: %q", err)
|
|
}
|
|
if err := compareAuthInfo(ctx, ts, ai); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
|
|
clientDone := make(chan struct{})
|
|
// A handshake function which simulates a handshake timeout from the
|
|
// server-side by simply blocking on the client-side handshake to timeout
|
|
// and not writing any handshake data.
|
|
hErr := errors.New("server handshake error")
|
|
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
|
<-clientDone
|
|
return handshakeResult{err: hErr}
|
|
})
|
|
defer ts.stop()
|
|
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
creds, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
conn, err := net.Dial("tcp", ts.address)
|
|
if err != nil {
|
|
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
|
|
defer sCancel()
|
|
ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
|
|
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
|
|
t.Fatal("ClientHandshake() succeeded when expected to timeout")
|
|
}
|
|
close(clientDone)
|
|
|
|
// Read the handshake result from the testServer and make sure the expected
|
|
// error is returned.
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
val, err := ts.hsResult.Receive(ctx)
|
|
if err != nil {
|
|
t.Fatalf("testServer failed to return handshake result: %v", err)
|
|
}
|
|
hsr := val.(handshakeResult)
|
|
if hsr.err != hErr {
|
|
t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
|
|
}
|
|
}
|
|
|
|
// TestClientCredsHandshakeFailure verifies different handshake failure cases.
|
|
func (s) TestClientCredsHandshakeFailure(t *testing.T) {
|
|
tests := []struct {
|
|
desc string
|
|
handshakeFunc testHandshakeFunc
|
|
rootProvider certprovider.Provider
|
|
san string
|
|
wantErr string
|
|
}{
|
|
{
|
|
desc: "cert validation failure",
|
|
handshakeFunc: testServerTLSHandshake,
|
|
rootProvider: makeRootProvider(t, "x509/client_ca_cert.pem"),
|
|
san: defaultTestCertSAN,
|
|
wantErr: "x509: certificate signed by unknown authority",
|
|
},
|
|
{
|
|
desc: "SAN mismatch",
|
|
handshakeFunc: testServerTLSHandshake,
|
|
rootProvider: makeRootProvider(t, "x509/server_ca_cert.pem"),
|
|
san: "bad-san",
|
|
wantErr: "does not match any of the accepted SANs",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.desc, func(t *testing.T) {
|
|
ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
|
|
defer ts.stop()
|
|
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
creds, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
conn, err := net.Dial("tcp", ts.address)
|
|
if err != nil {
|
|
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
|
|
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
|
|
t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestClientCredsProviderSwitch verifies the case where the first attempt of
|
|
// ClientHandshake fails because of a handshake failure. Then we update the
|
|
// certificate provider and the second attempt succeeds. This is an
|
|
// approximation of the flow of events when the control plane specifies new
|
|
// security config which results in new certificate providers being used.
|
|
func (s) TestClientCredsProviderSwitch(t *testing.T) {
|
|
ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
|
|
defer ts.stop()
|
|
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
creds, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
conn, err := net.Dial("tcp", ts.address)
|
|
if err != nil {
|
|
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
// Create a root provider which will fail the handshake because it does not
|
|
// use the correct trust roots.
|
|
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
|
|
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, defaultTestCertSAN)
|
|
|
|
// We need to repeat most of what newTestContextWithHandshakeInfo() does
|
|
// here because we need access to the underlying HandshakeInfo so that we
|
|
// can update it before the next call to ClientHandshake().
|
|
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
|
|
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
|
|
ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
|
|
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
|
|
t.Fatal("ClientHandshake() succeeded when expected to fail")
|
|
}
|
|
// Drain the result channel on the test server so that we can inspect the
|
|
// result for the next handshake.
|
|
_, err = ts.hsResult.Receive(ctx)
|
|
if err != nil {
|
|
t.Errorf("testServer failed to return handshake result: %v", err)
|
|
}
|
|
|
|
conn, err = net.Dial("tcp", ts.address)
|
|
if err != nil {
|
|
t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Create a new root provider which uses the correct trust roots. And update
|
|
// the HandshakeInfo with the new provider.
|
|
root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
|
|
handshakeInfo.SetRootCertProvider(root2)
|
|
_, ai, err := creds.ClientHandshake(ctx, authority, conn)
|
|
if err != nil {
|
|
t.Fatalf("ClientHandshake() returned failed: %q", err)
|
|
}
|
|
if err := compareAuthInfo(ctx, ts, ai); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// TestClientClone verifies the Clone() method on client credentials.
|
|
func (s) TestClientClone(t *testing.T) {
|
|
opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
|
|
orig, err := NewClientCredentials(opts)
|
|
if err != nil {
|
|
t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
|
|
}
|
|
|
|
// The credsImpl does not have any exported fields, and it does not make
|
|
// sense to use any cmp options to look deep into. So, all we make sure here
|
|
// is that the cloned object points to a different location in memory.
|
|
if clone := orig.Clone(); clone == orig {
|
|
t.Fatal("return value from Clone() doesn't point to new credentials instance")
|
|
}
|
|
}
|