1. bind variables in SplitQuery should not contain leading ':'.
2. fix custom_sharding.py test to pass through returned bind variables from VTGate.
3. fix SplitQuery tests in Java.
This commit is contained in:
Shengzhe Yao 2015-08-17 15:33:16 -07:00
Родитель 0c03e61532
Коммит 3736372fee
5 изменённых файлов: 39 добавлений и 29 удалений

Просмотреть файл

@ -28,8 +28,8 @@ type QuerySplitter struct {
}
const (
startBindVarName = ":_splitquery_start"
endBindVarName = ":_splitquery_end"
startBindVarName = "_splitquery_start"
endBindVarName = "_splitquery_end"
)
// NewQuerySplitter creates a new QuerySplitter. query is the original query
@ -157,7 +157,7 @@ func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars m
startClause = &sqlparser.ComparisonExpr{
Operator: sqlparser.AST_GE,
Left: pk,
Right: sqlparser.ValArg([]byte(startBindVarName)),
Right: sqlparser.ValArg([]byte(":" + startBindVarName)),
}
if start.IsNumeric() {
v, _ := start.ParseInt64()
@ -174,7 +174,7 @@ func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars m
endClause = &sqlparser.ComparisonExpr{
Operator: sqlparser.AST_LT,
Left: pk,
Right: sqlparser.ValArg([]byte(endBindVarName)),
Right: sqlparser.ValArg([]byte(":" + endBindVarName)),
}
if end.IsNumeric() {
v, _ := end.ParseInt64()

Просмотреть файл

@ -165,7 +165,7 @@ func TestGetWhereClause(t *testing.T) {
bindVars = make(map[string]interface{})
bindVars[":count"] = 300
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue)
want = " where (count > :count) and (id >= " + startBindVarName + ")"
want = " where (count > :count) and (id >= :" + startBindVarName + ")"
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
@ -182,7 +182,7 @@ func TestGetWhereClause(t *testing.T) {
end, _ := sqltypes.BuildValue(endVal)
bindVars = make(map[string]interface{})
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end)
want = " where (count > :count) and (id < " + endBindVarName + ")"
want = " where (count > :count) and (id < :" + endBindVarName + ")"
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
@ -198,7 +198,7 @@ func TestGetWhereClause(t *testing.T) {
// Set both bounds, should add two conditions to where clause
bindVars = make(map[string]interface{})
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
want = fmt.Sprintf(" where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName)
want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName)
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
@ -219,7 +219,7 @@ func TestGetWhereClause(t *testing.T) {
bindVars = make(map[string]interface{})
// Set both bounds, should add two conditions to where clause
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
want = fmt.Sprintf(" where id >= %s and id < %s", startBindVarName, endBindVarName)
want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName)
got = sqlparser.String(clause)
if !reflect.DeepEqual(got, want) {
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
@ -355,18 +355,18 @@ func TestSplitQuery(t *testing.T) {
}
want := []proto.BoundQuery{
{
Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")",
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
BindVariables: map[string]interface{}{endBindVarName: int64(100)},
},
{
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName),
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
BindVariables: map[string]interface{}{
startBindVarName: int64(100),
endBindVarName: int64(200),
},
},
{
Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")",
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
BindVariables: map[string]interface{}{startBindVarName: int64(200)},
},
}
@ -411,18 +411,18 @@ func TestSplitQueryFractionalColumn(t *testing.T) {
}
want := []proto.BoundQuery{
{
Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")",
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
BindVariables: map[string]interface{}{endBindVarName: 170.5},
},
{
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName),
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
BindVariables: map[string]interface{}{
startBindVarName: 170.5,
endBindVarName: 330.5,
},
},
{
Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")",
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
BindVariables: map[string]interface{}{startBindVarName: 330.5},
},
}
@ -451,18 +451,18 @@ func TestSplitQueryStringColumn(t *testing.T) {
}
want := []proto.BoundQuery{
{
Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")",
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
BindVariables: map[string]interface{}{endBindVarName: hexToByteUInt64(0x55555555)[4:]},
},
{
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName),
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
BindVariables: map[string]interface{}{
startBindVarName: hexToByteUInt64(0x55555555)[4:],
endBindVarName: hexToByteUInt64(0xAAAAAAAA)[4:],
},
},
{
Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")",
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
BindVariables: map[string]interface{}{startBindVarName: hexToByteUInt64(0xAAAAAAAA)[4:]},
},
}

Просмотреть файл

@ -679,6 +679,9 @@ func getColumnType(qre *QueryExecutor, columnName, tableName string) (int64, err
return mproto.VT_NULL, err
}
defer conn.Recycle()
// TODO(shengzhe): use AST to represent the query to avoid sql injection.
// current code is safe because QuerySplitter.validateQuery is called before
// calling this.
query := fmt.Sprintf("SELECT %v FROM %v LIMIT 0", columnName, tableName)
result, err := qre.execSQL(conn, query, true)
if err != nil {
@ -696,6 +699,9 @@ func getColumnMinMax(qre *QueryExecutor, columnName, tableName string) (*mproto.
return nil, err
}
defer conn.Recycle()
// TODO(shengzhe): use AST to represent the query to avoid sql injection.
// current code is safe because QuerySplitter.validateQuery is called before
// calling this.
minMaxSQL := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", columnName, columnName, tableName)
return qre.execSQL(conn, minMaxSQL, true)
}

Просмотреть файл

@ -1,6 +1,7 @@
package com.youtube.vitess.vtgate.integration;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.primitives.UnsignedLong;
import com.youtube.vitess.vtgate.BindVariable;
@ -303,13 +304,16 @@ public class VtGateIT {
public void testSplitQueryMultipleSplitsPerShard() throws Exception {
int rowCount = 30;
Util.insertRows(testEnv, 1, 30);
List<String> expectedSqls =
Lists.newArrayList("select id, keyspace_id from vtgate_test where id < 10",
"select id, keyspace_id from vtgate_test where id < 11",
"select id, keyspace_id from vtgate_test where id >= 10 and id < 19",
"select id, keyspace_id from vtgate_test where id >= 11 and id < 19",
"select id, keyspace_id from vtgate_test where id >= 19",
"select id, keyspace_id from vtgate_test where id >= 19");
Map<String, List<BindVariable>> expectedSqls = Maps.newHashMap();
expectedSqls.put("select id, keyspace_id from vtgate_test where id < :_splitquery_end",
Lists.newArrayList(BindVariable.forInt("_splitquery_end", 10)));
expectedSqls.put(
"select id, keyspace_id from vtgate_test where id >= :_splitquery_start "
+ "and id < :_splitquery_end",
Lists.newArrayList(BindVariable.forInt("_splitquery_start", 10),
BindVariable.forInt("_splitquery_end", 19)));
expectedSqls.put("select id, keyspace_id from vtgate_test where id >= :_splitquery_start",
Lists.newArrayList(BindVariable.forInt("_splitquery_start", 19)));
Util.waitForTablet("rdonly", rowCount, 3, testEnv);
VtGate vtgate = VtGate.connect("localhost:" + testEnv.getPort(), 0, testEnv.getRpcClientFactory());
int splitCount = 6;
@ -322,11 +326,11 @@ public class VtGateIT {
Set<String> shardsInSplits = new HashSet<>();
for (Query q : queries.keySet()) {
String sql = q.getSql();
Assert.assertTrue(expectedSqls.contains(sql));
expectedSqls.remove(sql);
List<BindVariable> bindVars = expectedSqls.get(sql);
Assert.assertNotNull(bindVars);
Assert.assertEquals("test_keyspace", q.getKeyspace());
Assert.assertEquals("rdonly", q.getTabletType());
Assert.assertEquals(0, q.getBindVars().size());
Assert.assertEquals(bindVars.size(), q.getBindVars().size());
Assert.assertEquals(null, q.getKeyspaceIds());
String start = Hex.encodeHexString(q.getKeyRanges().get(0).get("Start"));
String end = Hex.encodeHexString(q.getKeyRanges().get(0).get("End"));
@ -335,7 +339,6 @@ public class VtGateIT {
// Verify the keyrange queries in splits cover the entire keyspace
Assert.assertTrue(shardsInSplits.containsAll(testEnv.getShardKidMap().keySet()));
Assert.assertTrue(expectedSqls.size() == 0);
}
@Test

Просмотреть файл

@ -193,7 +193,8 @@ primary key (id)
rows = {}
for q in s:
qr = utils.vtgate.execute_shard(q['QueryShard']['Sql'],
'test_keyspace', ",".join(q['QueryShard']['Shards']))
'test_keyspace', ",".join(q['QueryShard']['Shards']),
tablet_type='master', bindvars=q['QueryShard']['BindVariables'])
for r in qr['Rows']:
id = int(r[0])
rows[id] = r[1]