Fix scalar aggregation engine primitive for column truncation (#12468) (#12472)

* fix: scalar aggregation truncation



* test: added scalar aggr engine unit test



* remove onecase change



---------

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
Co-authored-by: Harshit Gangal <harshit@planetscale.com>
This commit is contained in:
Manan Gupta 2023-02-27 12:23:02 +05:30 коммит произвёл GitHub
Родитель 911f246149
Коммит 7a594612dc
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 66 добавлений и 8 удалений

Просмотреть файл

@ -400,3 +400,30 @@ func TestAggregateLeftJoin(t *testing.T) {
mcmp.AssertMatches("SELECT count(*) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[INT64(2)]]`)
mcmp.AssertMatches("SELECT sum(t1.shardkey) FROM t1 LEFT JOIN t2 ON t1.t1_id = t2.id", `[[DECIMAL(1)]]`)
}
// TestScalarAggregate tests validates that only count is returned and no additional field is returned.gst
func TestScalarAggregate(t *testing.T) {
// disable schema tracking to have weight_string column added to query send down to mysql.
clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, "--schema_change_signal=false")
require.NoError(t,
clusterInstance.RestartVtgate())
// update vtgate params
vtParams = clusterInstance.GetVTParams(keyspaceName)
defer func() {
// roll it back
clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, "--schema_change_signal")
require.NoError(t,
clusterInstance.RestartVtgate())
// update vtgate params
vtParams = clusterInstance.GetVTParams(keyspaceName)
}()
mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`)
}

Просмотреть файл

@ -122,7 +122,7 @@ func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bind
}
out.Rows = [][]sqltypes.Value{resultRow}
return out, nil
return out.Truncate(sa.TruncateColumnCount), nil
}
// TryStreamExecute implements the Primitive interface

Просмотреть файл

@ -106,16 +106,16 @@ func TestEmptyRows(outer *testing.T) {
func TestScalarAggregateStreamExecute(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"count(*)",
"uint64",
"col|weight_string(col)",
"uint64|varbinary",
)
fp := &fakePrimitive{
allResultsInOneCall: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(fields,
"1",
"1|null",
), sqltypes.MakeTestResult(fields,
"3",
"3|null",
)},
}
@ -141,3 +141,34 @@ func TestScalarAggregateStreamExecute(t *testing.T) {
got := fmt.Sprintf("%v", results[1].Rows)
assert.Equal("[[UINT64(4)]]", got)
}
// TestScalarAggregateExecuteTruncate checks if truncate works
func TestScalarAggregateExecuteTruncate(t *testing.T) {
assert := assert.New(t)
fields := sqltypes.MakeTestFields(
"col|weight_string(col)",
"uint64|varbinary",
)
fp := &fakePrimitive{
allResultsInOneCall: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(fields,
"1|null", "3|null",
)},
}
oa := &ScalarAggregate{
Aggregates: []*AggregateParams{{
Opcode: AggregateSum,
Col: 0,
}},
Input: fp,
TruncateColumnCount: 1,
PreProcess: true,
}
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, true)
assert.NoError(err)
assert.Equal("[[UINT64(4)]]", fmt.Sprintf("%v", qr.Rows))
}

Просмотреть файл

@ -3614,7 +3614,7 @@ func TestSelectAggregationData(t *testing.T) {
}{
{
sql: `select count(distinct col) from user`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col", "int64"), "1", "2", "2", "3"),
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|weight_string(col)", "int64|varbinary"), "1|NULL", "2|NULL", "2|NULL", "3|NULL"),
expSandboxQ: "select col, weight_string(col) from `user` group by col, weight_string(col) order by col asc",
expField: `[name:"count(distinct col)" type:INT64]`,
expRow: `[[INT64(3)]]`,
@ -3628,14 +3628,14 @@ func TestSelectAggregationData(t *testing.T) {
},
{
sql: `select col, count(*) from user group by col`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)", "int64|int64"), "1|3"),
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)|weight_string(col)", "int64|int64|varbinary"), "1|3|NULL"),
expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc",
expField: `[name:"col" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(1) INT64(24)]]`,
},
{
sql: `select col, count(*) from user group by col limit 2`,
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)", "int64|int64"), "1|2", "2|1", "3|4"),
sandboxRes: sqltypes.MakeTestResult(sqltypes.MakeTestFields("col|count(*)|weight_string(col)", "int64|int64|varbinary"), "1|2|NULL", "2|1|NULL", "3|4|NULL"),
expSandboxQ: "select col, count(*), weight_string(col) from `user` group by col, weight_string(col) order by col asc limit :__upper_limit",
expField: `[name:"col" type:INT64 name:"count(*)" type:INT64]`,
expRow: `[[INT64(1) INT64(16)] [INT64(2) INT64(8)]]`,