From ce89746dabac0117818b7cb14ad604668f20ff10 Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Fri, 24 Oct 2014 13:53:25 -0700 Subject: [PATCH] tabletserver: list bind vars support complete You can now specify a new syntax for queries that have an IN clause: select * from t where id in ::list The corresponding bind variable would be: {"list": field_types.List([1,2,3])} - List bind vars are only allowed for IN and NOT IN clauses. - A bind variable can be a list if and only if it's referred as ::list. - A list must contain at least one value. For the python client, you need to supply lists using the field_types.List class. This is because of legacy behavior where lists were previously encoded as strings. We should soon change this API to use native lists once we confirm that no one will be affected by this change. --- data/test/sqlparser_test/parse_pass.sql | 2 + go/vt/sqlparser/analyzer.go | 24 ++-- go/vt/sqlparser/parsed_query.go | 47 +++++++- go/vt/sqlparser/parsed_query_test.go | 40 ++++++ go/vt/sqlparser/token.go | 2 +- go/vt/tabletserver/codex.go | 118 +++++++++++++----- go/vt/tabletserver/codex_test.go | 147 ++++++++++++++++++++--- py/vtdb/field_types.py | 7 +- test/queryservice_tests/cache_tests.py | 11 ++ test/queryservice_tests/nocache_tests.py | 11 ++ 10 files changed, 340 insertions(+), 69 deletions(-) diff --git a/data/test/sqlparser_test/parse_pass.sql b/data/test/sqlparser_test/parse_pass.sql index 238118d21e..826c07c9ab 100644 --- a/data/test/sqlparser_test/parse_pass.sql +++ b/data/test/sqlparser_test/parse_pass.sql @@ -115,6 +115,8 @@ select /* value argument with digit */ :a1 from t select /* value argument with dot */ :a.b from t select /* positional argument */ ? from t#select /* positional argument */ :v1 from t select /* multiple positional arguments */ ?, ? from t#select /* multiple positional arguments */ :v1, :v2 from t +select /* list arg */ * from t where a in ::list +select /* list arg not in */ * from t where a not in ::list select /* null */ null from t select /* octal */ 010 from t select /* hex */ 0xf0 from t diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index 73b8b0bcb3..69f8c3c42d 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -58,19 +58,21 @@ func HasINClause(conditions []BoolExpr) bool { } // IsSimpleTuple returns true if the ValExpr is a ValTuple that -// contains simple values. +// contains simple values or if it's a list arg. func IsSimpleTuple(node ValExpr) bool { - list, ok := node.(ValTuple) - if !ok { - // It's a subquery. - return false - } - for _, n := range list { - if !IsValue(n) { - return false + switch vals := node.(type) { + case ValTuple: + for _, n := range vals { + if !IsValue(n) { + return false + } } + return true + case ListArg: + return true } - return true + // It's a subquery + return false } // AsInterface converts the ValExpr to an interface. It converts @@ -90,6 +92,8 @@ func AsInterface(node ValExpr) (interface{}, error) { return vals, nil case ValArg: return string(node), nil + case ListArg: + return string(node), nil case StrVal: return sqltypes.MakeString(node), nil case NumVal: diff --git a/go/vt/sqlparser/parsed_query.go b/go/vt/sqlparser/parsed_query.go index 8dd84bf066..ecdb4b24ca 100644 --- a/go/vt/sqlparser/parsed_query.go +++ b/go/vt/sqlparser/parsed_query.go @@ -32,12 +32,10 @@ func (pq *ParsedQuery) GenerateQuery(bindVariables map[string]interface{}) ([]by current := 0 for _, loc := range pq.bindLocations { buf.WriteString(pq.Query[current:loc.offset]) - varName := pq.Query[loc.offset+1 : loc.offset+loc.length] - var supplied interface{} - var ok bool - supplied, ok = bindVariables[varName] - if !ok { - return nil, fmt.Errorf("missing bind var %s", varName) + name := pq.Query[loc.offset : loc.offset+loc.length] + supplied, _, err := FetchBindVar(name, bindVariables) + if err != nil { + return nil, err } if err := EncodeValue(buf, supplied); err != nil { return nil, err @@ -76,6 +74,17 @@ func EncodeValue(buf *bytes.Buffer, value interface{}) error { } buf.WriteByte(')') } + case []interface{}: + buf.WriteByte('(') + for i, v := range bindVal { + if i != 0 { + buf.WriteString(", ") + } + if err := EncodeValue(buf, v); err != nil { + return err + } + } + buf.WriteByte(')') case TupleEqualityList: if err := bindVal.Encode(buf); err != nil { return err @@ -148,3 +157,29 @@ func (tpl *TupleEqualityList) encodeLHS(buf *bytes.Buffer) { } buf.WriteByte(')') } + +func FetchBindVar(name string, bindVariables map[string]interface{}) (val interface{}, isList bool, err error) { + name = name[1:] + if name[0] == ':' { + name = name[1:] + isList = true + } + supplied, ok := bindVariables[name] + if !ok { + return nil, false, fmt.Errorf("missing bind var %s", name) + } + list, gotList := supplied.([]interface{}) + if isList { + if !gotList { + return nil, false, fmt.Errorf("unexpected list arg type %T for key %s", supplied, name) + } + if len(list) == 0 { + return nil, false, fmt.Errorf("empty list supplied for %s", name) + } + return list, true, nil + } + if gotList { + return nil, false, fmt.Errorf("unexpected arg type %T for key %s", supplied, name) + } + return supplied, false, nil +} diff --git a/go/vt/sqlparser/parsed_query_test.go b/go/vt/sqlparser/parsed_query_test.go index 1825e154f1..f9f200163d 100644 --- a/go/vt/sqlparser/parsed_query_test.go +++ b/go/vt/sqlparser/parsed_query_test.go @@ -72,6 +72,46 @@ func TestParsedQuery(t *testing.T) { }, }, "select * from a where id in ((1, 'aa'), (null, 'bb'))", + }, { + "list bind vars", + "select * from a where id in ::vals", + map[string]interface{}{ + "vals": []interface{}{ + 1, + "aa", + }, + }, + "select * from a where id in (1, 'aa')", + }, { + "list bind vars single argument", + "select * from a where id in ::vals", + map[string]interface{}{ + "vals": []interface{}{ + 1, + }, + }, + "select * from a where id in (1)", + }, { + "list bind vars 0 arguments", + "select * from a where id in ::vals", + map[string]interface{}{ + "vals": []interface{}{}, + }, + "empty list supplied for vals", + }, { + "non-list bind var supplied", + "select * from a where id in ::vals", + map[string]interface{}{ + "vals": 1, + }, + "unexpected list arg type int for key vals", + }, { + "list bind var for non-list", + "select * from a where id = :vals", + map[string]interface{}{ + "vals": []interface{}{1}, + }, + "unexpected arg type []interface {} for key vals", }, { "single column tuple equality", // We have to use an incorrect construct to get around the parser. diff --git a/go/vt/sqlparser/token.go b/go/vt/sqlparser/token.go index 1a5b15d130..dc4340340d 100644 --- a/go/vt/sqlparser/token.go +++ b/go/vt/sqlparser/token.go @@ -117,7 +117,7 @@ func (tkn *Tokenizer) Lex(lval *yySymType) int { typ, val = tkn.Scan() } switch typ { - case ID, STRING, NUMBER, VALUE_ARG, COMMENT: + case ID, STRING, NUMBER, VALUE_ARG, LIST_ARG, COMMENT: lval.bytes = val } tkn.errorToken = val diff --git a/go/vt/tabletserver/codex.go b/go/vt/tabletserver/codex.go index b3e48898df..57ad55c5ce 100644 --- a/go/vt/tabletserver/codex.go +++ b/go/vt/tabletserver/codex.go @@ -13,44 +13,98 @@ import ( log "github.com/golang/glog" "github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/vt/schema" + "github.com/youtube/vitess/go/vt/sqlparser" ) // buildValueList builds the set of PK reference rows used to drive the next query. // It uses the PK values supplied in the original query and bind variables. // The generated reference rows are validated for type match against the PK of the table. func buildValueList(tableInfo *TableInfo, pkValues []interface{}, bindVars map[string]interface{}) ([][]sqltypes.Value, error) { - length := -1 - for _, pkValue := range pkValues { - if list, ok := pkValue.([]interface{}); ok { - if length == -1 { - if length = len(list); length == 0 { - panic(fmt.Sprintf("empty list for values %v", pkValues)) + resolved, length, err := resolvePKValues(tableInfo, pkValues, bindVars) + if err != nil { + return nil, err + } + valueList := make([][]sqltypes.Value, length) + for i := 0; i < length; i++ { + valueList[i] = make([]sqltypes.Value, len(resolved)) + for j, val := range resolved { + if list, ok := val.([]sqltypes.Value); ok { + valueList[i][j] = list[i] + } else { + valueList[i][j] = val.(sqltypes.Value) + } + } + } + return valueList, nil +} + +func resolvePKValues(tableInfo *TableInfo, pkValues []interface{}, bindVars map[string]interface{}) (resolved []interface{}, length int, err error) { + length = -1 + setLength := func(list []sqltypes.Value) { + if length == -1 { + length = len(list) + } else if len(list) != length { + panic(fmt.Sprintf("mismatched lengths for values %v", pkValues)) + } + } + resolved = make([]interface{}, len(pkValues)) + for i, val := range pkValues { + switch val := val.(type) { + case string: + if val[1] != ':' { + resolved[i], err = resolveValue(tableInfo.GetPKColumn(i), val, bindVars) + if err != nil { + return nil, 0, err } - } else if length != len(list) { - panic(fmt.Sprintf("mismatched lengths for values %v", pkValues)) + } else { + list, err := resolveListArg(tableInfo.GetPKColumn(i), val, bindVars) + if err != nil { + return nil, 0, err + } + setLength(list) + resolved[i] = list + } + case []interface{}: + list := make([]sqltypes.Value, len(val)) + for j, listVal := range val { + list[j], err = resolveValue(tableInfo.GetPKColumn(i), listVal, bindVars) + if err != nil { + return nil, 0, err + } + } + setLength(list) + resolved[i] = list + default: + resolved[i], err = resolveValue(tableInfo.GetPKColumn(i), val, nil) + if err != nil { + return nil, 0, err } } } if length == -1 { length = 1 } - valueList := make([][]sqltypes.Value, length) - for i := 0; i < length; i++ { - valueList[i] = make([]sqltypes.Value, len(pkValues)) - for j, pkValue := range pkValues { - var value interface{} - if list, ok := pkValue.([]interface{}); ok { - value = list[i] - } else { - value = pkValue - } - var err error - if valueList[i][j], err = resolveValue(tableInfo.GetPKColumn(j), value, bindVars); err != nil { - return valueList, err - } - } + return resolved, length, nil +} + +func resolveListArg(col *schema.TableColumn, key string, bindVars map[string]interface{}) ([]sqltypes.Value, error) { + val, _, err := sqlparser.FetchBindVar(key, bindVars) + if err != nil { + return nil, NewTabletError(FAIL, "%v", err) } - return valueList, nil + list := val.([]interface{}) + resolved := make([]sqltypes.Value, len(list)) + for i, v := range list { + sqlval, err := sqltypes.BuildValue(v) + if err != nil { + return nil, NewTabletError(FAIL, "%v", err) + } + if err = validateValue(col, sqlval); err != nil { + return nil, err + } + resolved[i] = sqlval + } + return resolved, nil } // buildSecondaryList is used for handling ON DUPLICATE DMLs, or those that change the PK. @@ -78,19 +132,17 @@ func buildSecondaryList(tableInfo *TableInfo, pkList [][]sqltypes.Value, seconda func resolveValue(col *schema.TableColumn, value interface{}, bindVars map[string]interface{}) (result sqltypes.Value, err error) { switch v := value.(type) { case string: - lookup, ok := bindVars[v[1:]] - if !ok { - return result, NewTabletError(FAIL, "missing bind var %s", v) + val, _, err := sqlparser.FetchBindVar(v, bindVars) + if err != nil { + return result, NewTabletError(FAIL, "%v", err) } - if sqlval, err := sqltypes.BuildValue(lookup); err != nil { - return result, NewTabletError(FAIL, err.Error()) - } else { - result = sqlval + sqlval, err := sqltypes.BuildValue(val) + if err != nil { + return result, NewTabletError(FAIL, "%v", err) } + result = sqlval case sqltypes.Value: result = v - case nil: - // no op default: panic(fmt.Sprintf("incompatible value type %v", v)) } diff --git a/go/vt/tabletserver/codex_test.go b/go/vt/tabletserver/codex_test.go index a20f3c5fb0..c1378812ec 100644 --- a/go/vt/tabletserver/codex_test.go +++ b/go/vt/tabletserver/codex_test.go @@ -9,13 +9,11 @@ import ( ) func TestBuildValuesList(t *testing.T) { - pk1 := "pk1" - pk2 := "pk2" tableInfo := createTableInfo("Table", - map[string]string{pk1: "int", pk2: "varchar(128)", "col1": "int"}, - []string{pk1, pk2}) + map[string]string{"pk1": "int", "pk2": "varbinary(128)", "col1": "int"}, + []string{"pk1", "pk2"}) - // case 1: simple PK clause. e.g. where pk1 = 1 + // simple PK clause. e.g. where pk1 = 1 bindVars := map[string]interface{}{} pk1Val, _ := sqltypes.BuildValue(1) pkValues := []interface{}{pk1Val} @@ -23,60 +21,173 @@ func TestBuildValuesList(t *testing.T) { want := [][]sqltypes.Value{[]sqltypes.Value{pk1Val}} got, _ := buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { - t.Errorf("case 1 failed, got %v, want %v", got, want) + t.Errorf("got %v, want %v", got, want) } - // case 2: simple PK clause with bindVars. e.g. where pk1 = :pk1 - bindVars[pk1] = 1 + // simple PK clause with bindVars. e.g. where pk1 = :pk1 + bindVars["pk1"] = 1 pkValues = []interface{}{":pk1"} // want [[1]] want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { - t.Errorf("case 2 failed, got %v, want %v", got, want) + t.Errorf("got %v, want %v", got, want) } - // case 3: composite pK clause. e.g. where pk1 = 1 and pk2 = "abc" + // null value + bindVars["pk1"] = nil + pkValues = []interface{}{":pk1"} + // want [[1]] + want = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.Value{}}} + got, _ = buildValueList(&tableInfo, pkValues, bindVars) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + + // invalid value + bindVars["pk1"] = struct{}{} + pkValues = []interface{}{":pk1"} + wantErr := "error: unsupported bind variable type struct {}: {}" + + got, err := buildValueList(&tableInfo, pkValues, bindVars) + if err == nil || err.Error() != wantErr { + t.Errorf("got %v, want %v", err, wantErr) + } + + // type mismatch int + bindVars["pk1"] = "str" + pkValues = []interface{}{":pk1"} + wantErr = "error: type mismatch, expecting numeric type for str" + + got, err = buildValueList(&tableInfo, pkValues, bindVars) + if err == nil || err.Error() != wantErr { + t.Errorf("got %v, want %v", err, wantErr) + } + + // type mismatch binary + bindVars["pk1"] = 1 + bindVars["pk2"] = 1 + pkValues = []interface{}{":pk1", ":pk2"} + wantErr = "error: type mismatch, expecting string type for 1" + + got, err = buildValueList(&tableInfo, pkValues, bindVars) + t.Logf("%v", got) + if err == nil || err.Error() != wantErr { + t.Errorf("got %v, want %v", err, wantErr) + } + + // composite PK clause. e.g. where pk1 = 1 and pk2 = "abc" pk2Val, _ := sqltypes.BuildValue("abc") pkValues = []interface{}{pk1Val, pk2Val} // want [[1 abc]] want = [][]sqltypes.Value{[]sqltypes.Value{pk1Val, pk2Val}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { - t.Errorf("case 3 failed, got %v, want %v", got, want) + t.Errorf("got %v, want %v", got, want) } - // case 4: multi row composite PK insert + // multi row composite PK insert // e.g. insert into Table(pk1,pk2) values (1, "abc"), (2, "xyz") pk1Val2, _ := sqltypes.BuildValue(2) pk2Val2, _ := sqltypes.BuildValue("xyz") pkValues = []interface{}{ []interface{}{pk1Val, pk1Val2}, - []interface{}{pk2Val, pk2Val2}} + []interface{}{pk2Val, pk2Val2}, + } // want [[1 abc][2 xyz]] want = [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2Val}, []sqltypes.Value{pk1Val2, pk2Val2}} got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { - t.Errorf("case 4 failed, got %v, want %v", got, want) + t.Errorf("got %v, want %v", got, want) } - // case 5: composite PK IN clause + // composite PK IN clause // e.g. where pk1 = 1 and pk2 IN ("abc", "xyz") pkValues = []interface{}{ pk1Val, - []interface{}{pk2Val, pk2Val2}} + []interface{}{pk2Val, pk2Val2}, + } // want [[1 abc][1 xyz]] want = [][]sqltypes.Value{ []sqltypes.Value{pk1Val, pk2Val}, - []sqltypes.Value{pk1Val, pk2Val2}} + []sqltypes.Value{pk1Val, pk2Val2}, + } got, _ = buildValueList(&tableInfo, pkValues, bindVars) if !reflect.DeepEqual(got, want) { - t.Errorf("case 5 failed, got %v, want %v", got, want) + t.Errorf("got %v, want %v", got, want) } + // list arg + // e.g. where pk1 = 1 and pk2 IN ::list + bindVars = map[string]interface{}{ + "list": []interface{}{ + "abc", + "xyz", + }, + } + pkValues = []interface{}{ + pk1Val, + "::list", + } + // want [[1 abc][1 xyz]] + want = [][]sqltypes.Value{ + []sqltypes.Value{pk1Val, pk2Val}, + []sqltypes.Value{pk1Val, pk2Val2}, + } + + // list arg one value + // e.g. where pk1 = 1 and pk2 IN ::list + bindVars = map[string]interface{}{ + "list": []interface{}{ + "abc", + }, + } + pkValues = []interface{}{ + pk1Val, + "::list", + } + // want [[1 abc][1 xyz]] + want = [][]sqltypes.Value{ + []sqltypes.Value{pk1Val, pk2Val}, + } + + got, _ = buildValueList(&tableInfo, pkValues, bindVars) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + + // list arg empty list + bindVars = map[string]interface{}{ + "list": []interface{}{}, + } + pkValues = []interface{}{ + pk1Val, + "::list", + } + wantErr = "error: empty list supplied for list" + + got, err = buildValueList(&tableInfo, pkValues, bindVars) + if err == nil || err.Error() != wantErr { + t.Errorf("got %v, want %v", err, wantErr) + } + + // list arg for non-list + bindVars = map[string]interface{}{ + "list": []interface{}{}, + } + pkValues = []interface{}{ + pk1Val, + ":list", + } + wantErr = "error: unexpected arg type []interface {} for key list" + + got, err = buildValueList(&tableInfo, pkValues, bindVars) + if err == nil || err.Error() != wantErr { + t.Errorf("got %v, want %v", err, wantErr) + } } func TestBuildSecondaryList(t *testing.T) { diff --git a/py/vtdb/field_types.py b/py/vtdb/field_types.py index 2256ece8de..403fcbf077 100755 --- a/py/vtdb/field_types.py +++ b/py/vtdb/field_types.py @@ -71,6 +71,11 @@ conversions = { VT_NEWDECIMAL : Decimal, } +# This is a temporary workaround till we figure out how to support +# native lists in our API. +class List(list): + pass + NoneType = type(None) # FIXME(msolomon) we could make a SqlLiteral ABC and just type check. @@ -88,7 +93,7 @@ def convert_bind_vars(bind_variables): new_vars[key] = times.DateTimeToString(val) elif isinstance(val, datetime.date): new_vars[key] = times.DateToString(val) - elif isinstance(val, (int, long, float, str, NoneType)): + elif isinstance(val, (int, long, float, str, List, NoneType)): new_vars[key] = val else: # NOTE(msolomon) begrudgingly I allow this - we just have too much code diff --git a/test/queryservice_tests/cache_tests.py b/test/queryservice_tests/cache_tests.py index c19c71fd9d..52c3df977d 100644 --- a/test/queryservice_tests/cache_tests.py +++ b/test/queryservice_tests/cache_tests.py @@ -1,4 +1,5 @@ from vtdb import dbexceptions +from vtdb import field_types import framework import cache_cases1 @@ -36,6 +37,16 @@ class TestCache(framework.TestCase): else: self.fail("Did not receive exception") + def test_cache_list_arg(self): + cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List([3, 4, 32768])}) + self.assertEqual(cu.rowcount, 2) + cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List([3, 4])}) + self.assertEqual(cu.rowcount, 2) + cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List([3])}) + self.assertEqual(cu.rowcount, 1) + with self.assertRaises(dbexceptions.DatabaseError): + cu = self.env.execute("select * from vtocc_cached1 where eid in ::list", {"list": field_types.List()}) + def test_uncache(self): try: # Verify row cache is working diff --git a/test/queryservice_tests/nocache_tests.py b/test/queryservice_tests/nocache_tests.py index 83277bee23..4c737991e4 100644 --- a/test/queryservice_tests/nocache_tests.py +++ b/test/queryservice_tests/nocache_tests.py @@ -1,5 +1,6 @@ import time +from vtdb import field_types from vtdb import dbexceptions from vtdb import tablet as tablet_conn from vtdb import cursor @@ -39,6 +40,16 @@ class TestNocache(framework.TestCase): self.assertEqual(vstart.mget("Queries.TotalCount", 0)+1, vend.Queries.TotalCount) self.assertEqual(vstart.mget("Queries.Histograms.PASS_SELECT.Count", 0)+1, vend.Queries.Histograms.PASS_SELECT.Count) + def test_nocache_list_arg(self): + cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List([2, 3, 4])}) + self.assertEqual(cu.rowcount, 2) + cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List([3, 4])}) + self.assertEqual(cu.rowcount, 1) + cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List([3])}) + self.assertEqual(cu.rowcount, 1) + with self.assertRaises(dbexceptions.DatabaseError): + cu = self.env.execute("select * from vtocc_test where intval in ::list", {"list": field_types.List()}) + def test_commit(self): vstart = self.env.debug_vars() self.env.txlog.reset()