diff --git a/balancer/balancer.go b/balancer/balancer.go index cd2682f5..b7c685ad 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -128,6 +128,10 @@ type PickOptions struct{} type DoneInfo struct { // Err is the rpc error the RPC finished with. It could be nil. Err error + // BytesSent indicates if any bytes have been sent to the server. + BytesSent bool + // BytesReceived indicates if any byte has been received from the server. + BytesReceived bool } var ( diff --git a/call.go b/call.go index 0854f84b..9c4b2551 100644 --- a/call.go +++ b/call.go @@ -277,11 +277,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) if err != nil { if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } // Retry a non-failfast RPC when // i) the server started to drain before this RPC was initiated. @@ -301,11 +301,11 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli err = recvResponse(ctx, cc.dopts, t, c, stream, reply) if err != nil { if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } if !c.failFast && stream.Unprocessed() { // In these cases, the server did not receive the data, but we still @@ -323,12 +323,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) } t.CloseStream(stream, nil) + err = stream.Status().Err() if done != nil { - updateRPCInfoInContext(ctx, rpcInfo{ - bytesSent: true, - bytesReceived: stream.BytesReceived(), + done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: stream.BytesReceived(), }) - done(balancer.DoneInfo{Err: err}) } if !c.failFast && stream.Unprocessed() { // In these cases, the server did not receive the data, but we still @@ -339,6 +340,6 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli continue } } - return stream.Status().Err() + return err } } diff --git a/clientconn.go b/clientconn.go index 9d68df9f..87b3b578 100644 --- a/clientconn.go +++ b/clientconn.go @@ -97,6 +97,8 @@ type dialOptions struct { callOptions []CallOption // This is to support v1 balancer. balancerBuilder balancer.Builder + // This is to support grpclb. + resolverBuilder resolver.Builder } const ( @@ -204,6 +206,13 @@ func WithBalancerBuilder(b balancer.Builder) DialOption { } } +// withResolverBuilder is only for grpclb. +func withResolverBuilder(b resolver.Builder) DialOption { + return func(o *dialOptions) { + o.resolverBuilder = b + } +} + // WithServiceConfig returns a DialOption which has a channel to read the service configuration. // DEPRECATED: service config should be received through name resolver, as specified here. // https://github.com/grpc/grpc/blob/master/doc/service_config.md @@ -283,18 +292,23 @@ func WithTimeout(d time.Duration) DialOption { } } +func withContextDialer(f func(context.Context, string) (net.Conn, error)) DialOption { + return func(o *dialOptions) { + o.copts.Dialer = f + } +} + // WithDialer returns a DialOption that specifies a function to use for dialing network addresses. // If FailOnNonTempDialError() is set to true, and an error is returned by f, gRPC checks the error's // Temporary() method to decide if it should try to reconnect to the network address. func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { - return func(o *dialOptions) { - o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) { + return withContextDialer( + func(ctx context.Context, addr string) (net.Conn, error) { if deadline, ok := ctx.Deadline(); ok { return f(addr, deadline.Sub(time.Now())) } return f(addr, 0) - } - } + }) } // WithStatsHandler returns a DialOption that specifies the stats handler diff --git a/grpclb.go b/grpclb.go index db56ff36..8489c690 100644 --- a/grpclb.go +++ b/grpclb.go @@ -19,21 +19,32 @@ package grpc import ( - "errors" - "fmt" - "math/rand" - "net" + "strconv" + "strings" "sync" "time" "golang.org/x/net/context" - "google.golang.org/grpc/codes" - lbmpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/naming" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" ) +const ( + lbTokeyKey = "lb-token" + defaultFallbackTimeout = 10 * time.Second +) + +func convertDuration(d *lbpb.Duration) time.Duration { + if d == nil { + return 0 + } + return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond +} + // Client API for LoadBalancer service. // Mostly copied from generated pb.go file. // To avoid circular dependency. @@ -59,646 +70,270 @@ type balanceLoadClientStream struct { ClientStream } -func (x *balanceLoadClientStream) Send(m *lbmpb.LoadBalanceRequest) error { +func (x *balanceLoadClientStream) Send(m *lbpb.LoadBalanceRequest) error { return x.ClientStream.SendMsg(m) } -func (x *balanceLoadClientStream) Recv() (*lbmpb.LoadBalanceResponse, error) { - m := new(lbmpb.LoadBalanceResponse) +func (x *balanceLoadClientStream) Recv() (*lbpb.LoadBalanceResponse, error) { + m := new(lbpb.LoadBalanceResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } -// NewGRPCLBBalancer creates a grpclb load balancer. -func NewGRPCLBBalancer(r naming.Resolver) Balancer { - return &grpclbBalancer{ - r: r, +// NewLBBuilder creates a builder for grpclb. For testing only. +func NewLBBuilder() balancer.Builder { + // TODO(bar grpclb) this function is exported for testing only, remove it when resolver supports selecting grpclb. + return NewLBBuilderWithFallbackTimeout(defaultFallbackTimeout) +} + +// NewLBBuilderWithFallbackTimeout creates a grpclb builder with the given +// fallbackTimeout. If no response is received from the remote balancer within +// fallbackTimeout, the backend addresses from the resolved address list will be +// used. +// +// Only call this function when a non-default fallback timeout is needed. +func NewLBBuilderWithFallbackTimeout(fallbackTimeout time.Duration) balancer.Builder { + return &lbBuilder{ + fallbackTimeout: fallbackTimeout, } } -type remoteBalancerInfo struct { - addr string - // the server name used for authentication with the remote LB server. - name string +type lbBuilder struct { + fallbackTimeout time.Duration } -// grpclbAddrInfo consists of the information of a backend server. -type grpclbAddrInfo struct { - addr Address - connected bool - // dropForRateLimiting indicates whether this particular request should be - // dropped by the client for rate limiting. - dropForRateLimiting bool - // dropForLoadBalancing indicates whether this particular request should be - // dropped by the client for load balancing. - dropForLoadBalancing bool +func (b *lbBuilder) Name() string { + return "grpclb" } -type grpclbBalancer struct { - r naming.Resolver - target string - mu sync.Mutex - seq int // a sequence number to make sure addrCh does not get stale addresses. - w naming.Watcher - addrCh chan []Address - rbs []remoteBalancerInfo - addrs []*grpclbAddrInfo - next int - waitCh chan struct{} - done bool - rand *rand.Rand +func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer { + // This generates a manual resolver builder with a random scheme. This + // scheme will be used to dial to remote LB, so we can send filtered address + // updates to remote LB ClientConn using this manual resolver. + scheme := "grpclb_internal_" + strconv.FormatInt(time.Now().UnixNano(), 36) + r := manual.NewBuilderWithScheme(scheme) - clientStats lbmpb.ClientStats + var target string + targetSplitted := strings.Split(cc.Target(), ":///") + if len(targetSplitted) < 2 { + target = cc.Target() + } else { + target = targetSplitted[1] + } + + lb := &lbBalancer{ + cc: cc, + target: target, + opt: opt, + fallbackTimeout: b.fallbackTimeout, + doneCh: make(chan struct{}), + + manualResolver: r, + csEvltr: &connectivityStateEvaluator{}, + subConns: make(map[resolver.Address]balancer.SubConn), + scStates: make(map[balancer.SubConn]connectivity.State), + picker: &errPicker{err: balancer.ErrNoSubConnAvailable}, + clientStats: &rpcStats{}, + } + + return lb } -func (b *grpclbBalancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error { - updates, err := w.Next() - if err != nil { - grpclog.Warningf("grpclb: failed to get next addr update from watcher: %v", err) - return err - } - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return ErrClientConnClosing - } - for _, update := range updates { - switch update.Op { - case naming.Add: - var exist bool - for _, v := range b.rbs { - // TODO: Is the same addr with different server name a different balancer? - if update.Addr == v.addr { - exist = true - break - } - } - if exist { - continue - } - md, ok := update.Metadata.(*naming.AddrMetadataGRPCLB) - if !ok { - // TODO: Revisit the handling here and may introduce some fallback mechanism. - grpclog.Errorf("The name resolution contains unexpected metadata %v", update.Metadata) - continue - } - switch md.AddrType { - case naming.Backend: - // TODO: Revisit the handling here and may introduce some fallback mechanism. - grpclog.Errorf("The name resolution does not give grpclb addresses") - continue - case naming.GRPCLB: - b.rbs = append(b.rbs, remoteBalancerInfo{ - addr: update.Addr, - name: md.ServerName, - }) - default: - grpclog.Errorf("Received unknow address type %d", md.AddrType) - continue - } - case naming.Delete: - for i, v := range b.rbs { - if update.Addr == v.addr { - copy(b.rbs[i:], b.rbs[i+1:]) - b.rbs = b.rbs[:len(b.rbs)-1] - break - } - } - default: - grpclog.Errorf("Unknown update.Op %v", update.Op) - } - } - // TODO: Fall back to the basic round-robin load balancing if the resulting address is - // not a load balancer. - select { - case <-ch: - default: - } - ch <- b.rbs - return nil +type lbBalancer struct { + cc balancer.ClientConn + target string + opt balancer.BuildOptions + fallbackTimeout time.Duration + doneCh chan struct{} + + // manualResolver is used in the remote LB ClientConn inside grpclb. When + // resolved address updates are received by grpclb, filtered updates will be + // send to remote LB ClientConn through this resolver. + manualResolver *manual.Resolver + // The ClientConn to talk to the remote balancer. + ccRemoteLB *ClientConn + + // Support client side load reporting. Each picker gets a reference to this, + // and will update its content. + clientStats *rpcStats + + mu sync.Mutex // guards everything following. + // The full server list including drops, used to check if the newly received + // serverList contains anything new. Each generate picker will also have + // reference to this list to do the first layer pick. + fullServerList []*lbpb.Server + // All backends addresses, with metadata set to nil. This list contains all + // backend addresses in the same order and with the same duplicates as in + // serverlist. When generating picker, a SubConn slice with the same order + // but with only READY SCs will be gerenated. + backendAddrs []resolver.Address + // Roundrobin functionalities. + csEvltr *connectivityStateEvaluator + state connectivity.State + subConns map[resolver.Address]balancer.SubConn // Used to new/remove SubConn. + scStates map[balancer.SubConn]connectivity.State // Used to filter READY SubConns. + picker balancer.Picker + // Support fallback to resolved backend addresses if there's no response + // from remote balancer within fallbackTimeout. + fallbackTimerExpired bool + serverListReceived bool + // resolvedBackendAddrs is resolvedAddrs minus remote balancers. It's set + // when resolved address updates are received, and read in the goroutine + // handling fallback. + resolvedBackendAddrs []resolver.Address } -func convertDuration(d *lbmpb.Duration) time.Duration { - if d == nil { - return 0 - } - return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond -} - -func (b *grpclbBalancer) processServerList(l *lbmpb.ServerList, seq int) { - if l == nil { +// regeneratePicker takes a snapshot of the balancer, and generates a picker from +// it. The picker +// - always returns ErrTransientFailure if the balancer is in TransientFailure, +// - does two layer roundrobin pick otherwise. +// Caller must hold lb.mu. +func (lb *lbBalancer) regeneratePicker() { + if lb.state == connectivity.TransientFailure { + lb.picker = &errPicker{err: balancer.ErrTransientFailure} return } - servers := l.GetServers() - var ( - sl []*grpclbAddrInfo - addrs []Address - ) - for _, s := range servers { - md := metadata.Pairs("lb-token", s.LoadBalanceToken) - ip := net.IP(s.IpAddress) - ipStr := ip.String() - if ip.To4() == nil { - // Add square brackets to ipv6 addresses, otherwise net.Dial() and - // net.SplitHostPort() will return too many colons error. - ipStr = fmt.Sprintf("[%s]", ipStr) + var readySCs []balancer.SubConn + for _, a := range lb.backendAddrs { + if sc, ok := lb.subConns[a]; ok { + if st, ok := lb.scStates[sc]; ok && st == connectivity.Ready { + readySCs = append(readySCs, sc) + } } - addr := Address{ - Addr: fmt.Sprintf("%s:%d", ipStr, s.Port), - Metadata: &md, - } - sl = append(sl, &grpclbAddrInfo{ - addr: addr, - dropForRateLimiting: s.DropForRateLimiting, - dropForLoadBalancing: s.DropForLoadBalancing, - }) - addrs = append(addrs, addr) } - b.mu.Lock() - defer b.mu.Unlock() - if b.done || seq < b.seq { + + if len(lb.fullServerList) <= 0 { + if len(readySCs) <= 0 { + lb.picker = &errPicker{err: balancer.ErrNoSubConnAvailable} + return + } + lb.picker = &rrPicker{subConns: readySCs} return } - if len(sl) > 0 { - // reset b.next to 0 when replacing the server list. - b.next = 0 - b.addrs = sl - b.addrCh <- addrs + lb.picker = &lbPicker{ + serverList: lb.fullServerList, + subConns: readySCs, + stats: lb.clientStats, } return } -func (b *grpclbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration, done <-chan struct{}) { - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - case <-done: - return - } - b.mu.Lock() - stats := b.clientStats - b.clientStats = lbmpb.ClientStats{} // Clear the stats. - b.mu.Unlock() - t := time.Now() - stats.Timestamp = &lbmpb.Timestamp{ - Seconds: t.Unix(), - Nanos: int32(t.Nanosecond()), - } - if err := s.Send(&lbmpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_ClientStats{ - ClientStats: &stats, - }, - }); err != nil { - grpclog.Errorf("grpclb: failed to send load report: %v", err) +func (lb *lbBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { + grpclog.Infof("lbBalancer: handle SubConn state change: %p, %v", sc, s) + lb.mu.Lock() + defer lb.mu.Unlock() + + oldS, ok := lb.scStates[sc] + if !ok { + grpclog.Infof("lbBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + return + } + lb.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(lb.scStates, sc) + } + + oldAggrState := lb.state + lb.state = lb.csEvltr.recordTransition(oldS, s) + + // Regenerate picker when one of the following happens: + // - this sc became ready from not-ready + // - this sc became not-ready from ready + // - the aggregated state of balancer became TransientFailure from non-TransientFailure + // - the aggregated state of balancer became non-TransientFailure from TransientFailure + if (oldS == connectivity.Ready) != (s == connectivity.Ready) || + (lb.state == connectivity.TransientFailure) != (oldAggrState == connectivity.TransientFailure) { + lb.regeneratePicker() + } + + lb.cc.UpdateBalancerState(lb.state, lb.picker) + return +} + +// fallbackToBackendsAfter blocks for fallbackTimeout and falls back to use +// resolved backends (backends received from resolver, not from remote balancer) +// if no connection to remote balancers was successful. +func (lb *lbBalancer) fallbackToBackendsAfter(fallbackTimeout time.Duration) { + timer := time.NewTimer(fallbackTimeout) + defer timer.Stop() + select { + case <-timer.C: + case <-lb.doneCh: + return + } + lb.mu.Lock() + if lb.serverListReceived { + lb.mu.Unlock() + return + } + lb.fallbackTimerExpired = true + lb.refreshSubConns(lb.resolvedBackendAddrs) + lb.mu.Unlock() +} + +// HandleResolvedAddrs sends the updated remoteLB addresses to remoteLB +// clientConn. The remoteLB clientConn will handle creating/removing remoteLB +// connections. +func (lb *lbBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { + grpclog.Infof("lbBalancer: handleResolvedResult: %+v", addrs) + if len(addrs) <= 0 { + return + } + + var remoteBalancerAddrs, backendAddrs []resolver.Address + for _, a := range addrs { + if a.Type == resolver.GRPCLB { + remoteBalancerAddrs = append(remoteBalancerAddrs, a) + } else { + backendAddrs = append(backendAddrs, a) + } + } + + if lb.ccRemoteLB == nil { + if len(remoteBalancerAddrs) <= 0 { + grpclog.Errorf("grpclb: no remote balancer address is available, should never happen") return } + // First time receiving resolved addresses, create a cc to remote + // balancers. + lb.dialRemoteLB(remoteBalancerAddrs[0].ServerName) + // Start the fallback goroutine. + go lb.fallbackToBackendsAfter(lb.fallbackTimeout) } + + // cc to remote balancers uses lb.manualResolver. Send the updated remote + // balancer addresses to it through manualResolver. + lb.manualResolver.NewAddress(remoteBalancerAddrs) + + lb.mu.Lock() + lb.resolvedBackendAddrs = backendAddrs + // If serverListReceived is true, connection to remote balancer was + // successful and there's no need to do fallback anymore. + // If fallbackTimerExpired is false, fallback hasn't happened yet. + if !lb.serverListReceived && lb.fallbackTimerExpired { + // This means we received a new list of resolved backends, and we are + // still in fallback mode. Need to update the list of backends we are + // using to the new list of backends. + lb.refreshSubConns(lb.resolvedBackendAddrs) + } + lb.mu.Unlock() } -func (b *grpclbBalancer) callRemoteBalancer(lbc *loadBalancerClient, seq int) (retry bool) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - stream, err := lbc.BalanceLoad(ctx) - if err != nil { - grpclog.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) +func (lb *lbBalancer) Close() { + select { + case <-lb.doneCh: return + default: } - b.mu.Lock() - if b.done { - b.mu.Unlock() - return - } - b.mu.Unlock() - initReq := &lbmpb.LoadBalanceRequest{ - LoadBalanceRequestType: &lbmpb.LoadBalanceRequest_InitialRequest{ - InitialRequest: &lbmpb.InitialLoadBalanceRequest{ - Name: b.target, - }, - }, - } - if err := stream.Send(initReq); err != nil { - grpclog.Errorf("grpclb: failed to send init request: %v", err) - // TODO: backoff on retry? - return true - } - reply, err := stream.Recv() - if err != nil { - grpclog.Errorf("grpclb: failed to recv init response: %v", err) - // TODO: backoff on retry? - return true - } - initResp := reply.GetInitialResponse() - if initResp == nil { - grpclog.Errorf("grpclb: reply from remote balancer did not include initial response.") - return - } - // TODO: Support delegation. - if initResp.LoadBalancerDelegate != "" { - // delegation - grpclog.Errorf("TODO: Delegation is not supported yet.") - return - } - streamDone := make(chan struct{}) - defer close(streamDone) - b.mu.Lock() - b.clientStats = lbmpb.ClientStats{} // Clear client stats. - b.mu.Unlock() - if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { - go b.sendLoadReport(stream, d, streamDone) - } - // Retrieve the server list. - for { - reply, err := stream.Recv() - if err != nil { - grpclog.Errorf("grpclb: failed to recv server list: %v", err) - break - } - b.mu.Lock() - if b.done || seq < b.seq { - b.mu.Unlock() - return - } - b.seq++ // tick when receiving a new list of servers. - seq = b.seq - b.mu.Unlock() - if serverList := reply.GetServerList(); serverList != nil { - b.processServerList(serverList, seq) - } - } - return true -} - -func (b *grpclbBalancer) Start(target string, config BalancerConfig) error { - b.rand = rand.New(rand.NewSource(time.Now().Unix())) - // TODO: Fall back to the basic direct connection if there is no name resolver. - if b.r == nil { - return errors.New("there is no name resolver installed") - } - b.target = target - b.mu.Lock() - if b.done { - b.mu.Unlock() - return ErrClientConnClosing - } - b.addrCh = make(chan []Address) - w, err := b.r.Resolve(target) - if err != nil { - b.mu.Unlock() - grpclog.Errorf("grpclb: failed to resolve address: %v, err: %v", target, err) - return err - } - b.w = w - b.mu.Unlock() - balancerAddrsCh := make(chan []remoteBalancerInfo, 1) - // Spawn a goroutine to monitor the name resolution of remote load balancer. - go func() { - for { - if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil { - grpclog.Warningf("grpclb: the naming watcher stops working due to %v.\n", err) - close(balancerAddrsCh) - return - } - } - }() - // Spawn a goroutine to talk to the remote load balancer. - go func() { - var ( - cc *ClientConn - // ccError is closed when there is an error in the current cc. - // A new rb should be picked from rbs and connected. - ccError chan struct{} - rb *remoteBalancerInfo - rbs []remoteBalancerInfo - rbIdx int - ) - - defer func() { - if ccError != nil { - select { - case <-ccError: - default: - close(ccError) - } - } - if cc != nil { - cc.Close() - } - }() - - for { - var ok bool - select { - case rbs, ok = <-balancerAddrsCh: - if !ok { - return - } - foundIdx := -1 - if rb != nil { - for i, trb := range rbs { - if trb == *rb { - foundIdx = i - break - } - } - } - if foundIdx >= 0 { - if foundIdx >= 1 { - // Move the address in use to the beginning of the list. - b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0] - rbIdx = 0 - } - continue // If found, don't dial new cc. - } else if len(rbs) > 0 { - // Pick a random one from the list, instead of always using the first one. - if l := len(rbs); l > 1 && rb != nil { - tmpIdx := b.rand.Intn(l - 1) - b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0] - } - rbIdx = 0 - rb = &rbs[0] - } else { - // foundIdx < 0 && len(rbs) <= 0. - rb = nil - } - case <-ccError: - ccError = nil - if rbIdx < len(rbs)-1 { - rbIdx++ - rb = &rbs[rbIdx] - } else { - rb = nil - } - } - - if rb == nil { - continue - } - - if cc != nil { - cc.Close() - } - // Talk to the remote load balancer to get the server list. - var ( - err error - dopts []DialOption - ) - if creds := config.DialCreds; creds != nil { - if rb.name != "" { - if err := creds.OverrideServerName(rb.name); err != nil { - grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v", err) - continue - } - } - dopts = append(dopts, WithTransportCredentials(creds)) - } else { - dopts = append(dopts, WithInsecure()) - } - if dialer := config.Dialer; dialer != nil { - // WithDialer takes a different type of function, so we instead use a special DialOption here. - dopts = append(dopts, func(o *dialOptions) { o.copts.Dialer = dialer }) - } - dopts = append(dopts, WithBlock()) - ccError = make(chan struct{}) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - cc, err = DialContext(ctx, rb.addr, dopts...) - cancel() - if err != nil { - grpclog.Warningf("grpclb: failed to setup a connection to the remote balancer %v: %v", rb.addr, err) - close(ccError) - continue - } - b.mu.Lock() - b.seq++ // tick when getting a new balancer address - seq := b.seq - b.next = 0 - b.mu.Unlock() - go func(cc *ClientConn, ccError chan struct{}) { - lbc := &loadBalancerClient{cc} - b.callRemoteBalancer(lbc, seq) - cc.Close() - select { - case <-ccError: - default: - close(ccError) - } - }(cc, ccError) - } - }() - return nil -} - -func (b *grpclbBalancer) down(addr Address, err error) { - b.mu.Lock() - defer b.mu.Unlock() - for _, a := range b.addrs { - if addr == a.addr { - a.connected = false - break - } + close(lb.doneCh) + if lb.ccRemoteLB != nil { + lb.ccRemoteLB.Close() } } - -func (b *grpclbBalancer) Up(addr Address) func(error) { - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return nil - } - var cnt int - for _, a := range b.addrs { - if a.addr == addr { - if a.connected { - return nil - } - a.connected = true - } - if a.connected && !a.dropForRateLimiting && !a.dropForLoadBalancing { - cnt++ - } - } - // addr is the only one which is connected. Notify the Get() callers who are blocking. - if cnt == 1 && b.waitCh != nil { - close(b.waitCh) - b.waitCh = nil - } - return func(err error) { - b.down(addr, err) - } -} - -func (b *grpclbBalancer) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) { - var ch chan struct{} - b.mu.Lock() - if b.done { - b.mu.Unlock() - err = ErrClientConnClosing - return - } - seq := b.seq - - defer func() { - if err != nil { - return - } - put = func() { - s, ok := rpcInfoFromContext(ctx) - if !ok { - return - } - b.mu.Lock() - defer b.mu.Unlock() - if b.done || seq < b.seq { - return - } - b.clientStats.NumCallsFinished++ - if !s.bytesSent { - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - } else if s.bytesReceived { - b.clientStats.NumCallsFinishedKnownReceived++ - } - } - }() - - b.clientStats.NumCallsStarted++ - if len(b.addrs) > 0 { - if b.next >= len(b.addrs) { - b.next = 0 - } - next := b.next - for { - a := b.addrs[next] - next = (next + 1) % len(b.addrs) - if a.connected { - if !a.dropForRateLimiting && !a.dropForLoadBalancing { - addr = a.addr - b.next = next - b.mu.Unlock() - return - } - if !opts.BlockingWait { - b.next = next - if a.dropForLoadBalancing { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ - } else if a.dropForRateLimiting { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForRateLimiting++ - } - b.mu.Unlock() - err = Errorf(codes.Unavailable, "%s drops requests", a.addr.Addr) - return - } - } - if next == b.next { - // Has iterated all the possible address but none is connected. - break - } - } - } - if !opts.BlockingWait { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = Errorf(codes.Unavailable, "there is no address available") - return - } - // Wait on b.waitCh for non-failfast RPCs. - if b.waitCh == nil { - ch = make(chan struct{}) - b.waitCh = ch - } else { - ch = b.waitCh - } - b.mu.Unlock() - for { - select { - case <-ctx.Done(): - b.mu.Lock() - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = ctx.Err() - return - case <-ch: - b.mu.Lock() - if b.done { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithClientFailedToSend++ - b.mu.Unlock() - err = ErrClientConnClosing - return - } - - if len(b.addrs) > 0 { - if b.next >= len(b.addrs) { - b.next = 0 - } - next := b.next - for { - a := b.addrs[next] - next = (next + 1) % len(b.addrs) - if a.connected { - if !a.dropForRateLimiting && !a.dropForLoadBalancing { - addr = a.addr - b.next = next - b.mu.Unlock() - return - } - if !opts.BlockingWait { - b.next = next - if a.dropForLoadBalancing { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForLoadBalancing++ - } else if a.dropForRateLimiting { - b.clientStats.NumCallsFinished++ - b.clientStats.NumCallsFinishedWithDropForRateLimiting++ - } - b.mu.Unlock() - err = Errorf(codes.Unavailable, "drop requests for the addreess %s", a.addr.Addr) - return - } - } - if next == b.next { - // Has iterated all the possible address but none is connected. - break - } - } - } - // The newly added addr got removed by Down() again. - if b.waitCh == nil { - ch = make(chan struct{}) - b.waitCh = ch - } else { - ch = b.waitCh - } - b.mu.Unlock() - } - } -} - -func (b *grpclbBalancer) Notify() <-chan []Address { - return b.addrCh -} - -func (b *grpclbBalancer) Close() error { - b.mu.Lock() - defer b.mu.Unlock() - if b.done { - return errBalancerClosed - } - b.done = true - if b.waitCh != nil { - close(b.waitCh) - } - if b.addrCh != nil { - close(b.addrCh) - } - if b.w != nil { - b.w.Close() - } - return nil -} diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 1491be78..9bab731c 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -27,12 +27,14 @@ import ( "fmt" "io" "net" + "strconv" "strings" "sync" "testing" "time" "github.com/golang/protobuf/proto" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -41,16 +43,20 @@ import ( lbspb "google.golang.org/grpc/grpclb/grpc_lb_v1/service" _ "google.golang.org/grpc/grpclog/glogger" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/naming" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" "google.golang.org/grpc/test/leakcheck" + + _ "google.golang.org/grpc/grpclog/glogger" ) var ( - lbsn = "bar.com" - besn = "foo.com" - lbToken = "iamatoken" + lbServerName = "bar.com" + beServerName = "foo.com" + lbToken = "iamatoken" // Resolver replaces localhost with fakeName in Next(). // Dialer replaces fakeName with localhost when dialing. @@ -58,83 +64,6 @@ var ( fakeName = "fake.Name" ) -type testWatcher struct { - // the channel to receives name resolution updates - update chan *naming.Update - // the side channel to get to know how many updates in a batch - side chan int - // the channel to notifiy update injector that the update reading is done - readDone chan int -} - -func (w *testWatcher) Next() (updates []*naming.Update, err error) { - n, ok := <-w.side - if !ok { - return nil, fmt.Errorf("w.side is closed") - } - for i := 0; i < n; i++ { - u, ok := <-w.update - if !ok { - break - } - if u != nil { - // Resolver replaces localhost with fakeName in Next(). - // Custom dialer will replace fakeName with localhost when dialing. - u.Addr = strings.Replace(u.Addr, "localhost", fakeName, 1) - updates = append(updates, u) - } - } - w.readDone <- 0 - return -} - -func (w *testWatcher) Close() { - close(w.side) -} - -// Inject naming resolution updates to the testWatcher. -func (w *testWatcher) inject(updates []*naming.Update) { - w.side <- len(updates) - for _, u := range updates { - w.update <- u - } - <-w.readDone -} - -type testNameResolver struct { - w *testWatcher - addrs []string -} - -func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) { - r.w = &testWatcher{ - update: make(chan *naming.Update, len(r.addrs)), - side: make(chan int, 1), - readDone: make(chan int), - } - r.w.side <- len(r.addrs) - for _, addr := range r.addrs { - r.w.update <- &naming.Update{ - Op: naming.Add, - Addr: addr, - Metadata: &naming.AddrMetadataGRPCLB{ - AddrType: naming.GRPCLB, - ServerName: lbsn, - }, - } - } - go func() { - <-r.w.readDone - }() - return r.w, nil -} - -func (r *testNameResolver) inject(updates []*naming.Update) { - if r.w != nil { - r.w.inject(updates) - } -} - type serverNameCheckCreds struct { mu sync.Mutex sn string @@ -199,23 +128,22 @@ func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) { } type remoteBalancer struct { - sls []*lbmpb.ServerList - intervals []time.Duration + sls chan *lbmpb.ServerList statsDura time.Duration done chan struct{} mu sync.Mutex stats lbmpb.ClientStats } -func newRemoteBalancer(sls []*lbmpb.ServerList, intervals []time.Duration) *remoteBalancer { +func newRemoteBalancer(intervals []time.Duration) *remoteBalancer { return &remoteBalancer{ - sls: sls, - intervals: intervals, - done: make(chan struct{}), + sls: make(chan *lbmpb.ServerList, 1), + done: make(chan struct{}), } } func (b *remoteBalancer) stop() { + close(b.sls) close(b.done) } @@ -225,7 +153,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer return err } initReq := req.GetInitialRequest() - if initReq.Name != besn { + if initReq.Name != beServerName { return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) } resp := &lbmpb.LoadBalanceResponse{ @@ -260,8 +188,7 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer b.mu.Unlock() } }() - for k, v := range b.sls { - time.Sleep(b.intervals[k]) + for v := range b.sls { resp = &lbmpb.LoadBalanceResponse{ LoadBalanceResponseType: &lbmpb.LoadBalanceResponse_ServerList{ ServerList: v, @@ -278,7 +205,8 @@ func (b *remoteBalancer) BalanceLoad(stream lbspb.LoadBalancer_BalanceLoadServer type testServer struct { testpb.TestServiceServer - addr string + addr string + fallback bool } const testmdkey = "testmd" @@ -288,7 +216,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E if !ok { return nil, status.Error(codes.Internal, "failed to receive metadata") } - if md == nil || md["lb-token"][0] != lbToken { + if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) { return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md) } grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr)) @@ -299,13 +227,13 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ return nil } -func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) { +func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) { for _, l := range lis { creds := &serverNameCheckCreds{ sn: sn, } s := grpc.NewServer(grpc.Creds(creds)) - testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String()}) + testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback}) servers = append(servers, s) go func(s *grpc.Server, l net.Listener) { s.Serve(l) @@ -348,7 +276,7 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er beListeners = append(beListeners, beLis) } - backends := startBackends(besn, beListeners...) + backends := startBackends(beServerName, false, beListeners...) // Start a load balancer. lbLis, err := net.Listen("tcp", "localhost:0") @@ -357,21 +285,21 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er return } lbCreds := &serverNameCheckCreds{ - sn: lbsn, + sn: lbServerName, } lb = grpc.NewServer(grpc.Creds(lbCreds)) if err != nil { err = fmt.Errorf("Failed to generate the port number %v", err) return } - ls = newRemoteBalancer(nil, nil) + ls = newRemoteBalancer(nil) lbspb.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) }() tss = &testServers{ - lbAddr: lbLis.Addr().String(), + lbAddr: fakeName + ":" + strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port), ls: ls, lb: lb, beIPs: beIPs, @@ -389,6 +317,10 @@ func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), er func TestGRPCLB(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) @@ -405,136 +337,178 @@ func TestGRPCLB(t *testing.T) { sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbmpb.ServerList{sl} - tss.ls.intervals = []time.Duration{0} + tss.ls.sls <- sl creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerBuilder(grpc.NewLBBuilder()), + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) } } -func TestDropRequest(t *testing.T) { +// The remote balancer sends response with duplicates to grpclb client. +func TestGRPCLBWeighted(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + tss, cleanup, err := newLoadBalancer(2) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - tss.ls.sls = []*lbmpb.ServerList{{ - Servers: []*lbmpb.Server{{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: true, - }, { - IpAddress: tss.beIPs[1], - Port: int32(tss.bePorts[1]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: false, - }}, + + beServers := []*lbmpb.Server{{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, { + IpAddress: tss.beIPs[1], + Port: int32(tss.bePorts[1]), + LoadBalanceToken: lbToken, }} - tss.ls.intervals = []time.Duration{0} + portsToIndex := make(map[int]int) + for i := range beServers { + portsToIndex[tss.bePorts[i]] = i + } + creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerBuilder(grpc.NewLBBuilder()), + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - // Wait until the first connection is up. - // The first one has Drop set to true, error should contain "drop requests". - for { - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { - break + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + sequences := []string{"00101", "00011"} + for _, seq := range sequences { + var ( + bes []*lbmpb.Server + p peer.Peer + result string + ) + for _, s := range seq { + bes = append(bes, beServers[s-'0']) + } + tss.ls.sls <- &lbmpb.ServerList{Servers: bes} + + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) } + result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port]) } - } - // The 1st, non-fail-fast RPC should succeed. This ensures both server - // connections are made, because the first one has DropForLoadBalancing set to true. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { - t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", testC, err) - } - for i := 0; i < 3; i++ { - // Odd fail-fast RPCs should fail, because the 1st backend has DropForLoadBalancing - // set to true. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Unavailable { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) - } - // Even fail-fast RPCs should succeed since they choose the - // non-drop-request backend according to the round robin policy. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + // The generated result will be in format of "0010100101". + if !strings.Contains(result, strings.Repeat(seq, 2)) { + t.Errorf("got result sequence %q, want patten %q", result, seq) } } } -func TestDropRequestFailedNonFailFast(t *testing.T) { +func TestDropRequest(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbmpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: true, + tss.ls.sls <- &lbmpb.ServerList{ + Servers: []*lbmpb.Server{{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + DropForLoadBalancing: false, + }, { + DropForLoadBalancing: true, + }}, } - var bes []*lbmpb.Server - bes = append(bes, be) - sl := &lbmpb.ServerList{ - Servers: bes, - } - tss.ls.sls = []*lbmpb.ServerList{sl} - tss.ls.intervals = []time.Duration{0} creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerBuilder(grpc.NewLBBuilder()), + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded) + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + // The 1st, non-fail-fast RPC should succeed. This ensures both server + // connections are made, because the first one has DropForLoadBalancing set to true. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", testC, err) + } + for _, failfast := range []bool{true, false} { + for i := 0; i < 3; i++ { + // Even RPCs should fail, because the 2st backend has + // DropForLoadBalancing set to true. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); grpc.Code(err) != codes.Unavailable { + t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) + } + // Odd RPCs should succeed since they choose the non-drop-request + // backend according to the round robin policy. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); err != nil { + t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } + } } } // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list. func TestBalancerDisconnects(t *testing.T) { defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + var ( - lbAddrs []string - lbs []*grpc.Server + tests []*testServers + lbs []*grpc.Server ) - for i := 0; i < 3; i++ { + for i := 0; i < 2; i++ { tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) @@ -551,78 +525,149 @@ func TestBalancerDisconnects(t *testing.T) { sl := &lbmpb.ServerList{ Servers: bes, } - tss.ls.sls = []*lbmpb.ServerList{sl} - tss.ls.intervals = []time.Duration{0} + tss.ls.sls <- sl - lbAddrs = append(lbAddrs, tss.lbAddr) + tests = append(tests, tss) lbs = append(lbs, tss.lb) } creds := serverNameCheckCreds{ - expected: besn, + expected: beServerName, } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - resolver := &testNameResolver{ - addrs: lbAddrs[:2], - } - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(resolver)), - grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerBuilder(grpc.NewLBBuilder()), + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - var previousTrailer string - trailer := metadata.MD{} - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { + + r.NewAddress([]resolver.Address{{ + Addr: tests[0].lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: tests[1].lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + + var p peer.Peer + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) - } else { - previousTrailer = trailer[testmdkey][0] } - // The initial resolver update contains lbs[0] and lbs[1]. - // When lbs[0] is stopped, lbs[1] should be used. + if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] { + t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0]) + } + lbs[0].Stop() - for { - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { + // Stop balancer[0], balancer[1] should be used by grpclb. + // Check peer address to see if that happened. + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) - } else if trailer[testmdkey][0] != previousTrailer { - // A new backend server should receive the request. - // The trailer contains the backend address, so the trailer should be different from the previous one. - previousTrailer = trailer[testmdkey][0] - break } - time.Sleep(100 * time.Millisecond) + if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] { + return + } + time.Sleep(time.Millisecond) } - // Inject a update to add lbs[2] to resolved addresses. - resolver.inject([]*naming.Update{ - {Op: naming.Add, - Addr: lbAddrs[2], - Metadata: &naming.AddrMetadataGRPCLB{ - AddrType: naming.GRPCLB, - ServerName: lbsn, - }, - }, - }) - // Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used. - lbs[1].Stop() - for { - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { + t.Fatalf("No RPC sent to second backend after 1 second") +} + +func TestFallback(t *testing.T) { + defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + tss, cleanup, err := newLoadBalancer(1) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + // Start a standalone backend. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + defer beLis.Close() + standaloneBEs := startBackends(beServerName, true, beLis) + defer stopBackends(standaloneBEs) + + be := &lbmpb.Server{ + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + } + var bes []*lbmpb.Server + bes = append(bes, be) + sl := &lbmpb.ServerList{ + Servers: bes, + } + tss.ls.sls <- sl + creds := serverNameCheckCreds{ + expected: beServerName, + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerBuilder(grpc.NewLBBuilderWithFallbackTimeout(100*time.Millisecond)), + grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + r.NewAddress([]resolver.Address{{ + Addr: "", + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: beLis.Addr().String(), + Type: resolver.Backend, + ServerName: beServerName, + }}) + + var p peer.Peer + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } + if p.Addr.String() != beLis.Addr().String() { + t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) + } + + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: beLis.Addr().String(), + Type: resolver.Backend, + ServerName: beServerName, + }}) + + for i := 0; i < 1000; i++ { + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) - } else if trailer[testmdkey][0] != previousTrailer { - // A new backend server should receive the request. - // The trailer contains the backend address, so the trailer should be different from the previous one. - break } - time.Sleep(100 * time.Millisecond) + if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] { + return + } + time.Sleep(time.Millisecond) } + t.Fatalf("No RPC sent to backend behind remote balancer after 1 second") } type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - if strings.Contains(uri[0], "failtosend") { + if strings.Contains(uri[0], failtosendURI) { return nil, fmt.Errorf("rpc should fail to send") } return nil, nil @@ -640,35 +685,46 @@ func checkStats(stats *lbmpb.ClientStats, expected *lbmpb.ClientStats) error { } func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbmpb.ClientStats { - tss, cleanup, err := newLoadBalancer(3) + defer leakcheck.Check(t) + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + + tss, cleanup, err := newLoadBalancer(1) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - tss.ls.sls = []*lbmpb.ServerList{{ + tss.ls.sls <- &lbmpb.ServerList{ Servers: []*lbmpb.Server{{ - IpAddress: tss.beIPs[2], - Port: int32(tss.bePorts[2]), + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), LoadBalanceToken: lbToken, DropForLoadBalancing: dropForLoadBalancing, DropForRateLimiting: dropForRateLimiting, }}, - }} - tss.ls.intervals = []time.Duration{0} + } tss.ls.statsDura = 100 * time.Millisecond - creds := serverNameCheckCreds{expected: besn} + creds := serverNameCheckCreds{expected: beServerName} ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cc, err := grpc.DialContext(ctx, besn, - grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{addrs: []string{tss.lbAddr}})), - grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}), - grpc.WithBlock(), grpc.WithDialer(fakeNameDialer)) + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, + grpc.WithBalancerBuilder(grpc.NewLBBuilder()), + grpc.WithTransportCredentials(&creds), + grpc.WithPerRPCCredentials(failPreRPCCred{}), + grpc.WithDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() + r.NewAddress([]resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }}) + runRPCs(cc) time.Sleep(1 * time.Second) tss.ls.mu.Lock() @@ -677,7 +733,11 @@ func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool return stats } -const countRPC = 40 +const ( + countRPC = 40 + failtosendURI = "failtosend" + dropErrDesc = "request dropped by grpclb" +) func TestGRPCLBStatsUnarySuccess(t *testing.T) { defer leakcheck.Check(t) @@ -709,7 +769,7 @@ func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -737,7 +797,7 @@ func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -766,7 +826,7 @@ func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) } for i := 0; i < countRPC-1; i++ { - grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc) + grpc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil, cc) } }) @@ -824,7 +884,7 @@ func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -852,7 +912,7 @@ func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) { for { c++ if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { - if strings.Contains(err.Error(), "drops requests") { + if strings.Contains(err.Error(), dropErrDesc) { break } } @@ -887,7 +947,7 @@ func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) { } } for i := 0; i < countRPC-1; i++ { - grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend") + grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, failtosendURI) } }) diff --git a/grpclb_picker.go b/grpclb_picker.go new file mode 100644 index 00000000..872c7cce --- /dev/null +++ b/grpclb_picker.go @@ -0,0 +1,159 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "sync" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/status" +) + +type rpcStats struct { + NumCallsStarted int64 + NumCallsFinished int64 + NumCallsFinishedWithDropForRateLimiting int64 + NumCallsFinishedWithDropForLoadBalancing int64 + NumCallsFinishedWithClientFailedToSend int64 + NumCallsFinishedKnownReceived int64 +} + +// toClientStats converts rpcStats to lbpb.ClientStats, and clears rpcStats. +func (s *rpcStats) toClientStats() *lbpb.ClientStats { + stats := &lbpb.ClientStats{ + NumCallsStarted: atomic.SwapInt64(&s.NumCallsStarted, 0), + NumCallsFinished: atomic.SwapInt64(&s.NumCallsFinished, 0), + NumCallsFinishedWithDropForRateLimiting: atomic.SwapInt64(&s.NumCallsFinishedWithDropForRateLimiting, 0), + NumCallsFinishedWithDropForLoadBalancing: atomic.SwapInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 0), + NumCallsFinishedWithClientFailedToSend: atomic.SwapInt64(&s.NumCallsFinishedWithClientFailedToSend, 0), + NumCallsFinishedKnownReceived: atomic.SwapInt64(&s.NumCallsFinishedKnownReceived, 0), + } + return stats +} + +func (s *rpcStats) dropForRateLimiting() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithDropForRateLimiting, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) dropForLoadBalancing() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithDropForLoadBalancing, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) failedToSend() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedWithClientFailedToSend, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +func (s *rpcStats) knownReceived() { + atomic.AddInt64(&s.NumCallsStarted, 1) + atomic.AddInt64(&s.NumCallsFinishedKnownReceived, 1) + atomic.AddInt64(&s.NumCallsFinished, 1) +} + +type errPicker struct { + // Pick always returns this err. + err error +} + +func (p *errPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + return nil, nil, p.err +} + +// rrPicker does roundrobin on subConns. It's typically used when there's no +// response from remote balancer, and grpclb falls back to the resolved +// backends. +// +// It guaranteed that len(subConns) > 0. +type rrPicker struct { + mu sync.Mutex + subConns []balancer.SubConn // The subConns that were READY when taking the snapshot. + subConnsNext int +} + +func (p *rrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + p.mu.Lock() + defer p.mu.Unlock() + sc := p.subConns[p.subConnsNext] + p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns) + return sc, nil, nil +} + +// lbPicker does two layers of picks: +// +// First layer: roundrobin on all servers in serverList, including drops and backends. +// - If it picks a drop, the RPC will fail as being dropped. +// - If it picks a backend, do a second layer pick to pick the real backend. +// +// Second layer: roundrobin on all READY backends. +// +// It's guaranteed that len(serverList) > 0. +type lbPicker struct { + mu sync.Mutex + serverList []*lbpb.Server + serverListNext int + subConns []balancer.SubConn // The subConns that were READY when taking the snapshot. + subConnsNext int + + stats *rpcStats +} + +func (p *lbPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) { + p.mu.Lock() + defer p.mu.Unlock() + + // Layer one roundrobin on serverList. + s := p.serverList[p.serverListNext] + p.serverListNext = (p.serverListNext + 1) % len(p.serverList) + + // If it's a drop, return an error and fail the RPC. + if s.DropForRateLimiting { + p.stats.dropForRateLimiting() + return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb") + } + if s.DropForLoadBalancing { + p.stats.dropForLoadBalancing() + return nil, nil, status.Errorf(codes.Unavailable, "request dropped by grpclb") + } + + // If not a drop but there's no ready subConns. + if len(p.subConns) <= 0 { + return nil, nil, balancer.ErrNoSubConnAvailable + } + + // Return the next ready subConn in the list, also collect rpc stats. + sc := p.subConns[p.subConnsNext] + p.subConnsNext = (p.subConnsNext + 1) % len(p.subConns) + done := func(info balancer.DoneInfo) { + if !info.BytesSent { + p.stats.failedToSend() + } else if info.BytesReceived { + p.stats.knownReceived() + } + } + return sc, done, nil +} diff --git a/grpclb_remote_balancer.go b/grpclb_remote_balancer.go new file mode 100644 index 00000000..50d4ccd3 --- /dev/null +++ b/grpclb_remote_balancer.go @@ -0,0 +1,254 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package grpc + +import ( + "fmt" + "net" + "reflect" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1/messages" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/resolver" +) + +// processServerList updates balaner's internal state, create/remove SubConns +// and regenerates picker using the received serverList. +func (lb *lbBalancer) processServerList(l *lbpb.ServerList) { + grpclog.Infof("lbBalancer: processing server list: %+v", l) + lb.mu.Lock() + defer lb.mu.Unlock() + + // Set serverListReceived to true so fallback will not take effect if it has + // not hit timeout. + lb.serverListReceived = true + + // If the new server list == old server list, do nothing. + if reflect.DeepEqual(lb.fullServerList, l.Servers) { + grpclog.Infof("lbBalancer: new serverlist same as the previous one, ignoring") + return + } + lb.fullServerList = l.Servers + + var backendAddrs []resolver.Address + for _, s := range l.Servers { + if s.DropForLoadBalancing || s.DropForRateLimiting { + continue + } + + md := metadata.Pairs(lbTokeyKey, s.LoadBalanceToken) + ip := net.IP(s.IpAddress) + ipStr := ip.String() + if ip.To4() == nil { + // Add square brackets to ipv6 addresses, otherwise net.Dial() and + // net.SplitHostPort() will return too many colons error. + ipStr = fmt.Sprintf("[%s]", ipStr) + } + addr := resolver.Address{ + Addr: fmt.Sprintf("%s:%d", ipStr, s.Port), + Metadata: &md, + } + + backendAddrs = append(backendAddrs, addr) + } + + // Call refreshSubConns to create/remove SubConns. + backendsUpdated := lb.refreshSubConns(backendAddrs) + // If no backend was updated, no SubConn will be newed/removed. But since + // the full serverList was different, there might be updates in drops or + // pick weights(different number of duplicates). We need to update picker + // with the fulllist. + if !backendsUpdated { + lb.regeneratePicker() + lb.cc.UpdateBalancerState(lb.state, lb.picker) + } +} + +// refreshSubConns creates/removes SubConns with backendAddrs. It returns a bool +// indicating whether the backendAddrs are different from the cached +// backendAddrs (whether any SubConn was newed/removed). +// Caller must hold lb.mu. +func (lb *lbBalancer) refreshSubConns(backendAddrs []resolver.Address) bool { + lb.backendAddrs = nil + var backendsUpdated bool + // addrsSet is the set converted from backendAddrs, it's used to quick + // lookup for an address. + addrsSet := make(map[resolver.Address]struct{}) + // Create new SubConns. + for _, addr := range backendAddrs { + addrWithoutMD := addr + addrWithoutMD.Metadata = nil + addrsSet[addrWithoutMD] = struct{}{} + lb.backendAddrs = append(lb.backendAddrs, addrWithoutMD) + + if _, ok := lb.subConns[addrWithoutMD]; !ok { + backendsUpdated = true + + // Use addrWithMD to create the SubConn. + sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{}) + if err != nil { + grpclog.Warningf("roundrobinBalancer: failed to create new SubConn: %v", err) + continue + } + lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map. + lb.scStates[sc] = connectivity.Idle + sc.Connect() + } + } + + for a, sc := range lb.subConns { + // a was removed by resolver. + if _, ok := addrsSet[a]; !ok { + backendsUpdated = true + + lb.cc.RemoveSubConn(sc) + delete(lb.subConns, a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in HandleSubConnStateChange. + } + } + + return backendsUpdated +} + +func (lb *lbBalancer) readServerList(s *balanceLoadClientStream) error { + for { + reply, err := s.Recv() + if err != nil { + return fmt.Errorf("grpclb: failed to recv server list: %v", err) + } + if serverList := reply.GetServerList(); serverList != nil { + lb.processServerList(serverList) + } + } +} + +func (lb *lbBalancer) sendLoadReport(s *balanceLoadClientStream, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-s.Context().Done(): + return + } + stats := lb.clientStats.toClientStats() + t := time.Now() + stats.Timestamp = &lbpb.Timestamp{ + Seconds: t.Unix(), + Nanos: int32(t.Nanosecond()), + } + if err := s.Send(&lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_ClientStats{ + ClientStats: stats, + }, + }); err != nil { + return + } + } +} +func (lb *lbBalancer) callRemoteBalancer() error { + lbClient := &loadBalancerClient{cc: lb.ccRemoteLB} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := lbClient.BalanceLoad(ctx, FailFast(false)) + if err != nil { + return fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) + } + + // grpclb handshake on the stream. + initReq := &lbpb.LoadBalanceRequest{ + LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ + InitialRequest: &lbpb.InitialLoadBalanceRequest{ + Name: lb.target, + }, + }, + } + if err := stream.Send(initReq); err != nil { + return fmt.Errorf("grpclb: failed to send init request: %v", err) + } + reply, err := stream.Recv() + if err != nil { + return fmt.Errorf("grpclb: failed to recv init response: %v", err) + } + initResp := reply.GetInitialResponse() + if initResp == nil { + return fmt.Errorf("grpclb: reply from remote balancer did not include initial response") + } + if initResp.LoadBalancerDelegate != "" { + return fmt.Errorf("grpclb: Delegation is not supported") + } + + go func() { + if d := convertDuration(initResp.ClientStatsReportInterval); d > 0 { + lb.sendLoadReport(stream, d) + } + }() + return lb.readServerList(stream) +} + +func (lb *lbBalancer) watchRemoteBalancer() { + for { + err := lb.callRemoteBalancer() + select { + case <-lb.doneCh: + return + default: + if err != nil { + grpclog.Error(err) + } + } + + } +} + +func (lb *lbBalancer) dialRemoteLB(remoteLBName string) { + var dopts []DialOption + if creds := lb.opt.DialCreds; creds != nil { + if err := creds.OverrideServerName(remoteLBName); err == nil { + dopts = append(dopts, WithTransportCredentials(creds)) + } else { + grpclog.Warningf("grpclb: failed to override the server name in the credentials: %v, using Insecure", err) + dopts = append(dopts, WithInsecure()) + } + } else { + dopts = append(dopts, WithInsecure()) + } + if lb.opt.Dialer != nil { + // WithDialer takes a different type of function, so we instead use a + // special DialOption here. + dopts = append(dopts, withContextDialer(lb.opt.Dialer)) + } + // Explicitly set pickfirst as the balancer. + dopts = append(dopts, WithBalancerBuilder(newPickfirstBuilder())) + dopts = append(dopts, withResolverBuilder(lb.manualResolver)) + // Dial using manualResolver.Scheme, which is a random scheme generated + // when init grpclb. The target name is not important. + cc, err := Dial("grpclb:///grpclb.server", dopts...) + if err != nil { + grpclog.Fatalf("failed to dial: %v", err) + } + lb.ccRemoteLB = cc + go lb.watchRemoteBalancer() +} diff --git a/picker_wrapper.go b/picker_wrapper.go index 9085dbc9..db82bfb3 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -97,7 +97,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. p = bp.picker bp.mu.Unlock() - subConn, put, err := p.Pick(ctx, opts) + subConn, done, err := p.Pick(ctx, opts) if err != nil { switch err { @@ -120,7 +120,7 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer. continue } if t, ok := acw.getAddrConn().getReadyTransport(); ok { - return t, put, nil + return t, done, nil } grpclog.Infof("blockingPicker: the picked transport is not ready, loop back to repick") // If ok == false, ac.state is not READY. diff --git a/resolver/manual/manual.go b/resolver/manual/manual.go index de2fb25b..7a790903 100644 --- a/resolver/manual/manual.go +++ b/resolver/manual/manual.go @@ -16,7 +16,8 @@ * */ -// Package manual contains a resolver for testing purpose only. +// Package manual defines a resolver that can be used to manually send resolved +// addresses to ClientConn. package manual import ( diff --git a/resolver/resolver.go b/resolver/resolver.go index 0dd887fa..e27db82f 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -78,7 +78,9 @@ type Address struct { // Type is the type of this address. Type AddressType // ServerName is the name of this address. - // It's the name of the grpc load balancer, which will be used for authentication. + // + // e.g. if Type is GRPCLB, ServerName should be the name of the remote load + // balancer, not the name of the backend. ServerName string // Metadata is the information associated with Addr, which may be used // to make load balancing decision. diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index c07e174a..e285ad9b 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -61,12 +61,18 @@ func parseTarget(target string) (ret resolver.Target) { // newCCResolverWrapper parses cc.target for scheme and gets the resolver // builder for this scheme. It then builds the resolver and starts the // monitoring goroutine for it. +// +// If withResolverBuilder dial option is set, the specified resolver will be +// used instead. func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme) - rb := resolver.Get(cc.parsedTarget.Scheme) + rb := cc.dopts.resolverBuilder if rb == nil { - return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) + rb = resolver.Get(cc.parsedTarget.Scheme) + if rb == nil { + return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) + } } ccr := &ccResolverWrapper{ diff --git a/rpc_util.go b/rpc_util.go index 07d03a85..6ab195c2 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -441,9 +441,7 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{ } type rpcInfo struct { - failfast bool - bytesSent bool - bytesReceived bool + failfast bool } type rpcInfoContextKey struct{} @@ -457,14 +455,6 @@ func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { return } -func updateRPCInfoInContext(ctx context.Context, s rpcInfo) { - if ss, ok := rpcInfoFromContext(ctx); ok { - ss.bytesReceived = s.bytesReceived - ss.bytesSent = s.bytesSent - } - return -} - // Code returns the error code for err if it was produced by the rpc system. // Otherwise, it returns codes.Unknown. // diff --git a/stream.go b/stream.go index 9eeaafef..f2ded47d 100644 --- a/stream.go +++ b/stream.go @@ -232,7 +232,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s, err = t.NewStream(ctx, callHdr) if err != nil { if done != nil { - done(balancer.DoneInfo{Err: err}) + doneInfo := balancer.DoneInfo{Err: err} + if _, ok := err.(transport.ConnectionError); ok { + // If error is connection error, transport was sending data on wire, + // and we are not sure if anything has been sent on wire. + // If error is not connection error, we are sure nothing has been sent. + doneInfo.BytesSent = true + } + done(doneInfo) done = nil } // In the event of any error from NewStream, we never attempted to write @@ -529,11 +536,11 @@ func (cs *clientStream) finish(err error) { o.after(cs.c) } if cs.done != nil { - updateRPCInfoInContext(cs.s.Context(), rpcInfo{ - bytesSent: true, - bytesReceived: cs.s.BytesReceived(), + cs.done(balancer.DoneInfo{ + Err: err, + BytesSent: true, + BytesReceived: cs.s.BytesReceived(), }) - cs.done(balancer.DoneInfo{Err: err}) cs.done = nil } if cs.statsHandler != nil {