diff --git a/clientconn.go b/clientconn.go index fb4e0446..3aaee7d8 100644 --- a/clientconn.go +++ b/clientconn.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" @@ -124,12 +125,13 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { // e.g. to use dns resolver, a "dns:///" prefix should be applied to the target. func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { cc := &ClientConn{ - target: target, - csMgr: &connectivityStateManager{}, - conns: make(map[*addrConn]struct{}), - dopts: defaultDialOptions(), - blockingpicker: newPickerWrapper(), - czData: new(channelzData), + target: target, + csMgr: &connectivityStateManager{}, + conns: make(map[*addrConn]struct{}), + dopts: defaultDialOptions(), + blockingpicker: newPickerWrapper(), + czData: new(channelzData), + firstResolveEvent: grpcsync.NewEvent(), } cc.retryThrottler.Store((*retryThrottler)(nil)) cc.ctx, cc.cancel = context.WithCancel(context.Background()) @@ -402,6 +404,8 @@ type ClientConn struct { balancerWrapper *ccBalancerWrapper retryThrottler atomic.Value + firstResolveEvent *grpcsync.Event + channelzID int64 // channelz unique identification number czData *channelzData } @@ -447,6 +451,25 @@ func (cc *ClientConn) scWatcher() { } } +// waitForResolvedAddrs blocks until the resolver has provided addresses or the +// context expires. Returns nil unless the context expires first; otherwise +// returns a status error based on the context. +func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error { + // This is on the RPC path, so we use a fast path to avoid the + // more-expensive "select" below after the resolver has returned once. + if cc.firstResolveEvent.HasFired() { + return nil + } + select { + case <-cc.firstResolveEvent.Done(): + return nil + case <-ctx.Done(): + return status.FromContextError(ctx.Err()).Err() + case <-cc.ctx.Done(): + return ErrClientConnClosing + } +} + func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) { cc.mu.Lock() defer cc.mu.Unlock() @@ -460,6 +483,7 @@ func (cc *ClientConn) handleResolvedAddrs(addrs []resolver.Address, err error) { } cc.curAddresses = addrs + cc.firstResolveEvent.Fire() if cc.dopts.balancerBuilder == nil { // Only look at balancer types and switch balancer if balancer dial diff --git a/internal/grpcsync/event.go b/internal/grpcsync/event.go index 85dbea88..fbe697c3 100644 --- a/internal/grpcsync/event.go +++ b/internal/grpcsync/event.go @@ -20,12 +20,16 @@ // the sync package. package grpcsync -import "sync" +import ( + "sync" + "sync/atomic" +) // Event represents a one-time event that may occur in the future. type Event struct { - c chan struct{} - o sync.Once + fired int32 + c chan struct{} + o sync.Once } // Fire causes e to complete. It is safe to call multiple times, and @@ -34,6 +38,7 @@ type Event struct { func (e *Event) Fire() bool { ret := false e.o.Do(func() { + atomic.StoreInt32(&e.fired, 1) close(e.c) ret = true }) @@ -47,12 +52,7 @@ func (e *Event) Done() <-chan struct{} { // HasFired returns true if Fire has been called. func (e *Event) HasFired() bool { - select { - case <-e.c: - return true - default: - return false - } + return atomic.LoadInt32(&e.fired) == 1 } // NewEvent returns a new, ready-to-use Event. diff --git a/stream.go b/stream.go index 15e0e2ec..47aa822c 100644 --- a/stream.go +++ b/stream.go @@ -166,6 +166,11 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth }() } c := defaultCallInfo() + // Provide an opportunity for the first RPC to see the first service config + // provided by the resolver. + if err := cc.waitForResolvedAddrs(ctx); err != nil { + return nil, err + } mc := cc.GetMethodConfig(method) if mc.WaitForReady != nil { c.failFast = !*mc.WaitForReady diff --git a/test/end2end_test.go b/test/end2end_test.go index 59a7d8ce..39e9b00f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7147,3 +7147,55 @@ func (lis notifyingListener) Accept() (net.Conn, error) { defer lis.connEstablished.Fire() return lis.Listener.Accept() } + +func TestRPCWaitsForResolver(t *testing.T) { + te := testServiceConfigSetup(t, tcpClearRREnv) + te.startServer(&testServer{security: tcpClearRREnv.security}) + defer te.tearDown() + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + te.resolverScheme = r.Scheme() + te.nonBlockingDial = true + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + // With no resolved addresses yet, this will timeout. + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.DeadlineExceeded) + } + + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func() { + time.Sleep(time.Second) + r.NewServiceConfig(`{ + "methodConfig": [ + { + "name": [ + { + "service": "grpc.testing.TestService", + "method": "UnaryCall" + } + ], + "maxRequestMessageBytes": 0 + } + ] + }`) + r.NewAddress([]resolver.Address{{Addr: te.srvAddr}}) + }() + // We wait a second before providing a service config and resolving + // addresses. So this will wait for that and then honor the + // maxRequestMessageBytes it contains. + if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{ResponseType: testpb.PayloadType_UNCOMPRESSABLE}); status.Code(err) != codes.ResourceExhausted { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err) + } + if got := ctx.Err(); got != nil { + t.Fatalf("ctx.Err() = %v; want nil (deadline should be set short by service config)", got) + } + if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, nil", err) + } +}