testutils: Add a context parameter to the Receive() method. (#3835)

This commit is contained in:
Easwar Swaminathan 2020-08-27 13:55:15 -07:00 коммит произвёл GitHub
Родитель 35afeb6efe
Коммит d25c71b543
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
24 изменённых файлов: 674 добавлений и 455 удалений

Просмотреть файл

@ -19,6 +19,7 @@
package rls package rls
import ( import (
"context"
"net" "net"
"testing" "testing"
"time" "time"
@ -32,6 +33,8 @@ import (
"google.golang.org/grpc/testdata" "google.golang.org/grpc/testdata"
) )
const defaultTestTimeout = 1 * time.Second
type s struct { type s struct {
grpctest.Tester grpctest.Tester
} }
@ -99,7 +102,9 @@ func (s) TestUpdateControlChannelFirstConfig(t *testing.T) {
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel") t.Fatal("Timeout expired when waiting for LB policy to create control channel")
} }
@ -132,7 +137,9 @@ func (s) TestUpdateControlChannelSwitch(t *testing.T) {
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis1.connCh.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis1.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel") t.Fatal("Timeout expired when waiting for LB policy to create control channel")
} }
@ -140,7 +147,7 @@ func (s) TestUpdateControlChannelSwitch(t *testing.T) {
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis2.connCh.Receive(); err != nil { if _, err := lis2.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel") t.Fatal("Timeout expired when waiting for LB policy to create control channel")
} }
@ -169,14 +176,17 @@ func (s) TestUpdateControlChannelTimeout(t *testing.T) {
lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second} lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel") t.Fatal("Timeout expired when waiting for LB policy to create control channel")
} }
lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second} lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second}
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != testutils.ErrRecvTimeout { if _, err := lis.connCh.Receive(ctx); err != context.DeadlineExceeded {
t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed") t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed")
} }
@ -215,7 +225,9 @@ func (s) TestUpdateControlChannelWithCreds(t *testing.T) {
t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg)
rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg})
if _, err := lis.connCh.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := lis.connCh.Receive(ctx); err != nil {
t.Fatal("Timeout expired when waiting for LB policy to create control channel") t.Fatal("Timeout expired when waiting for LB policy to create control channel")
} }

Просмотреть файл

@ -19,6 +19,7 @@
package rls package rls
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
@ -82,7 +83,9 @@ func (s) TestLookupFailure(t *testing.T) {
errCh.Send(nil) errCh.Send(nil)
}) })
if e, err := errCh.Receive(); err != nil || e != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if e, err := errCh.Receive(ctx); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
} }
} }
@ -106,7 +109,9 @@ func (s) TestLookupDeadlineExceeded(t *testing.T) {
errCh.Send(nil) errCh.Send(nil)
}) })
if e, err := errCh.Receive(); err != nil || e != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if e, err := errCh.Receive(ctx); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
} }
} }
@ -150,7 +155,9 @@ func (s) TestLookupSuccess(t *testing.T) {
// Make sure that the fake server received the expected RouteLookupRequest // Make sure that the fake server received the expected RouteLookupRequest
// proto. // proto.
req, err := server.RequestChan.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
req, err := server.RequestChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timed out wile waiting for a RouteLookupRequest") t.Fatalf("Timed out wile waiting for a RouteLookupRequest")
} }
@ -168,7 +175,7 @@ func (s) TestLookupSuccess(t *testing.T) {
}, },
} }
if e, err := errCh.Receive(); err != nil || e != nil { if e, err := errCh.Receive(ctx); err != nil || e != nil {
t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err)
} }
} }

Просмотреть файл

