зеркало из https://github.com/github/vitess-gh.git
Merge pull request #7313 from planetscale/v4-subquery
Gen4: Handling subquery in query graph
This commit is contained in:
Коммит
d09dffef0c
|
@ -41,6 +41,9 @@ type (
|
|||
|
||||
// noDeps contains the predicates that can be evaluated anywhere.
|
||||
noDeps sqlparser.Expr
|
||||
|
||||
// subqueries contains the subqueries that depend on this query graph
|
||||
subqueries map[*sqlparser.Subquery][]*queryGraph
|
||||
}
|
||||
|
||||
// queryTable is a single FROM table, including all predicates particular to this table
|
||||
|
@ -79,9 +82,38 @@ func createQGFromSelect(sel *sqlparser.Select, semTable *semantics.SemTable) (*q
|
|||
return qg, nil
|
||||
}
|
||||
|
||||
func createQGFromSelectStatement(selStmt sqlparser.SelectStatement, semTable *semantics.SemTable) ([]*queryGraph, error) {
|
||||
switch stmt := selStmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
qg, err := createQGFromSelect(stmt, semTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []*queryGraph{qg}, err
|
||||
case *sqlparser.Union:
|
||||
qg, err := createQGFromSelectStatement(stmt.FirstStatement, semTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, sel := range stmt.UnionSelects {
|
||||
qgr, err := createQGFromSelectStatement(sel.Statement, semTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qg = append(qg, qgr...)
|
||||
}
|
||||
return qg, nil
|
||||
case *sqlparser.ParenSelect:
|
||||
return createQGFromSelectStatement(stmt.Select, semTable)
|
||||
}
|
||||
|
||||
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: not reachable %T", selStmt)
|
||||
}
|
||||
|
||||
func newQueryGraph() *queryGraph {
|
||||
return &queryGraph{
|
||||
crossTable: map[semantics.TableSet][]sqlparser.Expr{},
|
||||
subqueries: map[*sqlparser.Subquery][]*queryGraph{},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -156,7 +188,20 @@ func (qg *queryGraph) collectPredicate(predicate sqlparser.Expr, semTable *seman
|
|||
}
|
||||
qg.crossTable[deps] = allPredicates
|
||||
}
|
||||
return nil
|
||||
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
|
||||
switch subQuery := node.(type) {
|
||||
case *sqlparser.Subquery:
|
||||
|
||||
qgr, err := createQGFromSelectStatement(subQuery.Select, semTable)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
qg.subqueries[subQuery] = qgr
|
||||
}
|
||||
return true, nil
|
||||
}, predicate)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (qg *queryGraph) addToSingleTable(table semantics.TableSet, predicate sqlparser.Expr) bool {
|
||||
|
|
|
@ -18,130 +18,84 @@ package planbuilder
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"vitess.io/vitess/go/test/utils"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"vitess.io/vitess/go/vt/sqlparser"
|
||||
"vitess.io/vitess/go/vt/vtgate/semantics"
|
||||
"vitess.io/vitess/go/vt/vtgate/vindexes"
|
||||
)
|
||||
|
||||
type tcase struct {
|
||||
input string
|
||||
output *queryGraph
|
||||
input, output string
|
||||
}
|
||||
|
||||
var threeWayJoin = &queryGraph{
|
||||
tables: []*queryTable{{
|
||||
tableID: 1,
|
||||
alias: tableAlias("t"),
|
||||
table: tableName("t"),
|
||||
predicates: []sqlparser.Expr{equals(colName("t", "name"), literalString("foo"))},
|
||||
}, {
|
||||
tableID: 2,
|
||||
alias: tableAlias("y"),
|
||||
table: tableName("y"),
|
||||
predicates: []sqlparser.Expr{equals(colName("y", "col"), literalInt(42))},
|
||||
}, {
|
||||
tableID: 4,
|
||||
alias: tableAlias("z"),
|
||||
table: tableName("z"),
|
||||
predicates: []sqlparser.Expr{equals(colName("z", "baz"), literalInt(101))},
|
||||
}},
|
||||
crossTable: map[semantics.TableSet][]sqlparser.Expr{
|
||||
1 | 2: {
|
||||
equals(
|
||||
colName("t", "id"),
|
||||
colName("y", "t_id"))},
|
||||
1 | 4: {
|
||||
equals(
|
||||
colName("t", "id"),
|
||||
colName("z", "t_id"))}}}
|
||||
|
||||
var tcases = []tcase{{
|
||||
input: "select * from t",
|
||||
output: &queryGraph{
|
||||
tables: []*queryTable{{
|
||||
tableID: 1,
|
||||
alias: tableAlias("t"),
|
||||
table: tableName("t"),
|
||||
}},
|
||||
crossTable: map[semantics.TableSet][]sqlparser.Expr{},
|
||||
},
|
||||
output: `{
|
||||
Tables:
|
||||
1:t
|
||||
}`,
|
||||
}, {
|
||||
input: "select t.c from t,y,z where t.c = y.c and (t.a = z.a or t.a = y.a) and 1 < 2",
|
||||
output: &queryGraph{
|
||||
tables: []*queryTable{{
|
||||
tableID: 1,
|
||||
alias: tableAlias("t"),
|
||||
table: tableName("t"),
|
||||
}, {
|
||||
tableID: 2,
|
||||
alias: tableAlias("y"),
|
||||
table: tableName("y"),
|
||||
}, {
|
||||
tableID: 4,
|
||||
alias: tableAlias("z"),
|
||||
table: tableName("z"),
|
||||
}},
|
||||
crossTable: map[semantics.TableSet][]sqlparser.Expr{
|
||||
1 | 2: {
|
||||
equals(
|
||||
colName("t", "c"),
|
||||
colName("y", "c"))},
|
||||
1 | 2 | 4: {
|
||||
or(
|
||||
equals(
|
||||
colName("t", "a"),
|
||||
colName("z", "a")),
|
||||
equals(
|
||||
colName("t", "a"),
|
||||
colName("y", "a")))},
|
||||
},
|
||||
noDeps: &sqlparser.ComparisonExpr{
|
||||
Operator: sqlparser.LessThanOp,
|
||||
Left: literalInt(1),
|
||||
Right: literalInt(2)},
|
||||
},
|
||||
output: `{
|
||||
Tables:
|
||||
1:t
|
||||
2:y
|
||||
4:z
|
||||
JoinPredicates:
|
||||
1:2 - t.c = y.c
|
||||
1:2:4 - t.a = z.a or t.a = y.a
|
||||
ForAll: 1 < 2
|
||||
}`,
|
||||
}, {
|
||||
input: "select t.c from t join y on t.id = y.t_id join z on t.id = z.t_id where t.name = 'foo' and y.col = 42 and z.baz = 101",
|
||||
output: threeWayJoin,
|
||||
input: "select t.c from t join y on t.id = y.t_id join z on t.id = z.t_id where t.name = 'foo' and y.col = 42 and z.baz = 101",
|
||||
output: `{
|
||||
Tables:
|
||||
1:t where t.` + "`name`" + ` = 'foo'
|
||||
2:y where y.col = 42
|
||||
4:z where z.baz = 101
|
||||
JoinPredicates:
|
||||
1:2 - t.id = y.t_id
|
||||
1:4 - t.id = z.t_id
|
||||
}`,
|
||||
}, {
|
||||
input: "select t.c from t,y,z where t.name = 'foo' and y.col = 42 and z.baz = 101 and t.id = y.t_id and t.id = z.t_id",
|
||||
output: threeWayJoin,
|
||||
input: "select t.c from t,y,z where t.name = 'foo' and y.col = 42 and z.baz = 101 and t.id = y.t_id and t.id = z.t_id",
|
||||
output: `{
|
||||
Tables:
|
||||
1:t where t.` + "`name`" + ` = 'foo'
|
||||
2:y where y.col = 42
|
||||
4:z where z.baz = 101
|
||||
JoinPredicates:
|
||||
1:2 - t.id = y.t_id
|
||||
1:4 - t.id = z.t_id
|
||||
}`,
|
||||
}, {
|
||||
input: "select 1 from t where '1' = 1 and 12 = '12'",
|
||||
output: &queryGraph{
|
||||
tables: []*queryTable{{
|
||||
tableID: 1,
|
||||
alias: tableAlias("t"),
|
||||
table: tableName("t"),
|
||||
}},
|
||||
crossTable: map[semantics.TableSet][]sqlparser.Expr{},
|
||||
noDeps: &sqlparser.AndExpr{
|
||||
Left: equals(literalString("1"), literalInt(1)),
|
||||
Right: equals(literalInt(12), literalString("12")),
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
func literalInt(i int) *sqlparser.Literal {
|
||||
return &sqlparser.Literal{Type: sqlparser.IntVal, Val: []byte(fmt.Sprintf("%d", i))}
|
||||
}
|
||||
|
||||
func literalString(s string) *sqlparser.Literal {
|
||||
return &sqlparser.Literal{Type: sqlparser.StrVal, Val: []byte(s)}
|
||||
}
|
||||
|
||||
func or(left, right sqlparser.Expr) sqlparser.Expr {
|
||||
return &sqlparser.OrExpr{
|
||||
Left: left,
|
||||
Right: right,
|
||||
output: `{
|
||||
Tables:
|
||||
1:t
|
||||
ForAll: '1' = 1 and 12 = '12'
|
||||
}`,
|
||||
}, {
|
||||
input: "select 1 from t where exists (select 1)",
|
||||
output: `{
|
||||
Tables:
|
||||
1:t
|
||||
ForAll: exists (select 1 from dual)
|
||||
SubQueries:
|
||||
(select 1 from dual) - {
|
||||
Tables:
|
||||
2:dual
|
||||
}
|
||||
}
|
||||
}`,
|
||||
}}
|
||||
|
||||
func equals(left, right sqlparser.Expr) sqlparser.Expr {
|
||||
return &sqlparser.ComparisonExpr{
|
||||
|
@ -155,40 +109,127 @@ func colName(table, column string) *sqlparser.ColName {
|
|||
return &sqlparser.ColName{Name: sqlparser.NewColIdent(column), Qualifier: tableName(table)}
|
||||
}
|
||||
|
||||
func tableAlias(name string) *sqlparser.AliasedTableExpr {
|
||||
return &sqlparser.AliasedTableExpr{Expr: sqlparser.TableName{Name: sqlparser.NewTableIdent(name)}}
|
||||
}
|
||||
|
||||
func tableName(name string) sqlparser.TableName {
|
||||
return sqlparser.TableName{Name: sqlparser.NewTableIdent(name)}
|
||||
}
|
||||
|
||||
type schemaInf struct{}
|
||||
|
||||
func (node *schemaInf) FindTable(tablename sqlparser.TableName) (*vindexes.Table, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestQueryGraph(t *testing.T) {
|
||||
for _, tc := range tcases {
|
||||
for i, tc := range tcases {
|
||||
sql := tc.input
|
||||
t.Run(sql, func(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) {
|
||||
tree, err := sqlparser.Parse(sql)
|
||||
require.NoError(t, err)
|
||||
semTable, err := semantics.Analyse(tree)
|
||||
require.NoError(t, err)
|
||||
qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable)
|
||||
require.NoError(t, err)
|
||||
mustMatch(t, tc.output, qgraph, "incorrect query graph")
|
||||
fmt.Println(qgraph.testString())
|
||||
assert.Equal(t, tc.output, qgraph.testString())
|
||||
utils.MustMatch(t, tc.output, qgraph.testString(), "incorrect query graph")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var mustMatch = utils.MustMatchFn(
|
||||
[]interface{}{ // types with unexported fields
|
||||
queryGraph{},
|
||||
queryTable{},
|
||||
sqlparser.TableIdent{},
|
||||
},
|
||||
[]string{}, // ignored fields
|
||||
)
|
||||
func TestString(t *testing.T) {
|
||||
tree, err := sqlparser.Parse("select * from a,b join c on b.id = c.id where a.id = b.id and b.col IN (select 42) and func() = 'foo'")
|
||||
require.NoError(t, err)
|
||||
semTable, err := semantics.Analyse(tree)
|
||||
require.NoError(t, err)
|
||||
qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable)
|
||||
require.NoError(t, err)
|
||||
utils.MustMatch(t, `{
|
||||
Tables:
|
||||
1:a
|
||||
2:b where b.col in (select 42 from dual)
|
||||
4:c
|
||||
JoinPredicates:
|
||||
1:2 - a.id = b.id
|
||||
2:4 - b.id = c.id
|
||||
ForAll: func() = 'foo'
|
||||
SubQueries:
|
||||
(select 42 from dual) - {
|
||||
Tables:
|
||||
8:dual
|
||||
}
|
||||
}`, qgraph.testString())
|
||||
}
|
||||
|
||||
func (qt *queryTable) testString() string {
|
||||
var alias string
|
||||
if !qt.alias.As.IsEmpty() {
|
||||
alias = " AS " + sqlparser.String(qt.alias.As)
|
||||
}
|
||||
var preds []string
|
||||
for _, predicate := range qt.predicates {
|
||||
preds = append(preds, sqlparser.String(predicate))
|
||||
}
|
||||
var where string
|
||||
if len(preds) > 0 {
|
||||
where = " where " + strings.Join(preds, " and ")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("\t%d:%s%s%s", qt.tableID, sqlparser.String(qt.table), alias, where)
|
||||
}
|
||||
|
||||
func (qg *queryGraph) testString() string {
|
||||
return fmt.Sprintf(`{
|
||||
Tables:
|
||||
%s%s%s%s
|
||||
}`, strings.Join(qg.tableNames(), "\n"), qg.crossPredicateString(), qg.noDepsString(), qg.subqueriesString())
|
||||
}
|
||||
|
||||
func (qg *queryGraph) crossPredicateString() string {
|
||||
if len(qg.crossTable) == 0 {
|
||||
return ""
|
||||
}
|
||||
var joinPreds []string
|
||||
for deps, predicates := range qg.crossTable {
|
||||
var tables []string
|
||||
for _, id := range deps.Constituents() {
|
||||
tables = append(tables, fmt.Sprintf("%d", id))
|
||||
}
|
||||
var expressions []string
|
||||
for _, expr := range predicates {
|
||||
expressions = append(expressions, sqlparser.String(expr))
|
||||
}
|
||||
tableConcat := strings.Join(tables, ":")
|
||||
exprConcat := strings.Join(expressions, " and ")
|
||||
joinPreds = append(joinPreds, fmt.Sprintf("\t%s - %s", tableConcat, exprConcat))
|
||||
}
|
||||
sort.Strings(joinPreds)
|
||||
return fmt.Sprintf("\nJoinPredicates:\n%s", strings.Join(joinPreds, "\n"))
|
||||
}
|
||||
|
||||
func (qg *queryGraph) tableNames() []string {
|
||||
var tables []string
|
||||
for _, t := range qg.tables {
|
||||
tables = append(tables, t.testString())
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func (qg *queryGraph) subqueriesString() string {
|
||||
if len(qg.subqueries) == 0 {
|
||||
return ""
|
||||
}
|
||||
var graphs []string
|
||||
for sq, qgraphs := range qg.subqueries {
|
||||
key := sqlparser.String(sq)
|
||||
for _, inner := range qgraphs {
|
||||
str := inner.testString()
|
||||
splitInner := strings.Split(str, "\n")
|
||||
for i, s := range splitInner {
|
||||
splitInner[i] = "\t" + s
|
||||
}
|
||||
graphs = append(graphs, fmt.Sprintf("%s - %s", key, strings.Join(splitInner, "\n")))
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("\nSubQueries:\n%s", strings.Join(graphs, "\n"))
|
||||
}
|
||||
|
||||
func (qg *queryGraph) noDepsString() string {
|
||||
if qg.noDeps == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("\nForAll: %s", sqlparser.String(qg.noDeps))
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool {
|
|||
a.err = err
|
||||
return false
|
||||
}
|
||||
case *sqlparser.DerivedTable, *sqlparser.Subquery:
|
||||
case *sqlparser.DerivedTable:
|
||||
a.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%T not supported", node)
|
||||
case *sqlparser.TableExprs:
|
||||
// this has already been visited when we encountered the SELECT struct
|
||||
|
|
|
@ -109,6 +109,18 @@ func (ts TableSet) NumberOfTables() int {
|
|||
return count
|
||||
}
|
||||
|
||||
// Constituents returns an slice with all the
|
||||
// individual tables in their own TableSet identifier
|
||||
func (ts TableSet) Constituents() (result []TableSet) {
|
||||
for i := 0; i < 64; i++ {
|
||||
i2 := TableSet(1 << i)
|
||||
if ts&i2 == i2 {
|
||||
result = append(result, i2)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Merge creates a TableSet that contains both inputs
|
||||
func (ts TableSet) Merge(other TableSet) TableSet {
|
||||
return ts | other
|
||||
|
|
Загрузка…
Ссылка в новой задаче