Merge pull request #971 from alainjobart/resharding

Resharding
This commit is contained in:
Alain Jobart 2015-08-07 08:51:07 -07:00
Родитель 8d9479bf7c 07a81e272c
Коммит ef1e7a908e
6 изменённых файлов: 333 добавлений и 3 удалений

Просмотреть файл

@ -0,0 +1,176 @@
// Copyright 2015 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"encoding/json"
"fmt"
"reflect"
"strings"
log "github.com/golang/glog"
"golang.org/x/net/context"
"github.com/youtube/vitess/go/tb"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate/proto"
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
pb "github.com/youtube/vitess/go/vt/proto/vtrpc"
)
const callerIDPrefix = "callerid "
// callerIDClient implements vtgateservice.VTGateService, and checks
// the received callerid matches the one passed out of band by the client.
type callerIDClient struct {
fallback vtgateservice.VTGateService
}
func newCallerIDClient(fallback vtgateservice.VTGateService) *callerIDClient {
return &callerIDClient{
fallback: fallback,
}
}
// checkCallerID will see if this module is handling the request,
// and if it is, check the callerID from the context.
// Returns false if the query is not for this module.
// Returns true and the the error to return with the call
// if this module is handling the request.
func (c *callerIDClient) checkCallerID(ctx context.Context, received string) (bool, error) {
if !strings.HasPrefix(received, callerIDPrefix) {
return false, nil
}
jsonCallerID := []byte(received[len(callerIDPrefix):])
expectedCallerID := &pb.CallerID{}
if err := json.Unmarshal(jsonCallerID, expectedCallerID); err != nil {
return true, fmt.Errorf("cannot unmarshal provided callerid: %v", err)
}
receivedCallerID := callerid.EffectiveCallerIDFromContext(ctx)
if receivedCallerID == nil {
return true, fmt.Errorf("no callerid received in the query")
}
if !reflect.DeepEqual(receivedCallerID, expectedCallerID) {
return true, fmt.Errorf("callerid mismatch, got %v expected %v", receivedCallerID, expectedCallerID)
}
return true, nil
}
func (c *callerIDClient) Execute(ctx context.Context, query *proto.Query, reply *proto.QueryResult) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.Execute(ctx, query, reply)
}
func (c *callerIDClient) ExecuteShard(ctx context.Context, query *proto.QueryShard, reply *proto.QueryResult) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.ExecuteShard(ctx, query, reply)
}
func (c *callerIDClient) ExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, reply *proto.QueryResult) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.ExecuteKeyspaceIds(ctx, query, reply)
}
func (c *callerIDClient) ExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, reply *proto.QueryResult) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.ExecuteKeyRanges(ctx, query, reply)
}
func (c *callerIDClient) ExecuteEntityIds(ctx context.Context, query *proto.EntityIdsQuery, reply *proto.QueryResult) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.ExecuteEntityIds(ctx, query, reply)
}
func (c *callerIDClient) ExecuteBatchShard(ctx context.Context, batchQuery *proto.BatchQueryShard, reply *proto.QueryResultList) error {
if len(batchQuery.Queries) == 1 {
if ok, err := c.checkCallerID(ctx, batchQuery.Queries[0].Sql); ok {
return err
}
}
return c.fallback.ExecuteBatchShard(ctx, batchQuery, reply)
}
func (c *callerIDClient) ExecuteBatchKeyspaceIds(ctx context.Context, batchQuery *proto.KeyspaceIdBatchQuery, reply *proto.QueryResultList) error {
if len(batchQuery.Queries) == 1 {
if ok, err := c.checkCallerID(ctx, batchQuery.Queries[0].Sql); ok {
return err
}
}
return c.fallback.ExecuteBatchKeyspaceIds(ctx, batchQuery, reply)
}
func (c *callerIDClient) StreamExecute(ctx context.Context, query *proto.Query, sendReply func(*proto.QueryResult) error) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.StreamExecute(ctx, query, sendReply)
}
func (c *callerIDClient) StreamExecuteShard(ctx context.Context, query *proto.QueryShard, sendReply func(*proto.QueryResult) error) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.StreamExecuteShard(ctx, query, sendReply)
}
func (c *callerIDClient) StreamExecuteKeyRanges(ctx context.Context, query *proto.KeyRangeQuery, sendReply func(*proto.QueryResult) error) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.StreamExecuteKeyRanges(ctx, query, sendReply)
}
func (c *callerIDClient) StreamExecuteKeyspaceIds(ctx context.Context, query *proto.KeyspaceIdQuery, sendReply func(*proto.QueryResult) error) error {
if ok, err := c.checkCallerID(ctx, query.Sql); ok {
return err
}
return c.fallback.StreamExecuteKeyspaceIds(ctx, query, sendReply)
}
func (c *callerIDClient) Begin(ctx context.Context, outSession *proto.Session) error {
return c.fallback.Begin(ctx, outSession)
}
func (c *callerIDClient) Commit(ctx context.Context, inSession *proto.Session) error {
return c.fallback.Commit(ctx, inSession)
}
func (c *callerIDClient) Rollback(ctx context.Context, inSession *proto.Session) error {
return c.fallback.Rollback(ctx, inSession)
}
func (c *callerIDClient) SplitQuery(ctx context.Context, req *proto.SplitQueryRequest, reply *proto.SplitQueryResult) error {
if ok, err := c.checkCallerID(ctx, req.Query.Sql); ok {
return err
}
return c.fallback.SplitQuery(ctx, req, reply)
}
func (c *callerIDClient) GetSrvKeyspace(ctx context.Context, keyspace string) (*topo.SrvKeyspace, error) {
return c.fallback.GetSrvKeyspace(ctx, keyspace)
}
func (c *callerIDClient) HandlePanic(err *error) {
if x := recover(); x != nil {
log.Errorf("Uncaught panic:\n%v\n%s", x, tb.Stack(4))
*err = fmt.Errorf("uncaught panic: %v", x)
}
}