@ -260,7 +260,9 @@ func TestPick_DataCacheMiss_PendingCacheMiss(t *testing.T) {
// If the test specified that a new RLS request should be made, // If the test specified that a new RLS request should be made,
// verify it. // verify it.
if test.wantRLSRequest { if test.wantRLSRequest {
if rlsErr, err := rlsCh.Receive(); err != nil || rlsErr != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil {
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err) t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
} }
} }
@ -339,7 +341,9 @@ func TestPick_DataCacheMiss_PendingCacheHit(t *testing.T) {
} }
// Make sure that no RLS request was sent out. // Make sure that no RLS request was sent out.
if _, err := rlsCh.Receive(); err != testutils.ErrRecvTimeout { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("RLS request sent out when pending entry exists") t.Fatalf("RLS request sent out when pending entry exists")
} }
}) })
@ -483,7 +487,9 @@ func TestPick_DataCacheHit_PendingCacheMiss(t *testing.T) {
// If the test specified that a new RLS request should be made, // If the test specified that a new RLS request should be made,
// verify it. // verify it.
if test.wantRLSRequest { if test.wantRLSRequest {
if rlsErr, err := rlsCh.Receive(); err != nil || rlsErr != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if rlsErr, err := rlsCh.Receive(ctx); err != nil || rlsErr != nil {
t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err) t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err)
} }
} }
@ -590,7 +596,9 @@ func TestPick_DataCacheHit_PendingCacheHit(t *testing.T) {
t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr) t.Fatalf("Pick() returned error {%v}, want {%v}", err, test.wantErr)
} }
// Make sure that no RLS request was sent out. // Make sure that no RLS request was sent out.
if _, err := rlsCh.Receive(); err != testutils.ErrRecvTimeout { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := rlsCh.Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("RLS request sent out when pending entry exists") t.Fatalf("RLS request sent out when pending entry exists")
} }
if test.wantErr != nil { if test.wantErr != nil {

Просмотреть файл

@ -55,6 +55,7 @@ const (
exampleResource = "https://backend.example.com/api" exampleResource = "https://backend.example.com/api"
exampleAudience = "example-backend-service" exampleAudience = "example-backend-service"
testScope = "https://www.googleapis.com/auth/monitoring" testScope = "https://www.googleapis.com/auth/monitoring"
defaultTestTimeout = 1 * time.Second
) )
var ( var (
@ -142,7 +143,11 @@ type fakeHTTPDoer struct {
func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) { func (fc *fakeHTTPDoer) Do(req *http.Request) (*http.Response, error) {
fc.reqCh.Send(req) fc.reqCh.Send(req)
val, err := fc.respCh.Receive()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := fc.respCh.Receive(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -240,7 +245,10 @@ func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters)
// by the tests. So, any errors encountered are pushed to an error channel // by the tests. So, any errors encountered are pushed to an error channel
// which is monitored by the test. // which is monitored by the test.
func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) { func receiveAndCompareRequest(reqCh *testutils.Channel, errCh chan error) {
val, err := reqCh.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := reqCh.Receive(ctx)
if err != nil { if err != nil {
errCh <- err errCh <- err
return return
@ -430,7 +438,10 @@ func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
if _, err := fc.reqCh.Receive(); err != testutils.ErrRecvTimeout { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fc.reqCh.Receive(ctx); err != context.DeadlineExceeded {
errCh <- err errCh <- err
return return
} }

Просмотреть файл

@ -179,7 +179,9 @@ func (s) TestStoreSingleProvider(t *testing.T) {
// Our fakeProviderBuilder pushes newly created providers on a channel. Grab // Our fakeProviderBuilder pushes newly created providers on a channel. Grab
// the fake provider from that channel. // the fake provider from that channel.
p, err := fpb1.providerChan.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
p, err := fpb1.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
} }
@ -188,8 +190,6 @@ func (s) TestStoreSingleProvider(t *testing.T) {
// Attempt to read from key material from the Provider returned by the // Attempt to read from key material from the Provider returned by the
// store. This will fail because we have not pushed any key material into // store. This will fail because we have not pushed any key material into
// our fake provider. // our fake provider.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, prov, nil); !errors.Is(err, context.DeadlineExceeded) { if err := readAndVerifyKeyMaterial(ctx, prov, nil); !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err) t.Fatal(err)
} }
@ -208,8 +208,6 @@ func (s) TestStoreSingleProvider(t *testing.T) {
// updated key material. // updated key material.
testKM2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem") testKM2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
fakeProv.newKeyMaterial(testKM2, nil) fakeProv.newKeyMaterial(testKM2, nil)
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, prov, testKM2); err != nil { if err := readAndVerifyKeyMaterial(ctx, prov, testKM2); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -236,7 +234,9 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) {
defer provFoo2.Close() defer provFoo2.Close()
// Our fakeProviderBuilder pushes newly created providers on a channel. // Our fakeProviderBuilder pushes newly created providers on a channel.
// Grab the fake provider for optsFoo. // Grab the fake provider for optsFoo.
p, err := fpb1.providerChan.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
p, err := fpb1.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
} }
@ -248,7 +248,7 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) {
} }
defer provBar1.Close() defer provBar1.Close()
// Grab the fake provider for optsBar. // Grab the fake provider for optsBar.
p, err = fpb1.providerChan.Receive() p, err = fpb1.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
} }
@ -258,8 +258,6 @@ func (s) TestStoreSingleProviderSameConfigDifferentOpts(t *testing.T) {
// appropriate key material and the bar provider times out. // appropriate key material and the bar provider times out.
fooKM := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") fooKM := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
fakeProvFoo.newKeyMaterial(fooKM, nil) fakeProvFoo.newKeyMaterial(fooKM, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, provFoo1, fooKM); err != nil { if err := readAndVerifyKeyMaterial(ctx, provFoo1, fooKM); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -302,7 +300,9 @@ func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) {
defer prov1.Close() defer prov1.Close()
// Our fakeProviderBuilder pushes newly created providers on a channel. Grab // Our fakeProviderBuilder pushes newly created providers on a channel. Grab
// the fake provider from that channel. // the fake provider from that channel.
p1, err := fpb1.providerChan.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
p1, err := fpb1.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
} }
@ -314,7 +314,7 @@ func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) {
} }
defer prov2.Close() defer prov2.Close()
// Grab the second provider from the channel. // Grab the second provider from the channel.
p2, err := fpb1.providerChan.Receive() p2, err := fpb1.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
} }
@ -325,8 +325,6 @@ func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) {
km1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") km1 := loadKeyMaterials(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
fakeProv1.newKeyMaterial(km1, nil) fakeProv1.newKeyMaterial(km1, nil)
fakeProv2.newKeyMaterial(km1, nil) fakeProv2.newKeyMaterial(km1, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil { if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -339,8 +337,6 @@ func (s) TestStoreSingleProviderDifferentConfigs(t *testing.T) {
// material. // material.
km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem") km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
fakeProv2.newKeyMaterial(km2, nil) fakeProv2.newKeyMaterial(km2, nil)
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil { if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -366,7 +362,9 @@ func (s) TestStoreMultipleProviders(t *testing.T) {
defer prov1.Close() defer prov1.Close()
// Our fakeProviderBuilder pushes newly created providers on a channel. Grab // Our fakeProviderBuilder pushes newly created providers on a channel. Grab
// the fake provider from that channel. // the fake provider from that channel.
p1, err := fpb1.providerChan.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
p1, err := fpb1.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider1Name)
} }
@ -378,7 +376,7 @@ func (s) TestStoreMultipleProviders(t *testing.T) {
} }
defer prov2.Close() defer prov2.Close()
// Grab the second provider from the channel. // Grab the second provider from the channel.
p2, err := fpb2.providerChan.Receive() p2, err := fpb2.providerChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider2Name) t.Fatalf("Timeout when expecting certProvider %q to be created", fakeProvider2Name)
} }
@ -390,8 +388,6 @@ func (s) TestStoreMultipleProviders(t *testing.T) {
fakeProv1.newKeyMaterial(km1, nil) fakeProv1.newKeyMaterial(km1, nil)
km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem") km2 := loadKeyMaterials(t, "x509/server2_cert.pem", "x509/server2_key.pem", "x509/client_ca_cert.pem")
fakeProv2.newKeyMaterial(km2, nil) fakeProv2.newKeyMaterial(km2, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil { if err := readAndVerifyKeyMaterial(ctx, prov1, km1); err != nil {
t.Fatal(err) t.Fatal(err)
} }

Просмотреть файл

@ -18,21 +18,11 @@
package testutils package testutils
import ( import (
"errors" "context"
"time"
) )
// ErrRecvTimeout is an error to indicate that a receive operation on the // DefaultChanBufferSize is the default buffer size of the underlying channel.
// channel timed out. const DefaultChanBufferSize = 1
var ErrRecvTimeout = errors.New("timed out when waiting for value on channel")
const (
// DefaultChanRecvTimeout is the default timeout for receive operations on the
// underlying channel.
DefaultChanRecvTimeout = 1 * time.Second
// DefaultChanBufferSize is the default buffer size of the underlying channel.
DefaultChanBufferSize = 1
)
// Channel wraps a generic channel and provides a timed receive operation. // Channel wraps a generic channel and provides a timed receive operation.
type Channel struct { type Channel struct {
@ -44,19 +34,32 @@ func (cwt *Channel) Send(value interface{}) {
cwt.ch <- value cwt.ch <- value
} }
// Receive returns the value received on the underlying channel, or // Receive returns the value received on the underlying channel, or the error
// ErrRecvTimeout if DefaultChanRecvTimeout amount of time elapses. // returned by ctx if it is closed or cancelled.
func (cwt *Channel) Receive() (interface{}, error) { func (cwt *Channel) Receive(ctx context.Context) (interface{}, error) {
timer := time.NewTimer(DefaultChanRecvTimeout)
select { select {
case <-timer.C: case <-ctx.Done():
return nil, ErrRecvTimeout return nil, ctx.Err()
case got := <-cwt.ch: case got := <-cwt.ch:
timer.Stop()
return got, nil return got, nil
} }
} }
// Replace clears the value on the underlying channel, and sends the new value.
//
// It's expected to be used with a size-1 channel, to only keep the most
// up-to-date item. This method is inherently racy when invoked concurrently
// from multiple goroutines.
func (cwt *Channel) Replace(value interface{}) {
for {
select {
case cwt.ch <- value:
return
case <-cwt.ch:
}
}
}
// NewChannel returns a new Channel. // NewChannel returns a new Channel.
func NewChannel() *Channel { func NewChannel() *Channel {
return NewChannelWithSize(DefaultChanBufferSize) return NewChannelWithSize(DefaultChanBufferSize)

Просмотреть файл

@ -17,6 +17,7 @@
package cdsbalancer package cdsbalancer
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -30,19 +31,19 @@ import (
"google.golang.org/grpc/connectivity" "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
xdsinternal "google.golang.org/grpc/xds/internal" xdsinternal "google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/balancer/edsbalancer" "google.golang.org/grpc/xds/internal/balancer/edsbalancer"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeclient" "google.golang.org/grpc/xds/internal/testutils/fakeclient"
) )
const ( const (
clusterName = "cluster1" clusterName = "cluster1"
serviceName = "service1" serviceName = "service1"
defaultTestTimeout = 2 * time.Second defaultTestTimeout = 1 * time.Second
) )
type s struct { type s struct {
@ -90,13 +91,13 @@ func invokeWatchCbAndWait(xdsC *fakeclient.Client, cdsW cdsWatchInfo, wantCCS ba
// to the test. // to the test.
type testEDSBalancer struct { type testEDSBalancer struct {
// ccsCh is a channel used to signal the receipt of a ClientConn update. // ccsCh is a channel used to signal the receipt of a ClientConn update.
ccsCh chan balancer.ClientConnState ccsCh *testutils.Channel
// scStateCh is a channel used to signal the receipt of a SubConn update. // scStateCh is a channel used to signal the receipt of a SubConn update.
scStateCh chan subConnWithState scStateCh *testutils.Channel
// resolverErrCh is a channel used to signal a resolver error. // resolverErrCh is a channel used to signal a resolver error.
resolverErrCh chan error resolverErrCh *testutils.Channel
// closeCh is a channel used to signal the closing of this balancer. // closeCh is a channel used to signal the closing of this balancer.
closeCh chan struct{} closeCh *testutils.Channel
} }
type subConnWithState struct { type subConnWithState struct {
@ -106,89 +107,86 @@ type subConnWithState struct {
func newTestEDSBalancer() *testEDSBalancer { func newTestEDSBalancer() *testEDSBalancer {
return &testEDSBalancer{ return &testEDSBalancer{
ccsCh: make(chan balancer.ClientConnState, 1), ccsCh: testutils.NewChannel(),
scStateCh: make(chan subConnWithState, 1), scStateCh: testutils.NewChannel(),
resolverErrCh: make(chan error, 1), resolverErrCh: testutils.NewChannel(),
closeCh: make(chan struct{}, 1), closeCh: testutils.NewChannel(),
} }
} }
func (tb *testEDSBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { func (tb *testEDSBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error {
tb.ccsCh <- ccs tb.ccsCh.Send(ccs)
return nil return nil
} }
func (tb *testEDSBalancer) ResolverError(err error) { func (tb *testEDSBalancer) ResolverError(err error) {
tb.resolverErrCh <- err tb.resolverErrCh.Send(err)
} }
func (tb *testEDSBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { func (tb *testEDSBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
tb.scStateCh <- subConnWithState{sc: sc, state: state} tb.scStateCh.Send(subConnWithState{sc: sc, state: state})
} }
func (tb *testEDSBalancer) Close() { func (tb *testEDSBalancer) Close() {
tb.closeCh <- struct{}{} tb.closeCh.Send(struct{}{})
} }
// waitForClientConnUpdate verifies if the testEDSBalancer receives the // waitForClientConnUpdate verifies if the testEDSBalancer receives the
// provided ClientConnState within a reasonable amount of time. // provided ClientConnState within a reasonable amount of time.
func (tb *testEDSBalancer) waitForClientConnUpdate(wantCCS balancer.ClientConnState) error { func (tb *testEDSBalancer) waitForClientConnUpdate(wantCCS balancer.ClientConnState) error {
timer := time.NewTimer(defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
select { defer cancel()
case <-timer.C: ccs, err := tb.ccsCh.Receive(ctx)
return errors.New("Timeout when expecting ClientConn update on EDS balancer") if err != nil {
case gotCCS := <-tb.ccsCh: return err
timer.Stop()
if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreUnexported(attributes.Attributes{})) {
return fmt.Errorf("received ClientConnState: %+v, want %+v", gotCCS, wantCCS)
}
return nil
} }
gotCCS := ccs.(balancer.ClientConnState)
if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreUnexported(attributes.Attributes{})) {
return fmt.Errorf("received ClientConnState: %+v, want %+v", gotCCS, wantCCS)
}
return nil
} }
// waitForSubConnUpdate verifies if the testEDSBalancer receives the provided // waitForSubConnUpdate verifies if the testEDSBalancer receives the provided
// SubConn update within a reasonable amount of time. // SubConn update within a reasonable amount of time.
func (tb *testEDSBalancer) waitForSubConnUpdate(wantSCS subConnWithState) error { func (tb *testEDSBalancer) waitForSubConnUpdate(wantSCS subConnWithState) error {
timer := time.NewTimer(defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
select { defer cancel()
case <-timer.C: scs, err := tb.scStateCh.Receive(ctx)
return errors.New("Timeout when expecting SubConn update on EDS balancer") if err != nil {
case gotSCS := <-tb.scStateCh: return err
timer.Stop()
if !cmp.Equal(gotSCS, wantSCS, cmp.AllowUnexported(subConnWithState{})) {
return fmt.Errorf("received SubConnState: %+v, want %+v", gotSCS, wantSCS)
}
return nil
} }
gotSCS := scs.(subConnWithState)
if !cmp.Equal(gotSCS, wantSCS, cmp.AllowUnexported(subConnWithState{})) {
return fmt.Errorf("received SubConnState: %+v, want %+v", gotSCS, wantSCS)
}
return nil
} }
// waitForResolverError verifies if the testEDSBalancer receives the // waitForResolverError verifies if the testEDSBalancer receives the
// provided resolver error within a reasonable amount of time. // provided resolver error within a reasonable amount of time.
func (tb *testEDSBalancer) waitForResolverError(wantErr error) error { func (tb *testEDSBalancer) waitForResolverError(wantErr error) error {
timer := time.NewTimer(defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
select { defer cancel()
case <-timer.C: gotErr, err := tb.resolverErrCh.Receive(ctx)
return errors.New("Timeout when expecting a resolver error") if err != nil {
case gotErr := <-tb.resolverErrCh: return err
timer.Stop()
if gotErr != wantErr {
return fmt.Errorf("received resolver error: %v, want %v", gotErr, wantErr)
}
return nil
} }
if gotErr != wantErr {
return fmt.Errorf("received resolver error: %v, want %v", gotErr, wantErr)
}
return nil
} }
// waitForClose verifies that the edsBalancer is closed with a reasonable // waitForClose verifies that the edsBalancer is closed with a reasonable
// amount of time. // amount of time.
func (tb *testEDSBalancer) waitForClose() error { func (tb *testEDSBalancer) waitForClose() error {
timer := time.NewTimer(defaultTestTimeout) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
select { defer cancel()
case <-timer.C: if _, err := tb.closeCh.Receive(ctx); err != nil {
return errors.New("Timeout when expecting a close") return err
case <-tb.closeCh:
timer.Stop()
return nil
} }
return nil
} }
// cdsCCS is a helper function to construct a good update passed from the // cdsCCS is a helper function to construct a good update passed from the
@ -254,7 +252,10 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal
if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil {
t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err)
} }
gotCluster, err := xdsC.WaitForWatchCluster()
ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
gotCluster, err := xdsC.WaitForWatchCluster(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) t.Fatalf("xdsClient.WatchCDS failed with error: %v", err)
} }
@ -328,7 +329,9 @@ func (s) TestUpdateClientConnState(t *testing.T) {
// When we wanted an error and got it, we should return early. // When we wanted an error and got it, we should return early.
return return
} }
gotCluster, err := xdsC.WaitForWatchCluster() ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
gotCluster, err := xdsC.WaitForWatchCluster(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchCDS failed with error: %v", err) t.Fatalf("xdsClient.WatchCDS failed with error: %v", err)
} }
@ -364,7 +367,9 @@ func (s) TestUpdateClientConnStateWithSameState(t *testing.T) {
if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil {
t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err)
} }
if _, err := xdsC.WaitForWatchCluster(); err != testutils.ErrRecvTimeout { ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if _, err := xdsC.WaitForWatchCluster(ctx); err != context.DeadlineExceeded {
t.Fatalf("waiting for WatchCluster() should have timed out, but returned error: %v", err) t.Fatalf("waiting for WatchCluster() should have timed out, but returned error: %v", err)
} }
} }
@ -422,13 +427,17 @@ func (s) TestHandleClusterUpdateError(t *testing.T) {
// And this is not a resource not found error, watch shouldn't be canceled. // And this is not a resource not found error, watch shouldn't be canceled.
err1 := errors.New("cdsBalancer resolver error 1") err1 := errors.New("cdsBalancer resolver error 1")
xdsC.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, err1) xdsC.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, err1)
if err := xdsC.WaitForCancelClusterWatch(); err == nil { ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsB.waitForResolverError(err1); err == nil { if err := edsB.waitForResolverError(err1); err == nil {
t.Fatal("eds balancer shouldn't get error (shouldn't be built yet)") t.Fatal("eds balancer shouldn't get error (shouldn't be built yet)")
} }
state, err := tcc.newPickerCh.Receive() ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
state, err := tcc.newPickerCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get picker, expect an error picker") t.Fatalf("failed to get picker, expect an error picker")
} }
@ -447,7 +456,9 @@ func (s) TestHandleClusterUpdateError(t *testing.T) {
// is not a resource not found error, watch shouldn't be canceled // is not a resource not found error, watch shouldn't be canceled
err2 := errors.New("cdsBalancer resolver error 2") err2 := errors.New("cdsBalancer resolver error 2")
xdsC.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, err2) xdsC.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, err2)
if err := xdsC.WaitForCancelClusterWatch(); err == nil { ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsB.waitForResolverError(err2); err != nil { if err := edsB.waitForResolverError(err2); err != nil {
@ -458,7 +469,9 @@ func (s) TestHandleClusterUpdateError(t *testing.T) {
// means CDS resource is removed, and eds should receive the error. // means CDS resource is removed, and eds should receive the error.
resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "cdsBalancer resource not found error") resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "cdsBalancer resource not found error")
xdsC.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, resourceErr) xdsC.InvokeWatchClusterCallback(xdsclient.ClusterUpdate{}, resourceErr)
if err := xdsC.WaitForCancelClusterWatch(); err == nil { ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err == nil {
t.Fatalf("want watch to be not canceled, watchForCancel should timeout") t.Fatalf("want watch to be not canceled, watchForCancel should timeout")
} }
if err := edsB.waitForResolverError(resourceErr); err != nil { if err := edsB.waitForResolverError(resourceErr); err != nil {
@ -479,13 +492,17 @@ func (s) TestResolverError(t *testing.T) {
// Not a resource not found error, watch shouldn't be canceled. // Not a resource not found error, watch shouldn't be canceled.
err1 := errors.New("cdsBalancer resolver error 1") err1 := errors.New("cdsBalancer resolver error 1")
cdsB.ResolverError(err1) cdsB.ResolverError(err1)
if err := xdsC.WaitForCancelClusterWatch(); err == nil { ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsB.waitForResolverError(err1); err == nil { if err := edsB.waitForResolverError(err1); err == nil {
t.Fatal("eds balancer shouldn't get error (shouldn't be built yet)") t.Fatal("eds balancer shouldn't get error (shouldn't be built yet)")
} }
state, err := tcc.newPickerCh.Receive() ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
state, err := tcc.newPickerCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get picker, expect an error picker") t.Fatalf("failed to get picker, expect an error picker")
} }
@ -504,7 +521,9 @@ func (s) TestResolverError(t *testing.T) {
// should receive the error. // should receive the error.
err2 := errors.New("cdsBalancer resolver error 2") err2 := errors.New("cdsBalancer resolver error 2")
cdsB.ResolverError(err2) cdsB.ResolverError(err2)
if err := xdsC.WaitForCancelClusterWatch(); err == nil { ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsB.waitForResolverError(err2); err != nil { if err := edsB.waitForResolverError(err2); err != nil {
@ -515,7 +534,9 @@ func (s) TestResolverError(t *testing.T) {
// receive the error. // receive the error.
resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "cdsBalancer resource not found error") resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "cdsBalancer resource not found error")
cdsB.ResolverError(resourceErr) cdsB.ResolverError(resourceErr)
if err := xdsC.WaitForCancelClusterWatch(); err != nil { ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err != nil {
t.Fatalf("want watch to be canceled, watchForCancel failed: %v", err) t.Fatalf("want watch to be canceled, watchForCancel failed: %v", err)
} }
if err := edsB.waitForResolverError(resourceErr); err != nil { if err := edsB.waitForResolverError(resourceErr); err != nil {
@ -559,7 +580,9 @@ func (s) TestClose(t *testing.T) {
} }
cdsB.Close() cdsB.Close()
if err := xdsC.WaitForCancelClusterWatch(); err != nil { ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelClusterWatch(ctx); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := edsB.waitForClose(); err != nil { if err := edsB.waitForClose(); err != nil {

Просмотреть файл

@ -20,10 +20,12 @@ package edsbalancer
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
wrapperspb "github.com/golang/protobuf/ptypes/wrappers" wrapperspb "github.com/golang/protobuf/ptypes/wrappers"
@ -34,17 +36,20 @@ import (
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpctest"
scpb "google.golang.org/grpc/internal/proto/grpc_service_config" scpb "google.golang.org/grpc/internal/proto/grpc_service_config"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/xds/internal/balancer/lrs" "google.golang.org/grpc/xds/internal/balancer/lrs"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/client/bootstrap"
"google.golang.org/grpc/xds/internal/testutils" xdstestutils "google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeclient" "google.golang.org/grpc/xds/internal/testutils/fakeclient"
_ "google.golang.org/grpc/xds/internal/client/v2" // V2 client registration. _ "google.golang.org/grpc/xds/internal/client/v2" // V2 client registration.
) )
const defaultTestTimeout = 1 * time.Second
func init() { func init() {
balancer.Register(&edsBalancerBuilder{}) balancer.Register(&edsBalancerBuilder{})
@ -52,7 +57,7 @@ func init() {
return &bootstrap.Config{ return &bootstrap.Config{
BalancerName: testBalancerNameFooBar, BalancerName: testBalancerNameFooBar,
Creds: grpc.WithInsecure(), Creds: grpc.WithInsecure(),
NodeProto: testutils.EmptyNodeProtoV2, NodeProto: xdstestutils.EmptyNodeProtoV2,
}, nil }, nil
} }
} }
@ -120,7 +125,10 @@ func (f *fakeEDSBalancer) updateState(priority priorityType, s balancer.State) {
func (f *fakeEDSBalancer) close() {} func (f *fakeEDSBalancer) close() {}
func (f *fakeEDSBalancer) waitForChildPolicy(wantPolicy *loadBalancingConfig) error { func (f *fakeEDSBalancer) waitForChildPolicy(wantPolicy *loadBalancingConfig) error {
val, err := f.childPolicy.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := f.childPolicy.Receive(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error waiting for childPolicy: %v", err) return fmt.Errorf("error waiting for childPolicy: %v", err)
} }
@ -132,7 +140,10 @@ func (f *fakeEDSBalancer) waitForChildPolicy(wantPolicy *loadBalancingConfig) er
} }
func (f *fakeEDSBalancer) waitForSubConnStateChange(wantState *scStateChange) error { func (f *fakeEDSBalancer) waitForSubConnStateChange(wantState *scStateChange) error {
val, err := f.subconnStateChange.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := f.subconnStateChange.Receive(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error waiting for subconnStateChange: %v", err) return fmt.Errorf("error waiting for subconnStateChange: %v", err)
} }
@ -144,7 +155,10 @@ func (f *fakeEDSBalancer) waitForSubConnStateChange(wantState *scStateChange) er
} }
func (f *fakeEDSBalancer) waitForEDSResponse(wantUpdate xdsclient.EndpointsUpdate) error { func (f *fakeEDSBalancer) waitForEDSResponse(wantUpdate xdsclient.EndpointsUpdate) error {
val, err := f.edsUpdate.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := f.edsUpdate.Receive(ctx)
if err != nil { if err != nil {
return fmt.Errorf("error waiting for edsUpdate: %v", err) return fmt.Errorf("error waiting for edsUpdate: %v", err)
} }
@ -176,7 +190,9 @@ func (*fakeSubConn) Connect() { panic("implement me")
func waitForNewXDSClientWithEDSWatch(t *testing.T, ch *testutils.Channel, wantName string) *fakeclient.Client { func waitForNewXDSClientWithEDSWatch(t *testing.T, ch *testutils.Channel, wantName string) *fakeclient.Client {
t.Helper() t.Helper()
val, err := ch.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := ch.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("error when waiting for a new xds client: %v", err) t.Fatalf("error when waiting for a new xds client: %v", err)
return nil return nil
@ -186,7 +202,7 @@ func waitForNewXDSClientWithEDSWatch(t *testing.T, ch *testutils.Channel, wantNa
t.Fatalf("xdsClient created to balancer: %v, want %v", xdsC.Name(), wantName) t.Fatalf("xdsClient created to balancer: %v, want %v", xdsC.Name(), wantName)
return nil return nil
} }
_, err = xdsC.WaitForWatchEDS() _, err = xdsC.WaitForWatchEDS(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
return nil return nil
@ -199,7 +215,9 @@ func waitForNewXDSClientWithEDSWatch(t *testing.T, ch *testutils.Channel, wantNa
func waitForNewEDSLB(t *testing.T, ch *testutils.Channel) *fakeEDSBalancer { func waitForNewEDSLB(t *testing.T, ch *testutils.Channel) *fakeEDSBalancer {
t.Helper() t.Helper()
val, err := ch.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := ch.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("error when waiting for a new edsLB: %v", err) t.Fatalf("error when waiting for a new edsLB: %v", err)
return nil return nil
@ -439,12 +457,14 @@ func (s) TestErrorFromXDSClientUpdate(t *testing.T) {
} }
defer edsB.Close() defer edsB.Close()
edsB.UpdateClientConnState(balancer.ClientConnState{ if err := edsB.UpdateClientConnState(balancer.ClientConnState{
BalancerConfig: &EDSConfig{ BalancerConfig: &EDSConfig{
BalancerName: testBalancerNameFooBar, BalancerName: testBalancerNameFooBar,
EDSServiceName: testEDSClusterName, EDSServiceName: testEDSClusterName,
}, },
}) }); err != nil {
t.Fatal(err)
}
xdsC := waitForNewXDSClientWithEDSWatch(t, xdsClientCh, testBalancerNameFooBar) xdsC := waitForNewXDSClientWithEDSWatch(t, xdsClientCh, testBalancerNameFooBar)
xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil)
@ -455,7 +475,10 @@ func (s) TestErrorFromXDSClientUpdate(t *testing.T) {
connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error")
xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, connectionErr) xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, connectionErr)
if err := xdsC.WaitForCancelEDSWatch(); err == nil {
ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelEDSWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err == nil { if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err == nil {
@ -467,7 +490,9 @@ func (s) TestErrorFromXDSClientUpdate(t *testing.T) {
// Even if error is resource not found, watch shouldn't be canceled, because // Even if error is resource not found, watch shouldn't be canceled, because
// this is an EDS resource removed (and xds client actually never sends this // this is an EDS resource removed (and xds client actually never sends this
// error, but we still handles it). // error, but we still handles it).
if err := xdsC.WaitForCancelEDSWatch(); err == nil { ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelEDSWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err != nil { if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err != nil {
@ -496,12 +521,14 @@ func (s) TestErrorFromResolver(t *testing.T) {
} }
defer edsB.Close() defer edsB.Close()
edsB.UpdateClientConnState(balancer.ClientConnState{ if err := edsB.UpdateClientConnState(balancer.ClientConnState{
BalancerConfig: &EDSConfig{ BalancerConfig: &EDSConfig{
BalancerName: testBalancerNameFooBar, BalancerName: testBalancerNameFooBar,
EDSServiceName: testEDSClusterName, EDSServiceName: testEDSClusterName,
}, },
}) }); err != nil {
t.Fatal(err)
}
xdsC := waitForNewXDSClientWithEDSWatch(t, xdsClientCh, testBalancerNameFooBar) xdsC := waitForNewXDSClientWithEDSWatch(t, xdsClientCh, testBalancerNameFooBar)
xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil)
@ -512,7 +539,10 @@ func (s) TestErrorFromResolver(t *testing.T) {
connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error")
edsB.ResolverError(connectionErr) edsB.ResolverError(connectionErr)
if err := xdsC.WaitForCancelEDSWatch(); err == nil {
ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelEDSWatch(ctx); err == nil {
t.Fatal("watch was canceled, want not canceled (timeout error)") t.Fatal("watch was canceled, want not canceled (timeout error)")
} }
if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err == nil { if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err == nil {
@ -521,7 +551,9 @@ func (s) TestErrorFromResolver(t *testing.T) {
resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error")
edsB.ResolverError(resourceErr) edsB.ResolverError(resourceErr)
if err := xdsC.WaitForCancelEDSWatch(); err != nil { ctx, ctxCancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer ctxCancel()
if err := xdsC.WaitForCancelEDSWatch(ctx); err != nil {
t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err)
} }
if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err != nil { if err := edsLB.waitForEDSResponse(xdsclient.EndpointsUpdate{}); err != nil {

Просмотреть файл

@ -19,10 +19,10 @@
package edsbalancer package edsbalancer
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"time"
xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -31,11 +31,12 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/attributes" "google.golang.org/grpc/attributes"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
xdsinternal "google.golang.org/grpc/xds/internal" xdsinternal "google.golang.org/grpc/xds/internal"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/client/bootstrap"
"google.golang.org/grpc/xds/internal/testutils" xdstestutils "google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeclient" "google.golang.org/grpc/xds/internal/testutils/fakeclient"
"google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/testutils/fakeserver"
"google.golang.org/grpc/xds/internal/version" "google.golang.org/grpc/xds/internal/version"
@ -51,13 +52,16 @@ var (
func verifyExpectedRequests(fs *fakeserver.Server, resourceNames ...string) error { func verifyExpectedRequests(fs *fakeserver.Server, resourceNames ...string) error {
wantReq := &xdspb.DiscoveryRequest{ wantReq := &xdspb.DiscoveryRequest{
TypeUrl: version.V2EndpointsURL, TypeUrl: version.V2EndpointsURL,
Node: testutils.EmptyNodeProtoV2, Node: xdstestutils.EmptyNodeProtoV2,
} }
for _, name := range resourceNames { for _, name := range resourceNames {
if name != "" { if name != "" {
wantReq.ResourceNames = []string{name} wantReq.ResourceNames = []string{name}
} }
req, err := fs.XDSRequestChan.TimedReceive(time.Millisecond * 100)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
req, err := fs.XDSRequestChan.Receive(ctx)
if err != nil { if err != nil {
return fmt.Errorf("timed out when expecting request {%+v} at fake server", wantReq) return fmt.Errorf("timed out when expecting request {%+v} at fake server", wantReq)
} }
@ -98,7 +102,7 @@ func (s) TestClientWrapperWatchEDS(t *testing.T) {
return &bootstrap.Config{ return &bootstrap.Config{
BalancerName: fakeServer.Address, BalancerName: fakeServer.Address,
Creds: grpc.WithInsecure(), Creds: grpc.WithInsecure(),
NodeProto: testutils.EmptyNodeProtoV2, NodeProto: xdstestutils.EmptyNodeProtoV2,
}, nil }, nil
} }
defer func() { bootstrapConfigNew = oldBootstrapConfigNew }() defer func() { bootstrapConfigNew = oldBootstrapConfigNew }()
@ -109,7 +113,9 @@ func (s) TestClientWrapperWatchEDS(t *testing.T) {
BalancerName: fakeServer.Address, BalancerName: fakeServer.Address,
EDSServiceName: "", EDSServiceName: "",
}, nil) }, nil)
if _, err := fakeServer.NewConnChan.TimedReceive(1 * time.Second); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fakeServer.NewConnChan.Receive(ctx); err != nil {
t.Fatal("Failed to connect to fake server") t.Fatal("Failed to connect to fake server")
} }
t.Log("Client connection established to fake server...") t.Log("Client connection established to fake server...")
@ -161,7 +167,10 @@ func (s) TestClientWrapperHandleUpdateError(t *testing.T) {
xdsC := fakeclient.NewClient() xdsC := fakeclient.NewClient()
cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC)) cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC))
gotCluster, err := xdsC.WaitForWatchEDS()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotCluster, err := xdsC.WaitForWatchEDS(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
} }
@ -175,7 +184,9 @@ func (s) TestClientWrapperHandleUpdateError(t *testing.T) {
// //
// TODO: check for loseContact() when errors indicating "lose contact" are // TODO: check for loseContact() when errors indicating "lose contact" are
// handled correctly. // handled correctly.
gotUpdate, err := edsRespChan.Receive() ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotUpdate, err := edsRespChan.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("edsBalancer failed to get edsUpdate %v", err) t.Fatalf("edsBalancer failed to get edsUpdate %v", err)
} }
@ -199,10 +210,13 @@ func (s) TestClientWrapperGetsXDSClientInAttributes(t *testing.T) {
cw := newXDSClientWrapper(nil, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil, nil) cw := newXDSClientWrapper(nil, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil, nil)
defer cw.close() defer cw.close()
// Verify that the eds watch is registered for the expected resource name.
xdsC1 := fakeclient.NewClient() xdsC1 := fakeclient.NewClient()
cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC1)) cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC1))
gotCluster, err := xdsC1.WaitForWatchEDS()
// Verify that the eds watch is registered for the expected resource name.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotCluster, err := xdsC1.WaitForWatchEDS(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
} }
@ -216,7 +230,7 @@ func (s) TestClientWrapperGetsXDSClientInAttributes(t *testing.T) {
// close client that are passed through attributes). // close client that are passed through attributes).
xdsC2 := fakeclient.NewClient() xdsC2 := fakeclient.NewClient()
cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC2)) cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC2))
gotCluster, err = xdsC2.WaitForWatchEDS() gotCluster, err = xdsC2.WaitForWatchEDS(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
} }
@ -224,7 +238,7 @@ func (s) TestClientWrapperGetsXDSClientInAttributes(t *testing.T) {
t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName) t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName)
} }
if err := xdsC1.WaitForClose(); err != testutils.ErrRecvTimeout { if err := xdsC1.WaitForClose(ctx); err != context.DeadlineExceeded {
t.Fatalf("clientWrapper closed xdsClient received in attributes") t.Fatalf("clientWrapper closed xdsClient received in attributes")
} }
} }

Просмотреть файл

@ -19,6 +19,7 @@
package edsbalancer package edsbalancer
import ( import (
"context"
"testing" "testing"
"google.golang.org/grpc/attributes" "google.golang.org/grpc/attributes"
@ -41,12 +42,16 @@ func (s) TestXDSLoadReporting(t *testing.T) {
defer edsB.Close() defer edsB.Close()
xdsC := fakeclient.NewClient() xdsC := fakeclient.NewClient()
edsB.UpdateClientConnState(balancer.ClientConnState{ if err := edsB.UpdateClientConnState(balancer.ClientConnState{
ResolverState: resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)}, ResolverState: resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
BalancerConfig: &EDSConfig{LrsLoadReportingServerName: new(string)}, BalancerConfig: &EDSConfig{LrsLoadReportingServerName: new(string)},
}) }); err != nil {
t.Fatal(err)
}
gotCluster, err := xdsC.WaitForWatchEDS() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotCluster, err := xdsC.WaitForWatchEDS(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
} }
@ -54,7 +59,7 @@ func (s) TestXDSLoadReporting(t *testing.T) {
t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName) t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName)
} }
got, err := xdsC.WaitForReportLoad() got, err := xdsC.WaitForReportLoad(ctx)
if err != nil { if err != nil {
t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) t.Fatalf("xdsClient.ReportLoad failed with error: %v", err)
} }

Просмотреть файл

@ -19,13 +19,15 @@
package client package client
import ( import (
"context"
"testing" "testing"
"time" "time"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/client/bootstrap"
"google.golang.org/grpc/xds/internal/testutils" xdstestutils "google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/version" "google.golang.org/grpc/xds/internal/version"
) )
@ -47,6 +49,7 @@ const (
testEDSName = "test-eds" testEDSName = "test-eds"
defaultTestWatchExpiryTimeout = 500 * time.Millisecond defaultTestWatchExpiryTimeout = 500 * time.Millisecond
defaultTestTimeout = 1 * time.Second
) )
func clientOpts(balancerName string, overrideWatchExpiryTImeout bool) Options { func clientOpts(balancerName string, overrideWatchExpiryTImeout bool) Options {
@ -58,7 +61,7 @@ func clientOpts(balancerName string, overrideWatchExpiryTImeout bool) Options {
Config: bootstrap.Config{ Config: bootstrap.Config{
BalancerName: balancerName, BalancerName: balancerName,
Creds: grpc.WithInsecure(), Creds: grpc.WithInsecure(),
NodeProto: testutils.EmptyNodeProtoV2, NodeProto: xdstestutils.EmptyNodeProtoV2,
}, },
WatchExpiryTimeout: watchExpiryTimeout, WatchExpiryTimeout: watchExpiryTimeout,
} }
@ -132,12 +135,18 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) {
clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err})
// Calls another watch inline, to ensure there's deadlock. // Calls another watch inline, to ensure there's deadlock.
c.WatchCluster("another-random-name", func(ClusterUpdate, error) {}) c.WatchCluster("another-random-name", func(ClusterUpdate, error) {})
if _, err := v2Client.addWatches[ClusterResource].Receive(); firstTime && err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); firstTime && err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
firstTime = false firstTime = false
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -146,7 +155,7 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := clusterUpdateCh.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateCh.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -155,7 +164,7 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) {
testCDSName: wantUpdate2, testCDSName: wantUpdate2,
}) })
if u, err := clusterUpdateCh.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) { if u, err := clusterUpdateCh.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
} }

