diff --git a/go/vt/tabletserver/query_splitter.go b/go/vt/tabletserver/query_splitter.go index eac4018b68..664a4f1d79 100644 --- a/go/vt/tabletserver/query_splitter.go +++ b/go/vt/tabletserver/query_splitter.go @@ -28,8 +28,8 @@ type QuerySplitter struct { } const ( - startBindVarName = ":_splitquery_start" - endBindVarName = ":_splitquery_end" + startBindVarName = "_splitquery_start" + endBindVarName = "_splitquery_end" ) // NewQuerySplitter creates a new QuerySplitter. query is the original query @@ -157,7 +157,7 @@ func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars m startClause = &sqlparser.ComparisonExpr{ Operator: sqlparser.AST_GE, Left: pk, - Right: sqlparser.ValArg([]byte(startBindVarName)), + Right: sqlparser.ValArg([]byte(":" + startBindVarName)), } if start.IsNumeric() { v, _ := start.ParseInt64() @@ -174,7 +174,7 @@ func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars m endClause = &sqlparser.ComparisonExpr{ Operator: sqlparser.AST_LT, Left: pk, - Right: sqlparser.ValArg([]byte(endBindVarName)), + Right: sqlparser.ValArg([]byte(":" + endBindVarName)), } if end.IsNumeric() { v, _ := end.ParseInt64() diff --git a/go/vt/tabletserver/query_splitter_test.go b/go/vt/tabletserver/query_splitter_test.go index 5cf1a66849..353150e5f3 100644 --- a/go/vt/tabletserver/query_splitter_test.go +++ b/go/vt/tabletserver/query_splitter_test.go @@ -165,7 +165,7 @@ func TestGetWhereClause(t *testing.T) { bindVars = make(map[string]interface{}) bindVars[":count"] = 300 clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue) - want = " where (count > :count) and (id >= " + startBindVarName + ")" + want = " where (count > :count) and (id >= :" + startBindVarName + ")" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) @@ -182,7 +182,7 @@ func TestGetWhereClause(t *testing.T) { end, _ := sqltypes.BuildValue(endVal) bindVars = make(map[string]interface{}) clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end) - want = " where (count > :count) and (id < " + endBindVarName + ")" + want = " where (count > :count) and (id < :" + endBindVarName + ")" got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) @@ -198,7 +198,7 @@ func TestGetWhereClause(t *testing.T) { // Set both bounds, should add two conditions to where clause bindVars = make(map[string]interface{}) clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end) - want = fmt.Sprintf(" where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName) + want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName) got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) @@ -219,7 +219,7 @@ func TestGetWhereClause(t *testing.T) { bindVars = make(map[string]interface{}) // Set both bounds, should add two conditions to where clause clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end) - want = fmt.Sprintf(" where id >= %s and id < %s", startBindVarName, endBindVarName) + want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName) got = sqlparser.String(clause) if !reflect.DeepEqual(got, want) { t.Errorf("incorrect where clause, got:%v, want:%v", got, want) @@ -355,18 +355,18 @@ func TestSplitQuery(t *testing.T) { } want := []proto.BoundQuery{ { - Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")", + Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")", BindVariables: map[string]interface{}{endBindVarName: int64(100)}, }, { - Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName), + Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName), BindVariables: map[string]interface{}{ startBindVarName: int64(100), endBindVarName: int64(200), }, }, { - Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")", + Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")", BindVariables: map[string]interface{}{startBindVarName: int64(200)}, }, } @@ -411,18 +411,18 @@ func TestSplitQueryFractionalColumn(t *testing.T) { } want := []proto.BoundQuery{ { - Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")", + Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")", BindVariables: map[string]interface{}{endBindVarName: 170.5}, }, { - Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName), + Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName), BindVariables: map[string]interface{}{ startBindVarName: 170.5, endBindVarName: 330.5, }, }, { - Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")", + Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")", BindVariables: map[string]interface{}{startBindVarName: 330.5}, }, } @@ -451,18 +451,18 @@ func TestSplitQueryStringColumn(t *testing.T) { } want := []proto.BoundQuery{ { - Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")", + Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")", BindVariables: map[string]interface{}{endBindVarName: hexToByteUInt64(0x55555555)[4:]}, }, { - Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName), + Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName), BindVariables: map[string]interface{}{ startBindVarName: hexToByteUInt64(0x55555555)[4:], endBindVarName: hexToByteUInt64(0xAAAAAAAA)[4:], }, }, { - Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")", + Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")", BindVariables: map[string]interface{}{startBindVarName: hexToByteUInt64(0xAAAAAAAA)[4:]}, }, } diff --git a/go/vt/tabletserver/sqlquery.go b/go/vt/tabletserver/sqlquery.go index 9bf862b6bd..fa0b0c8b21 100644 --- a/go/vt/tabletserver/sqlquery.go +++ b/go/vt/tabletserver/sqlquery.go @@ -679,6 +679,9 @@ func getColumnType(qre *QueryExecutor, columnName, tableName string) (int64, err return mproto.VT_NULL, err } defer conn.Recycle() + // TODO(shengzhe): use AST to represent the query to avoid sql injection. + // current code is safe because QuerySplitter.validateQuery is called before + // calling this. query := fmt.Sprintf("SELECT %v FROM %v LIMIT 0", columnName, tableName) result, err := qre.execSQL(conn, query, true) if err != nil { @@ -696,6 +699,9 @@ func getColumnMinMax(qre *QueryExecutor, columnName, tableName string) (*mproto. return nil, err } defer conn.Recycle() + // TODO(shengzhe): use AST to represent the query to avoid sql injection. + // current code is safe because QuerySplitter.validateQuery is called before + // calling this. minMaxSQL := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", columnName, columnName, tableName) return qre.execSQL(conn, minMaxSQL, true) } diff --git a/java/vtgate-client/src/test/java/com/youtube/vitess/vtgate/integration/VtGateIT.java b/java/vtgate-client/src/test/java/com/youtube/vitess/vtgate/integration/VtGateIT.java index a68062aede..ed35d3d4e3 100644 --- a/java/vtgate-client/src/test/java/com/youtube/vitess/vtgate/integration/VtGateIT.java +++ b/java/vtgate-client/src/test/java/com/youtube/vitess/vtgate/integration/VtGateIT.java @@ -1,6 +1,7 @@ package com.youtube.vitess.vtgate.integration; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.primitives.UnsignedLong; import com.youtube.vitess.vtgate.BindVariable; @@ -303,13 +304,16 @@ public class VtGateIT { public void testSplitQueryMultipleSplitsPerShard() throws Exception { int rowCount = 30; Util.insertRows(testEnv, 1, 30); - List expectedSqls = - Lists.newArrayList("select id, keyspace_id from vtgate_test where id < 10", - "select id, keyspace_id from vtgate_test where id < 11", - "select id, keyspace_id from vtgate_test where id >= 10 and id < 19", - "select id, keyspace_id from vtgate_test where id >= 11 and id < 19", - "select id, keyspace_id from vtgate_test where id >= 19", - "select id, keyspace_id from vtgate_test where id >= 19"); + Map> expectedSqls = Maps.newHashMap(); + expectedSqls.put("select id, keyspace_id from vtgate_test where id < :_splitquery_end", + Lists.newArrayList(BindVariable.forInt("_splitquery_end", 10))); + expectedSqls.put( + "select id, keyspace_id from vtgate_test where id >= :_splitquery_start " + + "and id < :_splitquery_end", + Lists.newArrayList(BindVariable.forInt("_splitquery_start", 10), + BindVariable.forInt("_splitquery_end", 19))); + expectedSqls.put("select id, keyspace_id from vtgate_test where id >= :_splitquery_start", + Lists.newArrayList(BindVariable.forInt("_splitquery_start", 19))); Util.waitForTablet("rdonly", rowCount, 3, testEnv); VtGate vtgate = VtGate.connect("localhost:" + testEnv.getPort(), 0, testEnv.getRpcClientFactory()); int splitCount = 6; @@ -322,11 +326,11 @@ public class VtGateIT { Set shardsInSplits = new HashSet<>(); for (Query q : queries.keySet()) { String sql = q.getSql(); - Assert.assertTrue(expectedSqls.contains(sql)); - expectedSqls.remove(sql); + List bindVars = expectedSqls.get(sql); + Assert.assertNotNull(bindVars); Assert.assertEquals("test_keyspace", q.getKeyspace()); Assert.assertEquals("rdonly", q.getTabletType()); - Assert.assertEquals(0, q.getBindVars().size()); + Assert.assertEquals(bindVars.size(), q.getBindVars().size()); Assert.assertEquals(null, q.getKeyspaceIds()); String start = Hex.encodeHexString(q.getKeyRanges().get(0).get("Start")); String end = Hex.encodeHexString(q.getKeyRanges().get(0).get("End")); @@ -335,7 +339,6 @@ public class VtGateIT { // Verify the keyrange queries in splits cover the entire keyspace Assert.assertTrue(shardsInSplits.containsAll(testEnv.getShardKidMap().keySet())); - Assert.assertTrue(expectedSqls.size() == 0); } @Test diff --git a/test/custom_sharding.py b/test/custom_sharding.py index 0f90ae235e..a5b314b883 100755 --- a/test/custom_sharding.py +++ b/test/custom_sharding.py @@ -193,7 +193,8 @@ primary key (id) rows = {} for q in s: qr = utils.vtgate.execute_shard(q['QueryShard']['Sql'], - 'test_keyspace', ",".join(q['QueryShard']['Shards'])) + 'test_keyspace', ",".join(q['QueryShard']['Shards']), + tablet_type='master', bindvars=q['QueryShard']['BindVariables']) for r in qr['Rows']: id = int(r[0]) rows[id] = r[1]