From 3926816d541db48f3e4c1c87cff75ceeb205309e Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 14 Feb 2018 14:13:10 -0800 Subject: [PATCH] addrConn: Report underlying connection error in RPC error (#1855) --- clientconn.go | 1 + picker_wrapper.go | 19 ++++++++++++++++- test/end2end_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/clientconn.go b/clientconn.go index f61a6fe4..0c59c988 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1129,6 +1129,7 @@ func (ac *addrConn) createTransport(connectRetryNum, ridx int, backoffDeadline, newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, target, copts, onPrefaceReceipt) if err != nil { cancel() + ac.cc.blockingpicker.updateConnectionError(err) ac.mu.Lock() if ac.state == connectivity.Shutdown { // ac.tearDown(...) has been invoked. diff --git a/picker_wrapper.go b/picker_wrapper.go index db82bfb3..4d008259 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -36,6 +36,10 @@ type pickerWrapper struct { done bool blockingCh chan struct{} picker balancer.Picker + + // The latest connection happened. + connErrMu sync.Mutex + connErr error } func newPickerWrapper() *pickerWrapper { @@ -43,6 +47,19 @@ func newPickerWrapper() *pickerWrapper { return bp } +func (bp *pickerWrapper) updateConnectionError(err error) { + bp.connErrMu.Lock() + bp.connErr = err + bp.connErrMu.Unlock() +} + +func (bp *pickerWrapper) connectionError() error { + bp.connErrMu.Lock() + err := bp.connErr + bp.connErrMu.Unlock() + return err +} + // updatePicker is called by UpdateBalancerState. It unblocks all blocked pick. func (bp *pickerWrapper) updatePicker(p balancer.Picker) { bp.mu.Lock() @@ -107,7 +124,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. if !failfast { continue } - return nil, nil, status.Errorf(codes.Unavailable, "%v", err) + return nil, nil, status.Errorf(codes.Unavailable, "%v, latest connection error: %v", err, bp.connectionError()) default: // err is some other error. return nil, nil, toRPCErr(err) diff --git a/test/end2end_test.go b/test/end2end_test.go index 264fcba1..eb8a629b 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -6040,3 +6040,53 @@ func testClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T, e } wg.Wait() } + +const clientAlwaysFailCredErrorMsg = "clientAlwaysFailCred always fails" + +var errClientAlwaysFailCred = errors.New(clientAlwaysFailCredErrorMsg) + +type clientAlwaysFailCred struct{} + +func (c clientAlwaysFailCred) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errClientAlwaysFailCred +} +func (c clientAlwaysFailCred) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c clientAlwaysFailCred) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c clientAlwaysFailCred) Clone() credentials.TransportCredentials { + return nil +} +func (c clientAlwaysFailCred) OverrideServerName(s string) error { + return nil +} + +func TestFailFastRPCErrorOnBadCertificates(t *testing.T) { + te := newTest(t, env{name: "bad-cred", network: "tcp", security: "clientAlwaysFailCred", balancer: "round_robin"}) + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + opts := []grpc.DialOption{grpc.WithTransportCredentials(clientAlwaysFailCred{})} + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, te.srvAddr, opts...) + if err != nil { + t.Fatalf("Dial(_) = %v, want %v", err, nil) + } + defer cc.Close() + + tc := testpb.NewTestServiceClient(cc) + for i := 0; i < 1000; i++ { + // This loop runs for at most 1 second. The first several RPCs will fail + // with Unavailable because the connection hasn't started. When the + // first connection failed with creds error, the next RPC should also + // fail with the expected error. + if _, err = tc.EmptyCall(context.Background(), &testpb.Empty{}); strings.Contains(err.Error(), clientAlwaysFailCredErrorMsg) { + return + } + time.Sleep(time.Millisecond) + } + te.t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want err.Error() contains %q", err, clientAlwaysFailCredErrorMsg) +}