Merge branch 'master' into gtid

This commit is contained in:
Anthony Yeh 2014-07-28 14:46:14 -07:00
Родитель 150d8da3c3 2205f4ef40
Коммит 8206436dee
54 изменённых файлов: 2636 добавлений и 1713 удалений

Просмотреть файл

@ -1,13 +1,64 @@
create table a(abcd)#{"Action": "create", "NewName": "a"}
drop table b#{"Action": "drop", "TableName": "b"}
alter table c alter foo#{"Action": "alter", "TableName": "c", "NewTable": "c"}
alter table c comment 'aa'#{"Action": "alter", "TableName": "c", "NewTable": "c"}
drop index a on b#{"Action": "alter", "TableName": "b", "NewName": "b"}
rename table a to b#{"Action": "rename", "TableName": "a", "NewTable": "b"}
alter table a rename b#{"Action": "rename", "TableName": "a", "NewTable": "b"}
alter table a rename to b#{"Action": "rename", "TableName": "a", "NewTable": "b"}
create view a asdasd#{"Action": "create", "NewName": "a"}
alter view c alter foo#{"Action": "alter", "TableName": "c", "NewTable": "c"}
drop view b#{"Action": "drop", "TableName": "b"}
select * from a#{"Action": ""}
syntax error#{"Action": ""}
"create table a(abcd)"
{
"Action": "create", "NewName": "a"
}
"drop table b"
{
"Action": "drop", "TableName": "b"
}
"alter table c alter foo"
{
"Action": "alter", "TableName": "c", "NewTable": "c"
}
"alter table c comment 'aa'"
{
"Action": "alter", "TableName": "c", "NewTable": "c"
}
"drop index a on b"
{
"Action": "alter", "TableName": "b", "NewName": "b"
}
"rename table a to b"
{
"Action": "rename", "TableName": "a", "NewTable": "b"
}
"alter table a rename b"
{
"Action": "rename", "TableName": "a", "NewTable": "b"
}
"alter table a rename to b"
{
"Action": "rename", "TableName": "a", "NewTable": "b"
}
"create view a asdasd"
{
"Action": "create", "NewName": "a"
}
"alter view c alter foo"
{
"Action": "alter", "TableName": "c", "NewTable": "c"
}
"drop view b"
{
"Action": "drop", "TableName": "b"
}
"select * from a"
{
"Action": ""
}
"syntax error"
{
"Action": ""
}

Просмотреть файл

