Adding tests for callerid in client -> vtgate connections.

Fixing all the places that broke, in bson and grpc.
This commit is contained in:
Alain Jobart 2015-08-06 11:08:56 -07:00
Родитель b436bf95f6
Коммит 6813971a29
16 изменённых файлов: 465 добавлений и 100 удалений

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proto
// DO NOT EDIT.
// FILE GENERATED BY BSONGEN.
import (
"bytes"
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
)
// MarshalBson bson-encodes CallerID.
func (callerID *CallerID) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
bson.EncodeString(buf, "Principal", callerID.Principal)
bson.EncodeString(buf, "Component", callerID.Component)
bson.EncodeString(buf, "Subcomponent", callerID.Subcomponent)
lenWriter.Close()
}
// UnmarshalBson bson-decodes into CallerID.
func (callerID *CallerID) UnmarshalBson(buf *bytes.Buffer, kind byte) {
switch kind {
case bson.EOO, bson.Object:
// valid
case bson.Null:
return
default:
panic(bson.NewBsonError("unexpected kind %v for CallerID", kind))
}
bson.Next(buf, 4)
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "Principal":
callerID.Principal = bson.DecodeString(buf, kind)
case "Component":
callerID.Component = bson.DecodeString(buf, kind)
case "Subcomponent":
callerID.Subcomponent = bson.DecodeString(buf, kind)
default:
bson.Skip(buf, kind)
}
}
}

Просмотреть файл

@ -170,6 +170,8 @@ type CallerID struct {
Subcomponent string
}
//go:generate bsongen -file $GOFILE -type CallerID -o callerid_bson.go
// VTGateCallerID is the BSON implementation of the proto3 query.VTGateCallerID
type VTGateCallerID struct {
Username string

Просмотреть файл

@ -13,6 +13,7 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/rpcplus"
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/rpc"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
@ -43,12 +44,24 @@ func dial(ctx context.Context, address string, timeout time.Duration) (vtgatecon
return &vtgateConn{rpcConn: rpcConn}, nil
}
func getEffectiveCallerID(ctx context.Context) *tproto.CallerID {
if ef := callerid.EffectiveCallerIDFromContext(ctx); ef != nil {
return &tproto.CallerID{
Principal: ef.Principal,
Component: ef.Component,
Subcomponent: ef.Subcomponent,
}
}
return nil
}
func (conn *vtgateConn) Execute(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType, notInTransaction bool, session interface{}) (*mproto.QueryResult, interface{}, error) {
var s *proto.Session
if session != nil {
s = session.(*proto.Session)
}
request := proto.Query{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
TabletType: tabletType,
@ -74,6 +87,7 @@ func (conn *vtgateConn) ExecuteShard(ctx context.Context, query string, keyspace
s = session.(*proto.Session)
}
request := proto.QueryShard{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -101,6 +115,7 @@ func (conn *vtgateConn) ExecuteKeyspaceIds(ctx context.Context, query string, ke
s = session.(*proto.Session)
}
request := proto.KeyspaceIdQuery{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -128,6 +143,7 @@ func (conn *vtgateConn) ExecuteKeyRanges(ctx context.Context, query string, keys
s = session.(*proto.Session)
}
request := proto.KeyRangeQuery{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -155,6 +171,7 @@ func (conn *vtgateConn) ExecuteEntityIds(ctx context.Context, query string, keys
s = session.(*proto.Session)
}
request := proto.EntityIdsQuery{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -183,6 +200,7 @@ func (conn *vtgateConn) ExecuteBatchShard(ctx context.Context, queries []proto.B
s = session.(*proto.Session)
}
request := proto.BatchQueryShard{
CallerID: getEffectiveCallerID(ctx),
Queries: queries,
TabletType: tabletType,
AsTransaction: asTransaction,
@ -207,6 +225,7 @@ func (conn *vtgateConn) ExecuteBatchKeyspaceIds(ctx context.Context, queries []p
s = session.(*proto.Session)
}
request := proto.KeyspaceIdBatchQuery{
CallerID: getEffectiveCallerID(ctx),
Queries: queries,
TabletType: tabletType,
AsTransaction: asTransaction,
@ -227,6 +246,7 @@ func (conn *vtgateConn) ExecuteBatchKeyspaceIds(ctx context.Context, queries []p
func (conn *vtgateConn) StreamExecute(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &proto.Query{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
TabletType: tabletType,
@ -239,6 +259,7 @@ func (conn *vtgateConn) StreamExecute(ctx context.Context, query string, bindVar
func (conn *vtgateConn) StreamExecuteShard(ctx context.Context, query string, keyspace string, shards []string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &proto.QueryShard{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -253,6 +274,7 @@ func (conn *vtgateConn) StreamExecuteShard(ctx context.Context, query string, ke
func (conn *vtgateConn) StreamExecuteKeyRanges(ctx context.Context, query string, keyspace string, keyRanges []key.KeyRange, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &proto.KeyRangeQuery{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -267,6 +289,7 @@ func (conn *vtgateConn) StreamExecuteKeyRanges(ctx context.Context, query string
func (conn *vtgateConn) StreamExecuteKeyspaceIds(ctx context.Context, query string, keyspace string, keyspaceIds []key.KeyspaceId, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &proto.KeyspaceIdQuery{
CallerID: getEffectiveCallerID(ctx),
Sql: query,
BindVariables: bindVars,
Keyspace: keyspace,
@ -309,7 +332,9 @@ func (conn *vtgateConn) Rollback(ctx context.Context, session interface{}) error
}
func (conn *vtgateConn) Begin2(ctx context.Context) (interface{}, error) {
request := new(proto.BeginRequest)
request := &proto.BeginRequest{
CallerID: getEffectiveCallerID(ctx),
}
reply := new(proto.BeginResponse)
if err := conn.rpcConn.Call(ctx, "VTGate.Begin2", request, reply); err != nil {
return nil, err
@ -328,7 +353,8 @@ func (conn *vtgateConn) Begin2(ctx context.Context) (interface{}, error) {
func (conn *vtgateConn) Commit2(ctx context.Context, session interface{}) error {
s := session.(*proto.Session)
request := &proto.CommitRequest{
Session: s,
CallerID: getEffectiveCallerID(ctx),
Session: s,
}
reply := new(proto.CommitResponse)
if err := conn.rpcConn.Call(ctx, "VTGate.Commit2", request, reply); err != nil {
@ -340,7 +366,8 @@ func (conn *vtgateConn) Commit2(ctx context.Context, session interface{}) error
func (conn *vtgateConn) Rollback2(ctx context.Context, session interface{}) error {
s := session.(*proto.Session)
request := &proto.RollbackRequest{
Session: s,
CallerID: getEffectiveCallerID(ctx),
Session: s,
}
reply := new(proto.RollbackResponse)
if err := conn.rpcConn.Call(ctx, "VTGate.Rollback2", request, reply); err != nil {
@ -351,6 +378,7 @@ func (conn *vtgateConn) Rollback2(ctx context.Context, session interface{}) erro
func (conn *vtgateConn) SplitQuery(ctx context.Context, keyspace string, query tproto.BoundQuery, splitColumn string, splitCount int) ([]proto.SplitQueryPart, error) {
request := &proto.SplitQueryRequest{
CallerID: getEffectiveCallerID(ctx),
Keyspace: keyspace,
Query: query,
SplitColumn: splitColumn,

Просмотреть файл

@ -9,6 +9,7 @@ import (
"flag"
"time"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/rpc"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/topo"
@ -28,11 +29,14 @@ type VTGate struct {
}
// Execute is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) Execute(ctx context.Context, query *proto.Query, reply *proto.QueryResult) (err error) {
func (vtg *VTGate) Execute(ctx context.Context, request *proto.Query, reply *proto.QueryResult) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.Execute(ctx, query, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.Execute(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResult(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -41,11 +45,14 @@ func (vtg *VTGate) Execute(ctx context.Context, query *proto.Query, reply *proto
}
// ExecuteShard is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteShard(ctx context.Context, query *proto.QueryShard, reply *proto.QueryResult) (err error) {
func (vtg *VTGate) ExecuteShard(ctx context.Context, request *proto.QueryShard, reply *proto.QueryResult) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.ExecuteShard(ctx, query, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.ExecuteShard(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResult(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -54,11 +61,14 @@ func (vtg *VTGate) ExecuteShard(ctx context.Context, query *proto.QueryShard, re
}
// ExecuteKeyspaceIds is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, reply *proto.QueryResult) (err error) {
func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, request *proto.KeyspaceIdQuery, reply *proto.QueryResult) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.ExecuteKeyspaceIds(ctx, query, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.ExecuteKeyspaceIds(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResult(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -67,11 +77,14 @@ func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, query *proto.Keyspace
}
// ExecuteKeyRanges is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, reply *proto.QueryResult) (err error) {
func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, request *proto.KeyRangeQuery, reply *proto.QueryResult) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.ExecuteKeyRanges(ctx, query, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.ExecuteKeyRanges(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResult(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -80,11 +93,14 @@ func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQu
}
// ExecuteEntityIds is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, query *proto.EntityIdsQuery, reply *proto.QueryResult) (err error) {
func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, request *proto.EntityIdsQuery, reply *proto.QueryResult) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.ExecuteEntityIds(ctx, query, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.ExecuteEntityIds(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResult(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -93,11 +109,14 @@ func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, query *proto.EntityIdsQ
}
// ExecuteBatchShard is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteBatchShard(ctx context.Context, batchQuery *proto.BatchQueryShard, reply *proto.QueryResultList) (err error) {
func (vtg *VTGate) ExecuteBatchShard(ctx context.Context, request *proto.BatchQueryShard, reply *proto.QueryResultList) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.ExecuteBatchShard(ctx, batchQuery, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.ExecuteBatchShard(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResultList(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -107,11 +126,14 @@ func (vtg *VTGate) ExecuteBatchShard(ctx context.Context, batchQuery *proto.Batc
// ExecuteBatchKeyspaceIds is the RPC version of
// vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteBatchKeyspaceIds(ctx context.Context, batchQuery *proto.KeyspaceIdBatchQuery, reply *proto.QueryResultList) (err error) {
func (vtg *VTGate) ExecuteBatchKeyspaceIds(ctx context.Context, request *proto.KeyspaceIdBatchQuery, reply *proto.QueryResultList) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.ExecuteBatchKeyspaceIds(ctx, batchQuery, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.ExecuteBatchKeyspaceIds(ctx, request, reply)
vtgate.AddVtGateErrorToQueryResultList(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil
@ -120,35 +142,47 @@ func (vtg *VTGate) ExecuteBatchKeyspaceIds(ctx context.Context, batchQuery *prot
}
// StreamExecute is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecute(ctx context.Context, query *proto.Query, sendReply func(interface{}) error) (err error) {
func (vtg *VTGate) StreamExecute(ctx context.Context, request *proto.Query, sendReply func(interface{}) error) (err error) {
defer vtg.server.HandlePanic(&err)
return vtg.server.StreamExecute(ctx, query, func(value *proto.QueryResult) error {
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
return vtg.server.StreamExecute(ctx, request, func(value *proto.QueryResult) error {
return sendReply(value)
})
}
// StreamExecuteShard is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecuteShard(ctx context.Context, query *proto.QueryShard, sendReply func(interface{}) error) (err error) {
func (vtg *VTGate) StreamExecuteShard(ctx context.Context, request *proto.QueryShard, sendReply func(interface{}) error) (err error) {
defer vtg.server.HandlePanic(&err)
return vtg.server.StreamExecuteShard(ctx, query, func(value *proto.QueryResult) error {
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
return vtg.server.StreamExecuteShard(ctx, request, func(value *proto.QueryResult) error {
return sendReply(value)
})
}
// StreamExecuteKeyRanges is the RPC version of
// vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, sendReply func(interface{}) error) (err error) {
func (vtg *VTGate) StreamExecuteKeyRanges(ctx context.Context, request *proto.KeyRangeQuery, sendReply func(interface{}) error) (err error) {
defer vtg.server.HandlePanic(&err)
return vtg.server.StreamExecuteKeyRanges(ctx, query, func(value *proto.QueryResult) error {
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
return vtg.server.StreamExecuteKeyRanges(ctx, request, func(value *proto.QueryResult) error {
return sendReply(value)
})
}
// StreamExecuteKeyspaceIds is the RPC version of
// vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, sendReply func(interface{}) error) (err error) {
func (vtg *VTGate) StreamExecuteKeyspaceIds(ctx context.Context, request *proto.KeyspaceIdQuery, sendReply func(interface{}) error) (err error) {
defer vtg.server.HandlePanic(&err)
return vtg.server.StreamExecuteKeyspaceIds(ctx, query, func(value *proto.QueryResult) error {
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
return vtg.server.StreamExecuteKeyspaceIds(ctx, request, func(value *proto.QueryResult) error {
return sendReply(value)
})
}
@ -182,6 +216,9 @@ func (vtg *VTGate) Begin2(ctx context.Context, request *proto.BeginRequest, repl
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
// Don't pass in a nil pointer
reply.Session = &proto.Session{}
vtgErr := vtg.server.Begin(ctx, reply.Session)
@ -197,6 +234,9 @@ func (vtg *VTGate) Commit2(ctx context.Context, request *proto.CommitRequest, re
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.Commit(ctx, request.Session)
vtgate.AddVtGateErrorToCommitResponse(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
@ -210,6 +250,9 @@ func (vtg *VTGate) Rollback2(ctx context.Context, request *proto.RollbackRequest
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.Rollback(ctx, request.Session)
vtgate.AddVtGateErrorToRollbackResponse(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
@ -219,11 +262,14 @@ func (vtg *VTGate) Rollback2(ctx context.Context, request *proto.RollbackRequest
}
// SplitQuery is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) SplitQuery(ctx context.Context, req *proto.SplitQueryRequest, reply *proto.SplitQueryResult) (err error) {
func (vtg *VTGate) SplitQuery(ctx context.Context, request *proto.SplitQueryRequest, reply *proto.SplitQueryResult) (err error) {
defer vtg.server.HandlePanic(&err)
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout))
defer cancel()
vtgErr := vtg.server.SplitQuery(ctx, req, reply)
ctx = callerid.NewContext(ctx,
callerid.GoRPCEffectiveCallerID(request.CallerID),
callerid.NewImmediateCallerID("gorpc client"))
vtgErr := vtg.server.SplitQuery(ctx, request, reply)
vtgate.AddVtGateErrorToSplitQueryResult(vtgErr, reply)
if *vtgate.RPCErrorOnlyInReply {
return nil

Просмотреть файл

@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/topo"
@ -51,6 +52,7 @@ func (conn *vtgateConn) Execute(ctx context.Context, query string, bindVars map[
s = session.(*pb.Session)
}
request := &pb.ExecuteRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Query: tproto.BoundQueryToProto3(query, bindVars),
TabletType: topo.TabletTypeToProto(tabletType),
@ -72,6 +74,7 @@ func (conn *vtgateConn) ExecuteShard(ctx context.Context, query string, keyspace
s = session.(*pb.Session)
}
request := &pb.ExecuteShardsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
@ -95,6 +98,7 @@ func (conn *vtgateConn) ExecuteKeyspaceIds(ctx context.Context, query string, ke
s = session.(*pb.Session)
}
request := &pb.ExecuteKeyspaceIdsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
@ -118,6 +122,7 @@ func (conn *vtgateConn) ExecuteKeyRanges(ctx context.Context, query string, keys
s = session.(*pb.Session)
}
request := &pb.ExecuteKeyRangesRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
@ -141,6 +146,7 @@ func (conn *vtgateConn) ExecuteEntityIds(ctx context.Context, query string, keys
s = session.(*pb.Session)
}
request := &pb.ExecuteEntityIdsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
@ -165,6 +171,7 @@ func (conn *vtgateConn) ExecuteBatchShard(ctx context.Context, queries []proto.B
s = session.(*pb.Session)
}
request := &pb.ExecuteBatchShardsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Queries: proto.BoundShardQueriesToProto(queries),
TabletType: topo.TabletTypeToProto(tabletType),
@ -186,6 +193,7 @@ func (conn *vtgateConn) ExecuteBatchKeyspaceIds(ctx context.Context, queries []p
s = session.(*pb.Session)
}
request := &pb.ExecuteBatchKeyspaceIdsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: s,
Queries: proto.BoundKeyspaceIdQueriesToProto(queries),
TabletType: topo.TabletTypeToProto(tabletType),
@ -203,6 +211,7 @@ func (conn *vtgateConn) ExecuteBatchKeyspaceIds(ctx context.Context, queries []p
func (conn *vtgateConn) StreamExecute(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &pb.StreamExecuteRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query, bindVars),
TabletType: topo.TabletTypeToProto(tabletType),
}
@ -237,6 +246,7 @@ func (conn *vtgateConn) StreamExecute(ctx context.Context, query string, bindVar
func (conn *vtgateConn) StreamExecuteShard(ctx context.Context, query string, keyspace string, shards []string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &pb.StreamExecuteShardsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
Shards: shards,
@ -273,6 +283,7 @@ func (conn *vtgateConn) StreamExecuteShard(ctx context.Context, query string, ke
func (conn *vtgateConn) StreamExecuteKeyRanges(ctx context.Context, query string, keyspace string, keyRanges []key.KeyRange, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &pb.StreamExecuteKeyRangesRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
KeyRanges: key.KeyRangesToProto(keyRanges),
@ -309,6 +320,7 @@ func (conn *vtgateConn) StreamExecuteKeyRanges(ctx context.Context, query string
func (conn *vtgateConn) StreamExecuteKeyspaceIds(ctx context.Context, query string, keyspace string, keyspaceIds []key.KeyspaceId, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) {
req := &pb.StreamExecuteKeyspaceIdsRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Query: tproto.BoundQueryToProto3(query, bindVars),
Keyspace: keyspace,
KeyspaceIds: key.KeyspaceIdsToProto(keyspaceIds),
@ -344,7 +356,9 @@ func (conn *vtgateConn) StreamExecuteKeyspaceIds(ctx context.Context, query stri
}
func (conn *vtgateConn) Begin(ctx context.Context) (interface{}, error) {
request := &pb.BeginRequest{}
request := &pb.BeginRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
}
response, err := conn.c.Begin(ctx, request)
if err != nil {
return nil, err
@ -357,7 +371,8 @@ func (conn *vtgateConn) Begin(ctx context.Context) (interface{}, error) {
func (conn *vtgateConn) Commit(ctx context.Context, session interface{}) error {
request := &pb.CommitRequest{
Session: session.(*pb.Session),
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: session.(*pb.Session),
}
response, err := conn.c.Commit(ctx, request)
if err != nil {
@ -371,7 +386,8 @@ func (conn *vtgateConn) Commit(ctx context.Context, session interface{}) error {
func (conn *vtgateConn) Rollback(ctx context.Context, session interface{}) error {
request := &pb.RollbackRequest{
Session: session.(*pb.Session),
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Session: session.(*pb.Session),
}
response, err := conn.c.Rollback(ctx, request)
if err != nil {
@ -397,6 +413,7 @@ func (conn *vtgateConn) Rollback2(ctx context.Context, session interface{}) erro
func (conn *vtgateConn) SplitQuery(ctx context.Context, keyspace string, query tproto.BoundQuery, splitColumn string, splitCount int) ([]proto.SplitQueryPart, error) {
request := &pb.SplitQueryRequest{
CallerId: callerid.EffectiveCallerIDFromContext(ctx),
Keyspace: keyspace,
Query: tproto.BoundQueryToProto3(query.Sql, query.BindVariables),
SplitColumn: splitColumn,

Просмотреть файл

@ -9,6 +9,8 @@ import (
"google.golang.org/grpc"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/callinfo"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/servenv"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
@ -30,6 +32,9 @@ type VTGate struct {
// Execute is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) Execute(ctx context.Context, request *pb.ExecuteRequest) (response *pb.ExecuteResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.Query{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -56,6 +61,9 @@ func (vtg *VTGate) Execute(ctx context.Context, request *pb.ExecuteRequest) (res
// ExecuteShards is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteShards(ctx context.Context, request *pb.ExecuteShardsRequest) (response *pb.ExecuteShardsResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.QueryShard{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -84,6 +92,9 @@ func (vtg *VTGate) ExecuteShards(ctx context.Context, request *pb.ExecuteShardsR
// ExecuteKeyspaceIds is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, request *pb.ExecuteKeyspaceIdsRequest) (response *pb.ExecuteKeyspaceIdsResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.KeyspaceIdQuery{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -112,6 +123,9 @@ func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, request *pb.ExecuteKe
// ExecuteKeyRanges is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, request *pb.ExecuteKeyRangesRequest) (response *pb.ExecuteKeyRangesResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.KeyRangeQuery{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -140,6 +154,9 @@ func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, request *pb.ExecuteKeyR
// ExecuteEntityIds is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, request *pb.ExecuteEntityIdsRequest) (response *pb.ExecuteEntityIdsResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.EntityIdsQuery{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -169,7 +186,9 @@ func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, request *pb.ExecuteEnti
// ExecuteBatchShards is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteBatchShards(ctx context.Context, request *pb.ExecuteBatchShardsRequest) (response *pb.ExecuteBatchShardsResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.BatchQueryShard{
Session: proto.ProtoToSession(request.Session),
Queries: proto.ProtoToBoundShardQueries(request.Queries),
@ -196,7 +215,9 @@ func (vtg *VTGate) ExecuteBatchShards(ctx context.Context, request *pb.ExecuteBa
// vtgateservice.VTGateService method
func (vtg *VTGate) ExecuteBatchKeyspaceIds(ctx context.Context, request *pb.ExecuteBatchKeyspaceIdsRequest) (response *pb.ExecuteBatchKeyspaceIdsResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.KeyspaceIdBatchQuery{
Session: proto.ProtoToSession(request.Session),
Queries: proto.ProtoToBoundKeyspaceIdQueries(request.Queries),
@ -222,13 +243,15 @@ func (vtg *VTGate) ExecuteBatchKeyspaceIds(ctx context.Context, request *pb.Exec
// StreamExecute is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Vitess_StreamExecuteServer) (err error) {
defer vtg.server.HandlePanic(&err)
ctx := callerid.NewContext(callinfo.GRPCCallInfo(stream.Context()),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.Query{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
TabletType: topo.ProtoToTabletType(request.TabletType),
}
return vtg.server.StreamExecute(stream.Context(), query, func(value *proto.QueryResult) error {
return vtg.server.StreamExecute(ctx, query, func(value *proto.QueryResult) error {
return stream.Send(&pb.StreamExecuteResponse{
Result: mproto.QueryResultToProto3(value.Result),
})
@ -238,7 +261,9 @@ func (vtg *VTGate) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Vi
// StreamExecuteShards is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecuteShards(request *pb.StreamExecuteShardsRequest, stream pbs.Vitess_StreamExecuteShardsServer) (err error) {
defer vtg.server.HandlePanic(&err)
ctx := callerid.NewContext(callinfo.GRPCCallInfo(stream.Context()),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.QueryShard{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -246,7 +271,7 @@ func (vtg *VTGate) StreamExecuteShards(request *pb.StreamExecuteShardsRequest, s
Shards: request.Shards,
TabletType: topo.ProtoToTabletType(request.TabletType),
}
return vtg.server.StreamExecuteShard(stream.Context(), query, func(value *proto.QueryResult) error {
return vtg.server.StreamExecuteShard(ctx, query, func(value *proto.QueryResult) error {
return stream.Send(&pb.StreamExecuteShardsResponse{
Result: mproto.QueryResultToProto3(value.Result),
})
@ -257,7 +282,9 @@ func (vtg *VTGate) StreamExecuteShards(request *pb.StreamExecuteShardsRequest, s
// vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecuteKeyRanges(request *pb.StreamExecuteKeyRangesRequest, stream pbs.Vitess_StreamExecuteKeyRangesServer) (err error) {
defer vtg.server.HandlePanic(&err)
ctx := callerid.NewContext(callinfo.GRPCCallInfo(stream.Context()),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.KeyRangeQuery{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -265,7 +292,7 @@ func (vtg *VTGate) StreamExecuteKeyRanges(request *pb.StreamExecuteKeyRangesRequ
KeyRanges: key.ProtoToKeyRanges(request.KeyRanges),
TabletType: topo.ProtoToTabletType(request.TabletType),
}
return vtg.server.StreamExecuteKeyRanges(stream.Context(), query, func(value *proto.QueryResult) error {
return vtg.server.StreamExecuteKeyRanges(ctx, query, func(value *proto.QueryResult) error {
return stream.Send(&pb.StreamExecuteKeyRangesResponse{
Result: mproto.QueryResultToProto3(value.Result),
})
@ -276,7 +303,9 @@ func (vtg *VTGate) StreamExecuteKeyRanges(request *pb.StreamExecuteKeyRangesRequ
// vtgateservice.VTGateService method
func (vtg *VTGate) StreamExecuteKeyspaceIds(request *pb.StreamExecuteKeyspaceIdsRequest, stream pbs.Vitess_StreamExecuteKeyspaceIdsServer) (err error) {
defer vtg.server.HandlePanic(&err)
ctx := callerid.NewContext(callinfo.GRPCCallInfo(stream.Context()),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.KeyspaceIdQuery{
Sql: string(request.Query.Sql),
BindVariables: tproto.Proto3ToBindVariables(request.Query.BindVariables),
@ -284,7 +313,7 @@ func (vtg *VTGate) StreamExecuteKeyspaceIds(request *pb.StreamExecuteKeyspaceIds
KeyspaceIds: key.ProtoToKeyspaceIds(request.KeyspaceIds),
TabletType: topo.ProtoToTabletType(request.TabletType),
}
return vtg.server.StreamExecuteKeyspaceIds(stream.Context(), query, func(value *proto.QueryResult) error {
return vtg.server.StreamExecuteKeyspaceIds(ctx, query, func(value *proto.QueryResult) error {
return stream.Send(&pb.StreamExecuteKeyspaceIdsResponse{
Result: mproto.QueryResultToProto3(value.Result),
})
@ -294,6 +323,9 @@ func (vtg *VTGate) StreamExecuteKeyspaceIds(request *pb.StreamExecuteKeyspaceIds
// Begin is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) Begin(ctx context.Context, request *pb.BeginRequest) (response *pb.BeginResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
outSession := new(proto.Session)
beginErr := vtg.server.Begin(ctx, outSession)
response = &pb.BeginResponse{
@ -312,6 +344,9 @@ func (vtg *VTGate) Begin(ctx context.Context, request *pb.BeginRequest) (respons
// Commit is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) Commit(ctx context.Context, request *pb.CommitRequest) (response *pb.CommitResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
commitErr := vtg.server.Commit(ctx, proto.ProtoToSession(request.Session))
response = &pb.CommitResponse{
Error: vtgate.VtGateErrorToVtRPCError(commitErr, ""),
@ -328,6 +363,9 @@ func (vtg *VTGate) Commit(ctx context.Context, request *pb.CommitRequest) (respo
// Rollback is the RPC version of vtgateservice.VTGateService method
func (vtg *VTGate) Rollback(ctx context.Context, request *pb.RollbackRequest) (response *pb.RollbackResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
rollbackErr := vtg.server.Rollback(ctx, proto.ProtoToSession(request.Session))
response = &pb.RollbackResponse{
Error: vtgate.VtGateErrorToVtRPCError(rollbackErr, ""),
@ -345,6 +383,9 @@ func (vtg *VTGate) Rollback(ctx context.Context, request *pb.RollbackRequest) (r
func (vtg *VTGate) SplitQuery(ctx context.Context, request *pb.SplitQueryRequest) (response *pb.SplitQueryResponse, err error) {
defer vtg.server.HandlePanic(&err)
ctx = callerid.NewContext(callinfo.GRPCCallInfo(ctx),
request.CallerId,
callerid.NewImmediateCallerID("grpc client"))
query := &proto.SplitQueryRequest{
Keyspace: request.Keyspace,
Query: tproto.BoundQuery{

Просмотреть файл

@ -12,6 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes BatchQueryShard.
@ -19,6 +20,12 @@ func (batchQueryShard *BatchQueryShard) MarshalBson(buf *bytes2.ChunkedWriter, k
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if batchQueryShard.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*batchQueryShard.CallerID).MarshalBson(buf, "CallerID")
}
// []BoundShardQuery
{
bson.EncodePrefix(buf, bson.Array, "Queries")
@ -54,6 +61,12 @@ func (batchQueryShard *BatchQueryShard) UnmarshalBson(buf *bytes.Buffer, kind by
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
batchQueryShard.CallerID = new(tproto.CallerID)
(*batchQueryShard.CallerID).UnmarshalBson(buf, kind)
}
case "Queries":
// []BoundShardQuery
if kind != bson.Null {

Просмотреть файл

@ -12,6 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes EntityIdsQuery.
@ -19,6 +20,12 @@ func (entityIdsQuery *EntityIdsQuery) MarshalBson(buf *bytes2.ChunkedWriter, key
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if entityIdsQuery.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*entityIdsQuery.CallerID).MarshalBson(buf, "CallerID")
}
bson.EncodeString(buf, "Sql", entityIdsQuery.Sql)
// map[string]interface{}
{
@ -66,6 +73,12 @@ func (entityIdsQuery *EntityIdsQuery) UnmarshalBson(buf *bytes.Buffer, kind byte
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
entityIdsQuery.CallerID = new(tproto.CallerID)
(*entityIdsQuery.CallerID).UnmarshalBson(buf, kind)
}
case "Sql":
entityIdsQuery.Sql = bson.DecodeString(buf, kind)
case "BindVariables":

Просмотреть файл

@ -13,6 +13,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
"github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes KeyRangeQuery.
@ -20,6 +21,12 @@ func (keyRangeQuery *KeyRangeQuery) MarshalBson(buf *bytes2.ChunkedWriter, key s
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if keyRangeQuery.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*keyRangeQuery.CallerID).MarshalBson(buf, "CallerID")
}
bson.EncodeString(buf, "Sql", keyRangeQuery.Sql)
// map[string]interface{}
{
@ -66,6 +73,12 @@ func (keyRangeQuery *KeyRangeQuery) UnmarshalBson(buf *bytes.Buffer, kind byte)
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
keyRangeQuery.CallerID = new(tproto.CallerID)
(*keyRangeQuery.CallerID).UnmarshalBson(buf, kind)
}
case "Sql":
keyRangeQuery.Sql = bson.DecodeString(buf, kind)
case "BindVariables":

Просмотреть файл

@ -12,6 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes KeyspaceIdBatchQuery.
@ -19,6 +20,12 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) MarshalBson(buf *bytes2.Chunke
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if keyspaceIdBatchQuery.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*keyspaceIdBatchQuery.CallerID).MarshalBson(buf, "CallerID")
}
// []BoundKeyspaceIdQuery
{
bson.EncodePrefix(buf, bson.Array, "Queries")
@ -54,6 +61,12 @@ func (keyspaceIdBatchQuery *KeyspaceIdBatchQuery) UnmarshalBson(buf *bytes.Buffe
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
keyspaceIdBatchQuery.CallerID = new(tproto.CallerID)
(*keyspaceIdBatchQuery.CallerID).UnmarshalBson(buf, kind)
}
case "Queries":
// []BoundKeyspaceIdQuery
if kind != bson.Null {

Просмотреть файл

@ -13,6 +13,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
"github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes KeyspaceIdQuery.
@ -20,6 +21,12 @@ func (keyspaceIdQuery *KeyspaceIdQuery) MarshalBson(buf *bytes2.ChunkedWriter, k
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if keyspaceIdQuery.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*keyspaceIdQuery.CallerID).MarshalBson(buf, "CallerID")
}
bson.EncodeString(buf, "Sql", keyspaceIdQuery.Sql)
// map[string]interface{}
{
@ -66,6 +73,12 @@ func (keyspaceIdQuery *KeyspaceIdQuery) UnmarshalBson(buf *bytes.Buffer, kind by
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
keyspaceIdQuery.CallerID = new(tproto.CallerID)
(*keyspaceIdQuery.CallerID).UnmarshalBson(buf, kind)
}
case "Sql":
keyspaceIdQuery.Sql = bson.DecodeString(buf, kind)
case "BindVariables":

Просмотреть файл

@ -12,6 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes Query.
@ -19,6 +20,12 @@ func (query *Query) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if query.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*query.CallerID).MarshalBson(buf, "CallerID")
}
bson.EncodeString(buf, "Sql", query.Sql)
// map[string]interface{}
{
@ -55,6 +62,12 @@ func (query *Query) UnmarshalBson(buf *bytes.Buffer, kind byte) {
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
query.CallerID = new(tproto.CallerID)
(*query.CallerID).UnmarshalBson(buf, kind)
}
case "Sql":
query.Sql = bson.DecodeString(buf, kind)
case "BindVariables":

Просмотреть файл

@ -12,6 +12,7 @@ import (
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
)
// MarshalBson bson-encodes QueryShard.
@ -19,6 +20,12 @@ func (queryShard *QueryShard) MarshalBson(buf *bytes2.ChunkedWriter, key string)
bson.EncodeOptionalPrefix(buf, bson.Object, key)
lenWriter := bson.NewLenWriter(buf)
// *tproto.CallerID
if queryShard.CallerID == nil {
bson.EncodePrefix(buf, bson.Null, "CallerID")
} else {
(*queryShard.CallerID).MarshalBson(buf, "CallerID")
}
bson.EncodeString(buf, "Sql", queryShard.Sql)
// map[string]interface{}
{
@ -65,6 +72,12 @@ func (queryShard *QueryShard) UnmarshalBson(buf *bytes.Buffer, kind byte) {
for kind := bson.NextByte(buf); kind != bson.EOO; kind = bson.NextByte(buf) {
switch bson.ReadCString(buf) {
case "CallerID":
// *tproto.CallerID
if kind != bson.Null {
queryShard.CallerID = new(tproto.CallerID)
(*queryShard.CallerID).UnmarshalBson(buf, kind)
}
case "Sql":
queryShard.Sql = bson.DecodeString(buf, kind)
case "BindVariables":

Просмотреть файл

@ -43,6 +43,7 @@ func (shardSession *ShardSession) String() string {
// Query represents a keyspace agnostic query request.
type Query struct {
CallerID *tproto.CallerID // only used by BSON
Sql string
BindVariables map[string]interface{}
TabletType topo.TabletType
@ -55,6 +56,7 @@ type Query struct {
// QueryShard represents a query request for the
// specified list of shards.
type QueryShard struct {
CallerID *tproto.CallerID // only used by BSON
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -69,6 +71,7 @@ type QueryShard struct {
// KeyspaceIdQuery represents a query request for the
// specified list of keyspace IDs.
type KeyspaceIdQuery struct {
CallerID *tproto.CallerID // only used by BSON
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -83,6 +86,7 @@ type KeyspaceIdQuery struct {
// KeyRangeQuery represents a query request for the
// specified list of keyranges.
type KeyRangeQuery struct {
CallerID *tproto.CallerID // only used by BSON
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -104,6 +108,7 @@ type EntityId struct {
// EntityIdsQuery represents a query request for the specified KeyspaceId map.
type EntityIdsQuery struct {
CallerID *tproto.CallerID // only used by BSON
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -140,6 +145,7 @@ type BoundShardQuery struct {
// BatchQueryShard represents a batch query request
// for the specified shards.
type BatchQueryShard struct {
CallerID *tproto.CallerID // only used by BSON
Queries []BoundShardQuery
TabletType topo.TabletType
AsTransaction bool
@ -162,6 +168,7 @@ type BoundKeyspaceIdQuery struct {
// KeyspaceIdBatchQuery represents a batch query request
// for the specified keyspace IDs.
type KeyspaceIdBatchQuery struct {
CallerID *tproto.CallerID // only used by BSON
Queries []BoundKeyspaceIdQuery
TabletType topo.TabletType
AsTransaction bool
@ -180,6 +187,7 @@ type QueryResultList struct {
// SplitQueryRequest is a request to split a query into multiple parts
type SplitQueryRequest struct {
CallerID *tproto.CallerID // only used by BSON
Keyspace string
Query tproto.BoundQuery
SplitColumn string
@ -200,9 +208,9 @@ type SplitQueryResult struct {
Err *mproto.RPCError
}
// BeginRequest is the BSON implementation of the proto3 query.BeginkRequest
// BeginRequest is the BSON implementation of the proto3 query.BeginRequest
type BeginRequest struct {
CallerID *tproto.CallerID
CallerID *tproto.CallerID // only used by BSON
}
// BeginResponse is the BSON implementation of the proto3 vtgate.BeginResponse
@ -215,7 +223,7 @@ type BeginResponse struct {
// CommitRequest is the BSON implementation of the proto3 vtgate.CommitRequest
type CommitRequest struct {
CallerID *tproto.CallerID
CallerID *tproto.CallerID // only used by BSON
Session *Session
}
@ -228,7 +236,7 @@ type CommitResponse struct {
// RollbackRequest is the BSON implementation of the proto3 vtgate.RollbackRequest
type RollbackRequest struct {
CallerID *tproto.CallerID
CallerID *tproto.CallerID // only used by BSON
Session *Session
}

Просмотреть файл

@ -12,6 +12,7 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqltypes"
kproto "github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/topo"
)
@ -91,6 +92,7 @@ func TestSession(t *testing.T) {
}
type reflectQueryShard struct {
CallerID *tproto.CallerID
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -101,6 +103,7 @@ type reflectQueryShard struct {
}
type extraQueryShard struct {
CallerID *tproto.CallerID
Extra int
Sql string
BindVariables map[string]interface{}
@ -209,6 +212,7 @@ type reflectBoundShardQuery struct {
}
type reflectBatchQueryShard struct {
CallerID *tproto.CallerID
Queries []reflectBoundShardQuery
TabletType topo.TabletType
AsTransaction bool
@ -216,6 +220,7 @@ type reflectBatchQueryShard struct {
}
type extraBatchQueryShard struct {
CallerID *tproto.CallerID
Extra int
Queries []reflectBoundShardQuery
TabletType topo.TabletType
@ -393,6 +398,7 @@ func TestQueryResultList(t *testing.T) {
}
type reflectKeyspaceIdQuery struct {
CallerID *tproto.CallerID
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -403,6 +409,7 @@ type reflectKeyspaceIdQuery struct {
}
type extraKeyspaceIdQuery struct {
CallerID *tproto.CallerID
Extra int
Sql string
BindVariables map[string]interface{}
@ -465,6 +472,7 @@ func TestKeyspaceIdQuery(t *testing.T) {
}
type reflectKeyRangeQuery struct {
CallerID *tproto.CallerID
Sql string
BindVariables map[string]interface{}
Keyspace string
@ -475,6 +483,7 @@ type reflectKeyRangeQuery struct {
}
type extraKeyRangeQuery struct {
CallerID *tproto.CallerID
Extra int
Sql string
BindVariables map[string]interface{}
@ -544,6 +553,7 @@ type reflectBoundKeyspaceIdQuery struct {
}
type reflectKeyspaceIdBatchQuery struct {
CallerID *tproto.CallerID
Queries []reflectBoundKeyspaceIdQuery
TabletType topo.TabletType
AsTransaction bool
@ -551,6 +561,7 @@ type reflectKeyspaceIdBatchQuery struct {
}
type extraKeyspaceIdBatchQuery struct {
CallerID *tproto.CallerID
Extra int
Queries []reflectBoundKeyspaceIdQuery
TabletType topo.TabletType

Просмотреть файл

@ -17,6 +17,7 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/tb"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/topo"
@ -24,6 +25,8 @@ import (
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
"golang.org/x/net/context"
pbv "github.com/youtube/vitess/go/vt/proto/vtrpc"
)
// fakeVTGateService has the server side of this fake
@ -34,10 +37,31 @@ type fakeVTGateService struct {
// If True, calls to Begin/2 will always succeed. This is necessary so that
// we can test subsequent calls in the transaction (e.g., Commit, Rollback).
forceBeginSuccess bool
hasCallerID bool
}
var errTestVtGateError = errors.New("test vtgate error")
func newContext() context.Context {
ctx := context.Background()
ctx = callerid.NewContext(ctx, testCallerID, nil)
return ctx
}
func (f *fakeVTGateService) checkCallerID(ctx context.Context, name string) {
if !f.hasCallerID {
return
}
ef := callerid.EffectiveCallerIDFromContext(ctx)
if ef == nil {
f.t.Errorf("no effective caller id for %v", name)
} else {
if !reflect.DeepEqual(ef, testCallerID) {
f.t.Errorf("invalid effective caller id for %v: got %v expected %v", name, ef, testCallerID)
}
}
}
// Execute is part of the VTGateService interface
func (f *fakeVTGateService) Execute(ctx context.Context, query *proto.Query, reply *proto.QueryResult) error {
if f.hasError {
@ -46,6 +70,8 @@ func (f *fakeVTGateService) Execute(ctx context.Context, query *proto.Query, rep
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "Execute")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -66,6 +92,8 @@ func (f *fakeVTGateService) ExecuteShard(ctx context.Context, query *proto.Query
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "ExecuteShard")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -86,6 +114,8 @@ func (f *fakeVTGateService) ExecuteKeyspaceIds(ctx context.Context, query *proto
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "ExecuteKeyspaceIds")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -106,6 +136,8 @@ func (f *fakeVTGateService) ExecuteKeyRanges(ctx context.Context, query *proto.K
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "ExecuteKeyRanges")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -126,6 +158,8 @@ func (f *fakeVTGateService) ExecuteEntityIds(ctx context.Context, query *proto.E
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "ExecuteEntityIds")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -146,6 +180,8 @@ func (f *fakeVTGateService) ExecuteBatchShard(ctx context.Context, batchQuery *p
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "ExecuteBatchShard")
batchQuery.CallerID = nil
execCase, ok := execMap[batchQuery.Queries[0].Sql]
if !ok {
return fmt.Errorf("no match for: %s", batchQuery.Queries[0].Sql)
@ -170,6 +206,8 @@ func (f *fakeVTGateService) ExecuteBatchKeyspaceIds(ctx context.Context, batchQu
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "ExecuteBatchKeyspaceIds")
batchQuery.CallerID = nil
execCase, ok := execMap[batchQuery.Queries[0].Sql]
if !ok {
return fmt.Errorf("no match for: %s", batchQuery.Queries[0].Sql)
@ -195,6 +233,8 @@ func (f *fakeVTGateService) StreamExecute(ctx context.Context, query *proto.Quer
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
}
f.checkCallerID(ctx, "StreamExecute")
query.CallerID = nil
if !reflect.DeepEqual(query, execCase.execQuery) {
f.t.Errorf("StreamExecute: %+v, want %+v", query, execCase.execQuery)
return nil
@ -224,6 +264,8 @@ func (f *fakeVTGateService) StreamExecuteShard(ctx context.Context, query *proto
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "StreamExecuteShard")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -257,6 +299,8 @@ func (f *fakeVTGateService) StreamExecuteKeyRanges(ctx context.Context, query *p
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "StreamExecuteKeyRanges")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -290,6 +334,8 @@ func (f *fakeVTGateService) StreamExecuteKeyspaceIds(ctx context.Context, query
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "StreamExecuteKeyspaceIds")
query.CallerID = nil
execCase, ok := execMap[query.Sql]
if !ok {
return fmt.Errorf("no match for: %s", query.Sql)
@ -320,6 +366,7 @@ func (f *fakeVTGateService) StreamExecuteKeyspaceIds(ctx context.Context, query
// Begin is part of the VTGateService interface
func (f *fakeVTGateService) Begin(ctx context.Context, outSession *proto.Session) error {
f.checkCallerID(ctx, "Begin")
switch {
case f.forceBeginSuccess:
case f.hasError:
@ -334,6 +381,7 @@ func (f *fakeVTGateService) Begin(ctx context.Context, outSession *proto.Session
// Commit is part of the VTGateService interface
func (f *fakeVTGateService) Commit(ctx context.Context, inSession *proto.Session) error {
f.checkCallerID(ctx, "Commit")
if f.hasError {
return errTestVtGateError
}
@ -354,6 +402,7 @@ func (f *fakeVTGateService) Rollback(ctx context.Context, inSession *proto.Sessi
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "Rollback")
if !reflect.DeepEqual(inSession, session2) {
return errors.New("rollback: session mismatch")
}
@ -368,6 +417,8 @@ func (f *fakeVTGateService) SplitQuery(ctx context.Context, req *proto.SplitQuer
if f.panics {
panic(fmt.Errorf("test forced panic"))
}
f.checkCallerID(ctx, "SplitQuery")
req.CallerID = nil
if !reflect.DeepEqual(req, splitQueryRequest) {
f.t.Errorf("SplitQuery has wrong input: got %#v wanted %#v", req, splitQueryRequest)
}
@ -392,8 +443,9 @@ func (f *fakeVTGateService) GetSrvKeyspace(ctx context.Context, keyspace string)
// CreateFakeServer returns the fake server for the tests
func CreateFakeServer(t *testing.T) vtgateservice.VTGateService {
return &fakeVTGateService{
t: t,
panics: false,
t: t,
panics: false,
hasCallerID: true,
}
}
@ -425,9 +477,11 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
testStreamExecuteShard(t, conn)
testStreamExecuteKeyRanges(t, conn)
testStreamExecuteKeyspaceIds(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = false
testTxPass(t, conn)
testTxPassNotInTransaction(t, conn)
testTxFail(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testTx2Pass(t, conn)
testTx2PassNotInTransaction(t, conn)
testTx2Fail(t, conn)
@ -440,7 +494,9 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
// First test errors in Begin, and then force it to succeed so we can test
// subsequent calls in the transaction.
fakeServer.(*fakeVTGateService).forceBeginSuccess = false
fakeServer.(*fakeVTGateService).hasCallerID = false
testBeginError(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testBegin2Error(t, conn)
fakeServer.(*fakeVTGateService).forceBeginSuccess = true
@ -455,8 +511,10 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
// testStreamExecuteShardError(t, conn)
// testStreamExecuteKeyRangesError(t, conn)
// testStreamExecuteKeyspaceIdsError(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = false
testCommitError(t, conn)
testRollbackError(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testCommit2Error(t, conn)
testRollback2Error(t, conn)
testSplitQueryError(t, conn)
@ -469,7 +527,9 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
// First test errors in Begin, and then force it to succeed so we can test
// subsequent calls in the transaction.
fakeServer.(*fakeVTGateService).forceBeginSuccess = false
fakeServer.(*fakeVTGateService).hasCallerID = false
testBeginPanic(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testBegin2Panic(t, conn)
fakeServer.(*fakeVTGateService).forceBeginSuccess = true
@ -484,8 +544,10 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
testStreamExecuteShardPanic(t, conn)
testStreamExecuteKeyRangesPanic(t, conn)
testStreamExecuteKeyspaceIdsPanic(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = false
testCommitPanic(t, conn)
testRollbackPanic(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testCommit2Panic(t, conn)
testRollback2Panic(t, conn)
testSplitQueryPanic(t, conn)
@ -513,7 +575,7 @@ func verifyError(t *testing.T, err error, method string) {
}
func testExecute(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
qr, err := conn.Execute(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
@ -538,21 +600,21 @@ func testExecute(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.Execute(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
verifyError(t, err, "Execute")
}
func testExecutePanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.Execute(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
expectPanic(t, err)
}
func testExecuteShard(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
qr, err := conn.ExecuteShard(ctx, execCase.shardQuery.Sql, execCase.shardQuery.Keyspace, execCase.shardQuery.Shards, execCase.shardQuery.BindVariables, execCase.shardQuery.TabletType)
if err != nil {
@ -577,21 +639,21 @@ func testExecuteShard(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteShardError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteShard(ctx, execCase.shardQuery.Sql, execCase.shardQuery.Keyspace, execCase.shardQuery.Shards, execCase.shardQuery.BindVariables, execCase.shardQuery.TabletType)
verifyError(t, err, "ExecuteShard")
}
func testExecuteShardPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteShard(ctx, execCase.execQuery.Sql, "ks", []string{"1", "2"}, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
expectPanic(t, err)
}
func testExecuteKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
qr, err := conn.ExecuteKeyspaceIds(ctx, execCase.keyspaceIdQuery.Sql, execCase.keyspaceIdQuery.Keyspace, execCase.keyspaceIdQuery.KeyspaceIds, execCase.keyspaceIdQuery.BindVariables, execCase.keyspaceIdQuery.TabletType)
if err != nil {
@ -616,21 +678,21 @@ func testExecuteKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteKeyspaceIdsError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteKeyspaceIds(ctx, execCase.keyspaceIdQuery.Sql, execCase.keyspaceIdQuery.Keyspace, execCase.keyspaceIdQuery.KeyspaceIds, execCase.keyspaceIdQuery.BindVariables, execCase.keyspaceIdQuery.TabletType)
verifyError(t, err, "ExecuteKeyspaceIds")
}
func testExecuteKeyspaceIdsPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteKeyspaceIds(ctx, execCase.keyspaceIdQuery.Sql, execCase.keyspaceIdQuery.Keyspace, execCase.keyspaceIdQuery.KeyspaceIds, execCase.keyspaceIdQuery.BindVariables, execCase.keyspaceIdQuery.TabletType)
expectPanic(t, err)
}
func testExecuteKeyRanges(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
qr, err := conn.ExecuteKeyRanges(ctx, execCase.keyRangeQuery.Sql, execCase.keyRangeQuery.Keyspace, execCase.keyRangeQuery.KeyRanges, execCase.keyRangeQuery.BindVariables, execCase.keyRangeQuery.TabletType)
if err != nil {
@ -655,21 +717,21 @@ func testExecuteKeyRanges(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteKeyRangesError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteKeyRanges(ctx, execCase.keyRangeQuery.Sql, execCase.keyRangeQuery.Keyspace, execCase.keyRangeQuery.KeyRanges, execCase.keyRangeQuery.BindVariables, execCase.keyRangeQuery.TabletType)
verifyError(t, err, "ExecuteKeyRanges")
}
func testExecuteKeyRangesPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteKeyRanges(ctx, execCase.keyRangeQuery.Sql, execCase.keyRangeQuery.Keyspace, execCase.keyRangeQuery.KeyRanges, execCase.keyRangeQuery.BindVariables, execCase.keyRangeQuery.TabletType)
expectPanic(t, err)
}
func testExecuteEntityIds(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
qr, err := conn.ExecuteEntityIds(ctx, execCase.entityIdsQuery.Sql, execCase.entityIdsQuery.Keyspace, execCase.entityIdsQuery.EntityColumnName, execCase.entityIdsQuery.EntityKeyspaceIDs, execCase.entityIdsQuery.BindVariables, execCase.entityIdsQuery.TabletType)
if err != nil {
@ -694,21 +756,21 @@ func testExecuteEntityIds(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteEntityIdsError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteEntityIds(ctx, execCase.entityIdsQuery.Sql, execCase.entityIdsQuery.Keyspace, execCase.entityIdsQuery.EntityColumnName, execCase.entityIdsQuery.EntityKeyspaceIDs, execCase.entityIdsQuery.BindVariables, execCase.entityIdsQuery.TabletType)
verifyError(t, err, "ExecuteEntityIds")
}
func testExecuteEntityIdsPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteEntityIds(ctx, execCase.entityIdsQuery.Sql, execCase.entityIdsQuery.Keyspace, execCase.entityIdsQuery.EntityColumnName, execCase.entityIdsQuery.EntityKeyspaceIDs, execCase.entityIdsQuery.BindVariables, execCase.entityIdsQuery.TabletType)
expectPanic(t, err)
}
func testExecuteBatchShard(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
ql, err := conn.ExecuteBatchShard(ctx, execCase.batchQueryShard.Queries, execCase.batchQueryShard.TabletType, execCase.batchQueryShard.AsTransaction)
if err != nil {
@ -733,21 +795,21 @@ func testExecuteBatchShard(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteBatchShardError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteBatchShard(ctx, execCase.batchQueryShard.Queries, execCase.batchQueryShard.TabletType, execCase.batchQueryShard.AsTransaction)
verifyError(t, err, "ExecuteBatchShard")
}
func testExecuteBatchShardPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteBatchShard(ctx, execCase.batchQueryShard.Queries, execCase.batchQueryShard.TabletType, execCase.batchQueryShard.AsTransaction)
expectPanic(t, err)
}
func testExecuteBatchKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
ql, err := conn.ExecuteBatchKeyspaceIds(ctx, execCase.keyspaceIdBatchQuery.Queries, execCase.keyspaceIdBatchQuery.TabletType, execCase.batchQueryShard.AsTransaction)
if err != nil {
@ -772,21 +834,21 @@ func testExecuteBatchKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testExecuteBatchKeyspaceIdsError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteBatchKeyspaceIds(ctx, execCase.keyspaceIdBatchQuery.Queries, execCase.keyspaceIdBatchQuery.TabletType, execCase.keyspaceIdBatchQuery.AsTransaction)
verifyError(t, err, "ExecuteBatchKeyspaceIds")
}
func testExecuteBatchKeyspaceIdsPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
_, err := conn.ExecuteBatchKeyspaceIds(ctx, execCase.keyspaceIdBatchQuery.Queries, execCase.keyspaceIdBatchQuery.TabletType, execCase.keyspaceIdBatchQuery.AsTransaction)
expectPanic(t, err)
}
func testStreamExecute(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecute(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
@ -840,7 +902,7 @@ func testStreamExecute(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testStreamExecutePanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecute(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
@ -854,7 +916,7 @@ func testStreamExecutePanic(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testStreamExecuteShard(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecuteShard(ctx, execCase.shardQuery.Sql, execCase.shardQuery.Keyspace, execCase.shardQuery.Shards, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
@ -908,7 +970,7 @@ func testStreamExecuteShard(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testStreamExecuteShardPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecuteShard(ctx, execCase.shardQuery.Sql, execCase.shardQuery.Keyspace, execCase.shardQuery.Shards, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
@ -922,7 +984,7 @@ func testStreamExecuteShardPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testStreamExecuteKeyRanges(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecuteKeyRanges(ctx, execCase.keyRangeQuery.Sql, execCase.keyRangeQuery.Keyspace, execCase.keyRangeQuery.KeyRanges, execCase.keyRangeQuery.BindVariables, execCase.keyRangeQuery.TabletType)
if err != nil {
@ -976,7 +1038,7 @@ func testStreamExecuteKeyRanges(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testStreamExecuteKeyRangesPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecuteKeyRanges(ctx, execCase.keyRangeQuery.Sql, execCase.keyRangeQuery.Keyspace, execCase.keyRangeQuery.KeyRanges, execCase.keyRangeQuery.BindVariables, execCase.keyRangeQuery.TabletType)
if err != nil {
@ -990,7 +1052,7 @@ func testStreamExecuteKeyRangesPanic(t *testing.T, conn *vtgateconn.VTGateConn)
}
func testStreamExecuteKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecuteKeyspaceIds(ctx, execCase.keyspaceIdQuery.Sql, execCase.keyspaceIdQuery.Keyspace, execCase.keyspaceIdQuery.KeyspaceIds, execCase.keyspaceIdQuery.BindVariables, execCase.keyspaceIdQuery.TabletType)
if err != nil {
@ -1044,7 +1106,7 @@ func testStreamExecuteKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testStreamExecuteKeyspaceIdsPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["request1"]
packets, errFunc, err := conn.StreamExecuteKeyspaceIds(ctx, execCase.keyspaceIdQuery.Sql, execCase.keyspaceIdQuery.Keyspace, execCase.keyspaceIdQuery.KeyspaceIds, execCase.keyspaceIdQuery.BindVariables, execCase.keyspaceIdQuery.TabletType)
if err != nil {
@ -1058,7 +1120,7 @@ func testStreamExecuteKeyspaceIdsPanic(t *testing.T, conn *vtgateconn.VTGateConn
}
func testTxPass(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["txRequest"]
// Execute
@ -1161,7 +1223,7 @@ func testTxPass(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testTxPassNotInTransaction(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["txRequestNIT"]
tx, err := conn.Begin(ctx)
@ -1201,7 +1263,7 @@ func testTxPassNotInTransaction(t *testing.T, conn *vtgateconn.VTGateConn) {
// Same as testTxPass, but with Begin2/Commit2/Rollback2 instead
func testTx2Pass(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["txRequest"]
// Execute
@ -1305,7 +1367,7 @@ func testTx2Pass(t *testing.T, conn *vtgateconn.VTGateConn) {
// Same as testTxPassNotInTransaction, but with Begin2/Commit2/Rollback2 instead
func testTx2PassNotInTransaction(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
execCase := execMap["txRequestNIT"]
tx, err := conn.Begin2(ctx)
@ -1344,13 +1406,13 @@ func testTx2PassNotInTransaction(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testBeginError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.Begin(ctx)
verifyError(t, err, "Begin")
}
func testCommitError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin(ctx)
if err != nil {
t.Error(err)
@ -1360,7 +1422,7 @@ func testCommitError(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testRollbackError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin(ctx)
if err != nil {
t.Error(err)
@ -1370,13 +1432,13 @@ func testRollbackError(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testBegin2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.Begin2(ctx)
verifyError(t, err, "Begin2")
}
func testCommit2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin2(ctx)
if err != nil {
t.Error(err)
@ -1386,7 +1448,7 @@ func testCommit2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testRollback2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin2(ctx)
if err != nil {
t.Error(err)
@ -1396,13 +1458,13 @@ func testRollback2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testBeginPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.Begin(ctx)
expectPanic(t, err)
}
func testCommitPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin(ctx)
if err != nil {
t.Error(err)
@ -1412,7 +1474,7 @@ func testCommitPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testRollbackPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin(ctx)
if err != nil {
t.Error(err)
@ -1422,13 +1484,13 @@ func testRollbackPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testBegin2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.Begin2(ctx)
expectPanic(t, err)
}
func testCommit2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin2(ctx)
if err != nil {
t.Error(err)
@ -1438,7 +1500,7 @@ func testCommit2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testRollback2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin2(ctx)
if err != nil {
t.Error(err)
@ -1448,7 +1510,7 @@ func testRollback2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testTxFail(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin(ctx)
if err != nil {
t.Error(err)
@ -1525,7 +1587,7 @@ func testTxFail(t *testing.T, conn *vtgateconn.VTGateConn) {
// Same as testTxFail, but with Begin2/Commit2/Rollback2 instead
func testTx2Fail(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
tx, err := conn.Begin2(ctx)
if err != nil {
t.Error(err)
@ -1589,7 +1651,7 @@ func testTx2Fail(t *testing.T, conn *vtgateconn.VTGateConn) {
t.Error(err)
}
tx, err = conn.Begin(ctx)
tx, err = conn.Begin2(ctx)
if err != nil {
t.Error(err)
}
@ -1601,7 +1663,7 @@ func testTx2Fail(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testSplitQuery(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
qsl, err := conn.SplitQuery(ctx, splitQueryRequest.Keyspace, splitQueryRequest.Query, splitQueryRequest.SplitColumn, splitQueryRequest.SplitCount)
if err != nil {
t.Fatalf("SplitQuery failed: %v", err)
@ -1613,19 +1675,19 @@ func testSplitQuery(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testSplitQueryError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.SplitQuery(ctx, splitQueryRequest.Keyspace, splitQueryRequest.Query, splitQueryRequest.SplitColumn, splitQueryRequest.SplitCount)
verifyError(t, err, "SplitQuery")
}
func testSplitQueryPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.SplitQuery(ctx, splitQueryRequest.Keyspace, splitQueryRequest.Query, splitQueryRequest.SplitColumn, splitQueryRequest.SplitCount)
expectPanic(t, err)
}
func testGetSrvKeyspace(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
sk, err := conn.GetSrvKeyspace(ctx, getSrvKeyspaceKeyspace)
if err != nil {
t.Fatalf("GetSrvKeyspace failed: %v", err)
@ -1636,17 +1698,23 @@ func testGetSrvKeyspace(t *testing.T, conn *vtgateconn.VTGateConn) {
}
func testGetSrvKeyspaceError(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.GetSrvKeyspace(ctx, getSrvKeyspaceKeyspace)
verifyError(t, err, "GetSrvKeyspace")
}
func testGetSrvKeyspacePanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := context.Background()
ctx := newContext()
_, err := conn.GetSrvKeyspace(ctx, getSrvKeyspaceKeyspace)
expectPanic(t, err)
}
var testCallerID = &pbv.CallerID{
Principal: "test_principal",
Component: "test_component",
Subcomponent: "test_subcomponent",
}
var execMap = map[string]struct {
execQuery *proto.Query
shardQuery *proto.QueryShard