Merge pull request #967 from alainjobart/replication

Replication
This commit is contained in:
Alain Jobart 2015-08-05 20:54:16 -07:00
Родитель d0de5242c9 d56b66ea1c
Коммит b436bf95f6
7 изменённых файлов: 329 добавлений и 586 удалений

Просмотреть файл

@ -398,21 +398,13 @@ func (m *GetSessionIdRequest) GetImmediateCallerId() *VTGateCallerID {
// GetSessionIdResponse is the returned value from GetSessionId
type GetSessionIdResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
SessionId int64 `protobuf:"varint,2,opt,name=session_id" json:"session_id,omitempty"`
SessionId int64 `protobuf:"varint,1,opt,name=session_id" json:"session_id,omitempty"`
}
func (m *GetSessionIdResponse) Reset() { *m = GetSessionIdResponse{} }
func (m *GetSessionIdResponse) String() string { return proto.CompactTextString(m) }
func (*GetSessionIdResponse) ProtoMessage() {}
func (m *GetSessionIdResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
// ExecuteRequest is the payload to Execute
type ExecuteRequest struct {
EffectiveCallerId *vtrpc.CallerID `protobuf:"bytes,1,opt,name=effective_caller_id" json:"effective_caller_id,omitempty"`
@ -457,21 +449,13 @@ func (m *ExecuteRequest) GetQuery() *BoundQuery {
// ExecuteResponse is the returned value from Execute
type ExecuteResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
Result *QueryResult `protobuf:"bytes,2,opt,name=result" json:"result,omitempty"`
Result *QueryResult `protobuf:"bytes,1,opt,name=result" json:"result,omitempty"`
}
func (m *ExecuteResponse) Reset() { *m = ExecuteResponse{} }
func (m *ExecuteResponse) String() string { return proto.CompactTextString(m) }
func (*ExecuteResponse) ProtoMessage() {}
func (m *ExecuteResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
func (m *ExecuteResponse) GetResult() *QueryResult {
if m != nil {
return m.Result
@ -524,21 +508,13 @@ func (m *ExecuteBatchRequest) GetQueries() []*BoundQuery {
// ExecuteBatchResponse is the returned value from ExecuteBatch
type ExecuteBatchResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
Results []*QueryResult `protobuf:"bytes,2,rep,name=results" json:"results,omitempty"`
Results []*QueryResult `protobuf:"bytes,1,rep,name=results" json:"results,omitempty"`
}
func (m *ExecuteBatchResponse) Reset() { *m = ExecuteBatchResponse{} }
func (m *ExecuteBatchResponse) String() string { return proto.CompactTextString(m) }
func (*ExecuteBatchResponse) ProtoMessage() {}
func (m *ExecuteBatchResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
func (m *ExecuteBatchResponse) GetResults() []*QueryResult {
if m != nil {
return m.Results
@ -589,21 +565,13 @@ func (m *StreamExecuteRequest) GetQuery() *BoundQuery {
// StreamExecuteResponse is the returned value from StreamExecute
type StreamExecuteResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
Result *QueryResult `protobuf:"bytes,2,opt,name=result" json:"result,omitempty"`
Result *QueryResult `protobuf:"bytes,1,opt,name=result" json:"result,omitempty"`
}
func (m *StreamExecuteResponse) Reset() { *m = StreamExecuteResponse{} }
func (m *StreamExecuteResponse) String() string { return proto.CompactTextString(m) }
func (*StreamExecuteResponse) ProtoMessage() {}
func (m *StreamExecuteResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
func (m *StreamExecuteResponse) GetResult() *QueryResult {
if m != nil {
return m.Result
@ -646,21 +614,13 @@ func (m *BeginRequest) GetTarget() *Target {
// BeginResponse is the returned value from Begin
type BeginResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
TransactionId int64 `protobuf:"varint,2,opt,name=transaction_id" json:"transaction_id,omitempty"`
TransactionId int64 `protobuf:"varint,1,opt,name=transaction_id" json:"transaction_id,omitempty"`
}
func (m *BeginResponse) Reset() { *m = BeginResponse{} }
func (m *BeginResponse) String() string { return proto.CompactTextString(m) }
func (*BeginResponse) ProtoMessage() {}
func (m *BeginResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
// CommitRequest is the payload to Commit
type CommitRequest struct {
EffectiveCallerId *vtrpc.CallerID `protobuf:"bytes,1,opt,name=effective_caller_id" json:"effective_caller_id,omitempty"`
@ -697,20 +657,12 @@ func (m *CommitRequest) GetTarget() *Target {
// CommitResponse is the returned value from Commit
type CommitResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
}
func (m *CommitResponse) Reset() { *m = CommitResponse{} }
func (m *CommitResponse) String() string { return proto.CompactTextString(m) }
func (*CommitResponse) ProtoMessage() {}
func (m *CommitResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
// RollbackRequest is the payload to Rollback
type RollbackRequest struct {
EffectiveCallerId *vtrpc.CallerID `protobuf:"bytes,1,opt,name=effective_caller_id" json:"effective_caller_id,omitempty"`
@ -747,20 +699,12 @@ func (m *RollbackRequest) GetTarget() *Target {
// RollbackResponse is the returned value from Rollback
type RollbackResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
}
func (m *RollbackResponse) Reset() { *m = RollbackResponse{} }
func (m *RollbackResponse) String() string { return proto.CompactTextString(m) }
func (*RollbackResponse) ProtoMessage() {}
func (m *RollbackResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
// SplitQueryRequest is the payload for SplitQuery
type SplitQueryRequest struct {
EffectiveCallerId *vtrpc.CallerID `protobuf:"bytes,1,opt,name=effective_caller_id" json:"effective_caller_id,omitempty"`
@ -826,21 +770,13 @@ func (m *QuerySplit) GetQuery() *BoundQuery {
// SplitQueryResponse is returned by SplitQuery and represents all the queries
// to execute in order to get the entire data set.
type SplitQueryResponse struct {
Error *vtrpc.RPCError `protobuf:"bytes,1,opt,name=error" json:"error,omitempty"`
Queries []*QuerySplit `protobuf:"bytes,2,rep,name=queries" json:"queries,omitempty"`
Queries []*QuerySplit `protobuf:"bytes,1,rep,name=queries" json:"queries,omitempty"`
}
func (m *SplitQueryResponse) Reset() { *m = SplitQueryResponse{} }
func (m *SplitQueryResponse) String() string { return proto.CompactTextString(m) }
func (*SplitQueryResponse) ProtoMessage() {}
func (m *SplitQueryResponse) GetError() *vtrpc.RPCError {
if m != nil {
return m.Error
}
return nil
}
func (m *SplitQueryResponse) GetQueries() []*QuerySplit {
if m != nil {
return m.Queries

Просмотреть файл

@ -15,6 +15,7 @@ import (
"github.com/youtube/vitess/go/netutil"
"github.com/youtube/vitess/go/rpcplus"
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/rpc"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
@ -132,6 +133,26 @@ func (conn *TabletBson) Execute(ctx context.Context, query string, bindVars map[
return qr, nil
}
func getEffectiveCallerID(ctx context.Context) *tproto.CallerID {
if ef := callerid.EffectiveCallerIDFromContext(ctx); ef != nil {
return &tproto.CallerID{
Principal: ef.Principal,
Component: ef.Component,
Subcomponent: ef.Subcomponent,
}
}
return nil
}
func getImmediateCallerID(ctx context.Context) *tproto.VTGateCallerID {
if im := callerid.ImmediateCallerIDFromContext(ctx); im != nil {
return &tproto.VTGateCallerID{
Username: im.Username,
}
}
return nil
}
// Execute2 should not be used now other than in tests.
// It is the CallerID enabled version of Execute
// Execute2 sends to query to VTTablet
@ -143,6 +164,9 @@ func (conn *TabletBson) Execute2(ctx context.Context, query string, bindVars map
}
req := &tproto.ExecuteRequest{
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
QueryRequest: tproto.Query{
Sql: query,
BindVariables: bindVars,
@ -206,6 +230,9 @@ func (conn *TabletBson) ExecuteBatch2(ctx context.Context, queries []tproto.Boun
}
req := tproto.ExecuteBatchRequest{
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
QueryBatch: tproto.QueryList{
Queries: queries,
AsTransaction: asTransaction,
@ -290,13 +317,17 @@ func (conn *TabletBson) StreamExecute2(ctx context.Context, query string, bindVa
return nil, nil, tabletconn.ConnClosed
}
q := &tproto.Query{
Sql: query,
BindVariables: bindVars,
TransactionId: transactionID,
SessionId: conn.sessionID,
req := &tproto.StreamExecuteRequest{
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
Query: &tproto.Query{
Sql: query,
BindVariables: bindVars,
TransactionId: transactionID,
SessionId: conn.sessionID,
},
}
req := &tproto.StreamExecuteRequest{Query: q}
// Use QueryResult instead of StreamExecuteResult for now, due to backwards compatability reasons.
// It'll be easuer to migrate all end-points to using StreamExecuteResult instead of
// maintaining a mixture of QueryResult and StreamExecuteResult channel returns.
@ -373,7 +404,10 @@ func (conn *TabletBson) Begin2(ctx context.Context) (transactionID int64, err er
}
beginRequest := &tproto.BeginRequest{
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
SessionId: conn.sessionID,
}
beginResponse := new(tproto.BeginResponse)
action := func() error {
@ -418,8 +452,11 @@ func (conn *TabletBson) Commit2(ctx context.Context, transactionID int64) error
}
commitRequest := &tproto.CommitRequest{
SessionId: conn.sessionID,
TransactionId: transactionID,
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
SessionId: conn.sessionID,
TransactionId: transactionID,
}
commitResponse := new(tproto.CommitResponse)
action := func() error {
@ -464,8 +501,11 @@ func (conn *TabletBson) Rollback2(ctx context.Context, transactionID int64) erro
}
rollbackRequest := &tproto.RollbackRequest{
SessionId: conn.sessionID,
TransactionId: transactionID,
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
SessionId: conn.sessionID,
TransactionId: transactionID,
}
rollbackResponse := new(tproto.RollbackResponse)
action := func() error {
@ -489,10 +529,13 @@ func (conn *TabletBson) SplitQuery(ctx context.Context, query tproto.BoundQuery,
return
}
req := &tproto.SplitQueryRequest{
Query: query,
SplitColumn: splitColumn,
SplitCount: splitCount,
SessionID: conn.sessionID,
Target: conn.target,
EffectiveCallerID: getEffectiveCallerID(ctx),
ImmediateCallerID: getImmediateCallerID(ctx),
Query: query,
SplitColumn: splitColumn,
SplitCount: splitCount,
SessionID: conn.sessionID,
}
reply := new(tproto.SplitQueryResult)
action := func() error {

Просмотреть файл

@ -33,14 +33,15 @@ func (q *query) GetSessionId(ctx context.Context, request *pb.GetSessionIdReques
defer q.server.HandlePanic(&err)
sessionInfo := new(proto.SessionInfo)
gsiErr := q.server.GetSessionId(&proto.SessionParams{
if err := q.server.GetSessionId(&proto.SessionParams{
Keyspace: request.Keyspace,
Shard: request.Shard,
}, sessionInfo)
}, sessionInfo); err != nil {
return nil, err
}
return &pb.GetSessionIdResponse{
SessionId: sessionInfo.SessionId,
Error: tabletserver.TabletErrorToRPCError(gsiErr),
}, nil
}
@ -48,20 +49,17 @@ func (q *query) GetSessionId(ctx context.Context, request *pb.GetSessionIdReques
func (q *query) Execute(ctx context.Context, request *pb.ExecuteRequest) (response *pb.ExecuteResponse, err error) {
defer q.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
reply := new(mproto.QueryResult)
execErr := q.server.Execute(ctx, request.Target, &proto.Query{
if err := q.server.Execute(ctx, request.Target, &proto.Query{
Sql: string(request.Query.Sql),
BindVariables: proto.Proto3ToBindVariables(request.Query.BindVariables),
SessionId: request.SessionId,
TransactionId: request.TransactionId,
}, reply)
if execErr != nil {
return &pb.ExecuteResponse{
Error: tabletserver.TabletErrorToRPCError(execErr),
}, nil
}, reply); err != nil {
return nil, err
}
return &pb.ExecuteResponse{
Result: mproto.QueryResultToProto3(reply),
@ -72,20 +70,17 @@ func (q *query) Execute(ctx context.Context, request *pb.ExecuteRequest) (respon
func (q *query) ExecuteBatch(ctx context.Context, request *pb.ExecuteBatchRequest) (response *pb.ExecuteBatchResponse, err error) {
defer q.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
reply := new(proto.QueryResultList)
execErr := q.server.ExecuteBatch(ctx, request.Target, &proto.QueryList{
if err := q.server.ExecuteBatch(ctx, request.Target, &proto.QueryList{
Queries: proto.Proto3ToBoundQueryList(request.Queries),
SessionId: request.SessionId,
AsTransaction: request.AsTransaction,
TransactionId: request.TransactionId,
}, reply)
if execErr != nil {
return &pb.ExecuteBatchResponse{
Error: tabletserver.TabletErrorToRPCError(execErr),
}, nil
}, reply); err != nil {
return nil, err
}
return &pb.ExecuteBatchResponse{
Results: proto.QueryResultListToProto3(reply.List),
@ -96,10 +91,10 @@ func (q *query) ExecuteBatch(ctx context.Context, request *pb.ExecuteBatchReques
func (q *query) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Query_StreamExecuteServer) (err error) {
defer q.server.HandlePanic(&err)
ctx := callerid.NewContext(callinfo.GRPCCallInfo(stream.Context()),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
seErr := q.server.StreamExecute(ctx, request.Target, &proto.Query{
return q.server.StreamExecute(ctx, request.Target, &proto.Query{
Sql: string(request.Query.Sql),
BindVariables: proto.Proto3ToBindVariables(request.Query.BindVariables),
SessionId: request.SessionId,
@ -108,31 +103,20 @@ func (q *query) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Query
Result: mproto.QueryResultToProto3(reply),
})
})
if seErr != nil {
response := &pb.StreamExecuteResponse{
Error: tabletserver.TabletErrorToRPCError(seErr),
}
if err := stream.Send(response); err != nil {
return err
}
}
return nil
}
// Begin is part of the queryservice.QueryServer interface
func (q *query) Begin(ctx context.Context, request *pb.BeginRequest) (response *pb.BeginResponse, err error) {
defer q.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
txInfo := new(proto.TransactionInfo)
if beginErr := q.server.Begin(ctx, request.Target, &proto.Session{
if err := q.server.Begin(ctx, request.Target, &proto.Session{
SessionId: request.SessionId,
}, txInfo); beginErr != nil {
return &pb.BeginResponse{
Error: tabletserver.TabletErrorToRPCError(beginErr),
}, nil
}, txInfo); err != nil {
return nil, err
}
return &pb.BeginResponse{
@ -144,52 +128,50 @@ func (q *query) Begin(ctx context.Context, request *pb.BeginRequest) (response *
func (q *query) Commit(ctx context.Context, request *pb.CommitRequest) (response *pb.CommitResponse, err error) {
defer q.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
commitErr := q.server.Commit(ctx, request.Target, &proto.Session{
if err := q.server.Commit(ctx, request.Target, &proto.Session{
SessionId: request.SessionId,
TransactionId: request.TransactionId,
})
return &pb.CommitResponse{
Error: tabletserver.TabletErrorToRPCError(commitErr),
}, nil
}); err != nil {
return nil, err
}
return &pb.CommitResponse{}, nil
}
// Rollback is part of the queryservice.QueryServer interface
func (q *query) Rollback(ctx context.Context, request *pb.RollbackRequest) (response *pb.RollbackResponse, err error) {
defer q.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
rollbackErr := q.server.Rollback(ctx, request.Target, &proto.Session{
if err := q.server.Rollback(ctx, request.Target, &proto.Session{
SessionId: request.SessionId,
TransactionId: request.TransactionId,
})
}); err != nil {
return nil, err
}
return &pb.RollbackResponse{
Error: tabletserver.TabletErrorToRPCError(rollbackErr),
}, nil
return &pb.RollbackResponse{}, nil
}
// SplitQuery is part of the queryservice.QueryServer interface
func (q *query) SplitQuery(ctx context.Context, request *pb.SplitQueryRequest) (response *pb.SplitQueryResponse, err error) {
defer q.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.GetEffectiveCallerId(),
request.GetImmediateCallerId(),
request.EffectiveCallerId,
request.ImmediateCallerId,
)
reply := &proto.SplitQueryResult{}
if sqErr := q.server.SplitQuery(ctx, request.Target, &proto.SplitQueryRequest{
if err := q.server.SplitQuery(ctx, request.Target, &proto.SplitQueryRequest{
Query: *proto.Proto3ToBoundQuery(request.Query),
SplitColumn: request.SplitColumn,
SplitCount: int(request.SplitCount),
SessionID: request.SessionId,
}, reply); sqErr != nil {
return &pb.SplitQueryResponse{
Error: tabletserver.TabletErrorToRPCError(sqErr),
}, nil
}, reply); err != nil {
return nil, err
}
return &pb.SplitQueryResponse{
Queries: proto.QuerySplitsToProto3(reply.Queries),

Просмотреть файл

@ -12,16 +12,15 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/netutil"
"github.com/youtube/vitess/go/vt/callerid"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
"github.com/youtube/vitess/go/vt/vterrors"
"golang.org/x/net/context"
"google.golang.org/grpc"
pb "github.com/youtube/vitess/go/vt/proto/query"
pbs "github.com/youtube/vitess/go/vt/proto/queryservice"
pbt "github.com/youtube/vitess/go/vt/proto/topodata"
pbv "github.com/youtube/vitess/go/vt/proto/vtrpc"
)
const protocolName = "grpc"
@ -68,10 +67,6 @@ func DialTablet(ctx context.Context, endPoint *pbt.EndPoint, keyspace, shard str
cc.Close()
return nil, err
}
if gsir.Error != nil {
cc.Close()
return nil, tabletErrorFromRPCError(gsir.Error)
}
result.sessionID = gsir.SessionId
} else {
// we use target
@ -94,17 +89,17 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, query string, bindVars
}
req := &pb.ExecuteRequest{
Query: tproto.BoundQueryToProto3(query, bindVars),
TransactionId: transactionID,
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query, bindVars),
TransactionId: transactionID,
SessionId: conn.sessionID,
}
er, err := conn.c.Execute(ctx, req)
if err != nil {
return nil, tabletErrorFromGRPC(err)
}
if er.Error != nil {
return nil, tabletErrorFromRPCError(er.Error)
}
return mproto.Proto3ToQueryResult(er.Result), nil
}
@ -122,10 +117,13 @@ func (conn *gRPCQueryClient) ExecuteBatch(ctx context.Context, queries []tproto.
}
req := &pb.ExecuteBatchRequest{
Queries: make([]*pb.BoundQuery, len(queries)),
AsTransaction: asTransaction,
TransactionId: transactionID,
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
Queries: make([]*pb.BoundQuery, len(queries)),
AsTransaction: asTransaction,
TransactionId: transactionID,
SessionId: conn.sessionID,
}
for i, q := range queries {
req.Queries[i] = tproto.BoundQueryToProto3(q.Sql, q.BindVariables)
@ -134,9 +132,6 @@ func (conn *gRPCQueryClient) ExecuteBatch(ctx context.Context, queries []tproto.
if err != nil {
return nil, tabletErrorFromGRPC(err)
}
if ebr.Error != nil {
return nil, tabletErrorFromRPCError(ebr.Error)
}
return tproto.Proto3ToQueryResultList(ebr.Results), nil
}
@ -154,8 +149,11 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, query string, bi
}
req := &pb.StreamExecuteRequest{
Query: tproto.BoundQueryToProto3(query, bindVars),
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query, bindVars),
SessionId: conn.sessionID,
}
stream, err := conn.c.StreamExecute(ctx, req)
if err != nil {
@ -173,11 +171,6 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, query string, bi
close(sr)
return
}
if ser.Error != nil {
finalError = tabletErrorFromRPCError(ser.Error)
close(sr)
return
}
sr <- mproto.Proto3ToQueryResult(ser.Result)
}
}()
@ -200,15 +193,15 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context) (transactionID int64, er
}
req := &pb.BeginRequest{
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
SessionId: conn.sessionID,
}
br, err := conn.c.Begin(ctx, req)
if err != nil {
return 0, tabletErrorFromGRPC(err)
}
if br.Error != nil {
return 0, tabletErrorFromRPCError(br.Error)
}
return br.TransactionId, nil
}
@ -226,16 +219,16 @@ func (conn *gRPCQueryClient) Commit(ctx context.Context, transactionID int64) er
}
req := &pb.CommitRequest{
TransactionId: transactionID,
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
TransactionId: transactionID,
SessionId: conn.sessionID,
}
cr, err := conn.c.Commit(ctx, req)
_, err := conn.c.Commit(ctx, req)
if err != nil {
return tabletErrorFromGRPC(err)
}
if cr.Error != nil {
return tabletErrorFromRPCError(cr.Error)
}
return nil
}
@ -253,16 +246,16 @@ func (conn *gRPCQueryClient) Rollback(ctx context.Context, transactionID int64)
}
req := &pb.RollbackRequest{
TransactionId: transactionID,
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
TransactionId: transactionID,
SessionId: conn.sessionID,
}
rr, err := conn.c.Rollback(ctx, req)
_, err := conn.c.Rollback(ctx, req)
if err != nil {
return tabletErrorFromGRPC(err)
}
if rr.Error != nil {
return tabletErrorFromRPCError(rr.Error)
}
return nil
}
@ -281,18 +274,18 @@ func (conn *gRPCQueryClient) SplitQuery(ctx context.Context, query tproto.BoundQ
}
req := &pb.SplitQueryRequest{
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
SplitColumn: splitColumn,
SplitCount: int64(splitCount),
SessionId: conn.sessionID,
Target: conn.target,
EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx),
ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
SplitColumn: splitColumn,
SplitCount: int64(splitCount),
SessionId: conn.sessionID,
}
sqr, err := conn.c.SplitQuery(ctx, req)
if err != nil {
return nil, tabletErrorFromGRPC(err)
}
if sqr.Error != nil {
return nil, tabletErrorFromRPCError(sqr.Error)
}
return tproto.Proto3ToQuerySplits(sqr.Queries), nil
}
@ -371,19 +364,3 @@ func (conn *gRPCQueryClient) EndPoint() *pbt.EndPoint {
func tabletErrorFromGRPC(err error) error {
return tabletconn.OperationalError(fmt.Sprintf("vttablet: %v", err))
}
// tabletErrorFromRPCError reconstructs a tablet error from the
// RPCError, using the RPCError code.
func tabletErrorFromRPCError(rpcErr *pbv.RPCError) error {
ve := vterrors.FromVtRPCError(rpcErr)
// see if the range is in the tablet error range
if ve.Code >= int64(pbv.ErrorCode_TabletError) && ve.Code <= int64(pbv.ErrorCode_UnknownTabletError) {
return &tabletconn.ServerError{
Code: int(ve.Code - int64(pbv.ErrorCode_TabletError)),
Err: fmt.Sprintf("vttablet: %v", ve.Error()),
}
}
return tabletconn.OperationalError(fmt.Sprintf("vttablet: %v", ve.Message))
}

Просмотреть файл

@ -15,21 +15,25 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/tabletserver"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
"golang.org/x/net/context"
pb "github.com/youtube/vitess/go/vt/proto/query"
pbt "github.com/youtube/vitess/go/vt/proto/topodata"
pbv "github.com/youtube/vitess/go/vt/proto/vtrpc"
)
// FakeQueryService has the server side of this fake
type FakeQueryService struct {
t *testing.T
hasError bool
panics bool
streamExecutePanicsEarly bool
panicWait chan struct{}
// if set, we will also check Target, ImmediateCallerId and EffectiveCallerId
checkExtraFields bool
}
// HandlePanic is part of the queryservice.QueryService interface
@ -39,48 +43,56 @@ func (f *FakeQueryService) HandlePanic(err *error) {
}
}
// TestKeyspace is the Keyspace we use for this test
const TestKeyspace = "test_keyspace"
// testTarget is the target we use for this test
var testTarget = &pb.Target{
Keyspace: "test_keyspace",
Shard: "test_shard",
TabletType: pbt.TabletType_REPLICA,
}
// TestShard is the Shard we use for this test
const TestShard = "test_shard"
var testCallerID = &pbv.CallerID{
Principal: "test_principal",
Component: "test_component",
Subcomponent: "test_subcomponent",
}
// TestTabletType is the TabletType we use for this test
const TestTabletType = pbt.TabletType_UNKNOWN
var testVTGateCallerID = &pb.VTGateCallerID{
Username: "test_username",
}
const testAsTransaction bool = true
const testSessionID int64 = 5678
var testTabletError = tabletserver.NewTabletError(tabletserver.ErrFail, "generic error")
const expectedErrMatch string = "error: generic error"
// Verifies the returned error has the properties that we expect.
func verifyError(t *testing.T, err error, method string) {
if err == nil {
t.Errorf("%s was expecting an error, didn't get one", method)
return
func (f *FakeQueryService) checkTargetCallerID(ctx context.Context, name string, target *pb.Target) {
if !reflect.DeepEqual(target, testTarget) {
f.t.Errorf("invalid Target for %v: got %#v expected %#v", name, target, testTarget)
}
if se, ok := err.(*tabletconn.ServerError); ok {
if se.Code != tabletconn.ERR_NORMAL {
t.Errorf("Unexpected error code from %s: got %v, wanted %v", method, se.Code, tabletconn.ERR_NORMAL)
}
ef := callerid.EffectiveCallerIDFromContext(ctx)
if ef == nil {
f.t.Errorf("no effective caller id for %v", name)
} else {
t.Errorf("Unexpected error type from %s: got %v, wanted tabletconn.ServerError", method, reflect.TypeOf(err))
if !reflect.DeepEqual(ef, testCallerID) {
f.t.Errorf("invalid effective caller id for %v: got %v expected %v", name, ef, testCallerID)
}
}
if !strings.Contains(err.Error(), expectedErrMatch) {
t.Errorf("Unexpected error from %s: got %v, wanted err containing %v", method, err, expectedErrMatch)
im := callerid.ImmediateCallerIDFromContext(ctx)
if im == nil {
f.t.Errorf("no immediate caller id for %v", name)
} else {
if !reflect.DeepEqual(im, testVTGateCallerID) {
f.t.Errorf("invalid immediate caller id for %v: got %v expected %v", name, im, testVTGateCallerID)
}
}
}
// GetSessionId is part of the queryservice.QueryService interface
func (f *FakeQueryService) GetSessionId(sessionParams *proto.SessionParams, sessionInfo *proto.SessionInfo) error {
if sessionParams.Keyspace != TestKeyspace {
f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, TestKeyspace)
if sessionParams.Keyspace != testTarget.Keyspace {
f.t.Errorf("invalid keyspace: got %v expected %v", sessionParams.Keyspace, testTarget.Keyspace)
}
if sessionParams.Shard != TestShard {
f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, TestShard)
if sessionParams.Shard != testTarget.Shard {
f.t.Errorf("invalid shard: got %v expected %v", sessionParams.Shard, testTarget.Shard)
}
sessionInfo.SessionId = testSessionID
return nil
@ -88,14 +100,15 @@ func (f *FakeQueryService) GetSessionId(sessionParams *proto.SessionParams, sess
// Begin is part of the queryservice.QueryService interface
func (f *FakeQueryService) Begin(ctx context.Context, target *pb.Target, session *proto.Session, txInfo *proto.TransactionInfo) error {
if f.hasError {
return testTabletError
}
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if session.SessionId != testSessionID {
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "Begin", target)
} else {
if session.SessionId != testSessionID {
f.t.Errorf("Begin: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
}
}
if session.TransactionId != 0 {
f.t.Errorf("Begin: invalid TransactionId: got %v expected 0", session.TransactionId)
@ -118,13 +131,6 @@ func testBegin(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testBeginError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBeginError")
ctx := context.Background()
_, err := conn.Begin(ctx)
verifyError(t, err, "Begin")
}
func testBeginPanics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBeginPanics")
ctx := context.Background()
@ -136,6 +142,7 @@ func testBeginPanics(t *testing.T, conn tabletconn.TabletConn) {
func testBegin2(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBegin2")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
transactionID, err := conn.Begin2(ctx)
if err != nil {
t.Fatalf("Begin2 failed: %v", err)
@ -145,13 +152,6 @@ func testBegin2(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testBegin2Error(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBegin2Error")
ctx := context.Background()
_, err := conn.Begin2(ctx)
verifyError(t, err, "Begin2")
}
func testBegin2Panics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBegin2Panics")
ctx := context.Background()
@ -162,14 +162,15 @@ func testBegin2Panics(t *testing.T, conn tabletconn.TabletConn) {
// Commit is part of the queryservice.QueryService interface
func (f *FakeQueryService) Commit(ctx context.Context, target *pb.Target, session *proto.Session) error {
if f.hasError {
return testTabletError
}
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if session.SessionId != testSessionID {
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "Commit", target)
} else {
if session.SessionId != testSessionID {
f.t.Errorf("Commit: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
}
}
if session.TransactionId != commitTransactionID {
f.t.Errorf("Commit: invalid TransactionId: got %v expected %v", session.TransactionId, commitTransactionID)
@ -188,13 +189,6 @@ func testCommit(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testCommitError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testCommitError")
ctx := context.Background()
err := conn.Commit(ctx, commitTransactionID)
verifyError(t, err, "Commit")
}
func testCommitPanics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testCommitPanics")
ctx := context.Background()
@ -206,19 +200,13 @@ func testCommitPanics(t *testing.T, conn tabletconn.TabletConn) {
func testCommit2(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testCommit2")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
err := conn.Commit2(ctx, commitTransactionID)
if err != nil {
t.Fatalf("Commit2 failed: %v", err)
}
}
func testCommit2Error(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testCommit2Error")
ctx := context.Background()
err := conn.Commit2(ctx, commitTransactionID)
verifyError(t, err, "Commit2")
}
func testCommit2Panics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testCommit2Panics")
ctx := context.Background()
@ -229,14 +217,15 @@ func testCommit2Panics(t *testing.T, conn tabletconn.TabletConn) {
// Rollback is part of the queryservice.QueryService interface
func (f *FakeQueryService) Rollback(ctx context.Context, target *pb.Target, session *proto.Session) error {
if f.hasError {
return testTabletError
}
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if session.SessionId != testSessionID {
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "Rollback", target)
} else {
if session.SessionId != testSessionID {
f.t.Errorf("Rollback: invalid SessionId: got %v expected %v", session.SessionId, testSessionID)
}
}
if session.TransactionId != rollbackTransactionID {
f.t.Errorf("Rollback: invalid TransactionId: got %v expected %v", session.TransactionId, rollbackTransactionID)
@ -255,13 +244,6 @@ func testRollback(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testRollbackError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testRollbackError")
ctx := context.Background()
err := conn.Rollback(ctx, commitTransactionID)
verifyError(t, err, "Rollback")
}
func testRollbackPanics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testRollbackPanics")
ctx := context.Background()
@ -273,19 +255,13 @@ func testRollbackPanics(t *testing.T, conn tabletconn.TabletConn) {
func testRollback2(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testRollback2")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
err := conn.Rollback2(ctx, rollbackTransactionID)
if err != nil {
t.Fatalf("Rollback2 failed: %v", err)
}
}
func testRollback2Error(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testRollback2Error")
ctx := context.Background()
err := conn.Rollback2(ctx, commitTransactionID)
verifyError(t, err, "Rollback2")
}
func testRollback2Panics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testRollback2Panics")
ctx := context.Background()
@ -296,9 +272,6 @@ func testRollback2Panics(t *testing.T, conn tabletconn.TabletConn) {
// Execute is part of the queryservice.QueryService interface
func (f *FakeQueryService) Execute(ctx context.Context, target *pb.Target, query *proto.Query, reply *mproto.QueryResult) error {
if f.hasError {
return testTabletError
}
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
@ -308,8 +281,12 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *pb.Target, query
if !reflect.DeepEqual(query.BindVariables, executeBindVars) {
f.t.Errorf("invalid Execute.Query.BindVariables: got %v expected %v", query.BindVariables, executeBindVars)
}
if query.SessionId != testSessionID {
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "Execute", target)
} else {
if query.SessionId != testSessionID {
f.t.Errorf("invalid Execute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
}
}
if query.TransactionId != executeTransactionID {
f.t.Errorf("invalid Execute.Query.TransactionId: got %v expected %v", query.TransactionId, executeTransactionID)
@ -366,6 +343,7 @@ func testExecute(t *testing.T, conn tabletconn.TabletConn) {
func testExecute2(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecute2")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
qr, err := conn.Execute2(ctx, executeQuery, executeBindVars, executeTransactionID)
if err != nil {
t.Fatalf("Execute failed: %v", err)
@ -375,20 +353,6 @@ func testExecute2(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testExecuteError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecuteError")
ctx := context.Background()
_, err := conn.Execute(ctx, executeQuery, executeBindVars, executeTransactionID)
verifyError(t, err, "Execute")
}
func testExecute2Error(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecute2Error")
ctx := context.Background()
_, err := conn.Execute2(ctx, executeQuery, executeBindVars, executeTransactionID)
verifyError(t, err, "Execute")
}
func testExecutePanics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecutePanics")
ctx := context.Background()
@ -405,9 +369,6 @@ func testExecute2Panics(t *testing.T, conn tabletconn.TabletConn) {
}
}
var panicWait chan struct{}
var errorWait chan struct{}
// StreamExecute is part of the queryservice.QueryService interface
func (f *FakeQueryService) StreamExecute(ctx context.Context, target *pb.Target, query *proto.Query, sendReply func(*mproto.QueryResult) error) error {
if f.panics && f.streamExecutePanicsEarly {
@ -419,23 +380,22 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *pb.Target,
if !reflect.DeepEqual(query.BindVariables, streamExecuteBindVars) {
f.t.Errorf("invalid StreamExecute.Query.BindVariables: got %v expected %v", query.BindVariables, streamExecuteBindVars)
}
if query.SessionId != testSessionID {
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "StreamExecute", target)
} else {
if query.SessionId != testSessionID {
f.t.Errorf("invalid StreamExecute.Query.SessionId: got %v expected %v", query.SessionId, testSessionID)
}
}
if err := sendReply(&streamExecuteQueryResult1); err != nil {
f.t.Errorf("sendReply1 failed: %v", err)
}
if f.panics && !f.streamExecutePanicsEarly {
// wait until the client gets the response, then panics
<-panicWait
<-f.panicWait
f.panicWait = make(chan struct{}) // for next test
panic(fmt.Errorf("test-triggered panic late"))
}
if f.hasError {
// wait until the client has the response, since all streaming implementation may not
// send previous messages if an error has been triggered.
<-errorWait
return testTabletError
}
if err := sendReply(&streamExecuteQueryResult2); err != nil {
f.t.Errorf("sendReply2 failed: %v", err)
}
@ -512,36 +472,6 @@ func testStreamExecute(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testStreamExecuteError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testStreamExecuteError")
ctx := context.Background()
stream, errFunc, err := conn.StreamExecute(ctx, streamExecuteQuery, streamExecuteBindVars, streamExecuteTransactionID)
if err != nil {
t.Fatalf("StreamExecute failed: %v", err)
}
qr, ok := <-stream
if !ok {
t.Fatalf("StreamExecute failed: cannot read result1")
}
if len(qr.Rows) == 0 {
qr.Rows = nil
}
if !reflect.DeepEqual(*qr, streamExecuteQueryResult1) {
t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, streamExecuteQueryResult1)
}
// signal to the server that the first result has been received
close(errorWait)
// After 1 result, we expect to get an error (no more results).
qr, ok = <-stream
if ok {
t.Fatalf("StreamExecute channel wasn't closed")
}
err = errFunc()
verifyError(t, err, "StreamExecute")
// reset state for the test
errorWait = make(chan struct{})
}
func testStreamExecutePanics(t *testing.T, conn tabletconn.TabletConn, fake *FakeQueryService) {
t.Log("testStreamExecutePanics")
// early panic is before sending the Fields, that is returned
@ -581,20 +511,19 @@ func testStreamExecutePanics(t *testing.T, conn tabletconn.TabletConn, fake *Fak
if !reflect.DeepEqual(*qr, streamExecuteQueryResult1) {
t.Errorf("Unexpected result1 from StreamExecute: got %v wanted %v", qr, streamExecuteQueryResult1)
}
close(panicWait)
close(fake.panicWait)
if _, ok := <-stream; ok {
t.Fatalf("StreamExecute returned more results")
}
if err := errFunc(); err == nil || !strings.Contains(err.Error(), "caught test panic") {
t.Fatalf("unexpected panic error: %v", err)
}
// Make a new panicWait channel, to reset the state to the beginning of the test
panicWait = make(chan struct{})
}
func testStreamExecute2(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testStreamExecute2")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
stream, errFunc, err := conn.StreamExecute2(ctx, streamExecuteQuery, streamExecuteBindVars, streamExecuteTransactionID)
if err != nil {
t.Fatalf("StreamExecute2 failed: %v", err)
@ -628,42 +557,13 @@ func testStreamExecute2(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testStreamExecute2Error(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testStreamExecute2Error")
ctx := context.Background()
stream, errFunc, err := conn.StreamExecute2(ctx, streamExecuteQuery, streamExecuteBindVars, streamExecuteTransactionID)
if err != nil {
t.Fatalf("StreamExecute2 failed: %v", err)
}
qr, ok := <-stream
if !ok {
t.Fatalf("StreamExecute2 failed: cannot read result1")
}
if len(qr.Rows) == 0 {
qr.Rows = nil
}
if !reflect.DeepEqual(*qr, streamExecuteQueryResult1) {
t.Errorf("Unexpected result1 from StreamExecute2: got %v wanted %v", qr, streamExecuteQueryResult1)
}
// signal to the server that the first result has been received
close(errorWait)
// After 1 result, we expect to get an error (no more results).
qr, ok = <-stream
if ok {
t.Fatalf("StreamExecute2 channel wasn't closed")
}
err = errFunc()
verifyError(t, err, "StreamExecute2")
// reset state for the test
errorWait = make(chan struct{})
}
func testStreamExecute2Panics(t *testing.T, conn tabletconn.TabletConn, fake *FakeQueryService) {
t.Log("testStreamExecute2Panics")
// early panic is before sending the Fields, that is returned
// by the StreamExecute2 call itself, or as the first error
// by ErrFunc
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
fake.streamExecutePanicsEarly = true
stream, errFunc, err := conn.StreamExecute2(ctx, streamExecuteQuery, streamExecuteBindVars, streamExecuteTransactionID)
if err != nil {
@ -697,30 +597,29 @@ func testStreamExecute2Panics(t *testing.T, conn tabletconn.TabletConn, fake *Fa
if !reflect.DeepEqual(*qr, streamExecuteQueryResult1) {
t.Errorf("Unexpected result1 from StreamExecute2: got %v wanted %v", qr, streamExecuteQueryResult1)
}
close(panicWait)
close(fake.panicWait)
if _, ok := <-stream; ok {
t.Fatalf("StreamExecute2 returned more results")
}
if err := errFunc(); err == nil || !strings.Contains(err.Error(), "caught test panic") {
t.Fatalf("unexpected panic error: %v", err)
}
// Make a new panicWait channel, to reset the state to the beginning of the test
panicWait = make(chan struct{})
}
// ExecuteBatch is part of the queryservice.QueryService interface
func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *pb.Target, queryList *proto.QueryList, reply *proto.QueryResultList) error {
if f.hasError {
return testTabletError
}
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if !reflect.DeepEqual(queryList.Queries, executeBatchQueries) {
f.t.Errorf("invalid ExecuteBatch.QueryList.Queries: got %v expected %v", queryList.Queries, executeBatchQueries)
}
if queryList.SessionId != testSessionID {
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "ExecuteBatch", target)
} else {
if queryList.SessionId != testSessionID {
f.t.Errorf("invalid ExecuteBatch.QueryList.SessionId: got %v expected %v", queryList.SessionId, testSessionID)
}
}
if queryList.AsTransaction != testAsTransaction {
f.t.Errorf("invalid ExecuteBatch.QueryList.AsTransaction: got %v expected %v", queryList.AsTransaction, testAsTransaction)
@ -800,13 +699,6 @@ func testExecuteBatch(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testExecuteBatchError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBatchExecuteError")
ctx := context.Background()
_, err := conn.ExecuteBatch(ctx, executeBatchQueries, true, executeBatchTransactionID)
verifyError(t, err, "ExecuteBatch")
}
func testExecuteBatchPanics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecuteBatchPanics")
ctx := context.Background()
@ -818,6 +710,7 @@ func testExecuteBatchPanics(t *testing.T, conn tabletconn.TabletConn) {
func testExecuteBatch2(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecuteBatch2")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
qrl, err := conn.ExecuteBatch2(ctx, executeBatchQueries, true, executeBatchTransactionID)
if err != nil {
t.Fatalf("ExecuteBatch failed: %v", err)
@ -827,13 +720,6 @@ func testExecuteBatch2(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testExecuteBatch2Error(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testBatchExecute2Error")
ctx := context.Background()
_, err := conn.ExecuteBatch2(ctx, executeBatchQueries, true, executeBatchTransactionID)
verifyError(t, err, "ExecuteBatch")
}
func testExecuteBatch2Panics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testExecuteBatch2Panics")
ctx := context.Background()
@ -844,12 +730,12 @@ func testExecuteBatch2Panics(t *testing.T, conn tabletconn.TabletConn) {
// SplitQuery is part of the queryservice.QueryService interface
func (f *FakeQueryService) SplitQuery(ctx context.Context, target *pb.Target, req *proto.SplitQueryRequest, reply *proto.SplitQueryResult) error {
if f.hasError {
return testTabletError
}
if f.panics {
panic(fmt.Errorf("test-triggered panic"))
}
if f.checkExtraFields {
f.checkTargetCallerID(ctx, "SplitQuery", target)
}
if !reflect.DeepEqual(req.Query, splitQueryBoundQuery) {
f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v", req.Query, splitQueryBoundQuery)
}
@ -889,6 +775,7 @@ var splitQueryQuerySplitList = []proto.QuerySplit{
func testSplitQuery(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testSplitQuery")
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, testVTGateCallerID)
qsl, err := conn.SplitQuery(ctx, splitQueryBoundQuery, splitQuerySplitColumn, splitQuerySplitCount)
if err != nil {
t.Fatalf("SplitQuery failed: %v", err)
@ -898,13 +785,6 @@ func testSplitQuery(t *testing.T, conn tabletconn.TabletConn) {
}
}
func testSplitQueryError(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testSplitQueryError")
ctx := context.Background()
_, err := conn.SplitQuery(ctx, splitQueryBoundQuery, splitQuerySplitColumn, splitQuerySplitCount)
verifyError(t, err, "SplitQuery")
}
func testSplitQueryPanics(t *testing.T, conn tabletconn.TabletConn) {
t.Log("testSplitQueryPanics")
ctx := context.Background()
@ -1005,14 +885,11 @@ func testStreamHealthPanics(t *testing.T, conn tabletconn.TabletConn) {
// CreateFakeServer returns the fake server for the tests
func CreateFakeServer(t *testing.T) *FakeQueryService {
// Make the synchronization channels on init, so there's no state shared between servers
panicWait = make(chan struct{})
errorWait = make(chan struct{})
return &FakeQueryService{
t: t,
panics: false,
streamExecutePanicsEarly: false,
panicWait: make(chan struct{}),
}
}
@ -1021,62 +898,64 @@ func TestSuite(t *testing.T, protocol string, endPoint *pbt.EndPoint, fake *Fake
// make sure we use the right client
*tabletconn.TabletProtocol = protocol
// create a connection
// create a connection, using sessionId
ctx := context.Background()
conn, err := tabletconn.GetDialer()(ctx, endPoint, TestKeyspace, TestShard, TestTabletType, 30*time.Second)
conn, err := tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
defer conn.Close()
// run the normal tests
testBegin(t, conn)
testBegin2(t, conn)
testCommit(t, conn)
testCommit2(t, conn)
testRollback(t, conn)
testRollback2(t, conn)
testExecute(t, conn)
testExecute2(t, conn)
testStreamExecute(t, conn)
testStreamExecute2(t, conn)
testExecuteBatch(t, conn)
testExecuteBatch2(t, conn)
testSplitQuery(t, conn)
testStreamHealth(t, conn)
// fake should return an error, make sure errors are handled properly
fake.hasError = true
testBeginError(t, conn)
testBegin2Error(t, conn)
testCommitError(t, conn)
testCommit2Error(t, conn)
testRollbackError(t, conn)
testRollback2Error(t, conn)
testExecuteError(t, conn)
testExecute2Error(t, conn)
testStreamExecuteError(t, conn)
testStreamExecute2Error(t, conn)
testExecuteBatchError(t, conn)
testExecuteBatch2Error(t, conn)
testSplitQueryError(t, conn)
fake.hasError = false
// create a new connection that expects the extra fields
conn.Close()
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_REPLICA, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
// force panics, make sure they're caught
// run the tests that expect extra fields
fake.checkExtraFields = true
testBegin2(t, conn)
testCommit2(t, conn)
testRollback2(t, conn)
testExecute2(t, conn)
testStreamExecute2(t, conn)
testExecuteBatch2(t, conn)
testSplitQuery(t, conn)
// force panics, make sure they're caught (with extra fields)
fake.panics = true
testBeginPanics(t, conn)
testBegin2Panics(t, conn)
testCommitPanics(t, conn)
testCommit2Panics(t, conn)
testRollbackPanics(t, conn)
testRollback2Panics(t, conn)
testExecutePanics(t, conn)
testExecute2Panics(t, conn)
testStreamExecutePanics(t, conn, fake)
testStreamExecute2Panics(t, conn, fake)
testExecuteBatchPanics(t, conn)
testExecuteBatch2Panics(t, conn)
testSplitQueryPanics(t, conn)
testStreamHealthPanics(t, conn)
// force panic without extra fields
conn.Close()
conn, err = tabletconn.GetDialer()(ctx, endPoint, testTarget.Keyspace, testTarget.Shard, pbt.TabletType_UNKNOWN, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
fake.checkExtraFields = false
testBeginPanics(t, conn)
testCommitPanics(t, conn)
testRollbackPanics(t, conn)
testExecutePanics(t, conn)
testExecuteBatchPanics(t, conn)
testStreamExecutePanics(t, conn, fake)
fake.panics = false
conn.Close()
}

Просмотреть файл

@ -159,8 +159,7 @@ message GetSessionIdRequest {
// GetSessionIdResponse is the returned value from GetSessionId
message GetSessionIdResponse {
vtrpc.RPCError error = 1;
int64 session_id = 2;
int64 session_id = 1;
}
// ExecuteRequest is the payload to Execute
@ -175,8 +174,7 @@ message ExecuteRequest {
// ExecuteResponse is the returned value from Execute
message ExecuteResponse {
vtrpc.RPCError error = 1;
QueryResult result = 2;
QueryResult result = 1;
}
// ExecuteBatchRequest is the payload to ExecuteBatch
@ -192,8 +190,7 @@ message ExecuteBatchRequest {
// ExecuteBatchResponse is the returned value from ExecuteBatch
message ExecuteBatchResponse {
vtrpc.RPCError error = 1;
repeated QueryResult results = 2;
repeated QueryResult results = 1;
}
// StreamExecuteRequest is the payload to StreamExecute
@ -207,8 +204,7 @@ message StreamExecuteRequest {
// StreamExecuteResponse is the returned value from StreamExecute
message StreamExecuteResponse {
vtrpc.RPCError error = 1;
QueryResult result = 2;
QueryResult result = 1;
}
// BeginRequest is the payload to Begin
@ -221,8 +217,7 @@ message BeginRequest {
// BeginResponse is the returned value from Begin
message BeginResponse {
vtrpc.RPCError error = 1;
int64 transaction_id = 2;
int64 transaction_id = 1;
}
// CommitRequest is the payload to Commit
@ -235,9 +230,7 @@ message CommitRequest {
}
// CommitResponse is the returned value from Commit
message CommitResponse {
vtrpc.RPCError error = 1;
}
message CommitResponse {}
// RollbackRequest is the payload to Rollback
message RollbackRequest {
@ -249,9 +242,7 @@ message RollbackRequest {
}
// RollbackResponse is the returned value from Rollback
message RollbackResponse {
vtrpc.RPCError error = 1;
}
message RollbackResponse {}
// SplitQueryRequest is the payload for SplitQuery
message SplitQueryRequest {
@ -276,8 +267,7 @@ message QuerySplit {
// SplitQueryResponse is returned by SplitQuery and represents all the queries
// to execute in order to get the entire data set.
message SplitQueryResponse {
vtrpc.RPCError error = 1;
repeated QuerySplit queries = 2;
repeated QuerySplit queries = 1;
}
// StreamHealthRequest is the payload for StreamHealth

Различия файлов скрыты, потому что одна или несколько строк слишком длинны