diff --git a/go/vt/sqlparser/analyzer.go b/go/vt/sqlparser/analyzer.go index eb090ee819..c603635a2a 100644 --- a/go/vt/sqlparser/analyzer.go +++ b/go/vt/sqlparser/analyzer.go @@ -47,6 +47,15 @@ func IsValue(node ValExpr) bool { return false } +// IsNull returns true if the ValExpr is nil. +func IsNull(node ValExpr) bool { + switch node.(type) { + case *NullVal: + return true + } + return false +} + // HasINClause returns true if any of the conditions has an IN clause. func HasINClause(conditions []BoolExpr) bool { for _, node := range conditions { diff --git a/go/vt/tabletserver/endtoend/nocache_case_test.go b/go/vt/tabletserver/endtoend/nocache_case_test.go index 3606cd5de5..d1c051aa13 100644 --- a/go/vt/tabletserver/endtoend/nocache_case_test.go +++ b/go/vt/tabletserver/endtoend/nocache_case_test.go @@ -617,6 +617,30 @@ func TestNocacheCases(t *testing.T) { framework.TestQuery("commit"), }, }, + &framework.MultiCase{ + Name: "insert with null auto_increment", + Cases: []framework.Testable{ + framework.TestQuery("alter table vitess_e auto_increment = 1"), + framework.TestQuery("begin"), + &framework.TestCase{ + Query: "insert /* auto_increment */ into vitess_e(eid, name, foo) values (NULL, 'aaaa', 'cccc')", + Rewritten: []string{ + "insert /* auto_increment */ into vitess_e(eid, name, foo) values (null, 'aaaa', 'cccc') /* _stream vitess_e (eid id name ) (null 1 'YWFhYQ==' )", + }, + RowsAffected: 1, + }, + framework.TestQuery("commit"), + &framework.TestCase{ + Query: "select * from vitess_e", + Result: [][]string{ + {"1", "1", "aaaa", "cccc"}, + }, + }, + framework.TestQuery("begin"), + framework.TestQuery("delete from vitess_e"), + framework.TestQuery("commit"), + }, + }, &framework.MultiCase{ Name: "insert with number default value", Cases: []framework.Testable{ diff --git a/go/vt/tabletserver/planbuilder/dml.go b/go/vt/tabletserver/planbuilder/dml.go index 03ae8b6c45..0c29c4421c 100644 --- a/go/vt/tabletserver/planbuilder/dml.go +++ b/go/vt/tabletserver/planbuilder/dml.go @@ -494,7 +494,7 @@ func getInsertPKValues(pkColumnNumbers []int, rowList sqlparser.Values, tableInf return nil, errors.New("column count doesn't match value count") } node := row[columnNumber] - if !sqlparser.IsValue(node) { + if !sqlparser.IsNull(node) && !sqlparser.IsValue(node) { return nil, nil } var err error