зеркало из https://github.com/github/vitess-gh.git
Adding tests for caller ids, and fixing grpc / bson to they work.
This commit is contained in:
Родитель
9c7241c70d
Коммит
ed0c4c9cf5
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/youtube/vitess/go/netutil"
|
||||
"github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
|
||||
"github.com/youtube/vitess/go/vt/callerid"
|
||||
"github.com/youtube/vitess/go/vt/rpc"
|
||||
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
|
||||
|
@ -132,6 +133,26 @@ func (conn *TabletBson) Execute(ctx context.Context, query string, bindVars map[
|
|||
return qr, nil
|
||||
}
|
||||
|
||||
func getEffectiveCallerID(ctx context.Context) *tproto.CallerID {
|
||||
if ef := callerid.EffectiveCallerIDFromContext(ctx); ef != nil {
|
||||
return &tproto.CallerID{
|
||||
Principal: ef.Principal,
|
||||
Component: ef.Component,
|
||||
Subcomponent: ef.Subcomponent,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getImmediateCallerID(ctx context.Context) *tproto.VTGateCallerID {
|
||||
if im := callerid.ImmediateCallerIDFromContext(ctx); im != nil {
|
||||
return &tproto.VTGateCallerID{
|
||||
Username: im.Username,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute2 should not be used now other than in tests.
|
||||
// It is the CallerID enabled version of Execute
|
||||
// Execute2 sends to query to VTTablet
|
||||
|
@ -143,7 +164,9 @@ func (conn *TabletBson) Execute2(ctx context.Context, query string, bindVars map
|
|||
}
|
||||
|
||||
req := &tproto.ExecuteRequest{
|
||||
Target: conn.target,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
QueryRequest: tproto.Query{
|
||||
Sql: query,
|
||||
BindVariables: bindVars,
|
||||
|
@ -207,7 +230,9 @@ func (conn *TabletBson) ExecuteBatch2(ctx context.Context, queries []tproto.Boun
|
|||
}
|
||||
|
||||
req := tproto.ExecuteBatchRequest{
|
||||
Target: conn.target,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
QueryBatch: tproto.QueryList{
|
||||
Queries: queries,
|
||||
AsTransaction: asTransaction,
|
||||
|
@ -293,7 +318,9 @@ func (conn *TabletBson) StreamExecute2(ctx context.Context, query string, bindVa
|
|||
}
|
||||
|
||||
req := &tproto.StreamExecuteRequest{
|
||||
Target: conn.target,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
Query: &tproto.Query{
|
||||
Sql: query,
|
||||
BindVariables: bindVars,
|
||||
|
@ -377,8 +404,10 @@ func (conn *TabletBson) Begin2(ctx context.Context) (transactionID int64, err er
|
|||
}
|
||||
|
||||
beginRequest := &tproto.BeginRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
beginResponse := new(tproto.BeginResponse)
|
||||
action := func() error {
|
||||
|
@ -423,9 +452,11 @@ func (conn *TabletBson) Commit2(ctx context.Context, transactionID int64) error
|
|||
}
|
||||
|
||||
commitRequest := &tproto.CommitRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
TransactionId: transactionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
SessionId: conn.sessionID,
|
||||
TransactionId: transactionID,
|
||||
}
|
||||
commitResponse := new(tproto.CommitResponse)
|
||||
action := func() error {
|
||||
|
@ -470,9 +501,11 @@ func (conn *TabletBson) Rollback2(ctx context.Context, transactionID int64) erro
|
|||
}
|
||||
|
||||
rollbackRequest := &tproto.RollbackRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
TransactionId: transactionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
SessionId: conn.sessionID,
|
||||
TransactionId: transactionID,
|
||||
}
|
||||
rollbackResponse := new(tproto.RollbackResponse)
|
||||
action := func() error {
|
||||
|
@ -496,11 +529,13 @@ func (conn *TabletBson) SplitQuery(ctx context.Context, query tproto.BoundQuery,
|
|||
return
|
||||
}
|
||||
req := &tproto.SplitQueryRequest{
|
||||
Target: conn.target,
|
||||
Query: query,
|
||||
SplitColumn: splitColumn,
|
||||
SplitCount: splitCount,
|
||||
SessionID: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerID: getEffectiveCallerID(ctx),
|
||||
ImmediateCallerID: getImmediateCallerID(ctx),
|
||||
Query: query,
|
||||
SplitColumn: splitColumn,
|
||||
SplitCount: splitCount,
|
||||
SessionID: conn.sessionID,
|
||||
}
|
||||
reply := new(tproto.SplitQueryResult)
|
||||
action := func() error {
|
||||
|
|
|
@ -49,8 +49,8 @@ func (q *query) GetSessionId(ctx context.Context, request *pb.GetSessionIdReques
|
|||
func (q *query) Execute(ctx context.Context, request *pb.ExecuteRequest) (response *pb.ExecuteResponse, err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
reply := new(mproto.QueryResult)
|
||||
if err := q.server.Execute(ctx, request.Target, &proto.Query{
|
||||
|
@ -70,8 +70,8 @@ func (q *query) Execute(ctx context.Context, request *pb.ExecuteRequest) (respon
|
|||
func (q *query) ExecuteBatch(ctx context.Context, request *pb.ExecuteBatchRequest) (response *pb.ExecuteBatchResponse, err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
reply := new(proto.QueryResultList)
|
||||
if err := q.server.ExecuteBatch(ctx, request.Target, &proto.QueryList{
|
||||
|
@ -91,8 +91,8 @@ func (q *query) ExecuteBatch(ctx context.Context, request *pb.ExecuteBatchReques
|
|||
func (q *query) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Query_StreamExecuteServer) (err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx := callerid.NewContext(callinfo.GRPCCallInfo(stream.Context()),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
return q.server.StreamExecute(ctx, request.Target, &proto.Query{
|
||||
Sql: string(request.Query.Sql),
|
||||
|
@ -109,8 +109,8 @@ func (q *query) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Query
|
|||
func (q *query) Begin(ctx context.Context, request *pb.BeginRequest) (response *pb.BeginResponse, err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
txInfo := new(proto.TransactionInfo)
|
||||
if err := q.server.Begin(ctx, request.Target, &proto.Session{
|
||||
|
@ -128,8 +128,8 @@ func (q *query) Begin(ctx context.Context, request *pb.BeginRequest) (response *
|
|||
func (q *query) Commit(ctx context.Context, request *pb.CommitRequest) (response *pb.CommitResponse, err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
if err := q.server.Commit(ctx, request.Target, &proto.Session{
|
||||
SessionId: request.SessionId,
|
||||
|
@ -144,8 +144,8 @@ func (q *query) Commit(ctx context.Context, request *pb.CommitRequest) (response
|
|||
func (q *query) Rollback(ctx context.Context, request *pb.RollbackRequest) (response *pb.RollbackResponse, err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
if err := q.server.Rollback(ctx, request.Target, &proto.Session{
|
||||
SessionId: request.SessionId,
|
||||
|
@ -161,8 +161,8 @@ func (q *query) Rollback(ctx context.Context, request *pb.RollbackRequest) (resp
|
|||
func (q *query) SplitQuery(ctx context.Context, request *pb.SplitQueryRequest) (response *pb.SplitQueryResponse, err error) {
|
||||
defer q.server.HandlePanic(&err)
|
||||
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
|
||||
request.GetEffectiveCallerId(),
|
||||
request.GetImmediateCallerId(),
|
||||
request.EffectiveCallerId,
|
||||
request.ImmediateCallerId,
|
||||
)
|
||||
reply := &proto.SplitQueryResult{}
|
||||
if err := q.server.SplitQuery(ctx, request.Target, &proto.SplitQueryRequest{
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/netutil"
|
||||
"github.com/youtube/vitess/go/vt/callerid"
|
||||
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
|
||||
"golang.org/x/net/context"
|
||||
|
@ -88,10 +89,12 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, query string, bindVars
|
|||
}
|
||||
|
||||
req := &pb.ExecuteRequest{
|
||||
Target: conn.target,
|
||||
Query: tproto.BoundQueryToProto3(query, bindVars),
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
Query: tproto.BoundQueryToProto3(query, bindVars),
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
er, err := conn.c.Execute(ctx, req)
|
||||
if err != nil {
|
||||
|
@ -114,11 +117,13 @@ func (conn *gRPCQueryClient) ExecuteBatch(ctx context.Context, queries []tproto.
|
|||
}
|
||||
|
||||
req := &pb.ExecuteBatchRequest{
|
||||
Target: conn.target,
|
||||
Queries: make([]*pb.BoundQuery, len(queries)),
|
||||
AsTransaction: asTransaction,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
Queries: make([]*pb.BoundQuery, len(queries)),
|
||||
AsTransaction: asTransaction,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
for i, q := range queries {
|
||||
req.Queries[i] = tproto.BoundQueryToProto3(q.Sql, q.BindVariables)
|
||||
|
@ -144,9 +149,11 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, query string, bi
|
|||
}
|
||||
|
||||
req := &pb.StreamExecuteRequest{
|
||||
Target: conn.target,
|
||||
Query: tproto.BoundQueryToProto3(query, bindVars),
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
Query: tproto.BoundQueryToProto3(query, bindVars),
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
stream, err := conn.c.StreamExecute(ctx, req)
|
||||
if err != nil {
|
||||
|
@ -186,8 +193,10 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context) (transactionID int64, er
|
|||
}
|
||||
|
||||
req := &pb.BeginRequest{
|
||||
Target: conn.target,
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
br, err := conn.c.Begin(ctx, req)
|
||||
if err != nil {
|
||||
|
@ -210,9 +219,11 @@ func (conn *gRPCQueryClient) Commit(ctx context.Context, transactionID int64) er
|
|||
}
|
||||
|
||||
req := &pb.CommitRequest{
|
||||
Target: conn.target,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
_, err := conn.c.Commit(ctx, req)
|
||||
if err != nil {
|
||||
|
@ -235,9 +246,11 @@ func (conn *gRPCQueryClient) Rollback(ctx context.Context, transactionID int64)
|
|||
}
|
||||
|
||||
req := &pb.RollbackRequest{
|
||||
Target: conn.target,
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
TransactionId: transactionID,
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
_, err := conn.c.Rollback(ctx, req)
|
||||
if err != nil {
|
||||
|
@ -261,11 +274,13 @@ func (conn *gRPCQueryClient) SplitQuery(ctx context.Context, query tproto.BoundQ
|
|||
}
|
||||
|
||||
req := &pb.SplitQueryRequest{
|
||||
Target: conn.target,
|
||||
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
|
||||
SplitColumn: splitColumn,
|
||||
SplitCount: int64(splitCount),
|
||||
SessionId: conn.sessionID,
|
||||
Target: conn.target,
|
||||
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
|
||||
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
|
||||
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
|
||||
SplitColumn: splitColumn,
|
||||
SplitCount: int64(splitCount),
|
||||
SessionId: conn.sessionID,
|
||||
}
|
||||
sqr, err := conn.c.SplitQuery(ctx, req)
|
||||
if err != nil {
|
||||
|
|
|
@ -15,12 +15,14 @@ import (
|
|||
|
||||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/callerid"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/proto"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
pb "github.com/youtube/vitess/go/vt/proto/query"
|
||||
pbt "github.com/youtube/vitess/go/vt/proto/topodata"
|
||||
pbv "github.com/youtube/vitess/go/vt/proto/vtrpc"
|
||||
)
|
||||
|
||||
// FakeQueryService has the server side of this fake
|
||||
|
@ -48,13 +50,39 @@ var testTarget = &pb.Target{
|
|||
TabletType: pbt.TabletType_REPLICA,
|
||||
}
|
||||
|
||||
var testCallerID = &pbv.CallerID{
|
||||
Principal: "test_principal",
|
||||
Component: "test_component",
|
||||
Subcomponent: "test_subcomponent",
|
||||
}
|
||||
|
||||
var testVTGateCallerID = &pb.VTGateCallerID{
|
||||
Username: "test_username",
|
||||
}
|
||||
|
||||
const testAsTransaction bool = true
|
||||
|
||||
const testSessionID int64 = 5678
|
||||
|
||||
func (f *FakeQueryService) checkTarget(name string, target *pb.Target) {
|
||||
func (f *FakeQueryService) checkTargetCallerID(ctx context.Context, name string, target *pb.Target) {
|
||||
if !reflect.DeepEqual(target, testTarget) {
|
||||
f.t.Errorf("invalid Target for %v: for %#v expected %#v", name, target, testTarget)
|
||||
f.t.Errorf("invalid Target for %v: got %#v expected %#v", name, target, testTarget)
|
||||
}
|
||||
ef := callerid.EffectiveCallerIDFromContext(ctx)
|
||||
if ef == nil {
|
||||
f.t.Errorf("no effective caller id for %v", name)
|
||||
} else {
|
||||
if !reflect.DeepEqual(ef, testCallerID) {
|
||||
f.t.Errorf("invalid effective caller id for %v: got %v expected %v", name, ef, testCallerID)
|
||||
}
|
||||
}
|
||||
im := callerid.ImmediateCallerIDFromContext(ctx)
|
||||
if im == nil {
|
||||
f.t.Errorf("no immediate caller id for %v", name)
|
||||
} else {
|
||||
if !reflect.DeepEqual(im, testVTGateCallerID) {
|
||||
f.t.Errorf("invalid immediate caller id for %v: got %v expected %v", name, im, testVTGateCallerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,7 +104,7 @@ func (f *FakeQueryService) Begin(ctx context.Context, target *pb.Target, session
|
|||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Begin", target)
|
||||
f.checkTargetCallerID(ctx, "Begin", target)
|
||||
} else {
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
|
@ -114,6 +142,7 @@ func testBeginPanics(t *testing.T, conn tabletconn.TabletConn) {
|
|||
func testBegin2(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testBegin2")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
transactionID, err := conn.Begin2(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Begin2 failed: %v", err)
|
||||
|
@ -137,7 +166,7 @@ func (f *FakeQueryService) Commit(ctx context.Context, target *pb.Target, sessio
|
|||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Commit", target)
|
||||
f.checkTargetCallerID(ctx, "Commit", target)
|
||||
} else {
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
|
@ -171,6 +200,7 @@ func testCommitPanics(t *testing.T, conn tabletconn.TabletConn) {
|
|||
func testCommit2(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testCommit2")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
err := conn.Commit2(ctx, commitTransactionID)
|
||||
if err != nil {
|
||||
t.Fatalf("Commit2 failed: %v", err)
|
||||
|
@ -191,7 +221,7 @@ func (f *FakeQueryService) Rollback(ctx context.Context, target *pb.Target, sess
|
|||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Rollback", target)
|
||||
f.checkTargetCallerID(ctx, "Rollback", target)
|
||||
} else {
|
||||
if session.SessionId != testSessionID {
|
||||
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
|
||||
|
@ -225,6 +255,7 @@ func testRollbackPanics(t *testing.T, conn tabletconn.TabletConn) {
|
|||
func testRollback2(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testRollback2")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
err := conn.Rollback2(ctx, rollbackTransactionID)
|
||||
if err != nil {
|
||||
t.Fatalf("Rollback2 failed: %v", err)
|
||||
|
@ -251,7 +282,7 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *pb.Target, query
|
|||
f.t.Errorf("invalid Execute.Query.BindVariables: got %v expected %v", query.BindVariables, executeBindVars)
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("Execute", target)
|
||||
f.checkTargetCallerID(ctx, "Execute", target)
|
||||
} else {
|
||||
if query.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
|
||||
|
@ -312,6 +343,7 @@ func testExecute(t *testing.T, conn tabletconn.TabletConn) {
|
|||
func testExecute2(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testExecute2")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
qr, err := conn.Execute2(ctx, executeQuery, executeBindVars, executeTransactionID)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute failed: %v", err)
|
||||
|
@ -349,7 +381,7 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *pb.Target,
|
|||
f.t.Errorf("invalid StreamExecute.Query.BindVariables: got %v expected %v", query.BindVariables, streamExecuteBindVars)
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("StreamExecute", target)
|
||||
f.checkTargetCallerID(ctx, "StreamExecute", target)
|
||||
} else {
|
||||
if query.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
|
||||
|
@ -491,6 +523,7 @@ func testStreamExecutePanics(t *testing.T, conn tabletconn.TabletConn, fake *Fak
|
|||
func testStreamExecute2(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testStreamExecute2")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
stream, errFunc, err := conn.StreamExecute2(ctx, streamExecuteQuery, streamExecuteBindVars, streamExecuteTransactionID)
|
||||
if err != nil {
|
||||
t.Fatalf("StreamExecute2 failed: %v", err)
|
||||
|
@ -530,6 +563,7 @@ func testStreamExecute2Panics(t *testing.T, conn tabletconn.TabletConn, fake *Fa
|
|||
// by the StreamExecute2 call itself, or as the first error
|
||||
// by ErrFunc
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
fake.streamExecutePanicsEarly = true
|
||||
stream, errFunc, err := conn.StreamExecute2(ctx, streamExecuteQuery, streamExecuteBindVars, streamExecuteTransactionID)
|
||||
if err != nil {
|
||||
|
@ -581,7 +615,7 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *pb.Target,
|
|||
f.t.Errorf("invalid ExecuteBatch.QueryList.Queries: got %v expected %v", queryList.Queries, executeBatchQueries)
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("ExecuteBatch", target)
|
||||
f.checkTargetCallerID(ctx, "ExecuteBatch", target)
|
||||
} else {
|
||||
if queryList.SessionId != testSessionID {
|
||||
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
|
||||
|
@ -676,6 +710,7 @@ func testExecuteBatchPanics(t *testing.T, conn tabletconn.TabletConn) {
|
|||
func testExecuteBatch2(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testExecuteBatch2")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
qrl, err := conn.ExecuteBatch2(ctx, executeBatchQueries, true, executeBatchTransactionID)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteBatch failed: %v", err)
|
||||
|
@ -699,7 +734,7 @@ func (f *FakeQueryService) SplitQuery(ctx context.Context, target *pb.Target, re
|
|||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if f.checkExtraFields {
|
||||
f.checkTarget("SplitQuery", target)
|
||||
f.checkTargetCallerID(ctx, "SplitQuery", target)
|
||||
}
|
||||
if !reflect.DeepEqual(req.Query, splitQueryBoundQuery) {
|
||||
f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v", req.Query, splitQueryBoundQuery)
|
||||
|
@ -740,6 +775,7 @@ var splitQueryQuerySplitList = []proto.QuerySplit{
|
|||
func testSplitQuery(t *testing.T, conn tabletconn.TabletConn) {
|
||||
t.Log("testSplitQuery")
|
||||
ctx := context.Background()
|
||||
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
|
||||
qsl, err := conn.SplitQuery(ctx, splitQueryBoundQuery, splitQuerySplitColumn, splitQuerySplitCount)
|
||||
if err != nil {
|
||||
t.Fatalf("SplitQuery failed: %v", err)
|
||||
|
@ -906,9 +942,9 @@ func TestSuite(t *testing.T, protocol string, endPoint *pbt.EndPoint, fake *Fake
|
|||
testExecuteBatch2Panics(t, conn)
|
||||
testSplitQueryPanics(t, conn)
|
||||
testStreamHealthPanics(t, conn)
|
||||
conn.Close()
|
||||
|
||||
// force panic without extra fields
|
||||
conn.Close()
|
||||
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial failed: %v", err)
|
||||
|
|
Загрузка…
Ссылка в новой задаче