diff --git a/credentials/credentials.go b/credentials/credentials.go index 576cf62e..7cfef6c2 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -86,21 +86,9 @@ type TransportAuthenticator interface { Credentials } -// tlsCreds is the credentials required for authenticating a connection. type tlsCreds struct { - // serverName is used to verify the hostname on the returned - // certificates. It is also included in the client's handshake - // to support virtual hosting. This is optional. If it is not - // set gRPC internals will use the dialing address instead. - serverName string - // rootCAs defines the set of root certificate authorities - // that clients use when verifying server certificates. - // If rootCAs is nil, tls uses the host's root CA set. - rootCAs *x509.CertPool - // certificates contains one or more certificate chains - // to present to the other side of the connection. - // Server configurations must include at least one certificate. - certificates []tls.Certificate + // TLS configuration + config tls.Config } // GetRequestMetadata returns nil, nil since TLS credentials does not have @@ -110,18 +98,13 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e } func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) { - name := c.serverName - if name == "" { - name, _, err = net.SplitHostPort(addr) + if c.config.ServerName == "" { + c.config.ServerName, _, err = net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("credentials: failed to parse server address %v", err) } } - return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{ - RootCAs: c.rootCAs, - NextProtos: alpnProtoStr, - ServerName: name, - }) + return tls.DialWithDialer(dialer, "tcp", addr, &c.config) } // Dial connects to addr and performs TLS handshake. @@ -132,18 +115,18 @@ func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) { // NewListener creates a net.Listener with a TLS configuration constructed // from the information in tlsCreds. func (c *tlsCreds) NewListener(lis net.Listener) net.Listener { - return tls.NewListener(lis, &tls.Config{ - Certificates: c.certificates, - NextProtos: alpnProtoStr, - }) + return tls.NewListener(lis, &c.config) +} + +func NewTLS(c *tls.Config) TransportAuthenticator { + tc := &tlsCreds{*c} + tc.config.NextProtos = alpnProtoStr + return tc } // NewClientTLSFromCert constructs a TLS from the input certificate for client. func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { - return &tlsCreds{ - serverName: serverName, - rootCAs: cp, - } + return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) } // NewClientTLSFromFile constructs a TLS from the input certificate file for client. @@ -156,17 +139,12 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, if !cp.AppendCertsFromPEM(b) { return nil, fmt.Errorf("credentials: failed to append certificates") } - return &tlsCreds{ - serverName: serverName, - rootCAs: cp, - }, nil + return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil } // NewServerTLSFromCert constructs a TLS from the input certificate for server. func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { - return &tlsCreds{ - certificates: []tls.Certificate{*cert}, - } + return NewTLS(&tls.Config{ Certificates: []tls.Certificate{*cert} }) } // NewServerTLSFromFile constructs a TLS from the input certificate file and key @@ -176,9 +154,7 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err if err != nil { return nil, err } - return &tlsCreds{ - certificates: []tls.Certificate{cert}, - }, nil + return NewTLS(&tls.Config{ Certificates: []tls.Certificate{cert} }), nil } // TokenSource supplies credentials from an oauth2.TokenSource.