зеркало из https://github.com/github/vitess-gh.git
Only rewrite database() against dual
Signed-off-by: Andres Taylor <andres@planetscale.com>
This commit is contained in:
Родитель
041757feaa
Коммит
0995d2a4f8
|
@ -32,6 +32,7 @@ func PrepareAST(in Statement, bindVars map[string]*querypb.BindVariable, prefix
|
|||
// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries
|
||||
func RewriteAST(in Statement) (*RewriteASTResult, error) {
|
||||
er := new(expressionRewriter)
|
||||
er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in)
|
||||
Rewrite(in, er.goingDown, nil)
|
||||
|
||||
return &RewriteASTResult{
|
||||
|
@ -41,6 +42,25 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func shouldRewriteDatabaseFunc(in Statement) bool {
|
||||
selct, ok := in.(*Select)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if len(selct.From) != 1 {
|
||||
return false
|
||||
}
|
||||
aliasedTable, ok := selct.From[0].(*AliasedTableExpr)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
tableName, ok := aliasedTable.Expr.(TableName)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return tableName.Name.String() == "dual"
|
||||
}
|
||||
|
||||
// RewriteASTResult contains the rewritten ast and meta information about it
|
||||
type RewriteASTResult struct {
|
||||
AST Statement
|
||||
|
@ -49,8 +69,9 @@ type RewriteASTResult struct {
|
|||
}
|
||||
|
||||
type expressionRewriter struct {
|
||||
lastInsertID, database bool
|
||||
err error
|
||||
lastInsertID, database bool
|
||||
shouldRewriteDatabaseFunc bool
|
||||
err error
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -67,6 +88,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
|
|||
buf := NewTrackedBuffer(nil)
|
||||
node.Expr.Format(buf)
|
||||
inner := new(expressionRewriter)
|
||||
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
|
||||
tmp := Rewrite(node.Expr, inner.goingDown, nil)
|
||||
newExpr, ok := tmp.(Expr)
|
||||
if !ok {
|
||||
|
@ -91,7 +113,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool {
|
|||
cursor.Replace(bindVarExpression(LastInsertIDName))
|
||||
er.lastInsertID = true
|
||||
}
|
||||
case node.Name.EqualString("database"):
|
||||
case node.Name.EqualString("database") && er.shouldRewriteDatabaseFunc:
|
||||
if len(node.Exprs) > 0 {
|
||||
er.err = vterrors.New(vtrpc.Code_INVALID_ARGUMENT, "Syntax error. DATABASE() takes no arguments")
|
||||
} else {
|
||||
|
|
|
@ -55,14 +55,19 @@ func TestRewrites(in *testing.T) {
|
|||
db: true, liid: true,
|
||||
},
|
||||
{
|
||||
in: "select (select database() from test) from test",
|
||||
expected: "select (select :__vtdbname as `database()` from test) as `(select database() from test)` from test",
|
||||
in: "select (select database() from dual) from dual",
|
||||
expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual",
|
||||
db: true, liid: false,
|
||||
},
|
||||
{
|
||||
in: "select id from user where database()",
|
||||
expected: "select id from user where :__vtdbname",
|
||||
db: true, liid: false,
|
||||
expected: "select id from user where database()",
|
||||
db: false, liid: false,
|
||||
},
|
||||
{
|
||||
in: "select table_name from information_schema.tables where table_schema = database()",
|
||||
expected: "select table_name from information_schema.tables where table_schema = database()",
|
||||
db: false, liid: false,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -77,16 +82,10 @@ func TestRewrites(in *testing.T) {
|
|||
expected, err := Parse(tc.expected)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := toString(expected)
|
||||
require.Equal(t, s, toString(result.AST))
|
||||
s := String(expected)
|
||||
require.Equal(t, s, String(result.AST))
|
||||
require.Equal(t, tc.liid, result.NeedLastInsertID, "should need last insert id")
|
||||
require.Equal(t, tc.db, result.NeedDatabase, "should need database name")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toString(node SQLNode) string {
|
||||
buf := NewTrackedBuffer(nil)
|
||||
node.Format(buf)
|
||||
return buf.String()
|
||||
}
|
||||
|
|
|
@ -282,7 +282,6 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
|
|||
sql = comments.Leading + normalized + comments.Trailing
|
||||
if rewriteResult.NeedDatabase {
|
||||
keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString)
|
||||
log.Warningf("This is the keyspace name: ---> %v", keyspace)
|
||||
if keyspace == "" {
|
||||
bindVars[sqlparser.DBVarName] = sqltypes.NullBindVariable
|
||||
} else {
|
||||
|
|
Загрузка…
Ссылка в новой задаче