Просмотреть файл

@ -8,12 +8,12 @@ import (
"fmt"
log "github.com/golang/glog"
"golang.org/x/net/context"
"github.com/youtube/vitess/go/tb"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate/proto"
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
"golang.org/x/net/context"
)
// errorClient implements vtgateservice.VTGateService

Просмотреть файл

@ -0,0 +1,61 @@
// Copyright 2015 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"encoding/json"
"testing"
"time"
"github.com/youtube/vitess/go/vt/callerid"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
"golang.org/x/net/context"
)
// This file contains the reference test for clients. It tests
// all the corner cases of the API, and makes sure the go client
// is full featured.
//
// It can be used as a template by other languages for their test suites.
//
// TODO(team) add more unit test cases.
// testCallerID adds a caller ID to a context, and makes sure the server
// gets it.
func testCallerID(t *testing.T, conn *vtgateconn.VTGateConn) {
t.Log("testCallerID")
ctx := context.Background()
callerID := callerid.NewEffectiveCallerID("test_principal", "test_component", "test_subcomponent")
ctx = callerid.NewContext(ctx, callerID, nil)
data, err := json.Marshal(callerID)
if err != nil {
t.Errorf("failed to marshal callerid: %v", err)
return
}
query := callerIDPrefix + string(data)
// test Execute forwards the callerID
if _, err := conn.Execute(ctx, query, nil, topo.TYPE_MASTER); err != nil {
t.Errorf("failed to pass callerid: %v", err)
}
// FIXME(alainjobart) add all function calls
}
func testGoClient(t *testing.T, protocol, addr string) {
// Create a client connecting to the server
ctx := context.Background()
conn, err := vtgateconn.DialProtocol(ctx, protocol, addr, 30*time.Second)
if err != nil {
t.Fatalf("dial failed: %v", err)
}
testCallerID(t, conn)
// and clean up
conn.Close()
}

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright 2015 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"net"
"net/http"
"testing"
"github.com/youtube/vitess/go/rpcplus"
"github.com/youtube/vitess/go/rpcwrap/bsonrpc"
"github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateservice"
// import the gorpc client, it will register itself
_ "github.com/youtube/vitess/go/vt/vtgate/gorpcvtgateconn"
)
// TestGoRPCGoClient tests the go client using goRPC
func TestGoRPCGoClient(t *testing.T) {
service := createService()
// listen on a random port
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot listen: %v", err)
}
defer listener.Close()
// Create a Go Rpc server and listen on the port
server := rpcplus.NewServer()
server.Register(gorpcvtgateservice.New(service))
// create the HTTP server, serve the server from it
handler := http.NewServeMux()
bsonrpc.ServeCustomRPC(handler, server, false)
httpServer := http.Server{
Handler: handler,
}
go httpServer.Serve(listener)
// and run the test suite
testGoClient(t, "gorpc", listener.Addr().String())
}

Просмотреть файл

@ -0,0 +1,37 @@
// Copyright 2015 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"net"
"testing"
"google.golang.org/grpc"
"github.com/youtube/vitess/go/vt/vtgate/grpcvtgateservice"
// import the grpc client, it will register itself
_ "github.com/youtube/vitess/go/vt/vtgate/grpcvtgateconn"
)
// TestGRPCGoClient tests the go client using gRPC
func TestGRPCGoClient(t *testing.T) {
service := createService()
// listen on a random port
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("Cannot listen: %v", err)
}
defer listener.Close()
// Create a gRPC server and listen on the port
server := grpc.NewServer()
grpcvtgateservice.RegisterForTest(server, service)
go server.Serve(listener)
// and run the test suite
testGoClient(t, "grpc", listener.Addr().String())
}

Просмотреть файл

@ -13,12 +13,23 @@ import (
"github.com/youtube/vitess/go/exit"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/vtgate"
"github.com/youtube/vitess/go/vt/vtgate/vtgateservice"
)
func init() {
servenv.RegisterDefaultFlags()
}
// createService creates the implementation chain of all the test cases
func createService() vtgateservice.VTGateService {
var s vtgateservice.VTGateService
s = newTerminalClient()
s = newSuccessClient(s)
s = newErrorClient(s)
s = newCallerIDClient(s)
return s
}
func main() {
defer exit.Recover()
@ -26,9 +37,9 @@ func main() {
servenv.Init()
// The implementation chain.
c := newErrorClient(newSuccessClient(newTerminalClient()))
s := createService()
for _, f := range vtgate.RegisterVTGates {
f(c)
f(s)
}
servenv.RunDefault()