Merge pull request #7313 from planetscale/v4-subquery

Gen4: Handling subquery in query graph
This commit is contained in:
Andres Taylor 2021-01-18 05:14:08 +01:00 коммит произвёл GitHub
Родитель 915f80a8f2 5dbda83a0f
Коммит d09dffef0c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 223 добавлений и 125 удалений

Просмотреть файл

@ -41,6 +41,9 @@ type (
// noDeps contains the predicates that can be evaluated anywhere.
noDeps sqlparser.Expr
// subqueries contains the subqueries that depend on this query graph
subqueries map[*sqlparser.Subquery][]*queryGraph
}
// queryTable is a single FROM table, including all predicates particular to this table
@ -79,9 +82,38 @@ func createQGFromSelect(sel *sqlparser.Select, semTable *semantics.SemTable) (*q
return qg, nil
}
func createQGFromSelectStatement(selStmt sqlparser.SelectStatement, semTable *semantics.SemTable) ([]*queryGraph, error) {
switch stmt := selStmt.(type) {
case *sqlparser.Select:
qg, err := createQGFromSelect(stmt, semTable)
if err != nil {
return nil, err
}
return []*queryGraph{qg}, err
case *sqlparser.Union:
qg, err := createQGFromSelectStatement(stmt.FirstStatement, semTable)
if err != nil {
return nil, err
}
for _, sel := range stmt.UnionSelects {
qgr, err := createQGFromSelectStatement(sel.Statement, semTable)
if err != nil {
return nil, err
}
qg = append(qg, qgr...)
}
return qg, nil
case *sqlparser.ParenSelect:
return createQGFromSelectStatement(stmt.Select, semTable)
}
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: not reachable %T", selStmt)
}
func newQueryGraph() *queryGraph {
return &queryGraph{
crossTable: map[semantics.TableSet][]sqlparser.Expr{},
subqueries: map[*sqlparser.Subquery][]*queryGraph{},
}
}
@ -156,7 +188,20 @@ func (qg *queryGraph) collectPredicate(predicate sqlparser.Expr, semTable *seman
}
qg.crossTable[deps] = allPredicates
}
return nil
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch subQuery := node.(type) {
case *sqlparser.Subquery:
qgr, err := createQGFromSelectStatement(subQuery.Select, semTable)
if err != nil {
return false, err
}
qg.subqueries[subQuery] = qgr
}
return true, nil
}, predicate)
return err
}
func (qg *queryGraph) addToSingleTable(table semantics.TableSet, predicate sqlparser.Expr) bool {

Просмотреть файл

@ -18,130 +18,84 @@ package planbuilder
import (
"fmt"
"sort"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/test/utils"
"github.com/stretchr/testify/require"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/semantics"
"vitess.io/vitess/go/vt/vtgate/vindexes"
)
type tcase struct {
input string
output *queryGraph
input, output string
}
var threeWayJoin = &queryGraph{
tables: []*queryTable{{
tableID: 1,
alias: tableAlias("t"),
table: tableName("t"),
predicates: []sqlparser.Expr{equals(colName("t", "name"), literalString("foo"))},
}, {
tableID: 2,
alias: tableAlias("y"),
table: tableName("y"),
predicates: []sqlparser.Expr{equals(colName("y", "col"), literalInt(42))},
}, {
tableID: 4,
alias: tableAlias("z"),
table: tableName("z"),
predicates: []sqlparser.Expr{equals(colName("z", "baz"), literalInt(101))},
}},
crossTable: map[semantics.TableSet][]sqlparser.Expr{
1 | 2: {
equals(
colName("t", "id"),
colName("y", "t_id"))},
1 | 4: {
equals(
colName("t", "id"),
colName("z", "t_id"))}}}
var tcases = []tcase{{
input: "select * from t",
output: &queryGraph{
tables: []*queryTable{{
tableID: 1,
alias: tableAlias("t"),
table: tableName("t"),
}},
crossTable: map[semantics.TableSet][]sqlparser.Expr{},
},
output: `{
Tables:
1:t
}`,
}, {
input: "select t.c from t,y,z where t.c = y.c and (t.a = z.a or t.a = y.a) and 1 < 2",
output: &queryGraph{
tables: []*queryTable{{
tableID: 1,
alias: tableAlias("t"),
table: tableName("t"),
}, {
tableID: 2,
alias: tableAlias("y"),
table: tableName("y"),
}, {
tableID: 4,
alias: tableAlias("z"),
table: tableName("z"),
}},
crossTable: map[semantics.TableSet][]sqlparser.Expr{
1 | 2: {
equals(
colName("t", "c"),
colName("y", "c"))},
1 | 2 | 4: {
or(
equals(
colName("t", "a"),
colName("z", "a")),
equals(
colName("t", "a"),
colName("y", "a")))},
},
noDeps: &sqlparser.ComparisonExpr{
Operator: sqlparser.LessThanOp,
Left: literalInt(1),
Right: literalInt(2)},
},
output: `{
Tables:
1:t
2:y
4:z
JoinPredicates:
1:2 - t.c = y.c
1:2:4 - t.a = z.a or t.a = y.a
ForAll: 1 < 2
}`,
}, {
input: "select t.c from t join y on t.id = y.t_id join z on t.id = z.t_id where t.name = 'foo' and y.col = 42 and z.baz = 101",
output: threeWayJoin,
input: "select t.c from t join y on t.id = y.t_id join z on t.id = z.t_id where t.name = 'foo' and y.col = 42 and z.baz = 101",
output: `{
Tables:
1:t where t.` + "`name`" + ` = 'foo'
2:y where y.col = 42
4:z where z.baz = 101
JoinPredicates:
1:2 - t.id = y.t_id
1:4 - t.id = z.t_id
}`,
}, {
input: "select t.c from t,y,z where t.name = 'foo' and y.col = 42 and z.baz = 101 and t.id = y.t_id and t.id = z.t_id",
output: threeWayJoin,
input: "select t.c from t,y,z where t.name = 'foo' and y.col = 42 and z.baz = 101 and t.id = y.t_id and t.id = z.t_id",
output: `{
Tables:
1:t where t.` + "`name`" + ` = 'foo'
2:y where y.col = 42
4:z where z.baz = 101
JoinPredicates:
1:2 - t.id = y.t_id
1:4 - t.id = z.t_id
}`,
}, {
input: "select 1 from t where '1' = 1 and 12 = '12'",
output: &queryGraph{
tables: []*queryTable{{
tableID: 1,
alias: tableAlias("t"),
table: tableName("t"),
}},
crossTable: map[semantics.TableSet][]sqlparser.Expr{},
noDeps: &sqlparser.AndExpr{
Left: equals(literalString("1"), literalInt(1)),
Right: equals(literalInt(12), literalString("12")),
},
},
}}
func literalInt(i int) *sqlparser.Literal {
return &sqlparser.Literal{Type: sqlparser.IntVal, Val: []byte(fmt.Sprintf("%d", i))}
}
func literalString(s string) *sqlparser.Literal {
return &sqlparser.Literal{Type: sqlparser.StrVal, Val: []byte(s)}
}
func or(left, right sqlparser.Expr) sqlparser.Expr {
return &sqlparser.OrExpr{
Left: left,
Right: right,
output: `{
Tables:
1:t
ForAll: '1' = 1 and 12 = '12'
}`,
}, {
input: "select 1 from t where exists (select 1)",
output: `{
Tables:
1:t
ForAll: exists (select 1 from dual)
SubQueries:
(select 1 from dual) - {
Tables:
2:dual
}
}
}`,
}}
func equals(left, right sqlparser.Expr) sqlparser.Expr {
return &sqlparser.ComparisonExpr{
@ -155,40 +109,127 @@ func colName(table, column string) *sqlparser.ColName {
return &sqlparser.ColName{Name: sqlparser.NewColIdent(column), Qualifier: tableName(table)}
}
func tableAlias(name string) *sqlparser.AliasedTableExpr {
return &sqlparser.AliasedTableExpr{Expr: sqlparser.TableName{Name: sqlparser.NewTableIdent(name)}}
}
func tableName(name string) sqlparser.TableName {
return sqlparser.TableName{Name: sqlparser.NewTableIdent(name)}
}
type schemaInf struct{}
func (node *schemaInf) FindTable(tablename sqlparser.TableName) (*vindexes.Table, error) {
return nil, nil
}
func TestQueryGraph(t *testing.T) {
for _, tc := range tcases {
for i, tc := range tcases {
sql := tc.input
t.Run(sql, func(t *testing.T) {
t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) {
tree, err := sqlparser.Parse(sql)
require.NoError(t, err)
semTable, err := semantics.Analyse(tree)
require.NoError(t, err)
qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable)
require.NoError(t, err)
mustMatch(t, tc.output, qgraph, "incorrect query graph")
fmt.Println(qgraph.testString())
assert.Equal(t, tc.output, qgraph.testString())
utils.MustMatch(t, tc.output, qgraph.testString(), "incorrect query graph")
})
}
}
var mustMatch = utils.MustMatchFn(
[]interface{}{ // types with unexported fields
queryGraph{},
queryTable{},
sqlparser.TableIdent{},
},
[]string{}, // ignored fields
)
func TestString(t *testing.T) {
tree, err := sqlparser.Parse("select * from a,b join c on b.id = c.id where a.id = b.id and b.col IN (select 42) and func() = 'foo'")
require.NoError(t, err)
semTable, err := semantics.Analyse(tree)
require.NoError(t, err)
qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable)
require.NoError(t, err)
utils.MustMatch(t, `{
Tables:
1:a
2:b where b.col in (select 42 from dual)
4:c
JoinPredicates:
1:2 - a.id = b.id
2:4 - b.id = c.id
ForAll: func() = 'foo'
SubQueries:
(select 42 from dual) - {
Tables:
8:dual
}
}`, qgraph.testString())
}
func (qt *queryTable) testString() string {
var alias string
if !qt.alias.As.IsEmpty() {
alias = " AS " + sqlparser.String(qt.alias.As)
}
var preds []string
for _, predicate := range qt.predicates {
preds = append(preds, sqlparser.String(predicate))
}
var where string
if len(preds) > 0 {
where = " where " + strings.Join(preds, " and ")
}
return fmt.Sprintf("\t%d:%s%s%s", qt.tableID, sqlparser.String(qt.table), alias, where)
}
func (qg *queryGraph) testString() string {
return fmt.Sprintf(`{
Tables:
%s%s%s%s
}`, strings.Join(qg.tableNames(), "\n"), qg.crossPredicateString(), qg.noDepsString(), qg.subqueriesString())
}
func (qg *queryGraph) crossPredicateString() string {
if len(qg.crossTable) == 0 {
return ""
}
var joinPreds []string
for deps, predicates := range qg.crossTable {
var tables []string
for _, id := range deps.Constituents() {
tables = append(tables, fmt.Sprintf("%d", id))
}
var expressions []string
for _, expr := range predicates {
expressions = append(expressions, sqlparser.String(expr))
}
tableConcat := strings.Join(tables, ":")
exprConcat := strings.Join(expressions, " and ")
joinPreds = append(joinPreds, fmt.Sprintf("\t%s - %s", tableConcat, exprConcat))
}
sort.Strings(joinPreds)
return fmt.Sprintf("\nJoinPredicates:\n%s", strings.Join(joinPreds, "\n"))
}
func (qg *queryGraph) tableNames() []string {
var tables []string
for _, t := range qg.tables {
tables = append(tables, t.testString())
}
return tables
}
func (qg *queryGraph) subqueriesString() string {
if len(qg.subqueries) == 0 {
return ""
}
var graphs []string
for sq, qgraphs := range qg.subqueries {
key := sqlparser.String(sq)
for _, inner := range qgraphs {
str := inner.testString()
splitInner := strings.Split(str, "\n")
for i, s := range splitInner {
splitInner[i] = "\t" + s
}
graphs = append(graphs, fmt.Sprintf("%s - %s", key, strings.Join(splitInner, "\n")))
}
}
return fmt.Sprintf("\nSubQueries:\n%s", strings.Join(graphs, "\n"))
}
func (qg *queryGraph) noDepsString() string {
if qg.noDeps == nil {
return ""
}
return fmt.Sprintf("\nForAll: %s", sqlparser.String(qg.noDeps))
}

Просмотреть файл

@ -53,7 +53,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool {
a.err = err
return false
}
case *sqlparser.DerivedTable, *sqlparser.Subquery:
case *sqlparser.DerivedTable:
a.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%T not supported", node)
case *sqlparser.TableExprs:
// this has already been visited when we encountered the SELECT struct

Просмотреть файл

@ -109,6 +109,18 @@ func (ts TableSet) NumberOfTables() int {
return count
}
// Constituents returns an slice with all the
// individual tables in their own TableSet identifier
func (ts TableSet) Constituents() (result []TableSet) {
for i := 0; i < 64; i++ {
i2 := TableSet(1 << i)
if ts&i2 == i2 {
result = append(result, i2)
}
}
return
}
// Merge creates a TableSet that contains both inputs
func (ts TableSet) Merge(other TableSet) TableSet {
return ts | other