vitess-gh/go/vt/sqlparser/ast.go

357 строки
9.1 KiB
Go

// Copyright 2012, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sqlparser
import (
"bytes"
"fmt"
"github.com/youtube/vitess/go/sqltypes"
)
type ParserError struct {
Message string
}
func NewParserError(format string, args ...interface{}) ParserError {
return ParserError{fmt.Sprintf(format, args...)}
}
func (err ParserError) Error() string {
return err.Message
}
func handleError(err *error) {
if x := recover(); x != nil {
*err = x.(ParserError)
}
}
type Node struct {
Type int
Value []byte
Sub []*Node
}
func Parse(sql string) (*Node, error) {
tokenizer := NewStringTokenizer(sql)
if yyParse(tokenizer) != 0 {
return nil, NewParserError("%s", tokenizer.LastError)
}
return tokenizer.ParseTree, nil
}
func NewSimpleParseNode(Type int, value string) *Node {
return &Node{Type: Type, Value: []byte(value)}
}
func NewParseNode(Type int, value []byte) *Node {
return &Node{Type: Type, Value: value}
}
func (node *Node) PushTwo(left *Node, right *Node) *Node {
node.Push(left)
return node.Push(right)
}
func (node *Node) Push(value *Node) *Node {
if node.Sub == nil {
node.Sub = make([]*Node, 0, 2)
}
node.Sub = append(node.Sub, value)
return node
}
func (node *Node) Pop() *Node {
node.Sub = node.Sub[:len(node.Sub)-1]
return node
}
func (node *Node) At(index int) *Node {
return node.Sub[index]
}
func (node *Node) Set(index int, val *Node) {
node.Sub[index] = val
}
func (node *Node) Len() int {
return len(node.Sub)
}
func (node *Node) LowerCase() {
node.Value = bytes.ToLower(node.Value)
}
func (node *Node) String() (out string) {
buf := NewTrackedBuffer(nil)
buf.Fprintf("%v", node)
return buf.String()
}
func (node *Node) TreeString() string {
buf := bytes.NewBuffer(make([]byte, 0, 8))
node.NodeString(0, buf)
return buf.String()
}
func (node *Node) NodeString(level int, buf *bytes.Buffer) {
for i := 0; i < level; i++ {
buf.WriteString("|-")
}
buf.Write(node.Value)
buf.WriteByte('\n')
for i := 0; i < node.Len(); i++ {
node.At(i).NodeString(level+1, buf)
}
}
// FormatNode is the standard node formatter that
// generates the SQL statement from the AST.
func FormatNode(buf *TrackedBuffer, node *Node) {
switch node.Type {
case SELECT:
buf.Fprintf("select %v%v%v from %v%v%v%v%v%v%v",
node.At(SELECT_COMMENT_OFFSET),
node.At(SELECT_DISTINCT_OFFSET),
node.At(SELECT_EXPR_OFFSET),
node.At(SELECT_FROM_OFFSET),
node.At(SELECT_WHERE_OFFSET),
node.At(SELECT_GROUP_OFFSET),
node.At(SELECT_HAVING_OFFSET),
node.At(SELECT_ORDER_OFFSET),
node.At(SELECT_LIMIT_OFFSET),
node.At(SELECT_LOCK_OFFSET),
)
case INSERT:
buf.Fprintf("insert %vinto %v%v %v%v",
node.At(INSERT_COMMENT_OFFSET),
node.At(INSERT_TABLE_OFFSET),
node.At(INSERT_COLUMN_LIST_OFFSET),
node.At(INSERT_VALUES_OFFSET),
node.At(INSERT_ON_DUP_OFFSET),
)
case UPDATE:
buf.Fprintf("update %v%v set %v%v%v%v",
node.At(UPDATE_COMMENT_OFFSET),
node.At(UPDATE_TABLE_OFFSET),
node.At(UPDATE_LIST_OFFSET),
node.At(UPDATE_WHERE_OFFSET),
node.At(UPDATE_ORDER_OFFSET),
node.At(UPDATE_LIMIT_OFFSET),
)
case DELETE:
buf.Fprintf("delete %vfrom %v%v%v%v",
node.At(DELETE_COMMENT_OFFSET),
node.At(DELETE_TABLE_OFFSET),
node.At(DELETE_WHERE_OFFSET),
node.At(DELETE_ORDER_OFFSET),
node.At(DELETE_LIMIT_OFFSET),
)
case SET:
buf.Fprintf("set %v%v", node.At(0), node.At(1))
case CREATE, ALTER, DROP:
buf.Fprintf("%s table %v", node.Value, node.At(0))
case RENAME:
buf.Fprintf("%s table %v %v", node.Value, node.At(0), node.At(1))
case TABLE_EXPR:
buf.Fprintf("%v", node.At(0))
if node.At(1).Len() == 1 {
buf.Fprintf(" as %v", node.At(1).At(0))
}
buf.Fprintf("%v", node.At(2))
case USE, FORCE:
if node.Len() != 0 {
buf.Fprintf(" %s index %v", node.Value, node.At(0))
}
case WHERE, HAVING:
if node.Len() > 0 {
buf.Fprintf(" %s %v", node.Value, node.At(0))
}
case ORDER, GROUP:
if node.Len() > 0 {
buf.Fprintf(" %s by %v", node.Value, node.At(0))
}
case LIMIT:
if node.Len() > 0 {
buf.Fprintf(" %s %v", node.Value, node.At(0))
if node.Len() > 1 {
buf.Fprintf(", %v", node.At(1))
}
}
case COLUMN_LIST, INDEX_LIST:
if node.Len() > 0 {
buf.Fprintf("(%v", node.At(0))
for i := 1; i < node.Len(); i++ {
buf.Fprintf(", %v", node.At(i))
}
buf.WriteByte(')')
}
case NODE_LIST:
if node.Len() > 0 {
buf.Fprintf("%v", node.At(0))
for i := 1; i < node.Len(); i++ {
buf.Fprintf(", %v", node.At(i))
}
}
case COMMENT_LIST:
if node.Len() > 0 {
for i := 0; i < node.Len(); i++ {
buf.Fprintf("%v", node.At(i))
}
}
case WHEN_LIST:
buf.Fprintf("%v", node.At(0))
for i := 1; i < node.Len(); i++ {
buf.Fprintf(" %v", node.At(i))
}
case JOIN, STRAIGHT_JOIN, LEFT, RIGHT, CROSS, NATURAL:
buf.Fprintf("%v %s %v", node.At(0), node.Value, node.At(1))
if node.Len() > 2 {
buf.Fprintf(" on %v", node.At(2))
}
case DUPLICATE:
if node.Len() != 0 {
buf.Fprintf(" on duplicate key update %v", node.At(0))
}
case NUMBER, NULL, SELECT_STAR, NO_DISTINCT, COMMENT, NO_LOCK, FOR_UPDATE, LOCK_IN_SHARE_MODE, TABLE:
buf.Fprintf("%s", node.Value)
case ID:
if _, ok := keywords[string(node.Value)]; ok {
buf.Fprintf("`%s`", node.Value)
} else {
buf.Fprintf("%s", node.Value)
}
case VALUE_ARG:
buf.WriteArg(string(node.Value[1:]))
case STRING:
s := sqltypes.MakeString(node.Value)
s.EncodeSql(buf)
case '+', '-', '*', '/', '%', '&', '|', '^', '.':
buf.Fprintf("%v%s%v", node.At(0), node.Value, node.At(1))
case CASE_WHEN:
buf.Fprintf("case %v end", node.At(0))
case CASE:
buf.Fprintf("case %v %v end", node.At(0), node.At(1))
case WHEN:
buf.Fprintf("when %v then %v", node.At(0), node.At(1))
case ELSE:
buf.Fprintf("else %v", node.At(0))
case '=', '>', '<', GE, LE, NE, NULL_SAFE_EQUAL, AS, AND, OR, UNION, UNION_ALL, MINUS, EXCEPT, INTERSECT, LIKE, NOT_LIKE, IN, NOT_IN:
buf.Fprintf("%v %s %v", node.At(0), node.Value, node.At(1))
case '(':
buf.Fprintf("(%v)", node.At(0))
case EXISTS:
buf.Fprintf("%s (%v)", node.Value, node.At(0))
case FUNCTION:
if node.Len() == 2 { // DISTINCT
buf.Fprintf("%s(%v%v)", node.Value, node.At(0), node.At(1))
} else {
buf.Fprintf("%s(%v)", node.Value, node.At(0))
}
case UPLUS, UMINUS, '~':
buf.Fprintf("%s%v", node.Value, node.At(0))
case NOT, VALUES:
buf.Fprintf("%s %v", node.Value, node.At(0))
case ASC, DESC, IS_NULL, IS_NOT_NULL:
buf.Fprintf("%v %s", node.At(0), node.Value)
case BETWEEN, NOT_BETWEEN:
buf.Fprintf("%v %s %v and %v", node.At(0), node.Value, node.At(1), node.At(2))
case DISTINCT:
buf.Fprintf("%s ", node.Value)
default:
buf.Fprintf("Unknown: %s", node.Value)
}
}
// AnonymizedFormatNode is just like FormatNode except that
// it anonymizes all values in the SQL.
func AnonymizedFormatNode(buf *TrackedBuffer, node *Node) {
switch node.Type {
case STRING, NUMBER:
buf.Fprintf("?")
default:
FormatNode(buf, node)
}
}
// TrackedBuffer is used to rebuild a query from the ast.
// bindLocations keeps track of locations in the buffer that
// use bind variables for efficient future substitutions.
// nodeFormatter is the formatting function the buffer will
// use to format a node. By default(nil), it's FormatNode.
// But you can supply a different formatting function if you
// want to generate a query that's different from the default.
type TrackedBuffer struct {
*bytes.Buffer
bindLocations []BindLocation
nodeFormatter func(buf *TrackedBuffer, node *Node)
}
func NewTrackedBuffer(nodeFormatter func(buf *TrackedBuffer, node *Node)) *TrackedBuffer {
if nodeFormatter == nil {
nodeFormatter = FormatNode
}
buf := &TrackedBuffer{
Buffer: bytes.NewBuffer(make([]byte, 0, 128)),
bindLocations: make([]BindLocation, 0, 4),
nodeFormatter: nodeFormatter,
}
return buf
}
// Fprintf mimics fmt.Fprintf, but limited to Node(%v), Node.Value(%s) and string(%s).
// It also allows a %a for a value argument, in which case it adds tracking info for
// future substitutions.
func (buf *TrackedBuffer) Fprintf(format string, values ...interface{}) {
end := len(format)
fieldnum := 0
for i := 0; i < end; {
lasti := i
for i < end && format[i] != '%' {
i++
}
if i > lasti {
buf.WriteString(format[lasti:i])
}
if i >= end {
break
}
i++ // '%'
switch format[i] {
case 's':
switch v := values[fieldnum].(type) {
case []byte:
buf.Write(v)
case string:
buf.WriteString(v)
default:
panic(fmt.Sprintf("unexpected type %T", v))
}
case 'v':
node := values[fieldnum].(*Node)
buf.nodeFormatter(buf, node)
case 'a':
buf.WriteArg(values[fieldnum].(string))
default:
panic("unexpected")
}
fieldnum++
i++
}
}
// WriteArg writes a value argument into the buffer. arg should not contain
// the ':' prefix. It also adds tracking info for future substitutions.
func (buf *TrackedBuffer) WriteArg(arg string) {
buf.bindLocations = append(buf.bindLocations, BindLocation{buf.Len(), len(arg) + 1})
buf.WriteString(":")
buf.WriteString(arg)
}
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
return &ParsedQuery{buf.String(), buf.bindLocations}
}