Просмотреть файл

@ -19,9 +19,10 @@
package client package client
import ( import (
"context"
"testing" "testing"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/internal/testutils"
) )
type clusterUpdateErr struct { type clusterUpdateErr struct {
@ -52,7 +53,10 @@ func (s) TestClusterWatch(t *testing.T) {
cancelWatch := c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { cancelWatch := c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) {
clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -67,7 +71,7 @@ func (s) TestClusterWatch(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := clusterUpdateCh.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateCh.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -77,7 +81,7 @@ func (s) TestClusterWatch(t *testing.T) {
"randomName": {}, "randomName": {},
}) })
if u, err := clusterUpdateCh.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateCh.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected clusterUpdate: %+v, %v, want channel recv timeout", u, err) t.Errorf("unexpected clusterUpdate: %+v, %v, want channel recv timeout", u, err)
} }
@ -87,7 +91,7 @@ func (s) TestClusterWatch(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := clusterUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := clusterUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -109,6 +113,9 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) {
var clusterUpdateChs []*testutils.Channel var clusterUpdateChs []*testutils.Channel
var cancelLastWatch func() var cancelLastWatch func()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
const count = 2 const count = 2
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
clusterUpdateCh := testutils.NewChannel() clusterUpdateCh := testutils.NewChannel()
@ -116,8 +123,13 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) {
cancelLastWatch = c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { cancelLastWatch = c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) {
clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -127,7 +139,7 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := clusterUpdateChs[i].Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateChs[i].Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("i=%v, unexpected clusterUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected clusterUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
@ -139,12 +151,12 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count-1; i++ { for i := 0; i < count-1; i++ {
if u, err := clusterUpdateChs[i].Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateChs[i].Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("i=%v, unexpected clusterUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected clusterUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := clusterUpdateChs[count-1].TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := clusterUpdateChs[count-1].Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -167,14 +179,21 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) {
const count = 2 const count = 2
// Two watches for the same name. // Two watches for the same name.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
clusterUpdateCh := testutils.NewChannel() clusterUpdateCh := testutils.NewChannel()
clusterUpdateChs = append(clusterUpdateChs, clusterUpdateCh) clusterUpdateChs = append(clusterUpdateChs, clusterUpdateCh)
c.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) { c.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) {
clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -183,7 +202,7 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) {
c.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) { c.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) {
clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil { if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -195,12 +214,12 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := clusterUpdateChs[i].Receive(); err != nil || u != (clusterUpdateErr{wantUpdate1, nil}) { if u, err := clusterUpdateChs[i].Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate1, nil}) {
t.Errorf("i=%v, unexpected clusterUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected clusterUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := clusterUpdateCh2.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) { if u, err := clusterUpdateCh2.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -223,7 +242,10 @@ func (s) TestClusterWatchAfterCache(t *testing.T) {
c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) {
clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -232,7 +254,7 @@ func (s) TestClusterWatchAfterCache(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := clusterUpdateCh.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateCh.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -241,17 +263,19 @@ func (s) TestClusterWatchAfterCache(t *testing.T) {
c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { c.WatchCluster(testCDSName, func(update ClusterUpdate, err error) {
clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err})
}) })
if n, err := v2Client.addWatches[ClusterResource].Receive(); err == nil { if n, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err)
} }
// New watch should receives the update. // New watch should receives the update.
if u, err := clusterUpdateCh2.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := clusterUpdateCh2.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
// Old watch should see nothing. // Old watch should see nothing.
if u, err := clusterUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := clusterUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected clusterUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -275,11 +299,14 @@ func (s) TestClusterWatchExpiryTimer(t *testing.T) {
c.WatchCluster(testCDSName, func(u ClusterUpdate, err error) { c.WatchCluster(testCDSName, func(u ClusterUpdate, err error) {
clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
u, err := clusterUpdateCh.TimedReceive(defaultTestWatchExpiryTimeout * 2) u, err := clusterUpdateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get clusterUpdate: %v", err) t.Fatalf("failed to get clusterUpdate: %v", err)
} }
@ -311,7 +338,10 @@ func (s) TestClusterWatchExpiryTimerStop(t *testing.T) {
c.WatchCluster(testCDSName, func(u ClusterUpdate, err error) { c.WatchCluster(testCDSName, func(u ClusterUpdate, err error) {
clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err}) clusterUpdateCh.Send(clusterUpdateErr{u: u, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -320,13 +350,13 @@ func (s) TestClusterWatchExpiryTimerStop(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := clusterUpdateCh.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) { if u, err := clusterUpdateCh.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
// Wait for an error, the error should never happen. // Wait for an error, the error should never happen.
u, err := clusterUpdateCh.TimedReceive(defaultTestWatchExpiryTimeout * 2) u, err := clusterUpdateCh.Receive(ctx)
if err != testutils.ErrRecvTimeout { if err != context.DeadlineExceeded {
t.Fatalf("got unexpected: %v, %v, want recv timeout", u.(clusterUpdateErr).u, u.(clusterUpdateErr).err) t.Fatalf("got unexpected: %v, %v, want recv timeout", u.(clusterUpdateErr).u, u.(clusterUpdateErr).err)
} }
} }
@ -353,7 +383,10 @@ func (s) TestClusterResourceRemoved(t *testing.T) {
c.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) { c.WatchCluster(testCDSName+"1", func(update ClusterUpdate, err error) {
clusterUpdateCh1.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh1.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
// Another watch for a different name. // Another watch for a different name.
@ -361,7 +394,7 @@ func (s) TestClusterResourceRemoved(t *testing.T) {
c.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) { c.WatchCluster(testCDSName+"2", func(update ClusterUpdate, err error) {
clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err}) clusterUpdateCh2.Send(clusterUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ClusterResource].Receive(); err != nil { if _, err := v2Client.addWatches[ClusterResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -372,11 +405,11 @@ func (s) TestClusterResourceRemoved(t *testing.T) {
testCDSName + "2": wantUpdate2, testCDSName + "2": wantUpdate2,
}) })
if u, err := clusterUpdateCh1.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate1, nil}) { if u, err := clusterUpdateCh1.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate1, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
if u, err := clusterUpdateCh2.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) { if u, err := clusterUpdateCh2.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -386,12 +419,12 @@ func (s) TestClusterResourceRemoved(t *testing.T) {
}) })
// watcher 1 should get an error. // watcher 1 should get an error.
if u, err := clusterUpdateCh1.Receive(); err != nil || ErrType(u.(clusterUpdateErr).err) != ErrorTypeResourceNotFound { if u, err := clusterUpdateCh1.Receive(ctx); err != nil || ErrType(u.(clusterUpdateErr).err) != ErrorTypeResourceNotFound {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err)
} }
// watcher 2 should get the same update again. // watcher 2 should get the same update again.
if u, err := clusterUpdateCh2.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) { if u, err := clusterUpdateCh2.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -401,12 +434,14 @@ func (s) TestClusterResourceRemoved(t *testing.T) {
}) })
// watcher 1 should get an error. // watcher 1 should get an error.
if u, err := clusterUpdateCh1.Receive(); err != testutils.ErrRecvTimeout { if u, err := clusterUpdateCh1.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected clusterUpdate: %v, want receiving from channel timeout", u) t.Errorf("unexpected clusterUpdate: %v, want receiving from channel timeout", u)
} }
// watcher 2 should get the same update again. // watcher 2 should get the same update again.
if u, err := clusterUpdateCh2.Receive(); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := clusterUpdateCh2.Receive(ctx); err != nil || u != (clusterUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected clusterUpdate: %v, error receiving from channel: %v", u, err)
} }
} }

Просмотреть файл

@ -19,13 +19,14 @@
package client package client
import ( import (
"context"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/testutils"
) )
var ( var (
@ -71,7 +72,10 @@ func (s) TestEndpointsWatch(t *testing.T) {
cancelWatch := c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { cancelWatch := c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) {
endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[EndpointsResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -80,7 +84,7 @@ func (s) TestEndpointsWatch(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := endpointsUpdateCh.Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) { if u, err := endpointsUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) {
t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -89,7 +93,7 @@ func (s) TestEndpointsWatch(t *testing.T) {
"randomName": {}, "randomName": {},
}) })
if u, err := endpointsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := endpointsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err)
} }
@ -99,7 +103,9 @@ func (s) TestEndpointsWatch(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := endpointsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := endpointsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -123,14 +129,21 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) {
var cancelLastWatch func() var cancelLastWatch func()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
endpointsUpdateCh := testutils.NewChannel() endpointsUpdateCh := testutils.NewChannel()
endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh) endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh)
cancelLastWatch = c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { cancelLastWatch = c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) {
endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[EndpointsResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -140,7 +153,7 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := endpointsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) { if u, err := endpointsUpdateChs[i].Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) {
t.Errorf("i=%v, unexpected endpointsUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected endpointsUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
@ -152,12 +165,12 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count-1; i++ { for i := 0; i < count-1; i++ {
if u, err := endpointsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) { if u, err := endpointsUpdateChs[i].Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) {
t.Errorf("i=%v, unexpected endpointsUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected endpointsUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := endpointsUpdateChs[count-1].TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := endpointsUpdateChs[count-1].Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -180,14 +193,21 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) {
const count = 2 const count = 2
// Two watches for the same name. // Two watches for the same name.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
endpointsUpdateCh := testutils.NewChannel() endpointsUpdateCh := testutils.NewChannel()
endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh) endpointsUpdateChs = append(endpointsUpdateChs, endpointsUpdateCh)
c.WatchEndpoints(testCDSName+"1", func(update EndpointsUpdate, err error) { c.WatchEndpoints(testCDSName+"1", func(update EndpointsUpdate, err error) {
endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[EndpointsResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -196,7 +216,7 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) {
c.WatchEndpoints(testCDSName+"2", func(update EndpointsUpdate, err error) { c.WatchEndpoints(testCDSName+"2", func(update EndpointsUpdate, err error) {
endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[EndpointsResource].Receive(); err != nil { if _, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -208,12 +228,12 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := endpointsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate1, nil}, endpointsCmpOpts...) { if u, err := endpointsUpdateChs[i].Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate1, nil}, endpointsCmpOpts...) {
t.Errorf("i=%v, unexpected endpointsUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected endpointsUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := endpointsUpdateCh2.Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate2, nil}, endpointsCmpOpts...) { if u, err := endpointsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate2, nil}, endpointsCmpOpts...) {
t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -236,7 +256,10 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) {
c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) {
endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[EndpointsResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -245,7 +268,7 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) {
testCDSName: wantUpdate, testCDSName: wantUpdate,
}) })
if u, err := endpointsUpdateCh.Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) { if u, err := endpointsUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) {
t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -254,17 +277,19 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) {
c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) {
endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh2.Send(endpointsUpdateErr{u: update, err: err})
}) })
if n, err := v2Client.addWatches[EndpointsResource].Receive(); err == nil { if n, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err)
} }
// New watch should receives the update. // New watch should receives the update.
if u, err := endpointsUpdateCh2.Receive(); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := endpointsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, endpointsUpdateErr{wantUpdate, nil}, endpointsCmpOpts...) {
t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected endpointsUpdate: %v, error receiving from channel: %v", u, err)
} }
// Old watch should see nothing. // Old watch should see nothing.
if u, err := endpointsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := endpointsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected endpointsUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -288,11 +313,13 @@ func (s) TestEndpointsWatchExpiryTimer(t *testing.T) {
c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { c.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) {
endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[EndpointsResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[EndpointsResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
u, err := endpointsUpdateCh.TimedReceive(defaultTestWatchExpiryTimeout * 2) u, err := endpointsUpdateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get endpointsUpdate: %v", err) t.Fatalf("failed to get endpointsUpdate: %v", err)
} }

Просмотреть файл

@ -19,9 +19,10 @@
package client package client
import ( import (
"context"
"testing" "testing"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/internal/testutils"
) )
type ldsUpdateErr struct { type ldsUpdateErr struct {
@ -49,7 +50,10 @@ func (s) TestLDSWatch(t *testing.T) {
cancelWatch := c.watchLDS(testLDSName, func(update ListenerUpdate, err error) { cancelWatch := c.watchLDS(testLDSName, func(update ListenerUpdate, err error) {
ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -58,7 +62,7 @@ func (s) TestLDSWatch(t *testing.T) {
testLDSName: wantUpdate, testLDSName: wantUpdate,
}) })
if u, err := ldsUpdateCh.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) { if u, err := ldsUpdateCh.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -68,7 +72,7 @@ func (s) TestLDSWatch(t *testing.T) {
"randomName": {}, "randomName": {},
}) })
if u, err := ldsUpdateCh.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) { if u, err := ldsUpdateCh.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err)
} }
@ -78,7 +82,7 @@ func (s) TestLDSWatch(t *testing.T) {
testLDSName: wantUpdate, testLDSName: wantUpdate,
}) })
if u, err := ldsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := ldsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -102,14 +106,21 @@ func (s) TestLDSTwoWatchSameResourceName(t *testing.T) {
var cancelLastWatch func() var cancelLastWatch func()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
ldsUpdateCh := testutils.NewChannel() ldsUpdateCh := testutils.NewChannel()
ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh)
cancelLastWatch = c.watchLDS(testLDSName, func(update ListenerUpdate, err error) { cancelLastWatch = c.watchLDS(testLDSName, func(update ListenerUpdate, err error) {
ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -119,7 +130,7 @@ func (s) TestLDSTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := ldsUpdateChs[i].Receive(); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) { if u, err := ldsUpdateChs[i].Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) {
t.Errorf("i=%v, unexpected ListenerUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected ListenerUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
@ -131,12 +142,12 @@ func (s) TestLDSTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count-1; i++ { for i := 0; i < count-1; i++ {
if u, err := ldsUpdateChs[i].Receive(); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) { if u, err := ldsUpdateChs[i].Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) {
t.Errorf("i=%v, unexpected ListenerUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected ListenerUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := ldsUpdateChs[count-1].TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := ldsUpdateChs[count-1].Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -159,14 +170,21 @@ func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) {
const count = 2 const count = 2
// Two watches for the same name. // Two watches for the same name.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
ldsUpdateCh := testutils.NewChannel() ldsUpdateCh := testutils.NewChannel()
ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh) ldsUpdateChs = append(ldsUpdateChs, ldsUpdateCh)
c.watchLDS(testLDSName+"1", func(update ListenerUpdate, err error) { c.watchLDS(testLDSName+"1", func(update ListenerUpdate, err error) {
ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -175,7 +193,7 @@ func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) {
c.watchLDS(testLDSName+"2", func(update ListenerUpdate, err error) { c.watchLDS(testLDSName+"2", func(update ListenerUpdate, err error) {
ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -187,12 +205,12 @@ func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := ldsUpdateChs[i].Receive(); err != nil || u != (ldsUpdateErr{wantUpdate1, nil}) { if u, err := ldsUpdateChs[i].Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate1, nil}) {
t.Errorf("i=%v, unexpected ListenerUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected ListenerUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := ldsUpdateCh2.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) { if u, err := ldsUpdateCh2.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -215,7 +233,10 @@ func (s) TestLDSWatchAfterCache(t *testing.T) {
c.watchLDS(testLDSName, func(update ListenerUpdate, err error) { c.watchLDS(testLDSName, func(update ListenerUpdate, err error) {
ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -224,7 +245,7 @@ func (s) TestLDSWatchAfterCache(t *testing.T) {
testLDSName: wantUpdate, testLDSName: wantUpdate,
}) })
if u, err := ldsUpdateCh.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) { if u, err := ldsUpdateCh.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -233,17 +254,19 @@ func (s) TestLDSWatchAfterCache(t *testing.T) {
c.watchLDS(testLDSName, func(update ListenerUpdate, err error) { c.watchLDS(testLDSName, func(update ListenerUpdate, err error) {
ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err})
}) })
if n, err := v2Client.addWatches[ListenerResource].Receive(); err == nil { if n, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err)
} }
// New watch should receives the update. // New watch should receives the update.
if u, err := ldsUpdateCh2.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := ldsUpdateCh2.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
// Old watch should see nothing. // Old watch should see nothing.
if u, err := ldsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := ldsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected ListenerUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -270,7 +293,10 @@ func (s) TestLDSResourceRemoved(t *testing.T) {
c.watchLDS(testLDSName+"1", func(update ListenerUpdate, err error) { c.watchLDS(testLDSName+"1", func(update ListenerUpdate, err error) {
ldsUpdateCh1.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh1.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
// Another watch for a different name. // Another watch for a different name.
@ -278,7 +304,7 @@ func (s) TestLDSResourceRemoved(t *testing.T) {
c.watchLDS(testLDSName+"2", func(update ListenerUpdate, err error) { c.watchLDS(testLDSName+"2", func(update ListenerUpdate, err error) {
ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err}) ldsUpdateCh2.Send(ldsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -289,11 +315,11 @@ func (s) TestLDSResourceRemoved(t *testing.T) {
testLDSName + "2": wantUpdate2, testLDSName + "2": wantUpdate2,
}) })
if u, err := ldsUpdateCh1.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate1, nil}) { if u, err := ldsUpdateCh1.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate1, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
if u, err := ldsUpdateCh2.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) { if u, err := ldsUpdateCh2.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -303,12 +329,12 @@ func (s) TestLDSResourceRemoved(t *testing.T) {
}) })
// watcher 1 should get an error. // watcher 1 should get an error.
if u, err := ldsUpdateCh1.Receive(); err != nil || ErrType(u.(ldsUpdateErr).err) != ErrorTypeResourceNotFound { if u, err := ldsUpdateCh1.Receive(ctx); err != nil || ErrType(u.(ldsUpdateErr).err) != ErrorTypeResourceNotFound {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err)
} }
// watcher 2 should get the same update again. // watcher 2 should get the same update again.
if u, err := ldsUpdateCh2.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) { if u, err := ldsUpdateCh2.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -318,12 +344,14 @@ func (s) TestLDSResourceRemoved(t *testing.T) {
}) })
// watcher 1 should get an error. // watcher 1 should get an error.
if u, err := ldsUpdateCh1.Receive(); err != testutils.ErrRecvTimeout { if u, err := ldsUpdateCh1.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u) t.Errorf("unexpected ListenerUpdate: %v, want receiving from channel timeout", u)
} }
// watcher 2 should get the same update again. // watcher 2 should get the same update again.
if u, err := ldsUpdateCh2.Receive(); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := ldsUpdateCh2.Receive(ctx); err != nil || u != (ldsUpdateErr{wantUpdate2, nil}) {
t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected ListenerUpdate: %v, error receiving from channel: %v", u, err)
} }
} }

Просмотреть файл

@ -19,10 +19,11 @@
package client package client
import ( import (
"context"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/internal/testutils"
) )
type rdsUpdateErr struct { type rdsUpdateErr struct {
@ -50,7 +51,9 @@ func (s) TestRDSWatch(t *testing.T) {
cancelWatch := c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) { cancelWatch := c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) {
rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -59,7 +62,7 @@ func (s) TestRDSWatch(t *testing.T) {
testRDSName: wantUpdate, testRDSName: wantUpdate,
}) })
if u, err := rdsUpdateCh.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { if u, err := rdsUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -68,7 +71,7 @@ func (s) TestRDSWatch(t *testing.T) {
"randomName": {}, "randomName": {},
}) })
if u, err := rdsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := rdsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err)
} }
@ -78,7 +81,9 @@ func (s) TestRDSWatch(t *testing.T) {
testRDSName: wantUpdate, testRDSName: wantUpdate,
}) })
if u, err := rdsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := rdsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -102,14 +107,21 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) {
var cancelLastWatch func() var cancelLastWatch func()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
rdsUpdateCh := testutils.NewChannel() rdsUpdateCh := testutils.NewChannel()
rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh) rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh)
cancelLastWatch = c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) { cancelLastWatch = c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) {
rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -119,7 +131,7 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := rdsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { if u, err := rdsUpdateChs[i].Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("i=%v, unexpected RouteConfigUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected RouteConfigUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
@ -131,12 +143,12 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) {
}) })
for i := 0; i < count-1; i++ { for i := 0; i < count-1; i++ {
if u, err := rdsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { if u, err := rdsUpdateChs[i].Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("i=%v, unexpected RouteConfigUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected RouteConfigUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := rdsUpdateChs[count-1].TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := rdsUpdateChs[count-1].Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -159,14 +171,21 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) {
const count = 2 const count = 2
// Two watches for the same name. // Two watches for the same name.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
rdsUpdateCh := testutils.NewChannel() rdsUpdateCh := testutils.NewChannel()
rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh) rdsUpdateChs = append(rdsUpdateChs, rdsUpdateCh)
c.watchRDS(testRDSName+"1", func(update RouteConfigUpdate, err error) { c.watchRDS(testRDSName+"1", func(update RouteConfigUpdate, err error) {
rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); i == 0 && err != nil {
t.Fatalf("want new watch to start, got error %v", err) if i == 0 {
// A new watch is registered on the underlying API client only for
// the first iteration because we are using the same resource name.
if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err)
}
} }
} }
@ -175,7 +194,7 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) {
c.watchRDS(testRDSName+"2", func(update RouteConfigUpdate, err error) { c.watchRDS(testRDSName+"2", func(update RouteConfigUpdate, err error) {
rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -187,12 +206,12 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) {
}) })
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if u, err := rdsUpdateChs[i].Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate1, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { if u, err := rdsUpdateChs[i].Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate1, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("i=%v, unexpected RouteConfigUpdate: %v, error receiving from channel: %v", i, u, err) t.Errorf("i=%v, unexpected RouteConfigUpdate: %v, error receiving from channel: %v", i, u, err)
} }
} }
if u, err := rdsUpdateCh2.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate2, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { if u, err := rdsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate2, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -215,7 +234,9 @@ func (s) TestRDSWatchAfterCache(t *testing.T) {
c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) { c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) {
rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -224,7 +245,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) {
testRDSName: wantUpdate, testRDSName: wantUpdate,
}) })
if u, err := rdsUpdateCh.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { if u, err := rdsUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -233,17 +254,19 @@ func (s) TestRDSWatchAfterCache(t *testing.T) {
c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) { c.watchRDS(testRDSName, func(update RouteConfigUpdate, err error) {
rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err}) rdsUpdateCh2.Send(rdsUpdateErr{u: update, err: err})
}) })
if n, err := v2Client.addWatches[RouteConfigResource].Receive(); err == nil { if n, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err) t.Fatalf("want no new watch to start (recv timeout), got resource name: %v error %v", n, err)
} }
// New watch should receives the update. // New watch should receives the update.
if u, err := rdsUpdateCh2.Receive(); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := rdsUpdateCh2.Receive(ctx); err != nil || !cmp.Equal(u, rdsUpdateErr{wantUpdate, nil}, cmp.AllowUnexported(rdsUpdateErr{})) {
t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, error receiving from channel: %v", u, err)
} }
// Old watch should see nothing. // Old watch should see nothing.
if u, err := rdsUpdateCh.TimedReceive(chanRecvTimeout); err != testutils.ErrRecvTimeout { if u, err := rdsUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }

