vitess-gh/go/vt/tabletserver/query_splitter_test.go

476 строки
16 KiB
Go

package tabletserver
import (
"encoding/binary"
"fmt"
"reflect"
"strings"
"testing"
"github.com/youtube/vitess/go/sqltypes"
querypb "github.com/youtube/vitess/go/vt/proto/query"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
"github.com/youtube/vitess/go/vt/tabletserver/proto"
)
func getSchemaInfo() *SchemaInfo {
table := &schema.Table{
Name: "test_table",
}
zero, _ := sqltypes.BuildValue(0)
table.AddColumn("id", sqltypes.Int64, zero, "")
table.AddColumn("id2", sqltypes.Int64, zero, "")
table.AddColumn("count", sqltypes.Int64, zero, "")
table.PKColumns = []int{0}
primaryIndex := table.AddIndex("PRIMARY")
primaryIndex.AddColumn("id", 12345)
id2Index := table.AddIndex("idx_id2")
id2Index.AddColumn("id2", 1234)
tables := make(map[string]*TableInfo, 1)
tables["test_table"] = &TableInfo{Table: table}
tableNoPK := &schema.Table{
Name: "test_table_no_pk",
}
tableNoPK.AddColumn("id", sqltypes.Int64, zero, "")
tableNoPK.PKColumns = []int{}
tables["test_table_no_pk"] = &TableInfo{Table: tableNoPK}
return &SchemaInfo{tables: tables}
}
func TestValidateQuery(t *testing.T) {
schemaInfo := getSchemaInfo()
query := &proto.BoundQuery{}
splitter := NewQuerySplitter(query, "", 3, schemaInfo)
query.Sql = "delete from test_table"
got := splitter.validateQuery()
want := fmt.Errorf("not a select statement")
if !reflect.DeepEqual(got, want) {
t.Errorf("non-select validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select * from test_table order by id"
got = splitter.validateQuery()
want = fmt.Errorf("unsupported query")
if !reflect.DeepEqual(got, want) {
t.Errorf("order by query validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select * from test_table group by id"
got = splitter.validateQuery()
want = fmt.Errorf("unsupported query")
if !reflect.DeepEqual(got, want) {
t.Errorf("group by query validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select A.* from test_table A JOIN test_table B"
got = splitter.validateQuery()
want = fmt.Errorf("unsupported query")
if !reflect.DeepEqual(got, want) {
t.Errorf("join query validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select * from test_table_no_pk"
got = splitter.validateQuery()
want = fmt.Errorf("no primary keys")
if !reflect.DeepEqual(got, want) {
t.Errorf("no PK table validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select * from unknown_table"
got = splitter.validateQuery()
want = fmt.Errorf("can't find table in schema")
if !reflect.DeepEqual(got, want) {
t.Errorf("unknown table validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select * from test_table"
got = splitter.validateQuery()
want = nil
if !reflect.DeepEqual(got, want) {
t.Errorf("valid query validation failed, got:%v, want:%v", got, want)
}
query.Sql = "select * from test_table where count > :count"
got = splitter.validateQuery()
want = nil
if !reflect.DeepEqual(got, want) {
t.Errorf("valid query validation failed, got:%v, want:%v", got, want)
}
splitter = NewQuerySplitter(query, "id2", 0, schemaInfo)
query.Sql = "select * from test_table where count > :count"
got = splitter.validateQuery()
want = nil
if !reflect.DeepEqual(got, want) {
t.Errorf("valid query validation failed, got:%v, want:%v", got, want)
}
splitter = NewQuerySplitter(query, "id2", 0, schemaInfo)
query.Sql = "invalid select * from test_table where count > :count"
if err := splitter.validateQuery(); err == nil {
t.Fatalf("validateQuery() = %v, want: nil", err)
}
// column id2 is indexed
splitter = NewQuerySplitter(query, "id2", 3, schemaInfo)
query.Sql = "select * from test_table where count > :count"
got = splitter.validateQuery()
want = nil
if !reflect.DeepEqual(got, want) {
t.Errorf("valid query validation failed, got:%v, want:%v", got, want)
}
// column does not exist
splitter = NewQuerySplitter(query, "unknown_column", 3, schemaInfo)
got = splitter.validateQuery()
wantStr := "split column is not indexed or does not exist in table schema"
if !strings.Contains(got.Error(), wantStr) {
t.Errorf("unknown table validation failed, got:%v, want:%v", got, wantStr)
}
// column is not indexed
splitter = NewQuerySplitter(query, "count", 3, schemaInfo)
got = splitter.validateQuery()
wantStr = "split column is not indexed or does not exist in table schema"
if !strings.Contains(got.Error(), wantStr) {
t.Errorf("unknown table validation failed, got:%v, want:%v", got, wantStr)
}
}
func TestGetWhereClause(t *testing.T) {
splitter := &QuerySplitter{}
sql := "select * from test_table where count > :count"
statement, _ := sqlparser.Parse(sql)
splitter.sel, _ = statement.(*sqlparser.Select)
splitter.splitColumn = "id"
bindVars := make(map[string]interface{})
// no boundary case, start = end = nil, should not change the where clause
nilValue := sqltypes.Value{}
clause := splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue)
want := " where count > :count"
got := sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
}
// Set lower bound, should add the lower bound condition to where clause
startVal := int64(20)
start, _ := sqltypes.BuildValue(startVal)
bindVars = make(map[string]interface{})
bindVars[":count"] = 300
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue)
want = " where (count > :count) and (id >= :" + startBindVarName + ")"
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
}
v, ok := bindVars[startBindVarName]
if !ok {
t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal)
}
if v != startVal {
t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal)
}
// Set upper bound, should add the upper bound condition to where clause
endVal := int64(40)
end, _ := sqltypes.BuildValue(endVal)
bindVars = make(map[string]interface{})
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end)
want = " where (count > :count) and (id < :" + endBindVarName + ")"
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
}
v, ok = bindVars[endBindVarName]
if !ok {
t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal)
}
if v != endVal {
t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal)
}
// Set both bounds, should add two conditions to where clause
bindVars = make(map[string]interface{})
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName)
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
}
// Original query with no where clause
sql = "select * from test_table"
statement, _ = sqlparser.Parse(sql)
splitter.sel, _ = statement.(*sqlparser.Select)
bindVars = make(map[string]interface{})
// no boundary case, start = end = nil should return no where clause
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, nilValue)
want = ""
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause for nil ranges, got:%v, want:%v", got, want)
}
bindVars = make(map[string]interface{})
// Set both bounds, should add two conditions to where clause
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName)
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
}
v, ok = bindVars[startBindVarName]
if !ok {
t.Fatalf("bind var: %s not found got: nil, want: %v", startBindVarName, startVal)
}
if v != startVal {
t.Fatalf("bind var: %s not found got: %v, want: %v", startBindVarName, v, startVal)
}
v, ok = bindVars[endBindVarName]
if !ok {
t.Fatalf("bind var: %s not found got: nil, want: %v", endBindVarName, endVal)
}
if v != endVal {
t.Fatalf("bind var: %s not found got: %v, want: %v", endBindVarName, v, endVal)
}
}
func TestSplitBoundaries(t *testing.T) {
min, _ := sqltypes.BuildValue(10)
max, _ := sqltypes.BuildValue(60)
row := []sqltypes.Value{min, max}
rows := [][]sqltypes.Value{row}
minField := &querypb.Field{Name: "min", Type: sqltypes.Int64}
maxField := &querypb.Field{Name: "max", Type: sqltypes.Int64}
fields := []*querypb.Field{minField, maxField}
pkMinMax := &sqltypes.Result{
Fields: fields,
Rows: rows,
}
splitter := &QuerySplitter{}
splitter.splitCount = 5
boundaries, err := splitter.splitBoundaries(sqltypes.Int64, pkMinMax)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(boundaries) != splitter.splitCount-1 {
t.Errorf("wrong number of boundaries got: %v, want: %v", len(boundaries), splitter.splitCount-1)
}
got, err := splitter.splitBoundaries(sqltypes.Int64, pkMinMax)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []sqltypes.Value{buildVal(20), buildVal(30), buildVal(40), buildVal(50)}
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect boundaries, got: %v, want: %v", got, want)
}
// Test negative min value
min, _ = sqltypes.BuildValue(-100)
max, _ = sqltypes.BuildValue(100)
row = []sqltypes.Value{min, max}
rows = [][]sqltypes.Value{row}
pkMinMax.Rows = rows
got, err = splitter.splitBoundaries(sqltypes.Int64, pkMinMax)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want = []sqltypes.Value{buildVal(-60), buildVal(-20), buildVal(20), buildVal(60)}
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect boundaries, got: %v, want: %v", got, want)
}
// Test float min max
min, _ = sqltypes.BuildValue(10.5)
max, _ = sqltypes.BuildValue(60.5)
row = []sqltypes.Value{min, max}
rows = [][]sqltypes.Value{row}
minField = &querypb.Field{Name: "min", Type: sqltypes.Float64}
maxField = &querypb.Field{Name: "max", Type: sqltypes.Float64}
fields = []*querypb.Field{minField, maxField}
pkMinMax.Rows = rows
pkMinMax.Fields = fields
got, err = splitter.splitBoundaries(sqltypes.Float64, pkMinMax)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want = []sqltypes.Value{buildVal(20.5), buildVal(30.5), buildVal(40.5), buildVal(50.5)}
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect boundaries, got: %v, want: %v", got, want)
}
}
func buildVal(val interface{}) sqltypes.Value {
v, _ := sqltypes.BuildValue(val)
return v
}
func TestSplitQuery(t *testing.T) {
schemaInfo := getSchemaInfo()
query := &proto.BoundQuery{
Sql: "select * from test_table where count > :count",
}
splitter := NewQuerySplitter(query, "", 3, schemaInfo)
splitter.validateQuery()
min, _ := sqltypes.BuildValue(0)
max, _ := sqltypes.BuildValue(300)
minField := &querypb.Field{
Name: "min",
Type: sqltypes.Int64,
}
maxField := &querypb.Field{
Name: "max",
Type: sqltypes.Int64,
}
fields := []*querypb.Field{minField, maxField}
pkMinMax := &sqltypes.Result{
Fields: fields,
}
// Ensure that empty min max does not cause panic or return any error
splits, err := splitter.split(sqltypes.Int64, pkMinMax)
if err != nil {
t.Errorf("unexpected error while splitting on empty pkMinMax, %s", err)
}
pkMinMax.Rows = [][]sqltypes.Value{[]sqltypes.Value{min, max}}
splits, err = splitter.split(sqltypes.Int64, pkMinMax)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got := []proto.BoundQuery{}
for _, split := range splits {
if split.RowCount != 100 {
t.Errorf("wrong RowCount, got: %v, want: %v", split.RowCount, 100)
}
got = append(got, split.Query)
}
want := []proto.BoundQuery{
{
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
BindVariables: map[string]interface{}{endBindVarName: int64(100)},
},
{
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
BindVariables: map[string]interface{}{
startBindVarName: int64(100),
endBindVarName: int64(200),
},
},
{
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
BindVariables: map[string]interface{}{startBindVarName: int64(200)},
},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("wrong splits, got: %v, want: %v", got, want)
}
}
func TestSplitQueryFractionalColumn(t *testing.T) {
schemaInfo := getSchemaInfo()
query := &proto.BoundQuery{
Sql: "select * from test_table where count > :count",
}
splitter := NewQuerySplitter(query, "", 3, schemaInfo)
splitter.validateQuery()
min, _ := sqltypes.BuildValue(10.5)
max, _ := sqltypes.BuildValue(490.5)
minField := &querypb.Field{
Name: "min",
Type: sqltypes.Float32,
}
maxField := &querypb.Field{
Name: "max",
Type: sqltypes.Float32,
}
fields := []*querypb.Field{minField, maxField}
pkMinMax := &sqltypes.Result{
Fields: fields,
Rows: [][]sqltypes.Value{[]sqltypes.Value{min, max}},
}
splits, err := splitter.split(sqltypes.Float32, pkMinMax)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got := []proto.BoundQuery{}
for _, split := range splits {
if split.RowCount != 160 {
t.Errorf("wrong RowCount, got: %v, want: %v", split.RowCount, 160)
}
got = append(got, split.Query)
}
want := []proto.BoundQuery{
{
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
BindVariables: map[string]interface{}{endBindVarName: 170.5},
},
{
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
BindVariables: map[string]interface{}{
startBindVarName: 170.5,
endBindVarName: 330.5,
},
},
{
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
BindVariables: map[string]interface{}{startBindVarName: 330.5},
},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("wrong splits, got: %v, want: %v", got, want)
}
}
func TestSplitQueryStringColumn(t *testing.T) {
schemaInfo := getSchemaInfo()
query := &proto.BoundQuery{
Sql: "select * from test_table where count > :count",
}
splitter := NewQuerySplitter(query, "", 3, schemaInfo)
splitter.validateQuery()
splits, err := splitter.split(sqltypes.VarChar, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got := []proto.BoundQuery{}
for _, split := range splits {
got = append(got, split.Query)
}
want := []proto.BoundQuery{
{
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
BindVariables: map[string]interface{}{endBindVarName: hexToByteUInt32(0x55555555)},
},
{
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
BindVariables: map[string]interface{}{
startBindVarName: hexToByteUInt32(0x55555555),
endBindVarName: hexToByteUInt32(0xAAAAAAAA),
},
},
{
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
BindVariables: map[string]interface{}{startBindVarName: hexToByteUInt32(0xAAAAAAAA)},
},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("wrong splits, got: %v, want: %v", got, want)
}
}
func hexToByteUInt32(val uint32) []byte {
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, val)
return buf
}