зеркало из https://github.com/github/vitess-gh.git
add expanded colnames to semTable, error out when aliased star expr table is not found
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
This commit is contained in:
Родитель
ed06e513e7
Коммит
0f520d5e0a
|
@ -163,6 +163,7 @@ var stateToMysqlCode = map[vterrors.State]struct {
|
|||
vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError},
|
||||
vterrors.BadDb: {num: ERBadDb, state: SSClientError},
|
||||
vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError},
|
||||
vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable},
|
||||
vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError},
|
||||
vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange},
|
||||
vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState},
|
||||
|
|
|
@ -25,6 +25,7 @@ const (
|
|||
|
||||
// invalid argument
|
||||
BadFieldError
|
||||
BadTableError
|
||||
CantUseOptionHere
|
||||
DataOutOfRange
|
||||
EmptyQuery
|
||||
|
|
|
@ -55,7 +55,10 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
|
|||
return nil, err
|
||||
}
|
||||
|
||||
sel = expandStar(sel, semTable)
|
||||
sel, err = expandStar(sel, semTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qgraph, err := createQGFromSelect(sel, semTable)
|
||||
if err != nil {
|
||||
|
@ -99,66 +102,100 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
|
|||
return plan.Primitive(), nil
|
||||
}
|
||||
|
||||
func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) *sqlparser.Select {
|
||||
// TODO we could store in semTable whether there are any * in the query that needs expanding or not
|
||||
type starRewriter struct {
|
||||
err error
|
||||
semTable *semantics.SemTable
|
||||
}
|
||||
|
||||
_ = sqlparser.Rewrite(sel, func(cursor *sqlparser.Cursor) bool {
|
||||
switch node := cursor.Node().(type) {
|
||||
case *sqlparser.Select:
|
||||
tables := semTable.GetSelectTables(node)
|
||||
var selExprs sqlparser.SelectExprs
|
||||
for _, selectExpr := range node.SelectExprs {
|
||||
starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
|
||||
if !isStarExpr {
|
||||
selExprs = append(selExprs, selectExpr)
|
||||
continue
|
||||
}
|
||||
var colNames sqlparser.SelectExprs
|
||||
expandStar := false
|
||||
for _, tbl := range tables {
|
||||
if !starExpr.TableName.IsEmpty() {
|
||||
if !tbl.ASTNode.As.IsEmpty() {
|
||||
if !starExpr.TableName.Qualifier.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !starExpr.TableName.Qualifier.IsEmpty() {
|
||||
if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name {
|
||||
continue
|
||||
}
|
||||
}
|
||||
tblName := tbl.ASTNode.Expr.(sqlparser.TableName)
|
||||
if starExpr.TableName.Name.String() != tblName.Name.String() {
|
||||
func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool {
|
||||
switch node := cursor.Node().(type) {
|
||||
case *sqlparser.Select:
|
||||
tables := sr.semTable.GetSelectTables(node)
|
||||
var selExprs sqlparser.SelectExprs
|
||||
for _, selectExpr := range node.SelectExprs {
|
||||
starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
|
||||
if !isStarExpr {
|
||||
selExprs = append(selExprs, selectExpr)
|
||||
continue
|
||||
}
|
||||
expStar := &expandStarInfo{
|
||||
tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{},
|
||||
}
|
||||
var colNames sqlparser.SelectExprs
|
||||
unknownTbl := true
|
||||
for _, tbl := range tables {
|
||||
if !starExpr.TableName.IsEmpty() {
|
||||
if !tbl.ASTNode.As.IsEmpty() {
|
||||
if !starExpr.TableName.Qualifier.IsEmpty() {
|
||||
continue
|
||||
}
|
||||
if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !starExpr.TableName.Qualifier.IsEmpty() {
|
||||
if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if !tbl.Table.ColumnListAuthoritative {
|
||||
expandStar = false
|
||||
break
|
||||
}
|
||||
expandStar = true
|
||||
for _, col := range tbl.Table.Columns {
|
||||
colNames = append(colNames, &sqlparser.AliasedExpr{
|
||||
Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), sqlparser.TableName{Name: tbl.Table.Name}),
|
||||
})
|
||||
tblName := tbl.ASTNode.Expr.(sqlparser.TableName)
|
||||
if starExpr.TableName.Name.String() != tblName.Name.String() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if !expandStar {
|
||||
selExprs = append(selExprs, selectExpr)
|
||||
continue
|
||||
unknownTbl = false
|
||||
if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative {
|
||||
expStar.proceed = false
|
||||
break
|
||||
}
|
||||
selExprs = append(selExprs, colNames...)
|
||||
expStar.proceed = true
|
||||
tblName, err := tbl.ASTNode.TableName()
|
||||
if err != nil {
|
||||
sr.err = err
|
||||
return false
|
||||
}
|
||||
for _, col := range tbl.Table.Columns {
|
||||
colNames = append(colNames, &sqlparser.AliasedExpr{
|
||||
Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), tblName),
|
||||
As: sqlparser.NewColIdent(col.Name.String()),
|
||||
})
|
||||
}
|
||||
expStar.tblColMap[tbl.ASTNode] = colNames
|
||||
}
|
||||
if unknownTbl {
|
||||
// This will only happen for case when starExpr has qualifier.
|
||||
sr.err = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName))
|
||||
return false
|
||||
}
|
||||
if !expStar.proceed {
|
||||
selExprs = append(selExprs, selectExpr)
|
||||
continue
|
||||
}
|
||||
selExprs = append(selExprs, colNames...)
|
||||
for tbl, cols := range expStar.tblColMap {
|
||||
sr.semTable.AddExprs(tbl, cols)
|
||||
}
|
||||
node.SelectExprs = selExprs
|
||||
}
|
||||
node.SelectExprs = selExprs
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}, nil)
|
||||
return sel
|
||||
type expandStarInfo struct {
|
||||
proceed bool
|
||||
tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs
|
||||
}
|
||||
|
||||
func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) {
|
||||
// TODO we could store in semTable whether there are any * in the query that needs expanding or not
|
||||
sr := &starRewriter{semTable: semTable}
|
||||
|
||||
_ = sqlparser.Rewrite(sel, sr.starRewrite, nil)
|
||||
if sr.err != nil {
|
||||
return nil, sr.err
|
||||
}
|
||||
return sel, nil
|
||||
}
|
||||
|
||||
func planLimit(limit *sqlparser.Limit, plan logicalPlan) (logicalPlan, error) {
|
||||
|
|
|
@ -169,27 +169,31 @@ func TestExpandStar(t *testing.T) {
|
|||
tcases := []struct {
|
||||
sql string
|
||||
expSQL string
|
||||
expErr string
|
||||
}{{
|
||||
sql: "select * from t1",
|
||||
expSQL: "select t1.a, t1.b, t1.c from t1",
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
|
||||
}, {
|
||||
sql: "select t1.* from t1",
|
||||
expSQL: "select t1.a, t1.b, t1.c from t1",
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
|
||||
}, {
|
||||
sql: "select *, 42, t1.* from t1",
|
||||
expSQL: "select t1.a, t1.b, t1.c, 42, t1.a, t1.b, t1.c from t1",
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1",
|
||||
}, {
|
||||
sql: "select 42, t1.* from t1",
|
||||
expSQL: "select 42, t1.a, t1.b, t1.c from t1",
|
||||
expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1",
|
||||
}, {
|
||||
sql: "select * from t1, t2",
|
||||
expSQL: "select t1.a, t1.b, t1.c, t2.c1, t2.c2 from t1, t2",
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2",
|
||||
}, {
|
||||
sql: "select t1.* from t1, t2",
|
||||
expSQL: "select t1.a, t1.b, t1.c from t1, t2",
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2",
|
||||
}, {
|
||||
sql: "select *, t1.* from t1, t2",
|
||||
expSQL: "select t1.a, t1.b, t1.c, t2.c1, t2.c2, t1.a, t1.b, t1.c from t1, t2",
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2",
|
||||
}, { // aliased table
|
||||
sql: "select * from t1 a, t2 b",
|
||||
expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b",
|
||||
}, { // t3 is non-authoritative table
|
||||
sql: "select * from t3",
|
||||
expSQL: "select * from t3",
|
||||
|
@ -198,19 +202,78 @@ func TestExpandStar(t *testing.T) {
|
|||
expSQL: "select * from t1, t2, t3",
|
||||
}, { // t3 is non-authoritative table
|
||||
sql: "select t1.*, t2.*, t3.* from t1, t2, t3",
|
||||
expSQL: "select t1.a, t1.b, t1.c, t2.c1, t2.c2, t3.* from t1, t2, t3",
|
||||
}, { // TODO: This should fail on analyze step and should not reach down.
|
||||
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3",
|
||||
}, {
|
||||
sql: "select foo.* from t1, t2",
|
||||
expSQL: "select foo.* from t1, t2",
|
||||
expErr: "Unknown table 'foo'",
|
||||
}}
|
||||
for _, tcase := range tcases {
|
||||
t.Run(tcase.sql, func(t *testing.T) {
|
||||
ast, err := sqlparser.Parse(tcase.sql)
|
||||
require.NoError(t, err)
|
||||
semState, err := semantics.Analyze(ast, cDB, schemaInfo)
|
||||
semTable, err := semantics.Analyze(ast, cDB, schemaInfo)
|
||||
require.NoError(t, err)
|
||||
expanded := expandStar(ast.(*sqlparser.Select), semState)
|
||||
assert.Equal(t, tcase.expSQL, sqlparser.String(expanded))
|
||||
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
|
||||
if tcase.expErr == "" {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
|
||||
} else {
|
||||
require.EqualError(t, err, tcase.expErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSemTableDependenciesAfterExpandStar(t *testing.T) {
|
||||
schemaInfo := &fakeSI{tables: map[string]*vindexes.Table{
|
||||
"t1": {
|
||||
Name: sqlparser.NewTableIdent("t1"),
|
||||
Columns: []vindexes.Column{{
|
||||
Name: sqlparser.NewColIdent("a"),
|
||||
Type: sqltypes.VarChar,
|
||||
}},
|
||||
ColumnListAuthoritative: true,
|
||||
}}}
|
||||
tcases := []struct {
|
||||
sql string
|
||||
expSQL string
|
||||
sameTbl int
|
||||
otherTbl int
|
||||
expandedCol int
|
||||
}{{
|
||||
sql: "select a, * from t1",
|
||||
expSQL: "select a, t1.a as a from t1",
|
||||
otherTbl: -1, sameTbl: 0, expandedCol: 1,
|
||||
}, {
|
||||
sql: "select t2.a, t1.a, t1.* from t1, t2",
|
||||
expSQL: "select t2.a, t1.a, t1.a as a from t1, t2",
|
||||
otherTbl: 0, sameTbl: 1, expandedCol: 2,
|
||||
}, {
|
||||
sql: "select t2.a, t.a, t.* from t1 t, t2",
|
||||
expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2",
|
||||
otherTbl: 0, sameTbl: 1, expandedCol: 2,
|
||||
}}
|
||||
for _, tcase := range tcases {
|
||||
t.Run(tcase.sql, func(t *testing.T) {
|
||||
ast, err := sqlparser.Parse(tcase.sql)
|
||||
require.NoError(t, err)
|
||||
semTable, err := semantics.Analyze(ast, "", schemaInfo)
|
||||
require.NoError(t, err)
|
||||
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
|
||||
if tcase.otherTbl != -1 {
|
||||
assert.NotEqual(t,
|
||||
semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr),
|
||||
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
|
||||
)
|
||||
}
|
||||
if tcase.sameTbl != -1 {
|
||||
assert.Equal(t,
|
||||
semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr),
|
||||
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -168,7 +168,7 @@ func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColN
|
|||
|
||||
var tblInfo *TableInfo
|
||||
for _, tbl := range current.tables {
|
||||
if !tbl.Table.ColumnListAuthoritative {
|
||||
if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative {
|
||||
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr)))
|
||||
}
|
||||
for _, col := range tbl.Table.Columns {
|
||||
|
|
|
@ -107,6 +107,14 @@ func (st *SemTable) GetSelectTables(node *sqlparser.Select) []*TableInfo {
|
|||
return scope.tables
|
||||
}
|
||||
|
||||
// AddExprs adds new select exprs to the SemTable.
|
||||
func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) {
|
||||
tableSet := st.TableSetFor(tbl)
|
||||
for _, col := range cols {
|
||||
st.exprDependencies[col.(*sqlparser.AliasedExpr).Expr] = tableSet
|
||||
}
|
||||
}
|
||||
|
||||
func newScope(parent *scope) *scope {
|
||||
return &scope{parent: parent}
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче