зеркало из https://github.com/github/vitess-gh.git
fix splitquery related tests.
1. bind variables in SplitQuery should not contain leading ':'. 2. fix custom_sharding.py test to pass through returned bind variables from VTGate. 3. fix SplitQuery tests in Java.
This commit is contained in:
Родитель
0c03e61532
Коммит
3736372fee
|
@ -28,8 +28,8 @@ type QuerySplitter struct {
|
|||
}
|
||||
|
||||
const (
|
||||
startBindVarName = ":_splitquery_start"
|
||||
endBindVarName = ":_splitquery_end"
|
||||
startBindVarName = "_splitquery_start"
|
||||
endBindVarName = "_splitquery_end"
|
||||
)
|
||||
|
||||
// NewQuerySplitter creates a new QuerySplitter. query is the original query
|
||||
|
@ -157,7 +157,7 @@ func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars m
|
|||
startClause = &sqlparser.ComparisonExpr{
|
||||
Operator: sqlparser.AST_GE,
|
||||
Left: pk,
|
||||
Right: sqlparser.ValArg([]byte(startBindVarName)),
|
||||
Right: sqlparser.ValArg([]byte(":" + startBindVarName)),
|
||||
}
|
||||
if start.IsNumeric() {
|
||||
v, _ := start.ParseInt64()
|
||||
|
@ -174,7 +174,7 @@ func (qs *QuerySplitter) getWhereClause(whereClause *sqlparser.Where, bindVars m
|
|||
endClause = &sqlparser.ComparisonExpr{
|
||||
Operator: sqlparser.AST_LT,
|
||||
Left: pk,
|
||||
Right: sqlparser.ValArg([]byte(endBindVarName)),
|
||||
Right: sqlparser.ValArg([]byte(":" + endBindVarName)),
|
||||
}
|
||||
if end.IsNumeric() {
|
||||
v, _ := end.ParseInt64()
|
||||
|
|
|
@ -165,7 +165,7 @@ func TestGetWhereClause(t *testing.T) {
|
|||
bindVars = make(map[string]interface{})
|
||||
bindVars[":count"] = 300
|
||||
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, nilValue)
|
||||
want = " where (count > :count) and (id >= " + startBindVarName + ")"
|
||||
want = " where (count > :count) and (id >= :" + startBindVarName + ")"
|
||||
got = sqlparser.String(clause)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
|
||||
|
@ -182,7 +182,7 @@ func TestGetWhereClause(t *testing.T) {
|
|||
end, _ := sqltypes.BuildValue(endVal)
|
||||
bindVars = make(map[string]interface{})
|
||||
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, nilValue, end)
|
||||
want = " where (count > :count) and (id < " + endBindVarName + ")"
|
||||
want = " where (count > :count) and (id < :" + endBindVarName + ")"
|
||||
got = sqlparser.String(clause)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
|
||||
|
@ -198,7 +198,7 @@ func TestGetWhereClause(t *testing.T) {
|
|||
// Set both bounds, should add two conditions to where clause
|
||||
bindVars = make(map[string]interface{})
|
||||
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
|
||||
want = fmt.Sprintf(" where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName)
|
||||
want = fmt.Sprintf(" where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName)
|
||||
got = sqlparser.String(clause)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
|
||||
|
@ -219,7 +219,7 @@ func TestGetWhereClause(t *testing.T) {
|
|||
bindVars = make(map[string]interface{})
|
||||
// Set both bounds, should add two conditions to where clause
|
||||
clause = splitter.getWhereClause(splitter.sel.Where, bindVars, start, end)
|
||||
want = fmt.Sprintf(" where id >= %s and id < %s", startBindVarName, endBindVarName)
|
||||
want = fmt.Sprintf(" where id >= :%s and id < :%s", startBindVarName, endBindVarName)
|
||||
got = sqlparser.String(clause)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("incorrect where clause, got:%v, want:%v", got, want)
|
||||
|
@ -355,18 +355,18 @@ func TestSplitQuery(t *testing.T) {
|
|||
}
|
||||
want := []proto.BoundQuery{
|
||||
{
|
||||
Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")",
|
||||
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
|
||||
BindVariables: map[string]interface{}{endBindVarName: int64(100)},
|
||||
},
|
||||
{
|
||||
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName),
|
||||
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
|
||||
BindVariables: map[string]interface{}{
|
||||
startBindVarName: int64(100),
|
||||
endBindVarName: int64(200),
|
||||
},
|
||||
},
|
||||
{
|
||||
Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")",
|
||||
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
|
||||
BindVariables: map[string]interface{}{startBindVarName: int64(200)},
|
||||
},
|
||||
}
|
||||
|
@ -411,18 +411,18 @@ func TestSplitQueryFractionalColumn(t *testing.T) {
|
|||
}
|
||||
want := []proto.BoundQuery{
|
||||
{
|
||||
Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")",
|
||||
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
|
||||
BindVariables: map[string]interface{}{endBindVarName: 170.5},
|
||||
},
|
||||
{
|
||||
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName),
|
||||
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
|
||||
BindVariables: map[string]interface{}{
|
||||
startBindVarName: 170.5,
|
||||
endBindVarName: 330.5,
|
||||
},
|
||||
},
|
||||
{
|
||||
Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")",
|
||||
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
|
||||
BindVariables: map[string]interface{}{startBindVarName: 330.5},
|
||||
},
|
||||
}
|
||||
|
@ -451,18 +451,18 @@ func TestSplitQueryStringColumn(t *testing.T) {
|
|||
}
|
||||
want := []proto.BoundQuery{
|
||||
{
|
||||
Sql: "select * from test_table where (count > :count) and (id < " + endBindVarName + ")",
|
||||
Sql: "select * from test_table where (count > :count) and (id < :" + endBindVarName + ")",
|
||||
BindVariables: map[string]interface{}{endBindVarName: hexToByteUInt64(0x55555555)[4:]},
|
||||
},
|
||||
{
|
||||
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= %s and id < %s)", startBindVarName, endBindVarName),
|
||||
Sql: fmt.Sprintf("select * from test_table where (count > :count) and (id >= :%s and id < :%s)", startBindVarName, endBindVarName),
|
||||
BindVariables: map[string]interface{}{
|
||||
startBindVarName: hexToByteUInt64(0x55555555)[4:],
|
||||
endBindVarName: hexToByteUInt64(0xAAAAAAAA)[4:],
|
||||
},
|
||||
},
|
||||
{
|
||||
Sql: "select * from test_table where (count > :count) and (id >= " + startBindVarName + ")",
|
||||
Sql: "select * from test_table where (count > :count) and (id >= :" + startBindVarName + ")",
|
||||
BindVariables: map[string]interface{}{startBindVarName: hexToByteUInt64(0xAAAAAAAA)[4:]},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -679,6 +679,9 @@ func getColumnType(qre *QueryExecutor, columnName, tableName string) (int64, err
|
|||
return mproto.VT_NULL, err
|
||||
}
|
||||
defer conn.Recycle()
|
||||
// TODO(shengzhe): use AST to represent the query to avoid sql injection.
|
||||
// current code is safe because QuerySplitter.validateQuery is called before
|
||||
// calling this.
|
||||
query := fmt.Sprintf("SELECT %v FROM %v LIMIT 0", columnName, tableName)
|
||||
result, err := qre.execSQL(conn, query, true)
|
||||
if err != nil {
|
||||
|
@ -696,6 +699,9 @@ func getColumnMinMax(qre *QueryExecutor, columnName, tableName string) (*mproto.
|
|||
return nil, err
|
||||
}
|
||||
defer conn.Recycle()
|
||||
// TODO(shengzhe): use AST to represent the query to avoid sql injection.
|
||||
// current code is safe because QuerySplitter.validateQuery is called before
|
||||
// calling this.
|
||||
minMaxSQL := fmt.Sprintf("SELECT MIN(%v), MAX(%v) FROM %v", columnName, columnName, tableName)
|
||||
return qre.execSQL(conn, minMaxSQL, true)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package com.youtube.vitess.vtgate.integration;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.primitives.UnsignedLong;
|
||||
|
||||
import com.youtube.vitess.vtgate.BindVariable;
|
||||
|
@ -303,13 +304,16 @@ public class VtGateIT {
|
|||
public void testSplitQueryMultipleSplitsPerShard() throws Exception {
|
||||
int rowCount = 30;
|
||||
Util.insertRows(testEnv, 1, 30);
|
||||
List<String> expectedSqls =
|
||||
Lists.newArrayList("select id, keyspace_id from vtgate_test where id < 10",
|
||||
"select id, keyspace_id from vtgate_test where id < 11",
|
||||
"select id, keyspace_id from vtgate_test where id >= 10 and id < 19",
|
||||
"select id, keyspace_id from vtgate_test where id >= 11 and id < 19",
|
||||
"select id, keyspace_id from vtgate_test where id >= 19",
|
||||
"select id, keyspace_id from vtgate_test where id >= 19");
|
||||
Map<String, List<BindVariable>> expectedSqls = Maps.newHashMap();
|
||||
expectedSqls.put("select id, keyspace_id from vtgate_test where id < :_splitquery_end",
|
||||
Lists.newArrayList(BindVariable.forInt("_splitquery_end", 10)));
|
||||
expectedSqls.put(
|
||||
"select id, keyspace_id from vtgate_test where id >= :_splitquery_start "
|
||||
+ "and id < :_splitquery_end",
|
||||
Lists.newArrayList(BindVariable.forInt("_splitquery_start", 10),
|
||||
BindVariable.forInt("_splitquery_end", 19)));
|
||||
expectedSqls.put("select id, keyspace_id from vtgate_test where id >= :_splitquery_start",
|
||||
Lists.newArrayList(BindVariable.forInt("_splitquery_start", 19)));
|
||||
Util.waitForTablet("rdonly", rowCount, 3, testEnv);
|
||||
VtGate vtgate = VtGate.connect("localhost:" + testEnv.getPort(), 0, testEnv.getRpcClientFactory());
|
||||
int splitCount = 6;
|
||||
|
@ -322,11 +326,11 @@ public class VtGateIT {
|
|||
Set<String> shardsInSplits = new HashSet<>();
|
||||
for (Query q : queries.keySet()) {
|
||||
String sql = q.getSql();
|
||||
Assert.assertTrue(expectedSqls.contains(sql));
|
||||
expectedSqls.remove(sql);
|
||||
List<BindVariable> bindVars = expectedSqls.get(sql);
|
||||
Assert.assertNotNull(bindVars);
|
||||
Assert.assertEquals("test_keyspace", q.getKeyspace());
|
||||
Assert.assertEquals("rdonly", q.getTabletType());
|
||||
Assert.assertEquals(0, q.getBindVars().size());
|
||||
Assert.assertEquals(bindVars.size(), q.getBindVars().size());
|
||||
Assert.assertEquals(null, q.getKeyspaceIds());
|
||||
String start = Hex.encodeHexString(q.getKeyRanges().get(0).get("Start"));
|
||||
String end = Hex.encodeHexString(q.getKeyRanges().get(0).get("End"));
|
||||
|
@ -335,7 +339,6 @@ public class VtGateIT {
|
|||
|
||||
// Verify the keyrange queries in splits cover the entire keyspace
|
||||
Assert.assertTrue(shardsInSplits.containsAll(testEnv.getShardKidMap().keySet()));
|
||||
Assert.assertTrue(expectedSqls.size() == 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -193,7 +193,8 @@ primary key (id)
|
|||
rows = {}
|
||||
for q in s:
|
||||
qr = utils.vtgate.execute_shard(q['QueryShard']['Sql'],
|
||||
'test_keyspace', ",".join(q['QueryShard']['Shards']))
|
||||
'test_keyspace', ",".join(q['QueryShard']['Shards']),
|
||||
tablet_type='master', bindvars=q['QueryShard']['BindVariables'])
|
||||
for r in qr['Rows']:
|
||||
id = int(r[0])
|
||||
rows[id] = r[1]
|
||||
|
|
Загрузка…
Ссылка в новой задаче