tabletserver: list bind vars support complete

You can now specify a new syntax for queries that
have an IN clause:
select * from t where id in ::list
The corresponding bind variable would be:
{"list": field_types.List([1,2,3])}

- List bind vars are only allowed for IN and NOT IN clauses.
- A bind variable can be a list if and only if
  it's referred as ::list.
- A list must contain at least one value.

For the python client, you need to supply lists using the
field_types.List class. This is because of legacy behavior
where lists were previously encoded as strings.
We should soon change this API to use native lists once we
confirm that no one will be affected by this change.
This commit is contained in:
Sugu Sougoumarane 2014-10-24 13:53:25 -07:00
Родитель b5bed9f0f4
Коммит ce89746dab
10 изменённых файлов: 340 добавлений и 69 удалений

Просмотреть файл

@ -115,6 +115,8 @@ select /* value argument with digit */ :a1 from t
select /* value argument with dot */ :a.b from t
select /* positional argument */ ? from t#select /* positional argument */ :v1 from t
select /* multiple positional arguments */ ?, ? from t#select /* multiple positional arguments */ :v1, :v2 from t
select /* list arg */ * from t where a in ::list
select /* list arg not in */ * from t where a not in ::list
select /* null */ null from t
select /* octal */ 010 from t
select /* hex */ 0xf0 from t

Просмотреть файл

