зеркало из https://github.com/github/vitess-gh.git
Merge branch 'master' into gtid
This commit is contained in:
Коммит
8206436dee
|
@ -1,13 +1,64 @@
|
|||
create table a(abcd)#{"Action": "create", "NewName": "a"}
|
||||
drop table b#{"Action": "drop", "TableName": "b"}
|
||||
alter table c alter foo#{"Action": "alter", "TableName": "c", "NewTable": "c"}
|
||||
alter table c comment 'aa'#{"Action": "alter", "TableName": "c", "NewTable": "c"}
|
||||
drop index a on b#{"Action": "alter", "TableName": "b", "NewName": "b"}
|
||||
rename table a to b#{"Action": "rename", "TableName": "a", "NewTable": "b"}
|
||||
alter table a rename b#{"Action": "rename", "TableName": "a", "NewTable": "b"}
|
||||
alter table a rename to b#{"Action": "rename", "TableName": "a", "NewTable": "b"}
|
||||
create view a asdasd#{"Action": "create", "NewName": "a"}
|
||||
alter view c alter foo#{"Action": "alter", "TableName": "c", "NewTable": "c"}
|
||||
drop view b#{"Action": "drop", "TableName": "b"}
|
||||
select * from a#{"Action": ""}
|
||||
syntax error#{"Action": ""}
|
||||
"create table a(abcd)"
|
||||
{
|
||||
"Action": "create", "NewName": "a"
|
||||
}
|
||||
|
||||
"drop table b"
|
||||
{
|
||||
"Action": "drop", "TableName": "b"
|
||||
}
|
||||
|
||||
"alter table c alter foo"
|
||||
{
|
||||
"Action": "alter", "TableName": "c", "NewTable": "c"
|
||||
}
|
||||
|
||||
"alter table c comment 'aa'"
|
||||
{
|
||||
"Action": "alter", "TableName": "c", "NewTable": "c"
|
||||
}
|
||||
|
||||
"drop index a on b"
|
||||
{
|
||||
"Action": "alter", "TableName": "b", "NewName": "b"
|
||||
}
|
||||
|
||||
"rename table a to b"
|
||||
{
|
||||
"Action": "rename", "TableName": "a", "NewTable": "b"
|
||||
}
|
||||
|
||||
"alter table a rename b"
|
||||
{
|
||||
"Action": "rename", "TableName": "a", "NewTable": "b"
|
||||
}
|
||||
|
||||
"alter table a rename to b"
|
||||
{
|
||||
"Action": "rename", "TableName": "a", "NewTable": "b"
|
||||
}
|
||||
|
||||
"create view a asdasd"
|
||||
{
|
||||
"Action": "create", "NewName": "a"
|
||||
}
|
||||
|
||||
"alter view c alter foo"
|
||||
{
|
||||
"Action": "alter", "TableName": "c", "NewTable": "c"
|
||||
}
|
||||
|
||||
"drop view b"
|
||||
{
|
||||
"Action": "drop", "TableName": "b"
|
||||
}
|
||||
|
||||
"select * from a"
|
||||
{
|
||||
"Action": ""
|
||||
}
|
||||
|
||||
"syntax error"
|
||||
{
|
||||
"Action": ""
|
||||
}
|
||||
|
|
|
@ -2057,7 +2057,7 @@
|
|||
{
|
||||
"PlanId":"DDL",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"",
|
||||
"TableName":"a",
|
||||
"FieldQuery":null,
|
||||
"FullQuery":null,
|
||||
"OuterQuery":null,
|
||||
|
@ -2076,7 +2076,7 @@
|
|||
{
|
||||
"PlanId":"DDL",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"",
|
||||
"TableName":"a",
|
||||
"FieldQuery":null,
|
||||
"FullQuery":null,
|
||||
"OuterQuery":null,
|
||||
|
@ -2095,7 +2095,7 @@
|
|||
{
|
||||
"PlanId":"DDL",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"",
|
||||
"TableName":"a",
|
||||
"FieldQuery":null,
|
||||
"FullQuery":null,
|
||||
"OuterQuery":null,
|
||||
|
@ -2114,7 +2114,7 @@
|
|||
{
|
||||
"PlanId":"DDL",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"",
|
||||
"TableName":"a",
|
||||
"FieldQuery":null,
|
||||
"FullQuery":null,
|
||||
"OuterQuery":null,
|
||||
|
|
|
@ -1,7 +1,39 @@
|
|||
# select
|
||||
"select * from a"
|
||||
{
|
||||
"FullQuery": "select * from a"
|
||||
"PlanId":"PASS_SELECT",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"a",
|
||||
"FieldQuery":null,
|
||||
"FullQuery":"select * from a",
|
||||
"OuterQuery":null,
|
||||
"Subquery":null,
|
||||
"IndexUsed":"",
|
||||
"ColumnNumbers":null,
|
||||
"PKValues":null,
|
||||
"SecondaryPKValues":null,
|
||||
"SubqueryPKColumns":null,
|
||||
"SetKey":"",
|
||||
"SetValue":null
|
||||
}
|
||||
|
||||
# select join
|
||||
"select * from a join b"
|
||||
{
|
||||
"PlanId":"PASS_SELECT",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"",
|
||||
"FieldQuery":null,
|
||||
"FullQuery":"select * from a join b",
|
||||
"OuterQuery":null,
|
||||
"Subquery":null,
|
||||
"IndexUsed":"",
|
||||
"ColumnNumbers":null,
|
||||
"PKValues":null,
|
||||
"SecondaryPKValues":null,
|
||||
"SubqueryPKColumns":null,
|
||||
"SetKey":"",
|
||||
"SetValue":null
|
||||
}
|
||||
|
||||
# select for update
|
||||
|
@ -11,7 +43,20 @@
|
|||
# union
|
||||
"select * from a union select * from b"
|
||||
{
|
||||
"FullQuery": "select * from a union select * from b"
|
||||
"PlanId":"PASS_SELECT",
|
||||
"Reason":"DEFAULT",
|
||||
"TableName":"",
|
||||
"FieldQuery":null,
|
||||
"FullQuery": "select * from a union select * from b",
|
||||
"OuterQuery":null,
|
||||
"Subquery":null,
|
||||
"IndexUsed":"",
|
||||
"ColumnNumbers":null,
|
||||
"PKValues":null,
|
||||
"SecondaryPKValues":null,
|
||||
"SubqueryPKColumns":null,
|
||||
"SetKey":"",
|
||||
"SetValue":null
|
||||
}
|
||||
|
||||
# dml
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/dbconfigs"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl"
|
||||
"github.com/youtube/vitess/go/vt/servenv"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver"
|
||||
)
|
||||
|
||||
|
@ -21,6 +22,7 @@ var (
|
|||
enableRowcache = flag.Bool("enable-rowcache", false, "enable rowcacche")
|
||||
enableInvalidator = flag.Bool("enable-invalidator", false, "enable rowcache invalidator")
|
||||
binlogPath = flag.String("binlog-path", "", "binlog path used by rowcache invalidator")
|
||||
tableAclConfig = flag.String("table-acl-config", "", "path to table access checker config file")
|
||||
)
|
||||
|
||||
var schemaOverrides []tabletserver.SchemaOverride
|
||||
|
@ -55,6 +57,9 @@ func main() {
|
|||
data, _ := json.MarshalIndent(schemaOverrides, "", " ")
|
||||
log.Infof("schemaOverrides: %s\n", data)
|
||||
|
||||
if *tableAclConfig != "" {
|
||||
tableacl.Init(*tableAclConfig)
|
||||
}
|
||||
tabletserver.InitQueryService()
|
||||
|
||||
err = tabletserver.AllowQueries(&dbConfigs.App, schemaOverrides, tabletserver.LoadCustomRules(), mysqld, true)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/dbconfigs"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl"
|
||||
"github.com/youtube/vitess/go/vt/servenv"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletmanager"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
|
@ -23,6 +24,7 @@ var (
|
|||
tabletPath = flag.String("tablet-path", "", "tablet alias or path to zk node representing the tablet")
|
||||
enableRowcache = flag.Bool("enable-rowcache", false, "enable rowcacche")
|
||||
overridesFile = flag.String("schema-override", "", "schema overrides file")
|
||||
tableAclConfig = flag.String("table-acl-config", "", "path to table access checker config file")
|
||||
|
||||
agent *tabletmanager.ActionAgent
|
||||
)
|
||||
|
@ -77,6 +79,9 @@ func main() {
|
|||
}
|
||||
dbcfgs.App.EnableRowcache = *enableRowcache
|
||||
|
||||
if *tableAclConfig != "" {
|
||||
tableacl.Init(*tableAclConfig)
|
||||
}
|
||||
tabletserver.InitQueryService()
|
||||
binlog.RegisterUpdateStreamService(mycnf)
|
||||
|
||||
|
|
|
@ -28,13 +28,13 @@ type Histogram struct {
|
|||
// NewHistogram creates a histogram with auto-generated labels
|
||||
// based on the cutoffs. The buckets are categorized using the
|
||||
// following criterion: cutoff[i-1] < value <= cutoff[i]. Anything
|
||||
// higher than the highest cutoff is labeled as "Max".
|
||||
// higher than the highest cutoff is labeled as "inf".
|
||||
func NewHistogram(name string, cutoffs []int64) *Histogram {
|
||||
labels := make([]string, len(cutoffs)+1)
|
||||
for i, v := range cutoffs {
|
||||
labels[i] = fmt.Sprintf("%d", v)
|
||||
}
|
||||
labels[len(labels)-1] = "Max"
|
||||
labels[len(labels)-1] = "inf"
|
||||
return NewGenericHistogram(name, cutoffs, labels, "Count", "Total")
|
||||
}
|
||||
|
||||
|
@ -85,8 +85,8 @@ func (h *Histogram) MarshalJSON() ([]byte, error) {
|
|||
fmt.Fprintf(b, "{")
|
||||
totalCount := int64(0)
|
||||
for i, label := range h.labels {
|
||||
fmt.Fprintf(b, "\"%v\": %v, ", label, h.buckets[i])
|
||||
totalCount += h.buckets[i]
|
||||
fmt.Fprintf(b, "\"%v\": %v, ", label, totalCount)
|
||||
}
|
||||
fmt.Fprintf(b, "\"%s\": %v, ", h.countLabel, totalCount)
|
||||
fmt.Fprintf(b, "\"%s\": %v", h.totalLabel, h.total)
|
||||
|
@ -128,3 +128,13 @@ func (h *Histogram) Total() (total int64) {
|
|||
defer h.mu.Unlock()
|
||||
return h.total
|
||||
}
|
||||
|
||||
func (h *Histogram) Labels() []string {
|
||||
return h.labels
|
||||
}
|
||||
|
||||
func (h *Histogram) Buckets() []int64 {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.buckets
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestHistogram(t *testing.T) {
|
|||
for i := 0; i < 10; i++ {
|
||||
h.Add(int64(i))
|
||||
}
|
||||
want := `{"1": 2, "5": 4, "Max": 4, "Count": 10, "Total": 45}`
|
||||
want := `{"1": 2, "5": 6, "inf": 10, "Count": 10, "Total": 45}`
|
||||
if h.String() != want {
|
||||
t.Errorf("want %s, got %s", want, h.String())
|
||||
}
|
||||
|
@ -26,8 +26,8 @@ func TestHistogram(t *testing.T) {
|
|||
if counts["5"] != 4 {
|
||||
t.Errorf("want 4, got %d", counts["2"])
|
||||
}
|
||||
if counts["Max"] != 4 {
|
||||
t.Errorf("want 4, got %d", counts["Max"])
|
||||
if counts["inf"] != 4 {
|
||||
t.Errorf("want 4, got %d", counts["inf"])
|
||||
}
|
||||
if h.Count() != 10 {
|
||||
t.Errorf("want 10, got %d", h.Count())
|
||||
|
|
|
@ -118,9 +118,9 @@ var bucketLabels []string
|
|||
func init() {
|
||||
bucketLabels = make([]string, len(bucketCutoffs)+1)
|
||||
for i, v := range bucketCutoffs {
|
||||
bucketLabels[i] = fmt.Sprintf("%.4f", float64(v)/1e9)
|
||||
bucketLabels[i] = fmt.Sprintf("%d", v)
|
||||
}
|
||||
bucketLabels[len(bucketLabels)-1] = "Max"
|
||||
bucketLabels[len(bucketLabels)-1] = "inf"
|
||||
}
|
||||
|
||||
// MultiTimings is meant to tracks timing data by categories as well
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestTimings(t *testing.T) {
|
|||
tm.Add("tag1", 500*time.Microsecond)
|
||||
tm.Add("tag1", 1*time.Millisecond)
|
||||
tm.Add("tag2", 1*time.Millisecond)
|
||||
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1":{"0.0005":1,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":2,"Time":1500000},"tag2":{"0.0005":0,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":1,"Time":1000000}}}`
|
||||
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1":{"500000":1,"1000000":2,"5000000":2,"10000000":2,"50000000":2,"100000000":2,"500000000":2,"1000000000":2,"5000000000":2,"10000000000":2,"inf":2,"Count":2,"Time":1500000},"tag2":{"500000":0,"1000000":1,"5000000":1,"10000000":1,"50000000":1,"100000000":1,"500000000":1,"1000000000":1,"5000000000":1,"10000000000":1,"inf":1,"Count":1,"Time":1000000}}}`
|
||||
if tm.String() != want {
|
||||
t.Errorf("want %s, got %s", want, tm.String())
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func TestMultiTimings(t *testing.T) {
|
|||
mtm.Add([]string{"tag1a", "tag1b"}, 500*time.Microsecond)
|
||||
mtm.Add([]string{"tag1a", "tag1b"}, 1*time.Millisecond)
|
||||
mtm.Add([]string{"tag2a", "tag2b"}, 1*time.Millisecond)
|
||||
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1a.tag1b":{"0.0005":1,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":2,"Time":1500000},"tag2a.tag2b":{"0.0005":0,"0.0010":1,"0.0050":0,"0.0100":0,"0.0500":0,"0.1000":0,"0.5000":0,"1.0000":0,"5.0000":0,"10.0000":0,"Max":0,"Count":1,"Time":1000000}}}`
|
||||
want := `{"TotalCount":3,"TotalTime":2500000,"Histograms":{"tag1a.tag1b":{"500000":1,"1000000":2,"5000000":2,"10000000":2,"50000000":2,"100000000":2,"500000000":2,"1000000000":2,"5000000000":2,"10000000000":2,"inf":2,"Count":2,"Time":1500000},"tag2a.tag2b":{"500000":0,"1000000":1,"5000000":1,"10000000":1,"50000000":1,"100000000":1,"500000000":1,"1000000000":1,"5000000000":1,"10000000000":1,"inf":1,"Count":1,"Time":1000000}}}`
|
||||
if mtm.String() != want {
|
||||
t.Errorf("want %s, got %s", want, mtm.String())
|
||||
}
|
||||
|
|
|
@ -15,8 +15,8 @@ import (
|
|||
mproto "github.com/youtube/vitess/go/mysql/proto"
|
||||
"github.com/youtube/vitess/go/vt/client2/tablet"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
"github.com/youtube/vitess/go/vt/vtgate"
|
||||
"github.com/youtube/vitess/go/vt/zktopo"
|
||||
"github.com/youtube/vitess/go/zk"
|
||||
)
|
||||
|
@ -233,7 +233,7 @@ func (sc *ShardedConn) Exec(query string, bindVars map[string]interface{}) (db.R
|
|||
if sc.srvKeyspace == nil {
|
||||
return nil, ErrNotConnected
|
||||
}
|
||||
shards, err := sqlparser.GetShardList(query, bindVars, sc.shardMaxKeys)
|
||||
shards, err := vtgate.GetShardList(query, bindVars, sc.shardMaxKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -6,23 +6,12 @@ package sqlparser
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
)
|
||||
|
||||
func TestGen(t *testing.T) {
|
||||
|
@ -32,118 +21,6 @@ func TestGen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var (
|
||||
SQLZERO = sqltypes.MakeString([]byte("0"))
|
||||
)
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
testSchema := loadSchema("schema_test.json")
|
||||
for tcase := range iterateExecFile("exec_cases.txt") {
|
||||
plan, err := ExecParse(tcase.input, func(name string) (*schema.Table, bool) {
|
||||
r, ok := testSchema[name]
|
||||
return r, ok
|
||||
})
|
||||
var out string
|
||||
if err != nil {
|
||||
out = err.Error()
|
||||
} else {
|
||||
bout, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
|
||||
}
|
||||
out = string(bout)
|
||||
}
|
||||
if out != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
|
||||
}
|
||||
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomExec(t *testing.T) {
|
||||
testSchemas := testfiles.Glob("sqlparser_test/*_schema.json")
|
||||
if len(testSchemas) == 0 {
|
||||
t.Log("No schemas to test")
|
||||
return
|
||||
}
|
||||
for _, schemFile := range testSchemas {
|
||||
schem := loadSchema(schemFile)
|
||||
t.Logf("Testing schema %s", schemFile)
|
||||
files, err := filepath.Glob(strings.Replace(schemFile, "schema.json", "*.txt", -1))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(files) == 0 {
|
||||
t.Fatalf("No test files for %s", schemFile)
|
||||
}
|
||||
getter := func(name string) (*schema.Table, bool) {
|
||||
r, ok := schem[name]
|
||||
return r, ok
|
||||
}
|
||||
for _, file := range files {
|
||||
t.Logf("Testing file %s", file)
|
||||
for tcase := range iterateExecFile(file) {
|
||||
plan, err := ExecParse(tcase.input, getter)
|
||||
var out string
|
||||
if err != nil {
|
||||
out = err.Error()
|
||||
} else {
|
||||
bout, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
|
||||
}
|
||||
out = string(bout)
|
||||
}
|
||||
if out != tcase.output {
|
||||
t.Errorf("File: %s: Line:%v\n%s\n%s", file, tcase.lineno, tcase.output, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamExec(t *testing.T) {
|
||||
for tcase := range iterateExecFile("stream_cases.txt") {
|
||||
plan, err := StreamExecParse(tcase.input)
|
||||
var out string
|
||||
if err != nil {
|
||||
out = err.Error()
|
||||
} else {
|
||||
bout, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
|
||||
}
|
||||
out = string(bout)
|
||||
}
|
||||
if out != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
|
||||
}
|
||||
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDDL(t *testing.T) {
|
||||
for tcase := range iterateFiles("sqlparser_test/ddl_cases.txt") {
|
||||
plan := DDLParse(tcase.input)
|
||||
expected := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(tcase.output), &expected)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v", plan))
|
||||
}
|
||||
matchString(t, tcase.lineno, expected["Action"], plan.Action)
|
||||
matchString(t, tcase.lineno, expected["TableName"], plan.TableName)
|
||||
matchString(t, tcase.lineno, expected["NewName"], plan.NewName)
|
||||
}
|
||||
}
|
||||
|
||||
func matchString(t *testing.T, line int, expected interface{}, actual string) {
|
||||
if expected != nil {
|
||||
if expected.(string) != actual {
|
||||
t.Error(fmt.Sprintf("Line %d: expected: %v, received %s", line, expected, actual))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
for tcase := range iterateFiles("sqlparser_test/*.sql") {
|
||||
if tcase.output == "" {
|
||||
|
@ -162,47 +39,6 @@ func TestParse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
tabletkeys := []key.KeyspaceId{
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x02",
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x04",
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x06",
|
||||
"a",
|
||||
"b",
|
||||
"d",
|
||||
}
|
||||
bindVariables := make(map[string]interface{})
|
||||
bindVariables["id0"] = 0
|
||||
bindVariables["id2"] = 2
|
||||
bindVariables["id3"] = 3
|
||||
bindVariables["id4"] = 4
|
||||
bindVariables["id6"] = 6
|
||||
bindVariables["id8"] = 8
|
||||
bindVariables["ids"] = []interface{}{1, 4}
|
||||
bindVariables["a"] = "a"
|
||||
bindVariables["b"] = "b"
|
||||
bindVariables["c"] = "c"
|
||||
bindVariables["d"] = "d"
|
||||
bindVariables["e"] = "e"
|
||||
for tcase := range iterateFiles("sqlparser_test/routing_cases.txt") {
|
||||
if tcase.output == "" {
|
||||
tcase.output = tcase.input
|
||||
}
|
||||
out, err := GetShardList(tcase.input, bindVariables, tabletkeys)
|
||||
if err != nil {
|
||||
if err.Error() != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.input, err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
sort.Ints(out)
|
||||
outstr := fmt.Sprintf("%v", out)
|
||||
if outstr != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, outstr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParse1(b *testing.B) {
|
||||
sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'"
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
@ -223,23 +59,6 @@ func BenchmarkParse2(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func loadSchema(name string) map[string]*schema.Table {
|
||||
b, err := ioutil.ReadFile(locateFile(name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tables := make([]*schema.Table, 0, 8)
|
||||
err = json.Unmarshal(b, &tables)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s := make(map[string]*schema.Table)
|
||||
for _, t := range tables {
|
||||
s[t.Name] = t
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
file string
|
||||
lineno int
|
||||
|
@ -284,71 +103,3 @@ func iterateFiles(pattern string) (testCaseIterator chan testCase) {
|
|||
}()
|
||||
return testCaseIterator
|
||||
}
|
||||
|
||||
func iterateExecFile(name string) (testCaseIterator chan testCase) {
|
||||
name = locateFile(name)
|
||||
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Could not open file %s", name))
|
||||
}
|
||||
testCaseIterator = make(chan testCase)
|
||||
go func() {
|
||||
defer close(testCaseIterator)
|
||||
|
||||
r := bufio.NewReader(fd)
|
||||
lineno := 0
|
||||
for {
|
||||
binput, err := r.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
fmt.Printf("Line: %d\n", lineno)
|
||||
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
|
||||
}
|
||||
break
|
||||
}
|
||||
lineno++
|
||||
input := string(binput)
|
||||
if input == "" || input == "\n" || input[0] == '#' || strings.HasPrefix(input, "Length:") {
|
||||
//fmt.Printf("%s\n", input)
|
||||
continue
|
||||
}
|
||||
err = json.Unmarshal(binput, &input)
|
||||
if err != nil {
|
||||
fmt.Printf("Line: %d, input: %s\n", lineno, binput)
|
||||
panic(err)
|
||||
}
|
||||
input = strings.Trim(input, "\"")
|
||||
var output []byte
|
||||
for {
|
||||
l, err := r.ReadBytes('\n')
|
||||
lineno++
|
||||
if err != nil {
|
||||
fmt.Printf("Line: %d\n", lineno)
|
||||
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
|
||||
}
|
||||
output = append(output, l...)
|
||||
if l[0] == '}' {
|
||||
output = output[:len(output)-1]
|
||||
b := bytes.NewBuffer(make([]byte, 0, 64))
|
||||
if err := json.Compact(b, output); err == nil {
|
||||
output = b.Bytes()
|
||||
}
|
||||
break
|
||||
}
|
||||
if l[0] == '"' {
|
||||
output = output[1 : len(output)-2]
|
||||
break
|
||||
}
|
||||
}
|
||||
testCaseIterator <- testCase{name, lineno, input, string(output)}
|
||||
}
|
||||
}()
|
||||
return testCaseIterator
|
||||
}
|
||||
|
||||
func locateFile(name string) string {
|
||||
if path.IsAbs(name) {
|
||||
return name
|
||||
}
|
||||
return testfiles.Locate("sqlparser_test/" + name)
|
||||
}
|
||||
|
|
|
@ -1,254 +0,0 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlparser
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
)
|
||||
|
||||
const (
|
||||
EID_NODE = iota
|
||||
VALUE_NODE
|
||||
LIST_NODE
|
||||
OTHER_NODE
|
||||
)
|
||||
|
||||
type RoutingPlan struct {
|
||||
criteria SQLNode
|
||||
}
|
||||
|
||||
func GetShardList(sql string, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardlist []int, err error) {
|
||||
defer handleError(&err)
|
||||
|
||||
plan := buildPlan(sql)
|
||||
return shardListFromPlan(plan, bindVariables, tabletKeys), nil
|
||||
}
|
||||
|
||||
func buildPlan(sql string) (plan *RoutingPlan) {
|
||||
statement, err := Parse(sql)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return getRoutingPlan(statement)
|
||||
}
|
||||
|
||||
func shardListFromPlan(plan *RoutingPlan, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardList []int) {
|
||||
if plan.criteria == nil {
|
||||
return makeList(0, len(tabletKeys))
|
||||
}
|
||||
|
||||
switch criteria := plan.criteria.(type) {
|
||||
case Values:
|
||||
index := findInsertShard(criteria, bindVariables, tabletKeys)
|
||||
return []int{index}
|
||||
case *ComparisonExpr:
|
||||
switch criteria.Operator {
|
||||
case "=", "<=>":
|
||||
index := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
return []int{index}
|
||||
case "<", "<=":
|
||||
index := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
return makeList(0, index+1)
|
||||
case ">", ">=":
|
||||
index := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
return makeList(index, len(tabletKeys))
|
||||
case "in":
|
||||
return findShardList(criteria.Right, bindVariables, tabletKeys)
|
||||
}
|
||||
case *RangeCond:
|
||||
if criteria.Operator == "between" {
|
||||
start := findShard(criteria.From, bindVariables, tabletKeys)
|
||||
last := findShard(criteria.To, bindVariables, tabletKeys)
|
||||
if last < start {
|
||||
start, last = last, start
|
||||
}
|
||||
return makeList(start, last+1)
|
||||
}
|
||||
}
|
||||
return makeList(0, len(tabletKeys))
|
||||
}
|
||||
|
||||
func getRoutingPlan(statement Statement) (plan *RoutingPlan) {
|
||||
plan = &RoutingPlan{}
|
||||
if ins, ok := statement.(*Insert); ok {
|
||||
if sel, ok := ins.Rows.(SelectStatement); ok {
|
||||
return getRoutingPlan(sel)
|
||||
}
|
||||
plan.criteria = routingAnalyzeValues(ins.Rows.(Values))
|
||||
return plan
|
||||
}
|
||||
var where *Where
|
||||
switch stmt := statement.(type) {
|
||||
case *Select:
|
||||
where = stmt.Where
|
||||
case *Update:
|
||||
where = stmt.Where
|
||||
case *Delete:
|
||||
where = stmt.Where
|
||||
}
|
||||
if where != nil {
|
||||
plan.criteria = routingAnalyzeBoolean(where.Expr)
|
||||
}
|
||||
return plan
|
||||
}
|
||||
|
||||
func routingAnalyzeValues(vals Values) Values {
|
||||
// Analyze first value of every item in the list
|
||||
for i := 0; i < len(vals); i++ {
|
||||
switch tuple := vals[i].(type) {
|
||||
case ValTuple:
|
||||
result := routingAnalyzeValue(tuple[0])
|
||||
if result != VALUE_NODE {
|
||||
panic(NewParserError("insert is too complex"))
|
||||
}
|
||||
default:
|
||||
panic(NewParserError("insert is too complex"))
|
||||
}
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
func routingAnalyzeBoolean(node BoolExpr) BoolExpr {
|
||||
switch node := node.(type) {
|
||||
case *AndExpr:
|
||||
left := routingAnalyzeBoolean(node.Left)
|
||||
right := routingAnalyzeBoolean(node.Right)
|
||||
if left != nil && right != nil {
|
||||
return nil
|
||||
} else if left != nil {
|
||||
return left
|
||||
} else {
|
||||
return right
|
||||
}
|
||||
case *ParenBoolExpr:
|
||||
return routingAnalyzeBoolean(node.Expr)
|
||||
case *ComparisonExpr:
|
||||
switch {
|
||||
case StringIn(node.Operator, "=", "<", ">", "<=", ">=", "<=>"):
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
right := routingAnalyzeValue(node.Right)
|
||||
if (left == EID_NODE && right == VALUE_NODE) || (left == VALUE_NODE && right == EID_NODE) {
|
||||
return node
|
||||
}
|
||||
case node.Operator == "in":
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
right := routingAnalyzeValue(node.Right)
|
||||
if left == EID_NODE && right == LIST_NODE {
|
||||
return node
|
||||
}
|
||||
}
|
||||
case *RangeCond:
|
||||
if node.Operator != "between" {
|
||||
return nil
|
||||
}
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
from := routingAnalyzeValue(node.From)
|
||||
to := routingAnalyzeValue(node.To)
|
||||
if left == EID_NODE && from == VALUE_NODE && to == VALUE_NODE {
|
||||
return node
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func routingAnalyzeValue(valExpr ValExpr) int {
|
||||
switch node := valExpr.(type) {
|
||||
case *ColName:
|
||||
if string(node.Name) == "entity_id" {
|
||||
return EID_NODE
|
||||
}
|
||||
case ValTuple:
|
||||
for _, n := range node {
|
||||
if routingAnalyzeValue(n) != VALUE_NODE {
|
||||
return OTHER_NODE
|
||||
}
|
||||
}
|
||||
return LIST_NODE
|
||||
case StrVal, NumVal, ValArg:
|
||||
return VALUE_NODE
|
||||
}
|
||||
return OTHER_NODE
|
||||
}
|
||||
|
||||
func findShardList(valExpr ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) []int {
|
||||
shardset := make(map[int]bool)
|
||||
switch node := valExpr.(type) {
|
||||
case ValTuple:
|
||||
for _, n := range node {
|
||||
index := findShard(n, bindVariables, tabletKeys)
|
||||
shardset[index] = true
|
||||
}
|
||||
}
|
||||
shardlist := make([]int, len(shardset))
|
||||
index := 0
|
||||
for k := range shardset {
|
||||
shardlist[index] = k
|
||||
index++
|
||||
}
|
||||
return shardlist
|
||||
}
|
||||
|
||||
func findInsertShard(vals Values, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) int {
|
||||
index := -1
|
||||
for i := 0; i < len(vals); i++ {
|
||||
first_value_expression := vals[i].(ValTuple)[0]
|
||||
newIndex := findShard(first_value_expression, bindVariables, tabletKeys)
|
||||
if index == -1 {
|
||||
index = newIndex
|
||||
} else if index != newIndex {
|
||||
panic(NewParserError("insert has multiple shard targets"))
|
||||
}
|
||||
}
|
||||
return index
|
||||
}
|
||||
|
||||
func findShard(valExpr ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) int {
|
||||
value := getBoundValue(valExpr, bindVariables)
|
||||
return key.FindShardForValue(value, tabletKeys)
|
||||
}
|
||||
|
||||
func getBoundValue(valExpr ValExpr, bindVariables map[string]interface{}) string {
|
||||
switch node := valExpr.(type) {
|
||||
case ValTuple:
|
||||
if len(node) != 1 {
|
||||
panic(NewParserError("tuples not allowed as insert values"))
|
||||
}
|
||||
// TODO: Change parser to create single value tuples into non-tuples.
|
||||
return getBoundValue(node[0], bindVariables)
|
||||
case StrVal:
|
||||
return string(node)
|
||||
case NumVal:
|
||||
val, err := strconv.ParseInt(string(node), 10, 64)
|
||||
if err != nil {
|
||||
panic(NewParserError("%s", err.Error()))
|
||||
}
|
||||
return key.Uint64Key(val).String()
|
||||
case ValArg:
|
||||
value := findBindValue(node, bindVariables)
|
||||
return key.EncodeValue(value)
|
||||
}
|
||||
panic("Unexpected token")
|
||||
}
|
||||
|
||||
func findBindValue(valArg ValArg, bindVariables map[string]interface{}) interface{} {
|
||||
if bindVariables == nil {
|
||||
panic(NewParserError("No bind variable for " + string(valArg)))
|
||||
}
|
||||
value, ok := bindVariables[string(valArg[1:])]
|
||||
if !ok {
|
||||
panic(NewParserError("No bind variable for " + string(valArg)))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func makeList(start, end int) []int {
|
||||
list := make([]int, end-start)
|
||||
for i := start; i < end; i++ {
|
||||
list[i-start] = i
|
||||
}
|
||||
return list
|
||||
}
|
|
@ -9,26 +9,6 @@ import (
|
|||
"fmt"
|
||||
)
|
||||
|
||||
// ParserError: To be deprecated.
|
||||
// TODO(sougou): deprecate.
|
||||
type ParserError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func NewParserError(format string, args ...interface{}) ParserError {
|
||||
return ParserError{fmt.Sprintf(format, args...)}
|
||||
}
|
||||
|
||||
func (err ParserError) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
func handleError(err *error) {
|
||||
if x := recover(); x != nil {
|
||||
*err = x.(ParserError)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackedBuffer is used to rebuild a query from the ast.
|
||||
// bindLocations keeps track of locations in the buffer that
|
||||
// use bind variables for efficient future substitutions.
|
||||
|
@ -116,3 +96,7 @@ func (buf *TrackedBuffer) WriteArg(arg string) {
|
|||
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
|
||||
return &ParsedQuery{buf.String(), buf.bindLocations}
|
||||
}
|
||||
|
||||
func (buf *TrackedBuffer) HasBindVars() bool {
|
||||
return len(buf.bindLocations) != 0
|
||||
}
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
package tableacl
|
||||
|
||||
import "strings"
|
||||
|
||||
// Role defines the level of access on a table
|
||||
type Role int
|
||||
|
||||
const (
|
||||
// READER can run SELECT statements
|
||||
READER Role = iota
|
||||
// WRITER can run SELECT, INSERT & UPDATE statements
|
||||
WRITER
|
||||
// ADMIN can run any statements including DDLs
|
||||
ADMIN
|
||||
// NumRoles is number of Roles defined
|
||||
NumRoles
|
||||
)
|
||||
|
||||
var roleNames = []string{
|
||||
"READER",
|
||||
"WRITER",
|
||||
"ADMIN",
|
||||
}
|
||||
|
||||
// Name returns the name of a role
|
||||
func (r Role) Name() string {
|
||||
if r < READER || r > ADMIN {
|
||||
return ""
|
||||
}
|
||||
return roleNames[r]
|
||||
}
|
||||
|
||||
// RoleByName returns the Role corresponding to a name
|
||||
func RoleByName(s string) (Role, bool) {
|
||||
for i, v := range roleNames {
|
||||
if v == strings.ToUpper(s) {
|
||||
return Role(i), true
|
||||
}
|
||||
}
|
||||
return NumRoles, false
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package tableacl
|
||||
|
||||
var allAcl simpleACL
|
||||
|
||||
const (
|
||||
ALL = "*"
|
||||
)
|
||||
|
||||
// NewACL returns an ACL with the specified entries
|
||||
func NewACL(entries []string) (ACL, error) {
|
||||
a := simpleACL(map[string]bool{})
|
||||
for _, e := range entries {
|
||||
a[e] = true
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// simpleACL keeps all entries in a unique in-memory list
|
||||
type simpleACL map[string]bool
|
||||
|
||||
// IsMember checks the membership of a principal in this ACL
|
||||
func (a simpleACL) IsMember(principal string) bool {
|
||||
return a[principal] || a[ALL]
|
||||
}
|
||||
|
||||
func all() ACL {
|
||||
if allAcl == nil {
|
||||
allAcl = simpleACL(map[string]bool{ALL: true})
|
||||
}
|
||||
return allAcl
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
package tableacl
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ACL is an interface for Access Control List
|
||||
type ACL interface {
|
||||
// IsMember checks the membership of a principal in this ACL
|
||||
IsMember(principal string) bool
|
||||
}
|
||||
|
||||
var tableAcl map[*regexp.Regexp]map[Role]ACL
|
||||
|
||||
// Init initiates table ACLs
|
||||
func Init(configFile string) {
|
||||
config, err := ioutil.ReadFile(configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("unable to read tableACL config file: %v", err)
|
||||
}
|
||||
tableAcl, err = load(config)
|
||||
if err != nil {
|
||||
log.Fatalf("tableACL initialization error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// load loads configurations from a JSON byte array
|
||||
//
|
||||
// Sample configuration
|
||||
// []byte (`{
|
||||
// <tableRegexPattern1>: {"READER": "*", "WRITER": "<u2>,<u4>...","ADMIN": "<u5>"},
|
||||
// <tableRegexPattern2>: {"ADMIN": "<u5>"}
|
||||
//}`)
|
||||
func load(config []byte) (map[*regexp.Regexp]map[Role]ACL, error) {
|
||||
var contents map[string]map[string]string
|
||||
err := json.Unmarshal(config, &contents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tableAcl := make(map[*regexp.Regexp]map[Role]ACL)
|
||||
for tblPattern, accessMap := range contents {
|
||||
re, err := regexp.Compile(tblPattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regexp compile error %v: %v", tblPattern, err)
|
||||
}
|
||||
tableAcl[re] = make(map[Role]ACL)
|
||||
|
||||
entriesByRole := make(map[Role][]string)
|
||||
for i := READER; i < NumRoles; i++ {
|
||||
entriesByRole[i] = []string{}
|
||||
}
|
||||
for role, entries := range accessMap {
|
||||
r, ok := RoleByName(role)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parse error, invalid role %v", role)
|
||||
}
|
||||
// Entries must be assigned to all roles up to r
|
||||
for i := READER; i <= r; i++ {
|
||||
entriesByRole[i] = append(entriesByRole[i], strings.Split(entries, ",")...)
|
||||
}
|
||||
}
|
||||
for r, entries := range entriesByRole {
|
||||
a, err := NewACL(entries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tableAcl[re][r] = a
|
||||
}
|
||||
|
||||
}
|
||||
return tableAcl, nil
|
||||
}
|
||||
|
||||
// Authorized returns the list of entities who have at least the
|
||||
// minimum specified Role on a table
|
||||
func Authorized(table string, minRole Role) ACL {
|
||||
if tableAcl == nil {
|
||||
// No ACLs, allow all access
|
||||
return all()
|
||||
}
|
||||
for re, accessMap := range tableAcl {
|
||||
if !re.MatchString(table) {
|
||||
continue
|
||||
}
|
||||
return accessMap[minRole]
|
||||
}
|
||||
// No matching patterns for table, allow all access
|
||||
return all()
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tableacl
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/context"
|
||||
)
|
||||
|
||||
func currentUser() string {
|
||||
ctx := &context.DummyContext{}
|
||||
return ctx.GetUsername()
|
||||
}
|
||||
|
||||
func TestParseInvalidJSON(t *testing.T) {
|
||||
checkLoad([]byte(`{1:2}`), false, t)
|
||||
checkLoad([]byte(`{"1":"2"}`), false, t)
|
||||
checkLoad([]byte(`{"table1":{1:2}}`), false, t)
|
||||
}
|
||||
|
||||
func TestInvalidRoleName(t *testing.T) {
|
||||
checkLoad([]byte(`{"table1":{"SOMEROLE":"u1"}}`), false, t)
|
||||
}
|
||||
|
||||
func TestInvalidRegex(t *testing.T) {
|
||||
checkLoad([]byte(`{"table(1":{"READER":"u1"}}`), false, t)
|
||||
}
|
||||
|
||||
func TestValidConfigs(t *testing.T) {
|
||||
checkLoad([]byte(`{"table1":{"READER":"u1"}}`), true, t)
|
||||
checkLoad([]byte(`{"table1":{"READER":"u1,u2", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,u2", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"}}`), true, t)
|
||||
checkLoad([]byte(`{
|
||||
"table[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3"},
|
||||
"tbl[0-9]+":{"Reader":"u1,`+ALL+`", "WRITER":"u3", "ADMIN":"u4"}
|
||||
}`), true, t)
|
||||
|
||||
}
|
||||
|
||||
func TestDenyReaderInsert(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", WRITER, t, false)
|
||||
}
|
||||
|
||||
func TestAllowReaderSelect(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", READER, t, true)
|
||||
}
|
||||
|
||||
func TestDenyReaderDDL(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", ADMIN, t, false)
|
||||
}
|
||||
|
||||
func TestAllowUnmatchedTable(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "UNMATCHED_TABLE", ADMIN, t, true)
|
||||
}
|
||||
|
||||
func TestAllUserReadAcess(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + ALL + `", "WRITER":"u3"}}`)
|
||||
checkAccess(configData, "table1", READER, t, true)
|
||||
}
|
||||
|
||||
func TestAllUserWriteAccess(t *testing.T) {
|
||||
configData := []byte(`{"table[0-9]+":{"Reader":"` + currentUser() + `", "WRITER":"` + ALL + `"}}`)
|
||||
checkAccess(configData, "table1", WRITER, t, true)
|
||||
}
|
||||
|
||||
func checkLoad(configData []byte, valid bool, t *testing.T) {
|
||||
var err error
|
||||
tableAcl, err = load(configData)
|
||||
if !valid && err == nil {
|
||||
t.Errorf("expecting parse error none returned")
|
||||
}
|
||||
|
||||
if valid && err != nil {
|
||||
t.Errorf("unexpected load error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func checkAccess(configData []byte, tableName string, role Role, t *testing.T, want bool) {
|
||||
checkLoad(configData, true, t)
|
||||
got := Authorized(tableName, role).IsMember(currentUser())
|
||||
if want != got {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
|
@ -342,6 +342,7 @@ func updateReplicationGraphForPromotedSlave(ts topo.Server, tablet *topo.TabletI
|
|||
tablet.Type = topo.TYPE_MASTER
|
||||
tablet.Parent.Cell = ""
|
||||
tablet.Parent.Uid = topo.NO_TABLET
|
||||
tablet.Health = nil
|
||||
err := topo.UpdateTablet(ts, tablet)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -13,8 +13,8 @@ import (
|
|||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/stats"
|
||||
"github.com/youtube/vitess/go/vt/binlog"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
"github.com/youtube/vitess/go/vt/topo"
|
||||
)
|
||||
|
||||
|
@ -52,7 +52,7 @@ func (agent *ActionAgent) allowQueries(tablet *topo.Tablet) error {
|
|||
qrs := tabletserver.LoadCustomRules()
|
||||
if tablet.KeyRange.IsPartial() {
|
||||
qr := tabletserver.NewQueryRule("enforce keyspace_id range", "keyspace_id_not_in_range", tabletserver.QR_FAIL)
|
||||
qr.AddPlanCond(sqlparser.PLAN_INSERT_PK)
|
||||
qr.AddPlanCond(planbuilder.PLAN_INSERT_PK)
|
||||
err := qr.AddBindVarCond("keyspace_id", true, true, tabletserver.QR_NOTIN, tablet.KeyRange)
|
||||
if err != nil {
|
||||
log.Warningf("Unable to add keyspace rule: %v", err)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import "github.com/youtube/vitess/go/vt/sqlparser"
|
||||
|
||||
type DDLPlan struct {
|
||||
Action string
|
||||
TableName string
|
||||
NewName string
|
||||
}
|
||||
|
||||
func DDLParse(sql string) (plan *DDLPlan) {
|
||||
statement, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return &DDLPlan{Action: ""}
|
||||
}
|
||||
stmt, ok := statement.(*sqlparser.DDL)
|
||||
if !ok {
|
||||
return &DDLPlan{Action: ""}
|
||||
}
|
||||
return &DDLPlan{
|
||||
Action: stmt.Action,
|
||||
TableName: string(stmt.Table),
|
||||
NewName: string(stmt.NewName),
|
||||
}
|
||||
}
|
||||
|
||||
func analyzeDDL(ddl *sqlparser.DDL, getTable TableGetter) *ExecPlan {
|
||||
plan := &ExecPlan{PlanId: PLAN_DDL}
|
||||
tableName := string(ddl.Table)
|
||||
// Skip TableName if table is empty (create statements) or not found in schema
|
||||
if tableName != "" {
|
||||
tableInfo, ok := getTable(tableName)
|
||||
if ok {
|
||||
plan.TableName = tableInfo.Name
|
||||
}
|
||||
}
|
||||
return plan
|
||||
}
|
|
@ -0,0 +1,161 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
func analyzeUpdate(upd *sqlparser.Update, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
// Default plan
|
||||
plan = &ExecPlan{
|
||||
PlanId: PLAN_PASS_DML,
|
||||
FullQuery: GenerateFullQuery(upd),
|
||||
}
|
||||
|
||||
tableName := sqlparser.GetTableName(upd.Table)
|
||||
if tableName == "" {
|
||||
plan.Reason = REASON_TABLE
|
||||
return plan, nil
|
||||
}
|
||||
tableInfo, err := plan.setTableInfo(tableName, getTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
|
||||
log.Warningf("no primary key for table %s", tableName)
|
||||
plan.Reason = REASON_TABLE_NOINDEX
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
plan.SecondaryPKValues, err = analyzeUpdateExpressions(upd.Exprs, tableInfo.Indexes[0])
|
||||
if err != nil {
|
||||
if err == TooComplex {
|
||||
plan.Reason = REASON_PK_CHANGE
|
||||
return plan, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan.PlanId = PLAN_DML_SUBQUERY
|
||||
plan.OuterQuery = GenerateUpdateOuterQuery(upd, tableInfo.Indexes[0])
|
||||
plan.Subquery = GenerateUpdateSubquery(upd, tableInfo)
|
||||
|
||||
conditions := analyzeWhere(upd.Where)
|
||||
if conditions == nil {
|
||||
plan.Reason = REASON_WHERE
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
pkValues, err := getPKValues(conditions, tableInfo.Indexes[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pkValues != nil {
|
||||
plan.PlanId = PLAN_DML_PK
|
||||
plan.OuterQuery = plan.FullQuery
|
||||
plan.PKValues = pkValues
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func analyzeDelete(del *sqlparser.Delete, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
// Default plan
|
||||
plan = &ExecPlan{
|
||||
PlanId: PLAN_PASS_DML,
|
||||
FullQuery: GenerateFullQuery(del),
|
||||
}
|
||||
|
||||
tableName := sqlparser.GetTableName(del.Table)
|
||||
if tableName == "" {
|
||||
plan.Reason = REASON_TABLE
|
||||
return plan, nil
|
||||
}
|
||||
tableInfo, err := plan.setTableInfo(tableName, getTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
|
||||
log.Warningf("no primary key for table %s", tableName)
|
||||
plan.Reason = REASON_TABLE_NOINDEX
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
plan.PlanId = PLAN_DML_SUBQUERY
|
||||
plan.OuterQuery = GenerateDeleteOuterQuery(del, tableInfo.Indexes[0])
|
||||
plan.Subquery = GenerateDeleteSubquery(del, tableInfo)
|
||||
|
||||
conditions := analyzeWhere(del.Where)
|
||||
if conditions == nil {
|
||||
plan.Reason = REASON_WHERE
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
pkValues, err := getPKValues(conditions, tableInfo.Indexes[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pkValues != nil {
|
||||
plan.PlanId = PLAN_DML_PK
|
||||
plan.OuterQuery = plan.FullQuery
|
||||
plan.PKValues = pkValues
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func analyzeSet(set *sqlparser.Set) (plan *ExecPlan) {
|
||||
plan = &ExecPlan{
|
||||
PlanId: PLAN_SET,
|
||||
FullQuery: GenerateFullQuery(set),
|
||||
}
|
||||
if len(set.Exprs) > 1 { // Multiple set values
|
||||
return plan
|
||||
}
|
||||
update_expression := set.Exprs[0]
|
||||
plan.SetKey = string(update_expression.Name.Name)
|
||||
numExpr, ok := update_expression.Expr.(sqlparser.NumVal)
|
||||
if !ok {
|
||||
return plan
|
||||
}
|
||||
val := string(numExpr)
|
||||
if ival, err := strconv.ParseInt(val, 0, 64); err == nil {
|
||||
plan.SetValue = ival
|
||||
} else if fval, err := strconv.ParseFloat(val, 64); err == nil {
|
||||
plan.SetValue = fval
|
||||
}
|
||||
return plan
|
||||
}
|
||||
|
||||
func analyzeUpdateExpressions(exprs sqlparser.UpdateExprs, pkIndex *schema.Index) (pkValues []interface{}, err error) {
|
||||
for _, expr := range exprs {
|
||||
index := pkIndex.FindColumn(sqlparser.GetColName(expr.Name))
|
||||
if index == -1 {
|
||||
continue
|
||||
}
|
||||
if !sqlparser.IsValue(expr.Expr) {
|
||||
log.Warningf("expression is too complex %v", expr)
|
||||
return nil, TooComplex
|
||||
}
|
||||
if pkValues == nil {
|
||||
pkValues = make([]interface{}, len(pkIndex.Columns))
|
||||
}
|
||||
var err error
|
||||
pkValues[index], err = sqlparser.AsInterface(expr.Expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return pkValues, nil
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
var (
|
||||
TooComplex = errors.New("Complex")
|
||||
execLimit = &sqlparser.Limit{Rowcount: sqlparser.ValArg(":_vtMaxResultSize")}
|
||||
)
|
||||
|
||||
// ExecPlan is built for selects and DMLs.
|
||||
// PK Values values within ExecPlan can be:
|
||||
// sqltypes.Value: sourced form the query, or
|
||||
// string: bind variable name starting with ':', or
|
||||
// nil if no value was specified
|
||||
type ExecPlan struct {
|
||||
PlanId PlanType
|
||||
Reason ReasonType
|
||||
TableName string
|
||||
|
||||
// FieldQuery is used to fetch field info
|
||||
FieldQuery *sqlparser.ParsedQuery
|
||||
|
||||
// FullQuery will be set for all plans.
|
||||
FullQuery *sqlparser.ParsedQuery
|
||||
|
||||
// For PK plans, only OuterQuery is set.
|
||||
// For SUBQUERY plans, Subquery is also set.
|
||||
// IndexUsed is set only for PLAN_SELECT_SUBQUERY
|
||||
OuterQuery *sqlparser.ParsedQuery
|
||||
Subquery *sqlparser.ParsedQuery
|
||||
IndexUsed string
|
||||
|
||||
// For selects, columns to be returned
|
||||
// For PLAN_INSERT_SUBQUERY, columns to be inserted
|
||||
ColumnNumbers []int
|
||||
|
||||
// PLAN_PK_EQUAL, PLAN_DML_PK: where clause values
|
||||
// PLAN_PK_IN: IN clause values
|
||||
// PLAN_INSERT_PK: values clause
|
||||
PKValues []interface{}
|
||||
|
||||
// For update: set clause if pk is changing
|
||||
SecondaryPKValues []interface{}
|
||||
|
||||
// For PLAN_INSERT_SUBQUERY: pk columns in the subquery result
|
||||
SubqueryPKColumns []int
|
||||
|
||||
// PLAN_SET
|
||||
SetKey string
|
||||
SetValue interface{}
|
||||
}
|
||||
|
||||
func (node *ExecPlan) setTableInfo(tableName string, getTable TableGetter) (*schema.Table, error) {
|
||||
tableInfo, ok := getTable(tableName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("table %s not found in schema", tableName)
|
||||
}
|
||||
node.TableName = tableInfo.Name
|
||||
return tableInfo, nil
|
||||
}
|
||||
|
||||
type TableGetter func(tableName string) (*schema.Table, bool)
|
||||
|
||||
func GetExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
statement, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plan, err = analyzeSQL(statement, getTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if plan.PlanId == PLAN_PASS_DML {
|
||||
log.Warningf("PASS_DML: %s", sql)
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func GetStreamExecPlan(sql string, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
statement, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan = &ExecPlan{
|
||||
PlanId: PLAN_PASS_SELECT,
|
||||
FullQuery: GenerateFullQuery(statement),
|
||||
}
|
||||
|
||||
switch stmt := statement.(type) {
|
||||
case *sqlparser.Select:
|
||||
if stmt.Lock != "" {
|
||||
return nil, errors.New("select with lock disallowed with streaming")
|
||||
}
|
||||
tableName, _ := analyzeFrom(stmt.From)
|
||||
if tableName != "" {
|
||||
plan.setTableInfo(tableName, getTable)
|
||||
}
|
||||
|
||||
case *sqlparser.Union:
|
||||
// pass
|
||||
default:
|
||||
return nil, fmt.Errorf("'%v' not allowed for streaming", sqlparser.String(stmt))
|
||||
}
|
||||
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func analyzeSQL(statement sqlparser.Statement, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
switch stmt := statement.(type) {
|
||||
case *sqlparser.Union:
|
||||
return &ExecPlan{
|
||||
PlanId: PLAN_PASS_SELECT,
|
||||
FieldQuery: GenerateFieldQuery(stmt),
|
||||
FullQuery: GenerateFullQuery(stmt),
|
||||
Reason: REASON_SELECT,
|
||||
}, nil
|
||||
case *sqlparser.Select:
|
||||
return analyzeSelect(stmt, getTable)
|
||||
case *sqlparser.Insert:
|
||||
return analyzeInsert(stmt, getTable)
|
||||
case *sqlparser.Update:
|
||||
return analyzeUpdate(stmt, getTable)
|
||||
case *sqlparser.Delete:
|
||||
return analyzeDelete(stmt, getTable)
|
||||
case *sqlparser.Set:
|
||||
return analyzeSet(stmt), nil
|
||||
case *sqlparser.DDL:
|
||||
return analyzeDDL(stmt, getTable), nil
|
||||
}
|
||||
return nil, errors.New("invalid SQL")
|
||||
}
|
|
@ -0,0 +1,158 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
type IndexScore struct {
|
||||
Index *schema.Index
|
||||
ColumnMatch []bool
|
||||
MatchFailed bool
|
||||
}
|
||||
|
||||
type scoreValue int64
|
||||
|
||||
const (
|
||||
NO_MATCH = scoreValue(-1)
|
||||
PERFECT_SCORE = scoreValue(0)
|
||||
)
|
||||
|
||||
func NewIndexScore(index *schema.Index) *IndexScore {
|
||||
return &IndexScore{index, make([]bool, len(index.Columns)), false}
|
||||
}
|
||||
|
||||
func (is *IndexScore) FindMatch(columnName string) int {
|
||||
if is.MatchFailed {
|
||||
return -1
|
||||
}
|
||||
if index := is.Index.FindColumn(columnName); index != -1 {
|
||||
is.ColumnMatch[index] = true
|
||||
return index
|
||||
}
|
||||
// If the column is among the data columns, we can still use
|
||||
// the index without going to the main table
|
||||
if index := is.Index.FindDataColumn(columnName); index == -1 {
|
||||
is.MatchFailed = true
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (is *IndexScore) GetScore() scoreValue {
|
||||
if is.MatchFailed {
|
||||
return NO_MATCH
|
||||
}
|
||||
score := NO_MATCH
|
||||
for i, indexColumn := range is.ColumnMatch {
|
||||
if indexColumn {
|
||||
score = scoreValue(is.Index.Cardinality[i])
|
||||
continue
|
||||
}
|
||||
return score
|
||||
}
|
||||
return PERFECT_SCORE
|
||||
}
|
||||
|
||||
func NewIndexScoreList(indexes []*schema.Index) []*IndexScore {
|
||||
scoreList := make([]*IndexScore, len(indexes))
|
||||
for i, v := range indexes {
|
||||
scoreList[i] = NewIndexScore(v)
|
||||
}
|
||||
return scoreList
|
||||
}
|
||||
|
||||
func getSelectPKValues(conditions []sqlparser.BoolExpr, pkIndex *schema.Index) (planId PlanType, pkValues []interface{}, err error) {
|
||||
pkValues, err = getPKValues(conditions, pkIndex)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
if pkValues == nil {
|
||||
return PLAN_PASS_SELECT, nil, nil
|
||||
}
|
||||
for _, pkValue := range pkValues {
|
||||
inList, ok := pkValue.([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if len(pkValues) == 1 {
|
||||
return PLAN_PK_IN, inList, nil
|
||||
}
|
||||
return PLAN_PASS_SELECT, nil, nil
|
||||
}
|
||||
return PLAN_PK_EQUAL, pkValues, nil
|
||||
}
|
||||
|
||||
func getPKValues(conditions []sqlparser.BoolExpr, pkIndex *schema.Index) (pkValues []interface{}, err error) {
|
||||
pkIndexScore := NewIndexScore(pkIndex)
|
||||
pkValues = make([]interface{}, len(pkIndexScore.ColumnMatch))
|
||||
for _, condition := range conditions {
|
||||
condition, ok := condition.(*sqlparser.ComparisonExpr)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
if !sqlparser.StringIn(condition.Operator, sqlparser.AST_EQ, sqlparser.AST_IN) {
|
||||
return nil, nil
|
||||
}
|
||||
index := pkIndexScore.FindMatch(string(condition.Left.(*sqlparser.ColName).Name))
|
||||
if index == -1 {
|
||||
return nil, nil
|
||||
}
|
||||
switch condition.Operator {
|
||||
case sqlparser.AST_EQ, sqlparser.AST_IN:
|
||||
var err error
|
||||
pkValues[index], err = sqlparser.AsInterface(condition.Right)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
if pkIndexScore.GetScore() == PERFECT_SCORE {
|
||||
return pkValues, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getIndexMatch(conditions []sqlparser.BoolExpr, indexes []*schema.Index) string {
|
||||
indexScores := NewIndexScoreList(indexes)
|
||||
for _, condition := range conditions {
|
||||
var col string
|
||||
switch condition := condition.(type) {
|
||||
case *sqlparser.ComparisonExpr:
|
||||
col = string(condition.Left.(*sqlparser.ColName).Name)
|
||||
case *sqlparser.RangeCond:
|
||||
col = string(condition.Left.(*sqlparser.ColName).Name)
|
||||
default:
|
||||
panic("unreachaable")
|
||||
}
|
||||
for _, index := range indexScores {
|
||||
index.FindMatch(col)
|
||||
}
|
||||
}
|
||||
highScore := NO_MATCH
|
||||
highScorer := -1
|
||||
for i, index := range indexScores {
|
||||
curScore := index.GetScore()
|
||||
if curScore == NO_MATCH {
|
||||
continue
|
||||
}
|
||||
if curScore == PERFECT_SCORE {
|
||||
highScorer = i
|
||||
break
|
||||
}
|
||||
// Prefer secondary index over primary key
|
||||
if curScore >= highScore {
|
||||
highScore = curScore
|
||||
highScorer = i
|
||||
}
|
||||
}
|
||||
if highScorer == -1 {
|
||||
return ""
|
||||
}
|
||||
return indexes[highScorer].Name
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
func analyzeInsert(ins *sqlparser.Insert, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
plan = &ExecPlan{
|
||||
PlanId: PLAN_PASS_DML,
|
||||
FullQuery: GenerateFullQuery(ins),
|
||||
}
|
||||
tableName := sqlparser.GetTableName(ins.Table)
|
||||
if tableName == "" {
|
||||
plan.Reason = REASON_TABLE
|
||||
return plan, nil
|
||||
}
|
||||
tableInfo, err := plan.setTableInfo(tableName, getTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
|
||||
log.Warningf("no primary key for table %s", tableName)
|
||||
plan.Reason = REASON_TABLE_NOINDEX
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
pkColumnNumbers := getInsertPKColumns(ins.Columns, tableInfo)
|
||||
|
||||
if ins.OnDup != nil {
|
||||
// Upserts are not safe for statement based replication:
|
||||
// http://bugs.mysql.com/bug.php?id=58637
|
||||
plan.Reason = REASON_UPSERT
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
if sel, ok := ins.Rows.(sqlparser.SelectStatement); ok {
|
||||
plan.PlanId = PLAN_INSERT_SUBQUERY
|
||||
plan.OuterQuery = GenerateInsertOuterQuery(ins)
|
||||
plan.Subquery = GenerateSelectLimitQuery(sel)
|
||||
if len(ins.Columns) != 0 {
|
||||
plan.ColumnNumbers, err = analyzeSelectExprs(sqlparser.SelectExprs(ins.Columns), tableInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// StarExpr node will expand into all columns
|
||||
n := sqlparser.SelectExprs{&sqlparser.StarExpr{}}
|
||||
plan.ColumnNumbers, err = analyzeSelectExprs(n, tableInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
plan.SubqueryPKColumns = pkColumnNumbers
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// If it's not a sqlparser.SelectStatement, it's Values.
|
||||
rowList := ins.Rows.(sqlparser.Values)
|
||||
pkValues, err := getInsertPKValues(pkColumnNumbers, rowList, tableInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pkValues != nil {
|
||||
plan.PlanId = PLAN_INSERT_PK
|
||||
plan.OuterQuery = plan.FullQuery
|
||||
plan.PKValues = pkValues
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func getInsertPKColumns(columns sqlparser.Columns, tableInfo *schema.Table) (pkColumnNumbers []int) {
|
||||
if len(columns) == 0 {
|
||||
return tableInfo.PKColumns
|
||||
}
|
||||
pkIndex := tableInfo.Indexes[0]
|
||||
pkColumnNumbers = make([]int, len(pkIndex.Columns))
|
||||
for i := range pkColumnNumbers {
|
||||
pkColumnNumbers[i] = -1
|
||||
}
|
||||
for i, column := range columns {
|
||||
index := pkIndex.FindColumn(sqlparser.GetColName(column.(*sqlparser.NonStarExpr).Expr))
|
||||
if index == -1 {
|
||||
continue
|
||||
}
|
||||
pkColumnNumbers[index] = i
|
||||
}
|
||||
return pkColumnNumbers
|
||||
}
|
||||
|
||||
func getInsertPKValues(pkColumnNumbers []int, rowList sqlparser.Values, tableInfo *schema.Table) (pkValues []interface{}, err error) {
|
||||
pkValues = make([]interface{}, len(pkColumnNumbers))
|
||||
for index, columnNumber := range pkColumnNumbers {
|
||||
if columnNumber == -1 {
|
||||
pkValues[index] = tableInfo.GetPKColumn(index).Default
|
||||
continue
|
||||
}
|
||||
values := make([]interface{}, len(rowList))
|
||||
for j := 0; j < len(rowList); j++ {
|
||||
if _, ok := rowList[j].(*sqlparser.Subquery); ok {
|
||||
return nil, errors.New("row subquery not supported for inserts")
|
||||
}
|
||||
row := rowList[j].(sqlparser.ValTuple)
|
||||
if columnNumber >= len(row) {
|
||||
return nil, errors.New("column count doesn't match value count")
|
||||
}
|
||||
node := row[columnNumber]
|
||||
if !sqlparser.IsValue(node) {
|
||||
log.Warningf("insert is too complex %v", node)
|
||||
return nil, nil
|
||||
}
|
||||
var err error
|
||||
values[j], err = sqlparser.AsInterface(node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(values) == 1 {
|
||||
pkValues[index] = values[0]
|
||||
} else {
|
||||
pkValues[index] = values
|
||||
}
|
||||
}
|
||||
return pkValues, nil
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
)
|
||||
|
||||
func TestPlan(t *testing.T) {
|
||||
testSchema := loadSchema("schema_test.json")
|
||||
for tcase := range iterateExecFile("exec_cases.txt") {
|
||||
plan, err := GetExecPlan(tcase.input, func(name string) (*schema.Table, bool) {
|
||||
r, ok := testSchema[name]
|
||||
return r, ok
|
||||
})
|
||||
var out string
|
||||
if err != nil {
|
||||
out = err.Error()
|
||||
} else {
|
||||
bout, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
|
||||
}
|
||||
out = string(bout)
|
||||
}
|
||||
if out != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
|
||||
}
|
||||
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustom(t *testing.T) {
|
||||
testSchemas := testfiles.Glob("sqlparser_test/*_schema.json")
|
||||
if len(testSchemas) == 0 {
|
||||
t.Log("No schemas to test")
|
||||
return
|
||||
}
|
||||
for _, schemFile := range testSchemas {
|
||||
schem := loadSchema(schemFile)
|
||||
t.Logf("Testing schema %s", schemFile)
|
||||
files, err := filepath.Glob(strings.Replace(schemFile, "schema.json", "*.txt", -1))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(files) == 0 {
|
||||
t.Fatalf("No test files for %s", schemFile)
|
||||
}
|
||||
getter := func(name string) (*schema.Table, bool) {
|
||||
r, ok := schem[name]
|
||||
return r, ok
|
||||
}
|
||||
for _, file := range files {
|
||||
t.Logf("Testing file %s", file)
|
||||
for tcase := range iterateExecFile(file) {
|
||||
plan, err := GetExecPlan(tcase.input, getter)
|
||||
var out string
|
||||
if err != nil {
|
||||
out = err.Error()
|
||||
} else {
|
||||
bout, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
|
||||
}
|
||||
out = string(bout)
|
||||
}
|
||||
if out != tcase.output {
|
||||
t.Errorf("File: %s: Line:%v\n%s\n%s", file, tcase.lineno, tcase.output, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamPlan(t *testing.T) {
|
||||
testSchema := loadSchema("schema_test.json")
|
||||
for tcase := range iterateExecFile("stream_cases.txt") {
|
||||
plan, err := GetStreamExecPlan(tcase.input, func(name string) (*schema.Table, bool) {
|
||||
r, ok := testSchema[name]
|
||||
return r, ok
|
||||
})
|
||||
var out string
|
||||
if err != nil {
|
||||
out = err.Error()
|
||||
} else {
|
||||
bout, err := json.Marshal(plan)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v: %v", plan, err))
|
||||
}
|
||||
out = string(bout)
|
||||
}
|
||||
if out != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, out))
|
||||
}
|
||||
//fmt.Printf("%s\n%s\n\n", tcase.input, out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDDLPlan(t *testing.T) {
|
||||
for tcase := range iterateExecFile("ddl_cases.txt") {
|
||||
plan := DDLParse(tcase.input)
|
||||
expected := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(tcase.output), &expected)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Error marshalling %v", plan))
|
||||
}
|
||||
matchString(t, tcase.lineno, expected["Action"], plan.Action)
|
||||
matchString(t, tcase.lineno, expected["TableName"], plan.TableName)
|
||||
matchString(t, tcase.lineno, expected["NewName"], plan.NewName)
|
||||
}
|
||||
}
|
||||
|
||||
func matchString(t *testing.T, line int, expected interface{}, actual string) {
|
||||
if expected != nil {
|
||||
if expected.(string) != actual {
|
||||
t.Error(fmt.Sprintf("Line %d: expected: %v, received %s", line, expected, actual))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadSchema(name string) map[string]*schema.Table {
|
||||
b, err := ioutil.ReadFile(locateFile(name))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tables := make([]*schema.Table, 0, 8)
|
||||
err = json.Unmarshal(b, &tables)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s := make(map[string]*schema.Table)
|
||||
for _, t := range tables {
|
||||
s[t.Name] = t
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
file string
|
||||
lineno int
|
||||
input string
|
||||
output string
|
||||
}
|
||||
|
||||
func iterateExecFile(name string) (testCaseIterator chan testCase) {
|
||||
name = locateFile(name)
|
||||
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Could not open file %s", name))
|
||||
}
|
||||
testCaseIterator = make(chan testCase)
|
||||
go func() {
|
||||
defer close(testCaseIterator)
|
||||
|
||||
r := bufio.NewReader(fd)
|
||||
lineno := 0
|
||||
for {
|
||||
binput, err := r.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
fmt.Printf("Line: %d\n", lineno)
|
||||
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
|
||||
}
|
||||
break
|
||||
}
|
||||
lineno++
|
||||
input := string(binput)
|
||||
if input == "" || input == "\n" || input[0] == '#' || strings.HasPrefix(input, "Length:") {
|
||||
//fmt.Printf("%s\n", input)
|
||||
continue
|
||||
}
|
||||
err = json.Unmarshal(binput, &input)
|
||||
if err != nil {
|
||||
fmt.Printf("Line: %d, input: %s\n", lineno, binput)
|
||||
panic(err)
|
||||
}
|
||||
input = strings.Trim(input, "\"")
|
||||
var output []byte
|
||||
for {
|
||||
l, err := r.ReadBytes('\n')
|
||||
lineno++
|
||||
if err != nil {
|
||||
fmt.Printf("Line: %d\n", lineno)
|
||||
panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
|
||||
}
|
||||
output = append(output, l...)
|
||||
if l[0] == '}' {
|
||||
output = output[:len(output)-1]
|
||||
b := bytes.NewBuffer(make([]byte, 0, 64))
|
||||
if err := json.Compact(b, output); err == nil {
|
||||
output = b.Bytes()
|
||||
}
|
||||
break
|
||||
}
|
||||
if l[0] == '"' {
|
||||
output = output[1 : len(output)-2]
|
||||
break
|
||||
}
|
||||
}
|
||||
testCaseIterator <- testCase{name, lineno, input, string(output)}
|
||||
}
|
||||
}()
|
||||
return testCaseIterator
|
||||
}
|
||||
|
||||
func locateFile(name string) string {
|
||||
if path.IsAbs(name) {
|
||||
return name
|
||||
}
|
||||
return testfiles.Locate("sqlparser_test/" + name)
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
)
|
||||
|
||||
type PlanType int
|
||||
|
||||
const (
|
||||
// PLAN_PASS_SELECT is pass through select statements. This is the
|
||||
// default plan for select statements.
|
||||
PLAN_PASS_SELECT PlanType = iota
|
||||
// PLAN_PASS_DML is pass through update & delete statements. This is
|
||||
// the default plan for update and delete statements.
|
||||
PLAN_PASS_DML
|
||||
// PLAN_PK_EQUAL is select statement which has an equality where clause
|
||||
// on primary key
|
||||
PLAN_PK_EQUAL
|
||||
// PLAN_PK_IN is select statement with a single IN clause on primary key
|
||||
PLAN_PK_IN
|
||||
// PLAN_SELECT_SUBQUERY is select statement with a subselect statement
|
||||
PLAN_SELECT_SUBQUERY
|
||||
// PLAN_DML_PK is an update or delete with an equality where clause(s)
|
||||
// on primary key(s)
|
||||
PLAN_DML_PK
|
||||
// PLAN_DML_SUBQUERY is an update or delete with a subselect statement
|
||||
PLAN_DML_SUBQUERY
|
||||
// PLAN_INSERT_PK is insert statement where the PK value is
|
||||
// supplied with the query
|
||||
PLAN_INSERT_PK
|
||||
// PLAN_INSERT_SUBQUERY is same as PLAN_DML_SUBQUERY but for inserts
|
||||
PLAN_INSERT_SUBQUERY
|
||||
// PLAN_SET is for SET statements
|
||||
PLAN_SET
|
||||
// PLAN_DDL is for DDL statements
|
||||
PLAN_DDL
|
||||
NumPlans
|
||||
)
|
||||
|
||||
// Must exactly match order of plan constants.
|
||||
var planName = []string{
|
||||
"PASS_SELECT",
|
||||
"PASS_DML",
|
||||
"PK_EQUAL",
|
||||
"PK_IN",
|
||||
"SELECT_SUBQUERY",
|
||||
"DML_PK",
|
||||
"DML_SUBQUERY",
|
||||
"INSERT_PK",
|
||||
"INSERT_SUBQUERY",
|
||||
"SET",
|
||||
"DDL",
|
||||
}
|
||||
|
||||
func (pt PlanType) String() string {
|
||||
if pt < 0 || pt >= NumPlans {
|
||||
return ""
|
||||
}
|
||||
return planName[pt]
|
||||
}
|
||||
|
||||
func PlanByName(s string) (pt PlanType, ok bool) {
|
||||
for i, v := range planName {
|
||||
if v == s {
|
||||
return PlanType(i), true
|
||||
}
|
||||
}
|
||||
return NumPlans, false
|
||||
}
|
||||
|
||||
func (pt PlanType) IsSelect() bool {
|
||||
return pt == PLAN_PASS_SELECT || pt == PLAN_PK_EQUAL || pt == PLAN_PK_IN || pt == PLAN_SELECT_SUBQUERY
|
||||
}
|
||||
|
||||
func (pt PlanType) MarshalJSON() ([]byte, error) {
|
||||
return ([]byte)(fmt.Sprintf("\"%s\"", pt.String())), nil
|
||||
}
|
||||
|
||||
// MinRole is the minimum Role required to execute this PlanType
|
||||
func (pt PlanType) MinRole() tableacl.Role {
|
||||
return tableAclRoles[pt]
|
||||
}
|
||||
|
||||
var tableAclRoles = map[PlanType]tableacl.Role{
|
||||
PLAN_PASS_SELECT: tableacl.READER,
|
||||
PLAN_PK_EQUAL: tableacl.READER,
|
||||
PLAN_PK_IN: tableacl.READER,
|
||||
PLAN_SELECT_SUBQUERY: tableacl.READER,
|
||||
PLAN_SET: tableacl.READER,
|
||||
PLAN_PASS_DML: tableacl.WRITER,
|
||||
PLAN_DML_PK: tableacl.WRITER,
|
||||
PLAN_DML_SUBQUERY: tableacl.WRITER,
|
||||
PLAN_INSERT_PK: tableacl.WRITER,
|
||||
PLAN_INSERT_SUBQUERY: tableacl.WRITER,
|
||||
PLAN_DDL: tableacl.ADMIN,
|
||||
}
|
||||
|
||||
type ReasonType int
|
||||
|
||||
const (
|
||||
REASON_DEFAULT ReasonType = iota
|
||||
REASON_SELECT
|
||||
REASON_TABLE
|
||||
REASON_NOCACHE
|
||||
REASON_SELECT_LIST
|
||||
REASON_LOCK
|
||||
REASON_WHERE
|
||||
REASON_ORDER
|
||||
REASON_PKINDEX
|
||||
REASON_NOINDEX_MATCH
|
||||
REASON_TABLE_NOINDEX
|
||||
REASON_PK_CHANGE
|
||||
REASON_COMPOSITE_PK
|
||||
REASON_HAS_HINTS
|
||||
REASON_UPSERT
|
||||
)
|
||||
|
||||
// Must exactly match order of reason constants.
|
||||
var reasonName = []string{
|
||||
"DEFAULT",
|
||||
"SELECT",
|
||||
"TABLE",
|
||||
"NOCACHE",
|
||||
"SELECT_LIST",
|
||||
"LOCK",
|
||||
"WHERE",
|
||||
"ORDER",
|
||||
"PKINDEX",
|
||||
"NOINDEX_MATCH",
|
||||
"TABLE_NOINDEX",
|
||||
"PK_CHANGE",
|
||||
"COMPOSITE_PK",
|
||||
"HAS_HINTS",
|
||||
"UPSERT",
|
||||
}
|
||||
|
||||
func (rt ReasonType) String() string {
|
||||
return reasonName[rt]
|
||||
}
|
||||
|
||||
func (rt ReasonType) MarshalJSON() ([]byte, error) {
|
||||
return ([]byte)(fmt.Sprintf("\"%s\"", rt.String())), nil
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
func GenerateFullQuery(statement sqlparser.Statement) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
statement.Format(buf)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func GenerateFieldQuery(statement sqlparser.Statement) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(FormatImpossible)
|
||||
buf.Fprintf("%v", statement)
|
||||
if buf.HasBindVars() {
|
||||
return nil
|
||||
}
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
// FormatImpossible is a callback function used by TrackedBuffer
|
||||
// to generate a modified version of the query where all selects
|
||||
// have impossible where clauses. It overrides a few node types
|
||||
// and passes the rest down to the default FormatNode.
|
||||
func FormatImpossible(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) {
|
||||
switch node := node.(type) {
|
||||
case *sqlparser.Select:
|
||||
buf.Fprintf("select %v from %v where 1 != 1", node.SelectExprs, node.From)
|
||||
case *sqlparser.JoinTableExpr:
|
||||
if node.Join == sqlparser.AST_LEFT_JOIN || node.Join == sqlparser.AST_RIGHT_JOIN {
|
||||
// ON clause is requried
|
||||
buf.Fprintf("%v %s %v on 1 != 1", node.LeftExpr, node.Join, node.RightExpr)
|
||||
} else {
|
||||
buf.Fprintf("%v %s %v", node.LeftExpr, node.Join, node.RightExpr)
|
||||
}
|
||||
default:
|
||||
node.Format(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateSelectLimitQuery(selStmt sqlparser.SelectStatement) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
sel, ok := selStmt.(*sqlparser.Select)
|
||||
if ok {
|
||||
limit := sel.Limit
|
||||
if limit == nil {
|
||||
sel.Limit = execLimit
|
||||
defer func() {
|
||||
sel.Limit = nil
|
||||
}()
|
||||
}
|
||||
}
|
||||
buf.Fprintf("%v", selStmt)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func GenerateEqualOuterQuery(sel *sqlparser.Select, tableInfo *schema.Table) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
fmt.Fprintf(buf, "select ")
|
||||
writeColumnList(buf, tableInfo.Columns)
|
||||
buf.Fprintf(" from %v where ", sel.From)
|
||||
generatePKWhere(buf, tableInfo.Indexes[0])
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func GenerateInOuterQuery(sel *sqlparser.Select, tableInfo *schema.Table) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
fmt.Fprintf(buf, "select ")
|
||||
writeColumnList(buf, tableInfo.Columns)
|
||||
// We assume there is one and only one PK column.
|
||||
// A '*' argument name means all variables of the list.
|
||||
buf.Fprintf(" from %v where %s in (%a)", sel.From, tableInfo.Indexes[0].Columns[0], "*")
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func GenerateInsertOuterQuery(ins *sqlparser.Insert) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
buf.Fprintf("insert %vinto %v%v values %a%v",
|
||||
ins.Comments,
|
||||
ins.Table,
|
||||
ins.Columns,
|
||||
"_rowValues",
|
||||
ins.OnDup,
|
||||
)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func GenerateUpdateOuterQuery(upd *sqlparser.Update, pkIndex *schema.Index) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
buf.Fprintf("update %v%v set %v where ", upd.Comments, upd.Table, upd.Exprs)
|
||||
generatePKWhere(buf, pkIndex)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func GenerateDeleteOuterQuery(del *sqlparser.Delete, pkIndex *schema.Index) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
buf.Fprintf("delete %vfrom %v where ", del.Comments, del.Table)
|
||||
generatePKWhere(buf, pkIndex)
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func generatePKWhere(buf *sqlparser.TrackedBuffer, pkIndex *schema.Index) {
|
||||
for i := 0; i < len(pkIndex.Columns); i++ {
|
||||
if i != 0 {
|
||||
buf.WriteString(" and ")
|
||||
}
|
||||
buf.Fprintf("%s = %a", pkIndex.Columns[i], strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateSelectSubquery(sel *sqlparser.Select, tableInfo *schema.Table, index string) *sqlparser.ParsedQuery {
|
||||
hint := &sqlparser.IndexHints{Type: sqlparser.AST_USE, Indexes: [][]byte{[]byte(index)}}
|
||||
table_expr := sel.From[0].(*sqlparser.AliasedTableExpr)
|
||||
savedHint := table_expr.Hints
|
||||
table_expr.Hints = hint
|
||||
defer func() {
|
||||
table_expr.Hints = savedHint
|
||||
}()
|
||||
return GenerateSubquery(
|
||||
tableInfo.Indexes[0].Columns,
|
||||
table_expr,
|
||||
sel.Where,
|
||||
sel.OrderBy,
|
||||
sel.Limit,
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
func GenerateUpdateSubquery(upd *sqlparser.Update, tableInfo *schema.Table) *sqlparser.ParsedQuery {
|
||||
return GenerateSubquery(
|
||||
tableInfo.Indexes[0].Columns,
|
||||
&sqlparser.AliasedTableExpr{Expr: upd.Table},
|
||||
upd.Where,
|
||||
upd.OrderBy,
|
||||
upd.Limit,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func GenerateDeleteSubquery(del *sqlparser.Delete, tableInfo *schema.Table) *sqlparser.ParsedQuery {
|
||||
return GenerateSubquery(
|
||||
tableInfo.Indexes[0].Columns,
|
||||
&sqlparser.AliasedTableExpr{Expr: del.Table},
|
||||
del.Where,
|
||||
del.OrderBy,
|
||||
del.Limit,
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
func GenerateSubquery(columns []string, table *sqlparser.AliasedTableExpr, where *sqlparser.Where, order sqlparser.OrderBy, limit *sqlparser.Limit, for_update bool) *sqlparser.ParsedQuery {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
if limit == nil {
|
||||
limit = execLimit
|
||||
}
|
||||
fmt.Fprintf(buf, "select ")
|
||||
i := 0
|
||||
for i = 0; i < len(columns)-1; i++ {
|
||||
fmt.Fprintf(buf, "%s, ", columns[i])
|
||||
}
|
||||
fmt.Fprintf(buf, "%s", columns[i])
|
||||
buf.Fprintf(" from %v%v%v%v", table, where, order, limit)
|
||||
if for_update {
|
||||
buf.Fprintf(sqlparser.AST_FOR_UPDATE)
|
||||
}
|
||||
return buf.ParsedQuery()
|
||||
}
|
||||
|
||||
func writeColumnList(buf *sqlparser.TrackedBuffer, columns []schema.TableColumn) {
|
||||
i := 0
|
||||
for i = 0; i < len(columns)-1; i++ {
|
||||
fmt.Fprintf(buf, "%s, ", columns[i].Name)
|
||||
}
|
||||
fmt.Fprintf(buf, "%s", columns[i].Name)
|
||||
}
|
|
@ -0,0 +1,219 @@
|
|||
// Copyright 2014, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package planbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
func analyzeSelect(sel *sqlparser.Select, getTable TableGetter) (plan *ExecPlan, err error) {
|
||||
// Default plan
|
||||
plan = &ExecPlan{
|
||||
PlanId: PLAN_PASS_SELECT,
|
||||
FieldQuery: GenerateFieldQuery(sel),
|
||||
FullQuery: GenerateSelectLimitQuery(sel),
|
||||
}
|
||||
|
||||
// There are bind variables in the SELECT list
|
||||
if plan.FieldQuery == nil {
|
||||
plan.Reason = REASON_SELECT_LIST
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
if sel.Distinct != "" || sel.GroupBy != nil || sel.Having != nil {
|
||||
plan.Reason = REASON_SELECT
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// from
|
||||
tableName, hasHints := analyzeFrom(sel.From)
|
||||
if tableName == "" {
|
||||
plan.Reason = REASON_TABLE
|
||||
return plan, nil
|
||||
}
|
||||
tableInfo, err := plan.setTableInfo(tableName, getTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Don't improve the plan if the select is locking the row
|
||||
if sel.Lock != "" {
|
||||
plan.Reason = REASON_LOCK
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// Further improvements possible only if table is row-cached
|
||||
if tableInfo.CacheType == schema.CACHE_NONE || tableInfo.CacheType == schema.CACHE_W {
|
||||
plan.Reason = REASON_NOCACHE
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// Select expressions
|
||||
selects, err := analyzeSelectExprs(sel.SelectExprs, tableInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if selects == nil {
|
||||
plan.Reason = REASON_SELECT_LIST
|
||||
return plan, nil
|
||||
}
|
||||
plan.ColumnNumbers = selects
|
||||
|
||||
// where
|
||||
conditions := analyzeWhere(sel.Where)
|
||||
if conditions == nil {
|
||||
plan.Reason = REASON_WHERE
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// order
|
||||
if sel.OrderBy != nil {
|
||||
plan.Reason = REASON_ORDER
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// This check should never fail because we only cache tables with primary keys.
|
||||
if len(tableInfo.Indexes) == 0 || tableInfo.Indexes[0].Name != "PRIMARY" {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
// Attempt PK match only if there's no limit clause
|
||||
if sel.Limit == nil {
|
||||
planId, pkValues, err := getSelectPKValues(conditions, tableInfo.Indexes[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch planId {
|
||||
case PLAN_PK_EQUAL:
|
||||
plan.PlanId = PLAN_PK_EQUAL
|
||||
plan.OuterQuery = GenerateEqualOuterQuery(sel, tableInfo)
|
||||
plan.PKValues = pkValues
|
||||
return plan, nil
|
||||
case PLAN_PK_IN:
|
||||
plan.PlanId = PLAN_PK_IN
|
||||
plan.OuterQuery = GenerateInOuterQuery(sel, tableInfo)
|
||||
plan.PKValues = pkValues
|
||||
return plan, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(tableInfo.Indexes[0].Columns) != 1 {
|
||||
plan.Reason = REASON_COMPOSITE_PK
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
// TODO: Analyze hints to improve plan.
|
||||
if hasHints {
|
||||
plan.Reason = REASON_HAS_HINTS
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
plan.IndexUsed = getIndexMatch(conditions, tableInfo.Indexes)
|
||||
if plan.IndexUsed == "" {
|
||||
plan.Reason = REASON_NOINDEX_MATCH
|
||||
return plan, nil
|
||||
}
|
||||
if plan.IndexUsed == "PRIMARY" {
|
||||
plan.Reason = REASON_PKINDEX
|
||||
return plan, nil
|
||||
}
|
||||
// TODO: We can further optimize. Change this to pass-through if select list matches all columns in index.
|
||||
plan.PlanId = PLAN_SELECT_SUBQUERY
|
||||
plan.OuterQuery = GenerateInOuterQuery(sel, tableInfo)
|
||||
plan.Subquery = GenerateSelectSubquery(sel, tableInfo, plan.IndexUsed)
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func analyzeSelectExprs(exprs sqlparser.SelectExprs, table *schema.Table) (selects []int, err error) {
|
||||
selects = make([]int, 0, len(exprs))
|
||||
for _, expr := range exprs {
|
||||
switch expr := expr.(type) {
|
||||
case *sqlparser.StarExpr:
|
||||
// Append all columns.
|
||||
for colIndex := range table.Columns {
|
||||
selects = append(selects, colIndex)
|
||||
}
|
||||
case *sqlparser.NonStarExpr:
|
||||
name := sqlparser.GetColName(expr.Expr)
|
||||
if name == "" {
|
||||
// Not a simple column name.
|
||||
return nil, nil
|
||||
}
|
||||
colIndex := table.FindColumn(name)
|
||||
if colIndex == -1 {
|
||||
return nil, fmt.Errorf("column %s not found in table %s", name, table.Name)
|
||||
}
|
||||
selects = append(selects, colIndex)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
return selects, nil
|
||||
}
|
||||
|
||||
func analyzeFrom(tableExprs sqlparser.TableExprs) (tablename string, hasHints bool) {
|
||||
if len(tableExprs) > 1 {
|
||||
return "", false
|
||||
}
|
||||
node, ok := tableExprs[0].(*sqlparser.AliasedTableExpr)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return sqlparser.GetTableName(node.Expr), node.Hints != nil
|
||||
}
|
||||
|
||||
func analyzeWhere(node *sqlparser.Where) (conditions []sqlparser.BoolExpr) {
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return analyzeBoolean(node.Expr)
|
||||
}
|
||||
|
||||
func analyzeBoolean(node sqlparser.BoolExpr) (conditions []sqlparser.BoolExpr) {
|
||||
switch node := node.(type) {
|
||||
case *sqlparser.AndExpr:
|
||||
left := analyzeBoolean(node.Left)
|
||||
right := analyzeBoolean(node.Right)
|
||||
if left == nil || right == nil {
|
||||
return nil
|
||||
}
|
||||
if sqlparser.HasINClause(left) && sqlparser.HasINClause(right) {
|
||||
return nil
|
||||
}
|
||||
return append(left, right...)
|
||||
case *sqlparser.ParenBoolExpr:
|
||||
return analyzeBoolean(node.Expr)
|
||||
case *sqlparser.ComparisonExpr:
|
||||
switch {
|
||||
case sqlparser.StringIn(
|
||||
node.Operator,
|
||||
sqlparser.AST_EQ,
|
||||
sqlparser.AST_LT,
|
||||
sqlparser.AST_GT,
|
||||
sqlparser.AST_LE,
|
||||
sqlparser.AST_GE,
|
||||
sqlparser.AST_NSE,
|
||||
sqlparser.AST_LIKE):
|
||||
if sqlparser.IsColName(node.Left) && sqlparser.IsValue(node.Right) {
|
||||
return []sqlparser.BoolExpr{node}
|
||||
}
|
||||
case node.Operator == sqlparser.AST_IN:
|
||||
if sqlparser.IsColName(node.Left) && sqlparser.IsSimpleTuple(node.Right) {
|
||||
return []sqlparser.BoolExpr{node}
|
||||
}
|
||||
}
|
||||
case *sqlparser.RangeCond:
|
||||
if node.Operator != sqlparser.AST_BETWEEN {
|
||||
return nil
|
||||
}
|
||||
if sqlparser.IsColName(node.Left) && sqlparser.IsValue(node.From) && sqlparser.IsValue(node.To) {
|
||||
return []sqlparser.BoolExpr{node}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
)
|
||||
|
||||
func TestQueryRules(t *testing.T) {
|
||||
|
@ -52,7 +52,7 @@ func TestQueryRules(t *testing.T) {
|
|||
func TestCopy(t *testing.T) {
|
||||
qrs := NewQueryRules()
|
||||
qr1 := NewQueryRule("rule 1", "r1", QR_FAIL)
|
||||
qr1.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
|
||||
qr1.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
|
||||
qr1.AddTableCond("aa")
|
||||
qr1.AddBindVarCond("a", true, false, QR_NOOP, nil)
|
||||
|
||||
|
@ -70,7 +70,7 @@ func TestCopy(t *testing.T) {
|
|||
t.Errorf("want false, got true")
|
||||
}
|
||||
|
||||
qr1.plans[0] = sqlparser.PLAN_INSERT_PK
|
||||
qr1.plans[0] = planbuilder.PLAN_INSERT_PK
|
||||
if qr1.plans[0] == qrf1.plans[0] {
|
||||
t.Errorf("want false, got true")
|
||||
}
|
||||
|
@ -95,12 +95,12 @@ func TestFilterByPlan(t *testing.T) {
|
|||
qr1 := NewQueryRule("rule 1", "r1", QR_FAIL)
|
||||
qr1.SetIPCond("123")
|
||||
qr1.SetQueryCond("select")
|
||||
qr1.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
|
||||
qr1.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
|
||||
qr1.AddBindVarCond("a", true, false, QR_NOOP, nil)
|
||||
|
||||
qr2 := NewQueryRule("rule 2", "r2", QR_FAIL)
|
||||
qr2.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
|
||||
qr2.AddPlanCond(sqlparser.PLAN_PK_EQUAL)
|
||||
qr2.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
|
||||
qr2.AddPlanCond(planbuilder.PLAN_PK_EQUAL)
|
||||
qr2.AddBindVarCond("a", true, false, QR_NOOP, nil)
|
||||
|
||||
qr3 := NewQueryRule("rule 3", "r3", QR_FAIL)
|
||||
|
@ -116,7 +116,7 @@ func TestFilterByPlan(t *testing.T) {
|
|||
qrs.Add(qr3)
|
||||
qrs.Add(qr4)
|
||||
|
||||
qrs1 := qrs.filterByPlan("select", sqlparser.PLAN_PASS_SELECT, "a")
|
||||
qrs1 := qrs.filterByPlan("select", planbuilder.PLAN_PASS_SELECT, "a")
|
||||
if l := len(qrs1.rules); l != 3 {
|
||||
t.Errorf("want 3, got %d", l)
|
||||
}
|
||||
|
@ -130,7 +130,7 @@ func TestFilterByPlan(t *testing.T) {
|
|||
t.Errorf("want nil, got non-nil")
|
||||
}
|
||||
|
||||
qrs1 = qrs.filterByPlan("insert", sqlparser.PLAN_PASS_SELECT, "a")
|
||||
qrs1 = qrs.filterByPlan("insert", planbuilder.PLAN_PASS_SELECT, "a")
|
||||
if l := len(qrs1.rules); l != 1 {
|
||||
t.Errorf("want 1, got %d", l)
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ func TestFilterByPlan(t *testing.T) {
|
|||
t.Errorf("want r2, got %s", qrs1.rules[0].Name)
|
||||
}
|
||||
|
||||
qrs1 = qrs.filterByPlan("insert", sqlparser.PLAN_PK_EQUAL, "a")
|
||||
qrs1 = qrs.filterByPlan("insert", planbuilder.PLAN_PK_EQUAL, "a")
|
||||
if l := len(qrs1.rules); l != 1 {
|
||||
t.Errorf("want 1, got %d", l)
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func TestFilterByPlan(t *testing.T) {
|
|||
t.Errorf("want r2, got %s", qrs1.rules[0].Name)
|
||||
}
|
||||
|
||||
qrs1 = qrs.filterByPlan("select", sqlparser.PLAN_INSERT_PK, "a")
|
||||
qrs1 = qrs.filterByPlan("select", planbuilder.PLAN_INSERT_PK, "a")
|
||||
if l := len(qrs1.rules); l != 1 {
|
||||
t.Errorf("want 1, got %d", l)
|
||||
}
|
||||
|
@ -154,12 +154,12 @@ func TestFilterByPlan(t *testing.T) {
|
|||
t.Errorf("want r3, got %s", qrs1.rules[0].Name)
|
||||
}
|
||||
|
||||
qrs1 = qrs.filterByPlan("sel", sqlparser.PLAN_INSERT_PK, "a")
|
||||
qrs1 = qrs.filterByPlan("sel", planbuilder.PLAN_INSERT_PK, "a")
|
||||
if qrs1.rules != nil {
|
||||
t.Errorf("want nil, got non-nil")
|
||||
}
|
||||
|
||||
qrs1 = qrs.filterByPlan("table", sqlparser.PLAN_PASS_DML, "b")
|
||||
qrs1 = qrs.filterByPlan("table", planbuilder.PLAN_PASS_DML, "b")
|
||||
if l := len(qrs1.rules); l != 1 {
|
||||
t.Errorf("want 1, got %#v, %#v", qrs1.rules[0], qrs1.rules[1])
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ func TestFilterByPlan(t *testing.T) {
|
|||
qr5 := NewQueryRule("rule 5", "r5", QR_FAIL)
|
||||
qrs.Add(qr5)
|
||||
|
||||
qrs1 = qrs.filterByPlan("sel", sqlparser.PLAN_INSERT_PK, "a")
|
||||
qrs1 = qrs.filterByPlan("sel", planbuilder.PLAN_INSERT_PK, "a")
|
||||
if l := len(qrs1.rules); l != 1 {
|
||||
t.Errorf("want 1, got %d", l)
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ func TestFilterByPlan(t *testing.T) {
|
|||
}
|
||||
|
||||
qrsnil1 := NewQueryRules()
|
||||
if qrsnil2 := qrsnil1.filterByPlan("", sqlparser.PLAN_PASS_SELECT, "a"); qrsnil2.rules != nil {
|
||||
if qrsnil2 := qrsnil1.filterByPlan("", planbuilder.PLAN_PASS_SELECT, "a"); qrsnil2.rules != nil {
|
||||
t.Errorf("want nil, got non-nil")
|
||||
}
|
||||
}
|
||||
|
@ -204,13 +204,13 @@ func TestQueryRule(t *testing.T) {
|
|||
t.Errorf("want error")
|
||||
}
|
||||
|
||||
qr.AddPlanCond(sqlparser.PLAN_PASS_SELECT)
|
||||
qr.AddPlanCond(sqlparser.PLAN_INSERT_PK)
|
||||
qr.AddPlanCond(planbuilder.PLAN_PASS_SELECT)
|
||||
qr.AddPlanCond(planbuilder.PLAN_INSERT_PK)
|
||||
|
||||
if qr.plans[0] != sqlparser.PLAN_PASS_SELECT {
|
||||
if qr.plans[0] != planbuilder.PLAN_PASS_SELECT {
|
||||
t.Errorf("want PASS_SELECT, got %s", qr.plans[0].String())
|
||||
}
|
||||
if qr.plans[1] != sqlparser.PLAN_INSERT_PK {
|
||||
if qr.plans[1] != planbuilder.PLAN_INSERT_PK {
|
||||
t.Errorf("want INSERT_PK, got %s", qr.plans[1].String())
|
||||
}
|
||||
|
||||
|
@ -563,10 +563,10 @@ func TestImport(t *testing.T) {
|
|||
if qrs.rules[0].query == nil {
|
||||
t.Errorf("want non-nil")
|
||||
}
|
||||
if qrs.rules[0].plans[0] != sqlparser.PLAN_PASS_SELECT {
|
||||
if qrs.rules[0].plans[0] != planbuilder.PLAN_PASS_SELECT {
|
||||
t.Errorf("want PASS_SELECT, got %s", qrs.rules[0].plans[0].String())
|
||||
}
|
||||
if qrs.rules[0].plans[1] != sqlparser.PLAN_INSERT_PK {
|
||||
if qrs.rules[0].plans[1] != planbuilder.PLAN_INSERT_PK {
|
||||
t.Errorf("want PASS_INSERT_PK, got %s", qrs.rules[0].plans[0].String())
|
||||
}
|
||||
if qrs.rules[0].tableNames[0] != "a" {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package tabletserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/golang/glog"
|
||||
|
@ -15,9 +16,12 @@ import (
|
|||
"github.com/youtube/vitess/go/sync2"
|
||||
"github.com/youtube/vitess/go/vt/dbconfigs"
|
||||
"github.com/youtube/vitess/go/vt/dbconnpool"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
"github.com/youtube/vitess/go/vt/mysqlctl"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/proto"
|
||||
)
|
||||
|
||||
|
@ -67,6 +71,10 @@ type QueryEngine struct {
|
|||
strictMode sync2.AtomicInt64
|
||||
maxResultSize sync2.AtomicInt64
|
||||
streamBufferSize sync2.AtomicInt64
|
||||
strictTableAcl bool
|
||||
|
||||
// loggers
|
||||
accessCheckerLogger *logutil.ThrottledLogger
|
||||
}
|
||||
|
||||
type compiledPlan struct {
|
||||
|
@ -140,9 +148,13 @@ func NewQueryEngine(config Config) *QueryEngine {
|
|||
if config.StrictMode {
|
||||
qe.strictMode.Set(1)
|
||||
}
|
||||
qe.strictTableAcl = config.StrictTableAcl
|
||||
qe.maxResultSize = sync2.AtomicInt64(config.MaxResultSize)
|
||||
qe.streamBufferSize = sync2.AtomicInt64(config.StreamBufferSize)
|
||||
|
||||
// loggers
|
||||
qe.accessCheckerLogger = logutil.NewThrottledLogger("accessChecker", 1*time.Second)
|
||||
|
||||
// Stats
|
||||
stats.Publish("MaxResultSize", stats.IntFunc(qe.maxResultSize.Get))
|
||||
stats.Publish("StreamBufferSize", stats.IntFunc(qe.streamBufferSize.Get))
|
||||
|
@ -311,7 +323,9 @@ func (qe *QueryEngine) Execute(logStats *SQLQueryStats, query *proto.Query) (rep
|
|||
panic(NewTabletError(RETRY, "Query disallowed due to rule: %s", desc))
|
||||
}
|
||||
|
||||
if basePlan.PlanId == sqlparser.PLAN_DDL {
|
||||
qe.checkTableAcl(basePlan.TableName, basePlan.PlanId, basePlan.Authorized, logStats.context.GetUsername())
|
||||
|
||||
if basePlan.PlanId == planbuilder.PLAN_DDL {
|
||||
return qe.execDDL(logStats, query.Sql)
|
||||
}
|
||||
|
||||
|
@ -331,36 +345,36 @@ func (qe *QueryEngine) Execute(logStats *SQLQueryStats, query *proto.Query) (rep
|
|||
invalidator = conn.DirtyKeys(plan.TableName)
|
||||
}
|
||||
switch plan.PlanId {
|
||||
case sqlparser.PLAN_PASS_DML:
|
||||
case planbuilder.PLAN_PASS_DML:
|
||||
if qe.strictMode.Get() != 0 {
|
||||
panic(NewTabletError(FAIL, "DML too complex"))
|
||||
}
|
||||
reply = qe.directFetch(logStats, conn, plan.FullQuery, plan.BindVars, nil, nil)
|
||||
case sqlparser.PLAN_INSERT_PK:
|
||||
case planbuilder.PLAN_INSERT_PK:
|
||||
reply = qe.execInsertPK(logStats, conn, plan, invalidator)
|
||||
case sqlparser.PLAN_INSERT_SUBQUERY:
|
||||
case planbuilder.PLAN_INSERT_SUBQUERY:
|
||||
reply = qe.execInsertSubquery(logStats, conn, plan, invalidator)
|
||||
case sqlparser.PLAN_DML_PK:
|
||||
case planbuilder.PLAN_DML_PK:
|
||||
reply = qe.execDMLPK(logStats, conn, plan, invalidator)
|
||||
case sqlparser.PLAN_DML_SUBQUERY:
|
||||
case planbuilder.PLAN_DML_SUBQUERY:
|
||||
reply = qe.execDMLSubquery(logStats, conn, plan, invalidator)
|
||||
default: // select or set in a transaction, just count as select
|
||||
reply = qe.execDirect(logStats, plan, conn)
|
||||
}
|
||||
} else {
|
||||
switch plan.PlanId {
|
||||
case sqlparser.PLAN_PASS_SELECT:
|
||||
if plan.Reason == sqlparser.REASON_LOCK {
|
||||
case planbuilder.PLAN_PASS_SELECT:
|
||||
if plan.Reason == planbuilder.REASON_LOCK {
|
||||
panic(NewTabletError(FAIL, "Disallowed outside transaction"))
|
||||
}
|
||||
reply = qe.execSelect(logStats, plan)
|
||||
case sqlparser.PLAN_PK_EQUAL:
|
||||
case planbuilder.PLAN_PK_EQUAL:
|
||||
reply = qe.execPKEqual(logStats, plan)
|
||||
case sqlparser.PLAN_PK_IN:
|
||||
case planbuilder.PLAN_PK_IN:
|
||||
reply = qe.execPKIN(logStats, plan)
|
||||
case sqlparser.PLAN_SELECT_SUBQUERY:
|
||||
case planbuilder.PLAN_SELECT_SUBQUERY:
|
||||
reply = qe.execSubquery(logStats, plan)
|
||||
case sqlparser.PLAN_SET:
|
||||
case planbuilder.PLAN_SET:
|
||||
waitingForConnectionStart := time.Now()
|
||||
conn := getOrPanic(qe.connPool)
|
||||
logStats.WaitingForConnection += time.Now().Sub(waitingForConnectionStart)
|
||||
|
@ -395,6 +409,9 @@ func (qe *QueryEngine) StreamExecute(logStats *SQLQueryStats, query *proto.Query
|
|||
logStats.OriginalSql = query.Sql
|
||||
defer queryStats.Record("SELECT_STREAM", time.Now())
|
||||
|
||||
authorized := tableacl.Authorized(plan.TableName, plan.PlanId.MinRole())
|
||||
qe.checkTableAcl(plan.TableName, plan.PlanId, authorized, logStats.context.GetUsername())
|
||||
|
||||
// does the real work: first get a connection
|
||||
waitingForConnectionStart := time.Now()
|
||||
conn := getOrPanic(qe.streamConnPool)
|
||||
|
@ -410,6 +427,16 @@ func (qe *QueryEngine) StreamExecute(logStats *SQLQueryStats, query *proto.Query
|
|||
qe.fullStreamFetch(logStats, conn, plan.FullQuery, query.BindVariables, nil, nil, sendReply)
|
||||
}
|
||||
|
||||
func (qe *QueryEngine) checkTableAcl(table string, planId planbuilder.PlanType, authorized tableacl.ACL, user string) {
|
||||
if !authorized.IsMember(user) {
|
||||
err := fmt.Sprintf("table acl error: %v cannot run %v on table %v", user, planId, table)
|
||||
if qe.strictTableAcl {
|
||||
panic(NewTabletError(FAIL, err))
|
||||
}
|
||||
qe.accessCheckerLogger.Errorf(err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateForDml performs rowcache invalidations for the dml.
|
||||
func (qe *QueryEngine) InvalidateForDml(dml *proto.DmlType) {
|
||||
if qe.cachePool.IsClosed() {
|
||||
|
@ -435,7 +462,7 @@ func (qe *QueryEngine) InvalidateForDml(dml *proto.DmlType) {
|
|||
|
||||
// InvalidateForDDL performs schema and rowcache changes for the ddl.
|
||||
func (qe *QueryEngine) InvalidateForDDL(ddlInvalidate *proto.DDLInvalidate) {
|
||||
ddlPlan := sqlparser.DDLParse(ddlInvalidate.DDL)
|
||||
ddlPlan := planbuilder.DDLParse(ddlInvalidate.DDL)
|
||||
if ddlPlan.Action == "" {
|
||||
panic(NewTabletError(FAIL, "DDL is not understood"))
|
||||
}
|
||||
|
@ -449,7 +476,7 @@ func (qe *QueryEngine) InvalidateForDDL(ddlInvalidate *proto.DDLInvalidate) {
|
|||
// DDL
|
||||
|
||||
func (qe *QueryEngine) execDDL(logStats *SQLQueryStats, ddl string) *mproto.QueryResult {
|
||||
ddlPlan := sqlparser.DDLParse(ddl)
|
||||
ddlPlan := planbuilder.DDLParse(ddl)
|
||||
if ddlPlan.Action == "" {
|
||||
panic(NewTabletError(FAIL, "DDL is not understood"))
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
"strconv"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
)
|
||||
|
||||
//-----------------------------------------------
|
||||
|
@ -90,7 +90,7 @@ func (qrs *QueryRules) UnmarshalJSON(data []byte) (err error) {
|
|||
// filterByPlan creates a new QueryRules by prefiltering on the query and planId. This allows
|
||||
// us to create query plan specific QueryRules out of the original QueryRules. In the new rules,
|
||||
// query, plans and tableNames predicates are empty.
|
||||
func (qrs *QueryRules) filterByPlan(query string, planid sqlparser.PlanType, tableName string) (newqrs *QueryRules) {
|
||||
func (qrs *QueryRules) filterByPlan(query string, planid planbuilder.PlanType, tableName string) (newqrs *QueryRules) {
|
||||
var newrules []*QueryRule
|
||||
for _, qr := range qrs.rules {
|
||||
if newrule := qr.filterByPlan(query, planid, tableName); newrule != nil {
|
||||
|
@ -129,7 +129,7 @@ type QueryRule struct {
|
|||
requestIP, user, query *regexp.Regexp
|
||||
|
||||
// Any matched plan will make this condition true (OR)
|
||||
plans []sqlparser.PlanType
|
||||
plans []planbuilder.PlanType
|
||||
|
||||
// Any matched tableNames will make this condition true (OR)
|
||||
tableNames []string
|
||||
|
@ -158,7 +158,7 @@ func (qr *QueryRule) Copy() (newqr *QueryRule) {
|
|||
act: qr.act,
|
||||
}
|
||||
if qr.plans != nil {
|
||||
newqr.plans = make([]sqlparser.PlanType, len(qr.plans))
|
||||
newqr.plans = make([]planbuilder.PlanType, len(qr.plans))
|
||||
copy(newqr.plans, qr.plans)
|
||||
}
|
||||
if qr.tableNames != nil {
|
||||
|
@ -189,7 +189,7 @@ func (qr *QueryRule) SetUserCond(pattern string) (err error) {
|
|||
// AddPlanCond adds to the list of plans that can be matched for
|
||||
// the rule to fire.
|
||||
// This function acts as an OR: Any plan id match is considered a match.
|
||||
func (qr *QueryRule) AddPlanCond(planType sqlparser.PlanType) {
|
||||
func (qr *QueryRule) AddPlanCond(planType planbuilder.PlanType) {
|
||||
qr.plans = append(qr.plans, planType)
|
||||
}
|
||||
|
||||
|
@ -279,7 +279,7 @@ Error:
|
|||
// The new QueryRule will contain all the original constraints other
|
||||
// than the plan and query. If the plan and query don't match the QueryRule,
|
||||
// then it returns nil.
|
||||
func (qr *QueryRule) filterByPlan(query string, planid sqlparser.PlanType, tableName string) (newqr *QueryRule) {
|
||||
func (qr *QueryRule) filterByPlan(query string, planid planbuilder.PlanType, tableName string) (newqr *QueryRule) {
|
||||
if !reMatch(qr.query, query) {
|
||||
return nil
|
||||
}
|
||||
|
@ -315,7 +315,7 @@ func reMatch(re *regexp.Regexp, val string) bool {
|
|||
return re == nil || re.MatchString(val)
|
||||
}
|
||||
|
||||
func planMatch(plans []sqlparser.PlanType, plan sqlparser.PlanType) bool {
|
||||
func planMatch(plans []planbuilder.PlanType, plan planbuilder.PlanType) bool {
|
||||
if plans == nil {
|
||||
return true
|
||||
}
|
||||
|
@ -728,7 +728,7 @@ func buildQueryRule(ruleInfo map[string]interface{}) (qr *QueryRule, err error)
|
|||
if !ok {
|
||||
return nil, NewTabletError(FAIL, "want string for Plans")
|
||||
}
|
||||
pt, ok := sqlparser.PlanByName(pv)
|
||||
pt, ok := planbuilder.PlanByName(pv)
|
||||
if !ok {
|
||||
return nil, NewTabletError(FAIL, "invalid plan name: %s", pv)
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ func init() {
|
|||
flag.Float64Var(&qsConfig.IdleTimeout, "queryserver-config-idle-timeout", DefaultQsConfig.IdleTimeout, "query server idle timeout")
|
||||
flag.Float64Var(&qsConfig.SpotCheckRatio, "queryserver-config-spot-check-ratio", DefaultQsConfig.SpotCheckRatio, "query server rowcache spot check frequency")
|
||||
flag.BoolVar(&qsConfig.StrictMode, "queryserver-config-strict-mode", DefaultQsConfig.StrictMode, "allow only predictable DMLs and enforces MySQL's STRICT_TRANS_TABLES")
|
||||
flag.BoolVar(&qsConfig.StrictTableAcl, "queryserver-config-strict-table-acl", DefaultQsConfig.StrictTableAcl, "only allow queries that pass table acl checks")
|
||||
flag.StringVar(&qsConfig.RowCache.Binary, "rowcache-bin", DefaultQsConfig.RowCache.Binary, "rowcache binary file")
|
||||
flag.IntVar(&qsConfig.RowCache.Memory, "rowcache-memory", DefaultQsConfig.RowCache.Memory, "rowcache max memory usage in MB")
|
||||
flag.StringVar(&qsConfig.RowCache.Socket, "rowcache-socket", DefaultQsConfig.RowCache.Socket, "rowcache socket path to listen on")
|
||||
|
@ -102,6 +103,7 @@ type Config struct {
|
|||
RowCache RowCacheConfig
|
||||
SpotCheckRatio float64
|
||||
StrictMode bool
|
||||
StrictTableAcl bool
|
||||
}
|
||||
|
||||
// DefaultQSConfig is the default value for the query service config.
|
||||
|
@ -126,6 +128,7 @@ var DefaultQsConfig = Config{
|
|||
RowCache: RowCacheConfig{Memory: -1, TcpPort: -1, Connections: -1, Threads: -1},
|
||||
SpotCheckRatio: 0,
|
||||
StrictMode: true,
|
||||
StrictTableAcl: false,
|
||||
}
|
||||
|
||||
var qsConfig Config
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/youtube/vitess/go/acl"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -52,7 +52,7 @@ var (
|
|||
type queryzRow struct {
|
||||
Query string
|
||||
Table string
|
||||
Plan sqlparser.PlanType
|
||||
Plan planbuilder.PlanType
|
||||
Count int64
|
||||
tm time.Duration
|
||||
Rows int64
|
||||
|
|
|
@ -22,7 +22,8 @@ import (
|
|||
"github.com/youtube/vitess/go/timer"
|
||||
"github.com/youtube/vitess/go/vt/dbconnpool"
|
||||
"github.com/youtube/vitess/go/vt/schema"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
"github.com/youtube/vitess/go/vt/tableacl"
|
||||
"github.com/youtube/vitess/go/vt/tabletserver/planbuilder"
|
||||
)
|
||||
|
||||
const base_show_tables = "select table_name, table_type, unix_timestamp(create_time), table_comment from information_schema.tables where table_schema = database()"
|
||||
|
@ -30,10 +31,11 @@ const base_show_tables = "select table_name, table_type, unix_timestamp(create_t
|
|||
const maxTableCount = 10000
|
||||
|
||||
type ExecPlan struct {
|
||||
*sqlparser.ExecPlan
|
||||
TableInfo *TableInfo
|
||||
Fields []mproto.Field
|
||||
Rules *QueryRules
|
||||
*planbuilder.ExecPlan
|
||||
TableInfo *TableInfo
|
||||
Fields []mproto.Field
|
||||
Rules *QueryRules
|
||||
Authorized tableacl.ACL
|
||||
|
||||
mu sync.Mutex
|
||||
QueryCount int64
|
||||
|
@ -318,12 +320,13 @@ func (si *SchemaInfo) GetPlan(logStats *SQLQueryStats, sql string) (plan *ExecPl
|
|||
}
|
||||
return tableInfo.Table, true
|
||||
}
|
||||
splan, err := sqlparser.ExecParse(sql, GetTable)
|
||||
splan, err := planbuilder.GetExecPlan(sql, GetTable)
|
||||
if err != nil {
|
||||
panic(NewTabletError(FAIL, "%s", err))
|
||||
}
|
||||
plan = &ExecPlan{ExecPlan: splan, TableInfo: tableInfo}
|
||||
plan.Rules = si.rules.filterByPlan(sql, plan.PlanId, plan.TableName)
|
||||
plan.Authorized = tableacl.Authorized(plan.TableName, plan.PlanId.MinRole())
|
||||
if plan.PlanId.IsSelect() {
|
||||
if plan.FieldQuery == nil {
|
||||
log.Warningf("Cannot cache field info: %s", sql)
|
||||
|
@ -340,7 +343,7 @@ func (si *SchemaInfo) GetPlan(logStats *SQLQueryStats, sql string) (plan *ExecPl
|
|||
}
|
||||
plan.Fields = r.Fields
|
||||
}
|
||||
} else if plan.PlanId == sqlparser.PLAN_DDL || plan.PlanId == sqlparser.PLAN_SET {
|
||||
} else if plan.PlanId == planbuilder.PLAN_DDL || plan.PlanId == planbuilder.PLAN_SET {
|
||||
return plan
|
||||
}
|
||||
si.queries.Set(sql, plan)
|
||||
|
@ -349,8 +352,15 @@ func (si *SchemaInfo) GetPlan(logStats *SQLQueryStats, sql string) (plan *ExecPl
|
|||
|
||||
// GetStreamPlan is similar to GetPlan, but doesn't use the cache
|
||||
// and doesn't enforce a limit. It also just returns the parsed query.
|
||||
func (si *SchemaInfo) GetStreamPlan(sql string) *sqlparser.StreamExecPlan {
|
||||
plan, err := sqlparser.StreamExecParse(sql)
|
||||
func (si *SchemaInfo) GetStreamPlan(sql string) *planbuilder.ExecPlan {
|
||||
GetTable := func(tableName string) (*schema.Table, bool) {
|
||||
tableInfo, ok := si.tables[tableName]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return tableInfo.Table, true
|
||||
}
|
||||
plan, err := planbuilder.GetStreamExecPlan(sql, GetTable)
|
||||
if err != nil {
|
||||
panic(NewTabletError(FAIL, "%s", err))
|
||||
}
|
||||
|
@ -489,7 +499,7 @@ func (si *SchemaInfo) getQueryStats(f queryStatsFunc) map[string]int64 {
|
|||
type perQueryStats struct {
|
||||
Query string
|
||||
Table string
|
||||
Plan sqlparser.PlanType
|
||||
Plan planbuilder.PlanType
|
||||
QueryCount int64
|
||||
Time time.Duration
|
||||
RowCount int64
|
||||
|
|
|
@ -23,6 +23,9 @@ type SrvShard struct {
|
|||
KeyRange key.KeyRange
|
||||
ServedTypes []TabletType
|
||||
|
||||
// MasterCell indicates the cell that master tablet resides
|
||||
MasterCell string
|
||||
|
||||
// TabletTypes represents the list of types we have serving tablets
|
||||
// for, in this cell only.
|
||||
TabletTypes []TabletType
|
||||
|
|
|
@ -9,7 +9,6 @@ package topo
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/youtube/vitess/go/bson"
|
||||
"github.com/youtube/vitess/go/bytes2"
|
||||
)
|
||||
|
@ -30,6 +29,7 @@ func (srvShard *SrvShard) MarshalBson(buf *bytes2.ChunkedWriter, key string) {
|
|||
}
|
||||
lenWriter.Close()
|
||||
}
|
||||
bson.EncodeString(buf, "MasterCell", srvShard.MasterCell)
|
||||
// []TabletType
|
||||
{
|
||||
bson.EncodePrefix(buf, bson.Array, "TabletTypes")
|
||||
|
@ -76,6 +76,8 @@ func (srvShard *SrvShard) UnmarshalBson(buf *bytes.Buffer, kind byte) {
|
|||
srvShard.ServedTypes = append(srvShard.ServedTypes, _v1)
|
||||
}
|
||||
}
|
||||
case "MasterCell":
|
||||
srvShard.MasterCell = bson.DecodeString(buf, kind)
|
||||
case "TabletTypes":
|
||||
// []TabletType
|
||||
if kind != bson.Null {
|
||||
|
|
|
@ -41,6 +41,7 @@ func TestSrvKeySpace(t *testing.T) {
|
|||
SrvShard{
|
||||
Name: "test_shard",
|
||||
ServedTypes: []TabletType{TYPE_MASTER},
|
||||
MasterCell: "test_cell",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -64,6 +65,7 @@ func TestSrvKeySpace(t *testing.T) {
|
|||
SrvShard{
|
||||
Name: "test_shard",
|
||||
ServedTypes: []TabletType{TYPE_MASTER},
|
||||
MasterCell: "test_cell",
|
||||
TabletTypes: []TabletType{},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -207,6 +207,7 @@ func rebuildCellSrvShard(ts topo.Server, shardInfo *topo.ShardInfo, cell string,
|
|||
Name: shardInfo.ShardName(),
|
||||
KeyRange: shardInfo.KeyRange,
|
||||
ServedTypes: shardInfo.ServedTypes,
|
||||
MasterCell: shardInfo.MasterAlias.Cell,
|
||||
TabletTypes: make([]topo.TabletType, 0, len(locationAddrsMap)),
|
||||
}
|
||||
for tabletType := range locationAddrsMap {
|
||||
|
|
|
@ -0,0 +1,290 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vtgate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
"github.com/youtube/vitess/go/vt/sqlparser"
|
||||
)
|
||||
|
||||
const (
|
||||
EID_NODE = iota
|
||||
VALUE_NODE
|
||||
LIST_NODE
|
||||
OTHER_NODE
|
||||
)
|
||||
|
||||
type RoutingPlan struct {
|
||||
criteria sqlparser.SQLNode
|
||||
}
|
||||
|
||||
func GetShardList(sql string, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardlist []int, err error) {
|
||||
plan, err := buildPlan(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return shardListFromPlan(plan, bindVariables, tabletKeys)
|
||||
}
|
||||
|
||||
func buildPlan(sql string) (plan *RoutingPlan, err error) {
|
||||
statement, err := sqlparser.Parse(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getRoutingPlan(statement)
|
||||
}
|
||||
|
||||
func shardListFromPlan(plan *RoutingPlan, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (shardList []int, err error) {
|
||||
if plan.criteria == nil {
|
||||
return makeList(0, len(tabletKeys)), nil
|
||||
}
|
||||
|
||||
switch criteria := plan.criteria.(type) {
|
||||
case sqlparser.Values:
|
||||
index, err := findInsertShard(criteria, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []int{index}, nil
|
||||
case *sqlparser.ComparisonExpr:
|
||||
switch criteria.Operator {
|
||||
case "=", "<=>":
|
||||
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []int{index}, nil
|
||||
case "<", "<=":
|
||||
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return makeList(0, index+1), nil
|
||||
case ">", ">=":
|
||||
index, err := findShard(criteria.Right, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return makeList(index, len(tabletKeys)), nil
|
||||
case "in":
|
||||
return findShardList(criteria.Right, bindVariables, tabletKeys)
|
||||
}
|
||||
case *sqlparser.RangeCond:
|
||||
if criteria.Operator == "between" {
|
||||
start, err := findShard(criteria.From, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
last, err := findShard(criteria.To, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if last < start {
|
||||
start, last = last, start
|
||||
}
|
||||
return makeList(start, last+1), nil
|
||||
}
|
||||
}
|
||||
return makeList(0, len(tabletKeys)), nil
|
||||
}
|
||||
|
||||
func getRoutingPlan(statement sqlparser.Statement) (plan *RoutingPlan, err error) {
|
||||
plan = &RoutingPlan{}
|
||||
if ins, ok := statement.(*sqlparser.Insert); ok {
|
||||
if sel, ok := ins.Rows.(sqlparser.SelectStatement); ok {
|
||||
return getRoutingPlan(sel)
|
||||
}
|
||||
plan.criteria, err = routingAnalyzeValues(ins.Rows.(sqlparser.Values))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
var where *sqlparser.Where
|
||||
switch stmt := statement.(type) {
|
||||
case *sqlparser.Select:
|
||||
where = stmt.Where
|
||||
case *sqlparser.Update:
|
||||
where = stmt.Where
|
||||
case *sqlparser.Delete:
|
||||
where = stmt.Where
|
||||
}
|
||||
if where != nil {
|
||||
plan.criteria = routingAnalyzeBoolean(where.Expr)
|
||||
}
|
||||
return plan, nil
|
||||
}
|
||||
|
||||
func routingAnalyzeValues(vals sqlparser.Values) (sqlparser.Values, error) {
|
||||
// Analyze first value of every item in the list
|
||||
for i := 0; i < len(vals); i++ {
|
||||
switch tuple := vals[i].(type) {
|
||||
case sqlparser.ValTuple:
|
||||
result := routingAnalyzeValue(tuple[0])
|
||||
if result != VALUE_NODE {
|
||||
return nil, fmt.Errorf("insert is too complex")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("insert is too complex")
|
||||
}
|
||||
}
|
||||
return vals, nil
|
||||
}
|
||||
|
||||
func routingAnalyzeBoolean(node sqlparser.BoolExpr) sqlparser.BoolExpr {
|
||||
switch node := node.(type) {
|
||||
case *sqlparser.AndExpr:
|
||||
left := routingAnalyzeBoolean(node.Left)
|
||||
right := routingAnalyzeBoolean(node.Right)
|
||||
if left != nil && right != nil {
|
||||
return nil
|
||||
} else if left != nil {
|
||||
return left
|
||||
} else {
|
||||
return right
|
||||
}
|
||||
case *sqlparser.ParenBoolExpr:
|
||||
return routingAnalyzeBoolean(node.Expr)
|
||||
case *sqlparser.ComparisonExpr:
|
||||
switch {
|
||||
case sqlparser.StringIn(node.Operator, "=", "<", ">", "<=", ">=", "<=>"):
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
right := routingAnalyzeValue(node.Right)
|
||||
if (left == EID_NODE && right == VALUE_NODE) || (left == VALUE_NODE && right == EID_NODE) {
|
||||
return node
|
||||
}
|
||||
case node.Operator == "in":
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
right := routingAnalyzeValue(node.Right)
|
||||
if left == EID_NODE && right == LIST_NODE {
|
||||
return node
|
||||
}
|
||||
}
|
||||
case *sqlparser.RangeCond:
|
||||
if node.Operator != "between" {
|
||||
return nil
|
||||
}
|
||||
left := routingAnalyzeValue(node.Left)
|
||||
from := routingAnalyzeValue(node.From)
|
||||
to := routingAnalyzeValue(node.To)
|
||||
if left == EID_NODE && from == VALUE_NODE && to == VALUE_NODE {
|
||||
return node
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func routingAnalyzeValue(valExpr sqlparser.ValExpr) int {
|
||||
switch node := valExpr.(type) {
|
||||
case *sqlparser.ColName:
|
||||
if string(node.Name) == "entity_id" {
|
||||
return EID_NODE
|
||||
}
|
||||
case sqlparser.ValTuple:
|
||||
for _, n := range node {
|
||||
if routingAnalyzeValue(n) != VALUE_NODE {
|
||||
return OTHER_NODE
|
||||
}
|
||||
}
|
||||
return LIST_NODE
|
||||
case sqlparser.StrVal, sqlparser.NumVal, sqlparser.ValArg:
|
||||
return VALUE_NODE
|
||||
}
|
||||
return OTHER_NODE
|
||||
}
|
||||
|
||||
func findShardList(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) ([]int, error) {
|
||||
shardset := make(map[int]bool)
|
||||
switch node := valExpr.(type) {
|
||||
case sqlparser.ValTuple:
|
||||
for _, n := range node {
|
||||
index, err := findShard(n, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shardset[index] = true
|
||||
}
|
||||
}
|
||||
shardlist := make([]int, len(shardset))
|
||||
index := 0
|
||||
for k := range shardset {
|
||||
shardlist[index] = k
|
||||
index++
|
||||
}
|
||||
return shardlist, nil
|
||||
}
|
||||
|
||||
func findInsertShard(vals sqlparser.Values, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
|
||||
index := -1
|
||||
for i := 0; i < len(vals); i++ {
|
||||
first_value_expression := vals[i].(sqlparser.ValTuple)[0]
|
||||
newIndex, err := findShard(first_value_expression, bindVariables, tabletKeys)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
if index == -1 {
|
||||
index = newIndex
|
||||
} else if index != newIndex {
|
||||
return -1, fmt.Errorf("insert has multiple shard targets")
|
||||
}
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func findShard(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}, tabletKeys []key.KeyspaceId) (int, error) {
|
||||
value, err := getBoundValue(valExpr, bindVariables)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return key.FindShardForValue(value, tabletKeys), nil
|
||||
}
|
||||
|
||||
func getBoundValue(valExpr sqlparser.ValExpr, bindVariables map[string]interface{}) (string, error) {
|
||||
switch node := valExpr.(type) {
|
||||
case sqlparser.ValTuple:
|
||||
if len(node) != 1 {
|
||||
return "", fmt.Errorf("tuples not allowed as insert values")
|
||||
}
|
||||
// TODO: Change parser to create single value tuples into non-tuples.
|
||||
return getBoundValue(node[0], bindVariables)
|
||||
case sqlparser.StrVal:
|
||||
return string(node), nil
|
||||
case sqlparser.NumVal:
|
||||
val, err := strconv.ParseInt(string(node), 10, 64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return key.Uint64Key(val).String(), nil
|
||||
case sqlparser.ValArg:
|
||||
value, err := findBindValue(node, bindVariables)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return key.EncodeValue(value), nil
|
||||
}
|
||||
panic("Unexpected token")
|
||||
}
|
||||
|
||||
func findBindValue(valArg sqlparser.ValArg, bindVariables map[string]interface{}) (interface{}, error) {
|
||||
if bindVariables == nil {
|
||||
return nil, fmt.Errorf("No bind variable for " + string(valArg))
|
||||
}
|
||||
value, ok := bindVariables[string(valArg[1:])]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("No bind variable for " + string(valArg))
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func makeList(start, end int) []int {
|
||||
list := make([]int, end-start)
|
||||
for i := start; i < end; i++ {
|
||||
list[i-start] = i
|
||||
}
|
||||
return list
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright 2012, Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package vtgate
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/youtube/vitess/go/testfiles"
|
||||
"github.com/youtube/vitess/go/vt/key"
|
||||
)
|
||||
|
||||
func TestRouting(t *testing.T) {
|
||||
tabletkeys := []key.KeyspaceId{
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x02",
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x04",
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x06",
|
||||
"a",
|
||||
"b",
|
||||
"d",
|
||||
}
|
||||
bindVariables := make(map[string]interface{})
|
||||
bindVariables["id0"] = 0
|
||||
bindVariables["id2"] = 2
|
||||
bindVariables["id3"] = 3
|
||||
bindVariables["id4"] = 4
|
||||
bindVariables["id6"] = 6
|
||||
bindVariables["id8"] = 8
|
||||
bindVariables["ids"] = []interface{}{1, 4}
|
||||
bindVariables["a"] = "a"
|
||||
bindVariables["b"] = "b"
|
||||
bindVariables["c"] = "c"
|
||||
bindVariables["d"] = "d"
|
||||
bindVariables["e"] = "e"
|
||||
for tcase := range iterateFiles("sqlparser_test/routing_cases.txt") {
|
||||
if tcase.output == "" {
|
||||
tcase.output = tcase.input
|
||||
}
|
||||
out, err := GetShardList(tcase.input, bindVariables, tabletkeys)
|
||||
if err != nil {
|
||||
if err.Error() != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.input, err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
sort.Ints(out)
|
||||
outstr := fmt.Sprintf("%v", out)
|
||||
if outstr != tcase.output {
|
||||
t.Error(fmt.Sprintf("Line:%v\n%s\n%s", tcase.lineno, tcase.output, outstr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(sougou): This is now duplicated in three plcaes. Refactor.
|
||||
type testCase struct {
|
||||
file string
|
||||
lineno int
|
||||
input string
|
||||
output string
|
||||
}
|
||||
|
||||
func iterateFiles(pattern string) (testCaseIterator chan testCase) {
|
||||
names := testfiles.Glob(pattern)
|
||||
testCaseIterator = make(chan testCase)
|
||||
go func() {
|
||||
defer close(testCaseIterator)
|
||||
for _, name := range names {
|
||||
fd, err := os.OpenFile(name, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Could not open file %s", name))
|
||||
}
|
||||
|
||||
r := bufio.NewReader(fd)
|
||||
lineno := 0
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
lines := strings.Split(strings.TrimRight(line, "\n"), "#")
|
||||
lineno++
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
panic(fmt.Sprintf("Error reading file %s: %s", name, err.Error()))
|
||||
}
|
||||
break
|
||||
}
|
||||
input := lines[0]
|
||||
output := ""
|
||||
if len(lines) > 1 {
|
||||
output = lines[1]
|
||||
}
|
||||
if input == "" {
|
||||
continue
|
||||
}
|
||||
testCaseIterator <- testCase{name, lineno, input, output}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return testCaseIterator
|
||||
}
|
|
@ -51,6 +51,15 @@ func (wr *Wrangler) checkSlaveReplication(tabletMap map[topo.TabletAlias]*topo.T
|
|||
go func(tablet *topo.TabletInfo) {
|
||||
defer wg.Done()
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
mutex.Lock()
|
||||
lastError = err
|
||||
mutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
if tablet.Type == topo.TYPE_LAG {
|
||||
log.Infof(" skipping slave position check for %v tablet %v", tablet.Type, tablet.Alias)
|
||||
return
|
||||
|
@ -58,9 +67,6 @@ func (wr *Wrangler) checkSlaveReplication(tabletMap map[topo.TabletAlias]*topo.T
|
|||
|
||||
replPos, err := wr.ai.SlavePosition(tablet, wr.actionTimeout())
|
||||
if err != nil {
|
||||
mutex.Lock()
|
||||
lastError = err
|
||||
mutex.Unlock()
|
||||
if tablet.Type == topo.TYPE_BACKUP {
|
||||
log.Warningf(" failed to get slave position from backup tablet %v, either wait for backup to finish or scrap tablet (%v)", tablet.Alias, err)
|
||||
} else {
|
||||
|
@ -70,13 +76,18 @@ func (wr *Wrangler) checkSlaveReplication(tabletMap map[topo.TabletAlias]*topo.T
|
|||
}
|
||||
|
||||
if !masterIsDead {
|
||||
// This case used to be handled by the timeout check below, but checking
|
||||
// it explicitly provides a more informative error message.
|
||||
if replPos.SecondsBehindMaster == myproto.InvalidLagSeconds {
|
||||
err = fmt.Errorf("slave %v is not replicating (Slave_IO or Slave_SQL not running), can't complete reparent in time", tablet.Alias)
|
||||
log.Errorf(" %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var dur time.Duration = time.Duration(uint(time.Second) * replPos.SecondsBehindMaster)
|
||||
if dur > wr.actionTimeout() {
|
||||
err = fmt.Errorf("slave is too far behind to complete reparent in time (%v>%v), either increase timeout using 'vtctl -wait-time XXX ReparentShard ...' or scrap tablet %v", dur, wr.actionTimeout(), tablet.Alias)
|
||||
log.Errorf(" %v", err)
|
||||
mutex.Lock()
|
||||
lastError = err
|
||||
mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -93,10 +93,10 @@ func (wr *Wrangler) shardExternallyReparentedLocked(keyspace, shard string, mast
|
|||
}
|
||||
|
||||
// Compute the list of Cells we need to rebuild: old master and
|
||||
// new master cells.
|
||||
// all other cells if reparenting to another cell.
|
||||
cells := []string{shardInfo.MasterAlias.Cell}
|
||||
if shardInfo.MasterAlias.Cell != masterElectTabletAlias.Cell {
|
||||
cells = append(cells, masterElectTabletAlias.Cell)
|
||||
cells = nil
|
||||
}
|
||||
|
||||
// now update the master record in the shard object
|
||||
|
|
|
@ -32,6 +32,9 @@ class KeyRange(codec.BSONCoding):
|
|||
return keyrange_constants.NON_PARTIAL_KEYRANGE
|
||||
return '%s-%s' % (self.Start, self.End)
|
||||
|
||||
def __repr__(self):
|
||||
return 'KeyRange(%r-%r)' % (self.Start, self.End)
|
||||
|
||||
def bson_encode(self):
|
||||
return {"Start": self.Start, "End": self.End}
|
||||
|
||||
|
|
|
@ -82,17 +82,20 @@ class VTGateCursor(object):
|
|||
# FIXME(shrutip): these checks maybe better on vtgate server.
|
||||
if topology.is_sharded_keyspace(self.keyspace, self.tablet_type):
|
||||
if self.keyspace_ids is None or len(self.keyspace_ids) != 1:
|
||||
raise dbexceptions.ProgrammingError('DML on zero or multiple keyspace ids is not allowed')
|
||||
raise dbexceptions.ProgrammingError('DML on zero or multiple keyspace ids is not allowed: %r'
|
||||
% self.keyspace_ids)
|
||||
else:
|
||||
if not self.keyranges or str(self.keyranges[0]) != keyrange_constants.NON_PARTIAL_KEYRANGE:
|
||||
raise dbexceptions.ProgrammingError('Keyrange not correct for non-sharded keyspace')
|
||||
raise dbexceptions.ProgrammingError('Keyrange not correct for non-sharded keyspace: %r'
|
||||
% self.keyranges)
|
||||
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges)
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges)
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -107,12 +110,13 @@ class VTGateCursor(object):
|
|||
if write_query:
|
||||
raise dbexceptions.DatabaseError('execute_entity_ids is not allowed for write queries')
|
||||
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute_entity_ids(sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
entity_keyspace_id_map,
|
||||
entity_column_name)
|
||||
self.results, self.rowcount, self.lastrowid, self.description = self._conn._execute_entity_ids(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
entity_keyspace_id_map,
|
||||
entity_column_name)
|
||||
self.index = 0
|
||||
return self.rowcount
|
||||
|
||||
|
@ -225,12 +229,13 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
raise dbexceptions.ProgrammingError('Streaming query cannot be writable')
|
||||
|
||||
self.description = None
|
||||
x, y, z, self.description = self._conn._stream_execute(sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges)
|
||||
x, y, z, self.description = self._conn._stream_execute(
|
||||
sql,
|
||||
bind_variables,
|
||||
self.keyspace,
|
||||
self.tablet_type,
|
||||
keyspace_ids=self.keyspace_ids,
|
||||
keyranges=self.keyranges)
|
||||
self.index = 0
|
||||
return 0
|
||||
|
||||
|
@ -241,9 +246,9 @@ class StreamVTGateCursor(VTGateCursor):
|
|||
self.index += 1
|
||||
return self._conn._stream_next()
|
||||
|
||||
# fetchmany can be called until it returns no rows. Returning less rows
|
||||
# than what we asked for is also an indication we ran out, but the cursor
|
||||
# API in PEP249 is silent about that.
|
||||
# fetchmany can be called until it returns no rows. Returning less rows
|
||||
# than what we asked for is also an indication we ran out, but the cursor
|
||||
# API in PEP249 is silent about that.
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
import environment
|
||||
|
||||
|
||||
class MysqlFlavor(object):
|
||||
"""Base class with default SQL statements"""
|
||||
|
||||
def promote_slave_commands(self):
|
||||
"""Returns commands to convert a slave to a master."""
|
||||
return [
|
||||
"RESET MASTER",
|
||||
"STOP SLAVE",
|
||||
"RESET SLAVE",
|
||||
"CHANGE MASTER TO MASTER_HOST = ''",
|
||||
]
|
||||
|
||||
def reset_replication_commands(self):
|
||||
return [
|
||||
"RESET MASTER",
|
||||
"STOP SLAVE",
|
||||
"RESET SLAVE",
|
||||
'CHANGE MASTER TO MASTER_HOST = ""',
|
||||
]
|
||||
|
||||
|
||||
class GoogleMysql(MysqlFlavor):
|
||||
"""Overrides specific to Google MySQL"""
|
||||
|
||||
|
||||
class MariaDB(MysqlFlavor):
|
||||
"""Overrides specific to MariaDB"""
|
||||
|
||||
def promote_slave_commands(self):
|
||||
return [
|
||||
"RESET MASTER",
|
||||
"STOP SLAVE",
|
||||
"RESET SLAVE",
|
||||
]
|
||||
|
||||
def reset_replication_commands(self):
|
||||
return [
|
||||
"RESET MASTER",
|
||||
"STOP SLAVE",
|
||||
"RESET SLAVE",
|
||||
]
|
||||
|
||||
if environment.mysql_flavor == "MariaDB":
|
||||
mysql_flavor = MariaDB()
|
||||
else:
|
||||
mysql_flavor = GoogleMysql()
|
|
@ -490,3 +490,77 @@ class TestNocache(framework.TestCase):
|
|||
error_count = self.env.run_cases(nocache_cases.cases)
|
||||
if error_count != 0:
|
||||
self.fail("test_execution errors: %d"%(error_count))
|
||||
|
||||
def test_table_acl_no_access(self):
|
||||
self.env.conn.begin()
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("select * from vtocc_acl_no_access where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("delete from vtocc_acl_no_access where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("alter table vtocc_acl_no_access comment 'comment'")
|
||||
self.env.conn.commit()
|
||||
cu = cursor.StreamCursor(self.env.conn)
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
cu.execute("select * from vtocc_acl_no_access where key1=1", {})
|
||||
cu.close()
|
||||
|
||||
def test_table_acl_read_only(self):
|
||||
self.env.conn.begin()
|
||||
self.env.execute("select * from vtocc_acl_read_only where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("delete from vtocc_acl_read_only where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("alter table vtocc_acl_read_only comment 'comment'")
|
||||
self.env.conn.commit()
|
||||
cu = cursor.StreamCursor(self.env.conn)
|
||||
cu.execute("select * from vtocc_acl_read_only where key1=1", {})
|
||||
cu.fetchall()
|
||||
cu.close()
|
||||
|
||||
def test_table_acl_read_write(self):
|
||||
self.env.conn.begin()
|
||||
self.env.execute("select * from vtocc_acl_read_write where key1=1")
|
||||
self.env.execute("delete from vtocc_acl_read_write where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("alter table vtocc_acl_read_write comment 'comment'")
|
||||
self.env.conn.commit()
|
||||
cu = cursor.StreamCursor(self.env.conn)
|
||||
cu.execute("select * from vtocc_acl_read_write where key1=1", {})
|
||||
cu.fetchall()
|
||||
cu.close()
|
||||
|
||||
def test_table_acl_admin(self):
|
||||
self.env.conn.begin()
|
||||
self.env.execute("select * from vtocc_acl_admin where key1=1")
|
||||
self.env.execute("delete from vtocc_acl_admin where key1=1")
|
||||
self.env.execute("alter table vtocc_acl_admin comment 'comment'")
|
||||
self.env.conn.commit()
|
||||
cu = cursor.StreamCursor(self.env.conn)
|
||||
cu.execute("select * from vtocc_acl_admin where key1=1", {})
|
||||
cu.fetchall()
|
||||
cu.close()
|
||||
|
||||
def test_table_acl_unmatched(self):
|
||||
self.env.conn.begin()
|
||||
self.env.execute("select * from vtocc_acl_unmatched where key1=1")
|
||||
self.env.execute("delete from vtocc_acl_unmatched where key1=1")
|
||||
self.env.execute("alter table vtocc_acl_unmatched comment 'comment'")
|
||||
self.env.conn.commit()
|
||||
cu = cursor.StreamCursor(self.env.conn)
|
||||
cu.execute("select * from vtocc_acl_unmatched where key1=1", {})
|
||||
cu.fetchall()
|
||||
cu.close()
|
||||
|
||||
def test_table_acl_all_user_read_only(self):
|
||||
self.env.conn.begin()
|
||||
self.env.execute("select * from vtocc_acl_all_user_read_only where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("delete from vtocc_acl_all_user_read_only where key1=1")
|
||||
with self.assertRaisesRegexp(dbexceptions.DatabaseError, '.*table acl error.*'):
|
||||
self.env.execute("alter table vtocc_acl_all_user_read_only comment 'comment'")
|
||||
self.env.conn.commit()
|
||||
cu = cursor.StreamCursor(self.env.conn)
|
||||
cu.execute("select * from vtocc_acl_all_user_read_only where key1=1", {})
|
||||
cu.fetchall()
|
||||
cu.close()
|
||||
|
|
|
@ -52,7 +52,7 @@ class TestEnv(object):
|
|||
return "localhost:%s" % self.port
|
||||
|
||||
def connect(self):
|
||||
c = tablet_conn.connect(self.address, '', 'test_keyspace', '0', 2)
|
||||
c = tablet_conn.connect(self.address, '', 'test_keyspace', '0', 2, user='youtube-dev-dedicated', password='vtpass')
|
||||
c.max_attempts = 1
|
||||
return c
|
||||
|
||||
|
@ -196,7 +196,14 @@ class VttabletTestEnv(TestEnv):
|
|||
self.create_customrules(customrules)
|
||||
schema_override = os.path.join(environment.tmproot, 'schema_override.json')
|
||||
self.create_schema_override(schema_override)
|
||||
self.tablet.start_vttablet(memcache=self.memcache, customrules=customrules, schema_override=schema_override)
|
||||
table_acl_config = os.path.join(environment.vttop, 'test', 'test_data', 'table_acl_config.json')
|
||||
self.tablet.start_vttablet(
|
||||
memcache=self.memcache,
|
||||
customrules=customrules,
|
||||
schema_override=schema_override,
|
||||
table_acl_config=table_acl_config,
|
||||
auth=True,
|
||||
)
|
||||
|
||||
# FIXME(szopa): This is necessary here only because of a bug that
|
||||
# makes the qs reload its config only after an action.
|
||||
|
@ -281,12 +288,15 @@ class VtoccTestEnv(TestEnv):
|
|||
self.create_customrules(customrules)
|
||||
schema_override = os.path.join(environment.tmproot, 'schema_override.json')
|
||||
self.create_schema_override(schema_override)
|
||||
table_acl_config = os.path.join(environment.vttop, 'test', 'test_data', 'table_acl_config.json')
|
||||
|
||||
occ_args = environment.binary_args('vtocc') + [
|
||||
"-port", str(self.port),
|
||||
"-customrules", customrules,
|
||||
"-log_dir", environment.vtlogroot,
|
||||
"-schema-override", schema_override,
|
||||
"-table-acl-config", table_acl_config,
|
||||
"-queryserver-config-strict-table-acl",
|
||||
"-db-config-app-charset", "utf8",
|
||||
"-db-config-app-dbname", "vt_test_keyspace",
|
||||
"-db-config-app-host", "localhost",
|
||||
|
@ -294,7 +304,9 @@ class VtoccTestEnv(TestEnv):
|
|||
"-db-config-app-uname", 'vt_dba', # use vt_dba as some tests depend on 'drop'
|
||||
"-db-config-app-keyspace", "test_keyspace",
|
||||
"-db-config-app-shard", "0",
|
||||
"-auth-credentials", os.path.join(environment.vttop, 'test', 'test_data', 'authcredentials_test.json'),
|
||||
]
|
||||
|
||||
if self.memcache:
|
||||
memcache = self.mysqldir+"/memcache.sock"
|
||||
occ_args.extend(["-rowcache-bin", environment.memcached_bin()])
|
||||
|
@ -346,11 +358,6 @@ class VtoccTestEnv(TestEnv):
|
|||
except:
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
c = tablet_conn.connect("localhost:%s" % self.port, '', 'test_keyspace', '0', 2)
|
||||
c.max_attempts = 1
|
||||
return c
|
||||
|
||||
def mysql_connect(self):
|
||||
return mysql.connect(
|
||||
host='localhost',
|
||||
|
|
|
@ -16,6 +16,7 @@ import unittest
|
|||
import environment
|
||||
import utils
|
||||
import tablet
|
||||
from mysql_flavor import mysql_flavor
|
||||
|
||||
tablet_62344 = tablet.Tablet(62344)
|
||||
tablet_62044 = tablet.Tablet(62044)
|
||||
|
@ -70,8 +71,8 @@ class TestReparent(unittest.TestCase):
|
|||
t.clean_dbs()
|
||||
super(TestReparent, self).tearDown()
|
||||
|
||||
def _check_db_addr(self, shard, db_type, expected_port):
|
||||
ep = utils.run_vtctl_json(['GetEndPoints', 'test_nj', 'test_keyspace/'+shard, db_type])
|
||||
def _check_db_addr(self, shard, db_type, expected_port, cell='test_nj'):
|
||||
ep = utils.run_vtctl_json(['GetEndPoints', cell, 'test_keyspace/'+shard, db_type])
|
||||
self.assertEqual(len(ep['entries']), 1 , 'Wrong number of entries: %s' % str(ep))
|
||||
port = ep['entries'][0]['named_port_map']['_vtocc']
|
||||
self.assertEqual(port, expected_port, 'Unexpected port: %u != %u from %s' % (port, expected_port, str(ep)))
|
||||
|
@ -184,6 +185,66 @@ class TestReparent(unittest.TestCase):
|
|||
# so the other tests don't have any surprise
|
||||
tablet_62344.start_mysql().wait()
|
||||
|
||||
def test_reparent_cross_cell(self, shard_id='0'):
|
||||
utils.run_vtctl('CreateKeyspace test_keyspace')
|
||||
|
||||
# create the database so vttablets start, as they are serving
|
||||
tablet_62344.create_db('vt_test_keyspace')
|
||||
tablet_62044.create_db('vt_test_keyspace')
|
||||
tablet_41983.create_db('vt_test_keyspace')
|
||||
tablet_31981.create_db('vt_test_keyspace')
|
||||
|
||||
# Start up a master mysql and vttablet
|
||||
tablet_62344.init_tablet('master', 'test_keyspace', shard_id, start=True)
|
||||
if environment.topo_server_implementation == 'zookeeper':
|
||||
shard = utils.run_vtctl_json(['GetShard', 'test_keyspace/'+shard_id])
|
||||
self.assertEqual(shard['Cells'], ['test_nj'], 'wrong list of cell in Shard: %s' % str(shard['Cells']))
|
||||
|
||||
# Create a few slaves for testing reparenting.
|
||||
tablet_62044.init_tablet('replica', 'test_keyspace', shard_id, start=True, wait_for_start=False)
|
||||
tablet_41983.init_tablet('replica', 'test_keyspace', shard_id, start=True, wait_for_start=False)
|
||||
tablet_31981.init_tablet('replica', 'test_keyspace', shard_id, start=True, wait_for_start=False)
|
||||
for t in [tablet_62044, tablet_41983, tablet_31981]:
|
||||
t.wait_for_vttablet_state("SERVING")
|
||||
if environment.topo_server_implementation == 'zookeeper':
|
||||
shard = utils.run_vtctl_json(['GetShard', 'test_keyspace/'+shard_id])
|
||||
self.assertEqual(shard['Cells'], ['test_nj', 'test_ny'], 'wrong list of cell in Shard: %s' % str(shard['Cells']))
|
||||
|
||||
# Recompute the shard layout node - until you do that, it might not be valid.
|
||||
utils.run_vtctl('RebuildShardGraph test_keyspace/' + shard_id)
|
||||
utils.validate_topology()
|
||||
|
||||
# Force the slaves to reparent assuming that all the datasets are identical.
|
||||
for t in [tablet_62344, tablet_62044, tablet_41983, tablet_31981]:
|
||||
t.reset_replication()
|
||||
utils.pause("force ReparentShard?")
|
||||
utils.run_vtctl('ReparentShard -force test_keyspace/%s %s' % (shard_id, tablet_62344.tablet_alias))
|
||||
utils.validate_topology(ping_tablets=True)
|
||||
|
||||
self._check_db_addr(shard_id, 'master', tablet_62344.port)
|
||||
|
||||
# Verify MasterCell is properly set
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
# Perform a graceful reparent operation to another cell.
|
||||
utils.pause("graceful ReparentShard?")
|
||||
utils.run_vtctl('ReparentShard test_keyspace/%s %s' % (shard_id, tablet_31981.tablet_alias), auto_log=True)
|
||||
utils.validate_topology()
|
||||
|
||||
self._check_db_addr(shard_id, 'master', tablet_31981.port, cell='test_ny')
|
||||
|
||||
# Verify MasterCell is set to new cell.
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_ny')
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_ny')
|
||||
|
||||
tablet.kill_tablets([tablet_62344, tablet_62044, tablet_41983, tablet_31981])
|
||||
|
||||
|
||||
def test_reparent_graceful_range_based(self):
|
||||
shard_id = '0000000000000000-FFFFFFFFFFFFFFFF'
|
||||
self._test_reparent_graceful(shard_id)
|
||||
|
@ -230,6 +291,12 @@ class TestReparent(unittest.TestCase):
|
|||
|
||||
self._check_db_addr(shard_id, 'master', tablet_62344.port)
|
||||
|
||||
# Verify MasterCell is set to new cell.
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
# Convert two replica to spare. That should leave only one node serving traffic,
|
||||
# but still needs to appear in the replication graph.
|
||||
utils.run_vtctl(['ChangeSlaveType', tablet_41983.tablet_alias, 'spare'])
|
||||
|
@ -247,6 +314,12 @@ class TestReparent(unittest.TestCase):
|
|||
|
||||
self._check_db_addr(shard_id, 'master', tablet_62044.port)
|
||||
|
||||
# Verify MasterCell is set to new cell.
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_ny', 'test_keyspace/%s' % (shard_id)])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
tablet.kill_tablets([tablet_62344, tablet_62044, tablet_41983, tablet_31981])
|
||||
|
||||
# Test address correction.
|
||||
|
@ -337,12 +410,7 @@ class TestReparent(unittest.TestCase):
|
|||
# now manually reparent 1 out of 2 tablets
|
||||
# 62044 will be the new master
|
||||
# 31981 won't be re-parented, so it will be busted
|
||||
tablet_62044.mquery('', [
|
||||
"RESET MASTER",
|
||||
"STOP SLAVE",
|
||||
"RESET SLAVE",
|
||||
"CHANGE MASTER TO MASTER_HOST = ''",
|
||||
])
|
||||
tablet_62044.mquery('', mysql_flavor.promote_slave_commands())
|
||||
new_pos = tablet_62044.mquery('', 'show master status')
|
||||
logging.debug("New master position: %s" % str(new_pos))
|
||||
|
||||
|
@ -439,7 +507,6 @@ class TestReparent(unittest.TestCase):
|
|||
utils.run_vtctl('ReparentShard -force test_keyspace/%s %s' % (shard_id, tablet_62344.tablet_alias))
|
||||
utils.validate_topology(ping_tablets=True)
|
||||
|
||||
tablet_62344.create_db('vt_test_keyspace')
|
||||
tablet_62344.mquery('vt_test_keyspace', self._create_vt_insert_test)
|
||||
|
||||
tablet_41983.mquery('', 'stop slave')
|
||||
|
|
|
@ -13,6 +13,7 @@ import MySQLdb
|
|||
|
||||
import environment
|
||||
import utils
|
||||
from mysql_flavor import mysql_flavor
|
||||
|
||||
tablet_cell_map = {
|
||||
62344: 'nj',
|
||||
|
@ -148,14 +149,7 @@ class Tablet(object):
|
|||
raise utils.TestError("expected %u rows in %s" % (n, table), result)
|
||||
|
||||
def reset_replication(self):
|
||||
commands = [
|
||||
'RESET MASTER',
|
||||
'STOP SLAVE',
|
||||
'RESET SLAVE',
|
||||
]
|
||||
if environment.mysql_flavor == "GoogleMysql":
|
||||
commands.append('CHANGE MASTER TO MASTER_HOST = ""')
|
||||
self.mquery('', commands)
|
||||
self.mquery('', mysql_flavor.reset_replication_commands())
|
||||
|
||||
def populate(self, dbname, create_sql, insert_sqls=[]):
|
||||
self.create_db(dbname)
|
||||
|
@ -283,7 +277,7 @@ class Tablet(object):
|
|||
def start_vttablet(self, port=None, auth=False, memcache=False,
|
||||
wait_for_state="SERVING", customrules=None,
|
||||
schema_override=None, cert=None, key=None, ca_cert=None,
|
||||
repl_extra_flags={},
|
||||
repl_extra_flags={},table_acl_config=None,
|
||||
target_tablet_type=None, lameduck_period=None,
|
||||
extra_args=None, full_mycnf_args=False,
|
||||
security_policy=None):
|
||||
|
@ -350,6 +344,10 @@ class Tablet(object):
|
|||
if schema_override:
|
||||
args.extend(['-schema-override', schema_override])
|
||||
|
||||
if table_acl_config:
|
||||
args.extend(['-table-acl-config', table_acl_config])
|
||||
args.extend(['-queryserver-config-strict-table-acl'])
|
||||
|
||||
if cert:
|
||||
self.secure_port = environment.reserve_ports(1)
|
||||
args.extend(['-secure-port', '%s' % self.secure_port,
|
||||
|
|
|
@ -79,6 +79,8 @@ class TestTabletManager(unittest.TestCase):
|
|||
tablet_62344.init_tablet('master', 'test_keyspace', '0', parent=False)
|
||||
utils.run_vtctl(['RebuildKeyspaceGraph', 'test_keyspace'])
|
||||
utils.validate_topology()
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
# if these statements don't run before the tablet it will wedge waiting for the
|
||||
# db to become accessible. this is more a bug than a feature.
|
||||
|
@ -116,6 +118,8 @@ class TestTabletManager(unittest.TestCase):
|
|||
# not pinging tablets, as it enables replication checks, and they
|
||||
# break because we only have a single master, no slaves
|
||||
utils.run_vtctl('ValidateShard -ping-tablets=false test_keyspace/0')
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
tablet_62344.kill_vttablet()
|
||||
|
||||
|
@ -129,6 +133,8 @@ class TestTabletManager(unittest.TestCase):
|
|||
tablet_62344.init_tablet('master', 'test_keyspace', '0', parent=False)
|
||||
utils.run_vtctl(['RebuildKeyspaceGraph', 'test_keyspace'])
|
||||
utils.validate_topology()
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
# if these statements don't run before the tablet it will wedge waiting for the
|
||||
# db to become accessible. this is more a bug than a feature.
|
||||
|
@ -226,9 +232,13 @@ class TestTabletManager(unittest.TestCase):
|
|||
tablet_62044.init_tablet('replica', 'test_keyspace', '0')
|
||||
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/*'])
|
||||
utils.validate_topology()
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
tablet_62044.scrap(force=True)
|
||||
utils.validate_topology()
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
|
||||
_create_vt_select_test = '''create table vt_select_test (
|
||||
|
@ -249,6 +259,8 @@ class TestTabletManager(unittest.TestCase):
|
|||
tablet_62344.init_tablet('master', 'test_keyspace', '0')
|
||||
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/0'])
|
||||
utils.validate_topology()
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
tablet_62344.create_db('vt_test_keyspace')
|
||||
tablet_62344.start_vttablet()
|
||||
|
||||
|
@ -306,6 +318,8 @@ class TestTabletManager(unittest.TestCase):
|
|||
tablet_62344.init_tablet('master', 'test_keyspace', '0')
|
||||
utils.run_vtctl(['RebuildShardGraph', 'test_keyspace/0'])
|
||||
utils.validate_topology()
|
||||
srvShard = utils.run_vtctl_json(['GetSrvShard', 'test_nj', 'test_keyspace/0'])
|
||||
self.assertEqual(srvShard['MasterCell'], 'test_nj')
|
||||
|
||||
tablet_62344.populate('vt_test_keyspace', self._create_vt_select_test,
|
||||
self._populate_vt_select_test)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
{
|
||||
"ala": ["ma kota", "miala kota"]
|
||||
}
|
||||
"ala": ["ma kota", "miala kota"],
|
||||
"youtube-dev-dedicated": ["vtpass", "vtpasssec"]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"vtocc_acl_no_access": {},
|
||||
"vtocc_acl_read_only": {"Reader": "youtube-dev-dedicated"},
|
||||
"vtocc_acl_read_write": {"Writer": "youtube-dev-dedicated"},
|
||||
"vtocc_acl_admin": {"Admin": "youtube-dev-dedicated"},
|
||||
"vtocc_acl_all_user_read_only": {"READER":"*"}
|
||||
}
|
|
@ -62,6 +62,13 @@ insert into vtocc_part2 values(1, 3)
|
|||
insert into vtocc_part2 values(2, 4)
|
||||
commit
|
||||
|
||||
create table vtocc_acl_no_access(key1 bigint, key2 bigint, primary key(key1))
|
||||
create table vtocc_acl_read_only(key1 bigint, key2 bigint, primary key(key1))
|
||||
create table vtocc_acl_read_write(key1 bigint, key2 bigint, primary key(key1))
|
||||
create table vtocc_acl_admin(key1 bigint, key2 bigint, primary key(key1))
|
||||
create table vtocc_acl_unmatched(key1 bigint, key2 bigint, primary key(key1))
|
||||
create table vtocc_acl_all_user_read_only(key1 bigint, key2 bigint, primary key(key1))
|
||||
|
||||
# clean
|
||||
drop table if exists vtocc_test
|
||||
drop table if exists vtocc_a
|
||||
|
@ -82,3 +89,9 @@ drop table if exists vtocc_misc
|
|||
drop view if exists vtocc_view
|
||||
drop table if exists vtocc_part1
|
||||
drop table if exists vtocc_part2
|
||||
drop table if exists vtocc_acl_no_access
|
||||
drop table if exists vtocc_acl_read_only
|
||||
drop table if exists vtocc_acl_read_write
|
||||
drop table if exists vtocc_acl_admin
|
||||
drop table if exists vtocc_acl_unmatched
|
||||
drop table if exists vtocc_acl_all_user_read_only
|
||||
|
|
Загрузка…
Ссылка в новой задаче