Просмотреть файл

@ -19,12 +19,13 @@
package client package client
import ( import (
"context"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/internal/testutils"
) )
type serviceUpdateErr struct { type serviceUpdateErr struct {
@ -56,20 +57,22 @@ func (s) TestServiceWatch(t *testing.T) {
wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}} wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -87,7 +90,7 @@ func (s) TestServiceWatch(t *testing.T) {
}}, }},
}, },
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate2, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate2, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -114,20 +117,22 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}} wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -135,7 +140,7 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName + "2"}, testLDSName: {RouteConfigName: testRDSName + "2"},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
@ -144,7 +149,7 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != testutils.ErrRecvTimeout { if u, err := serviceUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err)
} }
@ -154,7 +159,9 @@ func (s) TestServiceWatchLDSUpdate(t *testing.T) {
testRDSName + "2": {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "2": 1}}}}, testRDSName + "2": {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "2": 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate2, nil}, serviceCmpOpts...) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate2, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
} }
@ -181,20 +188,22 @@ func (s) TestServiceWatchSecond(t *testing.T) {
wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}} wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -204,7 +213,7 @@ func (s) TestServiceWatchSecond(t *testing.T) {
serviceUpdateCh2.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh2.Send(serviceUpdateErr{u: update, err: err})
}) })
u, err := serviceUpdateCh2.Receive() u, err := serviceUpdateCh2.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get serviceUpdate: %v", err) t.Fatalf("failed to get serviceUpdate: %v", err)
} }
@ -225,11 +234,11 @@ func (s) TestServiceWatchSecond(t *testing.T) {
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
if u, err := serviceUpdateCh2.Receive(); err != testutils.ErrRecvTimeout { if u, err := serviceUpdateCh2.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -253,10 +262,13 @@ func (s) TestServiceWatchWithNoResponseFromServer(t *testing.T) {
c.WatchService(testLDSName, func(update ServiceUpdate, err error) { c.WatchService(testLDSName, func(update ServiceUpdate, err error) {
serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
u, err := serviceUpdateCh.TimedReceive(defaultTestWatchExpiryTimeout * 2) u, err := serviceUpdateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get serviceUpdate: %v", err) t.Fatalf("failed to get serviceUpdate: %v", err)
} }
@ -288,17 +300,19 @@ func (s) TestServiceWatchEmptyRDS(t *testing.T) {
serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{}) v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{})
u, err := serviceUpdateCh.TimedReceive(defaultTestWatchExpiryTimeout * 2) u, err := serviceUpdateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("failed to get serviceUpdate: %v", err) t.Fatalf("failed to get serviceUpdate: %v", err)
} }
@ -331,18 +345,20 @@ func (s) TestServiceWatchWithClientClose(t *testing.T) {
serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err}) serviceUpdateCh.Send(serviceUpdateErr{u: update, err: err})
}) })
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
// Client is closed before it receives the RDS response. // Client is closed before it receives the RDS response.
c.Close() c.Close()
if u, err := serviceUpdateCh.TimedReceive(defaultTestWatchExpiryTimeout * 2); err != testutils.ErrRecvTimeout { if u, err := serviceUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err) t.Errorf("unexpected serviceUpdate: %v, %v, want channel recv timeout", u, err)
} }
} }
@ -369,20 +385,22 @@ func (s) TestServiceNotCancelRDSOnSameLDSUpdate(t *testing.T) {
wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}} wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
@ -390,7 +408,7 @@ func (s) TestServiceNotCancelRDSOnSameLDSUpdate(t *testing.T) {
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if v, err := v2Client.removeWatches[RouteConfigResource].Receive(); err == nil { if v, err := v2Client.removeWatches[RouteConfigResource].Receive(ctx); err == nil {
t.Fatalf("unexpected rds watch cancel: %v", v) t.Fatalf("unexpected rds watch cancel: %v", v)
} }
} }
@ -420,30 +438,32 @@ func (s) TestServiceResourceRemoved(t *testing.T) {
wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}} wantUpdate := ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}
if _, err := v2Client.addWatches[ListenerResource].Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[ListenerResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName: 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) { if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{wantUpdate, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
// Remove LDS resource, should cancel the RDS watch, and trigger resource // Remove LDS resource, should cancel the RDS watch, and trigger resource
// removed error. // removed error.
v2Client.r.NewListeners(map[string]ListenerUpdate{}) v2Client.r.NewListeners(map[string]ListenerUpdate{})
if _, err := v2Client.removeWatches[RouteConfigResource].Receive(); err != nil { if _, err := v2Client.removeWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want watch to be canceled, got error %v", err) t.Fatalf("want watch to be canceled, got error %v", err)
} }
if u, err := serviceUpdateCh.Receive(); err != nil || ErrType(u.(serviceUpdateErr).err) != ErrorTypeResourceNotFound { if u, err := serviceUpdateCh.Receive(ctx); err != nil || ErrType(u.(serviceUpdateErr).err) != ErrorTypeResourceNotFound {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v, want update with error resource not found", u, err)
} }
@ -452,7 +472,7 @@ func (s) TestServiceResourceRemoved(t *testing.T) {
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "new": 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "new": 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != testutils.ErrRecvTimeout { if u, err := serviceUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected serviceUpdate: %v, want receiving from channel timeout", u) t.Errorf("unexpected serviceUpdate: %v, want receiving from channel timeout", u)
} }
@ -462,17 +482,21 @@ func (s) TestServiceResourceRemoved(t *testing.T) {
v2Client.r.NewListeners(map[string]ListenerUpdate{ v2Client.r.NewListeners(map[string]ListenerUpdate{
testLDSName: {RouteConfigName: testRDSName}, testLDSName: {RouteConfigName: testRDSName},
}) })
if _, err := v2Client.addWatches[RouteConfigResource].Receive(); err != nil { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := v2Client.addWatches[RouteConfigResource].Receive(ctx); err != nil {
t.Fatalf("want new watch to start, got error %v", err) t.Fatalf("want new watch to start, got error %v", err)
} }
if u, err := serviceUpdateCh.Receive(); err != testutils.ErrRecvTimeout { if u, err := serviceUpdateCh.Receive(ctx); err != context.DeadlineExceeded {
t.Errorf("unexpected serviceUpdate: %v, want receiving from channel timeout", u) t.Errorf("unexpected serviceUpdate: %v, want receiving from channel timeout", u)
} }
v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{ v2Client.r.NewRouteConfigs(map[string]RouteConfigUpdate{
testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "new2": 1}}}}, testRDSName: {Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "new2": 1}}}},
}) })
if u, err := serviceUpdateCh.Receive(); err != nil || !cmp.Equal(u, serviceUpdateErr{ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "new2": 1}}}}, nil}, serviceCmpOpts...) { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if u, err := serviceUpdateCh.Receive(ctx); err != nil || !cmp.Equal(u, serviceUpdateErr{ServiceUpdate{Routes: []*Route{{Prefix: newStringP(""), Action: map[string]uint32{testCDSName + "new2": 1}}}}, nil}, serviceCmpOpts...) {
t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err) t.Errorf("unexpected serviceUpdate: %v, error receiving from channel: %v", u, err)
} }
} }