@ -58,19 +58,21 @@ func HasINClause(conditions []BoolExpr) bool {
}
// IsSimpleTuple returns true if the ValExpr is a ValTuple that
// contains simple values.
// contains simple values or if it's a list arg.
func IsSimpleTuple(node ValExpr) bool {
list, ok := node.(ValTuple)
if !ok {
// It's a subquery.
return false
}
for _, n := range list {
if !IsValue(n) {
return false
switch vals := node.(type) {
case ValTuple:
for _, n := range vals {
if !IsValue(n) {
return false
}
}
return true
case ListArg:
return true
}
return true
// It's a subquery
return false
}
// AsInterface converts the ValExpr to an interface. It converts
@ -90,6 +92,8 @@ func AsInterface(node ValExpr) (interface{}, error) {
return vals, nil
case ValArg:
return string(node), nil
case ListArg:
return string(node), nil
case StrVal:
return sqltypes.MakeString(node), nil
case NumVal:

Просмотреть файл

@ -32,12 +32,10 @@ func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]interface{}) ([]by
current := 0
for _, loc := range pq.bindLocations {
buf.WriteString(pq.Query[current:loc.offset])
varName := pq.Query[loc.offset+1 : loc.offset+loc.length]
var supplied interface{}
var ok bool
supplied, ok = bindVariables[varName]
if !ok {
return nil, fmt.Errorf("missing bind var %s", varName)
name := pq.Query[loc.offset : loc.offset+loc.length]
supplied, _, err := FetchBindVar(name, bindVariables)
if err != nil {
return nil, err
}
if err := EncodeValue(buf, supplied); err != nil {
return nil, err
@ -76,6 +74,17 @@ func EncodeValue(buf *bytes.Buffer, value interface{}) error {
}
buf.WriteByte(')')
}
case []interface{}:
buf.WriteByte('(')
for i, v := range bindVal {
if i != 0 {
buf.WriteString(", ")
}
if err := EncodeValue(buf, v); err != nil {
return err
}
}
buf.WriteByte(')')
case TupleEqualityList:
if err := bindVal.Encode(buf); err != nil {
return err
@ -148,3 +157,29 @@ func (tpl *TupleEqualityList) encodeLHS(buf *bytes.Buffer) {
}
buf.WriteByte(')')
}
func FetchBindVar(name string, bindVariables map[string]interface{}) (val interface{}, isList bool, err error) {
name = name[1:]
if name[0] == ':' {
name = name[1:]
isList = true
}
supplied, ok := bindVariables[name]
if !ok {
return nil, false, fmt.Errorf("missing bind var %s", name)
}
list, gotList := supplied.([]interface{})
if isList {
if !gotList {
return nil, false, fmt.Errorf("unexpected list arg type %T for key %s", supplied, name)
}
if len(list) == 0 {
return nil, false, fmt.Errorf("empty list supplied for %s", name)
}
return list, true, nil
}
if gotList {
return nil, false, fmt.Errorf("unexpected arg type %T for key %s", supplied, name)
}
return supplied, false, nil
}

Просмотреть файл

@ -72,6 +72,46 @@ func TestParsedQuery(t *testing.T) {
},
},
"select * from a where id in ((1, 'aa'), (null, 'bb'))",
}, {
"list bind vars",
"select * from a where id in ::vals",
map[string]interface{}{
"vals": []interface{}{
1,
"aa",
},
},
"select * from a where id in (1, 'aa')",
}, {
"list bind vars single argument",
"select * from a where id in ::vals",
map[string]interface{}{
"vals": []interface{}{
1,
},
},
"select * from a where id in (1)",
}, {
"list bind vars 0 arguments",
"select * from a where id in ::vals",
map[string]interface{}{
"vals": []interface{}{},
},
"empty list supplied for vals",
}, {
"non-list bind var supplied",
"select * from a where id in ::vals",
map[string]interface{}{
"vals": 1,
},
"unexpected list arg type int for key vals",
}, {
"list bind var for non-list",
"select * from a where id = :vals",
map[string]interface{}{
"vals": []interface{}{1},
},
"unexpected arg type []interface {} for key vals",
}, {
"single column tuple equality",
// We have to use an incorrect construct to get around the parser.

Просмотреть файл

@ -117,7 +117,7 @@ func (tkn *Tokenizer) Lex(lval *yySymType) int {
typ, val = tkn.Scan()
}
switch typ {
case ID, STRING, NUMBER, VALUE_ARG, COMMENT:
case ID, STRING, NUMBER, VALUE_ARG, LIST_ARG, COMMENT:
lval.bytes = val
}
tkn.errorToken = val

Просмотреть файл

@ -13,44 +13,98 @@ import (
log "github.com/golang/glog"
"github.com/youtube/vitess/go/sqltypes"
"github.com/youtube/vitess/go/vt/schema"
"github.com/youtube/vitess/go/vt/sqlparser"
)
// buildValueList builds the set of PK reference rows used to drive the next query.
// It uses the PK values supplied in the original query and bind variables.
// The generated reference rows are validated for type match against the PK of the table.
func buildValueList(tableInfo *TableInfo, pkValues []interface{}, bindVars map[string]interface{}) ([][]sqltypes.Value, error) {
length := -1
for _, pkValue := range pkValues {
if list, ok := pkValue.([]interface{}); ok {
if length == -1 {
if length = len(list); length == 0 {
panic(fmt.Sprintf("empty list for values %v", pkValues))
resolved, length, err := resolvePKValues(tableInfo, pkValues, bindVars)
if err != nil {
return nil, err
}
valueList := make([][]sqltypes.Value, length)
for i := 0; i < length; i++ {
valueList[i] = make([]sqltypes.Value, len(resolved))
for j, val := range resolved {
if list, ok := val.([]sqltypes.Value); ok {
valueList[i][j] = list[i]
} else {
valueList[i][j] = val.(sqltypes.Value)
}
}
}
return valueList, nil
}
func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[string]interface{}) (resolved []interface{}, length int, err error) {
length = -1
setLength := func(list []sqltypes.Value) {
if length == -1 {
length = len(list)
} else if len(list) != length {
panic(fmt.Sprintf("mismatched lengths for values %v", pkValues))
}
}
resolved = make([]interface{}, len(pkValues))
for i, val := range pkValues {
switch val := val.(type) {
case string:
if val[1] != ':' {
resolved[i], err = resolveValue(tableInfo.GetPKColumn(i), val, bindVars)
if err != nil {
return nil, 0, err
}
} else if length != len(list) {
panic(fmt.Sprintf("mismatched lengths for values %v", pkValues))
} else {
list, err := resolveListArg(tableInfo.GetPKColumn(i), val, bindVars)
if err != nil {
return nil, 0, err
}
setLength(list)
resolved[i] = list
}
case []interface{}:
list := make([]sqltypes.Value, len(val))
for j, listVal := range val {
list[j], err = resolveValue(tableInfo.GetPKColumn(i), listVal, bindVars)
if err != nil {
return nil, 0, err
}
}
setLength(list)
resolved[i] = list
default:
resolved[i], err = resolveValue(tableInfo.GetPKColumn(i), val, nil)
if err != nil {
return nil, 0, err
}
}
}
if length == -1 {
length = 1
}
valueList := make([][]sqltypes.Value, length)
for i := 0; i < length; i++ {
valueList[i] = make([]sqltypes.Value, len(pkValues))
for j, pkValue := range pkValues {
var value interface{}
if list, ok := pkValue.([]interface{}); ok {
value = list[i]
} else {
value = pkValue
}
var err error
if valueList[i][j], err = resolveValue(tableInfo.GetPKColumn(j), value, bindVars); err != nil {
return valueList, err
}
}
return resolved, length, nil
}
func resolveListArg(col *schema.TableColumn, key string, bindVars map[string]interface{}) ([]sqltypes.Value, error) {
val, _, err := sqlparser.FetchBindVar(key, bindVars)
if err != nil {
return nil, NewTabletError(FAIL, "%v", err)
}
return valueList, nil
list := val.([]interface{})
resolved := make([]sqltypes.Value, len(list))
for i, v := range list {
sqlval, err := sqltypes.BuildValue(v)
if err != nil {
return nil, NewTabletError(FAIL, "%v", err)
}
if err = validateValue(col, sqlval); err != nil {
return nil, err
}
resolved[i] = sqlval
}
return resolved, nil
}
// buildSecondaryList is used for handling ON DUPLICATE DMLs, or those that change the PK.
@ -78,19 +132,17 @@ func buildSecondaryList(tableInfo *TableInfo, pkList [][]sqltypes.Value, seconda
func resolveValue(col *schema.TableColumn, value interface{}, bindVars map[string]interface{}) (result sqltypes.Value, err error) {
switch v := value.(type) {
case string:
lookup, ok := bindVars[v[1:]]
if !ok {
return result, NewTabletError(FAIL, "missing bind var %s", v)
val, _, err := sqlparser.FetchBindVar(v, bindVars)
if err != nil {
return result, NewTabletError(FAIL, "%v", err)
}
if sqlval, err := sqltypes.BuildValue(lookup); err != nil {
return result, NewTabletError(FAIL, err.Error())
} else {
result = sqlval
sqlval, err := sqltypes.BuildValue(val)
if err != nil {
return result, NewTabletError(FAIL, "%v", err)
}
result = sqlval
case sqltypes.Value:
result = v
case nil:
// no op
default:
panic(fmt.Sprintf("incompatible value type %v", v))
}

Просмотреть файл

@ -9,13 +9,11 @@ import (
)
func TestBuildValuesList(t *testing.T) {
pk1 := "pk1"
pk2 := "pk2"
tableInfo := createTableInfo("Table",
map[string]string{pk1: "int", pk2: "varchar(128)", "col1": "int"},
[]string{pk1, pk2})
map[string]string{"pk1": "int", "pk2": "varbinary(128)", "col1": "int"},
[]string{"pk1", "pk2"})
// case 1: simple PK clause. e.g. where pk1 = 1
// simple PK clause. e.g. where pk1 = 1
bindVars := map[string]interface{}{}
pk1Val, _ := sqltypes.BuildValue(1)
pkValues := []interface{}{pk1Val}
@ -23,60 +21,173 @@ func TestBuildValuesList(t *testing.T) {
want := [][]sqltypes.Value{[]sqltypes.Value{pk1Val}}
got, _ := buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("case 1 failed, got %v, want %v", got, want)
t.Errorf("got %v, want %v", got, want)
}
// case 2: simple PK clause with bindVars. e.g. where pk1 = :pk1
bindVars[pk1] = 1
// simple PK clause with bindVars. e.g. where pk1 = :pk1
bindVars["pk1"] = 1
pkValues = []interface{}{":pk1"}
// want [[1]]
want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val}}
got, _ = buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("case 2 failed, got %v, want %v", got, want)
t.Errorf("got %v, want %v", got, want)
}
// case 3: composite pK clause. e.g. where pk1 = 1 and pk2 = "abc"
// null value
bindVars["pk1"] = nil
pkValues = []interface{}{":pk1"}
// want [[1]]
want = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.Value{}}}
got, _ = buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
// invalid value
bindVars["pk1"] = struct{}{}
pkValues = []interface{}{":pk1"}
wantErr := "error: unsupported bind variable type struct {}: {}"
got, err := buildValueList(&tableInfo, pkValues, bindVars)
if err == nil || err.Error() != wantErr {
t.Errorf("got %v, want %v", err, wantErr)
}
// type mismatch int
bindVars["pk1"] = "str"
pkValues = []interface{}{":pk1"}
wantErr = "error: type mismatch, expecting numeric type for str"
got, err = buildValueList(&tableInfo, pkValues, bindVars)
if err == nil || err.Error() != wantErr {
t.Errorf("got %v, want %v", err, wantErr)
}
// type mismatch binary
bindVars["pk1"] = 1
bindVars["pk2"] = 1
pkValues = []interface{}{":pk1", ":pk2"}
wantErr = "error: type mismatch, expecting string type for 1"
got, err = buildValueList(&tableInfo, pkValues, bindVars)
t.Logf("%v", got)
if err == nil || err.Error() != wantErr {
t.Errorf("got %v, want %v", err, wantErr)
}
// composite PK clause. e.g. where pk1 = 1 and pk2 = "abc"
pk2Val, _ := sqltypes.BuildValue("abc")
pkValues = []interface{}{pk1Val, pk2Val}
// want [[1 abc]]
want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val, pk2Val}}
got, _ = buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("case 3 failed, got %v, want %v", got, want)
t.Errorf("got %v, want %v", got, want)
}
// case 4: multi row composite PK insert
// multi row composite PK insert
// e.g. insert into Table(pk1,pk2) values (1, "abc"), (2, "xyz")
pk1Val2, _ := sqltypes.BuildValue(2)
pk2Val2, _ := sqltypes.BuildValue("xyz")
pkValues = []interface{}{
[]interface{}{pk1Val, pk1Val2},
[]interface{}{pk2Val, pk2Val2}}
[]interface{}{pk2Val, pk2Val2},
}
// want [[1 abc][2 xyz]]
want = [][]sqltypes.Value{
[]sqltypes.Value{pk1Val, pk2Val},
[]sqltypes.Value{pk1Val2, pk2Val2}}
got, _ = buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("case 4 failed, got %v, want %v", got, want)
t.Errorf("got %v, want %v", got, want)
}
// case 5: composite PK IN clause
// composite PK IN clause
// e.g. where pk1 = 1 and pk2 IN ("abc", "xyz")
pkValues = []interface{}{
pk1Val,
[]interface{}{pk2Val, pk2Val2}}
[]interface{}{pk2Val, pk2Val2},
}
// want [[1 abc][1 xyz]]
want = [][]sqltypes.Value{
[]sqltypes.Value{pk1Val, pk2Val},
[]sqltypes.Value{pk1Val, pk2Val2}}
[]sqltypes.Value{pk1Val, pk2Val2},
}
got, _ = buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("case 5 failed, got %v, want %v", got, want)
t.Errorf("got %v, want %v", got, want)
}
// list arg
// e.g. where pk1 = 1 and pk2 IN ::list
bindVars = map[string]interface{}{
"list": []interface{}{
"abc",
"xyz",
},
}
pkValues = []interface{}{
pk1Val,
"::list",
}
// want [[1 abc][1 xyz]]
want = [][]sqltypes.Value{
[]sqltypes.Value{pk1Val, pk2Val},
[]sqltypes.Value{pk1Val, pk2Val2},
}
// list arg one value
// e.g. where pk1 = 1 and pk2 IN ::list
bindVars = map[string]interface{}{
"list": []interface{}{
"abc",
},
}
pkValues = []interface{}{
pk1Val,
"::list",
}
// want [[1 abc][1 xyz]]
want = [][]sqltypes.Value{
[]sqltypes.Value{pk1Val, pk2Val},
}
got, _ = buildValueList(&tableInfo, pkValues, bindVars)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
// list arg empty list
bindVars = map[string]interface{}{
"list": []interface{}{},
}
pkValues = []interface{}{
pk1Val,
"::list",
}
wantErr = "error: empty list supplied for list"
got, err = buildValueList(&tableInfo, pkValues, bindVars)
if err == nil || err.Error() != wantErr {
t.Errorf("got %v, want %v", err, wantErr)
}
// list arg for non-list
bindVars = map[string]interface{}{
"list": []interface{}{},
}
pkValues = []interface{}{
pk1Val,
":list",
}
wantErr = "error: unexpected arg type []interface {} for key list"
got, err = buildValueList(&tableInfo, pkValues, bindVars)
if err == nil || err.Error() != wantErr {
t.Errorf("got %v, want %v", err, wantErr)
}
}
func TestBuildSecondaryList(t *testing.T) {

Просмотреть файл

@ -71,6 +71,11 @@ conversions = {
VT_NEWDECIMAL : Decimal,
}
# This is a temporary workaround till we figure out how to support
# native lists in our API.
class List(list):
pass
NoneType = type(None)
# FIXME(msolomon) we could make a SqlLiteral ABC and just type check.
@ -88,7 +93,7 @@ def convert_bind_vars(bind_variables):
new_vars[key] = times.DateTimeToString(val)
elif isinstance(val, datetime.date):
new_vars[key] = times.DateToString(val)
elif isinstance(val, (int, long, float, str, NoneType)):
elif isinstance(val, (int, long, float, str, List, NoneType)):
new_vars[key] = val
else:
# NOTE(msolomon) begrudgingly I allow this - we just have too much code

Просмотреть файл

@ -1,4 +1,5 @@
from vtdb import dbexceptions
from vtdb import field_types
import framework
import cache_cases1
@ -36,6 +37,16 @@ class TestCache(framework.TestCase):
else:
self.fail("Did not receive exception")
def test_cache_list_arg(self):
cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List([3, 4, 32768])})
self.assertEqual(cu.rowcount, 2)
cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List([3, 4])})
self.assertEqual(cu.rowcount, 2)
cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List([3])})
self.assertEqual(cu.rowcount, 1)
with self.assertRaises(dbexceptions.DatabaseError):
cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List()})
def test_uncache(self):
try:
# Verify row cache is working

Просмотреть файл

@ -1,5 +1,6 @@
import time
from vtdb import field_types
from vtdb import dbexceptions
from vtdb import tablet as tablet_conn
from vtdb import cursor
@ -39,6 +40,16 @@ class TestNocache(framework.TestCase):
self.assertEqual(vstart.mget("Queries.TotalCount", 0)+1, vend.Queries.TotalCount)
self.assertEqual(vstart.mget("Queries.Histograms.PASS_SELECT.Count", 0)+1, vend.Queries.Histograms.PASS_SELECT.Count)
def test_nocache_list_arg(self):
cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List([2, 3, 4])})
self.assertEqual(cu.rowcount, 2)
cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List([3, 4])})
self.assertEqual(cu.rowcount, 1)
cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List([3])})
self.assertEqual(cu.rowcount, 1)
with self.assertRaises(dbexceptions.DatabaseError):
cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List()})
def test_commit(self):
vstart = self.env.debug_vars()
self.env.txlog.reset()