зеркало из https://github.com/github/vitess-gh.git
Коммит
ef1e7a908e
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/youtube/vitess/go/tb"
|
||||
"github.com/youtube/vitess/go/vt/callerid"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/proto"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
|
||||
|
||||
pb "github.com/youtube/vitess/go/vt/proto/vtrpc"
|
||||
)
|
||||
|
||||
const callerIDPrefix = "callerid "
|
||||
|
||||
// callerIDClient implements vtgateservice.VTGateService, and checks
|
||||
// the received callerid matches the one passed out of band by the client.
|
||||
type callerIDClient struct {
|
||||
fallback vtgateservice.VTGateService
|
||||
}
|
||||
|
||||
func newCallerIDClient(fallback vtgateservice.VTGateService) *callerIDClient {
|
||||
return &callerIDClient{
|
||||
fallback: fallback,
|
||||
}
|
||||
}
|
||||
|
||||
// checkCallerID will see if this module is handling the request,
|
||||
// and if it is, check the callerID from the context.
|
||||
// Returns false if the query is not for this module.
|
||||
// Returns true and the the error to return with the call
|
||||
// if this module is handling the request.
|
||||
func (c *callerIDClient) checkCallerID(ctx context.Context, received string) (bool, error) {
|
||||
if !strings.HasPrefix(received, callerIDPrefix) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
jsonCallerID := []byte(received[len(callerIDPrefix):])
|
||||
expectedCallerID := &pb.CallerID{}
|
||||
if err := json.Unmarshal(jsonCallerID, expectedCallerID); err != nil {
|
||||
return true, fmt.Errorf("cannot unmarshal provided callerid: %v", err)
|
||||
}
|
||||
|
||||
receivedCallerID := callerid.EffectiveCallerIDFromContext(ctx)
|
||||
if receivedCallerID == nil {
|
||||
return true, fmt.Errorf("no callerid received in the query")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(receivedCallerID, expectedCallerID) {
|
||||
return true, fmt.Errorf("callerid mismatch, got %v expected %v", receivedCallerID, expectedCallerID)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *callerIDClient) Execute(ctx context.Context, query *proto.Query, reply *proto.QueryResult) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.Execute(ctx, query, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) ExecuteShard(ctx context.Context, query *proto.QueryShard, reply *proto.QueryResult) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.ExecuteShard(ctx, query, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) ExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, reply *proto.QueryResult) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.ExecuteKeyspaceIds(ctx, query, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, reply *proto.QueryResult) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.ExecuteKeyRanges(ctx, query, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) ExecuteEntityIds(ctx context.Context, query *proto.EntityIdsQuery, reply *proto.QueryResult) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.ExecuteEntityIds(ctx, query, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) ExecuteBatchShard(ctx context.Context, batchQuery *proto.BatchQueryShard, reply *proto.QueryResultList) error {
|
||||
if len(batchQuery.Queries) == 1 {
|
||||
if ok, err := c.checkCallerID(ctx, batchQuery.Queries[0].Sql); ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.fallback.ExecuteBatchShard(ctx, batchQuery, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) ExecuteBatchKeyspaceIds(ctx context.Context, batchQuery *proto.KeyspaceIdBatchQuery, reply *proto.QueryResultList) error {
|
||||
if len(batchQuery.Queries) == 1 {
|
||||
if ok, err := c.checkCallerID(ctx, batchQuery.Queries[0].Sql); ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.fallback.ExecuteBatchKeyspaceIds(ctx, batchQuery, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) StreamExecute(ctx context.Context, query *proto.Query, sendReply func(*proto.QueryResult) error) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.StreamExecute(ctx, query, sendReply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) StreamExecuteShard(ctx context.Context, query *proto.QueryShard, sendReply func(*proto.QueryResult) error) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.StreamExecuteShard(ctx, query, sendReply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) StreamExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, sendReply func(*proto.QueryResult) error) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.StreamExecuteKeyRanges(ctx, query, sendReply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) StreamExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, sendReply func(*proto.QueryResult) error) error {
|
||||
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.StreamExecuteKeyspaceIds(ctx, query, sendReply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) Begin(ctx context.Context, outSession *proto.Session) error {
|
||||
return c.fallback.Begin(ctx, outSession)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) Commit(ctx context.Context, inSession *proto.Session) error {
|
||||
return c.fallback.Commit(ctx, inSession)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) Rollback(ctx context.Context, inSession *proto.Session) error {
|
||||
return c.fallback.Rollback(ctx, inSession)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) SplitQuery(ctx context.Context, req *proto.SplitQueryRequest, reply *proto.SplitQueryResult) error {
|
||||
if ok, err := c.checkCallerID(ctx, req.Query.Sql); ok {
|
||||
return err
|
||||
}
|
||||
return c.fallback.SplitQuery(ctx, req, reply)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) GetSrvKeyspace(ctx context.Context, keyspace string) (*topo.SrvKeyspace, error) {
|
||||
return c.fallback.GetSrvKeyspace(ctx, keyspace)
|
||||
}
|
||||
|
||||
func (c *callerIDClient) HandlePanic(err *error) {
|
||||
if x := recover(); x != nil {
|
||||
log.Errorf("Uncaught panic:\n%v\n%s", x, tb.Stack(4))
|
||||
*err = fmt.Errorf("uncaught panic: %v", x)
|
||||
}
|
||||
}
|
|
@ -8,12 +8,12 @@ import (
|
|||
"fmt"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/youtube/vitess/go/tb"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/proto"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// errorClient implements vtgateservice.VTGateService
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/callerid"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// This file contains the reference test for clients. It tests
|
||||
// all the corner cases of the API, and makes sure the go client
|
||||
// is full featured.
|
||||
//
|
||||
// It can be used as a template by other languages for their test suites.
|
||||
//
|
||||
// TODO(team) add more unit test cases.
|
||||
|
||||
// testCallerID adds a caller ID to a context, and makes sure the server
|
||||
// gets it.
|
||||
func testCallerID(t *testing.T, conn *vtgateconn.VTGateConn) {
|
||||
t.Log("testCallerID")
|
||||
ctx := context.Background()
|
||||
callerID := callerid.NewEffectiveCallerID("test_principal", "test_component", "test_subcomponent")
|
||||
ctx = callerid.NewContext(ctx, callerID, nil)
|
||||
|
||||
data, err := json.Marshal(callerID)
|
||||
if err != nil {
|
||||
t.Errorf("failed to marshal callerid: %v", err)
|
||||
return
|
||||
}
|
||||
query := callerIDPrefix + string(data)
|
||||
|
||||
// test Execute forwards the callerID
|
||||
if _, err := conn.Execute(ctx, query, nil, topo.TYPE_MASTER); err != nil {
|
||||
t.Errorf("failed to pass callerid: %v", err)
|
||||
}
|
||||
|
||||
// FIXME(alainjobart) add all function calls
|
||||
}
|
||||
|
||||
func testGoClient(t *testing.T, protocol, addr string) {
|
||||
// Create a client connecting to the server
|
||||
ctx := context.Background()
|
||||
conn, err := vtgateconn.DialProtocol(ctx, protocol, addr, 30*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial failed: %v", err)
|
||||
}
|
||||
|
||||
testCallerID(t, conn)
|
||||
|
||||
// and clean up
|
||||
conn.Close()
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/rpcplus"
|
||||
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice"
|
||||
|
||||
// import the gorpc client, it will register itself
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
|
||||
)
|
||||
|
||||
// TestGoRPCGoClient tests the go client using goRPC
|
||||
func TestGoRPCGoClient(t *testing.T) {
|
||||
service := createService()
|
||||
|
||||
// listen on a random port
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Create a Go Rpc server and listen on the port
|
||||
server := rpcplus.NewServer()
|
||||
server.Register(gorpcvtgateservice.New(service))
|
||||
|
||||
// create the HTTP server, serve the server from it
|
||||
handler := http.NewServeMux()
|
||||
bsonrpc.ServeCustomRPC(handler, server, false)
|
||||
httpServer := http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
go httpServer.Serve(listener)
|
||||
|
||||
// and run the test suite
|
||||
testGoClient(t, "gorpc", listener.Addr().String())
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/vtgate/grpcvtgateservice"
|
||||
|
||||
// import the grpc client, it will register itself
|
||||
_ "github.com/youtube/vitess/go/vt/vtgate/grpcvtgateconn"
|
||||
)
|
||||
|
||||
// TestGRPCGoClient tests the go client using gRPC
|
||||
func TestGRPCGoClient(t *testing.T) {
|
||||
service := createService()
|
||||
|
||||
// listen on a random port
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Create a gRPC server and listen on the port
|
||||
server := grpc.NewServer()
|
||||
grpcvtgateservice.RegisterForTest(server, service)
|
||||
go server.Serve(listener)
|
||||
|
||||
// and run the test suite
|
||||
testGoClient(t, "grpc", listener.Addr().String())
|
||||
}
|
|
@ -13,12 +13,23 @@ import (
|
|||
"github.com/youtube/vitess/go/exit"
|
||||
"github.com/youtube/vitess/go/vt/servenv"
|
||||
"github.com/youtube/vitess/go/vt/vtgate"
|
||||
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
|
||||
)
|
||||
|
||||
func init() {
|
||||
servenv.RegisterDefaultFlags()
|
||||
}
|
||||
|
||||
// createService creates the implementation chain of all the test cases
|
||||
func createService() vtgateservice.VTGateService {
|
||||
var s vtgateservice.VTGateService
|
||||
s = newTerminalClient()
|
||||
s = newSuccessClient(s)
|
||||
s = newErrorClient(s)
|
||||
s = newCallerIDClient(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func main() {
|
||||
defer exit.Recover()
|
||||
|
||||
|
@ -26,9 +37,9 @@ func main() {
|
|||
servenv.Init()
|
||||
|
||||
// The implementation chain.
|
||||
c := newErrorClient(newSuccessClient(newTerminalClient()))
|
||||
s := createService()
|
||||
for _, f := range vtgate.RegisterVTGates {
|
||||
f(c)
|
||||
f(s)
|
||||
}
|
||||
|
||||
servenv.RunDefault()
|
||||
|
|
Загрузка…
Ссылка в новой задаче