From b54cc6270756b9bb6309d0a5b7fbee040d902914 Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Sun, 16 Jun 2019 21:13:56 -0700 Subject: [PATCH] vtgate sql: more composable primitives Remove some tight coupling that existed between various primitives. In particular, orderedAggregate now points at builder instead of a route. Also, mergeSort takes on some of the work that route previously used to do. Boilerplate code has been moved to builderCommon and resultsBuilder. Introduced SupplyWeightString as a required function for all primitives. This is now used by all primitives that need to order by a text column. The end result: memorySort can now sort by text columns, and it can be on top of any primitive, like a join, subquery, etc. Signed-off-by: Sugu Sougoumarane --- go/vt/vtgate/engine/memory_sort.go | 20 ++- go/vt/vtgate/engine/memory_sort_test.go | 88 ++++++++++ go/vt/vtgate/engine/ordered_aggregate.go | 5 + go/vt/vtgate/engine/route.go | 5 + go/vt/vtgate/planbuilder/builder.go | 79 +++++++++ go/vt/vtgate/planbuilder/join.go | 31 +++- go/vt/vtgate/planbuilder/memory_sort.go | 43 +++-- go/vt/vtgate/planbuilder/merge_sort.go | 46 +++++- go/vt/vtgate/planbuilder/ordered_aggregate.go | 87 +++------- go/vt/vtgate/planbuilder/pullout_subquery.go | 5 + go/vt/vtgate/planbuilder/route.go | 67 +------- go/vt/vtgate/planbuilder/symtab.go | 22 +++ .../planbuilder/testdata/aggr_cases.txt | 61 ++++++- .../testdata/memory_sort_cases.txt | 151 ++++++++++++++++++ go/vt/vtgate/planbuilder/vindex_func.go | 5 + 15 files changed, 571 insertions(+), 144 deletions(-) diff --git a/go/vt/vtgate/engine/memory_sort.go b/go/vt/vtgate/engine/memory_sort.go index b81acd888c..ef05eaa31b 100644 --- a/go/vt/vtgate/engine/memory_sort.go +++ b/go/vt/vtgate/engine/memory_sort.go @@ -34,6 +34,11 @@ type MemorySort struct { UpperLimit sqltypes.PlanValue OrderBy []OrderbyParams Input Primitive + + // TruncateColumnCount specifies the number of columns to return + // in the final result. Rest of the columns are truncated + // from the result received. If 0, no truncation happens. + TruncateColumnCount int `json:",omitempty"` } // MarshalJSON serializes the MemorySort into a JSON representation. @@ -58,6 +63,11 @@ func (ms *MemorySort) RouteType() string { return ms.Input.RouteType() } +// SetTruncateColumnCount sets the truncate column count. +func (ms *MemorySort) SetTruncateColumnCount(count int) { + ms.TruncateColumnCount = count +} + // Execute satisfies the Primtive interface. func (ms *MemorySort) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { count, err := ms.fetchCount(bindVars) @@ -82,7 +92,7 @@ func (ms *MemorySort) Execute(vcursor VCursor, bindVars map[string]*querypb.Bind result.Rows = result.Rows[:count] result.RowsAffected = uint64(count) } - return result, nil + return result.Truncate(ms.TruncateColumnCount), nil } // StreamExecute satisfies the Primtive interface. @@ -92,6 +102,10 @@ func (ms *MemorySort) StreamExecute(vcursor VCursor, bindVars map[string]*queryp return err } + cb := func(qr *sqltypes.Result) error { + return callback(qr.Truncate(ms.TruncateColumnCount)) + } + // You have to reverse the ordering because the highest values // must be dropped once the upper limit is reached. sh := &sortHeap{ @@ -100,7 +114,7 @@ func (ms *MemorySort) StreamExecute(vcursor VCursor, bindVars map[string]*queryp } err = ms.Input.StreamExecute(vcursor, bindVars, wantfields, func(qr *sqltypes.Result) error { if len(qr.Fields) != 0 { - if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil { + if err := cb(&sqltypes.Result{Fields: qr.Fields}); err != nil { return err } } @@ -128,7 +142,7 @@ func (ms *MemorySort) StreamExecute(vcursor VCursor, bindVars map[string]*queryp // Unreachable. return sh.err } - return callback(&sqltypes.Result{Rows: sh.rows}) + return cb(&sqltypes.Result{Rows: sh.rows}) } // GetFields satisfies the Primtive interface. diff --git a/go/vt/vtgate/engine/memory_sort_test.go b/go/vt/vtgate/engine/memory_sort_test.go index 555c770458..a9732f4e23 100644 --- a/go/vt/vtgate/engine/memory_sort_test.go +++ b/go/vt/vtgate/engine/memory_sort_test.go @@ -184,6 +184,94 @@ func TestMemorySortGetFields(t *testing.T) { } } +func TestMemorySortExecuteTruncate(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|c3", + "varbinary|decimal|int64", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "a|1|1", + "b|2|1", + "a|1|1", + "c|4|1", + "c|3|1", + )}, + } + + ms := &MemorySort{ + OrderBy: []OrderbyParams{{ + Col: 1, + }}, + Input: fp, + TruncateColumnCount: 2, + } + + result, err := ms.Execute(nil, nil, false) + if err != nil { + t.Fatal(err) + } + + wantResult := sqltypes.MakeTestResult( + fields[:2], + "a|1", + "a|1", + "b|2", + "c|3", + "c|4", + ) + if !reflect.DeepEqual(result, wantResult) { + t.Errorf("oa.Execute:\n%v, want\n%v", result, wantResult) + } +} + +func TestMemorySortStreamExecuteTruncate(t *testing.T) { + fields := sqltypes.MakeTestFields( + "c1|c2|c3", + "varbinary|decimal|int64", + ) + fp := &fakePrimitive{ + results: []*sqltypes.Result{sqltypes.MakeTestResult( + fields, + "a|1|1", + "b|2|1", + "a|1|1", + "c|4|1", + "c|3|1", + )}, + } + + ms := &MemorySort{ + OrderBy: []OrderbyParams{{ + Col: 1, + }}, + Input: fp, + TruncateColumnCount: 2, + } + + var results []*sqltypes.Result + err := ms.StreamExecute(noopVCursor{}, nil, false, func(qr *sqltypes.Result) error { + results = append(results, qr) + return nil + }) + if err != nil { + t.Fatal(err) + } + + wantResults := sqltypes.MakeTestStreamingResults( + fields[:2], + "a|1", + "a|1", + "b|2", + "c|3", + "c|4", + ) + if !reflect.DeepEqual(results, wantResults) { + t.Errorf("oa.Execute:\n%v, want\n%v", results, wantResults) + } +} + func TestMemorySortMultiColumn(t *testing.T) { fields := sqltypes.MakeTestFields( "c1|c2", diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index e5b71ab11f..d36f8dd911 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -121,6 +121,11 @@ func (oa *OrderedAggregate) RouteType() string { return oa.Input.RouteType() } +// SetTruncateColumnCount sets the truncate column count. +func (oa *OrderedAggregate) SetTruncateColumnCount(count int) { + oa.TruncateColumnCount = count +} + // Execute is a Primitive function. func (oa *OrderedAggregate) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { qr, err := oa.execute(vcursor, bindVars, wantfields) diff --git a/go/vt/vtgate/engine/route.go b/go/vt/vtgate/engine/route.go index f24dd6f2e4..c9fc30ba7a 100644 --- a/go/vt/vtgate/engine/route.go +++ b/go/vt/vtgate/engine/route.go @@ -201,6 +201,11 @@ func (route *Route) RouteType() string { return routeName[route.Opcode] } +// SetTruncateColumnCount sets the truncate column count. +func (route *Route) SetTruncateColumnCount(count int) { + route.TruncateColumnCount = count +} + // Execute performs a non-streaming exec. func (route *Route) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { if route.QueryTimeout != 0 { diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 9f4cca1454..1a5aeaa6bf 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -104,6 +104,10 @@ type builder interface { // result column and returns a distinct symbol for it. SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) + // SupplyWeightString must supply a weight_string expression of the + // specified column. + SupplyWeightString(colNumber int) (weightcolNumber int, err error) + // Primitive returns the underlying primitive. // This function should only be called after Wireup is finished. Primitive() engine.Primitive @@ -170,6 +174,81 @@ func (bc *builderCommon) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, co return bc.input.SupplyCol(col) } +func (bc *builderCommon) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + return bc.input.SupplyWeightString(colNumber) +} + +//------------------------------------------------------------------------- + +type truncater interface { + SetTruncateColumnCount(int) +} + +// resultsBuilder is a superset of builderCommon. It also handles +// resultsColumn functionality. +type resultsBuilder struct { + builderCommon + resultColumns []*resultColumn + weightStrings map[*resultColumn]int + truncater truncater +} + +func newResultsBuilder(input builder, truncater truncater) resultsBuilder { + return resultsBuilder{ + builderCommon: newBuilderCommon(input), + resultColumns: input.ResultColumns(), + weightStrings: make(map[*resultColumn]int), + truncater: truncater, + } +} + +func (rsb *resultsBuilder) ResultColumns() []*resultColumn { + return rsb.resultColumns +} + +// SupplyCol is currently unreachable because the builders using resultsBuilder +// are currently above a join, which is the only builder that uses it for now. +// This can change if we start supporting correlated subqueries. +func (rsb *resultsBuilder) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { + c := col.Metadata.(*column) + for i, rc := range rsb.resultColumns { + if rc.column == c { + return rc, i + } + } + rc, colNumber = rsb.input.SupplyCol(col) + if colNumber < len(rsb.resultColumns) { + return rc, colNumber + } + // Add result columns from input until colNumber is reached. + for colNumber >= len(rsb.resultColumns) { + rsb.resultColumns = append(rsb.resultColumns, rsb.input.ResultColumns()[len(rsb.resultColumns)]) + } + rsb.truncater.SetTruncateColumnCount(len(rsb.resultColumns)) + return rc, colNumber +} + +func (rsb *resultsBuilder) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + rc := rsb.resultColumns[colNumber] + if weightcolNumber, ok := rsb.weightStrings[rc]; ok { + return weightcolNumber, nil + } + weightcolNumber, err = rsb.input.SupplyWeightString(colNumber) + if err != nil { + return 0, nil + } + rsb.weightStrings[rc] = weightcolNumber + if weightcolNumber < len(rsb.resultColumns) { + return weightcolNumber, nil + } + // Add result columns from input until weightcolNumber is reached. + for weightcolNumber >= len(rsb.resultColumns) { + rsb.resultColumns = append(rsb.resultColumns, rsb.input.ResultColumns()[len(rsb.resultColumns)]) + } + rsb.truncater.SetTruncateColumnCount(len(rsb.resultColumns)) + return weightcolNumber, nil +} + //------------------------------------------------------------------------- // Build builds a plan for a query based on the specified vschema. diff --git a/go/vt/vtgate/planbuilder/join.go b/go/vt/vtgate/planbuilder/join.go index dd8ba16515..86a99d9357 100644 --- a/go/vt/vtgate/planbuilder/join.go +++ b/go/vt/vtgate/planbuilder/join.go @@ -31,6 +31,7 @@ var _ builder = (*join)(nil) type join struct { order int resultColumns []*resultColumn + weightStrings map[*resultColumn]int // leftOrder stores the order number of the left node. This is // used for a b-tree style traversal towards the target route. @@ -98,8 +99,9 @@ func newJoin(lpb, rpb *primitiveBuilder, ajoin *sqlparser.JoinTableExpr) error { } } lpb.bldr = &join{ - Left: lpb.bldr, - Right: rpb.bldr, + weightStrings: make(map[*resultColumn]int), + Left: lpb.bldr, + Right: rpb.bldr, ejoin: &engine.Join{ Opcode: opcode, Vars: make(map[string]int), @@ -337,6 +339,31 @@ func (jb *join) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber i return rc, len(jb.ejoin.Cols) - 1 } +// SupplyWeightString satisfies the builder interface. +func (jb *join) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + rc := jb.resultColumns[colNumber] + if weightcolNumber, ok := jb.weightStrings[rc]; ok { + return weightcolNumber, nil + } + routeNumber := rc.column.Origin().Order() + if jb.isOnLeft(routeNumber) { + sourceCol, err := jb.Left.SupplyWeightString(-jb.ejoin.Cols[colNumber] - 1) + if err != nil { + return 0, err + } + jb.ejoin.Cols = append(jb.ejoin.Cols, -sourceCol-1) + } else { + sourceCol, err := jb.Right.SupplyWeightString(jb.ejoin.Cols[colNumber] - 1) + if err != nil { + return 0, err + } + jb.ejoin.Cols = append(jb.ejoin.Cols, sourceCol+1) + } + jb.resultColumns = append(jb.resultColumns, rc) + jb.weightStrings[rc] = len(jb.ejoin.Cols) - 1 + return len(jb.ejoin.Cols) - 1, nil +} + // isOnLeft returns true if the specified route number // is on the left side of the join. If false, it means // the node is on the right. diff --git a/go/vt/vtgate/planbuilder/memory_sort.go b/go/vt/vtgate/planbuilder/memory_sort.go index b8b465d55b..8fed34d0ba 100644 --- a/go/vt/vtgate/planbuilder/memory_sort.go +++ b/go/vt/vtgate/planbuilder/memory_sort.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -32,17 +33,16 @@ var _ builder = (*memorySort)(nil) // operation. Since a limit is the final operation // of a SELECT, most pushes are not applicable. type memorySort struct { - builderCommon - resultColumns []*resultColumn - eMemorySort *engine.MemorySort + resultsBuilder + eMemorySort *engine.MemorySort } // newMemorySort builds a new memorySort. func newMemorySort(bldr builder, orderBy sqlparser.OrderBy) (*memorySort, error) { + eMemorySort := &engine.MemorySort{} ms := &memorySort{ - builderCommon: newBuilderCommon(bldr), - resultColumns: bldr.ResultColumns(), - eMemorySort: &engine.MemorySort{}, + resultsBuilder: newResultsBuilder(bldr, eMemorySort), + eMemorySort: eMemorySort, } for _, order := range orderBy { colNumber := -1 @@ -83,11 +83,6 @@ func (ms *memorySort) Primitive() engine.Primitive { return ms.eMemorySort } -// ResultColumns satisfies the builder interface. -func (ms *memorySort) ResultColumns() []*resultColumn { - return ms.resultColumns -} - // PushFilter satisfies the builder interface. func (ms *memorySort) PushFilter(_ *primitiveBuilder, _ sqlparser.Expr, whereType string, _ builder) error { return errors.New("memorySort.PushFilter: unreachable") @@ -118,6 +113,32 @@ func (ms *memorySort) SetLimit(limit *sqlparser.Limit) error { return errors.New("memorySort.Limit: unreachable") } +// Wireup satisfies the builder interface. +// If text columns are detected in the keys, then the function modifies +// the primitive to pull a corresponding weight_string from mysql and +// compare those instead. This is because we currently don't have the +// ability to mimic mysql's collation behavior. +func (ms *memorySort) Wireup(bldr builder, jt *jointab) error { + for i, orderby := range ms.eMemorySort.OrderBy { + rc := ms.resultColumns[orderby.Col] + if sqltypes.IsText(rc.column.typ) { + // If a weight string was previously requested, reuse it. + if weightcolNumber, ok := ms.weightStrings[rc]; ok { + ms.eMemorySort.OrderBy[i].Col = weightcolNumber + continue + } + weightcolNumber, err := ms.input.SupplyWeightString(orderby.Col) + if err != nil { + return err + } + ms.weightStrings[rc] = weightcolNumber + ms.eMemorySort.OrderBy[i].Col = weightcolNumber + ms.eMemorySort.TruncateColumnCount = len(ms.resultColumns) + } + } + return ms.input.Wireup(bldr, jt) +} + // SetUpperLimit satisfies the builder interface. // This is a no-op because we actually call SetLimit for this primitive. // In the future, we may have to honor this call for subqueries. diff --git a/go/vt/vtgate/planbuilder/merge_sort.go b/go/vt/vtgate/planbuilder/merge_sort.go index 79cce8e8c6..0b42332104 100644 --- a/go/vt/vtgate/planbuilder/merge_sort.go +++ b/go/vt/vtgate/planbuilder/merge_sort.go @@ -19,6 +19,7 @@ package planbuilder import ( "errors" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" ) @@ -33,14 +34,24 @@ var _ builder = (*mergeSort)(nil) // Since ORDER BY happens near the end of the SQL processing, // most functions of this primitive are unreachable. type mergeSort struct { - builderCommon + resultsBuilder + truncateColumnCount int } // newMergeSort builds a new mergeSort. func newMergeSort(rb *route) *mergeSort { - return &mergeSort{ - builderCommon: newBuilderCommon(rb), + ms := &mergeSort{ + resultsBuilder: newResultsBuilder(rb, nil), } + ms.truncater = ms + return ms +} + +// SetTruncateColumnCount satisfies the truncater interface. +// This function records the truncate column count and sets +// it later on the eroute during wire-up phase. +func (ms *mergeSort) SetTruncateColumnCount(count int) { + ms.truncateColumnCount = count } // Primitive satisfies the builder interface. @@ -74,3 +85,32 @@ func (ms *mergeSort) PushGroupBy(groupBy sqlparser.GroupBy) error { func (ms *mergeSort) PushOrderBy(orderBy sqlparser.OrderBy) (builder, error) { return nil, errors.New("mergeSort.PushOrderBy: unreachable") } + +// Wireup satisfies the builder interface. +func (ms *mergeSort) Wireup(bldr builder, jt *jointab) error { + // If the route has to do the ordering, and if any columns are Text, + // we have to request the corresponding weight_string from mysql + // and use that value instead. This is because we cannot mimic + // mysql's collation behavior yet. + rb := ms.input.(*route) + rb.finalizeOptions() + ro := rb.routeOptions[0] + for i, orderby := range ro.eroute.OrderBy { + rc := ms.resultColumns[orderby.Col] + if sqltypes.IsText(rc.column.typ) { + // If a weight string was previously requested, reuse it. + if colNumber, ok := ms.weightStrings[rc]; ok { + ro.eroute.OrderBy[i].Col = colNumber + continue + } + var err error + ro.eroute.OrderBy[i].Col, err = rb.SupplyWeightString(orderby.Col) + if err != nil { + return err + } + ms.truncateColumnCount = len(ms.resultColumns) + } + } + ro.eroute.TruncateColumnCount = ms.truncateColumnCount + return ms.input.Wireup(bldr, jt) +} diff --git a/go/vt/vtgate/planbuilder/ordered_aggregate.go b/go/vt/vtgate/planbuilder/ordered_aggregate.go index 3513f16ca4..e1855ed24c 100644 --- a/go/vt/vtgate/planbuilder/ordered_aggregate.go +++ b/go/vt/vtgate/planbuilder/ordered_aggregate.go @@ -54,10 +54,8 @@ var _ builder = (*orderedAggregate)(nil) // Input: (Scatter Route with the order by request), // } type orderedAggregate struct { - resultColumns []*resultColumn - order int + resultsBuilder extraDistinct *sqlparser.ColName - input *route eaggr *engine.OrderedAggregate } @@ -132,9 +130,10 @@ func (pb *primitiveBuilder) checkAggregates(sel *sqlparser.Select) error { } // We need an aggregator primitive. + eaggr := &engine.OrderedAggregate{} pb.bldr = &orderedAggregate{ - input: rb, - eaggr: &engine.OrderedAggregate{}, + resultsBuilder: newResultsBuilder(rb, eaggr), + eaggr: eaggr, } pb.bldr.Reorder(0) return nil @@ -238,33 +237,12 @@ func findAlias(colname *sqlparser.ColName, selects sqlparser.SelectExprs) sqlpar return nil } -// Order satisfies the builder interface. -func (oa *orderedAggregate) Order() int { - return oa.order -} - -// Reorder satisfies the builder interface. -func (oa *orderedAggregate) Reorder(order int) { - oa.input.Reorder(order) - oa.order = oa.input.Order() + 1 -} - // Primitive satisfies the builder interface. func (oa *orderedAggregate) Primitive() engine.Primitive { oa.eaggr.Input = oa.input.Primitive() return oa.eaggr } -// First satisfies the builder interface. -func (oa *orderedAggregate) First() builder { - return oa.input.First() -} - -// ResultColumns satisfies the builder interface. -func (oa *orderedAggregate) ResultColumns() []*resultColumn { - return oa.resultColumns -} - // PushFilter satisfies the builder interface. func (oa *orderedAggregate) PushFilter(_ *primitiveBuilder, _ sqlparser.Expr, whereType string, _ builder) error { return errors.New("unsupported: filtering on results of aggregates") @@ -318,7 +296,7 @@ func (oa *orderedAggregate) pushAggr(pb *primitiveBuilder, expr *sqlparser.Alias // Push the expression that's inside the aggregate. // The column will eventually get added to the group by and order by clauses. innerRC, innerCol, _ = oa.input.PushSelect(pb, innerAliased, origin) - col, err := oa.input.BuildColName(innerCol) + col, err := BuildColName(oa.input.ResultColumns(), innerCol) if err != nil { return nil, 0, err } @@ -370,7 +348,12 @@ func (oa *orderedAggregate) needDistinctHandling(pb *primitiveBuilder, funcExpr if !ok { return false, nil, fmt.Errorf("syntax error: %s", sqlparser.String(funcExpr)) } - success := oa.input.removeOptions(func(ro *routeOption) bool { + rb, ok := oa.input.(*route) + if !ok { + // Unreachable + return true, innerAliased, nil + } + success := rb.removeOptions(func(ro *routeOption) bool { vindex := ro.FindVindex(pb, innerAliased.Expr) if vindex != nil && vindex.IsUnique() { return true @@ -497,7 +480,7 @@ func (oa *orderedAggregate) PushOrderBy(orderBy sqlparser.OrderBy) (builder, err continue } // Build a brand new reference for the key. - col, err := oa.input.BuildColName(key) + col, err := BuildColName(oa.input.ResultColumns(), key) if err != nil { return nil, fmt.Errorf("generating order by clause: %v", err) } @@ -513,12 +496,11 @@ func (oa *orderedAggregate) PushOrderBy(orderBy sqlparser.OrderBy) (builder, err // It's ok to push the original AST down because all references // should point to the route. Only aggregate functions are originated // by oa, and we currently don't allow the ORDER BY to reference them. - // TODO(sougou): PushOrderBy will return a mergeSort primitive, which - // we should ideally replace oa.input with. - _, err := oa.input.PushOrderBy(selOrderBy) + bldr, err := oa.input.PushOrderBy(selOrderBy) if err != nil { return nil, err } + oa.input = bldr if postSort { return newMemorySort(oa, orderBy) } @@ -542,37 +524,20 @@ func (oa *orderedAggregate) PushMisc(sel *sqlparser.Select) { // ability to mimic mysql's collation behavior. func (oa *orderedAggregate) Wireup(bldr builder, jt *jointab) error { for i, colNumber := range oa.eaggr.Keys { - if sqltypes.IsText(oa.resultColumns[colNumber].column.typ) { - // len(oa.resultColumns) does not change. No harm using the value multiple times. + rc := oa.resultColumns[colNumber] + if sqltypes.IsText(rc.column.typ) { + if weightcolNumber, ok := oa.weightStrings[rc]; ok { + oa.eaggr.Keys[i] = weightcolNumber + continue + } + weightcolNumber, err := oa.input.SupplyWeightString(colNumber) + if err != nil { + return err + } + oa.weightStrings[rc] = weightcolNumber + oa.eaggr.Keys[i] = weightcolNumber oa.eaggr.TruncateColumnCount = len(oa.resultColumns) - oa.eaggr.Keys[i] = oa.input.SupplyWeightString(colNumber) } } return oa.input.Wireup(bldr, jt) } - -// SupplyVar satisfies the builder interface. -func (oa *orderedAggregate) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { - oa.input.SupplyVar(from, to, col, varname) -} - -// SupplyCol satisfies the builder interface. -// This function is unreachable. It's just a reference implementation for now. -func (oa *orderedAggregate) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { - c := col.Metadata.(*column) - for i, rc := range oa.resultColumns { - if rc.column == c { - return rc, i - } - } - rc, colNumber = oa.input.SupplyCol(col) - if colNumber < len(oa.resultColumns) { - return rc, colNumber - } - // Add result columns from input until colNumber is reached. - for colNumber >= len(oa.resultColumns) { - oa.resultColumns = append(oa.resultColumns, oa.input.ResultColumns()[len(oa.resultColumns)]) - } - oa.eaggr.TruncateColumnCount = len(oa.resultColumns) - return rc, colNumber -} diff --git a/go/vt/vtgate/planbuilder/pullout_subquery.go b/go/vt/vtgate/planbuilder/pullout_subquery.go index d29b6aa9db..109b205c14 100644 --- a/go/vt/vtgate/planbuilder/pullout_subquery.go +++ b/go/vt/vtgate/planbuilder/pullout_subquery.go @@ -145,3 +145,8 @@ func (ps *pulloutSubquery) SupplyVar(from, to int, col *sqlparser.ColName, varna func (ps *pulloutSubquery) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { return ps.underlying.SupplyCol(col) } + +// SupplyWeightString satisfies the builder interface. +func (ps *pulloutSubquery) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + return ps.underlying.SupplyWeightString(colNumber) +} diff --git a/go/vt/vtgate/planbuilder/route.go b/go/vt/vtgate/planbuilder/route.go index 609a981db9..c8e58cf6bd 100644 --- a/go/vt/vtgate/planbuilder/route.go +++ b/go/vt/vtgate/planbuilder/route.go @@ -17,7 +17,6 @@ limitations under the License. package planbuilder import ( - "errors" "fmt" "strings" @@ -273,43 +272,6 @@ func (rb *route) Wireup(bldr builder, jt *jointab) error { } } - // If rb has to do the ordering, and if any columns are Text, - // we have to request the corresponding weight_string from mysql - // and use that value instead. This is because we cannot mimic - // mysql's collation behavior yet. - for i, orderby := range ro.eroute.OrderBy { - rc := rb.resultColumns[orderby.Col] - if sqltypes.IsText(rc.column.typ) { - // If a weight string was previously requested (by OrderedAggregator), - // reuse it. - if colNumber, ok := rb.weightStrings[rc]; ok { - ro.eroute.OrderBy[i].Col = colNumber - continue - } - - // len(rb.resultColumns) does not change. No harm using the value multiple times. - ro.eroute.TruncateColumnCount = len(rb.resultColumns) - - // This code is partially duplicated from SupplyWeightString and PushSelect. - // We should not update resultColumns because it's not returned in the result. - // This is why we don't call PushSelect (or SupplyWeightString). - expr := &sqlparser.AliasedExpr{ - Expr: &sqlparser.FuncExpr{ - Name: sqlparser.NewColIdent("weight_string"), - Exprs: []sqlparser.SelectExpr{ - rb.Select.(*sqlparser.Select).SelectExprs[orderby.Col], - }, - }, - } - sel := rb.Select.(*sqlparser.Select) - sel.SelectExprs = append(sel.SelectExprs, expr) - ro.eroute.OrderBy[i].Col = len(sel.SelectExprs) - 1 - // We don't really have to update weightStrings, but we're doing it - // for good measure. - rb.weightStrings[rc] = len(sel.SelectExprs) - 1 - } - } - // Fix up the AST. _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { switch node := node.(type) { @@ -455,10 +417,11 @@ func (rb *route) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber return rc, len(rb.resultColumns) - 1 } -func (rb *route) SupplyWeightString(colNumber int) (weightcolNumber int) { +// SupplyWeightString satisfies the builder interface. +func (rb *route) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { rc := rb.resultColumns[colNumber] if weightcolNumber, ok := rb.weightStrings[rc]; ok { - return weightcolNumber + return weightcolNumber, nil } expr := &sqlparser.AliasedExpr{ Expr: &sqlparser.FuncExpr{ @@ -471,29 +434,7 @@ func (rb *route) SupplyWeightString(colNumber int) (weightcolNumber int) { // It's ok to pass nil for pb and builder because PushSelect doesn't use them. _, weightcolNumber, _ = rb.PushSelect(nil, expr, nil) rb.weightStrings[rc] = weightcolNumber - return weightcolNumber -} - -// BuildColName builds a *sqlparser.ColName for the resultColumn specified -// by the index. The built ColName will correctly reference the resultColumn -// it was built from, which is safe to push down into the route. -func (rb *route) BuildColName(index int) (*sqlparser.ColName, error) { - alias := rb.resultColumns[index].alias - if alias.IsEmpty() { - return nil, errors.New("cannot reference a complex expression") - } - for i, rc := range rb.resultColumns { - if i == index { - continue - } - if rc.alias.Equal(alias) { - return nil, fmt.Errorf("ambiguous symbol reference: %v", alias) - } - } - return &sqlparser.ColName{ - Metadata: rb.resultColumns[index].column, - Name: alias, - }, nil + return weightcolNumber, nil } // MergeSubquery returns true if the subquery route could successfully be merged diff --git a/go/vt/vtgate/planbuilder/symtab.go b/go/vt/vtgate/planbuilder/symtab.go index 5ccc8bffc1..a676d9bfbc 100644 --- a/go/vt/vtgate/planbuilder/symtab.go +++ b/go/vt/vtgate/planbuilder/symtab.go @@ -400,6 +400,28 @@ func ResultFromNumber(rcs []*resultColumn, val *sqlparser.SQLVal) (int, error) { return int(num - 1), nil } +// BuildColName builds a *sqlparser.ColName for the resultColumn specified +// by the index. The built ColName will correctly reference the resultColumn +// it was built from. +func BuildColName(rcs []*resultColumn, index int) (*sqlparser.ColName, error) { + alias := rcs[index].alias + if alias.IsEmpty() { + return nil, errors.New("cannot reference a complex expression") + } + for i, rc := range rcs { + if i == index { + continue + } + if rc.alias.Equal(alias) { + return nil, fmt.Errorf("ambiguous symbol reference: %v", alias) + } + } + return &sqlparser.ColName{ + Metadata: rcs[index].column, + Name: alias, + }, nil +} + // ResolveSymbols resolves all column references against symtab. // This makes sure that they all have their Metadata initialized. // If a symbol cannot be resolved or if the expression contains diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index 11237d9880..a564b66b4e 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -115,7 +115,66 @@ "Col": 3, "Desc": false } - ] + ], + "TruncateColumnCount": 5 + } + } +} + +# scatter group by a text column, reuse existing weight_string +"select count(*) k, a, textcol1, b from user group by a, textcol1, b order by k, textcol1" +{ + "Original": "select count(*) k, a, textcol1, b from user group by a, textcol1, b order by k, textcol1", + "Instructions": { + "Opcode": "MemorySort", + "MaxRows": null, + "OrderBy": [ + { + "Col": 0, + "Desc": false + }, + { + "Col": 4, + "Desc": false + } + ], + "Input": { + "Aggregates": [ + { + "Opcode": "count", + "Col": 0 + } + ], + "Keys": [ + 1, + 4, + 3 + ], + "TruncateColumnCount": 5, + "Input": { + "Opcode": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "Query": "select count(*) as k, a, textcol1, b, weight_string(textcol1) from user group by a, textcol1, b order by textcol1 asc, a asc, b asc", + "FieldQuery": "select count(*) as k, a, textcol1, b, weight_string(textcol1) from user where 1 != 1 group by a, textcol1, b", + "OrderBy": [ + { + "Col": 4, + "Desc": false + }, + { + "Col": 1, + "Desc": false + }, + { + "Col": 3, + "Desc": false + } + ], + "TruncateColumnCount": 5 + } } } } diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt index 37f2367853..771abad962 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt @@ -227,6 +227,63 @@ } } +# scatter aggregate with memory sort and order by number, reuse weight_string +# we have to use a meaningless construct to test this. +"select textcol1, count(*) k from user group by textcol1 order by textcol1, k, textcol1" +{ + "Original": "select textcol1, count(*) k from user group by textcol1 order by textcol1, k, textcol1", + "Instructions": { + "Opcode": "MemorySort", + "MaxRows": null, + "OrderBy": [ + { + "Col": 2, + "Desc": false + }, + { + "Col": 1, + "Desc": false + }, + { + "Col": 2, + "Desc": false + } + ], + "Input": { + "Aggregates": [ + { + "Opcode": "count", + "Col": 1 + } + ], + "Keys": [ + 2 + ], + "TruncateColumnCount": 3, + "Input": { + "Opcode": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "Query": "select textcol1, count(*) as k, weight_string(textcol1) from user group by textcol1 order by textcol1 asc, textcol1 asc", + "FieldQuery": "select textcol1, count(*) as k, weight_string(textcol1) from user where 1 != 1 group by textcol1", + "OrderBy": [ + { + "Col": 2, + "Desc": false + }, + { + "Col": 2, + "Desc": false + } + ], + "TruncateColumnCount": 3 + } + } + } +} + # order by on a cross-shard subquery "select id from (select user.id, user.col from user join user_extra) as t order by id" { @@ -387,6 +444,100 @@ } } +# Order by for join, on text column in LHS. +"select u.a, u.textcol1, un.col2 from user u join unsharded un order by u.textcol1, un.col2" +{ + "Original": "select u.a, u.textcol1, un.col2 from user u join unsharded un order by u.textcol1, un.col2", + "Instructions": { + "Opcode": "MemorySort", + "MaxRows": null, + "OrderBy": [ + { + "Col": 3, + "Desc": false + }, + { + "Col": 2, + "Desc": false + } + ], + "Input": { + "Opcode": "Join", + "Left": { + "Opcode": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "Query": "select u.a, u.textcol1, weight_string(u.textcol1) from user as u", + "FieldQuery": "select u.a, u.textcol1, weight_string(u.textcol1) from user as u where 1 != 1" + }, + "Right": { + "Opcode": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "Query": "select un.col2 from unsharded as un", + "FieldQuery": "select un.col2 from unsharded as un where 1 != 1" + }, + "Cols": [ + -1, + -2, + 1, + -3 + ] + } + } +} + +# Order by for join, on text column in RHS. +"select u.a, u.textcol1, un.col2 from unsharded un join user u order by u.textcol1, un.col2" +{ + "Original": "select u.a, u.textcol1, un.col2 from unsharded un join user u order by u.textcol1, un.col2", + "Instructions": { + "Opcode": "MemorySort", + "MaxRows": null, + "OrderBy": [ + { + "Col": 3, + "Desc": false + }, + { + "Col": 2, + "Desc": false + } + ], + "Input": { + "Opcode": "Join", + "Left": { + "Opcode": "SelectUnsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "Query": "select un.col2 from unsharded as un", + "FieldQuery": "select un.col2 from unsharded as un where 1 != 1" + }, + "Right": { + "Opcode": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "Query": "select u.a, u.textcol1, weight_string(u.textcol1) from user as u", + "FieldQuery": "select u.a, u.textcol1, weight_string(u.textcol1) from user as u where 1 != 1" + }, + "Cols": [ + 1, + 2, + -1, + 3 + ] + } + } +} + # order by for vindex func "select id, keyspace_id, range_start, range_end from user_index where id = :id order by range_start" { diff --git a/go/vt/vtgate/planbuilder/vindex_func.go b/go/vt/vtgate/planbuilder/vindex_func.go index c51380b748..21f36d7400 100644 --- a/go/vt/vtgate/planbuilder/vindex_func.go +++ b/go/vt/vtgate/planbuilder/vindex_func.go @@ -211,3 +211,8 @@ func (vf *vindexFunc) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNu vf.eVindexFunc.Cols = append(vf.eVindexFunc.Cols, c.colNumber) return rc, len(vf.resultColumns) - 1 } + +// SupplyWeightString satisfies the builder interface. +func (vf *vindexFunc) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + return 0, errors.New("cannot do collation on vindex function") +}