[release-15.0] Make sure to not push down expressions when not possible (#12607) (#12647)

* [gen4 planner] Make sure to not push down expressions when not possible (#12607)

* Fix random aggregation to not select Null column
* stop pushing down projections that should be evaluated at the vtgate level
* undo changes to AggregateRandom
* clean up code
* fix executor test mock

Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
Signed-off-by: Andres Taylor <andres@planetscale.com>

* Fix schema error

Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>

---------

Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
Signed-off-by: Andres Taylor <andres@planetscale.com>
Co-authored-by: Andres Taylor <andres@planetscale.com>
This commit is contained in:
Florent Poinsard 2023-03-21 15:28:03 +02:00 коммит произвёл GitHub
Родитель f0cfda7983
Коммит f56e64a5a6
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 222 добавлений и 22 удалений

Просмотреть файл

@ -33,7 +33,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {
deleteAll := func() {
_, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp")
tables := []string{"t9", "aggr_test", "t3", "t7_xxhash", "aggr_test_dates", "t7_xxhash_idx", "t1", "t2"}
tables := []string{"t9", "aggr_test", "t3", "t7_xxhash", "aggr_test_dates", "t7_xxhash_idx", "t1", "t2", "t11"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
@ -427,3 +427,13 @@ func TestScalarAggregate(t *testing.T) {
mcmp.Exec("insert into aggr_test(id, val1, val2) values(1,'a',1), (2,'A',1), (3,'b',1), (4,'c',3), (5,'c',4)")
mcmp.AssertMatches("select /*vt+ PLANNER=gen4 */ count(distinct val1) from aggr_test", `[[INT64(3)]]`)
}
func TestAggregationRandomOnAnAggregatedValue(t *testing.T) {
mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into t11(k, a, b) values (0, 100, 10), (10, 200, 20);")
mcmp.AssertMatchesNoOrder("select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from t11 where a = 100) A;",
`[[DECIMAL(100) DECIMAL(10) DECIMAL(10.0000)]]`)
}

Просмотреть файл

@ -70,3 +70,8 @@ CREATE TABLE t2 (
PRIMARY KEY (id)
) ENGINE InnoDB;
CREATE TABLE t11 (
k BIGINT PRIMARY KEY,
a INT,
b INT
);

Просмотреть файл

@ -123,6 +123,14 @@
"name": "hash"
}
]
},
"t11": {
"column_vindexes": [
{
"column": "k",
"name": "hash"
}
]
}
}
}

Просмотреть файл

