This commit is contained in:
Alain Jobart 2015-08-05 16:43:39 -07:00
Родитель 63fdce3438
Коммит 9c7241c70d
3 изменённых файлов: 114 добавлений и 46 удалений

Просмотреть файл

@ -143,6 +143,7 @@ func (conn *TabletBson) Execute2(ctx context.Context, query string, bindVars map
}
req := &tproto.ExecuteRequest{
Target: conn.target,
QueryRequest: tproto.Query{
Sql: query,
BindVariables: bindVars,
@ -206,6 +207,7 @@ func (conn *TabletBson) ExecuteBatch2(ctx context.Context, queries []tproto.Boun
}
req := tproto.ExecuteBatchRequest{
Target: conn.target,
QueryBatch: tproto.QueryList{
Queries: queries,
AsTransaction: asTransaction,
@ -290,13 +292,15 @@ func (conn *TabletBson) StreamExecute2(ctx context.Context, query string, bindVa
return nil, nil, tabletconn.ConnClosed
}
q := &tproto.Query{
Sql: query,
BindVariables: bindVars,
TransactionId: transactionID,
SessionId: conn.sessionID,
req := &tproto.StreamExecuteRequest{
Target: conn.target,
Query: &tproto.Query{
Sql: query,
BindVariables: bindVars,
TransactionId: transactionID,
SessionId: conn.sessionID,
},
}
req := &tproto.StreamExecuteRequest{Query: q}
// Use QueryResult instead of StreamExecuteResult for now, due to backwards compatability reasons.
// It'll be easuer to migrate all end-points to using StreamExecuteResult instead of
// maintaining a mixture of QueryResult and StreamExecuteResult channel returns.
@ -373,6 +377,7 @@ func (conn *TabletBson) Begin2(ctx context.Context) (transactionID int64, err er
}
beginRequest := &tproto.BeginRequest{
Target: conn.target,
SessionId: conn.sessionID,
}
beginResponse := new(tproto.BeginResponse)
@ -418,6 +423,7 @@ func (conn *TabletBson) Commit2(ctx context.Context, transactionID int64) error
}
commitRequest := &tproto.CommitRequest{
Target: conn.target,
SessionId: conn.sessionID,
TransactionId: transactionID,
}
@ -464,6 +470,7 @@ func (conn *TabletBson) Rollback2(ctx context.Context, transactionID int64) erro
}
rollbackRequest := &tproto.RollbackRequest{
Target: conn.target,
SessionId: conn.sessionID,
TransactionId: transactionID,
}
@ -489,6 +496,7 @@ func (conn *TabletBson) SplitQuery(ctx context.Context, query tproto.BoundQuery,
return
}
req := &tproto.SplitQueryRequest{
Target: conn.target,
Query: query,
SplitColumn: splitColumn,
SplitCount: splitCount,

Просмотреть файл

@ -88,6 +88,7 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, query string, bindVars
}
req := &pb.ExecuteRequest{
Target: conn.target,
Query: tproto.BoundQueryToProto3(query, bindVars),
TransactionId: transactionID,
SessionId: conn.sessionID,
@ -113,6 +114,7 @@ func (conn *gRPCQueryClient) ExecuteBatch(ctx context.Context, queries []tproto.
}
req := &pb.ExecuteBatchRequest{
Target: conn.target,
Queries: make([]*pb.BoundQuery, len(queries)),
AsTransaction: asTransaction,
TransactionId: transactionID,
@ -142,6 +144,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, query string, bi
}
req := &pb.StreamExecuteRequest{
Target: conn.target,
Query: tproto.BoundQueryToProto3(query, bindVars),
SessionId: conn.sessionID,
}
@ -183,6 +186,7 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context) (transactionID int64, er
}
req := &pb.BeginRequest{
Target: conn.target,
SessionId: conn.sessionID,
}
br, err := conn.c.Begin(ctx, req)
@ -206,6 +210,7 @@ func (conn *gRPCQueryClient) Commit(ctx context.Context, transactionID int64) er
}
req := &pb.CommitRequest{
Target: conn.target,
TransactionId: transactionID,
SessionId: conn.sessionID,
}
@ -230,6 +235,7 @@ func (conn *gRPCQueryClient) Rollback(ctx context.Context, transactionID int64)
}
req := &pb.RollbackRequest{
Target: conn.target,
TransactionId: transactionID,
SessionId: conn.sessionID,
}
@ -255,6 +261,7 @@ func (conn *gRPCQueryClient) SplitQuery(ctx context.Context, query tproto.BoundQ
}
req := &pb.SplitQueryRequest{
Target: conn.target,
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
SplitColumn: splitColumn,
SplitCount: int64(splitCount),

Просмотреть файл

@ -29,6 +29,9 @@ type FakeQueryService struct {
panics bool
streamExecutePanicsEarly bool
panicWait chan struct{}
// if set, we will also check Target, ImmediateCallerId and EffectiveCallerId
checkExtraFields bool
}
// HandlePanic is part of the queryservice.QueryService interface
@ -38,26 +41,30 @@ func (f *FakeQueryService) HandlePanic(err *error) {
}
}
// TestKeyspace is the Keyspace we use for this test
const TestKeyspace = "test_keyspace"
// TestShard is the Shard we use for this test
const TestShard = "test_shard"
// TestTabletType is the TabletType we use for this test
const TestTabletType = pbt.TabletType_UNKNOWN
// testTarget is the target we use for this test
var testTarget = &pb.Target{
Keyspace: "test_keyspace",
Shard: "test_shard",
TabletType: pbt.TabletType_REPLICA,
}
const testAsTransaction bool = true
const testSessionID int64 = 5678
func (f *FakeQueryService) checkTarget(name string, target *pb.Target) {
if !reflect.DeepEqual(target, testTarget) {
f.t.Errorf("invalid Target for %v: for %#v expected %#v", name, target, testTarget)
}
}
// GetSessionId is part of the queryservice.QueryService interface
func (f *FakeQueryService) GetSessionId(sessionParams *proto.SessionParams, sessionInfo *proto.SessionInfo) error {
if sessionParams.Keyspace != TestKeyspace {
f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, TestKeyspace)
if sessionParams.Keyspace != testTarget.Keyspace {
f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, testTarget.Keyspace)
}
if sessionParams.Shard != TestShard {
f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, TestShard)
if sessionParams.Shard != testTarget.Shard {
f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, testTarget.Shard)
}
sessionInfo.SessionId = testSessionID
return nil
@ -68,8 +75,12 @@ func (f *FakeQueryService) Begin(ctx context.Context, target *pb.Target, session
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if session.SessionId != testSessionID {
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTarget("Begin", target)
} else {
if session.SessionId != testSessionID {
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
}
}
if session.TransactionId != 0 {
f.t.Errorf("Begin: invalid TransactionId: got %v expected 0", session.TransactionId)
@ -125,8 +136,12 @@ func (f *FakeQueryService) Commit(ctx context.Context, target *pb.Target, sessio
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if session.SessionId != testSessionID {
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTarget("Commit", target)
} else {
if session.SessionId != testSessionID {
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
}
}
if session.TransactionId != commitTransactionID {
f.t.Errorf("Commit: invalid TransactionId: got %v expected %v", session.TransactionId, commitTransactionID)
@ -175,8 +190,12 @@ func (f *FakeQueryService) Rollback(ctx context.Context, target *pb.Target, sess
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if session.SessionId != testSessionID {
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTarget("Rollback", target)
} else {
if session.SessionId != testSessionID {
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
}
}
if session.TransactionId != rollbackTransactionID {
f.t.Errorf("Rollback: invalid TransactionId: got %v expected %v", session.TransactionId, rollbackTransactionID)
@ -231,8 +250,12 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *pb.Target, query
if !reflect.DeepEqual(query.BindVariables, executeBindVars) {
f.t.Errorf("invalid Execute.Query.BindVariables: got %v expected %v", query.BindVariables, executeBindVars)
}
if query.SessionId != testSessionID {
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTarget("Execute", target)
} else {
if query.SessionId != testSessionID {
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
}
}
if query.TransactionId != executeTransactionID {
f.t.Errorf("invalid Execute.Query.TransactionId: got %v expected %v", query.TransactionId, executeTransactionID)
@ -325,8 +348,12 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *pb.Target,
if !reflect.DeepEqual(query.BindVariables, streamExecuteBindVars) {
f.t.Errorf("invalid StreamExecute.Query.BindVariables: got %v expected %v", query.BindVariables, streamExecuteBindVars)
}
if query.SessionId != testSessionID {
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTarget("StreamExecute", target)
} else {
if query.SessionId != testSessionID {
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
}
}
if err := sendReply(&streamExecuteQueryResult1); err != nil {
f.t.Errorf("sendReply1 failed: %v", err)
@ -553,8 +580,12 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *pb.Target,
if !reflect.DeepEqual(queryList.Queries, executeBatchQueries) {
f.t.Errorf("invalid ExecuteBatch.QueryList.Queries: got %v expected %v", queryList.Queries, executeBatchQueries)
}
if queryList.SessionId != testSessionID {
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTarget("ExecuteBatch", target)
} else {
if queryList.SessionId != testSessionID {
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
}
}
if queryList.AsTransaction != testAsTransaction {
f.t.Errorf("invalid ExecuteBatch.QueryList.AsTransaction: got %v expected %v", queryList.AsTransaction, testAsTransaction)
@ -667,6 +698,9 @@ func (f *FakeQueryService) SplitQuery(ctx context.Context, target *pb.Target, re
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if f.checkExtraFields {
f.checkTarget("SplitQuery", target)
}
if !reflect.DeepEqual(req.Query, splitQueryBoundQuery) {
f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v", req.Query, splitQueryBoundQuery)
}
@ -828,45 +862,64 @@ func TestSuite(t *testing.T, protocol string, endPoint *pbt.EndPoint, fake *Fake
// make sure we use the right client
*tabletconn.TabletProtocol = protocol
// create a connection
// create a connection, using sessionId
ctx := context.Background()
conn, err := tabletconn.GetDialer()(ctx, endPoint, TestKeyspace, TestShard, TestTabletType, 30*time.Second)
conn, err := tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
defer conn.Close()
// run the normal tests
testBegin(t, conn)
testBegin2(t, conn)
testCommit(t, conn)
testCommit2(t, conn)
testRollback(t, conn)
testRollback2(t, conn)
testExecute(t, conn)
testExecute2(t, conn)
testStreamExecute(t, conn)
testStreamExecute2(t, conn)
testExecuteBatch(t, conn)
testExecuteBatch2(t, conn)
testSplitQuery(t, conn)
testStreamHealth(t, conn)
// force panics, make sure they're caught
// create a new connection that expects the extra fields
conn.Close()
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_REPLICA, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
// run the tests that expect extra fields
fake.checkExtraFields = true
testBegin2(t, conn)
testCommit2(t, conn)
testRollback2(t, conn)
testExecute2(t, conn)
testStreamExecute2(t, conn)
testExecuteBatch2(t, conn)
testSplitQuery(t, conn)
// force panics, make sure they're caught (with extra fields)
fake.panics = true
testBeginPanics(t, conn)
testBegin2Panics(t, conn)
testCommitPanics(t, conn)
testCommit2Panics(t, conn)
testRollbackPanics(t, conn)
testRollback2Panics(t, conn)
testExecutePanics(t, conn)
testExecute2Panics(t, conn)
testStreamExecutePanics(t, conn, fake)
testStreamExecute2Panics(t, conn, fake)
testExecuteBatchPanics(t, conn)
testExecuteBatch2Panics(t, conn)
testSplitQueryPanics(t, conn)
testStreamHealthPanics(t, conn)
conn.Close()
// force panic without extra fields
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
fake.checkExtraFields = false
testBeginPanics(t, conn)
testCommitPanics(t, conn)
testRollbackPanics(t, conn)
testExecutePanics(t, conn)
testExecuteBatchPanics(t, conn)
testStreamExecutePanics(t, conn, fake)
fake.panics = false
conn.Close()
}