зеркало из https://github.com/github/vitess-gh.git
Adding unit tests for target.
This commit is contained in:
Родитель
63fdce3438
Коммит
9c7241c70d
|
@ -143,6 +143,7 @@ func (conn *TabletBson) Execute2(ctx context.Context, query string, bindVars map
|
|||
}
|
||||
|
||||
req := &tproto.ExecuteRequest{
|
||||
Target: conn.target,
|
||||
QueryRequest: tproto.Query{
|
||||
Sql: query,
|
||||
BindVariables: bindVars,
|
||||
|
@ -206,6 +207,7 @@ func (conn *TabletBson) ExecuteBatch2(ctx context.Context, queries []tproto.Boun
|
|||
}
|
||||
|
||||
req := tproto.ExecuteBatchRequest{
|
||||
Target: conn.target,
|
||||
QueryBatch: tproto.QueryList{
|
||||
Queries: queries,
|
||||
AsTransaction: asTransaction,
|
||||
|
@ -290,13 +292,15 @@ func (conn *TabletBson) StreamExecute2(ctx context.Context, query string, bindVa
|
|||
return nil, nil, tabletconn.ConnClosed
|
||||
}
|
||||
|
||||
q := &tproto.Query{
|
||||
Sql: query,
|
||||
BindVariables: bindVars,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
req := &tproto.StreamExecuteRequest{
|
||||
Target: conn.target,
|
||||
Query: &tproto.Query{
|
||||
Sql: query,
|
||||
BindVariables: bindVars,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
},
|
||||
}
|
||||
req := &tproto.StreamExecuteRequest{Query: q}
|
||||
// Use QueryResult instead of StreamExecuteResult for now, due to backwards compatability reasons.
|
||||
// It'll be easuer to migrate all end-points to using StreamExecuteResult instead of
|
||||
// maintaining a mixture of QueryResult and StreamExecuteResult channel returns.
|
||||
|
@ -373,6 +377,7 @@ func (conn *TabletBson) Begin2(ctx context.Context) (transactionID int64, err er
|
|||
}
|
||||
|
||||
beginRequest := &tproto.BeginRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
beginResponse := new(tproto.BeginResponse)
|
||||
|
@ -418,6 +423,7 @@ func (conn *TabletBson) Commit2(ctx context.Context, transactionID int64) error
|
|||
}
|
||||
|
||||
commitRequest := &tproto.CommitRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
TransactionId: transactionID,
|
||||
}
|
||||
|
@ -464,6 +470,7 @@ func (conn *TabletBson) Rollback2(ctx context.Context, transactionID int64) erro
|
|||
}
|
||||
|
||||
rollbackRequest := &tproto.RollbackRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
TransactionId: transactionID,
|
||||
}
|
||||
|
@ -489,6 +496,7 @@ func (conn *TabletBson) SplitQuery(ctx context.Context, query tproto.BoundQuery,
|
|||
return
|
||||
}
|
||||
req := &tproto.SplitQueryRequest{
|
||||
Target: conn.target,
|
||||
Query: query,
|
||||
SplitColumn: splitColumn,
|
||||
SplitCount: splitCount,
|
||||
|
|
|
@ -88,6 +88,7 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, query string, bindVars
|
|||
}
|
||||
|
||||
req := &pb.ExecuteRequest{
|
||||
Target: conn.target,
|
||||
Query: tproto.BoundQueryToProto3(query, bindVars),
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
|
@ -113,6 +114,7 @@ func (conn *gRPCQueryClient) ExecuteBatch(ctx context.Context, queries []tproto.
|
|||
}
|
||||
|
||||
req := &pb.ExecuteBatchRequest{
|
||||
Target: conn.target,
|
||||
Queries: make([]*pb.BoundQuery, len(queries)),
|
||||
AsTransaction: asTransaction,
|
||||
TransactionId: transactionID,
|
||||
|
@ -142,6 +144,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, query string, bi
|
|||
}
|
||||
|
||||
req := &pb.StreamExecuteRequest{
|
||||
Target: conn.target,
|
||||
Query: tproto.BoundQueryToProto3(query, bindVars),
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
|
@ -183,6 +186,7 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context) (transactionID int64, er
|
|||
}
|
||||
|
||||
req := &pb.BeginRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
br, err := conn.c.Begin(ctx, req)
|
||||
|
@ -206,6 +210,7 @@ func (conn *gRPCQueryClient) Commit(ctx context.Context, transactionID int64) er
|
|||
}
|
||||
|
||||
req := &pb.CommitRequest{
|
||||
Target: conn.target,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
|
@ -230,6 +235,7 @@ func (conn *gRPCQueryClient) Rollback(ctx context.Context, transactionID int64)
|
|||
}
|
||||
|
||||
req := &pb.RollbackRequest{
|
||||
Target: conn.target,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
|
@ -255,6 +261,7 @@ func (conn *gRPCQueryClient) SplitQuery(ctx context.Context, query tproto.BoundQ
|
|||
}
|
||||
|
||||
req := &pb.SplitQueryRequest{
|
||||
Target: conn.target,
|
||||
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
|
||||
SplitColumn: splitColumn,
|
||||
SplitCount: int64(splitCount),
|
||||
|
|
|
@ -29,6 +29,9 @@ type FakeQueryService struct {
|
|||
panics bool
|
||||
streamExecutePanicsEarly bool
|
||||
panicWait chan struct{}
|
||||
|
||||
// if set, we will also check Target, ImmediateCallerId and EffectiveCallerId
|
||||
checkExtraFields bool
|
||||
}
|
||||
|
||||
// HandlePanic is part of the queryservice.QueryService interface
|
||||
|
@ -38,26 +41,30 @@ func (f *FakeQueryService) HandlePanic(err *error) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestKeyspace is the Keyspace we use for this test
|
||||
const TestKeyspace = "test_keyspace"
|
||||
|
||||
// TestShard is the Shard we use for this test
|
||||
const TestShard = "test_shard"
|
||||
|
||||
// TestTabletType is the TabletType we use for this test
|
||||
const TestTabletType = pbt.TabletType_UNKNOWN
|
||||
// testTarget is the target we use for this test
|
||||
var testTarget = &pb.Target{
|
||||
Keyspace: "test_keyspace",
|
||||
Shard: "test_shard",
|
||||
TabletType: pbt.TabletType_REPLICA,
|
||||
}
|
||||
|
||||
const testAsTransaction bool = true
|
||||
|
||||
const testSessionID int64 = 5678
|
||||
|
||||
func (f *FakeQueryService) checkTarget(name string, target *pb.Target) {
|
||||
if !reflect.DeepEqual(target, testTarget) {
|
||||
f.t.Errorf("invalid Target for %v: for %#v expected %#v", name, target, testTarget)
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionId is part of the queryservice.QueryService interface
|
||||
func (f *FakeQueryService) GetSessionId(sessionParams *proto.SessionParams, sessionInfo *proto.SessionInfo) error {
|
||||
if sessionParams.Keyspace != TestKeyspace {
|
||||
f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, TestKeyspace)
|
||||
if sessionParams.Keyspace != testTarget.Keyspace {
|
||||
f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, testTarget.Keyspace)
|
||||
}
|
||||
if sessionParams.Shard != TestShard {
|
||||
f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, TestShard)
|
||||
if sessionParams.Shard != testTarget.Shard {
|
||||
f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, testTarget.Shard)
|
||||
}
|
||||
sessionInfo.SessionId = testSessionID
|
||||
return nil
|
||||
|
@ -68,8 +75,12 @@ func (f *FakeQueryService) Begin(ctx context.Context, target *pb.Target, session
|
|||
if f.panics {
|
||||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Begin", target)
|
||||
} else {
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
}
|
||||
}
|
||||
if session.TransactionId != 0 {
|
||||
f.t.Errorf("Begin: invalid TransactionId: got %v expected 0", session.TransactionId)
|
||||
|
@ -125,8 +136,12 @@ func (f *FakeQueryService) Commit(ctx context.Context, target *pb.Target, sessio
|
|||
if f.panics {
|
||||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Commit", target)
|
||||
} else {
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
}
|
||||
}
|
||||
if session.TransactionId != commitTransactionID {
|
||||
f.t.Errorf("Commit: invalid TransactionId: got %v expected %v", session.TransactionId, commitTransactionID)
|
||||
|
@ -175,8 +190,12 @@ func (f *FakeQueryService) Rollback(ctx context.Context, target *pb.Target, sess
|
|||
if f.panics {
|
||||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Rollback", target)
|
||||
} else {
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
}
|
||||
}
|
||||
if session.TransactionId != rollbackTransactionID {
|
||||
f.t.Errorf("Rollback: invalid TransactionId: got %v expected %v", session.TransactionId, rollbackTransactionID)
|
||||
|
@ -231,8 +250,12 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *pb.Target, query
|
|||
if !reflect.DeepEqual(query.BindVariables, executeBindVars) {
|
||||
f.t.Errorf("invalid Execute.Query.BindVariables: got %v expected %v", query.BindVariables, executeBindVars)
|
||||
}
|
||||
if query.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Execute", target)
|
||||
} else {
|
||||
if query.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
|
||||
}
|
||||
}
|
||||
if query.TransactionId != executeTransactionID {
|
||||
f.t.Errorf("invalid Execute.Query.TransactionId: got %v expected %v", query.TransactionId, executeTransactionID)
|
||||
|
@ -325,8 +348,12 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *pb.Target,
|
|||
if !reflect.DeepEqual(query.BindVariables, streamExecuteBindVars) {
|
||||
f.t.Errorf("invalid StreamExecute.Query.BindVariables: got %v expected %v", query.BindVariables, streamExecuteBindVars)
|
||||
}
|
||||
if query.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("StreamExecute", target)
|
||||
} else {
|
||||
if query.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
|
||||
}
|
||||
}
|
||||
if err := sendReply(&streamExecuteQueryResult1); err != nil {
|
||||
f.t.Errorf("sendReply1 failed: %v", err)
|
||||
|
@ -553,8 +580,12 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *pb.Target,
|
|||
if !reflect.DeepEqual(queryList.Queries, executeBatchQueries) {
|
||||
f.t.Errorf("invalid ExecuteBatch.QueryList.Queries: got %v expected %v", queryList.Queries, executeBatchQueries)
|
||||
}
|
||||
if queryList.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("ExecuteBatch", target)
|
||||
} else {
|
||||
if queryList.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
|
||||
}
|
||||
}
|
||||
if queryList.AsTransaction != testAsTransaction {
|
||||
f.t.Errorf("invalid ExecuteBatch.QueryList.AsTransaction: got %v expected %v", queryList.AsTransaction, testAsTransaction)
|
||||
|
@ -667,6 +698,9 @@ func (f *FakeQueryService) SplitQuery(ctx context.Context, target *pb.Target, re
|
|||
if f.panics {
|
||||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("SplitQuery", target)
|
||||
}
|
||||
if !reflect.DeepEqual(req.Query, splitQueryBoundQuery) {
|
||||
f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v", req.Query, splitQueryBoundQuery)
|
||||
}
|
||||
|
@ -828,45 +862,64 @@ func TestSuite(t *testing.T, protocol string, endPoint *pbt.EndPoint, fake *Fake
|
|||
// make sure we use the right client
|
||||
*tabletconn.TabletProtocol = protocol
|
||||
|
||||
// create a connection
|
||||
// create a connection, using sessionId
|
||||
ctx := context.Background()
|
||||
conn, err := tabletconn.GetDialer()(ctx, endPoint, TestKeyspace, TestShard, TestTabletType, 30*time.Second)
|
||||
conn, err := tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// run the normal tests
|
||||
testBegin(t, conn)
|
||||
testBegin2(t, conn)
|
||||
testCommit(t, conn)
|
||||
testCommit2(t, conn)
|
||||
testRollback(t, conn)
|
||||
testRollback2(t, conn)
|
||||
testExecute(t, conn)
|
||||
testExecute2(t, conn)
|
||||
testStreamExecute(t, conn)
|
||||
testStreamExecute2(t, conn)
|
||||
testExecuteBatch(t, conn)
|
||||
testExecuteBatch2(t, conn)
|
||||
testSplitQuery(t, conn)
|
||||
testStreamHealth(t, conn)
|
||||
|
||||
// force panics, make sure they're caught
|
||||
// create a new connection that expects the extra fields
|
||||
conn.Close()
|
||||
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_REPLICA, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial failed: %v", err)
|
||||
}
|
||||
|
||||
// run the tests that expect extra fields
|
||||
fake.checkExtraFields = true
|
||||
testBegin2(t, conn)
|
||||
testCommit2(t, conn)
|
||||
testRollback2(t, conn)
|
||||
testExecute2(t, conn)
|
||||
testStreamExecute2(t, conn)
|
||||
testExecuteBatch2(t, conn)
|
||||
testSplitQuery(t, conn)
|
||||
|
||||
// force panics, make sure they're caught (with extra fields)
|
||||
fake.panics = true
|
||||
testBeginPanics(t, conn)
|
||||
testBegin2Panics(t, conn)
|
||||
testCommitPanics(t, conn)
|
||||
testCommit2Panics(t, conn)
|
||||
testRollbackPanics(t, conn)
|
||||
testRollback2Panics(t, conn)
|
||||
testExecutePanics(t, conn)
|
||||
testExecute2Panics(t, conn)
|
||||
testStreamExecutePanics(t, conn, fake)
|
||||
testStreamExecute2Panics(t, conn, fake)
|
||||
testExecuteBatchPanics(t, conn)
|
||||
testExecuteBatch2Panics(t, conn)
|
||||
testSplitQueryPanics(t, conn)
|
||||
testStreamHealthPanics(t, conn)
|
||||
conn.Close()
|
||||
|
||||
// force panic without extra fields
|
||||
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial failed: %v", err)
|
||||
}
|
||||
fake.checkExtraFields = false
|
||||
testBeginPanics(t, conn)
|
||||
testCommitPanics(t, conn)
|
||||
testRollbackPanics(t, conn)
|
||||
testExecutePanics(t, conn)
|
||||
testExecuteBatchPanics(t, conn)
|
||||
testStreamExecutePanics(t, conn, fake)
|
||||
fake.panics = false
|
||||
conn.Close()
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче