зеркало из https://github.com/github/vitess-gh.git
Improve set statements
Signed-off-by: yuananf <yuananf@gmail.com>
This commit is contained in:
Родитель
e30c1e17d5
Коммит
29fdfca3dc
|
@ -132,7 +132,7 @@ func (mp *Proxy) doSet(ctx context.Context, session *ProxySession, sql string, b
|
|||
}
|
||||
|
||||
for k, v := range vals {
|
||||
switch k {
|
||||
switch k.Key {
|
||||
case "autocommit":
|
||||
val, ok := v.(int64)
|
||||
if !ok {
|
||||
|
|
|
@ -265,11 +265,17 @@ func StringIn(str string, values ...string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// SetKey is the extracted key from one SetExpr
|
||||
type SetKey struct {
|
||||
Key string
|
||||
Scope string
|
||||
}
|
||||
|
||||
// ExtractSetValues returns a map of key-value pairs
|
||||
// if the query is a SET statement. Values can be int64 or string.
|
||||
// if the query is a SET statement. Values can be bool, int64 or string.
|
||||
// Since set variable names are case insensitive, all keys are returned
|
||||
// as lower case.
|
||||
func ExtractSetValues(sql string) (keyValues map[string]interface{}, scope string, err error) {
|
||||
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
|
||||
stmt, err := Parse(sql)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
@ -278,28 +284,59 @@ func ExtractSetValues(sql string) (keyValues map[string]interface{}, scope strin
|
|||
if !ok {
|
||||
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
|
||||
}
|
||||
result := make(map[string]interface{})
|
||||
result := make(map[SetKey]interface{})
|
||||
for _, expr := range setStmt.Exprs {
|
||||
scope := SessionStr
|
||||
key := expr.Name.Lowered()
|
||||
switch {
|
||||
case strings.HasPrefix(key, "@@global."):
|
||||
scope = GlobalStr
|
||||
key = strings.TrimPrefix(key, "@@global.")
|
||||
case strings.HasPrefix(key, "@@session."):
|
||||
key = strings.TrimPrefix(key, "@@session.")
|
||||
case strings.HasPrefix(key, "@@"):
|
||||
key = strings.TrimPrefix(key, "@@")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr.Name.Lowered(), "@@") {
|
||||
if setStmt.Scope != "" && scope != "" {
|
||||
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
|
||||
}
|
||||
_, out := NewStringTokenizer(key).Scan()
|
||||
key = string(out)
|
||||
}
|
||||
|
||||
setKey := SetKey{
|
||||
Key: key,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
switch expr := expr.Expr.(type) {
|
||||
case *SQLVal:
|
||||
switch expr.Type {
|
||||
case StrVal:
|
||||
result[key] = string(expr.Val)
|
||||
result[setKey] = string(expr.Val)
|
||||
case IntVal:
|
||||
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
result[key] = num
|
||||
result[setKey] = num
|
||||
default:
|
||||
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
|
||||
}
|
||||
case BoolVal:
|
||||
var val int64
|
||||
if expr {
|
||||
val = 1
|
||||
}
|
||||
result[setKey] = val
|
||||
case *ColName:
|
||||
result[setKey] = expr.Name.String()
|
||||
case *NullVal:
|
||||
result[key] = nil
|
||||
result[setKey] = nil
|
||||
case *Default:
|
||||
result[key] = "default"
|
||||
result[setKey] = "default"
|
||||
default:
|
||||
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
|
||||
}
|
||||
|
|
|
@ -361,7 +361,7 @@ func TestStringIn(t *testing.T) {
|
|||
func TestExtractSetValues(t *testing.T) {
|
||||
testcases := []struct {
|
||||
sql string
|
||||
out map[string]interface{}
|
||||
out map[SetKey]interface{}
|
||||
scope string
|
||||
err string
|
||||
}{{
|
||||
|
@ -375,38 +375,74 @@ func TestExtractSetValues(t *testing.T) {
|
|||
err: "invalid syntax: 1 + 1",
|
||||
}, {
|
||||
sql: "set transaction_mode='single'",
|
||||
out: map[string]interface{}{"transaction_mode": "single"},
|
||||
out: map[SetKey]interface{}{{Key: "transaction_mode", Scope: "session"}: "single"},
|
||||
}, {
|
||||
sql: "set autocommit=1",
|
||||
out: map[string]interface{}{"autocommit": int64(1)},
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set autocommit=true",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set autocommit=false",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(0)},
|
||||
}, {
|
||||
sql: "set autocommit=on",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: "on"},
|
||||
}, {
|
||||
sql: "set autocommit=off",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: "off"},
|
||||
}, {
|
||||
sql: "set @@global.autocommit=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "global"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@global.autocommit=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "global"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@session.autocommit=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@session.`autocommit`=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@session.'autocommit'=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@session.\"autocommit\"=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@session.'\"autocommit'=1",
|
||||
out: map[SetKey]interface{}{{Key: "\"autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set @@session.`autocommit'`=1",
|
||||
out: map[SetKey]interface{}{{Key: "autocommit'", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "set AUTOCOMMIT=1",
|
||||
out: map[string]interface{}{"autocommit": int64(1)},
|
||||
out: map[SetKey]interface{}{{Key: "autocommit", Scope: "session"}: int64(1)},
|
||||
}, {
|
||||
sql: "SET character_set_results = NULL",
|
||||
out: map[string]interface{}{"character_set_results": nil},
|
||||
out: map[SetKey]interface{}{{Key: "character_set_results", Scope: "session"}: nil},
|
||||
}, {
|
||||
sql: "SET foo = 0x1234",
|
||||
err: "invalid value type: 0x1234",
|
||||
}, {
|
||||
sql: "SET names utf8",
|
||||
out: map[string]interface{}{"names": "utf8"},
|
||||
out: map[SetKey]interface{}{{Key: "names", Scope: "session"}: "utf8"},
|
||||
}, {
|
||||
sql: "SET names ascii collate ascii_bin",
|
||||
out: map[string]interface{}{"names": "ascii"},
|
||||
out: map[SetKey]interface{}{{Key: "names", Scope: "session"}: "ascii"},
|
||||
}, {
|
||||
sql: "SET charset default",
|
||||
out: map[string]interface{}{"charset": "default"},
|
||||
out: map[SetKey]interface{}{{Key: "charset", Scope: "session"}: "default"},
|
||||
}, {
|
||||
sql: "SET character set ascii",
|
||||
out: map[string]interface{}{"charset": "ascii"},
|
||||
out: map[SetKey]interface{}{{Key: "charset", Scope: "session"}: "ascii"},
|
||||
}, {
|
||||
sql: "SET SESSION wait_timeout = 3600",
|
||||
out: map[string]interface{}{"wait_timeout": int64(3600)},
|
||||
out: map[SetKey]interface{}{{Key: "wait_timeout", Scope: "session"}: int64(3600)},
|
||||
scope: "session",
|
||||
}, {
|
||||
sql: "SET GLOBAL wait_timeout = 3600",
|
||||
out: map[string]interface{}{"wait_timeout": int64(3600)},
|
||||
out: map[SetKey]interface{}{{Key: "wait_timeout", Scope: "session"}: int64(3600)},
|
||||
scope: "global",
|
||||
}}
|
||||
for _, tcase := range testcases {
|
||||
|
|
|
@ -3350,8 +3350,13 @@ func (node *TableIdent) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
func formatID(buf *TrackedBuffer, original, lowered string) {
|
||||
isDbSystemVariable := false
|
||||
if len(original) > 1 && original[:2] == "@@" {
|
||||
isDbSystemVariable = true
|
||||
}
|
||||
|
||||
for i, c := range original {
|
||||
if !isLetter(uint16(c)) {
|
||||
if !isLetter(uint16(c)) && (!isDbSystemVariable || !isCarat(uint16(c))) {
|
||||
if i == 0 || !isDigit(uint16(c)) {
|
||||
goto mustEscape
|
||||
}
|
||||
|
|
|
@ -678,6 +678,14 @@ var (
|
|||
input: "set #simple\n b = 4",
|
||||
}, {
|
||||
input: "set character_set_results = utf8",
|
||||
}, {
|
||||
input: "set @@session.autocommit = true",
|
||||
}, {
|
||||
input: "set @@session.`autocommit` = true",
|
||||
}, {
|
||||
input: "set @@session.'autocommit' = true",
|
||||
}, {
|
||||
input: "set @@session.\"autocommit\" = true",
|
||||
}, {
|
||||
input: "set names utf8 collate foo",
|
||||
output: "set names 'utf8'",
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -2663,7 +2663,11 @@ set_list:
|
|||
}
|
||||
|
||||
set_expression:
|
||||
reserved_sql_id '=' expression
|
||||
reserved_sql_id '=' ON
|
||||
{
|
||||
$$ = &SetExpr{Name: $1, Expr: NewStrVal([]byte("on"))}
|
||||
}
|
||||
| reserved_sql_id '=' expression
|
||||
{
|
||||
$$ = &SetExpr{Name: $1, Expr: $3}
|
||||
}
|
||||
|
|
|
@ -476,7 +476,11 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
|
|||
return tkn.scanBitLiteral()
|
||||
}
|
||||
}
|
||||
return tkn.scanIdentifier(byte(ch))
|
||||
isDbSystemVariable := false
|
||||
if ch == '@' && tkn.lastChar == '@' {
|
||||
isDbSystemVariable = true
|
||||
}
|
||||
return tkn.scanIdentifier(byte(ch), isDbSystemVariable)
|
||||
case isDigit(ch):
|
||||
return tkn.scanNumber(false)
|
||||
case ch == ':':
|
||||
|
@ -608,10 +612,10 @@ func (tkn *Tokenizer) skipBlank() {
|
|||
}
|
||||
}
|
||||
|
||||
func (tkn *Tokenizer) scanIdentifier(firstByte byte) (int, []byte) {
|
||||
func (tkn *Tokenizer) scanIdentifier(firstByte byte, isDbSystemVariable bool) (int, []byte) {
|
||||
buffer := &bytes2.Buffer{}
|
||||
buffer.WriteByte(firstByte)
|
||||
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) {
|
||||
for isLetter(tkn.lastChar) || isDigit(tkn.lastChar) || (isDbSystemVariable && isCarat(tkn.lastChar)) {
|
||||
buffer.WriteByte(byte(tkn.lastChar))
|
||||
tkn.next()
|
||||
}
|
||||
|
@ -915,6 +919,10 @@ func isLetter(ch uint16) bool {
|
|||
return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '@'
|
||||
}
|
||||
|
||||
func isCarat(ch uint16) bool {
|
||||
return ch == '.' || ch == '\'' || ch == '"' || ch == '`'
|
||||
}
|
||||
|
||||
func digitVal(ch uint16) int {
|
||||
switch {
|
||||
case '0' <= ch && ch <= '9':
|
||||
|
|
|
@ -526,17 +526,32 @@ func (e *Executor) handleSet(ctx context.Context, safeSession *SafeSession, sql
|
|||
return &sqltypes.Result{}, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, err.Error())
|
||||
}
|
||||
|
||||
if scope == "global" {
|
||||
if scope == sqlparser.GlobalStr {
|
||||
return &sqltypes.Result{}, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported in set: global")
|
||||
}
|
||||
|
||||
for k, v := range vals {
|
||||
switch k {
|
||||
if k.Scope == sqlparser.GlobalStr {
|
||||
return &sqltypes.Result{}, vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unsupported in set: global")
|
||||
}
|
||||
switch k.Key {
|
||||
case "autocommit":
|
||||
val, ok := v.(int64)
|
||||
if !ok {
|
||||
var val int64
|
||||
switch v := v.(type) {
|
||||
case int64:
|
||||
val = v
|
||||
case string:
|
||||
if v == "on" {
|
||||
val = 1
|
||||
} else if v == "off" {
|
||||
val = 0
|
||||
} else {
|
||||
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value for autocommit: %s", v)
|
||||
}
|
||||
default:
|
||||
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected value type for autocommit: %T", v)
|
||||
}
|
||||
|
||||
switch val {
|
||||
case 0:
|
||||
safeSession.Autocommit = false
|
||||
|
|
|
@ -225,28 +225,79 @@ func TestExecutorSet(t *testing.T) {
|
|||
out *vtgatepb.Session
|
||||
err string
|
||||
}{{
|
||||
in: "set autocommit=1",
|
||||
in: "set autocommit = 1",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set @@autocommit = true",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set @@session.autocommit = true",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set @@session.`autocommit` = true",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set @@session.'autocommit' = true",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set @@session.\"autocommit\" = true",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set autocommit = true",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set autocommit = on",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set autocommit = 'on'",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set autocommit = `on`",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set autocommit = \"on\"",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
}, {
|
||||
in: "set autocommit = false",
|
||||
out: &vtgatepb.Session{},
|
||||
}, {
|
||||
in: "set autocommit = off",
|
||||
out: &vtgatepb.Session{},
|
||||
}, {
|
||||
in: "set AUTOCOMMIT = 0",
|
||||
out: &vtgatepb.Session{},
|
||||
}, {
|
||||
in: "set AUTOCOMMIT = 'aa'",
|
||||
err: "unexpected value type for autocommit: string",
|
||||
err: "unexpected value for autocommit: aa",
|
||||
}, {
|
||||
in: "set autocommit = 2",
|
||||
err: "unexpected value for autocommit: 2",
|
||||
}, {
|
||||
in: "set client_found_rows=1",
|
||||
in: "set client_found_rows = 1",
|
||||
out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{ClientFoundRows: true}},
|
||||
}, {
|
||||
in: "set client_found_rows=0",
|
||||
in: "set client_found_rows = true",
|
||||
out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{ClientFoundRows: true}},
|
||||
}, {
|
||||
in: "set client_found_rows = 0",
|
||||
out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{}},
|
||||
}, {
|
||||
in: "set client_found_rows='aa'",
|
||||
in: "set client_found_rows = false",
|
||||
out: &vtgatepb.Session{Autocommit: true, Options: &querypb.ExecuteOptions{}},
|
||||
}, {
|
||||
in: "set @@global.client_found_rows = 1",
|
||||
err: "unsupported in set: global",
|
||||
}, {
|
||||
in: "set global client_found_rows = 1",
|
||||
err: "unsupported in set: global",
|
||||
}, {
|
||||
in: "set global @@session.client_found_rows = 1",
|
||||
err: "unsupported in set: mixed using of variable scope",
|
||||
}, {
|
||||
in: "set client_found_rows = 'aa'",
|
||||
err: "unexpected value type for client_found_rows: string",
|
||||
}, {
|
||||
in: "set client_found_rows=2",
|
||||
in: "set client_found_rows = 2",
|
||||
err: "unexpected value for client_found_rows: 2",
|
||||
}, {
|
||||
in: "set transaction_mode = 'unspecified'",
|
||||
|
@ -260,6 +311,9 @@ func TestExecutorSet(t *testing.T) {
|
|||
}, {
|
||||
in: "set transaction_mode = 'twopc'",
|
||||
out: &vtgatepb.Session{Autocommit: true, TransactionMode: vtgatepb.TransactionMode_TWOPC},
|
||||
}, {
|
||||
in: "set transaction_mode = twopc",
|
||||
out: &vtgatepb.Session{Autocommit: true, TransactionMode: vtgatepb.TransactionMode_TWOPC},
|
||||
}, {
|
||||
in: "set transaction_mode = 'aa'",
|
||||
err: "invalid transaction_mode: aa",
|
||||
|
@ -297,7 +351,7 @@ func TestExecutorSet(t *testing.T) {
|
|||
in: "set sql_select_limit = 'asdfasfd'",
|
||||
err: "unexpected string value for sql_select_limit: asdfasfd",
|
||||
}, {
|
||||
in: "set autocommit=1+1",
|
||||
in: "set autocommit = 1+1",
|
||||
err: "invalid syntax: 1 + 1",
|
||||
}, {
|
||||
in: "set character_set_results=null",
|
||||
|
@ -306,8 +360,8 @@ func TestExecutorSet(t *testing.T) {
|
|||
in: "set character_set_results='abcd'",
|
||||
err: "disallowed value for character_set_results: abcd",
|
||||
}, {
|
||||
in: "set foo=1",
|
||||
err: "unsupported construct: set foo=1",
|
||||
in: "set foo = 1",
|
||||
err: "unsupported construct: set foo = 1",
|
||||
}, {
|
||||
in: "set names utf8",
|
||||
out: &vtgatepb.Session{Autocommit: true},
|
||||
|
|
Загрузка…
Ссылка в новой задаче