Просмотреть файл

@ -18,6 +18,7 @@
package v2 package v2
import ( import (
"context"
"fmt" "fmt"
"strconv" "strconv"
"testing" "testing"
@ -28,12 +29,14 @@ import (
anypb "github.com/golang/protobuf/ptypes/any" anypb "github.com/golang/protobuf/ptypes/any"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/internal/testutils"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/testutils/fakeserver"
"google.golang.org/grpc/xds/internal/version" "google.golang.org/grpc/xds/internal/version"
) )
const defaultTestTimeout = 1 * time.Second
func startXDSV2Client(t *testing.T, cc *grpc.ClientConn) (v2c *client, cbLDS, cbRDS, cbCDS, cbEDS *testutils.Channel, cleanup func()) { func startXDSV2Client(t *testing.T, cc *grpc.ClientConn) (v2c *client, cbLDS, cbRDS, cbCDS, cbEDS *testutils.Channel, cleanup func()) {
cbLDS = testutils.NewChannel() cbLDS = testutils.NewChannel()
cbRDS = testutils.NewChannel() cbRDS = testutils.NewChannel()
@ -71,7 +74,9 @@ func startXDSV2Client(t *testing.T, cc *grpc.ClientConn) (v2c *client, cbLDS, cb
// compareXDSRequest reads requests from channel, compare it with want. // compareXDSRequest reads requests from channel, compare it with want.
func compareXDSRequest(ch *testutils.Channel, want *xdspb.DiscoveryRequest, ver, nonce string) error { func compareXDSRequest(ch *testutils.Channel, want *xdspb.DiscoveryRequest, ver, nonce string) error {
val, err := ch.Receive() ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
val, err := ch.Receive(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -133,7 +138,9 @@ func sendGoodResp(t *testing.T, rType xdsclient.ResourceType, fakeServer *fakese
} }
t.Logf("Good %v response acked", rType) t.Logf("Good %v response acked", rType)
if _, err := callbackCh.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := callbackCh.Receive(ctx); err != nil {
return "", fmt.Errorf("timeout when expecting %v update", rType) return "", fmt.Errorf("timeout when expecting %v update", rType)
} }
t.Logf("Good %v response callback executed", rType) t.Logf("Good %v response callback executed", rType)
@ -408,7 +415,9 @@ func (s) TestV2ClientAckCancelResponseRace(t *testing.T) {
} }
versionCDS++ versionCDS++
if req, err := fakeServer.XDSRequestChan.Receive(); err != testutils.ErrRecvTimeout { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if req, err := fakeServer.XDSRequestChan.Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("Got unexpected xds request after watch is canceled: %v", req) t.Fatalf("Got unexpected xds request after watch is canceled: %v", req)
} }
@ -417,11 +426,14 @@ func (s) TestV2ClientAckCancelResponseRace(t *testing.T) {
t.Logf("Good %v response pushed to fakeServer...", xdsclient.ClusterResource) t.Logf("Good %v response pushed to fakeServer...", xdsclient.ClusterResource)
// Expect no ACK because watch was canceled. // Expect no ACK because watch was canceled.
if req, err := fakeServer.XDSRequestChan.Receive(); err != testutils.ErrRecvTimeout { if req, err := fakeServer.XDSRequestChan.Receive(ctx); err != context.DeadlineExceeded {
t.Fatalf("Got unexpected xds request after watch is canceled: %v", req) t.Fatalf("Got unexpected xds request after watch is canceled: %v", req)
} }
// Still expected an callback update, because response was good. // Still expected an callback update, because response was good.
if _, err := cbCDS.Receive(); err != nil { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := cbCDS.Receive(ctx); err != nil {
t.Fatalf("Timeout when expecting %v update", xdsclient.ClusterResource) t.Fatalf("Timeout when expecting %v update", xdsclient.ClusterResource)
} }

Просмотреть файл

@ -19,6 +19,7 @@
package v2 package v2
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -35,7 +36,9 @@ import (
// watch. // watch.
func doLDS(t *testing.T, v2c xdsclient.APIClient, fakeServer *fakeserver.Server) { func doLDS(t *testing.T, v2c xdsclient.APIClient, fakeServer *fakeserver.Server) {
v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1) v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1)
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout waiting for LDS request: %v", err) t.Fatalf("Timeout waiting for LDS request: %v", err)
} }
} }

Просмотреть файл

@ -19,6 +19,7 @@
package v2 package v2
import ( import (
"context"
"errors" "errors"
"reflect" "reflect"
"testing" "testing"
@ -29,10 +30,10 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/resolver/manual"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/testutils/fakeserver"
"google.golang.org/grpc/xds/internal/version" "google.golang.org/grpc/xds/internal/version"
@ -421,7 +422,9 @@ func testWatchHandle(t *testing.T, test *watchHandleTestcase) {
// Wait till the request makes it to the fakeServer. This ensures that // Wait till the request makes it to the fakeServer. This ensures that
// the watch request has been processed by the v2Client. // the watch request has been processed by the v2Client.
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout waiting for an xDS request: %v", err) t.Fatalf("Timeout waiting for an xDS request: %v", err)
} }
@ -452,16 +455,16 @@ func testWatchHandle(t *testing.T, test *watchHandleTestcase) {
// Cannot directly compare test.wantUpdate with nil (typed vs non-typed nil: // Cannot directly compare test.wantUpdate with nil (typed vs non-typed nil:
// https://golang.org/doc/faq#nil_error). // https://golang.org/doc/faq#nil_error).
if c := test.wantUpdate; c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) { if c := test.wantUpdate; c == nil || (reflect.ValueOf(c).Kind() == reflect.Ptr && reflect.ValueOf(c).IsNil()) {
update, err := gotUpdateCh.Receive() update, err := gotUpdateCh.Receive(ctx)
if err == testutils.ErrRecvTimeout { if err == context.DeadlineExceeded {
return return
} }
t.Fatalf("Unexpected update: +%v", update) t.Fatalf("Unexpected update: +%v", update)
} }
wantUpdate := reflect.ValueOf(test.wantUpdate).Elem().Interface() wantUpdate := reflect.ValueOf(test.wantUpdate).Elem().Interface()
uErr, err := gotUpdateCh.Receive() uErr, err := gotUpdateCh.Receive(ctx)
if err == testutils.ErrRecvTimeout { if err == context.DeadlineExceeded {
t.Fatal("Timeout expecting xDS update") t.Fatal("Timeout expecting xDS update")
} }
gotUpdate := uErr.(updateErr).u gotUpdate := uErr.(updateErr).u
@ -533,7 +536,9 @@ func (s) TestV2ClientBackoffAfterRecvError(t *testing.T) {
t.Log("Started xds v2Client...") t.Log("Started xds v2Client...")
v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1) v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1)
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout expired when expecting an LDS request") t.Fatalf("Timeout expired when expecting an LDS request")
} }
t.Log("FakeServer received request...") t.Log("FakeServer received request...")
@ -552,7 +557,7 @@ func (s) TestV2ClientBackoffAfterRecvError(t *testing.T) {
t.Fatal("Received unexpected LDS callback") t.Fatal("Received unexpected LDS callback")
} }
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout expired when expecting an LDS request") t.Fatalf("Timeout expired when expecting an LDS request")
} }
t.Log("FakeServer received request after backoff...") t.Log("FakeServer received request after backoff...")
@ -583,7 +588,9 @@ func (s) TestV2ClientRetriesAfterBrokenStream(t *testing.T) {
t.Log("Started xds v2Client...") t.Log("Started xds v2Client...")
v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1) v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1)
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout expired when expecting an LDS request") t.Fatalf("Timeout expired when expecting an LDS request")
} }
t.Log("FakeServer received request...") t.Log("FakeServer received request...")
@ -591,20 +598,20 @@ func (s) TestV2ClientRetriesAfterBrokenStream(t *testing.T) {
fakeServer.XDSResponseChan <- &fakeserver.Response{Resp: goodLDSResponse1} fakeServer.XDSResponseChan <- &fakeserver.Response{Resp: goodLDSResponse1}
t.Log("Good LDS response pushed to fakeServer...") t.Log("Good LDS response pushed to fakeServer...")
if _, err := callbackCh.Receive(); err != nil { if _, err := callbackCh.Receive(ctx); err != nil {
t.Fatal("Timeout when expecting LDS update") t.Fatal("Timeout when expecting LDS update")
} }
// Read the ack, so the next request is sent after stream re-creation. // Read the ack, so the next request is sent after stream re-creation.
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout expired when expecting an LDS ACK") t.Fatalf("Timeout expired when expecting an LDS ACK")
} }
fakeServer.XDSResponseChan <- &fakeserver.Response{Err: errors.New("RPC error")} fakeServer.XDSResponseChan <- &fakeserver.Response{Err: errors.New("RPC error")}
t.Log("Bad LDS response pushed to fakeServer...") t.Log("Bad LDS response pushed to fakeServer...")
val, err := fakeServer.XDSRequestChan.Receive() val, err := fakeServer.XDSRequestChan.Receive(ctx)
if err == testutils.ErrRecvTimeout { if err == context.DeadlineExceeded {
t.Fatalf("Timeout expired when expecting LDS update") t.Fatalf("Timeout expired when expecting LDS update")
} }
gotRequest := val.(*fakeserver.Request) gotRequest := val.(*fakeserver.Request)
@ -657,7 +664,9 @@ func (s) TestV2ClientWatchWithoutStream(t *testing.T) {
v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1) v2c.AddWatch(xdsclient.ListenerResource, goodLDSTarget1)
// The watcher should receive an update, with a timeout error in it. // The watcher should receive an update, with a timeout error in it.
if v, err := callbackCh.TimedReceive(100 * time.Millisecond); err == nil { ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
if v, err := callbackCh.Receive(ctx); err == nil {
t.Fatalf("Expect an timeout error from watcher, got %v", v) t.Fatalf("Expect an timeout error from watcher, got %v", v)
} }
@ -667,7 +676,9 @@ func (s) TestV2ClientWatchWithoutStream(t *testing.T) {
Addresses: []resolver.Address{{Addr: fakeServer.Address}}, Addresses: []resolver.Address{{Addr: fakeServer.Address}},
}) })
if _, err := fakeServer.XDSRequestChan.Receive(); err != nil { ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := fakeServer.XDSRequestChan.Receive(ctx); err != nil {
t.Fatalf("Timeout expired when expecting an LDS request") t.Fatalf("Timeout expired when expecting an LDS request")
} }
t.Log("FakeServer received request...") t.Log("FakeServer received request...")
@ -675,7 +686,7 @@ func (s) TestV2ClientWatchWithoutStream(t *testing.T) {
fakeServer.XDSResponseChan <- &fakeserver.Response{Resp: goodLDSResponse1} fakeServer.XDSResponseChan <- &fakeserver.Response{Resp: goodLDSResponse1}
t.Log("Good LDS response pushed to fakeServer...") t.Log("Good LDS response pushed to fakeServer...")
if v, err := callbackCh.Receive(); err != nil { if v, err := callbackCh.Receive(ctx); err != nil {
t.Fatal("Timeout when expecting LDS update") t.Fatal("Timeout when expecting LDS update")
} else if _, ok := v.(xdsclient.ListenerUpdate); !ok { } else if _, ok := v.(xdsclient.ListenerUpdate); !ok {
t.Fatalf("Expect an LDS update from watcher, got %v", v) t.Fatalf("Expect an LDS update from watcher, got %v", v)

Просмотреть файл

@ -24,11 +24,13 @@ import (
"fmt" "fmt"
"net" "net"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/serviceconfig"
xdsinternal "google.golang.org/grpc/xds/internal" xdsinternal "google.golang.org/grpc/xds/internal"
@ -36,21 +38,22 @@ import (
"google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/client"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/client/bootstrap"
"google.golang.org/grpc/xds/internal/testutils" xdstestutils "google.golang.org/grpc/xds/internal/testutils"
"google.golang.org/grpc/xds/internal/testutils/fakeclient" "google.golang.org/grpc/xds/internal/testutils/fakeclient"
) )
const ( const (
targetStr = "target" targetStr = "target"
cluster = "cluster" cluster = "cluster"
balancerName = "dummyBalancer" balancerName = "dummyBalancer"
defaultTestTimeout = 1 * time.Second
) )
var ( var (
validConfig = bootstrap.Config{ validConfig = bootstrap.Config{
BalancerName: balancerName, BalancerName: balancerName,
Creds: grpc.WithInsecure(), Creds: grpc.WithInsecure(),
NodeProto: testutils.EmptyNodeProtoV2, NodeProto: xdstestutils.EmptyNodeProtoV2,
} }
target = resolver.Target{Endpoint: targetStr} target = resolver.Target{Endpoint: targetStr}
) )
@ -135,7 +138,7 @@ func TestResolverBuilder(t *testing.T) {
rbo: resolver.BuildOptions{}, rbo: resolver.BuildOptions{},
config: bootstrap.Config{ config: bootstrap.Config{
Creds: grpc.WithInsecure(), Creds: grpc.WithInsecure(),
NodeProto: testutils.EmptyNodeProtoV2, NodeProto: xdstestutils.EmptyNodeProtoV2,
}, },
wantErr: true, wantErr: true,
}, },
@ -144,7 +147,7 @@ func TestResolverBuilder(t *testing.T) {
rbo: resolver.BuildOptions{}, rbo: resolver.BuildOptions{},
config: bootstrap.Config{ config: bootstrap.Config{
BalancerName: balancerName, BalancerName: balancerName,
NodeProto: testutils.EmptyNodeProtoV2, NodeProto: xdstestutils.EmptyNodeProtoV2,
}, },
xdsClientFunc: getXDSClientMakerFunc(xdsclient.Options{Config: validConfig}), xdsClientFunc: getXDSClientMakerFunc(xdsclient.Options{Config: validConfig}),
wantErr: false, wantErr: false,
@ -248,7 +251,7 @@ func testSetup(t *testing.T, opts setupOpts) (*xdsResolver, *testClientConn, fun
func waitForWatchService(t *testing.T, xdsC *fakeclient.Client, wantTarget string) { func waitForWatchService(t *testing.T, xdsC *fakeclient.Client, wantTarget string) {
t.Helper() t.Helper()
gotTarget, err := xdsC.WaitForWatchService() gotTarget, err := xdsC.WaitForWatchService(context.Background())
if err != nil { if err != nil {
t.Fatalf("xdsClient.WatchService failed with error: %v", err) t.Fatalf("xdsClient.WatchService failed with error: %v", err)
} }
@ -273,7 +276,10 @@ func TestXDSResolverWatchCallbackAfterClose(t *testing.T) {
// update is triggerred on the ClientConn. // update is triggerred on the ClientConn.
xdsR.Close() xdsR.Close()
xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{Routes: []*client.Route{{Prefix: newStringP(""), Action: map[string]uint32{cluster: 1}}}}, nil) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{Routes: []*client.Route{{Prefix: newStringP(""), Action: map[string]uint32{cluster: 1}}}}, nil)
if gotVal, gotErr := tcc.stateCh.Receive(); gotErr != testutils.ErrRecvTimeout {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if gotVal, gotErr := tcc.stateCh.Receive(ctx); gotErr != context.DeadlineExceeded {
t.Fatalf("ClientConn.UpdateState called after xdsResolver is closed: %v", gotVal) t.Fatalf("ClientConn.UpdateState called after xdsResolver is closed: %v", gotVal)
} }
} }
@ -297,7 +303,10 @@ func TestXDSResolverBadServiceUpdate(t *testing.T) {
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr := errors.New("bad serviceupdate") suErr := errors.New("bad serviceupdate")
xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr {
t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr)
} }
} }
@ -337,7 +346,10 @@ func TestXDSResolverGoodServiceUpdate(t *testing.T) {
// Invoke the watchAPI callback with a good service update and wait for the // Invoke the watchAPI callback with a good service update and wait for the
// UpdateState method to be called on the ClientConn. // UpdateState method to be called on the ClientConn.
xdsC.InvokeWatchServiceCallback(tt.su, nil) xdsC.InvokeWatchServiceCallback(tt.su, nil)
gotState, err := tcc.stateCh.Receive()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotState, err := tcc.stateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("ClientConn.UpdateState returned error: %v", err) t.Fatalf("ClientConn.UpdateState returned error: %v", err)
} }
@ -377,14 +389,17 @@ func TestXDSResolverGoodUpdateAfterError(t *testing.T) {
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr := errors.New("bad serviceupdate") suErr := errors.New("bad serviceupdate")
xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr {
t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr) t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr)
} }
// Invoke the watchAPI callback with a good service update and wait for the // Invoke the watchAPI callback with a good service update and wait for the
// UpdateState method to be called on the ClientConn. // UpdateState method to be called on the ClientConn.
xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{Routes: []*client.Route{{Prefix: newStringP(""), Action: map[string]uint32{cluster: 1}}}}, nil) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{Routes: []*client.Route{{Prefix: newStringP(""), Action: map[string]uint32{cluster: 1}}}}, nil)
gotState, err := tcc.stateCh.Receive() gotState, err := tcc.stateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("ClientConn.UpdateState returned error: %v", err) t.Fatalf("ClientConn.UpdateState returned error: %v", err)
} }
@ -400,7 +415,7 @@ func TestXDSResolverGoodUpdateAfterError(t *testing.T) {
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr2 := errors.New("bad serviceupdate 2") suErr2 := errors.New("bad serviceupdate 2")
xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr2) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr2)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != nil || gotErrVal != suErr2 { if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != nil || gotErrVal != suErr2 {
t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr2) t.Fatalf("ClientConn.ReportError() received %v, want %v", gotErrVal, suErr2)
} }
} }
@ -425,10 +440,16 @@ func TestXDSResolverResourceNotFoundError(t *testing.T) {
// ReportError method to be called on the ClientConn. // ReportError method to be called on the ClientConn.
suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error") suErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "resource removed error")
xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr) xdsC.InvokeWatchServiceCallback(xdsclient.ServiceUpdate{}, suErr)
if gotErrVal, gotErr := tcc.errorCh.Receive(); gotErr != testutils.ErrRecvTimeout {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if gotErrVal, gotErr := tcc.errorCh.Receive(ctx); gotErr != context.DeadlineExceeded {
t.Fatalf("ClientConn.ReportError() received %v, %v, want channel recv timeout", gotErrVal, gotErr) t.Fatalf("ClientConn.ReportError() received %v, %v, want channel recv timeout", gotErrVal, gotErr)
} }
gotState, err := tcc.stateCh.Receive()
ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
gotState, err := tcc.stateCh.Receive(ctx)
if err != nil { if err != nil {
t.Fatalf("ClientConn.UpdateState returned error: %v", err) t.Fatalf("ClientConn.UpdateState returned error: %v", err)
} }