@ -2057,7 +2057,7 @@
{
"PlanId":"DDL",
"Reason":"DEFAULT",
"TableName":"",
"TableName":"a",
"FieldQuery":null,
"FullQuery":null,
"OuterQuery":null,
@ -2076,7 +2076,7 @@
{
"PlanId":"DDL",
"Reason":"DEFAULT",
"TableName":"",
"TableName":"a",
"FieldQuery":null,
"FullQuery":null,
"OuterQuery":null,
@ -2095,7 +2095,7 @@
{
"PlanId":"DDL",
"Reason":"DEFAULT",
"TableName":"",
"TableName":"a",
"FieldQuery":null,
"FullQuery":null,
"OuterQuery":null,
@ -2114,7 +2114,7 @@
{
"PlanId":"DDL",
"Reason":"DEFAULT",
"TableName":"",
"TableName":"a",
"FieldQuery":null,
"FullQuery":null,
"OuterQuery":null,

Просмотреть файл

@ -1,7 +1,39 @@
# select
"select * from a"
{
"FullQuery": "select * from a"
"PlanId":"PASS_SELECT",
"Reason":"DEFAULT",
"TableName":"a",
"FieldQuery":null,
"FullQuery":"select * from a",
"OuterQuery":null,
"Subquery":null,
"IndexUsed":"",
"ColumnNumbers":null,
"PKValues":null,
"SecondaryPKValues":null,
"SubqueryPKColumns":null,
"SetKey":"",
"SetValue":null
}
# select join
"select * from a join b"
{
"PlanId":"PASS_SELECT",
"Reason":"DEFAULT",
"TableName":"",
"FieldQuery":null,
"FullQuery":"select * from a join b",
"OuterQuery":null,
"Subquery":null,
"IndexUsed":"",
"ColumnNumbers":null,
"PKValues":null,
"SecondaryPKValues":null,
"SubqueryPKColumns":null,
"SetKey":"",
"SetValue":null
}
# select for update
@ -11,7 +43,20 @@
# union
"select * from a union select * from b"
{
"FullQuery": "select * from a union select * from b"
"PlanId":"PASS_SELECT",
"Reason":"DEFAULT",
"TableName":"",
"FieldQuery":null,
"FullQuery": "select * from a union select * from b",
"OuterQuery":null,
"Subquery":null,
"IndexUsed":"",
"ColumnNumbers":null,
"PKValues":null,
"SecondaryPKValues":null,
"SubqueryPKColumns":null,
"SetKey":"",
"SetValue":null
}
# dml

Просмотреть файл

@ -13,6 +13,7 @@ import (
"github.com/youtube/vitess/go/vt/dbconfigs"
"github.com/youtube/vitess/go/vt/mysqlctl"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tabletserver"
)
@ -21,6 +22,7 @@ var (
enableRowcache = flag.Bool("enable-rowcache", false, "enable rowcacche")
enableInvalidator = flag.Bool("enable-invalidator", false, "enable rowcache invalidator")
binlogPath = flag.String("binlog-path", "", "binlog path used by rowcache invalidator")
tableAclConfig = flag.String("table-acl-config", "", "path to table access checker config file")
)
var schemaOverrides []tabletserver.SchemaOverride
@ -55,6 +57,9 @@ func main() {
data, _ := json.MarshalIndent(schemaOverrides, "", " ")
log.Infof("schemaOverrides: %s\n", data)
if *tableAclConfig != "" {
tableacl.Init(*tableAclConfig)
}
tabletserver.InitQueryService()
err = tabletserver.AllowQueries(&dbConfigs.App, schemaOverrides, tabletserver.LoadCustomRules(), mysqld, true)

Просмотреть файл

@ -14,6 +14,7 @@ import (
"github.com/youtube/vitess/go/vt/dbconfigs"
"github.com/youtube/vitess/go/vt/mysqlctl"
"github.com/youtube/vitess/go/vt/servenv"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tabletmanager"
"github.com/youtube/vitess/go/vt/tabletserver"
"github.com/youtube/vitess/go/vt/topo"
@ -23,6 +24,7 @@ var (
tabletPath = flag.String("tablet-path", "", "tablet alias or path to zk node representing the tablet")
enableRowcache = flag.Bool("enable-rowcache", false, "enable rowcacche")
overridesFile = flag.String("schema-override", "", "schema overrides file")
tableAclConfig = flag.String("table-acl-config", "", "path to table access checker config file")
agent *tabletmanager.ActionAgent
)
@ -77,6 +79,9 @@ func main() {
}
dbcfgs.App.EnableRowcache = *enableRowcache
if *tableAclConfig != "" {
tableacl.Init(*tableAclConfig)
}
tabletserver.InitQueryService()
binlog.RegisterUpdateStreamService(mycnf)

Просмотреть файл

@ -28,13 +28,13 @@ type Histogram struct {
// NewHistogram creates a histogram with auto-generated labels
// based on the cutoffs. The buckets are categorized using the
// following criterion: cutoff[i-1] < value <= cutoff[i]. Anything
// higher than the highest cutoff is labeled as "Max".
// higher than the highest cutoff is labeled as "inf".
func NewHistogram(name string, cutoffs []int64) *Histogram {
labels := make([]string, len(cutoffs)+1)
for i, v := range cutoffs {
labels[i] = fmt.Sprintf("%d", v)
}
labels[len(labels)-1] = "Max"
labels[len(labels)-1] = "inf"
return NewGenericHistogram(name, cutoffs, labels, "Count", "Total")
}
@ -85,8 +85,8 @@ func (h *Histogram) MarshalJSON() ([]byte, error) {
fmt.Fprintf(b, "{")
totalCount := int64(0)
for i, label := range h.labels {
fmt.Fprintf(b, "\"%v\": %v, ", label, h.buckets[i])
totalCount += h.buckets[i]
fmt.Fprintf(b, "\"%v\": %v, ", label, totalCount)
}
fmt.Fprintf(b, "\"%s\": %v, ", h.countLabel, totalCount)
fmt.Fprintf(b, "\"%s\": %v", h.totalLabel, h.total)
@ -128,3 +128,13 @@ func (h *Histogram) Total() (total int64) {
defer h.mu.Unlock()
return h.total
}
func (h *Histogram) Labels() []string {
return h.labels
}
func (h *Histogram) Buckets() []int64 {
h.mu.Lock()
defer h.mu.Unlock()
return h.buckets
}

Просмотреть файл

@ -15,7 +15,7 @@ func TestHistogram(t *testing.T) {
for i := 0; i < 10; i++ {
h.Add(int64(i))
}
want := `{"1": 2, "5": 4, "Max": 4, "Count": 10, "Total": 45}`
want := `{"1": 2, "5": 6, "inf": 10, "Count": 10, "Total": 45}`
if h.String() != want {
t.Errorf("want %s, got %s", want, h.String())
}
@ -26,8 +26,8 @@ func TestHistogram(t *testing.T) {
if counts["5"] != 4 {
t.Errorf("want 4, got %d", counts["2"])
}
if counts["Max"] != 4 {
t.Errorf("want 4, got %d", counts["Max"])
if counts["inf"] != 4 {
t.Errorf("want 4, got %d", counts["inf"])
}
if h.Count() != 10 {
t.Errorf("want 10, got %d", h.Count())

Просмотреть файл

@ -118,9 +118,9 @@ var bucketLabels []string
func init() {
bucketLabels = make([]string, len(bucketCutoffs)+1)
for i, v := range bucketCutoffs {
bucketLabels[i] = fmt.Sprintf("%.4f", float64(v)/1e9)
bucketLabels[i] = fmt.Sprintf("%d", v)
}
bucketLabels[len(bucketLabels)-1] = "Max"
bucketLabels[len(bucketLabels)-1] = "inf"
}
// MultiTimings is meant to tracks timing data by categories as well

Просмотреть файл

@ -16,7 +16,7 @@ func TestTimings(t *testing.T) {
tm.Add("tag1", 500*time.Microsecond)
tm.Add("tag1", 1*time.Millisecond)
tm.Add("tag2", 1*time.Millisecond)
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1":{"0.0005":1,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":2,"Time":1500000},"tag2":{"0.0005":0,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":1,"Time":1000000}}}`
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1":{"500000":1,"1000000":2,"5000000":2,"10000000":2,"50000000":2,"100000000":2,"500000000":2,"1000000000":2,"5000000000":2,"10000000000":2,"inf":2,"Count":2,"Time":1500000},"tag2":{"500000":0,"1000000":1,"5000000":1,"10000000":1,"50000000":1,"100000000":1,"500000000":1,"1000000000":1,"5000000000":1,"10000000000":1,"inf":1,"Count":1,"Time":1000000}}}`
if tm.String() != want {
t.Errorf("want %s, got %s", want, tm.String())
}
@ -28,7 +28,7 @@ func TestMultiTimings(t *testing.T) {
mtm.Add([]string{"tag1a", "tag1b"}, 500*time.Microsecond)
mtm.Add([]string{"tag1a", "tag1b"}, 1*time.Millisecond)
mtm.Add([]string{"tag2a", "tag2b"}, 1*time.Millisecond)
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1a.tag1b":{"0.0005":1,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":2,"Time":1500000},"tag2a.tag2b":{"0.0005":0,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":1,"Time":1000000}}}`
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1a.tag1b":{"500000":1,"1000000":2,"5000000":2,"10000000":2,"50000000":2,"100000000":2,"500000000":2,"1000000000":2,"5000000000":2,"10000000000":2,"inf":2,"Count":2,"Time":1500000},"tag2a.tag2b":{"500000":0,"1000000":1,"5000000":1,"10000000":1,"50000000":1,"100000000":1,"500000000":1,"1000000000":1,"5000000000":1,"10000000000":1,"inf":1,"Count":1,"Time":1000000}}}`
if mtm.String() != want {
t.Errorf("want %s, got %s", want, mtm.String())
}

Просмотреть файл

@ -15,8 +15,8 @@ import (
mproto "github.com/youtube/vitess/go/mysql/proto"
"github.com/youtube/vitess/go/vt/client2/tablet"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/topo"
"github.com/youtube/vitess/go/vt/vtgate"
"github.com/youtube/vitess/go/vt/zktopo"
"github.com/youtube/vitess/go/zk"
)
@ -233,7 +233,7 @@ func (sc *ShardedConn) Exec(query string, bindVars map[string]interface{}) (db.R
if sc.srvKeyspace == nil {
return nil, ErrNotConnected
}
shards, err := sqlparser.GetShardList(query, bindVars, sc.shardMaxKeys)
shards, err := vtgate.GetShardList(query, bindVars, sc.shardMaxKeys)
if err != nil {
return nil, err
}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -6,23 +6,12 @@ package sqlparser
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/youtube/vitess/go/testfiles"
"io"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"sort"
"strings"
"testing"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/testfiles"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/schema"
)
func TestGen(t *testing.T) {
@ -32,118 +21,6 @@ func TestGen(t *testing.T) {
}
}
var (
SQLZERO = sqltypes.MakeString([]byte("0"))
)
func TestExec(t *testing.T) {
testSchema := loadSchema("schema_test.json")
for tcase := range iterateExecFile("exec_cases.txt") {
plan, err := ExecParse(tcase.input, func(name string) (*schema.Table, bool) {
r, ok := testSchema[name]
return r, ok
})
var out string
if err != nil {
out = err.Error()
} else {
bout, err := json.Marshal(plan)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
}
out = string(bout)
}
if out != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
}
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
}
}
func TestCustomExec(t *testing.T) {
testSchemas := testfiles.Glob("sqlparser_test/*_schema.json")
if len(testSchemas) == 0 {
t.Log("No schemas to test")
return
}
for _, schemFile := range testSchemas {
schem := loadSchema(schemFile)
t.Logf("Testing schema %s", schemFile)
files, err := filepath.Glob(strings.Replace(schemFile, "schema.json", "*.txt", -1))
if err != nil {
log.Fatal(err)
}
if len(files) == 0 {
t.Fatalf("No test files for %s", schemFile)
}
getter := func(name string) (*schema.Table, bool) {
r, ok := schem[name]
return r, ok
}
for _, file := range files {
t.Logf("Testing file %s", file)
for tcase := range iterateExecFile(file) {
plan, err := ExecParse(tcase.input, getter)
var out string
if err != nil {
out = err.Error()
} else {
bout, err := json.Marshal(plan)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
}
out = string(bout)
}
if out != tcase.output {
t.Errorf("File: %s: Line:%v\n%s\n%s", file, tcase.lineno, tcase.output, out)
}
}
}
}
}
func TestStreamExec(t *testing.T) {
for tcase := range iterateExecFile("stream_cases.txt") {
plan, err := StreamExecParse(tcase.input)
var out string
if err != nil {
out = err.Error()
} else {
bout, err := json.Marshal(plan)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
}
out = string(bout)
}
if out != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
}
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
}
}
func TestDDL(t *testing.T) {
for tcase := range iterateFiles("sqlparser_test/ddl_cases.txt") {
plan := DDLParse(tcase.input)
expected := make(map[string]interface{})
err := json.Unmarshal([]byte(tcase.output), &expected)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v", plan))
}
matchString(t, tcase.lineno, expected["Action"], plan.Action)
matchString(t, tcase.lineno, expected["TableName"], plan.TableName)
matchString(t, tcase.lineno, expected["NewName"], plan.NewName)
}
}
func matchString(t *testing.T, line int, expected interface{}, actual string) {
if expected != nil {
if expected.(string) != actual {
t.Error(fmt.Sprintf("Line %d: expected: %v, received %s", line, expected, actual))
}
}
}
func TestParse(t *testing.T) {
for tcase := range iterateFiles("sqlparser_test/*.sql") {
if tcase.output == "" {
@ -162,47 +39,6 @@ func TestParse(t *testing.T) {
}
}
func TestRouting(t *testing.T) {
tabletkeys := []key.KeyspaceId{
"\x00\x00\x00\x00\x00\x00\x00\x02",
"\x00\x00\x00\x00\x00\x00\x00\x04",
"\x00\x00\x00\x00\x00\x00\x00\x06",
"a",
"b",
"d",
}
bindVariables := make(map[string]interface{})
bindVariables["id0"] = 0
bindVariables["id2"] = 2
bindVariables["id3"] = 3
bindVariables["id4"] = 4
bindVariables["id6"] = 6
bindVariables["id8"] = 8
bindVariables["ids"] = []interface{}{1, 4}
bindVariables["a"] = "a"
bindVariables["b"] = "b"
bindVariables["c"] = "c"
bindVariables["d"] = "d"
bindVariables["e"] = "e"
for tcase := range iterateFiles("sqlparser_test/routing_cases.txt") {
if tcase.output == "" {
tcase.output = tcase.input
}
out, err := GetShardList(tcase.input, bindVariables, tabletkeys)
if err != nil {
if err.Error() != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.input, err))
}
continue
}
sort.Ints(out)
outstr := fmt.Sprintf("%v", out)
if outstr != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, outstr))
}
}
}
func BenchmarkParse1(b *testing.B) {
sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'"
for i := 0; i < b.N; i++ {
@ -223,23 +59,6 @@ func BenchmarkParse2(b *testing.B) {
}
}
func loadSchema(name string) map[string]*schema.Table {
b, err := ioutil.ReadFile(locateFile(name))
if err != nil {
panic(err)
}
tables := make([]*schema.Table, 0, 8)
err = json.Unmarshal(b, &tables)
if err != nil {
panic(err)
}
s := make(map[string]*schema.Table)
for _, t := range tables {
s[t.Name] = t
}
return s
}
type testCase struct {
file string
lineno int
@ -284,71 +103,3 @@ func iterateFiles(pattern string) (testCaseIterator chan testCase) {
}()
return testCaseIterator
}
func iterateExecFile(name string) (testCaseIterator chan testCase) {
name = locateFile(name)
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
panic(fmt.Sprintf("Could not open file %s", name))
}
testCaseIterator = make(chan testCase)
go func() {
defer close(testCaseIterator)
r := bufio.NewReader(fd)
lineno := 0
for {
binput, err := r.ReadBytes('\n')
if err != nil {
if err != io.EOF {
fmt.Printf("Line: %d\n", lineno)
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
}
break
}
lineno++
input := string(binput)
if input == "" || input == "\n" || input[0] == '#' || strings.HasPrefix(input, "Length:") {
//fmt.Printf("%s\n", input)
continue
}
err = json.Unmarshal(binput, &input)
if err != nil {
fmt.Printf("Line: %d, input: %s\n", lineno, binput)
panic(err)
}
input = strings.Trim(input, "\"")
var output []byte
for {
l, err := r.ReadBytes('\n')
lineno++
if err != nil {
fmt.Printf("Line: %d\n", lineno)
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
}
output = append(output, l...)
if l[0] == '}' {
output = output[:len(output)-1]
b := bytes.NewBuffer(make([]byte, 0, 64))
if err := json.Compact(b, output); err == nil {
output = b.Bytes()
}
break
}
if l[0] == '"' {
output = output[1 : len(output)-2]
break
}
}
testCaseIterator <- testCase{name, lineno, input, string(output)}
}
}()
return testCaseIterator
}
func locateFile(name string) string {
if path.IsAbs(name) {
return name
}
return testfiles.Locate("sqlparser_test/" + name)
}

Просмотреть файл

@ -1,254 +0,0 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sqlparser
import (
"strconv"
"github.com/youtube/vitess/go/vt/key"
)
const (
EID_NODE = iota
VALUE_NODE
LIST_NODE
OTHER_NODE
)
type RoutingPlan struct {
criteria SQLNode
}
func GetShardList(sql string, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardlist []int, err error) {
defer handleError(&err)
plan := buildPlan(sql)
return shardListFromPlan(plan, bindVariables, tabletKeys), nil
}
func buildPlan(sql string) (plan *RoutingPlan) {
statement, err := Parse(sql)
if err != nil {
panic(err)
}
return getRoutingPlan(statement)
}
func shardListFromPlan(plan *RoutingPlan, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardList []int) {
if plan.criteria == nil {
return makeList(0, len(tabletKeys))
}
switch criteria := plan.criteria.(type) {
case Values:
index := findInsertShard(criteria, bindVariables, tabletKeys)
return []int{index}
case *ComparisonExpr:
switch criteria.Operator {
case "=", "<=>":
index := findShard(criteria.Right, bindVariables, tabletKeys)
return []int{index}
case "<", "<=":
index := findShard(criteria.Right, bindVariables, tabletKeys)
return makeList(0, index+1)
case ">", ">=":
index := findShard(criteria.Right, bindVariables, tabletKeys)
return makeList(index, len(tabletKeys))
case "in":
return findShardList(criteria.Right, bindVariables, tabletKeys)
}
case *RangeCond:
if criteria.Operator == "between" {
start := findShard(criteria.From, bindVariables, tabletKeys)
last := findShard(criteria.To, bindVariables, tabletKeys)
if last < start {
start, last = last, start
}
return makeList(start, last+1)
}
}
return makeList(0, len(tabletKeys))
}
func getRoutingPlan(statement Statement) (plan *RoutingPlan) {
plan = &RoutingPlan{}
if ins, ok := statement.(*Insert); ok {
if sel, ok := ins.Rows.(SelectStatement); ok {
return getRoutingPlan(sel)
}
plan.criteria = routingAnalyzeValues(ins.Rows.(Values))
return plan
}
var where *Where
switch stmt := statement.(type) {
case *Select:
where = stmt.Where
case *Update:
where = stmt.Where
case *Delete:
where = stmt.Where
}
if where != nil {
plan.criteria = routingAnalyzeBoolean(where.Expr)
}
return plan
}
func routingAnalyzeValues(vals Values) Values {
// Analyze first value of every item in the list
for i := 0; i < len(vals); i++ {
switch tuple := vals[i].(type) {
case ValTuple:
result := routingAnalyzeValue(tuple[0])
if result != VALUE_NODE {
panic(NewParserError("insert is too complex"))
}
default:
panic(NewParserError("insert is too complex"))
}
}
return vals
}
func routingAnalyzeBoolean(node BoolExpr) BoolExpr {
switch node := node.(type) {
case *AndExpr:
left := routingAnalyzeBoolean(node.Left)
right := routingAnalyzeBoolean(node.Right)
if left != nil && right != nil {
return nil
} else if left != nil {
return left
} else {
return right
}
case *ParenBoolExpr:
return routingAnalyzeBoolean(node.Expr)
case *ComparisonExpr:
switch {
case StringIn(node.Operator, "=", "<", ">", "<=", ">=", "<=>"):
left := routingAnalyzeValue(node.Left)
right := routingAnalyzeValue(node.Right)
if (left == EID_NODE && right == VALUE_NODE) || (left == VALUE_NODE && right == EID_NODE) {
return node
}
case node.Operator == "in":
left := routingAnalyzeValue(node.Left)
right := routingAnalyzeValue(node.Right)
if left == EID_NODE && right == LIST_NODE {
return node
}
}
case *RangeCond:
if node.Operator != "between" {
return nil
}
left := routingAnalyzeValue(node.Left)
from := routingAnalyzeValue(node.From)
to := routingAnalyzeValue(node.To)
if left == EID_NODE && from == VALUE_NODE && to == VALUE_NODE {
return node
}
}
return nil
}
func routingAnalyzeValue(valExpr ValExpr) int {
switch node := valExpr.(type) {
case *ColName:
if string(node.Name) == "entity_id" {
return EID_NODE
}
case ValTuple:
for _, n := range node {
if routingAnalyzeValue(n) != VALUE_NODE {
return OTHER_NODE
}
}
return LIST_NODE
case StrVal, NumVal, ValArg:
return VALUE_NODE
}
return OTHER_NODE
}
func findShardList(valExpr ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) []int {
shardset := make(map[int]bool)
switch node := valExpr.(type) {
case ValTuple:
for _, n := range node {
index := findShard(n, bindVariables, tabletKeys)
shardset[index] = true
}
}
shardlist := make([]int, len(shardset))
index := 0
for k := range shardset {
shardlist[index] = k
index++
}
return shardlist
}
func findInsertShard(vals Values, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) int {
index := -1
for i := 0; i < len(vals); i++ {
first_value_expression := vals[i].(ValTuple)[0]
newIndex := findShard(first_value_expression, bindVariables, tabletKeys)
if index == -1 {
index = newIndex
} else if index != newIndex {
panic(NewParserError("insert has multiple shard targets"))
}
}
return index
}
func findShard(valExpr ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) int {
value := getBoundValue(valExpr, bindVariables)
return key.FindShardForValue(value, tabletKeys)
}
func getBoundValue(valExpr ValExpr, bindVariables map[string]interface{}) string {
switch node := valExpr.(type) {
case ValTuple:
if len(node) != 1 {
panic(NewParserError("tuples not allowed as insert values"))
}
// TODO: Change parser to create single value tuples into non-tuples.
return getBoundValue(node[0], bindVariables)
case StrVal:
return string(node)
case NumVal:
val, err := strconv.ParseInt(string(node), 10, 64)
if err != nil {
panic(NewParserError("%s", err.Error()))
}
return key.Uint64Key(val).String()
case ValArg:
value := findBindValue(node, bindVariables)
return key.EncodeValue(value)
}
panic("Unexpected token")
}
func findBindValue(valArg ValArg, bindVariables map[string]interface{}) interface{} {
if bindVariables == nil {
panic(NewParserError("No bind variable for " + string(valArg)))
}
value, ok := bindVariables[string(valArg[1:])]
if !ok {
panic(NewParserError("No bind variable for " + string(valArg)))
}
return value
}
func makeList(start, end int) []int {
list := make([]int, end-start)
for i := start; i < end; i++ {
list[i-start] = i
}
return list
}

Просмотреть файл

@ -9,26 +9,6 @@ import (
"fmt"
)
// ParserError: To be deprecated.
// TODO(sougou): deprecate.
type ParserError struct {
Message string
}
func NewParserError(format string, args ...interface{}) ParserError {
return ParserError{fmt.Sprintf(format, args...)}
}
func (err ParserError) Error() string {
return err.Message
}
func handleError(err *error) {
if x := recover(); x != nil {
*err = x.(ParserError)
}
}
// TrackedBuffer is used to rebuild a query from the ast.
// bindLocations keeps track of locations in the buffer that
// use bind variables for efficient future substitutions.
@ -116,3 +96,7 @@ func (buf *TrackedBuffer) WriteArg(arg string) {
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
return &ParsedQuery{buf.String(), buf.bindLocations}
}
func (buf *TrackedBuffer) HasBindVars() bool {
return len(buf.bindLocations) != 0
}

41
go/vt/tableacl/role.go Normal file
Просмотреть файл

@ -0,0 +1,41 @@
package tableacl
import "strings"
// Role defines the level of access on a table
type Role int
const (
// READER can run SELECT statements
READER Role = iota
// WRITER can run SELECT, INSERT & UPDATE statements
WRITER
// ADMIN can run any statements including DDLs
ADMIN
// NumRoles is number of Roles defined
NumRoles
)
var roleNames = []string{
"READER",
"WRITER",
"ADMIN",
}
// Name returns the name of a role
func (r Role) Name() string {
if r < READER || r > ADMIN {
return ""
}
return roleNames[r]
}
// RoleByName returns the Role corresponding to a name
func RoleByName(s string) (Role, bool) {
for i, v := range roleNames {
if v == strings.ToUpper(s) {
return Role(i), true
}
}
return NumRoles, false
}

Просмотреть файл

@ -0,0 +1,31 @@
package tableacl
var allAcl simpleACL
const (
ALL = "*"
)
// NewACL returns an ACL with the specified entries
func NewACL(entries []string) (ACL, error) {
a := simpleACL(map[string]bool{})
for _, e := range entries {
a[e] = true
}
return a, nil
}
// simpleACL keeps all entries in a unique in-memory list
type simpleACL map[string]bool
// IsMember checks the membership of a principal in this ACL
func (a simpleACL) IsMember(principal string) bool {
return a[principal] || a[ALL]
}
func all() ACL {
if allAcl == nil {
allAcl = simpleACL(map[string]bool{ALL: true})
}
return allAcl
}

Просмотреть файл

@ -0,0 +1,94 @@
package tableacl
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"regexp"
"strings"
)
// ACL is an interface for Access Control List
type ACL interface {
// IsMember checks the membership of a principal in this ACL
IsMember(principal string) bool
}
var tableAcl map[*regexp.Regexp]map[Role]ACL
// Init initiates table ACLs
func Init(configFile string) {
config, err := ioutil.ReadFile(configFile)
if err != nil {
log.Fatalf("unable to read tableACL config file: %v", err)
}
tableAcl, err = load(config)
if err != nil {
log.Fatalf("tableACL initialization error: %v", err)
}
}
// load loads configurations from a JSON byte array
//
// Sample configuration
// []byte (`{
// <tableRegexPattern1>: {"READER": "*", "WRITER": "<u2>,<u4>...","ADMIN": "<u5>"},
// <tableRegexPattern2>: {"ADMIN": "<u5>"}
//}`)
func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
var contents map[string]map[string]string
err := json.Unmarshal(config, &contents)
if err != nil {
return nil, err
}
tableAcl := make(map[*regexp.Regexp]map[Role]ACL)
for tblPattern, accessMap := range contents {
re, err := regexp.Compile(tblPattern)
if err != nil {
return nil, fmt.Errorf("regexp compile error %v: %v", tblPattern, err)
}
tableAcl[re] = make(map[Role]ACL)
entriesByRole := make(map[Role][]string)
for i := READER; i < NumRoles; i++ {
entriesByRole[i] = []string{}
}
for role, entries := range accessMap {
r, ok := RoleByName(role)
if !ok {
return nil, fmt.Errorf("parse error, invalid role %v", role)
}
// Entries must be assigned to all roles up to r
for i := READER; i <= r; i++ {
entriesByRole[i] = append(entriesByRole[i], strings.Split(entries, ",")...)
}
}
for r, entries := range entriesByRole {
a, err := NewACL(entries)
if err != nil {
return nil, err
}
tableAcl[re][r] = a
}
}
return tableAcl, nil
}
// Authorized returns the list of entities who have at least the
// minimum specified Role on a table
func Authorized(table string, minRole Role) ACL {
if tableAcl == nil {
// No ACLs, allow all access
return all()
}
for re, accessMap := range tableAcl {
if !re.MatchString(table) {
continue
}
return accessMap[minRole]
}
// No matching patterns for table, allow all access
return all()
}

Просмотреть файл

@ -0,0 +1,92 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tableacl
import (
"testing"
"github.com/youtube/vitess/go/vt/context"
)
func currentUser() string {
ctx := &context.DummyContext{}
return ctx.GetUsername()
}
func TestParseInvalidJSON(t *testing.T) {
checkLoad([]byte(`{1:2}`), false, t)
checkLoad([]byte(`{"1":"2"}`), false, t)
checkLoad([]byte(`{"table1":{1:2}}`), false, t)
}
func TestInvalidRoleName(t *testing.T) {
checkLoad([]byte(`{"table1":{"SOMEROLE":"u1"}}`), false, t)
}
func TestInvalidRegex(t *testing.T) {
checkLoad([]byte(`{"table(1":{"READER":"u1"}}`), false, t)
}
func TestValidConfigs(t *testing.T) {
checkLoad([]byte(`{"table1":{"READER":"u1"}}`), true, t)
checkLoad([]byte(`{"table1":{"READER":"u1,u2", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,u2", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"}}`), true, t)
checkLoad([]byte(`{
"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"},
"tbl[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3", "ADMIN":"u4"}
}`), true, t)
}
func TestDenyReaderInsert(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", WRITER, t, false)
}
func TestAllowReaderSelect(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", READER, t, true)
}
func TestDenyReaderDDL(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", ADMIN, t, false)
}
func TestAllowUnmatchedTable(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
checkAccess(configData, "UNMATCHED_TABLE", ADMIN, t, true)
}
func TestAllUserReadAcess(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + ALL + `", "WRITER":"u3"}}`)
checkAccess(configData, "table1", READER, t, true)
}
func TestAllUserWriteAccess(t *testing.T) {
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"` + ALL + `"}}`)
checkAccess(configData, "table1", WRITER, t, true)
}
func checkLoad(configData []byte, valid bool, t *testing.T) {
var err error
tableAcl, err = load(configData)
if !valid && err == nil {
t.Errorf("expecting parse error none returned")
}
if valid && err != nil {
t.Errorf("unexpected load error: %v", err)
}
}
func checkAccess(configData []byte, tableName string, role Role, t *testing.T, want bool) {
checkLoad(configData, true, t)
got := Authorized(tableName, role).IsMember(currentUser())
if want != got {
t.Errorf("got %v, want %v", got, want)
}
}

Просмотреть файл

@ -342,6 +342,7 @@ func updateReplicationGraphForPromotedSlave(ts topo.Server, tablet *topo.TabletI
tablet.Type = topo.TYPE_MASTER
tablet.Parent.Cell = ""
tablet.Parent.Uid = topo.NO_TABLET
tablet.Health = nil
err := topo.UpdateTablet(ts, tablet)
if err != nil {
return err

Просмотреть файл

@ -13,8 +13,8 @@ import (
log "github.com/golang/glog"
"github.com/youtube/vitess/go/stats"
"github.com/youtube/vitess/go/vt/binlog"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tabletserver"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
"github.com/youtube/vitess/go/vt/topo"
)
@ -52,7 +52,7 @@ func (agent *ActionAgent) allowQueries(tablet *topo.Tablet) error {
qrs := tabletserver.LoadCustomRules()
if tablet.KeyRange.IsPartial() {
qr := tabletserver.NewQueryRule("enforce keyspace_id range", "keyspace_id_not_in_range", tabletserver.QR_FAIL)
qr.AddPlanCond(sqlparser.PLAN_INSERT_PK)
qr.AddPlanCond(planbuilder.PLAN_INSERT_PK)
err := qr.AddBindVarCond("keyspace_id", true, true, tabletserver.QR_NOTIN, tablet.KeyRange)
if err != nil {
log.Warningf("Unable to add keyspace rule: %v", err)

Просмотреть файл

@ -0,0 +1,42 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import "github.com/youtube/vitess/go/vt/sqlparser"
type DDLPlan struct {
Action string
TableName string
NewName string
}
func DDLParse(sql string) (plan *DDLPlan) {
statement, err := sqlparser.Parse(sql)
if err != nil {
return &DDLPlan{Action: ""}
}
stmt, ok := statement.(*sqlparser.DDL)
if !ok {
return &DDLPlan{Action: ""}
}
return &DDLPlan{
Action: stmt.Action,
TableName: string(stmt.Table),
NewName: string(stmt.NewName),
}
}
func analyzeDDL(ddl *sqlparser.DDL, getTable TableGetter) *ExecPlan {
plan := &ExecPlan{PlanId: PLAN_DDL}
tableName := string(ddl.Table)
// Skip TableName if table is empty (create statements) or not found in schema
if tableName != "" {
tableInfo, ok := getTable(tableName)
if ok {
plan.TableName = tableInfo.Name
}
}
return plan
}

Просмотреть файл

@ -0,0 +1,161 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"strconv"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
func analyzeUpdate(upd *sqlparser.Update, getTable TableGetter) (plan *ExecPlan, err error) {
// Default plan
plan = &ExecPlan{
PlanId: PLAN_PASS_DML,
FullQuery: GenerateFullQuery(upd),
}
tableName := sqlparser.GetTableName(upd.Table)
if tableName == "" {
plan.Reason = REASON_TABLE
return plan, nil
}
tableInfo, err := plan.setTableInfo(tableName, getTable)
if err != nil {
return nil, err
}
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
log.Warningf("no primary key for table %s", tableName)
plan.Reason = REASON_TABLE_NOINDEX
return plan, nil
}
plan.SecondaryPKValues, err = analyzeUpdateExpressions(upd.Exprs, tableInfo.Indexes[0])
if err != nil {
if err == TooComplex {
plan.Reason = REASON_PK_CHANGE
return plan, nil
}
return nil, err
}
plan.PlanId = PLAN_DML_SUBQUERY
plan.OuterQuery = GenerateUpdateOuterQuery(upd, tableInfo.Indexes[0])
plan.Subquery = GenerateUpdateSubquery(upd, tableInfo)
conditions := analyzeWhere(upd.Where)
if conditions == nil {
plan.Reason = REASON_WHERE
return plan, nil
}
pkValues, err := getPKValues(conditions, tableInfo.Indexes[0])
if err != nil {
return nil, err
}
if pkValues != nil {
plan.PlanId = PLAN_DML_PK
plan.OuterQuery = plan.FullQuery
plan.PKValues = pkValues
return plan, nil
}
return plan, nil
}
func analyzeDelete(del *sqlparser.Delete, getTable TableGetter) (plan *ExecPlan, err error) {
// Default plan
plan = &ExecPlan{
PlanId: PLAN_PASS_DML,
FullQuery: GenerateFullQuery(del),
}
tableName := sqlparser.GetTableName(del.Table)
if tableName == "" {
plan.Reason = REASON_TABLE
return plan, nil
}
tableInfo, err := plan.setTableInfo(tableName, getTable)
if err != nil {
return nil, err
}
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
log.Warningf("no primary key for table %s", tableName)
plan.Reason = REASON_TABLE_NOINDEX
return plan, nil
}
plan.PlanId = PLAN_DML_SUBQUERY
plan.OuterQuery = GenerateDeleteOuterQuery(del, tableInfo.Indexes[0])
plan.Subquery = GenerateDeleteSubquery(del, tableInfo)
conditions := analyzeWhere(del.Where)
if conditions == nil {
plan.Reason = REASON_WHERE
return plan, nil
}
pkValues, err := getPKValues(conditions, tableInfo.Indexes[0])
if err != nil {
return nil, err
}
if pkValues != nil {
plan.PlanId = PLAN_DML_PK
plan.OuterQuery = plan.FullQuery
plan.PKValues = pkValues
return plan, nil
}
return plan, nil
}
func analyzeSet(set *sqlparser.Set) (plan *ExecPlan) {
plan = &ExecPlan{
PlanId: PLAN_SET,
FullQuery: GenerateFullQuery(set),
}
if len(set.Exprs) > 1 { // Multiple set values
return plan
}
update_expression := set.Exprs[0]
plan.SetKey = string(update_expression.Name.Name)
numExpr, ok := update_expression.Expr.(sqlparser.NumVal)
if !ok {
return plan
}
val := string(numExpr)
if ival, err := strconv.ParseInt(val, 0, 64); err == nil {
plan.SetValue = ival
} else if fval, err := strconv.ParseFloat(val, 64); err == nil {
plan.SetValue = fval
}
return plan
}
func analyzeUpdateExpressions(exprs sqlparser.UpdateExprs, pkIndex *schema.Index) (pkValues []interface{}, err error) {
for _, expr := range exprs {
index := pkIndex.FindColumn(sqlparser.GetColName(expr.Name))
if index == -1 {
continue
}
if !sqlparser.IsValue(expr.Expr) {
log.Warningf("expression is too complex %v", expr)
return nil, TooComplex
}
if pkValues == nil {
pkValues = make([]interface{}, len(pkIndex.Columns))
}
var err error
pkValues[index], err = sqlparser.AsInterface(expr.Expr)
if err != nil {
return nil, err
}
}
return pkValues, nil
}

Просмотреть файл

@ -0,0 +1,143 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"errors"
"fmt"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
var (
TooComplex = errors.New("Complex")
execLimit = &sqlparser.Limit{Rowcount: sqlparser.ValArg(":_vtMaxResultSize")}
)
// ExecPlan is built for selects and DMLs.
// PK Values values within ExecPlan can be:
// sqltypes.Value: sourced form the query, or
// string: bind variable name starting with ':', or
// nil if no value was specified
type ExecPlan struct {
PlanId PlanType
Reason ReasonType
TableName string
// FieldQuery is used to fetch field info
FieldQuery *sqlparser.ParsedQuery
// FullQuery will be set for all plans.
FullQuery *sqlparser.ParsedQuery
// For PK plans, only OuterQuery is set.
// For SUBQUERY plans, Subquery is also set.
// IndexUsed is set only for PLAN_SELECT_SUBQUERY
OuterQuery *sqlparser.ParsedQuery
Subquery *sqlparser.ParsedQuery
IndexUsed string
// For selects, columns to be returned
// For PLAN_INSERT_SUBQUERY, columns to be inserted
ColumnNumbers []int
// PLAN_PK_EQUAL, PLAN_DML_PK: where clause values
// PLAN_PK_IN: IN clause values
// PLAN_INSERT_PK: values clause
PKValues []interface{}
// For update: set clause if pk is changing
SecondaryPKValues []interface{}
// For PLAN_INSERT_SUBQUERY: pk columns in the subquery result
SubqueryPKColumns []int
// PLAN_SET
SetKey string
SetValue interface{}
}
func (node *ExecPlan) setTableInfo(tableName string, getTable TableGetter) (*schema.Table, error) {
tableInfo, ok := getTable(tableName)
if !ok {
return nil, fmt.Errorf("table %s not found in schema", tableName)
}
node.TableName = tableInfo.Name
return tableInfo, nil
}
type TableGetter func(tableName string) (*schema.Table, bool)
func GetExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
statement, err := sqlparser.Parse(sql)
if err != nil {
return nil, err
}
plan, err = analyzeSQL(statement, getTable)
if err != nil {
return nil, err
}
if plan.PlanId == PLAN_PASS_DML {
log.Warningf("PASS_DML: %s", sql)
}
return plan, nil
}
func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
statement, err := sqlparser.Parse(sql)
if err != nil {
return nil, err
}
plan = &ExecPlan{
PlanId: PLAN_PASS_SELECT,
FullQuery: GenerateFullQuery(statement),
}
switch stmt := statement.(type) {
case *sqlparser.Select:
if stmt.Lock != "" {
return nil, errors.New("select with lock disallowed with streaming")
}
tableName, _ := analyzeFrom(stmt.From)
if tableName != "" {
plan.setTableInfo(tableName, getTable)
}
case *sqlparser.Union:
// pass
default:
return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt))
}
return plan, nil
}
func analyzeSQL(statement sqlparser.Statement, getTable TableGetter) (plan *ExecPlan, err error) {
switch stmt := statement.(type) {
case *sqlparser.Union:
return &ExecPlan{
PlanId: PLAN_PASS_SELECT,
FieldQuery: GenerateFieldQuery(stmt),
FullQuery: GenerateFullQuery(stmt),
Reason: REASON_SELECT,
}, nil
case *sqlparser.Select:
return analyzeSelect(stmt, getTable)
case *sqlparser.Insert:
return analyzeInsert(stmt, getTable)
case *sqlparser.Update:
return analyzeUpdate(stmt, getTable)
case *sqlparser.Delete:
return analyzeDelete(stmt, getTable)
case *sqlparser.Set:
return analyzeSet(stmt), nil
case *sqlparser.DDL:
return analyzeDDL(stmt, getTable), nil
}
return nil, errors.New("invalid SQL")
}

Просмотреть файл

@ -0,0 +1,158 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
type IndexScore struct {
Index *schema.Index
ColumnMatch []bool
MatchFailed bool
}
type scoreValue int64
const (
NO_MATCH = scoreValue(-1)
PERFECT_SCORE = scoreValue(0)
)
func NewIndexScore(index *schema.Index) *IndexScore {
return &IndexScore{index, make([]bool, len(index.Columns)), false}
}
func (is *IndexScore) FindMatch(columnName string) int {
if is.MatchFailed {
return -1
}
if index := is.Index.FindColumn(columnName); index != -1 {
is.ColumnMatch[index] = true
return index
}
// If the column is among the data columns, we can still use
// the index without going to the main table
if index := is.Index.FindDataColumn(columnName); index == -1 {
is.MatchFailed = true
}
return -1
}
func (is *IndexScore) GetScore() scoreValue {
if is.MatchFailed {
return NO_MATCH
}
score := NO_MATCH
for i, indexColumn := range is.ColumnMatch {
if indexColumn {
score = scoreValue(is.Index.Cardinality[i])
continue
}
return score
}
return PERFECT_SCORE
}
func NewIndexScoreList(indexes []*schema.Index) []*IndexScore {
scoreList := make([]*IndexScore, len(indexes))
for i, v := range indexes {
scoreList[i] = NewIndexScore(v)
}
return scoreList
}
func getSelectPKValues(conditions []sqlparser.BoolExpr, pkIndex *schema.Index) (planId PlanType, pkValues []interface{}, err error) {
pkValues, err = getPKValues(conditions, pkIndex)
if err != nil {
return 0, nil, err
}
if pkValues == nil {
return PLAN_PASS_SELECT, nil, nil
}
for _, pkValue := range pkValues {
inList, ok := pkValue.([]interface{})
if !ok {
continue
}
if len(pkValues) == 1 {
return PLAN_PK_IN, inList, nil
}
return PLAN_PASS_SELECT, nil, nil
}
return PLAN_PK_EQUAL, pkValues, nil
}
func getPKValues(conditions []sqlparser.BoolExpr, pkIndex *schema.Index) (pkValues []interface{}, err error) {
pkIndexScore := NewIndexScore(pkIndex)
pkValues = make([]interface{}, len(pkIndexScore.ColumnMatch))
for _, condition := range conditions {
condition, ok := condition.(*sqlparser.ComparisonExpr)
if !ok {
return nil, nil
}
if !sqlparser.StringIn(condition.Operator, sqlparser.AST_EQ, sqlparser.AST_IN) {
return nil, nil
}
index := pkIndexScore.FindMatch(string(condition.Left.(*sqlparser.ColName).Name))
if index == -1 {
return nil, nil
}
switch condition.Operator {
case sqlparser.AST_EQ, sqlparser.AST_IN:
var err error
pkValues[index], err = sqlparser.AsInterface(condition.Right)
if err != nil {
return nil, err
}
default:
panic("unreachable")
}
}
if pkIndexScore.GetScore() == PERFECT_SCORE {
return pkValues, nil
}
return nil, nil
}
func getIndexMatch(conditions []sqlparser.BoolExpr, indexes []*schema.Index) string {
indexScores := NewIndexScoreList(indexes)
for _, condition := range conditions {
var col string
switch condition := condition.(type) {
case *sqlparser.ComparisonExpr:
col = string(condition.Left.(*sqlparser.ColName).Name)
case *sqlparser.RangeCond:
col = string(condition.Left.(*sqlparser.ColName).Name)
default:
panic("unreachaable")
}
for _, index := range indexScores {
index.FindMatch(col)
}
}
highScore := NO_MATCH
highScorer := -1
for i, index := range indexScores {
curScore := index.GetScore()
if curScore == NO_MATCH {
continue
}
if curScore == PERFECT_SCORE {
highScorer = i
break
}
// Prefer secondary index over primary key
if curScore >= highScore {
highScore = curScore
highScorer = i
}
}
if highScorer == -1 {
return ""
}
return indexes[highScorer].Name
}

Просмотреть файл

@ -0,0 +1,133 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"errors"
log "github.com/golang/glog"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
func analyzeInsert(ins *sqlparser.Insert, getTable TableGetter) (plan *ExecPlan, err error) {
plan = &ExecPlan{
PlanId: PLAN_PASS_DML,
FullQuery: GenerateFullQuery(ins),
}
tableName := sqlparser.GetTableName(ins.Table)
if tableName == "" {
plan.Reason = REASON_TABLE
return plan, nil
}
tableInfo, err := plan.setTableInfo(tableName, getTable)
if err != nil {
return nil, err
}
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
log.Warningf("no primary key for table %s", tableName)
plan.Reason = REASON_TABLE_NOINDEX
return plan, nil
}
pkColumnNumbers := getInsertPKColumns(ins.Columns, tableInfo)
if ins.OnDup != nil {
// Upserts are not safe for statement based replication:
// http://bugs.mysql.com/bug.php?id=58637
plan.Reason = REASON_UPSERT
return plan, nil
}
if sel, ok := ins.Rows.(sqlparser.SelectStatement); ok {
plan.PlanId = PLAN_INSERT_SUBQUERY
plan.OuterQuery = GenerateInsertOuterQuery(ins)
plan.Subquery = GenerateSelectLimitQuery(sel)
if len(ins.Columns) != 0 {
plan.ColumnNumbers, err = analyzeSelectExprs(sqlparser.SelectExprs(ins.Columns), tableInfo)
if err != nil {
return nil, err
}
} else {
// StarExpr node will expand into all columns
n := sqlparser.SelectExprs{&sqlparser.StarExpr{}}
plan.ColumnNumbers, err = analyzeSelectExprs(n, tableInfo)
if err != nil {
return nil, err
}
}
plan.SubqueryPKColumns = pkColumnNumbers
return plan, nil
}
// If it's not a sqlparser.SelectStatement, it's Values.
rowList := ins.Rows.(sqlparser.Values)
pkValues, err := getInsertPKValues(pkColumnNumbers, rowList, tableInfo)
if err != nil {
return nil, err
}
if pkValues != nil {
plan.PlanId = PLAN_INSERT_PK
plan.OuterQuery = plan.FullQuery
plan.PKValues = pkValues
}
return plan, nil
}
func getInsertPKColumns(columns sqlparser.Columns, tableInfo *schema.Table) (pkColumnNumbers []int) {
if len(columns) == 0 {
return tableInfo.PKColumns
}
pkIndex := tableInfo.Indexes[0]
pkColumnNumbers = make([]int, len(pkIndex.Columns))
for i := range pkColumnNumbers {
pkColumnNumbers[i] = -1
}
for i, column := range columns {
index := pkIndex.FindColumn(sqlparser.GetColName(column.(*sqlparser.NonStarExpr).Expr))
if index == -1 {
continue
}
pkColumnNumbers[index] = i
}
return pkColumnNumbers
}
func getInsertPKValues(pkColumnNumbers []int, rowList sqlparser.Values, tableInfo *schema.Table) (pkValues []interface{}, err error) {
pkValues = make([]interface{}, len(pkColumnNumbers))
for index, columnNumber := range pkColumnNumbers {
if columnNumber == -1 {
pkValues[index] = tableInfo.GetPKColumn(index).Default
continue
}
values := make([]interface{}, len(rowList))
for j := 0; j < len(rowList); j++ {
if _, ok := rowList[j].(*sqlparser.Subquery); ok {
return nil, errors.New("row subquery not supported for inserts")
}
row := rowList[j].(sqlparser.ValTuple)
if columnNumber >= len(row) {
return nil, errors.New("column count doesn't match value count")
}
node := row[columnNumber]
if !sqlparser.IsValue(node) {
log.Warningf("insert is too complex %v", node)
return nil, nil
}
var err error
values[j], err = sqlparser.AsInterface(node)
if err != nil {
return nil, err
}
}
if len(values) == 1 {
pkValues[index] = values[0]
} else {
pkValues[index] = values
}
}
return pkValues, nil
}

Просмотреть файл

@ -0,0 +1,227 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path"
"path/filepath"
"strings"
"testing"
"github.com/youtube/vitess/go/testfiles"
"github.com/youtube/vitess/go/vt/schema"
)
func TestPlan(t *testing.T) {
testSchema := loadSchema("schema_test.json")
for tcase := range iterateExecFile("exec_cases.txt") {
plan, err := GetExecPlan(tcase.input, func(name string) (*schema.Table, bool) {
r, ok := testSchema[name]
return r, ok
})
var out string
if err != nil {
out = err.Error()
} else {
bout, err := json.Marshal(plan)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
}
out = string(bout)
}
if out != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
}
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
}
}
func TestCustom(t *testing.T) {
testSchemas := testfiles.Glob("sqlparser_test/*_schema.json")
if len(testSchemas) == 0 {
t.Log("No schemas to test")
return
}
for _, schemFile := range testSchemas {
schem := loadSchema(schemFile)
t.Logf("Testing schema %s", schemFile)
files, err := filepath.Glob(strings.Replace(schemFile, "schema.json", "*.txt", -1))
if err != nil {
log.Fatal(err)
}
if len(files) == 0 {
t.Fatalf("No test files for %s", schemFile)
}
getter := func(name string) (*schema.Table, bool) {
r, ok := schem[name]
return r, ok
}
for _, file := range files {
t.Logf("Testing file %s", file)
for tcase := range iterateExecFile(file) {
plan, err := GetExecPlan(tcase.input, getter)
var out string
if err != nil {
out = err.Error()
} else {
bout, err := json.Marshal(plan)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
}
out = string(bout)
}
if out != tcase.output {
t.Errorf("File: %s: Line:%v\n%s\n%s", file, tcase.lineno, tcase.output, out)
}
}
}
}
}
func TestStreamPlan(t *testing.T) {
testSchema := loadSchema("schema_test.json")
for tcase := range iterateExecFile("stream_cases.txt") {
plan, err := GetStreamExecPlan(tcase.input, func(name string) (*schema.Table, bool) {
r, ok := testSchema[name]
return r, ok
})
var out string
if err != nil {
out = err.Error()
} else {
bout, err := json.Marshal(plan)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
}
out = string(bout)
}
if out != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
}
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
}
}
func TestDDLPlan(t *testing.T) {
for tcase := range iterateExecFile("ddl_cases.txt") {
plan := DDLParse(tcase.input)
expected := make(map[string]interface{})
err := json.Unmarshal([]byte(tcase.output), &expected)
if err != nil {
panic(fmt.Sprintf("Error marshalling %v", plan))
}
matchString(t, tcase.lineno, expected["Action"], plan.Action)
matchString(t, tcase.lineno, expected["TableName"], plan.TableName)
matchString(t, tcase.lineno, expected["NewName"], plan.NewName)
}
}
func matchString(t *testing.T, line int, expected interface{}, actual string) {
if expected != nil {
if expected.(string) != actual {
t.Error(fmt.Sprintf("Line %d: expected: %v, received %s", line, expected, actual))
}
}
}
func loadSchema(name string) map[string]*schema.Table {
b, err := ioutil.ReadFile(locateFile(name))
if err != nil {
panic(err)
}
tables := make([]*schema.Table, 0, 8)
err = json.Unmarshal(b, &tables)
if err != nil {
panic(err)
}
s := make(map[string]*schema.Table)
for _, t := range tables {
s[t.Name] = t
}
return s
}
type testCase struct {
file string
lineno int
input string
output string
}
func iterateExecFile(name string) (testCaseIterator chan testCase) {
name = locateFile(name)
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
panic(fmt.Sprintf("Could not open file %s", name))
}
testCaseIterator = make(chan testCase)
go func() {
defer close(testCaseIterator)
r := bufio.NewReader(fd)
lineno := 0
for {
binput, err := r.ReadBytes('\n')
if err != nil {
if err != io.EOF {
fmt.Printf("Line: %d\n", lineno)
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
}
break
}
lineno++
input := string(binput)
if input == "" || input == "\n" || input[0] == '#' || strings.HasPrefix(input, "Length:") {
//fmt.Printf("%s\n", input)
continue
}
err = json.Unmarshal(binput, &input)
if err != nil {
fmt.Printf("Line: %d, input: %s\n", lineno, binput)
panic(err)
}
input = strings.Trim(input, "\"")
var output []byte
for {
l, err := r.ReadBytes('\n')
lineno++
if err != nil {
fmt.Printf("Line: %d\n", lineno)
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
}
output = append(output, l...)
if l[0] == '}' {
output = output[:len(output)-1]
b := bytes.NewBuffer(make([]byte, 0, 64))
if err := json.Compact(b, output); err == nil {
output = b.Bytes()
}
break
}
if l[0] == '"' {
output = output[1 : len(output)-2]
break
}
}
testCaseIterator <- testCase{name, lineno, input, string(output)}
}
}()
return testCaseIterator
}
func locateFile(name string) string {
if path.IsAbs(name) {
return name
}
return testfiles.Locate("sqlparser_test/" + name)
}

Просмотреть файл

@ -0,0 +1,149 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"fmt"
"github.com/youtube/vitess/go/vt/tableacl"
)
type PlanType int
const (
// PLAN_PASS_SELECT is pass through select statements. This is the
// default plan for select statements.
PLAN_PASS_SELECT PlanType = iota
// PLAN_PASS_DML is pass through update & delete statements. This is
// the default plan for update and delete statements.
PLAN_PASS_DML
// PLAN_PK_EQUAL is select statement which has an equality where clause
// on primary key
PLAN_PK_EQUAL
// PLAN_PK_IN is select statement with a single IN clause on primary key
PLAN_PK_IN
// PLAN_SELECT_SUBQUERY is select statement with a subselect statement
PLAN_SELECT_SUBQUERY
// PLAN_DML_PK is an update or delete with an equality where clause(s)
// on primary key(s)
PLAN_DML_PK
// PLAN_DML_SUBQUERY is an update or delete with a subselect statement
PLAN_DML_SUBQUERY
// PLAN_INSERT_PK is insert statement where the PK value is
// supplied with the query
PLAN_INSERT_PK
// PLAN_INSERT_SUBQUERY is same as PLAN_DML_SUBQUERY but for inserts
PLAN_INSERT_SUBQUERY
// PLAN_SET is for SET statements
PLAN_SET
// PLAN_DDL is for DDL statements
PLAN_DDL
NumPlans
)
// Must exactly match order of plan constants.
var planName = []string{
"PASS_SELECT",
"PASS_DML",
"PK_EQUAL",
"PK_IN",
"SELECT_SUBQUERY",
"DML_PK",
"DML_SUBQUERY",
"INSERT_PK",
"INSERT_SUBQUERY",
"SET",
"DDL",
}
func (pt PlanType) String() string {
if pt < 0 || pt >= NumPlans {
return ""
}
return planName[pt]
}
func PlanByName(s string) (pt PlanType, ok bool) {
for i, v := range planName {
if v == s {
return PlanType(i), true
}
}
return NumPlans, false
}
func (pt PlanType) IsSelect() bool {
return pt == PLAN_PASS_SELECT || pt == PLAN_PK_EQUAL || pt == PLAN_PK_IN || pt == PLAN_SELECT_SUBQUERY
}
func (pt PlanType) MarshalJSON() ([]byte, error) {
return ([]byte)(fmt.Sprintf("\"%s\"", pt.String())), nil
}
// MinRole is the minimum Role required to execute this PlanType
func (pt PlanType) MinRole() tableacl.Role {
return tableAclRoles[pt]
}
var tableAclRoles = map[PlanType]tableacl.Role{
PLAN_PASS_SELECT: tableacl.READER,
PLAN_PK_EQUAL: tableacl.READER,
PLAN_PK_IN: tableacl.READER,
PLAN_SELECT_SUBQUERY: tableacl.READER,
PLAN_SET: tableacl.READER,
PLAN_PASS_DML: tableacl.WRITER,
PLAN_DML_PK: tableacl.WRITER,
PLAN_DML_SUBQUERY: tableacl.WRITER,
PLAN_INSERT_PK: tableacl.WRITER,
PLAN_INSERT_SUBQUERY: tableacl.WRITER,
PLAN_DDL: tableacl.ADMIN,
}
type ReasonType int
const (
REASON_DEFAULT ReasonType = iota
REASON_SELECT
REASON_TABLE
REASON_NOCACHE
REASON_SELECT_LIST
REASON_LOCK
REASON_WHERE
REASON_ORDER
REASON_PKINDEX
REASON_NOINDEX_MATCH
REASON_TABLE_NOINDEX
REASON_PK_CHANGE
REASON_COMPOSITE_PK
REASON_HAS_HINTS
REASON_UPSERT
)
// Must exactly match order of reason constants.
var reasonName = []string{
"DEFAULT",
"SELECT",
"TABLE",
"NOCACHE",
"SELECT_LIST",
"LOCK",
"WHERE",
"ORDER",
"PKINDEX",
"NOINDEX_MATCH",
"TABLE_NOINDEX",
"PK_CHANGE",
"COMPOSITE_PK",
"HAS_HINTS",
"UPSERT",
}
func (rt ReasonType) String() string {
return reasonName[rt]
}
func (rt ReasonType) MarshalJSON() ([]byte, error) {
return ([]byte)(fmt.Sprintf("\"%s\"", rt.String())), nil
}

Просмотреть файл

@ -0,0 +1,184 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"fmt"
"strconv"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
func GenerateFullQuery(statement sqlparser.Statement) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
statement.Format(buf)
return buf.ParsedQuery()
}
func GenerateFieldQuery(statement sqlparser.Statement) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(FormatImpossible)
buf.Fprintf("%v", statement)
if buf.HasBindVars() {
return nil
}
return buf.ParsedQuery()
}
// FormatImpossible is a callback function used by TrackedBuffer
// to generate a modified version of the query where all selects
// have impossible where clauses. It overrides a few node types
// and passes the rest down to the default FormatNode.
func FormatImpossible(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) {
switch node := node.(type) {
case *sqlparser.Select:
buf.Fprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From)
case *sqlparser.JoinTableExpr:
if node.Join == sqlparser.AST_LEFT_JOIN || node.Join == sqlparser.AST_RIGHT_JOIN {
// ON clause is requried
buf.Fprintf("%v %s %v on 1 != 1", node.LeftExpr, node.Join, node.RightExpr)
} else {
buf.Fprintf("%v %s %v", node.LeftExpr, node.Join, node.RightExpr)
}
default:
node.Format(buf)
}
}
func GenerateSelectLimitQuery(selStmt sqlparser.SelectStatement) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
sel, ok := selStmt.(*sqlparser.Select)
if ok {
limit := sel.Limit
if limit == nil {
sel.Limit = execLimit
defer func() {
sel.Limit = nil
}()
}
}
buf.Fprintf("%v", selStmt)
return buf.ParsedQuery()
}
func GenerateEqualOuterQuery(sel *sqlparser.Select, tableInfo *schema.Table) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
fmt.Fprintf(buf, "select ")
writeColumnList(buf, tableInfo.Columns)
buf.Fprintf(" from %v where ", sel.From)
generatePKWhere(buf, tableInfo.Indexes[0])
return buf.ParsedQuery()
}
func GenerateInOuterQuery(sel *sqlparser.Select, tableInfo *schema.Table) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
fmt.Fprintf(buf, "select ")
writeColumnList(buf, tableInfo.Columns)
// We assume there is one and only one PK column.
// A '*' argument name means all variables of the list.
buf.Fprintf(" from %v where %s in (%a)", sel.From, tableInfo.Indexes[0].Columns[0], "*")
return buf.ParsedQuery()
}
func GenerateInsertOuterQuery(ins *sqlparser.Insert) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
buf.Fprintf("insert %vinto %v%v values %a%v",
ins.Comments,
ins.Table,
ins.Columns,
"_rowValues",
ins.OnDup,
)
return buf.ParsedQuery()
}
func GenerateUpdateOuterQuery(upd *sqlparser.Update, pkIndex *schema.Index) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
buf.Fprintf("update %v%v set %v where ", upd.Comments, upd.Table, upd.Exprs)
generatePKWhere(buf, pkIndex)
return buf.ParsedQuery()
}
func GenerateDeleteOuterQuery(del *sqlparser.Delete, pkIndex *schema.Index) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
buf.Fprintf("delete %vfrom %v where ", del.Comments, del.Table)
generatePKWhere(buf, pkIndex)
return buf.ParsedQuery()
}
func generatePKWhere(buf *sqlparser.TrackedBuffer, pkIndex *schema.Index) {
for i := 0; i < len(pkIndex.Columns); i++ {
if i != 0 {
buf.WriteString(" and ")
}
buf.Fprintf("%s = %a", pkIndex.Columns[i], strconv.FormatInt(int64(i), 10))
}
}
func GenerateSelectSubquery(sel *sqlparser.Select, tableInfo *schema.Table, index string) *sqlparser.ParsedQuery {
hint := &sqlparser.IndexHints{Type: sqlparser.AST_USE, Indexes: [][]byte{[]byte(index)}}
table_expr := sel.From[0].(*sqlparser.AliasedTableExpr)
savedHint := table_expr.Hints
table_expr.Hints = hint
defer func() {
table_expr.Hints = savedHint
}()
return GenerateSubquery(
tableInfo.Indexes[0].Columns,
table_expr,
sel.Where,
sel.OrderBy,
sel.Limit,
false,
)
}
func GenerateUpdateSubquery(upd *sqlparser.Update, tableInfo *schema.Table) *sqlparser.ParsedQuery {
return GenerateSubquery(
tableInfo.Indexes[0].Columns,
&sqlparser.AliasedTableExpr{Expr: upd.Table},
upd.Where,
upd.OrderBy,
upd.Limit,
true,
)
}
func GenerateDeleteSubquery(del *sqlparser.Delete, tableInfo *schema.Table) *sqlparser.ParsedQuery {
return GenerateSubquery(
tableInfo.Indexes[0].Columns,
&sqlparser.AliasedTableExpr{Expr: del.Table},
del.Where,
del.OrderBy,
del.Limit,
true,
)
}
func GenerateSubquery(columns []string, table *sqlparser.AliasedTableExpr, where *sqlparser.Where, order sqlparser.OrderBy, limit *sqlparser.Limit, for_update bool) *sqlparser.ParsedQuery {
buf := sqlparser.NewTrackedBuffer(nil)
if limit == nil {
limit = execLimit
}
fmt.Fprintf(buf, "select ")
i := 0
for i = 0; i < len(columns)-1; i++ {
fmt.Fprintf(buf, "%s, ", columns[i])
}
fmt.Fprintf(buf, "%s", columns[i])
buf.Fprintf(" from %v%v%v%v", table, where, order, limit)
if for_update {
buf.Fprintf(sqlparser.AST_FOR_UPDATE)
}
return buf.ParsedQuery()
}
func writeColumnList(buf *sqlparser.TrackedBuffer, columns []schema.TableColumn) {
i := 0
for i = 0; i < len(columns)-1; i++ {
fmt.Fprintf(buf, "%s, ", columns[i].Name)
}
fmt.Fprintf(buf, "%s", columns[i].Name)
}

Просмотреть файл

@ -0,0 +1,219 @@
// Copyright 2014, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package planbuilder
import (
"fmt"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
func analyzeSelect(sel *sqlparser.Select, getTable TableGetter) (plan *ExecPlan, err error) {
// Default plan
plan = &ExecPlan{
PlanId: PLAN_PASS_SELECT,
FieldQuery: GenerateFieldQuery(sel),
FullQuery: GenerateSelectLimitQuery(sel),
}
// There are bind variables in the SELECT list
if plan.FieldQuery == nil {
plan.Reason = REASON_SELECT_LIST
return plan, nil
}
if sel.Distinct != "" || sel.GroupBy != nil || sel.Having != nil {
plan.Reason = REASON_SELECT
return plan, nil
}
// from
tableName, hasHints := analyzeFrom(sel.From)
if tableName == "" {
plan.Reason = REASON_TABLE
return plan, nil
}
tableInfo, err := plan.setTableInfo(tableName, getTable)
if err != nil {
return nil, err
}
// Don't improve the plan if the select is locking the row
if sel.Lock != "" {
plan.Reason = REASON_LOCK
return plan, nil
}
// Further improvements possible only if table is row-cached
if tableInfo.CacheType == schema.CACHE_NONE || tableInfo.CacheType == schema.CACHE_W {
plan.Reason = REASON_NOCACHE
return plan, nil
}
// Select expressions
selects, err := analyzeSelectExprs(sel.SelectExprs, tableInfo)
if err != nil {
return nil, err
}
if selects == nil {
plan.Reason = REASON_SELECT_LIST
return plan, nil
}
plan.ColumnNumbers = selects
// where
conditions := analyzeWhere(sel.Where)
if conditions == nil {
plan.Reason = REASON_WHERE
return plan, nil
}
// order
if sel.OrderBy != nil {
plan.Reason = REASON_ORDER
return plan, nil
}
// This check should never fail because we only cache tables with primary keys.
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
panic("unexpected")
}
// Attempt PK match only if there's no limit clause
if sel.Limit == nil {
planId, pkValues, err := getSelectPKValues(conditions, tableInfo.Indexes[0])
if err != nil {
return nil, err
}
switch planId {
case PLAN_PK_EQUAL:
plan.PlanId = PLAN_PK_EQUAL
plan.OuterQuery = GenerateEqualOuterQuery(sel, tableInfo)
plan.PKValues = pkValues
return plan, nil
case PLAN_PK_IN:
plan.PlanId = PLAN_PK_IN
plan.OuterQuery = GenerateInOuterQuery(sel, tableInfo)
plan.PKValues = pkValues
return plan, nil
}
}
if len(tableInfo.Indexes[0].Columns) != 1 {
plan.Reason = REASON_COMPOSITE_PK
return plan, nil
}
// TODO: Analyze hints to improve plan.
if hasHints {
plan.Reason = REASON_HAS_HINTS
return plan, nil
}
plan.IndexUsed = getIndexMatch(conditions, tableInfo.Indexes)
if plan.IndexUsed == "" {
plan.Reason = REASON_NOINDEX_MATCH
return plan, nil
}
if plan.IndexUsed == "PRIMARY" {
plan.Reason = REASON_PKINDEX
return plan, nil
}
// TODO: We can further optimize. Change this to pass-through if select list matches all columns in index.
plan.PlanId = PLAN_SELECT_SUBQUERY
plan.OuterQuery = GenerateInOuterQuery(sel, tableInfo)
plan.Subquery = GenerateSelectSubquery(sel, tableInfo, plan.IndexUsed)
return plan, nil
}
func analyzeSelectExprs(exprs sqlparser.SelectExprs, table *schema.Table) (selects []int, err error) {
selects = make([]int, 0, len(exprs))
for _, expr := range exprs {
switch expr := expr.(type) {
case *sqlparser.StarExpr:
// Append all columns.
for colIndex := range table.Columns {
selects = append(selects, colIndex)
}
case *sqlparser.NonStarExpr:
name := sqlparser.GetColName(expr.Expr)
if name == "" {
// Not a simple column name.
return nil, nil
}
colIndex := table.FindColumn(name)
if colIndex == -1 {
return nil, fmt.Errorf("column %s not found in table %s", name, table.Name)
}
selects = append(selects, colIndex)
default:
panic("unreachable")
}
}
return selects, nil
}
func analyzeFrom(tableExprs sqlparser.TableExprs) (tablename string, hasHints bool) {
if len(tableExprs) > 1 {
return "", false
}
node, ok := tableExprs[0].(*sqlparser.AliasedTableExpr)
if !ok {
return "", false
}
return sqlparser.GetTableName(node.Expr), node.Hints != nil
}
func analyzeWhere(node *sqlparser.Where) (conditions []sqlparser.BoolExpr) {
if node == nil {
return nil
}
return analyzeBoolean(node.Expr)
}
func analyzeBoolean(node sqlparser.BoolExpr) (conditions []sqlparser.BoolExpr) {
switch node := node.(type) {
case *sqlparser.AndExpr:
left := analyzeBoolean(node.Left)
right := analyzeBoolean(node.Right)
if left == nil || right == nil {
return nil
}
if sqlparser.HasINClause(left) && sqlparser.HasINClause(right) {
return nil
}
return append(left, right...)
case *sqlparser.ParenBoolExpr:
return analyzeBoolean(node.Expr)
case *sqlparser.ComparisonExpr:
switch {
case sqlparser.StringIn(
node.Operator,
sqlparser.AST_EQ,
sqlparser.AST_LT,
sqlparser.AST_GT,
sqlparser.AST_LE,
sqlparser.AST_GE,
sqlparser.AST_NSE,
sqlparser.AST_LIKE):
if sqlparser.IsColName(node.Left) && sqlparser.IsValue(node.Right) {
return []sqlparser.BoolExpr{node}
}
case node.Operator == sqlparser.AST_IN:
if sqlparser.IsColName(node.Left) && sqlparser.IsSimpleTuple(node.Right) {
return []sqlparser.BoolExpr{node}
}
}
case *sqlparser.RangeCond:
if node.Operator != sqlparser.AST_BETWEEN {
return nil
}
if sqlparser.IsColName(node.Left) && sqlparser.IsValue(node.From) && sqlparser.IsValue(node.To) {
return []sqlparser.BoolExpr{node}
}
}
return nil
}

Просмотреть файл

@ -10,7 +10,7 @@ import (
"testing"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
)
func TestQueryRules(t *testing.T) {
@ -52,7 +52,7 @@ func TestQueryRules(t *testing.T) {
func TestCopy(t *testing.T) {
qrs := NewQueryRules()
qr1 := NewQueryRule("rule 1", "r1", QR_FAIL)
qr1.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
qr1.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
qr1.AddTableCond("aa")
qr1.AddBindVarCond("a", true, false, QR_NOOP, nil)
@ -70,7 +70,7 @@ func TestCopy(t *testing.T) {
t.Errorf("want false, got true")
}
qr1.plans[0] = sqlparser.PLAN_INSERT_PK
qr1.plans[0] = planbuilder.PLAN_INSERT_PK
if qr1.plans[0] == qrf1.plans[0] {
t.Errorf("want false, got true")
}
@ -95,12 +95,12 @@ func TestFilterByPlan(t *testing.T) {
qr1 := NewQueryRule("rule 1", "r1", QR_FAIL)
qr1.SetIPCond("123")
qr1.SetQueryCond("select")
qr1.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
qr1.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
qr1.AddBindVarCond("a", true, false, QR_NOOP, nil)
qr2 := NewQueryRule("rule 2", "r2", QR_FAIL)
qr2.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
qr2.AddPlanCond(sqlparser.PLAN_PK_EQUAL)
qr2.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
qr2.AddPlanCond(planbuilder.PLAN_PK_EQUAL)
qr2.AddBindVarCond("a", true, false, QR_NOOP, nil)
qr3 := NewQueryRule("rule 3", "r3", QR_FAIL)
@ -116,7 +116,7 @@ func TestFilterByPlan(t *testing.T) {
qrs.Add(qr3)
qrs.Add(qr4)
qrs1 := qrs.filterByPlan("select", sqlparser.PLAN_PASS_SELECT, "a")
qrs1 := qrs.filterByPlan("select", planbuilder.PLAN_PASS_SELECT, "a")
if l := len(qrs1.rules); l != 3 {
t.Errorf("want 3, got %d", l)
}
@ -130,7 +130,7 @@ func TestFilterByPlan(t *testing.T) {
t.Errorf("want nil, got non-nil")
}
qrs1 = qrs.filterByPlan("insert", sqlparser.PLAN_PASS_SELECT, "a")
qrs1 = qrs.filterByPlan("insert", planbuilder.PLAN_PASS_SELECT, "a")
if l := len(qrs1.rules); l != 1 {
t.Errorf("want 1, got %d", l)
}
@ -138,7 +138,7 @@ func TestFilterByPlan(t *testing.T) {
t.Errorf("want r2, got %s", qrs1.rules[0].Name)
}
qrs1 = qrs.filterByPlan("insert", sqlparser.PLAN_PK_EQUAL, "a")
qrs1 = qrs.filterByPlan("insert", planbuilder.PLAN_PK_EQUAL, "a")
if l := len(qrs1.rules); l != 1 {
t.Errorf("want 1, got %d", l)
}
@ -146,7 +146,7 @@ func TestFilterByPlan(t *testing.T) {
t.Errorf("want r2, got %s", qrs1.rules[0].Name)
}
qrs1 = qrs.filterByPlan("select", sqlparser.PLAN_INSERT_PK, "a")
qrs1 = qrs.filterByPlan("select", planbuilder.PLAN_INSERT_PK, "a")
if l := len(qrs1.rules); l != 1 {
t.Errorf("want 1, got %d", l)
}
@ -154,12 +154,12 @@ func TestFilterByPlan(t *testing.T) {
t.Errorf("want r3, got %s", qrs1.rules[0].Name)
}
qrs1 = qrs.filterByPlan("sel", sqlparser.PLAN_INSERT_PK, "a")
qrs1 = qrs.filterByPlan("sel", planbuilder.PLAN_INSERT_PK, "a")
if qrs1.rules != nil {
t.Errorf("want nil, got non-nil")
}
qrs1 = qrs.filterByPlan("table", sqlparser.PLAN_PASS_DML, "b")
qrs1 = qrs.filterByPlan("table", planbuilder.PLAN_PASS_DML, "b")
if l := len(qrs1.rules); l != 1 {
t.Errorf("want 1, got %#v, %#v", qrs1.rules[0], qrs1.rules[1])
}
@ -170,7 +170,7 @@ func TestFilterByPlan(t *testing.T) {
qr5 := NewQueryRule("rule 5", "r5", QR_FAIL)
qrs.Add(qr5)
qrs1 = qrs.filterByPlan("sel", sqlparser.PLAN_INSERT_PK, "a")
qrs1 = qrs.filterByPlan("sel", planbuilder.PLAN_INSERT_PK, "a")
if l := len(qrs1.rules); l != 1 {
t.Errorf("want 1, got %d", l)
}
@ -179,7 +179,7 @@ func TestFilterByPlan(t *testing.T) {
}
qrsnil1 := NewQueryRules()
if qrsnil2 := qrsnil1.filterByPlan("", sqlparser.PLAN_PASS_SELECT, "a"); qrsnil2.rules != nil {
if qrsnil2 := qrsnil1.filterByPlan("", planbuilder.PLAN_PASS_SELECT, "a"); qrsnil2.rules != nil {
t.Errorf("want nil, got non-nil")
}
}
@ -204,13 +204,13 @@ func TestQueryRule(t *testing.T) {
t.Errorf("want error")
}
qr.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
qr.AddPlanCond(sqlparser.PLAN_INSERT_PK)
qr.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
qr.AddPlanCond(planbuilder.PLAN_INSERT_PK)
if qr.plans[0] != sqlparser.PLAN_PASS_SELECT {
if qr.plans[0] != planbuilder.PLAN_PASS_SELECT {
t.Errorf("want PASS_SELECT, got %s", qr.plans[0].String())
}
if qr.plans[1] != sqlparser.PLAN_INSERT_PK {
if qr.plans[1] != planbuilder.PLAN_INSERT_PK {
t.Errorf("want INSERT_PK, got %s", qr.plans[1].String())
}
@ -563,10 +563,10 @@ func TestImport(t *testing.T) {
if qrs.rules[0].query == nil {
t.Errorf("want non-nil")
}
if qrs.rules[0].plans[0] != sqlparser.PLAN_PASS_SELECT {
if qrs.rules[0].plans[0] != planbuilder.PLAN_PASS_SELECT {
t.Errorf("want PASS_SELECT, got %s", qrs.rules[0].plans[0].String())
}
if qrs.rules[0].plans[1] != sqlparser.PLAN_INSERT_PK {
if qrs.rules[0].plans[1] != planbuilder.PLAN_INSERT_PK {
t.Errorf("want PASS_INSERT_PK, got %s", qrs.rules[0].plans[0].String())
}
if qrs.rules[0].tableNames[0] != "a" {

Просмотреть файл

@ -5,6 +5,7 @@
package tabletserver
import (
"fmt"
"time"
log "github.com/golang/glog"
@ -15,9 +16,12 @@ import (
"github.com/youtube/vitess/go/sync2"
"github.com/youtube/vitess/go/vt/dbconfigs"
"github.com/youtube/vitess/go/vt/dbconnpool"
"github.com/youtube/vitess/go/vt/logutil"
"github.com/youtube/vitess/go/vt/mysqlctl"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
"github.com/youtube/vitess/go/vt/tabletserver/proto"
)
@ -67,6 +71,10 @@ type QueryEngine struct {
strictMode sync2.AtomicInt64
maxResultSize sync2.AtomicInt64
streamBufferSize sync2.AtomicInt64
strictTableAcl bool
// loggers
accessCheckerLogger *logutil.ThrottledLogger
}
type compiledPlan struct {
@ -140,9 +148,13 @@ func NewQueryEngine(config Config) *QueryEngine {
if config.StrictMode {
qe.strictMode.Set(1)
}
qe.strictTableAcl = config.StrictTableAcl
qe.maxResultSize = sync2.AtomicInt64(config.MaxResultSize)
qe.streamBufferSize = sync2.AtomicInt64(config.StreamBufferSize)
// loggers
qe.accessCheckerLogger = logutil.NewThrottledLogger("accessChecker", 1*time.Second)
// Stats
stats.Publish("MaxResultSize", stats.IntFunc(qe.maxResultSize.Get))
stats.Publish("StreamBufferSize", stats.IntFunc(qe.streamBufferSize.Get))
@ -311,7 +323,9 @@ func (qe *QueryEngine) Execute(logStats *SQLQueryStats, query *proto.Query) (rep
panic(NewTabletError(RETRY, "Query disallowed due to rule: %s", desc))
}
if basePlan.PlanId == sqlparser.PLAN_DDL {
qe.checkTableAcl(basePlan.TableName, basePlan.PlanId, basePlan.Authorized, logStats.context.GetUsername())
if basePlan.PlanId == planbuilder.PLAN_DDL {
return qe.execDDL(logStats, query.Sql)
}
@ -331,36 +345,36 @@ func (qe *QueryEngine) Execute(logStats *SQLQueryStats, query *proto.Query) (rep
invalidator = conn.DirtyKeys(plan.TableName)
}
switch plan.PlanId {
case sqlparser.PLAN_PASS_DML:
case planbuilder.PLAN_PASS_DML:
if qe.strictMode.Get() != 0 {
panic(NewTabletError(FAIL, "DML too complex"))
}
reply = qe.directFetch(logStats, conn, plan.FullQuery, plan.BindVars, nil, nil)
case sqlparser.PLAN_INSERT_PK:
case planbuilder.PLAN_INSERT_PK:
reply = qe.execInsertPK(logStats, conn, plan, invalidator)
case sqlparser.PLAN_INSERT_SUBQUERY:
case planbuilder.PLAN_INSERT_SUBQUERY:
reply = qe.execInsertSubquery(logStats, conn, plan, invalidator)
case sqlparser.PLAN_DML_PK:
case planbuilder.PLAN_DML_PK:
reply = qe.execDMLPK(logStats, conn, plan, invalidator)
case sqlparser.PLAN_DML_SUBQUERY:
case planbuilder.PLAN_DML_SUBQUERY:
reply = qe.execDMLSubquery(logStats, conn, plan, invalidator)
default: // select or set in a transaction, just count as select
reply = qe.execDirect(logStats, plan, conn)
}
} else {
switch plan.PlanId {
case sqlparser.PLAN_PASS_SELECT:
if plan.Reason == sqlparser.REASON_LOCK {
case planbuilder.PLAN_PASS_SELECT:
if plan.Reason == planbuilder.REASON_LOCK {
panic(NewTabletError(FAIL, "Disallowed outside transaction"))
}
reply = qe.execSelect(logStats, plan)
case sqlparser.PLAN_PK_EQUAL:
case planbuilder.PLAN_PK_EQUAL:
reply = qe.execPKEqual(logStats, plan)
case sqlparser.PLAN_PK_IN:
case planbuilder.PLAN_PK_IN:
reply = qe.execPKIN(logStats, plan)
case sqlparser.PLAN_SELECT_SUBQUERY:
case planbuilder.PLAN_SELECT_SUBQUERY:
reply = qe.execSubquery(logStats, plan)
case sqlparser.PLAN_SET:
case planbuilder.PLAN_SET:
waitingForConnectionStart := time.Now()
conn := getOrPanic(qe.connPool)
logStats.WaitingForConnection += time.Now().Sub(waitingForConnectionStart)
@ -395,6 +409,9 @@ func (qe *QueryEngine) StreamExecute(logStats *SQLQueryStats, query *proto.Query
logStats.OriginalSql = query.Sql
defer queryStats.Record("SELECT_STREAM", time.Now())
authorized := tableacl.Authorized(plan.TableName, plan.PlanId.MinRole())
qe.checkTableAcl(plan.TableName, plan.PlanId, authorized, logStats.context.GetUsername())
// does the real work: first get a connection
waitingForConnectionStart := time.Now()
conn := getOrPanic(qe.streamConnPool)
@ -410,6 +427,16 @@ func (qe *QueryEngine) StreamExecute(logStats *SQLQueryStats, query *proto.Query
qe.fullStreamFetch(logStats, conn, plan.FullQuery, query.BindVariables, nil, nil, sendReply)
}
func (qe *QueryEngine) checkTableAcl(table string, planId planbuilder.PlanType, authorized tableacl.ACL, user string) {
if !authorized.IsMember(user) {
err := fmt.Sprintf("table acl error: %v cannot run %v on table %v", user, planId, table)
if qe.strictTableAcl {
panic(NewTabletError(FAIL, err))
}
qe.accessCheckerLogger.Errorf(err)
}
}
// InvalidateForDml performs rowcache invalidations for the dml.
func (qe *QueryEngine) InvalidateForDml(dml *proto.DmlType) {
if qe.cachePool.IsClosed() {
@ -435,7 +462,7 @@ func (qe *QueryEngine) InvalidateForDml(dml *proto.DmlType) {
// InvalidateForDDL performs schema and rowcache changes for the ddl.
func (qe *QueryEngine) InvalidateForDDL(ddlInvalidate *proto.DDLInvalidate) {
ddlPlan := sqlparser.DDLParse(ddlInvalidate.DDL)
ddlPlan := planbuilder.DDLParse(ddlInvalidate.DDL)
if ddlPlan.Action == "" {
panic(NewTabletError(FAIL, "DDL is not understood"))
}
@ -449,7 +476,7 @@ func (qe *QueryEngine) InvalidateForDDL(ddlInvalidate *proto.DDLInvalidate) {
// DDL
func (qe *QueryEngine) execDDL(logStats *SQLQueryStats, ddl string) *mproto.QueryResult {
ddlPlan := sqlparser.DDLParse(ddl)
ddlPlan := planbuilder.DDLParse(ddl)
if ddlPlan.Action == "" {
panic(NewTabletError(FAIL, "DDL is not understood"))
}

Просмотреть файл

@ -11,7 +11,7 @@ import (
"strconv"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
)
//-----------------------------------------------
@ -90,7 +90,7 @@ func (qrs *QueryRules) UnmarshalJSON(data []byte) (err error) {
// filterByPlan creates a new QueryRules by prefiltering on the query and planId. This allows
// us to create query plan specific QueryRules out of the original QueryRules. In the new rules,
// query, plans and tableNames predicates are empty.
func (qrs *QueryRules) filterByPlan(query string, planid sqlparser.PlanType, tableName string) (newqrs *QueryRules) {
func (qrs *QueryRules) filterByPlan(query string, planid planbuilder.PlanType, tableName string) (newqrs *QueryRules) {
var newrules []*QueryRule
for _, qr := range qrs.rules {
if newrule := qr.filterByPlan(query, planid, tableName); newrule != nil {
@ -129,7 +129,7 @@ type QueryRule struct {
requestIP, user, query *regexp.Regexp
// Any matched plan will make this condition true (OR)
plans []sqlparser.PlanType
plans []planbuilder.PlanType
// Any matched tableNames will make this condition true (OR)
tableNames []string
@ -158,7 +158,7 @@ func (qr *QueryRule) Copy() (newqr *QueryRule) {
act: qr.act,
}
if qr.plans != nil {
newqr.plans = make([]sqlparser.PlanType, len(qr.plans))
newqr.plans = make([]planbuilder.PlanType, len(qr.plans))
copy(newqr.plans, qr.plans)
}
if qr.tableNames != nil {
@ -189,7 +189,7 @@ func (qr *QueryRule) SetUserCond(pattern string) (err error) {
// AddPlanCond adds to the list of plans that can be matched for
// the rule to fire.
// This function acts as an OR: Any plan id match is considered a match.
func (qr *QueryRule) AddPlanCond(planType sqlparser.PlanType) {
func (qr *QueryRule) AddPlanCond(planType planbuilder.PlanType) {
qr.plans = append(qr.plans, planType)
}
@ -279,7 +279,7 @@ Error:
// The new QueryRule will contain all the original constraints other
// than the plan and query. If the plan and query don't match the QueryRule,
// then it returns nil.
func (qr *QueryRule) filterByPlan(query string, planid sqlparser.PlanType, tableName string) (newqr *QueryRule) {
func (qr *QueryRule) filterByPlan(query string, planid planbuilder.PlanType, tableName string) (newqr *QueryRule) {
if !reMatch(qr.query, query) {
return nil
}
@ -315,7 +315,7 @@ func reMatch(re *regexp.Regexp, val string) bool {
return re == nil || re.MatchString(val)
}
func planMatch(plans []sqlparser.PlanType, plan sqlparser.PlanType) bool {
func planMatch(plans []planbuilder.PlanType, plan planbuilder.PlanType) bool {
if plans == nil {
return true
}
@ -728,7 +728,7 @@ func buildQueryRule(ruleInfo map[string]interface{}) (qr *QueryRule, err error)
if !ok {
return nil, NewTabletError(FAIL, "want string for Plans")
}
pt, ok := sqlparser.PlanByName(pv)
pt, ok := planbuilder.PlanByName(pv)
if !ok {
return nil, NewTabletError(FAIL, "invalid plan name: %s", pv)
}

Просмотреть файл

@ -41,6 +41,7 @@ func init() {
flag.Float64Var(&qsConfig.IdleTimeout, "queryserver-config-idle-timeout", DefaultQsConfig.IdleTimeout, "query server idle timeout")
flag.Float64Var(&qsConfig.SpotCheckRatio, "queryserver-config-spot-check-ratio", DefaultQsConfig.SpotCheckRatio, "query server rowcache spot check frequency")
flag.BoolVar(&qsConfig.StrictMode, "queryserver-config-strict-mode", DefaultQsConfig.StrictMode, "allow only predictable DMLs and enforces MySQL's STRICT_TRANS_TABLES")
flag.BoolVar(&qsConfig.StrictTableAcl, "queryserver-config-strict-table-acl", DefaultQsConfig.StrictTableAcl, "only allow queries that pass table acl checks")
flag.StringVar(&qsConfig.RowCache.Binary, "rowcache-bin", DefaultQsConfig.RowCache.Binary, "rowcache binary file")
flag.IntVar(&qsConfig.RowCache.Memory, "rowcache-memory", DefaultQsConfig.RowCache.Memory, "rowcache max memory usage in MB")
flag.StringVar(&qsConfig.RowCache.Socket, "rowcache-socket", DefaultQsConfig.RowCache.Socket, "rowcache socket path to listen on")
@ -102,6 +103,7 @@ type Config struct {
RowCache RowCacheConfig
SpotCheckRatio float64
StrictMode bool
StrictTableAcl bool
}
// DefaultQSConfig is the default value for the query service config.
@ -126,6 +128,7 @@ var DefaultQsConfig = Config{
RowCache: RowCacheConfig{Memory: -1, TcpPort: -1, Connections: -1, Threads: -1},
SpotCheckRatio: 0,
StrictMode: true,
StrictTableAcl: false,
}
var qsConfig Config

Просмотреть файл

@ -12,7 +12,7 @@ import (
"time"
"github.com/youtube/vitess/go/acl"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
)
var (
@ -52,7 +52,7 @@ var (
type queryzRow struct {
Query string
Table string
Plan sqlparser.PlanType
Plan planbuilder.PlanType
Count int64
tm time.Duration
Rows int64

Просмотреть файл

@ -22,7 +22,8 @@ import (
"github.com/youtube/vitess/go/timer"
"github.com/youtube/vitess/go/vt/dbconnpool"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tableacl"
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
)
const base_show_tables = "select table_name, table_type, unix_timestamp(create_time), table_comment from information_schema.tables where table_schema = database()"
@ -30,10 +31,11 @@ const base_show_tables = "select table_name, table_type, unix_timestamp(create_t
const maxTableCount = 10000
type ExecPlan struct {
*sqlparser.ExecPlan
TableInfo *TableInfo
Fields []mproto.Field
Rules *QueryRules
*planbuilder.ExecPlan
TableInfo *TableInfo
Fields []mproto.Field
Rules *QueryRules
Authorized tableacl.ACL
mu sync.Mutex
QueryCount int64
@ -318,12 +320,13 @@ func (si *SchemaInfo) GetPlan(logStats *SQLQueryStats, sql string) (plan *ExecPl
}
return tableInfo.Table, true
}
splan, err := sqlparser.ExecParse(sql, GetTable)
splan, err := planbuilder.GetExecPlan(sql, GetTable)
if err != nil {
panic(NewTabletError(FAIL, "%s", err))
}
plan = &ExecPlan{ExecPlan: splan, TableInfo: tableInfo}
plan.Rules = si.rules.filterByPlan(sql, plan.PlanId, plan.TableName)
plan.Authorized = tableacl.Authorized(plan.TableName, plan.PlanId.MinRole())
if plan.PlanId.IsSelect() {
if plan.FieldQuery == nil {
log.Warningf("Cannot cache field info: %s", sql)
@ -340,7 +343,7 @@ func (si *SchemaInfo) GetPlan(logStats *SQLQueryStats, sql string) (plan *ExecPl
}
plan.Fields = r.Fields
}
} else if plan.PlanId == sqlparser.PLAN_DDL || plan.PlanId == sqlparser.PLAN_SET {
} else if plan.PlanId == planbuilder.PLAN_DDL || plan.PlanId == planbuilder.PLAN_SET {
return plan
}
si.queries.Set(sql, plan)
@ -349,8 +352,15 @@ func (si *SchemaInfo) GetPlan(logStats *SQLQueryStats, sql string) (plan *ExecPl
// GetStreamPlan is similar to GetPlan, but doesn't use the cache
// and doesn't enforce a limit. It also just returns the parsed query.
func (si *SchemaInfo) GetStreamPlan(sql string) *sqlparser.StreamExecPlan {
plan, err := sqlparser.StreamExecParse(sql)
func (si *SchemaInfo) GetStreamPlan(sql string) *planbuilder.ExecPlan {
GetTable := func(tableName string) (*schema.Table, bool) {
tableInfo, ok := si.tables[tableName]
if !ok {
return nil, false
}
return tableInfo.Table, true
}
plan, err := planbuilder.GetStreamExecPlan(sql, GetTable)
if err != nil {
panic(NewTabletError(FAIL, "%s", err))
}
@ -489,7 +499,7 @@ func (si *SchemaInfo) getQueryStats(f queryStatsFunc) map[string]int64 {
type perQueryStats struct {
Query string
Table string
Plan sqlparser.PlanType
Plan planbuilder.PlanType
QueryCount int64
Time time.Duration
RowCount int64

Просмотреть файл

@ -23,6 +23,9 @@ type SrvShard struct {
KeyRange key.KeyRange
ServedTypes []TabletType
// MasterCell indicates the cell that master tablet resides
MasterCell string
// TabletTypes represents the list of types we have serving tablets
// for, in this cell only.
TabletTypes []TabletType

Просмотреть файл

@ -9,7 +9,6 @@ package topo
import (
"bytes"
"github.com/youtube/vitess/go/bson"
"github.com/youtube/vitess/go/bytes2"
)
@ -30,6 +29,7 @@ func (srvShard *SrvShard) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
}
lenWriter.Close()
}
bson.EncodeString(buf, "MasterCell", srvShard.MasterCell)
// []TabletType
{
bson.EncodePrefix(buf, bson.Array, "TabletTypes")
@ -76,6 +76,8 @@ func (srvShard *SrvShard) UnmarshalBson(buf *bytes.Buffer, kind byte) {
srvShard.ServedTypes = append(srvShard.ServedTypes, _v1)
}
}
case "MasterCell":
srvShard.MasterCell = bson.DecodeString(buf, kind)
case "TabletTypes":
// []TabletType
if kind != bson.Null {

Просмотреть файл

@ -41,6 +41,7 @@ func TestSrvKeySpace(t *testing.T) {
SrvShard{
Name: "test_shard",
ServedTypes: []TabletType{TYPE_MASTER},
MasterCell: "test_cell",
},
},
},
@ -64,6 +65,7 @@ func TestSrvKeySpace(t *testing.T) {
SrvShard{
Name: "test_shard",
ServedTypes: []TabletType{TYPE_MASTER},
MasterCell: "test_cell",
TabletTypes: []TabletType{},
},
},

Просмотреть файл

@ -207,6 +207,7 @@ func rebuildCellSrvShard(ts topo.Server, shardInfo *topo.ShardInfo, cell string,
Name: shardInfo.ShardName(),
KeyRange: shardInfo.KeyRange,
ServedTypes: shardInfo.ServedTypes,
MasterCell: shardInfo.MasterAlias.Cell,
TabletTypes: make([]topo.TabletType, 0, len(locationAddrsMap)),
}
for tabletType := range locationAddrsMap {

290
go/vt/vtgate/router.go Normal file
Просмотреть файл

@ -0,0 +1,290 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vtgate
import (
"fmt"
"strconv"
"github.com/youtube/vitess/go/vt/key"
"github.com/youtube/vitess/go/vt/sqlparser"
)
const (
EID_NODE = iota
VALUE_NODE
LIST_NODE
OTHER_NODE
)
type RoutingPlan struct {
criteria sqlparser.SQLNode
}
func GetShardList(sql string, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardlist []int, err error) {
plan, err := buildPlan(sql)
if err != nil {
return nil, err
}
return shardListFromPlan(plan, bindVariables, tabletKeys)
}
func buildPlan(sql string) (plan *RoutingPlan, err error) {
statement, err := sqlparser.Parse(sql)
if err != nil {
return nil, err
}
return getRoutingPlan(statement)
}
func shardListFromPlan(plan *RoutingPlan, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardList []int, err error) {
if plan.criteria == nil {
return makeList(0, len(tabletKeys)), nil
}
switch criteria := plan.criteria.(type) {
case sqlparser.Values:
index, err := findInsertShard(criteria, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return []int{index}, nil
case *sqlparser.ComparisonExpr:
switch criteria.Operator {
case "=", "<=>":
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return []int{index}, nil
case "<", "<=":
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return makeList(0, index+1), nil
case ">", ">=":
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
return makeList(index, len(tabletKeys)), nil
case "in":
return findShardList(criteria.Right, bindVariables, tabletKeys)
}
case *sqlparser.RangeCond:
if criteria.Operator == "between" {
start, err := findShard(criteria.From, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
last, err := findShard(criteria.To, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
if last < start {
start, last = last, start
}
return makeList(start, last+1), nil
}
}
return makeList(0, len(tabletKeys)), nil
}
func getRoutingPlan(statement sqlparser.Statement) (plan *RoutingPlan, err error) {
plan = &RoutingPlan{}
if ins, ok := statement.(*sqlparser.Insert); ok {
if sel, ok := ins.Rows.(sqlparser.SelectStatement); ok {
return getRoutingPlan(sel)
}
plan.criteria, err = routingAnalyzeValues(ins.Rows.(sqlparser.Values))
if err != nil {
return nil, err
}
return plan, nil
}
var where *sqlparser.Where
switch stmt := statement.(type) {
case *sqlparser.Select:
where = stmt.Where
case *sqlparser.Update:
where = stmt.Where
case *sqlparser.Delete:
where = stmt.Where
}
if where != nil {
plan.criteria = routingAnalyzeBoolean(where.Expr)
}
return plan, nil
}
func routingAnalyzeValues(vals sqlparser.Values) (sqlparser.Values, error) {
// Analyze first value of every item in the list
for i := 0; i < len(vals); i++ {
switch tuple := vals[i].(type) {
case sqlparser.ValTuple:
result := routingAnalyzeValue(tuple[0])
if result != VALUE_NODE {
return nil, fmt.Errorf("insert is too complex")
}
default:
return nil, fmt.Errorf("insert is too complex")
}
}
return vals, nil
}
func routingAnalyzeBoolean(node sqlparser.BoolExpr) sqlparser.BoolExpr {
switch node := node.(type) {
case *sqlparser.AndExpr:
left := routingAnalyzeBoolean(node.Left)
right := routingAnalyzeBoolean(node.Right)
if left != nil && right != nil {
return nil
} else if left != nil {
return left
} else {
return right
}
case *sqlparser.ParenBoolExpr:
return routingAnalyzeBoolean(node.Expr)
case *sqlparser.ComparisonExpr:
switch {
case sqlparser.StringIn(node.Operator, "=", "<", ">", "<=", ">=", "<=>"):
left := routingAnalyzeValue(node.Left)
right := routingAnalyzeValue(node.Right)
if (left == EID_NODE && right == VALUE_NODE) || (left == VALUE_NODE && right == EID_NODE) {
return node
}
case node.Operator == "in":
left := routingAnalyzeValue(node.Left)
right := routingAnalyzeValue(node.Right)
if left == EID_NODE && right == LIST_NODE {
return node
}
}
case *sqlparser.RangeCond:
if node.Operator != "between" {
return nil
}
left := routingAnalyzeValue(node.Left)
from := routingAnalyzeValue(node.From)
to := routingAnalyzeValue(node.To)
if left == EID_NODE && from == VALUE_NODE && to == VALUE_NODE {
return node
}
}
return nil
}
func routingAnalyzeValue(valExpr sqlparser.ValExpr) int {
switch node := valExpr.(type) {
case *sqlparser.ColName:
if string(node.Name) == "entity_id" {
return EID_NODE
}
case sqlparser.ValTuple:
for _, n := range node {
if routingAnalyzeValue(n) != VALUE_NODE {
return OTHER_NODE
}
}
return LIST_NODE
case sqlparser.StrVal, sqlparser.NumVal, sqlparser.ValArg:
return VALUE_NODE
}
return OTHER_NODE
}
func findShardList(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) ([]int, error) {
shardset := make(map[int]bool)
switch node := valExpr.(type) {
case sqlparser.ValTuple:
for _, n := range node {
index, err := findShard(n, bindVariables, tabletKeys)
if err != nil {
return nil, err
}
shardset[index] = true
}
}
shardlist := make([]int, len(shardset))
index := 0
for k := range shardset {
shardlist[index] = k
index++
}
return shardlist, nil
}
func findInsertShard(vals sqlparser.Values, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
index := -1
for i := 0; i < len(vals); i++ {
first_value_expression := vals[i].(sqlparser.ValTuple)[0]
newIndex, err := findShard(first_value_expression, bindVariables, tabletKeys)
if err != nil {
return -1, err
}
if index == -1 {
index = newIndex
} else if index != newIndex {
return -1, fmt.Errorf("insert has multiple shard targets")
}
}
return index, nil
}
func findShard(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
value, err := getBoundValue(valExpr, bindVariables)
if err != nil {
return -1, err
}
return key.FindShardForValue(value, tabletKeys), nil
}
func getBoundValue(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}) (string, error) {
switch node := valExpr.(type) {
case sqlparser.ValTuple:
if len(node) != 1 {
return "", fmt.Errorf("tuples not allowed as insert values")
}
// TODO: Change parser to create single value tuples into non-tuples.
return getBoundValue(node[0], bindVariables)
case sqlparser.StrVal:
return string(node), nil
case sqlparser.NumVal:
val, err := strconv.ParseInt(string(node), 10, 64)
if err != nil {
return "", err
}
return key.Uint64Key(val).String(), nil
case sqlparser.ValArg:
value, err := findBindValue(node, bindVariables)
if err != nil {
return "", err
}
return key.EncodeValue(value), nil
}
panic("Unexpected token")
}
func findBindValue(valArg sqlparser.ValArg, bindVariables map[string]interface{}) (interface{}, error) {
if bindVariables == nil {
return nil, fmt.Errorf("No bind variable for " + string(valArg))
}
value, ok := bindVariables[string(valArg[1:])]
if !ok {
return nil, fmt.Errorf("No bind variable for " + string(valArg))
}
return value, nil
}
func makeList(start, end int) []int {
list := make([]int, end-start)
for i := start; i < end; i++ {
list[i-start] = i
}
return list
}

105
go/vt/vtgate/router_test.go Normal file
Просмотреть файл

@ -0,0 +1,105 @@
// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package vtgate
import (
"bufio"
"fmt"
"io"
"os"
"sort"
"strings"
"testing"
"github.com/youtube/vitess/go/testfiles"
"github.com/youtube/vitess/go/vt/key"
)
func TestRouting(t *testing.T) {
tabletkeys := []key.KeyspaceId{
"\x00\x00\x00\x00\x00\x00\x00\x02",
"\x00\x00\x00\x00\x00\x00\x00\x04",
"\x00\x00\x00\x00\x00\x00\x00\x06",
"a",
"b",
"d",
}
bindVariables := make(map[string]interface{})
bindVariables["id0"] = 0
bindVariables["id2"] = 2
bindVariables["id3"] = 3
bindVariables["id4"] = 4
bindVariables["id6"] = 6
bindVariables["id8"] = 8
bindVariables["ids"] = []interface{}{1, 4}
bindVariables["a"] = "a"
bindVariables["b"] = "b"
bindVariables["c"] = "c"
bindVariables["d"] = "d"
bindVariables["e"] = "e"
for tcase := range iterateFiles("sqlparser_test/routing_cases.txt") {
if tcase.output == "" {
tcase.output = tcase.input
}
out, err := GetShardList(tcase.input, bindVariables, tabletkeys)
if err != nil {
if err.Error() != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.input, err))
}
continue
}
sort.Ints(out)
outstr := fmt.Sprintf("%v", out)
if outstr != tcase.output {
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, outstr))
}
}
}
// TODO(sougou): This is now duplicated in three plcaes. Refactor.
type testCase struct {
file string
lineno int
input string
output string
}
func iterateFiles(pattern string) (testCaseIterator chan testCase) {
names := testfiles.Glob(pattern)
testCaseIterator = make(chan testCase)
go func() {
defer close(testCaseIterator)
for _, name := range names {
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
panic(fmt.Sprintf("Could not open file %s", name))
}
r := bufio.NewReader(fd)
lineno := 0
for {
line, err := r.ReadString('\n')
lines := strings.Split(strings.TrimRight(line, "\n"), "#")
lineno++
if err != nil {
if err != io.EOF {
panic(fmt.Sprintf("Error reading file %s: %s", name, err.Error()))
}
break
}
input := lines[0]
output := ""
if len(lines) > 1 {
output = lines[1]
}
if input == "" {
continue
}
testCaseIterator <- testCase{name, lineno, input, output}
}
}
}()
return testCaseIterator
}

Просмотреть файл

@ -51,6 +51,15 @@ func (wr *Wrangler) checkSlaveReplication(tabletMap map[topo.TabletAlias]*topo.T
go func(tablet *topo.TabletInfo) {
defer wg.Done()
var err error
defer func() {
if err != nil {
mutex.Lock()
lastError = err
mutex.Unlock()
}
}()
if tablet.Type == topo.TYPE_LAG {
log.Infof(" skipping slave position check for %v tablet %v", tablet.Type, tablet.Alias)
return
@ -58,9 +67,6 @@ func (wr *Wrangler) checkSlaveReplication(tabletMap map[topo.TabletAlias]*topo.T
replPos, err := wr.ai.SlavePosition(tablet, wr.actionTimeout())
if err != nil {
mutex.Lock()
lastError = err
mutex.Unlock()
if tablet.Type == topo.TYPE_BACKUP {
log.Warningf(" failed to get slave position from backup tablet %v, either wait for backup to finish or scrap tablet (%v)", tablet.Alias, err)
} else {
@ -70,13 +76,18 @@ func (wr *Wrangler) checkSlaveReplication(tabletMap map[topo.TabletAlias]*topo.T
}
if !masterIsDead {
// This case used to be handled by the timeout check below, but checking
// it explicitly provides a more informative error message.
if replPos.SecondsBehindMaster == myproto.InvalidLagSeconds {
err = fmt.Errorf("slave %v is not replicating (Slave_IO or Slave_SQL not running), can't complete reparent in time", tablet.Alias)
log.Errorf(" %v", err)
return
}
var dur time.Duration = time.Duration(uint(time.Second) * replPos.SecondsBehindMaster)
if dur > wr.actionTimeout() {
err = fmt.Errorf("slave is too far behind to complete reparent in time (%v>%v), either increase timeout using 'vtctl -wait-time XXX ReparentShard ...' or scrap tablet %v", dur, wr.actionTimeout(), tablet.Alias)
log.Errorf(" %v", err)
mutex.Lock()
lastError = err
mutex.Unlock()
return
}

Просмотреть файл

@ -93,10 +93,10 @@ func (wr *Wrangler) shardExternallyReparentedLocked(keyspace, shard string, mast
}
// Compute the list of Cells we need to rebuild: old master and
// new master cells.
// all other cells if reparenting to another cell.
cells := []string{shardInfo.MasterAlias.Cell}
if shardInfo.MasterAlias.Cell != masterElectTabletAlias.Cell {
cells = append(cells, masterElectTabletAlias.Cell)
cells = nil
}
// now update the master record in the shard object

Просмотреть файл

@ -32,6 +32,9 @@ class KeyRange(codec.BSONCoding):
return keyrange_constants.NON_PARTIAL_KEYRANGE
return '%s-%s' % (self.Start, self.End)
def __repr__(self):
return 'KeyRange(%r-%r)' % (self.Start, self.End)
def bson_encode(self):
return {"Start": self.Start, "End": self.End}

Просмотреть файл

@ -82,17 +82,20 @@ class VTGateCursor(object):
# FIXME(shrutip): these checks maybe better on vtgate server.
if topology.is_sharded_keyspace(self.keyspace, self.tablet_type):
if self.keyspace_ids is None or len(self.keyspace_ids) != 1:
raise dbexceptions.ProgrammingError('DML on zero or multiple keyspace ids is not allowed')
raise dbexceptions.ProgrammingError('DML on zero or multiple keyspace ids is not allowed: %r'
% self.keyspace_ids)
else:
if not self.keyranges or str(self.keyranges[0]) != keyrange_constants.NON_PARTIAL_KEYRANGE:
raise dbexceptions.ProgrammingError('Keyrange not correct for non-sharded keyspace')
raise dbexceptions.ProgrammingError('Keyrange not correct for non-sharded keyspace: %r'
% self.keyranges)
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(sql,
bind_variables,
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges)
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges)
self.index = 0
return self.rowcount
@ -107,12 +110,13 @@ class VTGateCursor(object):
if write_query:
raise dbexceptions.DatabaseError('execute_entity_ids is not allowed for write queries')
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute_entity_ids(sql,
bind_variables,
self.keyspace,
self.tablet_type,
entity_keyspace_id_map,
entity_column_name)
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute_entity_ids(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
entity_keyspace_id_map,
entity_column_name)
self.index = 0
return self.rowcount
@ -225,12 +229,13 @@ class StreamVTGateCursor(VTGateCursor):
raise dbexceptions.ProgrammingError('Streaming query cannot be writable')
self.description = None
x, y, z, self.description = self._conn._stream_execute(sql,
bind_variables,
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges)
x, y, z, self.description = self._conn._stream_execute(
sql,
bind_variables,
self.keyspace,
self.tablet_type,
keyspace_ids=self.keyspace_ids,
keyranges=self.keyranges)
self.index = 0
return 0
@ -241,9 +246,9 @@ class StreamVTGateCursor(VTGateCursor):
self.index += 1
return self._conn._stream_next()
# fetchmany can be called until it returns no rows. Returning less rows
# than what we asked for is also an indication we ran out, but the cursor
# API in PEP249 is silent about that.
# fetchmany can be called until it returns no rows. Returning less rows
# than what we asked for is also an indication we ran out, but the cursor
# API in PEP249 is silent about that.
def fetchmany(self, size=None):
if size is None:
size = self.arraysize

51
test/mysql_flavor.py Normal file
Просмотреть файл

@ -0,0 +1,51 @@
#!/usr/bin/python
import environment
class MysqlFlavor(object):
"""Base class with default SQL statements"""
def promote_slave_commands(self):
"""Returns commands to convert a slave to a master."""
return [
"RESET MASTER",
"STOP SLAVE",
"RESET SLAVE",
"CHANGE MASTER TO MASTER_HOST = ''",
]
def reset_replication_commands(self):
return [
"RESET MASTER",
"STOP SLAVE",
"RESET SLAVE",
'CHANGE MASTER TO MASTER_HOST = ""',
]
class GoogleMysql(MysqlFlavor):
"""Overrides specific to Google MySQL"""
class MariaDB(MysqlFlavor):
"""Overrides specific to MariaDB"""
def promote_slave_commands(self):
return [
"RESET MASTER",
"STOP SLAVE",
"RESET SLAVE",
]
def reset_replication_commands(self):
return [
"RESET MASTER",
"STOP SLAVE",
"RESET SLAVE",
]
if environment.mysql_flavor == "MariaDB":
mysql_flavor = MariaDB()
else:
mysql_flavor = GoogleMysql()

Просмотреть файл

@ -490,3 +490,77 @@ class TestNocache(framework.TestCase):
error_count = self.env.run_cases(nocache_cases.cases)
if error_count != 0:
self.fail("test_execution errors: %d"%(error_count))
def test_table_acl_no_access(self):
self.env.conn.begin()
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("select * from vtocc_acl_no_access where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("delete from vtocc_acl_no_access where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("alter table vtocc_acl_no_access comment 'comment'")
self.env.conn.commit()
cu = cursor.StreamCursor(self.env.conn)
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
cu.execute("select * from vtocc_acl_no_access where key1=1", {})
cu.close()
def test_table_acl_read_only(self):
self.env.conn.begin()
self.env.execute("select * from vtocc_acl_read_only where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("delete from vtocc_acl_read_only where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("alter table vtocc_acl_read_only comment 'comment'")
self.env.conn.commit()
cu = cursor.StreamCursor(self.env.conn)
cu.execute("select * from vtocc_acl_read_only where key1=1", {})
cu.fetchall()
cu.close()
def test_table_acl_read_write(self):
self.env.conn.begin()
self.env.execute("select * from vtocc_acl_read_write where key1=1")
self.env.execute("delete from vtocc_acl_read_write where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("alter table vtocc_acl_read_write comment 'comment'")
self.env.conn.commit()
cu = cursor.StreamCursor(self.env.conn)
cu.execute("select * from vtocc_acl_read_write where key1=1", {})
cu.fetchall()
cu.close()
def test_table_acl_admin(self):
self.env.conn.begin()
self.env.execute("select * from vtocc_acl_admin where key1=1")
self.env.execute("delete from vtocc_acl_admin where key1=1")
self.env.execute("alter table vtocc_acl_admin comment 'comment'")
self.env.conn.commit()
cu = cursor.StreamCursor(self.env.conn)
cu.execute("select * from vtocc_acl_admin where key1=1", {})
cu.fetchall()
cu.close()
def test_table_acl_unmatched(self):
self.env.conn.begin()
self.env.execute("select * from vtocc_acl_unmatched where key1=1")
self.env.execute("delete from vtocc_acl_unmatched where key1=1")
self.env.execute("alter table vtocc_acl_unmatched comment 'comment'")
self.env.conn.commit()
cu = cursor.StreamCursor(self.env.conn)
cu.execute("select * from vtocc_acl_unmatched where key1=1", {})
cu.fetchall()
cu.close()
def test_table_acl_all_user_read_only(self):
self.env.conn.begin()
self.env.execute("select * from vtocc_acl_all_user_read_only where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("delete from vtocc_acl_all_user_read_only where key1=1")
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
self.env.execute("alter table vtocc_acl_all_user_read_only comment 'comment'")
self.env.conn.commit()
cu = cursor.StreamCursor(self.env.conn)
cu.execute("select * from vtocc_acl_all_user_read_only where key1=1", {})
cu.fetchall()
cu.close()

Просмотреть файл

@ -52,7 +52,7 @@ class TestEnv(object):
return "localhost:%s" % self.port
def connect(self):
c = tablet_conn.connect(self.address, '', 'test_keyspace', '0', 2)
c = tablet_conn.connect(self.address, '', 'test_keyspace', '0', 2, user='youtube-dev-dedicated', password='vtpass')
c.max_attempts = 1
return c
@ -196,7 +196,14 @@ class VttabletTestEnv(TestEnv):
self.create_customrules(customrules)
schema_override = os.path.join(environment.tmproot, 'schema_override.json')
self.create_schema_override(schema_override)
self.tablet.start_vttablet(memcache=self.memcache, customrules=customrules, schema_override=schema_override)
table_acl_config = os.path.join(environment.vttop, 'test', 'test_data', 'table_acl_config.json')
self.tablet.start_vttablet(
memcache=self.memcache,
customrules=customrules,
schema_override=schema_override,
table_acl_config=table_acl_config,
auth=True,
)
# FIXME(szopa): This is necessary here only because of a bug that
# makes the qs reload its config only after an action.
@ -281,12 +288,15 @@ class VtoccTestEnv(TestEnv):
self.create_customrules(customrules)
schema_override = os.path.join(environment.tmproot, 'schema_override.json')
self.create_schema_override(schema_override)
table_acl_config = os.path.join(environment.vttop, 'test', 'test_data', 'table_acl_config.json')
occ_args = environment.binary_args('vtocc') + [
"-port", str(self.port),
"-customrules", customrules,
"-log_dir", environment.vtlogroot,
"-schema-override", schema_override,
"-table-acl-config", table_acl_config,
"-queryserver-config-strict-table-acl",
"-db-config-app-charset", "utf8",
"-db-config-app-dbname", "vt_test_keyspace",
"-db-config-app-host", "localhost",
@ -294,7 +304,9 @@ class VtoccTestEnv(TestEnv):
"-db-config-app-uname", 'vt_dba', # use vt_dba as some tests depend on 'drop'
"-db-config-app-keyspace", "test_keyspace",
"-db-config-app-shard", "0",
"-auth-credentials", os.path.join(environment.vttop, 'test', 'test_data', 'authcredentials_test.json'),
]
if self.memcache:
memcache = self.mysqldir+"/memcache.sock"
occ_args.extend(["-rowcache-bin", environment.memcached_bin()])
@ -346,11 +358,6 @@ class VtoccTestEnv(TestEnv):
except:
pass
def connect(self):
c = tablet_conn.connect("localhost:%s" % self.port, '', 'test_keyspace', '0', 2)
c.max_attempts = 1
return c
def mysql_connect(self):
return mysql.connect(
host='localhost',

Просмотреть файл

@ -16,6 +16,7 @@ import unittest
import environment
import utils
import tablet
from mysql_flavor import mysql_flavor
tablet_62344 = tablet.Tablet(62344)
tablet_62044 = tablet.Tablet(62044)
@ -70,8 +71,8 @@ class TestReparent(unittest.TestCase):
t.clean_dbs()
super(TestReparent, self).tearDown()
def _check_db_addr(self, shard, db_type, expected_port):
ep = utils.run_vtctl_json(['GetEndPoints', 'test_nj', 'test_keyspace/'+shard, db_type])
def _check_db_addr(self, shard, db_type, expected_port, cell='test_nj'):
ep = utils.run_vtctl_json(['GetEndPoints', cell, 'test_keyspace/'+shard, db_type])
self.assertEqual(len(ep['entries']), 1 , 'Wrong number of entries: %s' % str(ep))
port = ep['entries'][0]['named_port_map']['_vtocc']
self.assertEqual(port, expected_port, 'Unexpected port: %u != %u from %s' % (port, expected_port, str(ep)))
@ -184,6 +185,66 @@ class TestReparent(unittest.TestCase):
# so the other tests don't have any surprise
tablet_62344.start_mysql().wait()
def test_reparent_cross_cell(self, shard_id='0'):
utils.run_vtctl('CreateKeyspace test_keyspace')
# create the database so vttablets start, as they are serving
tablet_62344.create_db('vt_test_keyspace')
tablet_62044.create_db('vt_test_keyspace')
tablet_41983.create_db('vt_test_keyspace')
tablet_31981.create_db('vt_test_keyspace')
# Start up a master mysql and vttablet
tablet_62344.init_tablet('master', 'test_keyspace', shard_id, start=True)
if environment.topo_server_implementation == 'zookeeper':
shard = utils.run_vtctl_json(['GetShard', 'test_keyspace/'+shard_id])
self.assertEqual(shard['Cells'], ['test_nj'], 'wrong list of cell in Shard: %s' % str(shard['Cells']))
# Create a few slaves for testing reparenting.
tablet_62044.init_tablet('replica', 'test_keyspace', shard_id, start=True, wait_for_start=False)
tablet_41983.init_tablet('replica', 'test_keyspace', shard_id, start=True, wait_for_start=False)
tablet_31981.init_tablet('replica', 'test_keyspace', shard_id, start=True, wait_for_start=False)
for t in [tablet_62044, tablet_41983, tablet_31981]:
t.wait_for_vttablet_state("SERVING")
if environment.topo_server_implementation == 'zookeeper':
shard = utils.run_vtctl_json(['GetShard', 'test_keyspace/'+shard_id])
self.assertEqual(shard['Cells'], ['test_nj', 'test_ny'], 'wrong list of cell in Shard: %s' % str(shard['Cells']))
# Recompute the shard layout node - until you do that, it might not be valid.
utils.run_vtctl('RebuildShardGraph test_keyspace/' + shard_id)
utils.validate_topology()
# Force the slaves to reparent assuming that all the datasets are identical.
for t in [tablet_62344, tablet_62044, tablet_41983, tablet_31981]:
t.reset_replication()
utils.pause("force ReparentShard?")
utils.run_vtctl('ReparentShard -force test_keyspace/%s %s' % (shard_id, tablet_62344.tablet_alias))
utils.validate_topology(ping_tablets=True)
self._check_db_addr(shard_id, 'master', tablet_62344.port)
# Verify MasterCell is properly set
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
# Perform a graceful reparent operation to another cell.
utils.pause("graceful ReparentShard?")
utils.run_vtctl('ReparentShard test_keyspace/%s %s' % (shard_id, tablet_31981.tablet_alias), auto_log=True)
utils.validate_topology()
self._check_db_addr(shard_id, 'master', tablet_31981.port, cell='test_ny')
# Verify MasterCell is set to new cell.
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_ny')
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_ny')
tablet.kill_tablets([tablet_62344, tablet_62044, tablet_41983, tablet_31981])
def test_reparent_graceful_range_based(self):
shard_id = '0000000000000000-FFFFFFFFFFFFFFFF'
self._test_reparent_graceful(shard_id)
@ -230,6 +291,12 @@ class TestReparent(unittest.TestCase):
self._check_db_addr(shard_id, 'master', tablet_62344.port)
# Verify MasterCell is set to new cell.
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
# Convert two replica to spare. That should leave only one node serving traffic,
# but still needs to appear in the replication graph.
utils.run_vtctl(['ChangeSlaveType', tablet_41983.tablet_alias, 'spare'])
@ -247,6 +314,12 @@ class TestReparent(unittest.TestCase):
self._check_db_addr(shard_id, 'master', tablet_62044.port)
# Verify MasterCell is set to new cell.
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
tablet.kill_tablets([tablet_62344, tablet_62044, tablet_41983, tablet_31981])
# Test address correction.
@ -337,12 +410,7 @@ class TestReparent(unittest.TestCase):
# now manually reparent 1 out of 2 tablets
# 62044 will be the new master
# 31981 won't be re-parented, so it will be busted
tablet_62044.mquery('', [
"RESET MASTER",
"STOP SLAVE",
"RESET SLAVE",
"CHANGE MASTER TO MASTER_HOST = ''",
])
tablet_62044.mquery('', mysql_flavor.promote_slave_commands())
new_pos = tablet_62044.mquery('', 'show master status')
logging.debug("New master position: %s" % str(new_pos))
@ -439,7 +507,6 @@ class TestReparent(unittest.TestCase):
utils.run_vtctl('ReparentShard -force test_keyspace/%s %s' % (shard_id, tablet_62344.tablet_alias))
utils.validate_topology(ping_tablets=True)
tablet_62344.create_db('vt_test_keyspace')
tablet_62344.mquery('vt_test_keyspace', self._create_vt_insert_test)
tablet_41983.mquery('', 'stop slave')

Просмотреть файл

@ -13,6 +13,7 @@ import MySQLdb
import environment
import utils
from mysql_flavor import mysql_flavor
tablet_cell_map = {
62344: 'nj',
@ -148,14 +149,7 @@ class Tablet(object):
raise utils.TestError("expected %u rows in %s" % (n, table), result)
def reset_replication(self):
commands = [
'RESET MASTER',
'STOP SLAVE',
'RESET SLAVE',
]
if environment.mysql_flavor == "GoogleMysql":
commands.append('CHANGE MASTER TO MASTER_HOST = ""')
self.mquery('', commands)
self.mquery('', mysql_flavor.reset_replication_commands())
def populate(self, dbname, create_sql, insert_sqls=[]):
self.create_db(dbname)
@ -283,7 +277,7 @@ class Tablet(object):
def start_vttablet(self, port=None, auth=False, memcache=False,
wait_for_state="SERVING", customrules=None,
schema_override=None, cert=None, key=None, ca_cert=None,
repl_extra_flags={},
repl_extra_flags={},table_acl_config=None,
target_tablet_type=None, lameduck_period=None,
extra_args=None, full_mycnf_args=False,
security_policy=None):
@ -350,6 +344,10 @@ class Tablet(object):
if schema_override:
args.extend(['-schema-override', schema_override])
if table_acl_config:
args.extend(['-table-acl-config', table_acl_config])
args.extend(['-queryserver-config-strict-table-acl'])
if cert:
self.secure_port = environment.reserve_ports(1)
args.extend(['-secure-port', '%s' % self.secure_port,

Просмотреть файл

@ -79,6 +79,8 @@ class TestTabletManager(unittest.TestCase):
tablet_62344.init_tablet('master', 'test_keyspace', '0', parent=False)
utils.run_vtctl(['RebuildKeyspaceGraph', 'test_keyspace'])
utils.validate_topology()
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
# if these statements don't run before the tablet it will wedge waiting for the
# db to become accessible. this is more a bug than a feature.
@ -116,6 +118,8 @@ class TestTabletManager(unittest.TestCase):
# not pinging tablets, as it enables replication checks, and they
# break because we only have a single master, no slaves
utils.run_vtctl('ValidateShard -ping-tablets=false test_keyspace/0')
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
tablet_62344.kill_vttablet()
@ -129,6 +133,8 @@ class TestTabletManager(unittest.TestCase):
tablet_62344.init_tablet('master', 'test_keyspace', '0', parent=False)
utils.run_vtctl(['RebuildKeyspaceGraph', 'test_keyspace'])
utils.validate_topology()
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
# if these statements don't run before the tablet it will wedge waiting for the
# db to become accessible. this is more a bug than a feature.
@ -226,9 +232,13 @@ class TestTabletManager(unittest.TestCase):
tablet_62044.init_tablet('replica', 'test_keyspace', '0')
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/*'])
utils.validate_topology()
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
tablet_62044.scrap(force=True)
utils.validate_topology()
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
_create_vt_select_test = '''create table vt_select_test (
@ -249,6 +259,8 @@ class TestTabletManager(unittest.TestCase):
tablet_62344.init_tablet('master', 'test_keyspace', '0')
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/0'])
utils.validate_topology()
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
tablet_62344.create_db('vt_test_keyspace')
tablet_62344.start_vttablet()
@ -306,6 +318,8 @@ class TestTabletManager(unittest.TestCase):
tablet_62344.init_tablet('master', 'test_keyspace', '0')
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/0'])
utils.validate_topology()
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
self.assertEqual(srvShard['MasterCell'], 'test_nj')
tablet_62344.populate('vt_test_keyspace', self._create_vt_select_test,
self._populate_vt_select_test)

Просмотреть файл

@ -1,3 +1,4 @@
{
"ala": ["ma kota", "miala kota"]
}
"ala": ["ma kota", "miala kota"],
"youtube-dev-dedicated": ["vtpass", "vtpasssec"]
}

Просмотреть файл

@ -0,0 +1,7 @@
{
"vtocc_acl_no_access": {},
"vtocc_acl_read_only": {"Reader": "youtube-dev-dedicated"},
"vtocc_acl_read_write": {"Writer": "youtube-dev-dedicated"},
"vtocc_acl_admin": {"Admin": "youtube-dev-dedicated"},
"vtocc_acl_all_user_read_only": {"READER":"*"}
}

Просмотреть файл

@ -62,6 +62,13 @@ insert into vtocc_part2 values(1, 3)
insert into vtocc_part2 values(2, 4)
commit
create table vtocc_acl_no_access(key1 bigint, key2 bigint, primary key(key1))
create table vtocc_acl_read_only(key1 bigint, key2 bigint, primary key(key1))
create table vtocc_acl_read_write(key1 bigint, key2 bigint, primary key(key1))
create table vtocc_acl_admin(key1 bigint, key2 bigint, primary key(key1))
create table vtocc_acl_unmatched(key1 bigint, key2 bigint, primary key(key1))
create table vtocc_acl_all_user_read_only(key1 bigint, key2 bigint, primary key(key1))
# clean
drop table if exists vtocc_test
drop table if exists vtocc_a
@ -82,3 +89,9 @@ drop table if exists vtocc_misc
drop view if exists vtocc_view
drop table if exists vtocc_part1
drop table if exists vtocc_part2
drop table if exists vtocc_acl_no_access
drop table if exists vtocc_acl_read_only
drop table if exists vtocc_acl_read_write
drop table if exists vtocc_acl_admin
drop table if exists vtocc_acl_unmatched
drop table if exists vtocc_acl_all_user_read_only