sqlparser: Limit nesting of parenthesized exprs

This addresses https://github.com/youtube/vitess/issues/767
as well as other situations where there's possibility of
indefinite nesting of SQL constructs. There may be other
non-aprentheszed constructs that allow nesting, but this fix
doesn't address them for now.
I've also made a few lint fixes. sql.go is still in violation,
but that requires bigger work.
This commit is contained in:
Sugu Sougoumarane 2015-07-02 06:31:34 -07:00
Родитель a9d85bec8f
Коммит ccf0fcec25
4 изменённых файлов: 487 добавлений и 444 удалений

Просмотреть файл

@ -11,4 +11,6 @@ select * from t where :1 = 2#syntax error at position 24 near :
select * from t where :. = 2#syntax error at position 24 near : select * from t where :. = 2#syntax error at position 24 near :
select * from t where ::1 = 2#syntax error at position 25 near :: select * from t where ::1 = 2#syntax error at position 25 near ::
select * from t where ::. = 2#syntax error at position 25 near :: select * from t where ::. = 2#syntax error at position 25 near ::
select(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(#max nesting level reached at position 409
select(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(F(#syntax error at position 407
select /* aa#syntax error at position 13 near /* aa select /* aa#syntax error at position 13 near /* aa

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -7,15 +7,27 @@ package sqlparser
import "bytes" import "bytes"
func SetParseTree(yylex interface{}, stmt Statement) { func setParseTree(yylex interface{}, stmt Statement) {
yylex.(*Tokenizer).ParseTree = stmt yylex.(*Tokenizer).ParseTree = stmt
} }
func SetAllowComments(yylex interface{}, allow bool) { func setAllowComments(yylex interface{}, allow bool) {
yylex.(*Tokenizer).AllowComments = allow yylex.(*Tokenizer).AllowComments = allow
} }
func ForceEOF(yylex interface{}) { func incNesting(yylex interface{}) bool {
yylex.(*Tokenizer).nesting++
if yylex.(*Tokenizer).nesting == 200 {
return true
}
return false
}
func decNesting(yylex interface{}) {
yylex.(*Tokenizer).nesting--
}
func forceEOF(yylex interface{}) {
yylex.(*Tokenizer).ForceEOF = true yylex.(*Tokenizer).ForceEOF = true
} }
@ -149,7 +161,7 @@ var (
any_command: any_command:
command command
{ {
SetParseTree(yylex, $1) setParseTree(yylex, $1)
} }
command: command:
@ -285,12 +297,12 @@ other_statement:
comment_opt: comment_opt:
{ {
SetAllowComments(yylex, true) setAllowComments(yylex, true)
} }
comment_list comment_list
{ {
$$ = $2 $$ = $2
SetAllowComments(yylex, false) setAllowComments(yylex, false)
} }
comment_list: comment_list:
@ -708,7 +720,7 @@ value_expression:
{ {
$$ = &FuncExpr{Name: $1} $$ = &FuncExpr{Name: $1}
} }
| sql_id '(' select_expression_list ')' | sql_id openb select_expression_list closeb
{ {
$$ = &FuncExpr{Name: $1, Exprs: $3} $$ = &FuncExpr{Name: $1, Exprs: $3}
} }
@ -1029,7 +1041,22 @@ sql_id:
$$ = bytes.ToLower($1) $$ = bytes.ToLower($1)
} }
openb:
'('
{
if incNesting(yylex) {
yylex.Error("max nesting level reached")
return 1
}
}
closeb:
')'
{
decNesting(yylex)
}
force_eof: force_eof:
{ {
ForceEOF(yylex) forceEOF(yylex)
} }

Просмотреть файл

@ -12,7 +12,7 @@ import (
"github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/sqltypes"
) )
const EOFCHAR = 0x100 const eofChar = 0x100
// Tokenizer is the struct used to generate SQL // Tokenizer is the struct used to generate SQL
// tokens for the parser. // tokens for the parser.
@ -26,6 +26,7 @@ type Tokenizer struct {
LastError string LastError string
posVarIndex int posVarIndex int
ParseTree Statement ParseTree Statement
nesting int
} }
// NewStringTokenizer creates a new Tokenizer for the // NewStringTokenizer creates a new Tokenizer for the
@ -157,7 +158,7 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
default: default:
tkn.next() tkn.next()
switch ch { switch ch {
case EOFCHAR: case eofChar:
return 0, nil return 0, nil
case '=', ',', ';', '(', ')', '+', '*', '%', '&', '|', '^', '~': case '=', ',', ';', '(', ')', '+', '*', '%', '&', '|', '^', '~':
return int(ch), nil return int(ch), nil
@ -169,9 +170,8 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
case '.': case '.':
if isDigit(tkn.lastChar) { if isDigit(tkn.lastChar) {
return tkn.scanNumber(true) return tkn.scanNumber(true)
} else {
return int(ch), nil
} }
return int(ch), nil
case '/': case '/':
switch tkn.lastChar { switch tkn.lastChar {
case '/': case '/':
@ -187,9 +187,8 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
if tkn.lastChar == '-' { if tkn.lastChar == '-' {
tkn.next() tkn.next()
return tkn.scanCommentType1("--") return tkn.scanCommentType1("--")
} else {
return int(ch), nil
} }
return int(ch), nil
case '<': case '<':
switch tkn.lastChar { switch tkn.lastChar {
case '>': case '>':
@ -211,16 +210,14 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
if tkn.lastChar == '=' { if tkn.lastChar == '=' {
tkn.next() tkn.next()
return GE, nil return GE, nil
} else {
return int(ch), nil
} }
return int(ch), nil
case '!': case '!':
if tkn.lastChar == '=' { if tkn.lastChar == '=' {
tkn.next() tkn.next()
return NE, nil return NE, nil
} else {
return LEX_ERROR, []byte("!")
} }
return LEX_ERROR, []byte("!")
case '\'', '"': case '\'', '"':
return tkn.scanString(ch, STRING) return tkn.scanString(ch, STRING)
case '`': case '`':
@ -246,8 +243,8 @@ func (tkn *Tokenizer) scanIdentifier() (int, []byte) {
buffer.WriteByte(byte(tkn.lastChar)) buffer.WriteByte(byte(tkn.lastChar))
} }
lowered := bytes.ToLower(buffer.Bytes()) lowered := bytes.ToLower(buffer.Bytes())
if keywordId, found := keywords[string(lowered)]; found { if keywordID, found := keywords[string(lowered)]; found {
return keywordId, lowered return keywordID, lowered
} }
return ID, buffer.Bytes() return ID, buffer.Bytes()
} }
@ -290,7 +287,7 @@ func (tkn *Tokenizer) scanBindVar() (int, []byte) {
func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes.Buffer) { func (tkn *Tokenizer) scanMantissa(base int, buffer *bytes.Buffer) {
for digitVal(tkn.lastChar) < base { for digitVal(tkn.lastChar) < base {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
} }
} }
@ -304,10 +301,10 @@ func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) {
if tkn.lastChar == '0' { if tkn.lastChar == '0' {
// int or float // int or float
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
if tkn.lastChar == 'x' || tkn.lastChar == 'X' { if tkn.lastChar == 'x' || tkn.lastChar == 'X' {
// hexadecimal int // hexadecimal int
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
tkn.scanMantissa(16, buffer) tkn.scanMantissa(16, buffer)
} else { } else {
// octal int or float // octal int or float
@ -334,15 +331,15 @@ func (tkn *Tokenizer) scanNumber(seenDecimalPoint bool) (int, []byte) {
fraction: fraction:
if tkn.lastChar == '.' { if tkn.lastChar == '.' {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
tkn.scanMantissa(10, buffer) tkn.scanMantissa(10, buffer)
} }
exponent: exponent:
if tkn.lastChar == 'e' || tkn.lastChar == 'E' { if tkn.lastChar == 'e' || tkn.lastChar == 'E' {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
if tkn.lastChar == '+' || tkn.lastChar == '-' { if tkn.lastChar == '+' || tkn.lastChar == '-' {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
} }
tkn.scanMantissa(10, buffer) tkn.scanMantissa(10, buffer)
} }
@ -363,7 +360,7 @@ func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
break break
} }
} else if ch == '\\' { } else if ch == '\\' {
if tkn.lastChar == EOFCHAR { if tkn.lastChar == eofChar {
return LEX_ERROR, buffer.Bytes() return LEX_ERROR, buffer.Bytes()
} }
if decodedChar := sqltypes.SqlDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DONTESCAPE { if decodedChar := sqltypes.SqlDecodeMap[byte(tkn.lastChar)]; decodedChar == sqltypes.DONTESCAPE {
@ -373,7 +370,7 @@ func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
} }
tkn.next() tkn.next()
} }
if ch == EOFCHAR { if ch == eofChar {
return LEX_ERROR, buffer.Bytes() return LEX_ERROR, buffer.Bytes()
} }
buffer.WriteByte(byte(ch)) buffer.WriteByte(byte(ch))
@ -384,12 +381,12 @@ func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, []byte) {
func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) { func (tkn *Tokenizer) scanCommentType1(prefix string) (int, []byte) {
buffer := bytes.NewBuffer(make([]byte, 0, 8)) buffer := bytes.NewBuffer(make([]byte, 0, 8))
buffer.WriteString(prefix) buffer.WriteString(prefix)
for tkn.lastChar != EOFCHAR { for tkn.lastChar != eofChar {
if tkn.lastChar == '\n' { if tkn.lastChar == '\n' {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
break break
} }
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
} }
return COMMENT, buffer.Bytes() return COMMENT, buffer.Bytes()
} }
@ -399,23 +396,23 @@ func (tkn *Tokenizer) scanCommentType2() (int, []byte) {
buffer.WriteString("/*") buffer.WriteString("/*")
for { for {
if tkn.lastChar == '*' { if tkn.lastChar == '*' {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
if tkn.lastChar == '/' { if tkn.lastChar == '/' {
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
break break
} }
continue continue
} }
if tkn.lastChar == EOFCHAR { if tkn.lastChar == eofChar {
return LEX_ERROR, buffer.Bytes() return LEX_ERROR, buffer.Bytes()
} }
tkn.ConsumeNext(buffer) tkn.consumeNext(buffer)
} }
return COMMENT, buffer.Bytes() return COMMENT, buffer.Bytes()
} }
func (tkn *Tokenizer) ConsumeNext(buffer *bytes.Buffer) { func (tkn *Tokenizer) consumeNext(buffer *bytes.Buffer) {
if tkn.lastChar == EOFCHAR { if tkn.lastChar == eofChar {
// This should never happen. // This should never happen.
panic("unexpected EOF") panic("unexpected EOF")
} }
@ -426,7 +423,7 @@ func (tkn *Tokenizer) ConsumeNext(buffer *bytes.Buffer) {
func (tkn *Tokenizer) next() { func (tkn *Tokenizer) next() {
if ch, err := tkn.InStream.ReadByte(); err != nil { if ch, err := tkn.InStream.ReadByte(); err != nil {
// Only EOF is possible. // Only EOF is possible.
tkn.lastChar = EOFCHAR tkn.lastChar = eofChar
} else { } else {
tkn.lastChar = uint16(ch) tkn.lastChar = uint16(ch)
} }