Просмотреть файл

@ -16,6 +16,7 @@
* *
*/ */
// Package testutils provides utility types, for use in xds tests.
package testutils package testutils
import ( import (

Просмотреть файл

@ -1,87 +0,0 @@
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Package testutils provides utility types, for use in xds tests.
package testutils
import (
"errors"
"time"
)
// ErrRecvTimeout is an error to indicate that a receive operation on the
// channel timed out.
var ErrRecvTimeout = errors.New("timed out when waiting for value on channel")
const (
// DefaultChanRecvTimeout is the default timeout for receive operations on the
// underlying channel.
DefaultChanRecvTimeout = 1 * time.Second
// DefaultChanBufferSize is the default buffer size of the underlying channel.
DefaultChanBufferSize = 1
)
// Channel wraps a generic channel and provides a timed receive operation.
type Channel struct {
ch chan interface{}
}
// Send sends value on the underlying channel.
func (cwt *Channel) Send(value interface{}) {
cwt.ch <- value
}
// Replace clears the value on the underlying channel, and sends the new value.
//
// It's expected to be used with a size-1 channel, to only keep the most
// up-to-date item.
func (cwt *Channel) Replace(value interface{}) {
select {
case <-cwt.ch:
default:
}
cwt.ch <- value
}
// TimedReceive returns the value received on the underlying channel, or
// ErrRecvTimeout if timeout amount of time elapsed.
func (cwt *Channel) TimedReceive(timeout time.Duration) (interface{}, error) {
timer := time.NewTimer(timeout)
select {
case <-timer.C:
return nil, ErrRecvTimeout
case got := <-cwt.ch:
timer.Stop()
return got, nil
}
}
// Receive returns the value received on the underlying channel, or
// ErrRecvTimeout if DefaultChanRecvTimeout amount of time elapses.
func (cwt *Channel) Receive() (interface{}, error) {
return cwt.TimedReceive(DefaultChanRecvTimeout)
}
// NewChannel returns a new Channel.
func NewChannel() *Channel {
return NewChannelWithSize(DefaultChanBufferSize)
}
// NewChannelWithSize returns a new Channel with a buffer of bufSize.
func NewChannelWithSize(bufSize int) *Channel {
return &Channel{ch: make(chan interface{}, bufSize)}
}

Просмотреть файл

@ -20,11 +20,12 @@
package fakeclient package fakeclient
import ( import (
"context"
"sync" "sync"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/xds/internal/balancer/lrs" "google.golang.org/grpc/xds/internal/balancer/lrs"
xdsclient "google.golang.org/grpc/xds/internal/client" xdsclient "google.golang.org/grpc/xds/internal/client"
"google.golang.org/grpc/xds/internal/testutils"
) )
// Client is a fake implementation of an xds client. It exposes a bunch of // Client is a fake implementation of an xds client. It exposes a bunch of
@ -58,10 +59,10 @@ func (xdsC *Client) WatchService(target string, callback func(xdsclient.ServiceU
} }
} }
// WaitForWatchService waits for WatchService to be invoked on this client // WaitForWatchService waits for WatchService to be invoked on this client and
// within a reasonable timeout, and returns the serviceName being watched. // returns the serviceName being watched.
func (xdsC *Client) WaitForWatchService() (string, error) { func (xdsC *Client) WaitForWatchService(ctx context.Context) (string, error) {
val, err := xdsC.suWatchCh.Receive() val, err := xdsC.suWatchCh.Receive(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -88,10 +89,10 @@ func (xdsC *Client) WatchCluster(clusterName string, callback func(xdsclient.Clu
} }
} }
// WaitForWatchCluster waits for WatchCluster to be invoked on this client // WaitForWatchCluster waits for WatchCluster to be invoked on this client and
// within a reasonable timeout, and returns the clusterName being watched. // returns the clusterName being watched.
func (xdsC *Client) WaitForWatchCluster() (string, error) { func (xdsC *Client) WaitForWatchCluster(ctx context.Context) (string, error) {
val, err := xdsC.cdsWatchCh.Receive() val, err := xdsC.cdsWatchCh.Receive(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -106,10 +107,10 @@ func (xdsC *Client) InvokeWatchClusterCallback(update xdsclient.ClusterUpdate, e
xdsC.cdsCb(update, err) xdsC.cdsCb(update, err)
} }
// WaitForCancelClusterWatch waits for a CDS watch to be cancelled within a // WaitForCancelClusterWatch waits for a CDS watch to be cancelled and returns
// reasonable timeout, and returns testutils.ErrRecvTimeout otherwise. // context.DeadlineExceeded otherwise.
func (xdsC *Client) WaitForCancelClusterWatch() error { func (xdsC *Client) WaitForCancelClusterWatch(ctx context.Context) error {
_, err := xdsC.cdsCancelCh.Receive() _, err := xdsC.cdsCancelCh.Receive(ctx)
return err return err
} }
@ -125,10 +126,10 @@ func (xdsC *Client) WatchEndpoints(clusterName string, callback func(xdsclient.E
} }
} }
// WaitForWatchEDS waits for WatchEndpoints to be invoked on this client within a // WaitForWatchEDS waits for WatchEndpoints to be invoked on this client and
// reasonable timeout, and returns the clusterName being watched. // returns the clusterName being watched.
func (xdsC *Client) WaitForWatchEDS() (string, error) { func (xdsC *Client) WaitForWatchEDS(ctx context.Context) (string, error) {
val, err := xdsC.edsWatchCh.Receive() val, err := xdsC.edsWatchCh.Receive(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -143,10 +144,10 @@ func (xdsC *Client) InvokeWatchEDSCallback(update xdsclient.EndpointsUpdate, err
xdsC.edsCb(update, err) xdsC.edsCb(update, err)
} }
// WaitForCancelEDSWatch waits for a EDS watch to be cancelled within a // WaitForCancelEDSWatch waits for a EDS watch to be cancelled and returns
// reasonable timeout, and returns testutils.ErrRecvTimeout otherwise. // context.DeadlineExceeded otherwise.
func (xdsC *Client) WaitForCancelEDSWatch() error { func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) error {
_, err := xdsC.edsCancelCh.Receive() _, err := xdsC.edsCancelCh.Receive(ctx)
return err return err
} }
@ -164,10 +165,10 @@ func (xdsC *Client) ReportLoad(server string, clusterName string, loadStore lrs.
return func() {} return func() {}
} }
// WaitForReportLoad waits for ReportLoad to be invoked on this client within a // WaitForReportLoad waits for ReportLoad to be invoked on this client and
// reasonable timeout, and returns the arguments passed to it. // returns the arguments passed to it.
func (xdsC *Client) WaitForReportLoad() (ReportLoadArgs, error) { func (xdsC *Client) WaitForReportLoad(ctx context.Context) (ReportLoadArgs, error) {
val, err := xdsC.loadReportCh.Receive() val, err := xdsC.loadReportCh.Receive(ctx)
return val.(ReportLoadArgs), err return val.(ReportLoadArgs), err
} }
@ -176,10 +177,10 @@ func (xdsC *Client) Close() {
xdsC.closeCh.Send(nil) xdsC.closeCh.Send(nil)
} }
// WaitForClose waits for Close to be invoked on this client within a // WaitForClose waits for Close to be invoked on this client and returns
// reasonable timeout, and returns testutils.ErrRecvTimeout otherwise. // context.DeadlineExceeded otherwise.
func (xdsC *Client) WaitForClose() error { func (xdsC *Client) WaitForClose(ctx context.Context) error {
_, err := xdsC.closeCh.Receive() _, err := xdsC.closeCh.Receive(ctx)
return err return err
} }

Просмотреть файл

@ -29,8 +29,8 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/xds/internal/testutils"
discoverypb "github.com/envoyproxy/go-control-plane/envoy/api/v2" discoverypb "github.com/envoyproxy/go-control-plane/envoy/api/v2"
adsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2" adsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2"