@ -3744,6 +3744,40 @@ func TestSelectHexAndBit(t *testing.T) {
require.Equal(t, `[[UINT64(10) UINT64(10) UINT64(10) UINT64(10)]]`, fmt.Sprintf("%v", qr.Rows))
}
func TestSelectAggregationRandom(t *testing.T) {
cell := "aa"
hc := discovery.NewFakeHealthCheck(nil)
createSandbox(KsTestSharded).VSchema = executorVSchema
getSandbox(KsTestUnsharded).VSchema = unshardedVSchema
serv := newSandboxForCells([]string{cell})
resolver := newTestResolver(hc, serv, cell)
shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"}
var conns []*sandboxconn.SandboxConn
for _, shard := range shards {
sbc := hc.AddTestTablet(cell, shard, 1, KsTestSharded, shard, topodatapb.TabletType_PRIMARY, true, 1, nil)
conns = append(conns, sbc)
sbc.SetResults([]*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields("a|b", "int64|int64"),
"null|null",
)})
}
conns[0].SetResults([]*sqltypes.Result{sqltypes.MakeTestResult(
sqltypes.MakeTestFields("a|b", "int64|int64"),
"10|1",
)})
executor := createExecutor(serv, cell, resolver)
executor.pv = querypb.ExecuteOptions_Gen4
session := NewAutocommitSession(&vtgatepb.Session{})
rs, err := executor.Execute(context.Background(), "TestSelectCFC", session,
"select /*vt+ PLANNER=gen4 */ A.a, A.b, (A.a / A.b) as c from (select sum(a) as a, sum(b) as b from user) A", nil)
require.NoError(t, err)
assert.Equal(t, `[[INT64(10) INT64(1) DECIMAL(10.0000)]]`, fmt.Sprintf("%v", rs.Rows))
}
func TestMain(m *testing.M) {
_flag.ParseFlagsForTest()
os.Exit(m.Run())

Просмотреть файл

@ -22,6 +22,8 @@ import (
"strings"
"vitess.io/vitess/go/vt/vtgate/engine"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
"vitess.io/vitess/go/vt/vtgate/semantics"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
@ -575,6 +577,85 @@ func (qp *QueryProjection) GetColumnCount() int {
return len(qp.SelectExprs) - qp.AddedColumn
}
// NeedsProjecting returns true if we have projections that need to be evaluated at the vtgate level
// and can't be pushed down to MySQL
func (qp *QueryProjection) NeedsProjecting(
ctx *plancontext.PlanningContext,
pusher func(expr *sqlparser.AliasedExpr) (int, error),
) (needsVtGateEval bool, expressions []sqlparser.Expr, colNames []string, err error) {
for _, se := range qp.SelectExprs {
var ae *sqlparser.AliasedExpr
ae, err = se.GetAliasedExpr()
if err != nil {
return false, nil, nil, err
}
expr := ae.Expr
colNames = append(colNames, ae.ColumnName())
if _, isCol := expr.(*sqlparser.ColName); isCol {
offset, err := pusher(ae)
if err != nil {
return false, nil, nil, err
}
expressions = append(expressions, sqlparser.NewOffset(offset, expr))
continue
}
rExpr := sqlparser.Rewrite(sqlparser.CloneExpr(expr), func(cursor *sqlparser.Cursor) bool {
col, isCol := cursor.Node().(*sqlparser.ColName)
if !isCol {
return true
}
var tableInfo semantics.TableInfo
tableInfo, err = ctx.SemTable.TableInfoForExpr(col)
if err != nil {
return true
}
_, isDT := tableInfo.(*semantics.DerivedTable)
if !isDT {
return true
}
var rewritten sqlparser.Expr
rewritten, err = semantics.RewriteDerivedTableExpression(col, tableInfo)
if err != nil {
return false
}
if sqlparser.ContainsAggregation(rewritten) {
offset, tErr := pusher(&sqlparser.AliasedExpr{Expr: col})
if tErr != nil {
err = tErr
return false
}
cursor.Replace(sqlparser.NewOffset(offset, col))
}
return true
}, nil).(sqlparser.Expr)
if err != nil {
return
}
if !sqlparser.EqualsExpr(rExpr, expr) {
// if we changed the expression, it means that we have to evaluate the rest at the vtgate level
expressions = append(expressions, rExpr)
needsVtGateEval = true
continue
}
// we did not need to push any parts of this expression down. Let's check if we can push all of it
offset, err := pusher(ae)
if err != nil {
return false, nil, nil, err
}
expressions = append(expressions, sqlparser.NewOffset(offset, expr))
}
return
}
func checkForInvalidGroupingExpressions(expr sqlparser.Expr) error {
return sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
if _, isAggregate := node.(sqlparser.AggrFunc); isAggregate {

Просмотреть файл

@ -222,13 +222,13 @@ func newBuildSelectPlan(
return nil, nil, err
}
plan = optimizePlan(plan)
plan, err = planHorizon(ctx, plan, selStmt, true)
if err != nil {
return nil, nil, err
}
optimizePlan(plan)
sel, isSel := selStmt.(*sqlparser.Select)
if isSel {
if err := setMiscFunc(plan, sel); err != nil {
@ -249,25 +249,25 @@ func newBuildSelectPlan(
}
// optimizePlan removes unnecessary simpleProjections that have been created while planning
func optimizePlan(plan logicalPlan) logicalPlan {
newPlan, _ := visit(plan, func(plan logicalPlan) (bool, logicalPlan, error) {
this, ok := plan.(*simpleProjection)
if !ok {
return true, plan, nil
}
func optimizePlan(plan logicalPlan) {
for _, lp := range plan.Inputs() {
optimizePlan(lp)
}
input, ok := this.input.(*simpleProjection)
if !ok {
return true, plan, nil
}
this, ok := plan.(*simpleProjection)
if !ok {
return
}
for i, col := range this.eSimpleProj.Cols {
this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col]
}
this.input = input.input
return true, this, nil
})
return newPlan
input, ok := this.input.(*simpleProjection)
if !ok {
return
}
for i, col := range this.eSimpleProj.Cols {
this.eSimpleProj.Cols[i] = input.eSimpleProj.Cols[col]
}
this.input = input.input
}
func gen4UpdateStmtPlanner(

Просмотреть файл

@ -59,7 +59,8 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo
// a simpleProjection. We create a new Route that contains the derived table in the
// FROM clause. Meaning that, when we push expressions to the select list of this
// new Route, we do not want them to rewrite them.
if _, isSimpleProj := plan.(*simpleProjection); isSimpleProj {
sp, derivedTable := plan.(*simpleProjection)
if derivedTable {
oldRewriteDerivedExpr := ctx.RewriteDerivedExpr
defer func() {
ctx.RewriteDerivedExpr = oldRewriteDerivedExpr
@ -74,10 +75,11 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo
}
needsOrdering := len(hp.qp.OrderExprs) > 0
canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering
// If we still have a HAVING clause, it's because it could not be pushed to the WHERE,
// so it probably has aggregations
canShortcut := isRoute && hp.sel.Having == nil && !needsOrdering
switch {
case hp.qp.NeedsAggregation() || hp.sel.Having != nil:
plan, err = hp.planAggregations(ctx, plan)
@ -91,6 +93,26 @@ func (hp *horizonPlanning) planHorizon(ctx *plancontext.PlanningContext, plan lo
if err != nil {
return nil, err
}
case derivedTable:
pusher := func(ae *sqlparser.AliasedExpr) (int, error) {
offset, _, err := pushProjection(ctx, ae, sp.input, true, true, false)
return offset, err
}
needsVtGate, projections, colNames, err := hp.qp.NeedsProjecting(ctx, pusher)
if err != nil {
return nil, err
}
if !needsVtGate {
break
}
// there were some expressions we could not push down entirely,
// so replace the simpleProjection with a real projection
plan = &projection{
source: sp.input,
columns: projections,
columnNames: colNames,
}
default:
err = pushProjections(ctx, plan, hp.qp.SelectExprs)
if err != nil {

Просмотреть файл

@ -4956,5 +4956,45 @@
"user.user_extra"
]
}
},
{
"comment": "Aggregations from derived table used in arithmetic outside derived table",
"query": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A",
"v3-plan": "unsupported: expression on results of a cross-shard subquery",
"gen4-plan": {
"QueryType": "SELECT",
"Original": "select A.a, A.b, (A.a / A.b) as d from (select sum(a) as a, sum(b) as b from user) A",
"Instructions": {
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] as a",
"[COLUMN 1] as b",
"[COLUMN 0] / [COLUMN 1] as d"
],
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Scalar",
"Aggregates": "sum(0) AS a, sum(1) AS b",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select sum(a) as a, sum(b) as b from `user` where 1 != 1",
"Query": "select sum(a) as a, sum(b) as b from `user`",
"Table": "`user`"
}
]
}
]
},
"TablesUsed": [
"user.user"
]
}
}
]