Support timeout for grpc.Dial
This commit is contained in:
Родитель
575a9b2af8
Коммит
a5ca6e56d2
|
@ -36,6 +36,7 @@ package grpc
|
|||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -50,30 +51,35 @@ var (
|
|||
// ErrClientConnClosing indicates that the operation is illegal because
|
||||
// the session is closing.
|
||||
ErrClientConnClosing = errors.New("grpc: the client connection is closing")
|
||||
// ErrClientConnTimeout indicates that the connection could not be
|
||||
// established within the specified timeout.
|
||||
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect")
|
||||
)
|
||||
|
||||
type dialOptions struct {
|
||||
protocol string
|
||||
authOptions []credentials.Credentials
|
||||
}
|
||||
|
||||
// DialOption configures how we set up the connection including auth
|
||||
// credentials.
|
||||
type DialOption func(*dialOptions)
|
||||
// DialOption configures how we set up the connection.
|
||||
type DialOption func(*transport.DialOptions)
|
||||
|
||||
// WithTransportCredentials returns a DialOption which configures a
|
||||
// connection level security credentials (e.g., TLS/SSL).
|
||||
func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.authOptions = append(o.authOptions, creds)
|
||||
return func(o *transport.DialOptions) {
|
||||
o.AuthOptions = append(o.AuthOptions, creds)
|
||||
}
|
||||
}
|
||||
|
||||
// WithPerRPCCredentials returns a DialOption which sets
|
||||
// credentials which will place auth state on each outbound RPC.
|
||||
func WithPerRPCCredentials(creds credentials.Credentials) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.authOptions = append(o.authOptions, creds)
|
||||
return func(o *transport.DialOptions) {
|
||||
o.AuthOptions = append(o.AuthOptions, creds)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns a DialOption which configures a timeout for dialing a
|
||||
// client connection.
|
||||
func WithTimeout(d time.Duration) DialOption {
|
||||
return func(o *transport.DialOptions) {
|
||||
o.Timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +108,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||
// ClientConn represents a client connection to an RPC service.
|
||||
type ClientConn struct {
|
||||
target string
|
||||
dopts dialOptions
|
||||
dopts transport.DialOptions
|
||||
shutdownChan chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
|
@ -119,6 +125,7 @@ type ClientConn struct {
|
|||
|
||||
func (cc *ClientConn) resetTransport(closeTransport bool) error {
|
||||
var retries int
|
||||
start := time.Now()
|
||||
for {
|
||||
cc.mu.Lock()
|
||||
t := cc.transport
|
||||
|
@ -133,12 +140,22 @@ func (cc *ClientConn) resetTransport(closeTransport bool) error {
|
|||
if closeTransport {
|
||||
t.Close()
|
||||
}
|
||||
newTransport, err := transport.NewClientTransport(cc.dopts.protocol, cc.target, cc.dopts.authOptions)
|
||||
// Adjust timeout for the current try.
|
||||
if cc.dopts.Timeout > 0 {
|
||||
cc.dopts.Timeout -= time.Since(start)
|
||||
if cc.dopts.Timeout <= 0 {
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
}
|
||||
newTransport, err := transport.NewClientTransport(cc.target, cc.dopts)
|
||||
if err != nil {
|
||||
// TODO(zhaoq): Record the error with glog.V.
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return ErrClientConnTimeout
|
||||
}
|
||||
closeTransport = false
|
||||
time.Sleep(backoff(retries))
|
||||
retries++
|
||||
// TODO(zhaoq): Record the error with glog.V.
|
||||
log.Printf("grpc: ClientConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target)
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -71,9 +71,15 @@ type Credentials interface {
|
|||
// TransportAuthenticator defines the common interface all supported transport
|
||||
// authentication protocols (e.g., TLS, SSL) must implement.
|
||||
type TransportAuthenticator interface {
|
||||
// Dial connects to the given network address and does the authentication
|
||||
// handshake specified by the corresponding authentication protocol.
|
||||
Dial(addr string) (net.Conn, error)
|
||||
// Dial connects to the given network address using net.Dial and then
|
||||
// does the authentication handshake specified by the corresponding
|
||||
// authentication protocol.
|
||||
Dial(network, addr string) (net.Conn, error)
|
||||
// DialWithDialer connects to the given network address using
|
||||
// dialer.Dialand does the authentication handshake specified by the
|
||||
// corresponding authentication protocol. Any timeout or deadline
|
||||
// given in the dialer apply to connection and handshake as a whole.
|
||||
DialWithDialer(dialer *net.Dialer, network, addr string) (net.Conn, error)
|
||||
// NewListener creates a listener which accepts connections with requested
|
||||
// authentication handshake.
|
||||
NewListener(lis net.Listener) net.Listener
|
||||
|
@ -103,8 +109,7 @@ func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, e
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// Dial connects to addr and performs TLS handshake.
|
||||
func (c *tlsCreds) Dial(addr string) (_ net.Conn, err error) {
|
||||
func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ net.Conn, err error) {
|
||||
name := c.serverName
|
||||
if name == "" {
|
||||
name, _, err = net.SplitHostPort(addr)
|
||||
|
@ -112,13 +117,18 @@ func (c *tlsCreds) Dial(addr string) (_ net.Conn, err error) {
|
|||
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
|
||||
}
|
||||
}
|
||||
return tls.Dial("tcp", addr, &tls.Config{
|
||||
return tls.DialWithDialer(dialer, "tcp", addr, &tls.Config{
|
||||
RootCAs: c.rootCAs,
|
||||
NextProtos: alpnProtoStr,
|
||||
ServerName: name,
|
||||
})
|
||||
}
|
||||
|
||||
// Dial connects to addr and performs TLS handshake.
|
||||
func (c *tlsCreds) Dial(network, addr string) (_ net.Conn, err error) {
|
||||
return c.DialWithDialer(new(net.Dialer), network, addr)
|
||||
}
|
||||
|
||||
// NewListener creates a net.Listener with a TLS configuration constructed
|
||||
// from the information in tlsCreds.
|
||||
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
|
||||
|
|
|
@ -193,6 +193,16 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
|
|||
|
||||
const tlsDir = "testdata/"
|
||||
|
||||
func TestDialTimeout(t *testing.T) {
|
||||
conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond))
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
}
|
||||
if err != grpc.ErrClientConnTimeout {
|
||||
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, grpc.ErrClientConnTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, tc testpb.TestServiceClient) {
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
|
|
|
@ -96,29 +96,38 @@ type http2Client struct {
|
|||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||
// and starts to receive messages on it. Non-nil error returns if construction
|
||||
// fails.
|
||||
func newHTTP2Client(addr string, authOpts []credentials.Credentials) (_ ClientTransport, err error) {
|
||||
func newHTTP2Client(addr string, opts DialOptions) (_ ClientTransport, err error) {
|
||||
var (
|
||||
connErr error
|
||||
conn net.Conn
|
||||
)
|
||||
scheme := "http"
|
||||
// TODO(zhaoq): Use DialTimeout instead.
|
||||
for _, c := range authOpts {
|
||||
for _, c := range opts.AuthOptions {
|
||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||
scheme = "https"
|
||||
// TODO(zhaoq): Now the first TransportAuthenticator is used if there are
|
||||
// multiple ones provided. Revisit this if it is not appropriate. Probably
|
||||
// place the ClientTransport construction into a separate function to make
|
||||
// things clear.
|
||||
conn, connErr = ccreds.Dial(addr)
|
||||
if opts.Timeout > 0 {
|
||||
dialer := &net.Dialer{Timeout: opts.Timeout}
|
||||
conn, connErr = ccreds.DialWithDialer(dialer, "tcp", addr)
|
||||
} else {
|
||||
conn, connErr = ccreds.Dial("tcp", addr)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if scheme == "http" {
|
||||
conn, connErr = net.Dial("tcp", addr)
|
||||
if opts.Timeout > 0 {
|
||||
conn, connErr = net.DialTimeout("tcp", addr, opts.Timeout)
|
||||
} else {
|
||||
conn, connErr = net.Dial("tcp", addr)
|
||||
}
|
||||
}
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
return nil, connErr
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
|
@ -155,7 +164,7 @@ func newHTTP2Client(addr string, authOpts []credentials.Credentials) (_ ClientTr
|
|||
state: reachable,
|
||||
activeStreams: make(map[uint32]*Stream),
|
||||
maxStreams: math.MaxUint32,
|
||||
authCreds: authOpts,
|
||||
authCreds: opts.AuthOptions,
|
||||
}
|
||||
go t.controller()
|
||||
t.writableChan <- 0
|
||||
|
|
|
@ -44,6 +44,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -310,10 +311,17 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
|
|||
return newHTTP2Server(conn, maxStreams)
|
||||
}
|
||||
|
||||
// NewClientTransport establishes the transport with the required protocol
|
||||
// DialOptions covers all relevant options for dial a client connection.
|
||||
type DialOptions struct {
|
||||
Protocol string
|
||||
AuthOptions []credentials.Credentials
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// NewClientTransport establishes the transport with the required DialOptions
|
||||
// and returns it to the caller.
|
||||
func NewClientTransport(protocol, target string, authOpts []credentials.Credentials) (ClientTransport, error) {
|
||||
return newHTTP2Client(target, authOpts)
|
||||
func NewClientTransport(target string, opts DialOptions) (ClientTransport, error) {
|
||||
return newHTTP2Client(target, opts)
|
||||
}
|
||||
|
||||
// Options provides additional hints and information for message
|
||||
|
|
|
@ -181,9 +181,12 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool)
|
|||
if err != nil {
|
||||
t.Fatalf("Failed to create credentials %v", err)
|
||||
}
|
||||
ct, connErr = NewClientTransport("http2", addr, []credentials.Credentials{creds})
|
||||
dopts := DialOptions {
|
||||
AuthOptions: []credentials.Credentials{creds},
|
||||
}
|
||||
ct, connErr = NewClientTransport(addr, dopts)
|
||||
} else {
|
||||
ct, connErr = NewClientTransport("http2", addr, nil)
|
||||
ct, connErr = NewClientTransport(addr, DialOptions{})
|
||||
}
|
||||
if connErr != nil {
|
||||
t.Fatalf("failed to create transport: %v", connErr)
|
||||
|
|
Загрузка…
Ссылка в новой задаче