diff --git a/go/vt/tabletserver/gorpctabletconn/conn.go b/go/vt/tabletserver/gorpctabletconn/conn.go index b0cf6c4811..57aa8bb533 100644 --- a/go/vt/tabletserver/gorpctabletconn/conn.go +++ b/go/vt/tabletserver/gorpctabletconn/conn.go @@ -143,6 +143,7 @@ func (conn *TabletBson) Execute2(ctx context.Context, query string, bindVars map } req := &tproto.ExecuteRequest{ + Target: conn.target, QueryRequest: tproto.Query{ Sql: query, BindVariables: bindVars, @@ -206,6 +207,7 @@ func (conn *TabletBson) ExecuteBatch2(ctx context.Context, queries []tproto.Boun } req := tproto.ExecuteBatchRequest{ + Target: conn.target, QueryBatch: tproto.QueryList{ Queries: queries, AsTransaction: asTransaction, @@ -290,13 +292,15 @@ func (conn *TabletBson) StreamExecute2(ctx context.Context, query string, bindVa return nil, nil, tabletconn.ConnClosed } - q := &tproto.Query{ - Sql: query, - BindVariables: bindVars, - TransactionId: transactionID, - SessionId: conn.sessionID, + req := &tproto.StreamExecuteRequest{ + Target: conn.target, + Query: &tproto.Query{ + Sql: query, + BindVariables: bindVars, + TransactionId: transactionID, + SessionId: conn.sessionID, + }, } - req := &tproto.StreamExecuteRequest{Query: q} // Use QueryResult instead of StreamExecuteResult for now, due to backwards compatability reasons. // It'll be easuer to migrate all end-points to using StreamExecuteResult instead of // maintaining a mixture of QueryResult and StreamExecuteResult channel returns. @@ -373,6 +377,7 @@ func (conn *TabletBson) Begin2(ctx context.Context) (transactionID int64, err er } beginRequest := &tproto.BeginRequest{ + Target: conn.target, SessionId: conn.sessionID, } beginResponse := new(tproto.BeginResponse) @@ -418,6 +423,7 @@ func (conn *TabletBson) Commit2(ctx context.Context, transactionID int64) error } commitRequest := &tproto.CommitRequest{ + Target: conn.target, SessionId: conn.sessionID, TransactionId: transactionID, } @@ -464,6 +470,7 @@ func (conn *TabletBson) Rollback2(ctx context.Context, transactionID int64) erro } rollbackRequest := &tproto.RollbackRequest{ + Target: conn.target, SessionId: conn.sessionID, TransactionId: transactionID, } @@ -489,6 +496,7 @@ func (conn *TabletBson) SplitQuery(ctx context.Context, query tproto.BoundQuery, return } req := &tproto.SplitQueryRequest{ + Target: conn.target, Query: query, SplitColumn: splitColumn, SplitCount: splitCount, diff --git a/go/vt/tabletserver/grpctabletconn/conn.go b/go/vt/tabletserver/grpctabletconn/conn.go index 495c4ca8e5..b419451d11 100644 --- a/go/vt/tabletserver/grpctabletconn/conn.go +++ b/go/vt/tabletserver/grpctabletconn/conn.go @@ -88,6 +88,7 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, query string, bindVars } req := &pb.ExecuteRequest{ + Target: conn.target, Query: tproto.BoundQueryToProto3(query, bindVars), TransactionId: transactionID, SessionId: conn.sessionID, @@ -113,6 +114,7 @@ func (conn *gRPCQueryClient) ExecuteBatch(ctx context.Context, queries []tproto. } req := &pb.ExecuteBatchRequest{ + Target: conn.target, Queries: make([]*pb.BoundQuery, len(queries)), AsTransaction: asTransaction, TransactionId: transactionID, @@ -142,6 +144,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, query string, bi } req := &pb.StreamExecuteRequest{ + Target: conn.target, Query: tproto.BoundQueryToProto3(query, bindVars), SessionId: conn.sessionID, } @@ -183,6 +186,7 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context) (transactionID int64, er } req := &pb.BeginRequest{ + Target: conn.target, SessionId: conn.sessionID, } br, err := conn.c.Begin(ctx, req) @@ -206,6 +210,7 @@ func (conn *gRPCQueryClient) Commit(ctx context.Context, transactionID int64) er } req := &pb.CommitRequest{ + Target: conn.target, TransactionId: transactionID, SessionId: conn.sessionID, } @@ -230,6 +235,7 @@ func (conn *gRPCQueryClient) Rollback(ctx context.Context, transactionID int64) } req := &pb.RollbackRequest{ + Target: conn.target, TransactionId: transactionID, SessionId: conn.sessionID, } @@ -255,6 +261,7 @@ func (conn *gRPCQueryClient) SplitQuery(ctx context.Context, query tproto.BoundQ } req := &pb.SplitQueryRequest{ + Target: conn.target, Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables), SplitColumn: splitColumn, SplitCount: int64(splitCount), diff --git a/go/vt/tabletserver/tabletconntest/tabletconntest.go b/go/vt/tabletserver/tabletconntest/tabletconntest.go index 7885f71be8..50caf0c76d 100644 --- a/go/vt/tabletserver/tabletconntest/tabletconntest.go +++ b/go/vt/tabletserver/tabletconntest/tabletconntest.go @@ -29,6 +29,9 @@ type FakeQueryService struct { panics bool streamExecutePanicsEarly bool panicWait chan struct{} + + // if set, we will also check Target, ImmediateCallerId and EffectiveCallerId + checkExtraFields bool } // HandlePanic is part of the queryservice.QueryService interface @@ -38,26 +41,30 @@ func (f *FakeQueryService) HandlePanic(err *error) { } } -// TestKeyspace is the Keyspace we use for this test -const TestKeyspace = "test_keyspace" - -// TestShard is the Shard we use for this test -const TestShard = "test_shard" - -// TestTabletType is the TabletType we use for this test -const TestTabletType = pbt.TabletType_UNKNOWN +// testTarget is the target we use for this test +var testTarget = &pb.Target{ + Keyspace: "test_keyspace", + Shard: "test_shard", + TabletType: pbt.TabletType_REPLICA, +} const testAsTransaction bool = true const testSessionID int64 = 5678 +func (f *FakeQueryService) checkTarget(name string, target *pb.Target) { + if !reflect.DeepEqual(target, testTarget) { + f.t.Errorf("invalid Target for %v: for %#v expected %#v", name, target, testTarget) + } +} + // GetSessionId is part of the queryservice.QueryService interface func (f *FakeQueryService) GetSessionId(sessionParams *proto.SessionParams, sessionInfo *proto.SessionInfo) error { - if sessionParams.Keyspace != TestKeyspace { - f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, TestKeyspace) + if sessionParams.Keyspace != testTarget.Keyspace { + f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, testTarget.Keyspace) } - if sessionParams.Shard != TestShard { - f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, TestShard) + if sessionParams.Shard != testTarget.Shard { + f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, testTarget.Shard) } sessionInfo.SessionId = testSessionID return nil @@ -68,8 +75,12 @@ func (f *FakeQueryService) Begin(ctx context.Context, target *pb.Target, session if f.panics { panic(fmt.Errorf("test-triggered panic")) } - if session.SessionId != testSessionID { - f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID) + if f.checkExtraFields { + f.checkTarget("Begin", target) + } else { + if session.SessionId != testSessionID { + f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID) + } } if session.TransactionId != 0 { f.t.Errorf("Begin: invalid TransactionId: got %v expected 0", session.TransactionId) @@ -125,8 +136,12 @@ func (f *FakeQueryService) Commit(ctx context.Context, target *pb.Target, sessio if f.panics { panic(fmt.Errorf("test-triggered panic")) } - if session.SessionId != testSessionID { - f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID) + if f.checkExtraFields { + f.checkTarget("Commit", target) + } else { + if session.SessionId != testSessionID { + f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID) + } } if session.TransactionId != commitTransactionID { f.t.Errorf("Commit: invalid TransactionId: got %v expected %v", session.TransactionId, commitTransactionID) @@ -175,8 +190,12 @@ func (f *FakeQueryService) Rollback(ctx context.Context, target *pb.Target, sess if f.panics { panic(fmt.Errorf("test-triggered panic")) } - if session.SessionId != testSessionID { - f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID) + if f.checkExtraFields { + f.checkTarget("Rollback", target) + } else { + if session.SessionId != testSessionID { + f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID) + } } if session.TransactionId != rollbackTransactionID { f.t.Errorf("Rollback: invalid TransactionId: got %v expected %v", session.TransactionId, rollbackTransactionID) @@ -231,8 +250,12 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *pb.Target, query if !reflect.DeepEqual(query.BindVariables, executeBindVars) { f.t.Errorf("invalid Execute.Query.BindVariables: got %v expected %v", query.BindVariables, executeBindVars) } - if query.SessionId != testSessionID { - f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID) + if f.checkExtraFields { + f.checkTarget("Execute", target) + } else { + if query.SessionId != testSessionID { + f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID) + } } if query.TransactionId != executeTransactionID { f.t.Errorf("invalid Execute.Query.TransactionId: got %v expected %v", query.TransactionId, executeTransactionID) @@ -325,8 +348,12 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *pb.Target, if !reflect.DeepEqual(query.BindVariables, streamExecuteBindVars) { f.t.Errorf("invalid StreamExecute.Query.BindVariables: got %v expected %v", query.BindVariables, streamExecuteBindVars) } - if query.SessionId != testSessionID { - f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID) + if f.checkExtraFields { + f.checkTarget("StreamExecute", target) + } else { + if query.SessionId != testSessionID { + f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID) + } } if err := sendReply(&streamExecuteQueryResult1); err != nil { f.t.Errorf("sendReply1 failed: %v", err) @@ -553,8 +580,12 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *pb.Target, if !reflect.DeepEqual(queryList.Queries, executeBatchQueries) { f.t.Errorf("invalid ExecuteBatch.QueryList.Queries: got %v expected %v", queryList.Queries, executeBatchQueries) } - if queryList.SessionId != testSessionID { - f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID) + if f.checkExtraFields { + f.checkTarget("ExecuteBatch", target) + } else { + if queryList.SessionId != testSessionID { + f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID) + } } if queryList.AsTransaction != testAsTransaction { f.t.Errorf("invalid ExecuteBatch.QueryList.AsTransaction: got %v expected %v", queryList.AsTransaction, testAsTransaction) @@ -667,6 +698,9 @@ func (f *FakeQueryService) SplitQuery(ctx context.Context, target *pb.Target, re if f.panics { panic(fmt.Errorf("test-triggered panic")) } + if f.checkExtraFields { + f.checkTarget("SplitQuery", target) + } if !reflect.DeepEqual(req.Query, splitQueryBoundQuery) { f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v", req.Query, splitQueryBoundQuery) } @@ -828,45 +862,64 @@ func TestSuite(t *testing.T, protocol string, endPoint *pbt.EndPoint, fake *Fake // make sure we use the right client *tabletconn.TabletProtocol = protocol - // create a connection + // create a connection, using sessionId ctx := context.Background() - conn, err := tabletconn.GetDialer()(ctx, endPoint, TestKeyspace, TestShard, TestTabletType, 30*time.Second) + conn, err := tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second) if err != nil { t.Fatalf("dial failed: %v", err) } - defer conn.Close() // run the normal tests testBegin(t, conn) - testBegin2(t, conn) testCommit(t, conn) - testCommit2(t, conn) testRollback(t, conn) - testRollback2(t, conn) testExecute(t, conn) - testExecute2(t, conn) testStreamExecute(t, conn) - testStreamExecute2(t, conn) testExecuteBatch(t, conn) - testExecuteBatch2(t, conn) testSplitQuery(t, conn) testStreamHealth(t, conn) - // force panics, make sure they're caught + // create a new connection that expects the extra fields + conn.Close() + conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_REPLICA, 30*time.Second) + if err != nil { + t.Fatalf("dial failed: %v", err) + } + + // run the tests that expect extra fields + fake.checkExtraFields = true + testBegin2(t, conn) + testCommit2(t, conn) + testRollback2(t, conn) + testExecute2(t, conn) + testStreamExecute2(t, conn) + testExecuteBatch2(t, conn) + testSplitQuery(t, conn) + + // force panics, make sure they're caught (with extra fields) fake.panics = true - testBeginPanics(t, conn) testBegin2Panics(t, conn) - testCommitPanics(t, conn) testCommit2Panics(t, conn) - testRollbackPanics(t, conn) testRollback2Panics(t, conn) - testExecutePanics(t, conn) testExecute2Panics(t, conn) - testStreamExecutePanics(t, conn, fake) testStreamExecute2Panics(t, conn, fake) - testExecuteBatchPanics(t, conn) testExecuteBatch2Panics(t, conn) testSplitQueryPanics(t, conn) testStreamHealthPanics(t, conn) + conn.Close() + + // force panic without extra fields + conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second) + if err != nil { + t.Fatalf("dial failed: %v", err) + } + fake.checkExtraFields = false + testBeginPanics(t, conn) + testCommitPanics(t, conn) + testRollbackPanics(t, conn) + testExecutePanics(t, conn) + testExecuteBatchPanics(t, conn) + testStreamExecutePanics(t, conn, fake) fake.panics = false + conn.Close() }