add expanded colnames to semTable, error out when aliased star expr table is not found

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
This commit is contained in:
Harshit Gangal 2021-06-17 16:16:35 +05:30
Родитель ed06e513e7
Коммит 0f520d5e0a
6 изменённых файлов: 175 добавлений и 65 удалений

Просмотреть файл

@ -163,6 +163,7 @@ var stateToMysqlCode = map[vterrors.State]struct {
vterrors.AccessDeniedError: {num: ERAccessDeniedError, state: SSAccessDeniedError},
vterrors.BadDb: {num: ERBadDb, state: SSClientError},
vterrors.BadFieldError: {num: ERBadFieldError, state: SSBadFieldError},
vterrors.BadTableError: {num: ERBadTable, state: SSUnknownTable},
vterrors.CantUseOptionHere: {num: ERCantUseOptionHere, state: SSClientError},
vterrors.DataOutOfRange: {num: ERDataOutOfRange, state: SSDataOutOfRange},
vterrors.DbCreateExists: {num: ERDbCreateExists, state: SSUnknownSQLState},

Просмотреть файл

@ -25,6 +25,7 @@ const (
// invalid argument
BadFieldError
BadTableError
CantUseOptionHere
DataOutOfRange
EmptyQuery

Просмотреть файл

@ -55,7 +55,10 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
return nil, err
}
sel = expandStar(sel, semTable)
sel, err = expandStar(sel, semTable)
if err != nil {
return nil, err
}
qgraph, err := createQGFromSelect(sel, semTable)
if err != nil {
@ -99,66 +102,100 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
return plan.Primitive(), nil
}
func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) *sqlparser.Select {
// TODO we could store in semTable whether there are any * in the query that needs expanding or not
type starRewriter struct {
err error
semTable *semantics.SemTable
}
_ = sqlparser.Rewrite(sel, func(cursor *sqlparser.Cursor) bool {
switch node := cursor.Node().(type) {
case *sqlparser.Select:
tables := semTable.GetSelectTables(node)
var selExprs sqlparser.SelectExprs
for _, selectExpr := range node.SelectExprs {
starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
if !isStarExpr {
selExprs = append(selExprs, selectExpr)
continue
}
var colNames sqlparser.SelectExprs
expandStar := false
for _, tbl := range tables {
if !starExpr.TableName.IsEmpty() {
if !tbl.ASTNode.As.IsEmpty() {
if !starExpr.TableName.Qualifier.IsEmpty() {
continue
}
if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() {
continue
}
} else {
if !starExpr.TableName.Qualifier.IsEmpty() {
if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name {
continue
}
}
tblName := tbl.ASTNode.Expr.(sqlparser.TableName)
if starExpr.TableName.Name.String() != tblName.Name.String() {
func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool {
switch node := cursor.Node().(type) {
case *sqlparser.Select:
tables := sr.semTable.GetSelectTables(node)
var selExprs sqlparser.SelectExprs
for _, selectExpr := range node.SelectExprs {
starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr)
if !isStarExpr {
selExprs = append(selExprs, selectExpr)
continue
}
expStar := &expandStarInfo{
tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{},
}
var colNames sqlparser.SelectExprs
unknownTbl := true
for _, tbl := range tables {
if !starExpr.TableName.IsEmpty() {
if !tbl.ASTNode.As.IsEmpty() {
if !starExpr.TableName.Qualifier.IsEmpty() {
continue
}
if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() {
continue
}
} else {
if !starExpr.TableName.Qualifier.IsEmpty() {
if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name {
continue
}
}
}
if !tbl.Table.ColumnListAuthoritative {
expandStar = false
break
}
expandStar = true
for _, col := range tbl.Table.Columns {
colNames = append(colNames, &sqlparser.AliasedExpr{
Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), sqlparser.TableName{Name: tbl.Table.Name}),
})
tblName := tbl.ASTNode.Expr.(sqlparser.TableName)
if starExpr.TableName.Name.String() != tblName.Name.String() {
continue
}
}
}
if !expandStar {
selExprs = append(selExprs, selectExpr)
continue
unknownTbl = false
if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative {
expStar.proceed = false
break
}
selExprs = append(selExprs, colNames...)
expStar.proceed = true
tblName, err := tbl.ASTNode.TableName()
if err != nil {
sr.err = err
return false
}
for _, col := range tbl.Table.Columns {
colNames = append(colNames, &sqlparser.AliasedExpr{
Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), tblName),
As: sqlparser.NewColIdent(col.Name.String()),
})
}
expStar.tblColMap[tbl.ASTNode] = colNames
}
if unknownTbl {
// This will only happen for case when starExpr has qualifier.
sr.err = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName))
return false
}
if !expStar.proceed {
selExprs = append(selExprs, selectExpr)
continue
}
selExprs = append(selExprs, colNames...)
for tbl, cols := range expStar.tblColMap {
sr.semTable.AddExprs(tbl, cols)
}
node.SelectExprs = selExprs
}
node.SelectExprs = selExprs
}
return true
}
return true
}, nil)
return sel
type expandStarInfo struct {
proceed bool
tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs
}
func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) {
// TODO we could store in semTable whether there are any * in the query that needs expanding or not
sr := &starRewriter{semTable: semTable}
_ = sqlparser.Rewrite(sel, sr.starRewrite, nil)
if sr.err != nil {
return nil, sr.err
}
return sel, nil
}
func planLimit(limit *sqlparser.Limit, plan logicalPlan) (logicalPlan, error) {

Просмотреть файл

@ -169,27 +169,31 @@ func TestExpandStar(t *testing.T) {
tcases := []struct {
sql string
expSQL string
expErr string
}{{
sql: "select * from t1",
expSQL: "select t1.a, t1.b, t1.c from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select t1.* from t1",
expSQL: "select t1.a, t1.b, t1.c from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select *, 42, t1.* from t1",
expSQL: "select t1.a, t1.b, t1.c, 42, t1.a, t1.b, t1.c from t1",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select 42, t1.* from t1",
expSQL: "select 42, t1.a, t1.b, t1.c from t1",
expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1",
}, {
sql: "select * from t1, t2",
expSQL: "select t1.a, t1.b, t1.c, t2.c1, t2.c2 from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2",
}, {
sql: "select t1.* from t1, t2",
expSQL: "select t1.a, t1.b, t1.c from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2",
}, {
sql: "select *, t1.* from t1, t2",
expSQL: "select t1.a, t1.b, t1.c, t2.c1, t2.c2, t1.a, t1.b, t1.c from t1, t2",
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2",
}, { // aliased table
sql: "select * from t1 a, t2 b",
expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b",
}, { // t3 is non-authoritative table
sql: "select * from t3",
expSQL: "select * from t3",
@ -198,19 +202,78 @@ func TestExpandStar(t *testing.T) {
expSQL: "select * from t1, t2, t3",
}, { // t3 is non-authoritative table
sql: "select t1.*, t2.*, t3.* from t1, t2, t3",
expSQL: "select t1.a, t1.b, t1.c, t2.c1, t2.c2, t3.* from t1, t2, t3",
}, { // TODO: This should fail on analyze step and should not reach down.
expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3",
}, {
sql: "select foo.* from t1, t2",
expSQL: "select foo.* from t1, t2",
expErr: "Unknown table 'foo'",
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semState, err := semantics.Analyze(ast, cDB, schemaInfo)
semTable, err := semantics.Analyze(ast, cDB, schemaInfo)
require.NoError(t, err)
expanded := expandStar(ast.(*sqlparser.Select), semState)
assert.Equal(t, tcase.expSQL, sqlparser.String(expanded))
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
if tcase.expErr == "" {
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
} else {
require.EqualError(t, err, tcase.expErr)
}
})
}
}
func TestSemTableDependenciesAfterExpandStar(t *testing.T) {
schemaInfo := &fakeSI{tables: map[string]*vindexes.Table{
"t1": {
Name: sqlparser.NewTableIdent("t1"),
Columns: []vindexes.Column{{
Name: sqlparser.NewColIdent("a"),
Type: sqltypes.VarChar,
}},
ColumnListAuthoritative: true,
}}}
tcases := []struct {
sql string
expSQL string
sameTbl int
otherTbl int
expandedCol int
}{{
sql: "select a, * from t1",
expSQL: "select a, t1.a as a from t1",
otherTbl: -1, sameTbl: 0, expandedCol: 1,
}, {
sql: "select t2.a, t1.a, t1.* from t1, t2",
expSQL: "select t2.a, t1.a, t1.a as a from t1, t2",
otherTbl: 0, sameTbl: 1, expandedCol: 2,
}, {
sql: "select t2.a, t.a, t.* from t1 t, t2",
expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2",
otherTbl: 0, sameTbl: 1, expandedCol: 2,
}}
for _, tcase := range tcases {
t.Run(tcase.sql, func(t *testing.T) {
ast, err := sqlparser.Parse(tcase.sql)
require.NoError(t, err)
semTable, err := semantics.Analyze(ast, "", schemaInfo)
require.NoError(t, err)
expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable)
require.NoError(t, err)
assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect))
if tcase.otherTbl != -1 {
assert.NotEqual(t,
semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr),
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
)
}
if tcase.sameTbl != -1 {
assert.Equal(t,
semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr),
semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr),
)
}
})
}
}

Просмотреть файл

@ -168,7 +168,7 @@ func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColN
var tblInfo *TableInfo
for _, tbl := range current.tables {
if !tbl.Table.ColumnListAuthoritative {
if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative {
return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr)))
}
for _, col := range tbl.Table.Columns {

Просмотреть файл

@ -107,6 +107,14 @@ func (st *SemTable) GetSelectTables(node *sqlparser.Select) []*TableInfo {
return scope.tables
}
// AddExprs adds new select exprs to the SemTable.
func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) {
tableSet := st.TableSetFor(tbl)
for _, col := range cols {
st.exprDependencies[col.(*sqlparser.AliasedExpr).Expr] = tableSet
}
}
func newScope(parent *scope) *scope {
return &scope{parent: parent}
}