diff --git a/go/vt/tabletserver/gorpctabletconn/conn.go b/go/vt/tabletserver/gorpctabletconn/conn.go index e50a82812b..c68c21d474 100644 --- a/go/vt/tabletserver/gorpctabletconn/conn.go +++ b/go/vt/tabletserver/gorpctabletconn/conn.go @@ -494,4 +494,4 @@ func tabletErrorFromVitessError(ve *vterrors.VitessError) error { } return tabletconn.OperationalError(fmt.Sprintf("vttablet: %v", ve.Message)) -} \ No newline at end of file +} diff --git a/go/vt/vtgate/gorpcvtgateservice/server.go b/go/vt/vtgate/gorpcvtgateservice/server.go index bb9532c6b9..e568df7847 100644 --- a/go/vt/vtgate/gorpcvtgateservice/server.go +++ b/go/vt/vtgate/gorpcvtgateservice/server.go @@ -31,7 +31,12 @@ func (vtg *VTGate) Execute(ctx context.Context, query *proto.Query, reply *proto defer vtg.server.HandlePanic(&err) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout)) defer cancel() - return vtg.server.Execute(ctx, query, reply) + vtgErr := vtg.server.Execute(ctx, query, reply) + vtgate.AddVtGateErrorToQueryResult(vtgErr, reply) + if *vtgate.RPCErrorOnlyInReply { + return nil + } + return vtgErr } // ExecuteShard is the RPC version of vtgateservice.VTGateService method @@ -39,7 +44,12 @@ func (vtg *VTGate) ExecuteShard(ctx context.Context, query *proto.QueryShard, re defer vtg.server.HandlePanic(&err) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout)) defer cancel() - return vtg.server.ExecuteShard(ctx, query, reply) + vtgErr := vtg.server.ExecuteShard(ctx, query, reply) + vtgate.AddVtGateErrorToQueryResult(vtgErr, reply) + if *vtgate.RPCErrorOnlyInReply { + return nil + } + return vtgErr } // ExecuteKeyspaceIds is the RPC version of vtgateservice.VTGateService method @@ -47,7 +57,12 @@ func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, query *proto.Keyspace defer vtg.server.HandlePanic(&err) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout)) defer cancel() - return vtg.server.ExecuteKeyspaceIds(ctx, query, reply) + vtgErr := vtg.server.ExecuteKeyspaceIds(ctx, query, reply) + vtgate.AddVtGateErrorToQueryResult(vtgErr, reply) + if *vtgate.RPCErrorOnlyInReply { + return nil + } + return vtgErr } // ExecuteKeyRanges is the RPC version of vtgateservice.VTGateService method @@ -55,7 +70,12 @@ func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQu defer vtg.server.HandlePanic(&err) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout)) defer cancel() - return vtg.server.ExecuteKeyRanges(ctx, query, reply) + vtgErr := vtg.server.ExecuteKeyRanges(ctx, query, reply) + vtgate.AddVtGateErrorToQueryResult(vtgErr, reply) + if *vtgate.RPCErrorOnlyInReply { + return nil + } + return vtgErr } // ExecuteEntityIds is the RPC version of vtgateservice.VTGateService method @@ -63,7 +83,12 @@ func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, query *proto.EntityIdsQ defer vtg.server.HandlePanic(&err) ctx, cancel := context.WithDeadline(ctx, time.Now().Add(*rpcTimeout)) defer cancel() - return vtg.server.ExecuteEntityIds(ctx, query, reply) + vtgErr := vtg.server.ExecuteEntityIds(ctx, query, reply) + vtgate.AddVtGateErrorToQueryResult(vtgErr, reply) + if *vtgate.RPCErrorOnlyInReply { + return nil + } + return vtgErr } // ExecuteBatchShard is the RPC version of vtgateservice.VTGateService method diff --git a/go/vt/vtgate/proto/vtgate_proto.go b/go/vt/vtgate/proto/vtgate_proto.go index a179f16e29..0e42693194 100644 --- a/go/vt/vtgate/proto/vtgate_proto.go +++ b/go/vt/vtgate/proto/vtgate_proto.go @@ -121,6 +121,7 @@ type QueryResult struct { Result *mproto.QueryResult Session *Session Error string + Err *mproto.RPCError } //go:generate bsongen -file $GOFILE -type QueryResult -o query_result_bson.go diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 99d75e6eda..8c00ffbd7d 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -8,6 +8,7 @@ package vtgate import ( "errors" + "flag" "fmt" "math" "strings" @@ -84,6 +85,11 @@ type RegisterVTGate func(vtgateservice.VTGateService) // RegisterVTGates stores register funcs for VTGate server. var RegisterVTGates []RegisterVTGate +var ( + // RPCErrorOnlyInReply informs vtgateservice(s) about how to return errors. + RPCErrorOnlyInReply = flag.Bool("rpc-error-only-in-reply", false, "if true, supported RPC calls from vtgateservice(s) will only return errors as part of the RPC server response") +) + // Init initializes VTGate server. func Init(serv SrvTopoServer, schema *planbuilder.Schema, cell string, retryDelay time.Duration, retryCount int, connTimeoutTotal, connTimeoutPerConn, connLife time.Duration, maxInFlight int) { if rpcVTGate != nil { diff --git a/go/vt/vtgate/vtgate_error.go b/go/vt/vtgate/vtgate_error.go new file mode 100644 index 0000000000..e04e13b46c --- /dev/null +++ b/go/vt/vtgate/vtgate_error.go @@ -0,0 +1,34 @@ +// Copyright 2012, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package vtgate + +import ( + mproto "github.com/youtube/vitess/go/mysql/proto" + "github.com/youtube/vitess/go/vt/vterrors" + "github.com/youtube/vitess/go/vt/vtgate/proto" +) + +// rpcErrFromTabletError translate an error from VTGate to an *mproto.RPCError +func rpcErrFromVtGateError(err error) *mproto.RPCError { + if err == nil { + return nil + } + // TODO(aaijazi): for now, we don't have any differentiation of VtGate errors. + // However, we should have them soon, so that clients don't have to parse the + // returned error string. + return &mproto.RPCError{ + Code: vterrors.UnknownVtgateError, + Message: err.Error(), + } +} + +// AddVtGateErrorToQueryResult will mutate a QueryResult struct to fill in the Err +// field with details from the VTGate error. +func AddVtGateErrorToQueryResult(err error, reply *proto.QueryResult) { + if err == nil { + return + } + reply.Err = rpcErrFromVtGateError(err) +} diff --git a/go/vt/vtgate/vtgateconn/vtgateconn_test.go b/go/vt/vtgate/vtgateconn/vtgateconn_test.go index 09857068e5..afe318e1dc 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn_test.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn_test.go @@ -43,6 +43,6 @@ func TestServerError(t *testing.T) { func TestOperationalError(t *testing.T) { if OperationalError("error").Error() == "" { - t.Fatal("operational error is not mepty, should not return empty error") + t.Fatal("operational error is not empty, should not return empty error") } }