Merge branch 'master' of github.com:youtube/vitess

This commit is contained in:
Alain Jobart 2014-01-31 17:00:38 -08:00
Родитель ab9e3b267c 3d0e81a147
Коммит be8a9544c6
17 изменённых файлов: 868 добавлений и 198 удалений

Просмотреть файл

@ -6,6 +6,7 @@ package topo
import (
"bytes"
"fmt"
"sort"
"github.com/youtube/vitess/go/bson"
@ -13,6 +14,10 @@ import (
"github.com/youtube/vitess/go/vt/key"
)
// This is the shard name for when the keyrange covers the entire space
// for unsharded database.
const SHARD_ZERO = "0"
// SrvShard contains a roll-up of the shard in the local namespace.
// In zk, it is under /zk/local/vt/ns/<keyspace>/<shard>
type SrvShard struct {
@ -120,6 +125,13 @@ func (ss *SrvShard) UnmarshalBson(buf *bytes.Buffer) {
}
}
func (ss *SrvShard) ShardName() string {
if !ss.KeyRange.IsPartial() {
return SHARD_ZERO
}
return fmt.Sprintf("%v-%v", string(ss.KeyRange.Start.Hex()), string(ss.KeyRange.End.Hex()))
}
// KeyspacePartition represents a continuous set of shards to
// serve an entire data set.
type KeyspacePartition struct {

Просмотреть файл

@ -25,8 +25,8 @@ func (vtg *VTGate) ExecuteBatchShard(context *rpcproto.Context, batchQuery *prot
return vtg.server.ExecuteBatchShard(context, batchQuery, reply)
}
func (vtg *VTGate) StreamExecuteShard(context *rpcproto.Context, query *proto.QueryShard, sendReply func(interface{}) error) error {
return vtg.server.StreamExecuteShard(context, query, func(value *proto.QueryResult) error {
func (vtg *VTGate) StreamExecuteKeyRange(context *rpcproto.Context, query *proto.StreamQueryKeyRange, sendReply func(interface{}) error) error {
return vtg.server.StreamExecuteKeyRange(context, query, func(value *proto.QueryResult) error {
return sendReply(value)
})
}

Просмотреть файл

@ -72,14 +72,14 @@ func (session *Session) UnmarshalBson(buf *bytes.Buffer) {
kind := bson.NextByte(buf)
for kind != bson.EOO {
key := bson.ReadCString(buf)
switch key {
keyName := bson.ReadCString(buf)
switch keyName {
case "InTransaction":
session.InTransaction = bson.DecodeBool(buf, kind)
case "ShardSessions":
session.ShardSessions = decodeShardSessionsBson(buf, kind)
default:
panic(bson.NewBsonError("Unrecognized tag %s", key))
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}
@ -117,8 +117,8 @@ func (shardSession *ShardSession) UnmarshalBson(buf *bytes.Buffer) {
kind := bson.NextByte(buf)
for kind != bson.EOO {
key := bson.ReadCString(buf)
switch key {
keyName := bson.ReadCString(buf)
switch keyName {
case "Keyspace":
shardSession.Keyspace = bson.DecodeString(buf, kind)
case "Shard":
@ -128,7 +128,7 @@ func (shardSession *ShardSession) UnmarshalBson(buf *bytes.Buffer) {
case "TransactionId":
shardSession.TransactionId = bson.DecodeInt64(buf, kind)
default:
panic(bson.NewBsonError("Unrecognized tag %s", key))
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}
@ -170,8 +170,8 @@ func (qrs *QueryShard) UnmarshalBson(buf *bytes.Buffer) {
kind := bson.NextByte(buf)
for kind != bson.EOO {
key := bson.ReadCString(buf)
switch key {
keyName := bson.ReadCString(buf)
switch keyName {
case "Sql":
qrs.Sql = bson.DecodeString(buf, kind)
case "BindVariables":
@ -186,7 +186,7 @@ func (qrs *QueryShard) UnmarshalBson(buf *bytes.Buffer) {
qrs.Session = new(Session)
qrs.Session.UnmarshalBson(buf)
default:
panic(bson.NewBsonError("Unrecognized tag %s", key))
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}
@ -237,8 +237,8 @@ func (qr *QueryResult) UnmarshalBson(buf *bytes.Buffer) {
kind := bson.NextByte(buf)
for kind != bson.EOO {
key := bson.ReadCString(buf)
switch key {
keyName := bson.ReadCString(buf)
switch keyName {
case "Fields":
qr.Fields = mproto.DecodeFieldsBson(buf, kind)
case "RowsAffected":
@ -253,7 +253,7 @@ func (qr *QueryResult) UnmarshalBson(buf *bytes.Buffer) {
case "Error":
qr.Error = bson.DecodeString(buf, kind)
default:
panic(bson.NewBsonError("Unrecognized tag %s", key))
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}
@ -293,8 +293,8 @@ func (bqs *BatchQueryShard) UnmarshalBson(buf *bytes.Buffer) {
kind := bson.NextByte(buf)
for kind != bson.EOO {
key := bson.ReadCString(buf)
switch key {
keyName := bson.ReadCString(buf)
switch keyName {
case "Queries":
bqs.Queries = tproto.DecodeQueriesBson(buf, kind)
case "Keyspace":
@ -307,7 +307,7 @@ func (bqs *BatchQueryShard) UnmarshalBson(buf *bytes.Buffer) {
bqs.Session = new(Session)
bqs.Session.UnmarshalBson(buf)
default:
panic(bson.NewBsonError("Unrecognized tag %s", key))
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}
@ -345,8 +345,8 @@ func (qrl *QueryResultList) UnmarshalBson(buf *bytes.Buffer) {
kind := bson.NextByte(buf)
for kind != bson.EOO {
key := bson.ReadCString(buf)
switch key {
keyName := bson.ReadCString(buf)
switch keyName {
case "List":
qrl.List = tproto.DecodeResultsBson(buf, kind)
case "Session":
@ -355,7 +355,61 @@ func (qrl *QueryResultList) UnmarshalBson(buf *bytes.Buffer) {
case "Error":
qrl.Error = bson.DecodeString(buf, kind)
default:
panic(bson.NewBsonError("Unrecognized tag %s", key))
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}
}
type StreamQueryKeyRange struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRange string
TabletType topo.TabletType
Session *Session
}
func (sqs *StreamQueryKeyRange) MarshalBson(buf *bytes2.ChunkedWriter) {
lenWriter := bson.NewLenWriter(buf)
bson.EncodeString(buf, "Sql", sqs.Sql)
tproto.EncodeBindVariablesBson(buf, "BindVariables", sqs.BindVariables)
bson.EncodeString(buf, "Keyspace", sqs.Keyspace)
bson.EncodeString(buf, "KeyRange", sqs.KeyRange)
bson.EncodeString(buf, "TabletType", string(sqs.TabletType))
if sqs.Session != nil {
bson.EncodePrefix(buf, bson.Object, "Session")
sqs.Session.MarshalBson(buf)
}
buf.WriteByte(0)
lenWriter.RecordLen()
}
func (sqs *StreamQueryKeyRange) UnmarshalBson(buf *bytes.Buffer) {
bson.Next(buf, 4)
kind := bson.NextByte(buf)
for kind != bson.EOO {
keyName := bson.ReadCString(buf)
switch keyName {
case "Sql":
sqs.Sql = bson.DecodeString(buf, kind)
case "BindVariables":
sqs.BindVariables = tproto.DecodeBindVariablesBson(buf, kind)
case "Keyspace":
sqs.Keyspace = bson.DecodeString(buf, kind)
case "KeyRange":
sqs.KeyRange = bson.DecodeString(buf, kind)
case "TabletType":
sqs.TabletType = topo.TabletType(bson.DecodeString(buf, kind))
case "Session":
sqs.Session = new(Session)
sqs.Session.UnmarshalBson(buf)
default:
panic(bson.NewBsonError("Unrecognized tag %s", keyName))
}
kind = bson.NextByte(buf)
}

Просмотреть файл

@ -393,6 +393,78 @@ func TestQueryResultList(t *testing.T) {
}
unexpected, err := bson.Marshal(&badQueryResultList{})
if err != nil {
t.Error(err)
}
err = bson.Unmarshal(unexpected, &unmarshalled)
want = "Unrecognized tag Extra"
if err == nil || want != err.Error() {
t.Errorf("want %v, got %v", want, err)
}
}
type reflectStreamQueryKeyRange struct {
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRange string
TabletType topo.TabletType
Session *Session
}
type badStreamQueryKeyRange struct {
Extra int
Sql string
BindVariables map[string]interface{}
Keyspace string
KeyRange string
TabletType topo.TabletType
Session *Session
}
func TestStreamQueryKeyRange(t *testing.T) {
reflected, err := bson.Marshal(&reflectStreamQueryKeyRange{
Sql: "query",
BindVariables: map[string]interface{}{"val": int64(1)},
Keyspace: "keyspace",
KeyRange: "10-18",
TabletType: "replica",
Session: &commonSession,
})
if err != nil {
t.Error(err)
}
want := string(reflected)
custom := StreamQueryKeyRange{
Sql: "query",
BindVariables: map[string]interface{}{"val": int64(1)},
Keyspace: "keyspace",
KeyRange: "10-18",
TabletType: "replica",
Session: &commonSession,
}
encoded, err := bson.Marshal(&custom)
if err != nil {
t.Error(err)
}
got := string(encoded)
if want != got {
t.Errorf("want\n%#v, got\n%#v", want, got)
}
var unmarshalled StreamQueryKeyRange
err = bson.Unmarshal(encoded, &unmarshalled)
if err != nil {
t.Error(err)
}
if !reflect.DeepEqual(custom, unmarshalled) {
t.Errorf("want \n%#v, got \n%#v", custom, unmarshalled)
}
unexpected, err := bson.Marshal(&badStreamQueryKeyRange{})
if err != nil {
t.Error(err)
}

Просмотреть файл

@ -14,12 +14,13 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/sync2"
"github.com/youtube/vitess/go/vt/key"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/tabletserver/tabletconn"
"github.com/youtube/vitess/go/vt/topo"
)
// sandbox_test.go provides a sandbox for unit testing Barnacle.
// sandbox_test.go provides a sandbox for unit testing VTGate.
func init() {
tabletconn.RegisterDialer("sandbox", sandboxDialer)
@ -48,6 +49,11 @@ var (
transactionId sync2.AtomicInt64
)
const (
TEST_SHARDED = "TestSharded"
TEST_UNSHARDED = "TestUnshared"
)
func resetSandbox() {
sandmu.Lock()
defer sandmu.Unlock()
@ -62,12 +68,90 @@ func resetSandbox() {
type sandboxTopo struct {
}
var ShardSpec = "-20-40-60-80-a0-c0-e0-"
var ShardedKrArray key.KeyRangeArray
func getAllShards() (key.KeyRangeArray, error) {
if ShardedKrArray != nil {
return ShardedKrArray, nil
}
ShardedKrArray, err := key.ParseShardingSpec(ShardSpec)
if err != nil {
return nil, err
}
return ShardedKrArray, nil
}
func getKeyRangeName(kr key.KeyRange) string {
return fmt.Sprintf("%v-%v", string(kr.Start.Hex()), string(kr.End.Hex()))
}
func getUidForShard(shardName string) (int, error) {
// Try simple unsharded case first
uid, err := strconv.Atoi(shardName)
if err == nil {
return uid, nil
}
shards, err := getAllShards()
if err != nil {
return 0, fmt.Errorf("shard not found %v", shardName)
}
for i, s := range shards {
if shardName == getKeyRangeName(s) {
return i, nil
}
}
return 0, fmt.Errorf("shard not found %v", shardName)
}
func (sct *sandboxTopo) GetSrvKeyspaceNames(cell string) ([]string, error) {
panic(fmt.Errorf("not implemented"))
return []string{TEST_SHARDED, TEST_UNSHARDED}, nil
}
func (sct *sandboxTopo) GetSrvKeyspace(cell, keyspace string) (*topo.SrvKeyspace, error) {
panic(fmt.Errorf("not implemented"))
shardKrArray, err := getAllShards()
if err != nil {
return nil, err
}
shards := make([]topo.SrvShard, 0, len(shardKrArray))
for i := 0; i < len(shardKrArray); i++ {
shard := topo.SrvShard{
KeyRange: shardKrArray[i],
ServedTypes: []topo.TabletType{topo.TYPE_MASTER},
TabletTypes: []topo.TabletType{topo.TYPE_MASTER},
}
shards = append(shards, shard)
}
shardedSrvKeyspace := &topo.SrvKeyspace{
Partitions: map[topo.TabletType]*topo.KeyspacePartition{
topo.TYPE_MASTER: &topo.KeyspacePartition{
Shards: shards,
},
},
TabletTypes: []topo.TabletType{topo.TYPE_MASTER},
}
unshardedSrvKeyspace := &topo.SrvKeyspace{
Partitions: map[topo.TabletType]*topo.KeyspacePartition{
topo.TYPE_MASTER: &topo.KeyspacePartition{
Shards: []topo.SrvShard{
{KeyRange: key.KeyRange{Start: "", End: ""},
ServedTypes: []topo.TabletType{topo.TYPE_MASTER},
TabletTypes: []topo.TabletType{topo.TYPE_MASTER},
},
},
},
},
TabletTypes: []topo.TabletType{topo.TYPE_MASTER},
}
// Return unsharded SrvKeyspace record if asked
// By default return the sharded keyspace
if keyspace == TEST_UNSHARDED {
return unshardedSrvKeyspace, nil
}
return shardedSrvKeyspace, nil
}
func (sct *sandboxTopo) GetEndPoints(cell, keyspace, shard string, tabletType topo.TabletType) (*topo.EndPoints, error) {
@ -78,7 +162,7 @@ func (sct *sandboxTopo) GetEndPoints(cell, keyspace, shard string, tabletType to
endPointMustFail--
return nil, fmt.Errorf("topo error")
}
uid, err := strconv.Atoi(shard)
uid, err := getUidForShard(shard)
if err != nil {
panic(err)
}
@ -105,6 +189,14 @@ func sandboxDialer(context interface{}, endPoint topo.EndPoint, keyspace, shard
return tconn, nil
}
func mapTestConn(shard string, conn tabletconn.TabletConn) {
uid, err := getUidForShard(shard)
if err != nil {
panic(err)
}
testConns[uint32(uid)] = conn
}
// sandboxConn satisfies the TabletConn interface
type sandboxConn struct {
endPoint topo.EndPoint

Просмотреть файл

@ -127,7 +127,7 @@ func (sdc *ShardConn) Close() {
sdc.conn = nil
}
// withRetry sets up the connection and exexutes the action. If there are connection errors,
// withRetry sets up the connection and executes the action. If there are connection errors,
// it retries retryCount times before failing. It does not retry if the connection is in
// the middle of a transaction.
func (sdc *ShardConn) withRetry(context interface{}, action func(conn tabletconn.TabletConn) error, transactionId int64) error {

Просмотреть файл

@ -6,12 +6,14 @@ package vtgate
import (
"flag"
"fmt"
"sync"
"time"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/stats"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/topo"
)
@ -219,3 +221,36 @@ func (server *ResilientSrvTopoServer) GetEndPoints(cell, keyspace, shard string,
entry.value = result
return result, nil
}
// This maps a list of keyranges to shard names.
func resolveKeyRangeToShards(topoServer SrvTopoServer, cell, keyspace string, tabletType topo.TabletType, kr key.KeyRange) ([]string, error) {
srvKeyspace, err := topoServer.GetSrvKeyspace(cell, keyspace)
if err != nil {
return nil, fmt.Errorf("Error in reading the keyspace %v", err)
}
tabletTypePartition, ok := srvKeyspace.Partitions[tabletType]
if !ok {
return nil, fmt.Errorf("No shards available for tablet type '%v' in keyspace '%v'", tabletType, keyspace)
}
topo.SrvShardArray(tabletTypePartition.Shards).Sort()
shards := make([]string, 0, 1)
if !kr.IsPartial() {
for j := 0; j < len(tabletTypePartition.Shards); j++ {
shards = append(shards, tabletTypePartition.Shards[j].ShardName())
}
return shards, nil
}
for j := 0; j < len(tabletTypePartition.Shards); j++ {
shard := tabletTypePartition.Shards[j]
if key.KeyRangesIntersect(kr, shard.KeyRange) {
shards = append(shards, shard.ShardName())
}
if kr.End < shard.KeyRange.Start {
break
}
}
return shards, nil
}

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vtgate
import (
"reflect"
"testing"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/topo"
)
func TestKeyRangeToShardMap(t *testing.T) {
ts := new(sandboxTopo)
var testCases = []struct {
keyspace string
keyRange string
shards []string
}{
{keyspace: TEST_SHARDED, keyRange: "20-40", shards: []string{"20-40"}},
// check for partial keyrange, spanning one shard
{keyspace: TEST_SHARDED, keyRange: "10-18", shards: []string{"-20"}},
// check for keyrange intersecting with multiple shards
{keyspace: TEST_SHARDED, keyRange: "10-40", shards: []string{"-20", "20-40"}},
// check for keyrange intersecting with multiple shards
{keyspace: TEST_SHARDED, keyRange: "1C-2A", shards: []string{"-20", "20-40"}},
// test for sharded, non-partial keyrange spanning the entire space.
{keyspace: TEST_SHARDED, keyRange: "", shards: []string{"-20", "20-40", "40-60", "60-80", "80-A0", "A0-C0", "C0-E0", "E0-"}},
// test for unsharded, non-partial keyrange spanning the entire space.
{keyspace: TEST_UNSHARDED, keyRange: "", shards: []string{"0"}},
}
for _, testCase := range testCases {
var keyRange key.KeyRange
var err error
if testCase.keyRange == "" {
keyRange = key.KeyRange{Start: "", End: ""}
} else {
krArray, err := key.ParseShardingSpec(testCase.keyRange)
if err != nil {
t.Errorf("Got error while parsing sharding spec %v", err)
}
keyRange = krArray[0]
}
gotShards, err := resolveKeyRangeToShards(ts, "", testCase.keyspace, topo.TYPE_MASTER, keyRange)
if err != nil {
t.Errorf("want nil, got %v", err)
}
if !reflect.DeepEqual(testCase.shards, gotShards) {
t.Errorf("want \n%#v, got \n%#v", testCase.shards, gotShards)
}
}
}

Просмотреть файл

@ -7,10 +7,12 @@
package vtgate
import (
"fmt"
"time"
log "github.com/golang/glog"
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/vtgate/proto"
)
@ -78,16 +80,57 @@ func (vtg *VTGate) ExecuteBatchShard(context interface{}, batchQuery *proto.Batc
return nil
}
// StreamExecuteShard executes a streaming query on the specified shards.
func (vtg *VTGate) StreamExecuteShard(context interface{}, query *proto.QueryShard, sendReply func(*proto.QueryResult) error) error {
err := vtg.scatterConn.StreamExecute(
// This function implements the restriction of handling one keyrange
// and one shard since streaming doesn't support merge sorting the results.
// The input/output api is generic though.
func (vtg *VTGate) mapKrToShardsForStreaming(streamQuery *proto.StreamQueryKeyRange) ([]string, error) {
var keyRange key.KeyRange
var err error
if streamQuery.KeyRange == "" {
keyRange = key.KeyRange{Start: "", End: ""}
} else {
krArray, err := key.ParseShardingSpec(streamQuery.KeyRange)
if err != nil {
return nil, err
}
keyRange = krArray[0]
}
shards, err := resolveKeyRangeToShards(vtg.scatterConn.toposerv,
vtg.scatterConn.cell,
streamQuery.Keyspace,
streamQuery.TabletType,
keyRange)
if err != nil {
return nil, err
}
if len(shards) != 1 {
return nil, fmt.Errorf("KeyRange cannot map to more than one shard")
}
return shards, nil
}
// StreamExecuteKeyRange executes a streaming query on the specified KeyRange.
// The KeyRange is resolved to shards using the serving graph.
// This function currently temporarily enforces the restriction of executing on one keyrange
// and one shard since it cannot merge-sort the results to guarantee ordering of
// response which is needed for checkpointing. The api supports supplying multiple keyranges
// to make it future proof.
func (vtg *VTGate) StreamExecuteKeyRange(context interface{}, streamQuery *proto.StreamQueryKeyRange, sendReply func(*proto.QueryResult) error) error {
shards, err := vtg.mapKrToShardsForStreaming(streamQuery)
if err != nil {
return err
}
err = vtg.scatterConn.StreamExecute(
context,
query.Sql,
query.BindVariables,
query.Keyspace,
query.Shards,
query.TabletType,
NewSafeSession(query.Session),
streamQuery.Sql,
streamQuery.BindVariables,
streamQuery.Keyspace,
shards,
streamQuery.TabletType,
NewSafeSession(streamQuery.Session),
func(mreply *mproto.QueryResult) error {
reply := new(proto.QueryResult)
proto.PopulateQueryResult(mreply, reply)
@ -96,12 +139,13 @@ func (vtg *VTGate) StreamExecuteShard(context interface{}, query *proto.QuerySha
// are sent.
return sendReply(reply)
})
if err != nil {
log.Errorf("StreamExecuteShard: %v, query: %#v", err, query)
log.Errorf("StreamExecuteKeyRange: %v, query: %#v", err, streamQuery)
}
// now we can send the final Session info.
if query.Session != nil {
sendReply(&proto.QueryResult{Session: query.Session})
if streamQuery.Session != nil {
sendReply(&proto.QueryResult{Session: streamQuery.Session})
}
return err
}

Просмотреть файл

@ -10,6 +10,7 @@ import (
"time"
tproto "github.com/youtube/vitess/go/vt/tabletserver/proto"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate/proto"
)
@ -78,8 +79,8 @@ func TestVTGateExecuteShard(t *testing.T) {
func TestVTGateExecuteBatchShard(t *testing.T) {
resetSandbox()
testConns[0] = &sandboxConn{}
testConns[1] = &sandboxConn{}
mapTestConn("-20", &sandboxConn{})
mapTestConn("20-40", &sandboxConn{})
q := proto.BatchQueryShard{
Queries: []tproto.BoundQuery{{
"query",
@ -88,7 +89,7 @@ func TestVTGateExecuteBatchShard(t *testing.T) {
"query",
nil,
}},
Shards: []string{"0", "1"},
Shards: []string{"-20", "20-40"},
}
qrl := new(proto.QueryResultList)
err := RpcVTGate.ExecuteBatchShard(nil, &q, qrl)
@ -113,16 +114,18 @@ func TestVTGateExecuteBatchShard(t *testing.T) {
}
}
func TestVTGateStreamExecuteShard(t *testing.T) {
func TestVTGateStreamExecuteKeyRange(t *testing.T) {
resetSandbox()
sbc := &sandboxConn{}
testConns[0] = sbc
q := proto.QueryShard{
Sql: "query",
Shards: []string{"0"},
mapTestConn("-20", sbc)
sq := proto.StreamQueryKeyRange{
Sql: "query",
KeyRange: "-20",
TabletType: topo.TYPE_MASTER,
}
// Test for successful execution
var qrs []*proto.QueryResult
err := RpcVTGate.StreamExecuteShard(nil, &q, func(r *proto.QueryResult) error {
err := RpcVTGate.StreamExecuteKeyRange(nil, &sq, func(r *proto.QueryResult) error {
qrs = append(qrs, r)
return nil
})
@ -136,10 +139,10 @@ func TestVTGateStreamExecuteShard(t *testing.T) {
t.Errorf("want \n%#v, got \n%#v", want, qrs)
}
q.Session = new(proto.Session)
sq.Session = new(proto.Session)
qrs = nil
RpcVTGate.Begin(nil, q.Session)
err = RpcVTGate.StreamExecuteShard(nil, &q, func(r *proto.QueryResult) error {
RpcVTGate.Begin(nil, sq.Session)
err = RpcVTGate.StreamExecuteKeyRange(nil, &sq, func(r *proto.QueryResult) error {
qrs = append(qrs, r)
return nil
})
@ -149,8 +152,9 @@ func TestVTGateStreamExecuteShard(t *testing.T) {
Session: &proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{{
Shard: "0",
Shard: "-20",
TransactionId: 1,
TabletType: topo.TYPE_MASTER,
}},
},
},
@ -158,4 +162,23 @@ func TestVTGateStreamExecuteShard(t *testing.T) {
if !reflect.DeepEqual(want, qrs) {
t.Errorf("want \n%#v, got \n%#v", want, qrs)
}
// Test for error condition - multiple shards
sq.KeyRange = "10-40"
err = RpcVTGate.StreamExecuteKeyRange(nil, &sq, func(r *proto.QueryResult) error {
qrs = append(qrs, r)
return nil
})
if err == nil {
t.Errorf("want not nil, got %v", err)
}
// Test for error condition - multiple shards, non-partial keyspace
sq.KeyRange = ""
err = RpcVTGate.StreamExecuteKeyRange(nil, &sq, func(r *proto.QueryResult) error {
qrs = append(qrs, r)
return nil
})
if err == nil {
t.Errorf("want not nil, got %v", err)
}
}

Просмотреть файл

@ -8,6 +8,40 @@ from vtdb import dbexceptions
# bind_vars for distrubuting the workload of streaming queries.
# Keyrange that spans the entire space, used
# for unsharded database.
NON_PARTIAL_KEYRANGE = ""
MIN_KEY = ''
MAX_KEY = ''
KIT_UNSET = ""
KIT_UINT64 = "uint64"
KIT_BYTES = "bytes"
class KeyRange(object):
Start = None
End = None
def __init__(self, kr):
if isinstance(kr, str):
if kr == NON_PARTIAL_KEYRANGE:
self.Start = ""
self.End = ""
return
else:
kr = kr.split('-')
if not isinstance(kr, tuple) and not isinstance(kr, list) or len(kr) != 2:
raise dbexceptions.ProgrammingError("keyrange must be a list or tuple or a '-' separated str %s" % keyrange)
self.Start = kr[0].strip()
self.End = kr[1].strip()
def __str__(self):
if self.Start == MIN_KEY and self.End == MAX_KEY:
return NON_PARTIAL_KEYRANGE
return '%s-%s' % (self.Start, self.End)
class StreamingTaskMap(object):
keyrange_list = None
@ -27,7 +61,7 @@ class StreamingTaskMap(object):
#kr_chunks.append(hex(kr).split('0x')[1])
kr_chunks.append('%x' % kr)
kr_chunks[-1] = ''
self.keyrange_list = [(kr_chunks[i], kr_chunks[i+1]) for i in xrange(len(kr_chunks) - 1)]
self.keyrange_list = [str(KeyRange((kr_chunks[i], kr_chunks[i+1],))) for i in xrange(len(kr_chunks) - 1)]
# Compute the task map for a streaming query.
@ -54,17 +88,13 @@ def _true_int_kr_value(kr_value):
return int(kr_value, base=16)
MIN_KEY = ''
MAX_KEY = ''
KIT_UNSET = ""
KIT_UINT64 = "uint64"
KIT_BYTES = "bytes"
# Compute the where clause and bind_vars for a given keyrange.
def create_where_clause_for_keyrange(keyrange, keyspace_col_name='keyspace_id', keyspace_col_type=KIT_UINT64):
if isinstance(keyrange, str):
# If the keyrange is for unsharded db, there is no
# where clause to add to or bind_vars to add to.
if keyrange == NON_PARTIAL_KEYRANGE:
return "", {}
keyrange = keyrange.split('-')
if not isinstance(keyrange, tuple) and not isinstance(keyrange, list) or len(keyrange) != 2:

Просмотреть файл

@ -11,6 +11,11 @@ from net import gorpc
from vtdb import cursor
from vtdb import dbexceptions
from vtdb import field_types
from vtdb import keyrange
# This is the shard name for when the keyrange covers the entire space
# for unsharded database.
SHARD_ZERO = "0"
_errno_pattern = re.compile('\(errno (\d+)\)')
@ -26,12 +31,6 @@ def convert_exception(exc, *args):
return dbexceptions.TimeoutError(new_args)
elif isinstance(exc, gorpc.AppError):
msg = str(exc[0]).lower()
if msg.startswith('retry'):
return dbexceptions.RetryError(new_args)
if msg.startswith('fatal'):
return dbexceptions.FatalError(new_args)
if msg.startswith('tx_pool_full'):
return dbexceptions.TxPoolFull(new_args)
match = _errno_pattern.search(msg)
if match:
mysql_errno = int(match.group(1))
@ -204,12 +203,20 @@ class VtgateConnection(object):
# (that way we avoid using a member variable here for such a corner case)
def _stream_execute(self, sql, bind_variables):
new_binds = field_types.convert_bind_vars(bind_variables)
key_range = None
# For the unsharded keyspace, the keyrange should cover the
# entire space.
if self.shard == SHARD_ZERO:
key_range = str(keyrange.KeyRange(""))
else:
key_range = str(keyrange.KeyRange(self.shard))
req = {
'Sql': sql,
'BindVariables': new_binds,
'Keyspace': self.keyspace,
'KeyRange': key_range,
'TabletType': self.tablet_type,
'Shards': [self.shard],
}
self._add_session(req)
@ -218,7 +225,7 @@ class VtgateConnection(object):
self._stream_result = None
self._stream_result_index = 0
try:
self.client.stream_call('VTGate.StreamExecuteShard', req)
self.client.stream_call('VTGate.StreamExecuteKeyRange', req)
first_response = self.client.stream_next()
reply = first_response.reply

Просмотреть файл

@ -12,22 +12,22 @@ from vtdb import keyrange
# and where clauses for streaming queries.
pkid_pack = struct.Struct('!Q').pack
int_shard_kid_map = {('', '10'):[1, 100, 1000, 100000, 527875958493693904, 626750931627689502, 345387386794260318, 332484755310826578],
('10', '20'):[1842642426274125671, 1326307661227634652, 1761124146422844620, 1661669973250483744],
('20', '30'):[3361397649937244239, 3303511690915522723, 2444880764308344533, 2973657788686139039],
('30', '40'):[3821005920507858605, 4575089859165626432, 3607090456016432961, 3979558375123453425],
('40', '50'):[5129057445097465905, 5464969577815708398, 5190676584475132364, 5762096070688827561],
('50', '60'):[6419540613918919447, 6867152356089593986, 6601838130703675400, 6132605084892127391],
('60', '70'):[7251511061270371980, 7395364497868053835, 7814586147633440734, 7968977924086033834],
('70', '80'):[8653665459643609079, 8419099072545971426, 9020726671664230611, 9064594986161620444],
('80', '90'):[9767889778372766922, 9742070682920810358, 10296850775085416642, 9537430901666854108],
('90', 'a0'):[10440455099304929791, 11454183276974683945, 11185910247776122031, 10460396697869122981],
('a0', 'b0'):[11935085245138597119, 12115696589214223782, 12639360876311033978, 12548906240535188165],
('b0', 'c0'):[13379616110062597001, 12826553979133932576, 13288572810772383281, 13471801046560785347],
('c0', 'd0'):[14394342688314745188, 14639660031570920207, 14646353412066152016, 14186650213447467187],
('d0', 'e0'):[15397348460895960623, 16014223083986915239, 15058390871463382185, 15811857963302932363],
('e0', 'f0'):[17275711019497396001, 16979796627403646478, 16635982235308289704, 16906674090344806032],
('f0', ''):[18229242992218358675, 17623451135465171527, 18333015752598164958, 17775908119782706671],
int_shard_kid_map = {'-10':[1, 100, 1000, 100000, 527875958493693904, 626750931627689502, 345387386794260318, 332484755310826578],
'10-20':[1842642426274125671, 1326307661227634652, 1761124146422844620, 1661669973250483744],
'20-30':[3361397649937244239, 3303511690915522723, 2444880764308344533, 2973657788686139039],
'30-40':[3821005920507858605, 4575089859165626432, 3607090456016432961, 3979558375123453425],
'40-50':[5129057445097465905, 5464969577815708398, 5190676584475132364, 5762096070688827561],
'50-60':[6419540613918919447, 6867152356089593986, 6601838130703675400, 6132605084892127391],
'60-70':[7251511061270371980, 7395364497868053835, 7814586147633440734, 7968977924086033834],
'70-80':[8653665459643609079, 8419099072545971426, 9020726671664230611, 9064594986161620444],
'80-90':[9767889778372766922, 9742070682920810358, 10296850775085416642, 9537430901666854108],
'90-a0':[10440455099304929791, 11454183276974683945, 11185910247776122031, 10460396697869122981],
'a0-b0':[11935085245138597119, 12115696589214223782, 12639360876311033978, 12548906240535188165],
'b0-c0':[13379616110062597001, 12826553979133932576, 13288572810772383281, 13471801046560785347],
'c0-d0':[14394342688314745188, 14639660031570920207, 14646353412066152016, 14186650213447467187],
'd0-e0':[15397348460895960623, 16014223083986915239, 15058390871463382185, 15811857963302932363],
'e0-f0':[17275711019497396001, 16979796627403646478, 16635982235308289704, 16906674090344806032],
'f0-':[18229242992218358675, 17623451135465171527, 18333015752598164958, 17775908119782706671],
}
# str_shard_kid_map is derived from int_shard_kid_map
@ -60,9 +60,10 @@ class TestKeyRange(unittest.TestCase):
def test_bind_values_for_int_keyspace(self):
stm = keyrange.create_streaming_task_map(16, 16)
for i, kr in enumerate(stm.keyrange_list):
kr_parts = kr.split('-')
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(kr)
if len(bind_vars.keys()) == 1:
if kr[0] == '':
if kr_parts[0] == '':
self.assertNotEqual(where_clause.find('<'), -1)
else:
self.assertNotEqual(where_clause.find('>='), -1)
@ -73,7 +74,7 @@ class TestKeyRange(unittest.TestCase):
kid_list = int_shard_kid_map[kr]
for keyspace_id in kid_list:
if len(bind_vars.keys()) == 1:
if kr[0] == '':
if kr_parts[0] == '':
self.assertLess(keyspace_id, bind_vars['keyspace_id0'])
else:
self.assertGreaterEqual(keyspace_id, bind_vars['keyspace_id0'])
@ -90,9 +91,10 @@ class TestKeyRange(unittest.TestCase):
def test_bind_values_for_str_keyspace(self):
stm = keyrange.create_streaming_task_map(16, 16)
for i, kr in enumerate(stm.keyrange_list):
kr_parts = kr.split('-')
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(kr, keyspace_col_type=keyrange.KIT_BYTES)
if len(bind_vars.keys()) == 1:
if kr[0] == '':
if kr_parts[0] == '':
self.assertNotEqual(where_clause.find('<'), -1)
else:
self.assertNotEqual(where_clause.find('>='), -1)
@ -103,7 +105,7 @@ class TestKeyRange(unittest.TestCase):
kid_list = str_shard_kid_map[kr]
for keyspace_id in kid_list:
if len(bind_vars.keys()) == 1:
if kr[0] == '':
if kr_parts[0] == '':
self.assertLess(keyspace_id.encode('hex'), bind_vars['keyspace_id0'])
else:
self.assertGreaterEqual(keyspace_id.encode('hex'), bind_vars['keyspace_id0'])
@ -111,5 +113,12 @@ class TestKeyRange(unittest.TestCase):
self.assertGreaterEqual(keyspace_id.encode('hex'), bind_vars['keyspace_id0'])
self.assertLess(keyspace_id.encode('hex'), bind_vars['keyspace_id1'])
def test_bind_values_for_unsharded_keyspace(self):
stm = keyrange.create_streaming_task_map(1, 1)
self.assertEqual(len(stm.keyrange_list), 1)
where_clause, bind_vars = keyrange.create_where_clause_for_keyrange(stm.keyrange_list[0])
self.assertEqual(where_clause, "")
self.assertEqual(bind_vars, {})
if __name__ == '__main__':
utils.main()

Просмотреть файл

@ -415,7 +415,7 @@ primary key (name)
utils.run_vtctl(['RebuildKeyspaceGraph', '-use-served-types', 'test_keyspace'],
auto_log=True)
self._check_srv_keyspace('test_nj', 'test_keyspace',
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-\n' +
'Partitions(rdonly): -80 80-\n' +
'Partitions(replica): -80 80-\n' +
@ -501,7 +501,7 @@ primary key (name)
# now serve rdonly from the split shards
utils.run_vtctl(['MigrateServedTypes', 'test_keyspace/80-', 'rdonly'],
auto_log=True)
self._check_srv_keyspace('test_nj', 'test_keyspace',
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-\n' +
'Partitions(rdonly): -80 80-C0 C0-\n' +
'Partitions(replica): -80 80-\n' +
@ -510,7 +510,7 @@ primary key (name)
# then serve replica from the split shards
utils.run_vtctl(['MigrateServedTypes', 'test_keyspace/80-', 'replica'],
auto_log=True)
self._check_srv_keyspace('test_nj', 'test_keyspace',
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-\n' +
'Partitions(rdonly): -80 80-C0 C0-\n' +
'Partitions(replica): -80 80-C0 C0-\n' +
@ -519,14 +519,14 @@ primary key (name)
# move replica back and forth
utils.run_vtctl(['MigrateServedTypes', '-reverse', 'test_keyspace/80-', 'replica'],
auto_log=True)
self._check_srv_keyspace('test_nj', 'test_keyspace',
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-\n' +
'Partitions(rdonly): -80 80-C0 C0-\n' +
'Partitions(replica): -80 80-\n' +
'TabletTypes: master,rdonly,replica')
utils.run_vtctl(['MigrateServedTypes', 'test_keyspace/80-', 'replica'],
auto_log=True)
self._check_srv_keyspace('test_nj', 'test_keyspace',
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-\n' +
'Partitions(rdonly): -80 80-C0 C0-\n' +
'Partitions(replica): -80 80-C0 C0-\n' +
@ -567,7 +567,7 @@ primary key (name)
# then serve master from the split shards
utils.run_vtctl(['MigrateServedTypes', 'test_keyspace/80-', 'master'],
auto_log=True)
self._check_srv_keyspace('test_nj', 'test_keyspace',
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-C0 C0-\n' +
'Partitions(rdonly): -80 80-C0 C0-\n' +
'Partitions(replica): -80 80-C0 C0-\n' +
@ -583,27 +583,5 @@ primary key (name)
shard_2_master, shard_2_replica1, shard_2_replica2,
shard_3_master, shard_3_replica, shard_3_rdonly])
def _check_srv_keyspace(self, cell, keyspace, expected):
ks = utils.run_vtctl_json(['GetSrvKeyspace', cell, keyspace])
result = ""
for tablet_type in sorted(ks['Partitions'].keys()):
result += "Partitions(%s):" % tablet_type
partition = ks['Partitions'][tablet_type]
for shard in partition['Shards']:
result = result + " %s-%s" % (shard['KeyRange']['Start'],
shard['KeyRange']['End'])
result += "\n"
result += "TabletTypes: " + ",".join(sorted(ks['TabletTypes']))
logging.debug("Cell %s keyspace %s has data:\n%s", cell, keyspace, result)
self.assertEqual(expected, result,
"Mismatch in srv keyspace for cell %s keyspace %s, expected:\n%s\ngot:\n%s" % (
cell, keyspace, expected, result))
self.assertEqual('keyspace_id', ks.get('ShardingColumnName'),
"Got wrong ShardingColumnName in SrvKeyspace: %s" %
str(ks))
self.assertEqual(keyspace_id_type, ks.get('ShardingColumnType'),
"Got wrong ShardingColumnType in SrvKeyspace: %s" %
str(ks))
if __name__ == '__main__':
utils.main()

Просмотреть файл

@ -138,6 +138,8 @@ def kill_sub_processes():
logging.debug("kill_sub_processes: %s", str(e))
def kill_sub_process(proc):
if proc is None:
return
pid = proc.pid
proc.kill()
if pid and pid in pid_map:
@ -343,6 +345,8 @@ def vtgate_start(cell='test_nj', retry_delay=1, retry_count=1, topo_impl=None, t
return sp, port
def vtgate_kill(sp):
if sp is None:
return
kill_sub_process(sp)
sp.wait()
@ -483,3 +487,25 @@ def wait_db_read_only(uid):
logging.warning("wait_db_read_only: %s", str(e))
time.sleep(1.0)
raise e
def check_srv_keyspace(cell, keyspace, expected, keyspace_id_type='uint64'):
ks = run_vtctl_json(['GetSrvKeyspace', cell, keyspace])
result = ""
for tablet_type in sorted(ks['TabletTypes']):
result += "Partitions(%s):" % tablet_type
partition = ks['Partitions'][tablet_type]
for shard in partition['Shards']:
result = result + " %s-%s" % (shard['KeyRange']['Start'],
shard['KeyRange']['End'])
result += "\n"
result += "TabletTypes: " + ",".join(sorted(ks['TabletTypes']))
logging.debug("Cell %s keyspace %s has data:\n%s", cell, keyspace, result)
if expected != result:
raise Exception("Mismatch in srv keyspace for cell %s keyspace %s, expected:\n%s\ngot:\n%s" % (
cell, keyspace, expected, result))
if 'keyspace_id' != ks.get('ShardingColumnName'):
raise Exception("Got wrong ShardingColumnName in SrvKeyspace: %s" %
str(ks))
if keyspace_id_type != ks.get('ShardingColumnType'):
raise Exception("Got wrong ShardingColumnType in SrvKeyspace: %s" %
str(ks))

Просмотреть файл

@ -13,11 +13,17 @@ import tablet
import utils
from net import gorpc
from vtdb import tablet as tablet3
from vtdb import cursor
from vtdb import vtclient
from vtdb import vtgate
from vtdb import dbexceptions
from zk import zkocc
VTGATE_PROTOCOL_TABLET = 'v0'
VTGATE_PROTOCOL_V1BSON = 'v1bson'
vtgate_protocol = VTGATE_PROTOCOL_TABLET
shard_0_master = tablet.Tablet()
shard_0_replica = tablet.Tablet()
@ -27,15 +33,30 @@ shard_1_replica = tablet.Tablet()
vtgate_server = None
vtgate_port = None
shard_names = ['-80', '80-']
shard_kid_map = {'-80': [527875958493693904, 626750931627689502,
345387386794260318, 332484755310826578,
1842642426274125671, 1326307661227634652,
1761124146422844620, 1661669973250483744,
3361397649937244239, 2444880764308344533],
'80-': [9767889778372766922, 9742070682920810358,
10296850775085416642, 9537430901666854108,
10440455099304929791, 11454183276974683945,
11185910247776122031, 10460396697869122981,
13379616110062597001, 12826553979133932576],
}
create_vt_insert_test = '''create table vt_insert_test (
id bigint auto_increment,
msg varchar(64),
keyspace_id bigint(20) unsigned NOT NULL,
primary key (id)
) Engine=InnoDB'''
create_vt_a = '''create table vt_a (
eid bigint,
id int,
keyspace_id bigint(20) unsigned NOT NULL,
primary key(eid, id)
) Engine=InnoDB'''
@ -58,6 +79,7 @@ def setUpModule():
raise
def tearDownModule():
global vtgate_server
logging.debug("in tearDownModule")
if utils.options.skip_teardown:
return
@ -89,55 +111,84 @@ def setup_tablets():
# Start up a master mysql and vttablet
logging.debug("Setting up tablets")
utils.run_vtctl(['CreateKeyspace', 'test_keyspace'])
shard_0_master.init_tablet('master', keyspace='test_keyspace', shard='0')
shard_0_replica.init_tablet('replica', keyspace='test_keyspace', shard='0')
shard_1_master.init_tablet('master', keyspace='test_keyspace', shard='1')
shard_1_replica.init_tablet('replica', keyspace='test_keyspace', shard='1')
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/0'], auto_log=True)
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/1'], auto_log=True)
utils.validate_topology()
shard_0_master.create_db(shard_0_master.dbname)
shard_0_replica.create_db(shard_0_master.dbname)
shard_1_master.create_db(shard_0_master.dbname)
shard_1_replica.create_db(shard_0_master.dbname)
for t in [shard_0_master, shard_0_replica, shard_1_master, shard_1_replica]:
t.mquery(shard_0_master.dbname, create_vt_insert_test)
t.mquery(shard_0_master.dbname, create_vt_a)
utils.run_vtctl(['SetKeyspaceShardingInfo', '-force', 'test_keyspace',
'keyspace_id', 'uint64'])
shard_0_master.init_tablet('master', keyspace='test_keyspace', shard='-80')
shard_0_replica.init_tablet('replica', keyspace='test_keyspace', shard='-80')
shard_1_master.init_tablet('master', keyspace='test_keyspace', shard='80-')
shard_1_replica.init_tablet('replica', keyspace='test_keyspace', shard='80-')
utils.run_vtctl(['RebuildKeyspaceGraph', 'test_keyspace'], auto_log=True)
vtgate_server, vtgate_port = utils.vtgate_start()
for t in [shard_0_master, shard_0_replica, shard_1_master, shard_1_replica]:
t.create_db('vt_test_keyspace')
t.mquery(shard_0_master.dbname, create_vt_insert_test)
t.mquery(shard_0_master.dbname, create_vt_a)
t.start_vttablet(wait_for_state=None)
for t in [shard_0_master, shard_0_replica, shard_1_master, shard_1_replica]:
t.wait_for_vttablet_state('SERVING')
utils.run_vtctl(['ReparentShard', '-force', 'test_keyspace/0',
utils.run_vtctl(['ReparentShard', '-force', 'test_keyspace/-80',
shard_0_master.tablet_alias], auto_log=True)
utils.run_vtctl(['ReparentShard', '-force', 'test_keyspace/1',
utils.run_vtctl(['ReparentShard', '-force', 'test_keyspace/80-',
shard_1_master.tablet_alias], auto_log=True)
utils.run_vtctl(['RebuildKeyspaceGraph', '-use-served-types', 'test_keyspace'],
auto_log=True)
utils.check_srv_keyspace('test_nj', 'test_keyspace',
'Partitions(master): -80 80-\n' +
'Partitions(replica): -80 80-\n' +
'TabletTypes: master,replica')
vtgate_server, vtgate_port = utils.vtgate_start()
def get_master_connection(shard_index=0, user=None, password=None):
global vtgate_protocol
global vtgate_port
timeout = 10.0
master_conn = None
shard = shard_names[shard_index]
if vtgate_protocol == VTGATE_PROTOCOL_TABLET:
vtgate_addrs = []
elif vtgate_protocol == VTGATE_PROTOCOL_V1BSON:
vtgate_addrs = ["localhost:%s"%(vtgate_port),]
else:
raise Exception("Unknown vtgate_protocol %s", vtgate_protocol)
def get_master_connection(shard='1', user=None, password=None):
logging.debug("connecting to master with params")
vtgate_client = zkocc.ZkOccConnection("localhost:%u" % vtgate_port,
"test_nj", 30.0)
master_conn = vtclient.VtOCCConnection(vtgate_client, 'test_keyspace', shard,
"master", 10.0,
user=user, password=password)
"master", timeout,
user=user, password=password,
vtgate_protocol=vtgate_protocol,
vtgate_addrs=vtgate_addrs)
master_conn.connect()
return master_conn
def get_replica_connection(shard='1', user=None, password=None):
def get_replica_connection(shard_index=0, user=None, password=None):
global vtgate_protocol
global vtgate_port
logging.debug("connecting to replica with params %s %s", user, password)
timeout = 10.0
shard = shard_names[shard_index]
if vtgate_protocol == VTGATE_PROTOCOL_TABLET:
vtgate_addrs = []
elif vtgate_protocol == VTGATE_PROTOCOL_V1BSON:
vtgate_addrs = ["localhost:%s"%(vtgate_port),]
else:
raise Exception("Unknown vtgate_protocol %s", vtgate_protocol)
vtgate_client = zkocc.ZkOccConnection("localhost:%u" % vtgate_port,
"test_nj", 30.0)
replica_conn = vtclient.VtOCCConnection(vtgate_client, 'test_keyspace', shard,
"replica", 10.0,
user=user, password=password)
"replica", timeout,
user=user, password=password,
vtgate_protocol=vtgate_protocol,
vtgate_addrs=vtgate_addrs)
replica_conn.connect()
return replica_conn
@ -145,39 +196,56 @@ def do_write(count):
master_conn = get_master_connection()
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
kid_list = shard_kid_map[master_conn.shard]
for x in xrange(count):
master_conn._execute("insert into vt_insert_test (msg) values (%(msg)s)",
{'msg': 'test %s' % x})
keyspace_id = kid_list[count%len(kid_list)]
master_conn._execute("insert into vt_insert_test (msg, keyspace_id) values (%(msg)s, %(keyspace_id)s)",
{'msg': 'test %s' % x, 'keyspace_id': keyspace_id})
master_conn.commit()
class TestTabletFunctions(unittest.TestCase):
def setUp(self):
self.shard_index = 0
self.master_tablet = shard_0_master
self.replica_tablet = shard_0_replica
def test_connect(self):
global vtgate_protocol
try:
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 master failed with error %s" % str(e))
self.assertNotEqual(master_conn, None)
self.assertIsInstance(master_conn, vtclient.VtOCCConnection,
"Invalid master connection")
try:
replica_conn = get_replica_connection()
replica_conn = get_replica_connection(shard_index=self.shard_index)
except Exception, e:
logging.debug("Connection to shard0 replica failed with error %s" %
str(e))
logging.debug("Connection to %s replica failed with error %s" %
(shard_names[self.shard_index], str(e)))
raise
self.assertNotEqual(replica_conn, None)
self.assertIsInstance(replica_conn, vtclient.VtOCCConnection,
"Invalid replica connection")
if vtgate_protocol == VTGATE_PROTOCOL_TABLET:
self.assertIsInstance(master_conn.conn, tablet3.TabletConnection,
"Invalid master connection")
self.assertIsInstance(replica_conn.conn, tablet3.TabletConnection,
"Invalid replica connection")
elif vtgate_protocol == VTGATE_PROTOCOL_V1BSON:
self.assertIsInstance(master_conn.conn, vtgate.VtgateConnection,
"Invalid master connection")
self.assertIsInstance(replica_conn.conn, vtgate.VtgateConnection,
"Invalid replica connection")
def test_writes(self):
try:
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
count = 10
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
kid_list = shard_kid_map[master_conn.shard]
for x in xrange(count):
master_conn._execute("insert into vt_insert_test (msg) values (%(msg)s)", {'msg': 'test %s' % x})
keyspace_id = kid_list[count%len(kid_list)]
master_conn._execute("insert into vt_insert_test (msg, keyspace_id) values (%(msg)s, %(keyspace_id)s)",
{'msg': 'test %s' % x, 'keyspace_id': keyspace_id})
master_conn.commit()
results, rowcount = master_conn._execute("select * from vt_insert_test",
{})[:2]
@ -188,17 +256,23 @@ class TestTabletFunctions(unittest.TestCase):
def test_batch_read(self):
try:
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
count = 10
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
kid_list = shard_kid_map[master_conn.shard]
for x in xrange(count):
master_conn._execute("insert into vt_insert_test (msg) values (%(msg)s)", {'msg': 'test %s' % x})
keyspace_id = kid_list[count%len(kid_list)]
master_conn._execute("insert into vt_insert_test (msg, keyspace_id) values (%(msg)s, %(keyspace_id)s)",
{'msg': 'test %s' % x, 'keyspace_id': keyspace_id})
master_conn.commit()
master_conn.begin()
master_conn._execute("delete from vt_a", {})
for x in xrange(count):
master_conn._execute("insert into vt_a (eid, id) values (%(eid)s, %(id)s)", {'eid': x, 'id': x})
keyspace_id = kid_list[count%len(kid_list)]
master_conn._execute("insert into vt_a (eid, id, keyspace_id) \
values (%(eid)s, %(id)s, %(keyspace_id)s)",
{'eid': x, 'id': x, 'keyspace_id': keyspace_id})
master_conn.commit()
rowsets = master_conn._execute_batch(["select * from vt_insert_test",
"select * from vt_a"], [{}, {}])
@ -210,20 +284,23 @@ class TestTabletFunctions(unittest.TestCase):
def test_batch_write(self):
try:
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
count = 10
query_list = []
bind_vars_list = []
query_list.append("delete from vt_insert_test")
bind_vars_list.append({})
kid_list = shard_kid_map[master_conn.shard]
for x in xrange(count):
query_list.append("insert into vt_insert_test (msg) values (%(msg)s)")
bind_vars_list.append({'msg': 'test %s' % x})
keyspace_id = kid_list[count%len(kid_list)]
query_list.append("insert into vt_insert_test (msg, keyspace_id) values (%(msg)s, %(keyspace_id)s)")
bind_vars_list.append({'msg': 'test %s' % x, 'keyspace_id': keyspace_id})
query_list.append("delete from vt_a")
bind_vars_list.append({})
for x in xrange(count):
query_list.append("insert into vt_a (eid, id) values (%(eid)s, %(id)s)")
bind_vars_list.append({'eid': x, 'id': x})
keyspace_id = kid_list[count%len(kid_list)]
query_list.append("insert into vt_a (eid, id, keyspace_id) values (%(eid)s, %(id)s, %(keyspace_id)s)")
bind_vars_list.append({'eid': x, 'id': x, 'keyspace_id': keyspace_id})
master_conn.begin()
master_conn._execute_batch(query_list, bind_vars_list)
master_conn.commit()
@ -239,7 +316,7 @@ class TestTabletFunctions(unittest.TestCase):
count = 100
do_write(count)
# Fetch a subset of the total size.
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
stream_cursor = cursor.StreamCursor(master_conn)
stream_cursor.execute("select * from vt_insert_test", {})
fetch_size = 10
@ -257,7 +334,7 @@ class TestTabletFunctions(unittest.TestCase):
count = 100
do_write(count)
# Fetch all.
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
stream_cursor = cursor.StreamCursor(master_conn)
stream_cursor.execute("select * from vt_insert_test", {})
rows = stream_cursor.fetchall()
@ -274,7 +351,7 @@ class TestTabletFunctions(unittest.TestCase):
count = 100
do_write(count)
# Fetch one.
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
stream_cursor = cursor.StreamCursor(master_conn)
stream_cursor.execute("select * from vt_insert_test", {})
rows = stream_cursor.fetchone()
@ -285,7 +362,7 @@ class TestTabletFunctions(unittest.TestCase):
def test_streaming_zero_results(self):
try:
master_conn = get_master_connection()
master_conn = get_master_connection(shard_index=self.shard_index)
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
@ -302,30 +379,36 @@ class TestTabletFunctions(unittest.TestCase):
class TestFailures(unittest.TestCase):
def setUp(self):
self.shard_index = 0
self.master_tablet = shard_0_master
self.replica_tablet = shard_0_replica
def test_tablet_restart_read(self):
try:
replica_conn = get_replica_connection()
replica_conn = get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
shard_1_replica.kill_vttablet()
self.fail("Connection to shard %s replica failed with error %s" % (shard_names[self.shard_index], str(e)))
self.replica_tablet.kill_vttablet()
with self.assertRaises(dbexceptions.OperationalError):
replica_conn._execute("select 1 from vt_insert_test", {})
proc = shard_1_replica.start_vttablet()
proc = self.replica_tablet.start_vttablet()
try:
results = replica_conn._execute("select 1 from vt_insert_test", {})
except Exception, e:
self.fail("Communication with shard0 replica failed with error %s" % str(e))
self.fail("Communication with shard %s replica failed with error %s" % (shard_names[self.shard_index], str(e)))
def test_tablet_restart_stream_execute(self):
try:
replica_conn = get_replica_connection()
replica_conn = get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
self.fail("Connection to %s replica failed with error %s" % (shard_names[self.shard_index], str(e)))
stream_cursor = cursor.StreamCursor(replica_conn)
shard_1_replica.kill_vttablet()
self.replica_tablet.kill_vttablet()
with self.assertRaises(dbexceptions.OperationalError):
stream_cursor.execute("select * from vt_insert_test", {})
proc = shard_1_replica.start_vttablet()
proc = self.replica_tablet.start_vttablet()
self.replica_tablet.wait_for_vttablet_state('SERVING')
try:
stream_cursor.execute("select * from vt_insert_test", {})
except Exception, e:
@ -337,10 +420,10 @@ class TestFailures(unittest.TestCase):
master_conn = get_master_connection()
except Exception, e:
self.fail("Connection to shard0 master failed with error %s" % str(e))
shard_1_master.kill_vttablet()
self.master_tablet.kill_vttablet()
with self.assertRaises(dbexceptions.OperationalError):
master_conn.begin()
proc = shard_1_master.start_vttablet()
proc = self.master_tablet.start_vttablet()
master_conn.begin()
def test_tablet_fail_write(self):
@ -350,19 +433,21 @@ class TestFailures(unittest.TestCase):
self.fail("Connection to shard0 master failed with error %s" % str(e))
with self.assertRaises(dbexceptions.OperationalError):
master_conn.begin()
shard_1_master.kill_vttablet()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
proc = shard_1_master.start_vttablet()
with self.assertRaises(dbexceptions.OperationalError):
master_conn.begin()
shard_1_master.kill_vttablet()
self.master_tablet.kill_vttablet()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
proc = self.master_tablet.start_vttablet()
try:
master_conn = get_master_connection()
except Exception, e:
self.fail("Connection to shard0 master failed with error %s" % str(e))
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
def test_query_timeout(self):
try:
replica_conn = get_replica_connection()
replica_conn = get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
with self.assertRaises(dbexceptions.TimeoutError):
@ -377,15 +462,15 @@ class TestFailures(unittest.TestCase):
def test_restart_mysql_failure(self):
try:
replica_conn = get_replica_connection()
replica_conn = get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
utils.wait_procs([shard_1_replica.shutdown_mysql(),])
utils.wait_procs([self.replica_tablet.shutdown_mysql(),])
with self.assertRaises(dbexceptions.DatabaseError):
replica_conn._execute("select 1 from vt_insert_test", {})
utils.wait_procs([shard_1_replica.start_mysql(),])
shard_1_replica.kill_vttablet()
shard_1_replica.start_vttablet()
utils.wait_procs([self.replica_tablet.start_mysql(),])
self.replica_tablet.kill_vttablet()
self.replica_tablet.start_vttablet()
replica_conn._execute("select 1 from vt_insert_test", {})
def test_retry_txn_pool_full(self):
@ -404,8 +489,10 @@ class TestFailures(unittest.TestCase):
class TestAuthentication(unittest.TestCase):
def setUp(self):
shard_1_replica.kill_vttablet()
shard_1_replica.start_vttablet(auth=True)
self.shard_index = 0
self.replica_tablet = shard_0_replica
self.replica_tablet.kill_vttablet()
self.replica_tablet.start_vttablet(auth=True)
credentials_file_name = os.path.join(environment.vttop, 'test', 'test_data',
'authcredentials_test.json')
credentials_file = open(credentials_file_name, 'r')
@ -416,7 +503,7 @@ class TestAuthentication(unittest.TestCase):
def test_correct_credentials(self):
try:
replica_conn = get_replica_connection(user=self.user,
replica_conn = get_replica_connection(shard_index = self.shard_index, user=self.user,
password=self.password)
replica_conn.connect()
finally:
@ -424,7 +511,7 @@ class TestAuthentication(unittest.TestCase):
def test_secondary_credentials(self):
try:
replica_conn = get_replica_connection(user=self.user,
replica_conn = get_replica_connection(shard_index = self.shard_index, user=self.user,
password=self.secondary_password)
replica_conn.connect()
finally:
@ -432,16 +519,16 @@ class TestAuthentication(unittest.TestCase):
def test_incorrect_user(self):
with self.assertRaises(dbexceptions.OperationalError):
replica_conn = get_replica_connection(user="romek", password="ma raka")
replica_conn = get_replica_connection(shard_index = self.shard_index, user="romek", password="ma raka")
replica_conn.connect()
def test_incorrect_credentials(self):
with self.assertRaises(dbexceptions.OperationalError):
replica_conn = get_replica_connection(user=self.user, password="ma raka")
replica_conn = get_replica_connection(shard_index = self.shard_index, user=self.user, password="ma raka")
replica_conn.connect()
def test_challenge_is_used(self):
replica_conn = get_replica_connection(user=self.user,
replica_conn = get_replica_connection(shard_index = self.shard_index, user=self.user,
password=self.password)
replica_conn.connect()
challenge = ""
@ -451,7 +538,7 @@ class TestAuthentication(unittest.TestCase):
'AuthenticatorCRAMMD5.Authenticate', {"Proof": proof})
def test_only_few_requests_are_allowed(self):
replica_conn = get_replica_connection(user=self.user,
replica_conn = get_replica_connection(shard_index = self.shard_index, user=self.user,
password=self.password)
replica_conn.connect()
for i in range(4):

146
test/vtgate_test.py Executable file
Просмотреть файл

@ -0,0 +1,146 @@
#!/usr/bin/python
#
# Copyright 2013, Google Inc. All rights reserved.
# Use of this source code is governed by a BSD-style license that can
# be found in the LICENSE file.
import unittest
import vtdb_test
import utils
from vtdb import cursor
from vtdb import dbexceptions
def setUpModule():
vtdb_test.setUpModule()
def tearDownModule():
vtdb_test.tearDownModule()
class TestClientApi(vtdb_test.TestTabletFunctions):
pass
# FIXME(shrutip): this class needs reworking once
# the error handling is resolved the right way at vtgate binary.
class TestFailures(unittest.TestCase):
def setUp(self):
self.shard_index = 0
self.master_tablet = vtdb_test.shard_0_master
self.replica_tablet = vtdb_test.shard_0_replica
def test_tablet_restart_read(self):
try:
replica_conn = vtdb_test.get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard %s replica failed with error %s" % (shard_names[self.shard_index], str(e)))
self.replica_tablet.kill_vttablet()
with self.assertRaises(dbexceptions.DatabaseError):
replica_conn._execute("select 1 from vt_insert_test", {})
proc = self.replica_tablet.start_vttablet()
try:
results = replica_conn._execute("select 1 from vt_insert_test", {})
except Exception, e:
self.fail("Communication with shard %s replica failed with error %s" % (shard_names[self.shard_index], str(e)))
def test_tablet_restart_stream_execute(self):
try:
replica_conn = vtdb_test.get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
stream_cursor = cursor.StreamCursor(replica_conn)
self.replica_tablet.kill_vttablet()
# FIXME(shrutip): this sometimes throws a TimeoutError but catching
# DatabaseError as that is a superclass anyways.
with self.assertRaises(dbexceptions.DatabaseError):
stream_cursor.execute("select * from vt_insert_test", {})
proc = self.replica_tablet.start_vttablet()
try:
# This goes through a reconnect loop since connection to vtgate is closed
# by the timeout error above.
stream_cursor.execute("select * from vt_insert_test", {})
except Exception, e:
self.fail("Communication with shard0 replica failed with error %s" %
str(e))
# vtgate begin doesn't make any back-end connections to
# vttablet so the kill and restart shouldn't have any effect.
def test_tablet_restart_begin(self):
try:
master_conn = vtdb_test.get_master_connection()
except Exception, e:
self.fail("Connection to shard0 master failed with error %s" % str(e))
self.master_tablet.kill_vttablet()
master_conn.begin()
proc = self.master_tablet.start_vttablet()
master_conn.begin()
def test_tablet_fail_write(self):
try:
master_conn = vtdb_test.get_master_connection()
except Exception, e:
self.fail("Connection to shard0 master failed with error %s" % str(e))
with self.assertRaises(dbexceptions.DatabaseError):
master_conn.begin()
self.master_tablet.kill_vttablet()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
proc = self.master_tablet.start_vttablet()
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
def test_query_timeout(self):
try:
replica_conn = vtdb_test.get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
with self.assertRaises(dbexceptions.TimeoutError):
replica_conn._execute("select sleep(12) from dual", {})
try:
master_conn = vtdb_test.get_master_connection()
except Exception, e:
self.fail("Connection to shard0 master failed with error %s" % str(e))
with self.assertRaises(dbexceptions.TimeoutError):
master_conn._execute("select sleep(12) from dual", {})
# FIXME(shrutip): flaky test, making it NOP for now
def test_restart_mysql_failure(self):
return
try:
replica_conn = vtdb_test.get_replica_connection(shard_index=self.shard_index)
except Exception, e:
self.fail("Connection to shard0 replica failed with error %s" % str(e))
utils.wait_procs([self.replica_tablet.shutdown_mysql(),])
with self.assertRaises(dbexceptions.DatabaseError):
replica_conn._execute("select 1 from vt_insert_test", {})
utils.wait_procs([self.replica_tablet.start_mysql(),])
self.replica_tablet.kill_vttablet()
self.replica_tablet.start_vttablet()
self.replica_tablet.wait_for_vttablet_state('SERVING')
replica_conn._execute("select 1 from vt_insert_test", {})
# FIXME(shrutip): this test is basically just testing that
# txn pool full error doesn't get thrown anymore with vtgate.
# vtgate retries for this condition. Not a very high value
# test at this point, could be removed if there is coverage at vtgate level.
def test_retry_txn_pool_full(self):
master_conn = vtdb_test.get_master_connection()
master_conn._execute("set vt_transaction_cap=1", {})
master_conn.begin()
master_conn2 = vtdb_test.get_master_connection()
master_conn2.begin()
master_conn.commit()
master_conn._execute("set vt_transaction_cap=20", {})
master_conn.begin()
master_conn._execute("delete from vt_insert_test", {})
master_conn.commit()
# this test is just re-running an entire vtdb_test.py with a
# client type VTGate
if __name__ == '__main__':
vtdb_test.vtgate_protocol = vtdb_test.VTGATE_PROTOCOL_V1BSON
utils.main()