diff --git a/go/cmd/vtgateclienttest/callerid.go b/go/cmd/vtgateclienttest/callerid.go new file mode 100644 index 0000000000..5f70ef2db3 --- /dev/null +++ b/go/cmd/vtgateclienttest/callerid.go @@ -0,0 +1,176 @@ +// Copyright 2015 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + log "github.com/golang/glog" + "golang.org/x/net/context" + + "github.com/youtube/vitess/go/tb" + "github.com/youtube/vitess/go/vt/callerid" + "github.com/youtube/vitess/go/vt/topo" + "github.com/youtube/vitess/go/vt/vtgate/proto" + "github.com/youtube/vitess/go/vt/vtgate/vtgateservice" + + pb "github.com/youtube/vitess/go/vt/proto/vtrpc" +) + +const callerIDPrefix = "callerid " + +// callerIDClient implements vtgateservice.VTGateService, and checks +// the received callerid matches the one passed out of band by the client. +type callerIDClient struct { + fallback vtgateservice.VTGateService +} + +func newCallerIDClient(fallback vtgateservice.VTGateService) *callerIDClient { + return &callerIDClient{ + fallback: fallback, + } +} + +// checkCallerID will see if this module is handling the request, +// and if it is, check the callerID from the context. +// Returns false if the query is not for this module. +// Returns true and the the error to return with the call +// if this module is handling the request. +func (c *callerIDClient) checkCallerID(ctx context.Context, received string) (bool, error) { + if !strings.HasPrefix(received, callerIDPrefix) { + return false, nil + } + + jsonCallerID := []byte(received[len(callerIDPrefix):]) + expectedCallerID := &pb.CallerID{} + if err := json.Unmarshal(jsonCallerID, expectedCallerID); err != nil { + return true, fmt.Errorf("cannot unmarshal provided callerid: %v", err) + } + + receivedCallerID := callerid.EffectiveCallerIDFromContext(ctx) + if receivedCallerID == nil { + return true, fmt.Errorf("no callerid received in the query") + } + + if !reflect.DeepEqual(receivedCallerID, expectedCallerID) { + return true, fmt.Errorf("callerid mismatch, got %v expected %v", receivedCallerID, expectedCallerID) + } + + return true, nil +} + +func (c *callerIDClient) Execute(ctx context.Context, query *proto.Query, reply *proto.QueryResult) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.Execute(ctx, query, reply) +} + +func (c *callerIDClient) ExecuteShard(ctx context.Context, query *proto.QueryShard, reply *proto.QueryResult) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.ExecuteShard(ctx, query, reply) +} + +func (c *callerIDClient) ExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, reply *proto.QueryResult) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.ExecuteKeyspaceIds(ctx, query, reply) +} + +func (c *callerIDClient) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, reply *proto.QueryResult) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.ExecuteKeyRanges(ctx, query, reply) +} + +func (c *callerIDClient) ExecuteEntityIds(ctx context.Context, query *proto.EntityIdsQuery, reply *proto.QueryResult) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.ExecuteEntityIds(ctx, query, reply) +} + +func (c *callerIDClient) ExecuteBatchShard(ctx context.Context, batchQuery *proto.BatchQueryShard, reply *proto.QueryResultList) error { + if len(batchQuery.Queries) == 1 { + if ok, err := c.checkCallerID(ctx, batchQuery.Queries[0].Sql); ok { + return err + } + } + return c.fallback.ExecuteBatchShard(ctx, batchQuery, reply) +} + +func (c *callerIDClient) ExecuteBatchKeyspaceIds(ctx context.Context, batchQuery *proto.KeyspaceIdBatchQuery, reply *proto.QueryResultList) error { + if len(batchQuery.Queries) == 1 { + if ok, err := c.checkCallerID(ctx, batchQuery.Queries[0].Sql); ok { + return err + } + } + return c.fallback.ExecuteBatchKeyspaceIds(ctx, batchQuery, reply) +} + +func (c *callerIDClient) StreamExecute(ctx context.Context, query *proto.Query, sendReply func(*proto.QueryResult) error) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.StreamExecute(ctx, query, sendReply) +} + +func (c *callerIDClient) StreamExecuteShard(ctx context.Context, query *proto.QueryShard, sendReply func(*proto.QueryResult) error) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.StreamExecuteShard(ctx, query, sendReply) +} + +func (c *callerIDClient) StreamExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, sendReply func(*proto.QueryResult) error) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.StreamExecuteKeyRanges(ctx, query, sendReply) +} + +func (c *callerIDClient) StreamExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, sendReply func(*proto.QueryResult) error) error { + if ok, err := c.checkCallerID(ctx, query.Sql); ok { + return err + } + return c.fallback.StreamExecuteKeyspaceIds(ctx, query, sendReply) +} + +func (c *callerIDClient) Begin(ctx context.Context, outSession *proto.Session) error { + return c.fallback.Begin(ctx, outSession) +} + +func (c *callerIDClient) Commit(ctx context.Context, inSession *proto.Session) error { + return c.fallback.Commit(ctx, inSession) +} + +func (c *callerIDClient) Rollback(ctx context.Context, inSession *proto.Session) error { + return c.fallback.Rollback(ctx, inSession) +} + +func (c *callerIDClient) SplitQuery(ctx context.Context, req *proto.SplitQueryRequest, reply *proto.SplitQueryResult) error { + if ok, err := c.checkCallerID(ctx, req.Query.Sql); ok { + return err + } + return c.fallback.SplitQuery(ctx, req, reply) +} + +func (c *callerIDClient) GetSrvKeyspace(ctx context.Context, keyspace string) (*topo.SrvKeyspace, error) { + return c.fallback.GetSrvKeyspace(ctx, keyspace) +} + +func (c *callerIDClient) HandlePanic(err *error) { + if x := recover(); x != nil { + log.Errorf("Uncaught panic:\n%v\n%s", x, tb.Stack(4)) + *err = fmt.Errorf("uncaught panic: %v", x) + } +} diff --git a/go/cmd/vtgateclienttest/errors.go b/go/cmd/vtgateclienttest/errors.go index 87d2c7315d..e5de129ffe 100644 --- a/go/cmd/vtgateclienttest/errors.go +++ b/go/cmd/vtgateclienttest/errors.go @@ -8,12 +8,12 @@ import ( "fmt" log "github.com/golang/glog" + "golang.org/x/net/context" "github.com/youtube/vitess/go/tb" "github.com/youtube/vitess/go/vt/topo" "github.com/youtube/vitess/go/vt/vtgate/proto" "github.com/youtube/vitess/go/vt/vtgate/vtgateservice" - "golang.org/x/net/context" ) // errorClient implements vtgateservice.VTGateService diff --git a/go/cmd/vtgateclienttest/goclient_test.go b/go/cmd/vtgateclienttest/goclient_test.go new file mode 100644 index 0000000000..2db7636d77 --- /dev/null +++ b/go/cmd/vtgateclienttest/goclient_test.go @@ -0,0 +1,61 @@ +// Copyright 2015 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "encoding/json" + "testing" + "time" + + "github.com/youtube/vitess/go/vt/callerid" + "github.com/youtube/vitess/go/vt/topo" + "github.com/youtube/vitess/go/vt/vtgate/vtgateconn" + "golang.org/x/net/context" +) + +// This file contains the reference test for clients. It tests +// all the corner cases of the API, and makes sure the go client +// is full featured. +// +// It can be used as a template by other languages for their test suites. +// +// TODO(team) add more unit test cases. + +// testCallerID adds a caller ID to a context, and makes sure the server +// gets it. +func testCallerID(t *testing.T, conn *vtgateconn.VTGateConn) { + t.Log("testCallerID") + ctx := context.Background() + callerID := callerid.NewEffectiveCallerID("test_principal", "test_component", "test_subcomponent") + ctx = callerid.NewContext(ctx, callerID, nil) + + data, err := json.Marshal(callerID) + if err != nil { + t.Errorf("failed to marshal callerid: %v", err) + return + } + query := callerIDPrefix + string(data) + + // test Execute forwards the callerID + if _, err := conn.Execute(ctx, query, nil, topo.TYPE_MASTER); err != nil { + t.Errorf("failed to pass callerid: %v", err) + } + + // FIXME(alainjobart) add all function calls +} + +func testGoClient(t *testing.T, protocol, addr string) { + // Create a client connecting to the server + ctx := context.Background() + conn, err := vtgateconn.DialProtocol(ctx, protocol, addr, 30*time.Second) + if err != nil { + t.Fatalf("dial failed: %v", err) + } + + testCallerID(t, conn) + + // and clean up + conn.Close() +} diff --git a/go/cmd/vtgateclienttest/gorpc_goclient_test.go b/go/cmd/vtgateclienttest/gorpc_goclient_test.go new file mode 100644 index 0000000000..11e7639320 --- /dev/null +++ b/go/cmd/vtgateclienttest/gorpc_goclient_test.go @@ -0,0 +1,45 @@ +// Copyright 2015 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "net" + "net/http" + "testing" + + "github.com/youtube/vitess/go/rpcplus" + "github.com/youtube/vitess/go/rpcwrap/bsonrpc" + "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice" + + // import the gorpc client, it will register itself + _ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn" +) + +// TestGoRPCGoClient tests the go client using goRPC +func TestGoRPCGoClient(t *testing.T) { + service := createService() + + // listen on a random port + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Cannot listen: %v", err) + } + defer listener.Close() + + // Create a Go Rpc server and listen on the port + server := rpcplus.NewServer() + server.Register(gorpcvtgateservice.New(service)) + + // create the HTTP server, serve the server from it + handler := http.NewServeMux() + bsonrpc.ServeCustomRPC(handler, server, false) + httpServer := http.Server{ + Handler: handler, + } + go httpServer.Serve(listener) + + // and run the test suite + testGoClient(t, "gorpc", listener.Addr().String()) +} diff --git a/go/cmd/vtgateclienttest/grpc_goclient_test.go b/go/cmd/vtgateclienttest/grpc_goclient_test.go new file mode 100644 index 0000000000..92b046afa6 --- /dev/null +++ b/go/cmd/vtgateclienttest/grpc_goclient_test.go @@ -0,0 +1,37 @@ +// Copyright 2015 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "net" + "testing" + + "google.golang.org/grpc" + + "github.com/youtube/vitess/go/vt/vtgate/grpcvtgateservice" + + // import the grpc client, it will register itself + _ "github.com/youtube/vitess/go/vt/vtgate/grpcvtgateconn" +) + +// TestGRPCGoClient tests the go client using gRPC +func TestGRPCGoClient(t *testing.T) { + service := createService() + + // listen on a random port + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Cannot listen: %v", err) + } + defer listener.Close() + + // Create a gRPC server and listen on the port + server := grpc.NewServer() + grpcvtgateservice.RegisterForTest(server, service) + go server.Serve(listener) + + // and run the test suite + testGoClient(t, "grpc", listener.Addr().String()) +} diff --git a/go/cmd/vtgateclienttest/main.go b/go/cmd/vtgateclienttest/main.go index 06fd79ece9..961a9f6e0a 100644 --- a/go/cmd/vtgateclienttest/main.go +++ b/go/cmd/vtgateclienttest/main.go @@ -13,12 +13,23 @@ import ( "github.com/youtube/vitess/go/exit" "github.com/youtube/vitess/go/vt/servenv" "github.com/youtube/vitess/go/vt/vtgate" + "github.com/youtube/vitess/go/vt/vtgate/vtgateservice" ) func init() { servenv.RegisterDefaultFlags() } +// createService creates the implementation chain of all the test cases +func createService() vtgateservice.VTGateService { + var s vtgateservice.VTGateService + s = newTerminalClient() + s = newSuccessClient(s) + s = newErrorClient(s) + s = newCallerIDClient(s) + return s +} + func main() { defer exit.Recover() @@ -26,9 +37,9 @@ func main() { servenv.Init() // The implementation chain. - c := newErrorClient(newSuccessClient(newTerminalClient())) + s := createService() for _, f := range vtgate.RegisterVTGates { - f(c) + f(s) } servenv.RunDefault()