From e8aa83fcdc7ba473e9fa0214fbd72e12041b7001 Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Mon, 2 Nov 2015 21:17:46 -0800 Subject: [PATCH 1/6] query.proto: bind var conversion tweak --- go/vt/tabletserver/proto/proto3.go | 34 +++++++++++++++--------------- go/vt/vtgate/proto/proto3.go | 6 +----- go/vt/vtgate/topo_utils.go | 7 +----- 3 files changed, 19 insertions(+), 28 deletions(-) diff --git a/go/vt/tabletserver/proto/proto3.go b/go/vt/tabletserver/proto/proto3.go index 8494e64720..4cc23a2757 100644 --- a/go/vt/tabletserver/proto/proto3.go +++ b/go/vt/tabletserver/proto/proto3.go @@ -268,21 +268,20 @@ func Proto3ToBindVariables(bv map[string]*pb.BindVariable) (map[string]interface result := make(map[string]interface{}) var err error for k, v := range bv { - if v != nil && v.Type == sqltypes.Tuple { + if v == nil { + continue + } + if v.Type == sqltypes.Tuple { list := make([]interface{}, len(v.Values)) for i, lv := range v.Values { - asbind := &pb.BindVariable{ - Type: lv.Type, - Value: lv.Value, - } - list[i], err = BindVariableToNative(asbind) + list[i], err = SQLToNative(lv.Type, lv.Value) if err != nil { return nil, err } } result[k] = list } else { - result[k], err = BindVariableToNative(v) + result[k], err = SQLToNative(v.Type, v.Value) if err != nil { return nil, err } @@ -291,18 +290,19 @@ func Proto3ToBindVariables(bv map[string]*pb.BindVariable) (map[string]interface return result, nil } -// BindVariableToNative converts a proto bind var to a native go type. -func BindVariableToNative(v *pb.BindVariable) (interface{}, error) { - if v == nil || v.Type == sqltypes.Null { +// SQLToNative converts a SQL type & value to a native go type. +// This does not work for sqltypes.Tuple. +func SQLToNative(typ pb.Type, val []byte) (interface{}, error) { + if typ == sqltypes.Null { return nil, nil - } else if sqltypes.IsSigned(v.Type) { - return strconv.ParseInt(string(v.Value), 0, 64) - } else if sqltypes.IsUnsigned(v.Type) { - return strconv.ParseUint(string(v.Value), 0, 64) - } else if sqltypes.IsFloat(v.Type) { - return strconv.ParseFloat(string(v.Value), 64) + } else if sqltypes.IsSigned(typ) { + return strconv.ParseInt(string(val), 0, 64) + } else if sqltypes.IsUnsigned(typ) { + return strconv.ParseUint(string(val), 0, 64) + } else if sqltypes.IsFloat(typ) { + return strconv.ParseFloat(string(val), 64) } - return v.Value, nil + return val, nil } // Proto3ToQueryResultList converts a proto3 QueryResult to an internal data structure. diff --git a/go/vt/vtgate/proto/proto3.go b/go/vt/vtgate/proto/proto3.go index ff0ef9e2dd..75bb763203 100644 --- a/go/vt/vtgate/proto/proto3.go +++ b/go/vt/vtgate/proto/proto3.go @@ -83,11 +83,7 @@ func ProtoToEntityIds(l []*pb.ExecuteEntityIdsRequest_EntityId) []EntityId { result := make([]EntityId, len(l)) for i, e := range l { result[i].KeyspaceID = key.KeyspaceId(e.KeyspaceId) - bv := &pbq.BindVariable{ - Type: e.XidType, - Value: e.XidValue, - } - v, err := tproto.BindVariableToNative(bv) + v, err := tproto.SQLToNative(e.XidType, e.XidValue) if err != nil { panic(err) } diff --git a/go/vt/vtgate/topo_utils.go b/go/vt/vtgate/topo_utils.go index dc7b842f5d..483d81f28d 100644 --- a/go/vt/vtgate/topo_utils.go +++ b/go/vt/vtgate/topo_utils.go @@ -17,7 +17,6 @@ import ( "github.com/youtube/vitess/go/vt/vtgate/proto" "golang.org/x/net/context" - "github.com/youtube/vitess/go/vt/proto/query" pb "github.com/youtube/vitess/go/vt/proto/topodata" pbg "github.com/youtube/vitess/go/vt/proto/vtgate" "github.com/youtube/vitess/go/vt/proto/vtrpc" @@ -104,11 +103,7 @@ func mapEntityIdsToShards(ctx context.Context, topoServ SrvTopoServer, cell, key if err != nil { return "", nil, err } - bv := &query.BindVariable{ - Type: eid.XidType, - Value: eid.XidValue, - } - v, _ := tproto.BindVariableToNative(bv) + v, _ := tproto.SQLToNative(eid.XidType, eid.XidValue) shards[shard] = append(shards[shard], v) } return keyspace, shards, nil From be45074302485bd165c60b4f50dec11e3c145a5a Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Thu, 5 Nov 2015 21:03:18 -0800 Subject: [PATCH 2/6] schema: use vitess type system --- go/vt/schema/schema.go | 36 ++++++-------------- go/vt/tabletserver/codex.go | 5 ++- go/vt/tabletserver/codex_test.go | 25 +++++++------- go/vt/tabletserver/endtoend/cache_test.go | 12 +++---- go/vt/tabletserver/planbuilder/dml.go | 2 +- go/vt/tabletserver/query_executor.go | 4 +-- go/vt/tabletserver/query_executor_test.go | 18 ++++++++-- go/vt/tabletserver/query_splitter_test.go | 8 ++--- go/vt/tabletserver/rowcache.go | 3 +- go/vt/tabletserver/rowcache_invalidator.go | 4 +-- go/vt/tabletserver/schema_info.go | 12 +++---- go/vt/tabletserver/schema_info_test.go | 39 +++++++++++++++++++--- go/vt/tabletserver/schemaz.go | 10 +++--- go/vt/tabletserver/schemaz_test.go | 18 +++++----- go/vt/tabletserver/table_info.go | 30 ++++++++++++++--- go/vt/tabletserver/table_info_test.go | 29 ++++++++++++++-- go/vt/tabletserver/tabletserver_test.go | 3 ++ go/vt/vtgate/proto/proto3.go | 1 - 18 files changed, 166 insertions(+), 93 deletions(-) diff --git a/go/vt/schema/schema.go b/go/vt/schema/schema.go index 6d7d5ebfcc..aa448f7a39 100644 --- a/go/vt/schema/schema.go +++ b/go/vt/schema/schema.go @@ -8,32 +8,24 @@ package schema // It contains a data structure that's shared between sqlparser & tabletserver import ( - "strings" - "github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/sync2" -) - -// Column categories -const ( - CAT_OTHER = iota - CAT_NUMBER - CAT_VARBINARY + "github.com/youtube/vitess/go/vt/proto/query" ) // Cache types const ( - CACHE_NONE = 0 - CACHE_RW = 1 - CACHE_W = 2 + CacheNone = 0 + CacheRW = 1 + CacheW = 2 ) // TableColumn contains info about a table's column. type TableColumn struct { - Name string - Category int - IsAuto bool - Default sqltypes.Value + Name string + Type query.Type + IsAuto bool + Default sqltypes.Value } // Table contains info about a table. @@ -61,16 +53,10 @@ func NewTable(name string) *Table { } // AddColumn adds a column to the Table. -func (ta *Table) AddColumn(name string, columnType string, defval sqltypes.Value, extra string) { +func (ta *Table) AddColumn(name string, columnType query.Type, defval sqltypes.Value, extra string) { index := len(ta.Columns) ta.Columns = append(ta.Columns, TableColumn{Name: name}) - if strings.Contains(columnType, "int") { - ta.Columns[index].Category = CAT_NUMBER - } else if strings.HasPrefix(columnType, "varbinary") { - ta.Columns[index].Category = CAT_VARBINARY - } else { - ta.Columns[index].Category = CAT_OTHER - } + ta.Columns[index].Type = columnType if extra == "auto_increment" { ta.Columns[index].IsAuto = true // Ignore default value, if any @@ -79,7 +65,7 @@ func (ta *Table) AddColumn(name string, columnType string, defval sqltypes.Value if defval.IsNull() { return } - if ta.Columns[index].Category == CAT_NUMBER { + if sqltypes.IsIntegral(ta.Columns[index].Type) { ta.Columns[index].Default = sqltypes.MakeNumeric(defval.Raw()) } else { ta.Columns[index].Default = sqltypes.MakeString(defval.Raw()) diff --git a/go/vt/tabletserver/codex.go b/go/vt/tabletserver/codex.go index 891430a219..f0cb8a8278 100644 --- a/go/vt/tabletserver/codex.go +++ b/go/vt/tabletserver/codex.go @@ -176,12 +176,11 @@ func validateValue(col *schema.TableColumn, value sqltypes.Value) error { if value.IsNull() { return nil } - switch col.Category { - case schema.CAT_NUMBER: + if sqltypes.IsIntegral(col.Type) { if !value.IsNumeric() { return NewTabletError(ErrFail, vtrpc.ErrorCode_BAD_INPUT, "type mismatch, expecting numeric type for %v for column: %v", value, col) } - case schema.CAT_VARBINARY: + } else if col.Type == sqltypes.VarBinary { if !value.IsString() { return NewTabletError(ErrFail, vtrpc.ErrorCode_BAD_INPUT, "type mismatch, expecting string type for %v for column: %v", value, col) } diff --git a/go/vt/tabletserver/codex_test.go b/go/vt/tabletserver/codex_test.go index 3779ad212f..6fd8578745 100644 --- a/go/vt/tabletserver/codex_test.go +++ b/go/vt/tabletserver/codex_test.go @@ -11,13 +11,14 @@ import ( "testing" "github.com/youtube/vitess/go/sqltypes" + "github.com/youtube/vitess/go/vt/proto/query" "github.com/youtube/vitess/go/vt/schema" ) func TestCodexBuildValuesList(t *testing.T) { tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) // simple PK clause. e.g. where pk1 = 1 @@ -201,7 +202,7 @@ func TestCodexResolvePKValues(t *testing.T) { testUtils := newTestUtils() tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) key := "var" bindVariables := make(map[string]interface{}) @@ -237,7 +238,7 @@ func TestCodexResolveListArg(t *testing.T) { testUtils := newTestUtils() tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) key := "var" @@ -264,7 +265,7 @@ func TestCodexBuildSecondaryList(t *testing.T) { pk2 := "pk2" tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{pk1, pk2}) // set pk2 = 'xyz' where pk1=1 and pk2 = 'abc' @@ -295,7 +296,7 @@ func TestCodexBuildStreamComment(t *testing.T) { pk2 := "pk2" tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{pk1, pk2}) // set pk2 = 'xyz' where pk1=1 and pk2 = 'abc' @@ -318,7 +319,7 @@ func TestCodexResolveValueWithIncompatibleValueType(t *testing.T) { testUtils := newTestUtils() tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) _, err := resolveValue(tableInfo.GetPKColumn(0), 0, nil) testUtils.checkTabletError(t, err, ErrFail, "incompatible value type ") @@ -328,7 +329,7 @@ func TestCodexValidateRow(t *testing.T) { testUtils := newTestUtils() tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) // #columns and #rows do not match err := validateRow(&tableInfo, []int{1}, []sqltypes.Value{}) @@ -414,7 +415,7 @@ func TestCodexApplyFilterWithPKDefaults(t *testing.T) { testUtils := newTestUtils() tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) output := applyFilterWithPKDefaults(&tableInfo, []int{-1}, []sqltypes.Value{}) if len(output) != 1 { @@ -432,7 +433,7 @@ func TestCodexValidateKey(t *testing.T) { queryServiceStats := NewQueryServiceStats("", false) tableInfo := createTableInfo("Table", []string{"pk1", "pk2", "col1"}, - []string{"int", "varbinary(128)", "int"}, + []query.Type{sqltypes.Int64, sqltypes.VarBinary, sqltypes.Int32}, []string{"pk1", "pk2"}) // validate empty key newKey := validateKey(&tableInfo, "", queryServiceStats) @@ -473,14 +474,14 @@ func TestCodexUnicoded(t *testing.T) { } func createTableInfo( - name string, colNames []string, colTypes []string, pKeys []string) TableInfo { + name string, colNames []string, colTypes []query.Type, pKeys []string) TableInfo { table := schema.NewTable(name) for i, colName := range colNames { colType := colTypes[i] defaultVal := sqltypes.Value{} - if strings.Contains(colType, "int") { + if sqltypes.IsIntegral(colType) { defaultVal = sqltypes.MakeNumeric([]byte("0")) - } else if strings.HasPrefix(colType, "varbinary") { + } else if colType == sqltypes.VarBinary { defaultVal = sqltypes.MakeString([]byte("")) } table.AddColumn(colName, colType, defaultVal, "") diff --git a/go/vt/tabletserver/endtoend/cache_test.go b/go/vt/tabletserver/endtoend/cache_test.go index 2a7cb3811f..37c55ca9ef 100644 --- a/go/vt/tabletserver/endtoend/cache_test.go +++ b/go/vt/tabletserver/endtoend/cache_test.go @@ -42,8 +42,8 @@ func TestUncacheableTables(t *testing.T) { t.Errorf("%s: table vitess_nocache not found in schema", tcase.create) continue } - if table.CacheType != schema.CACHE_NONE { - t.Errorf("CacheType: %d, want %d", table.CacheType, schema.CACHE_NONE) + if table.CacheType != schema.CacheNone { + t.Errorf("CacheType: %d, want %d", table.CacheType, schema.CacheNone) } } } @@ -54,16 +54,16 @@ func TestOverrideTables(t *testing.T) { cacheType int }{{ table: "vitess_cached2", - cacheType: schema.CACHE_RW, + cacheType: schema.CacheRW, }, { table: "vitess_view", - cacheType: schema.CACHE_RW, + cacheType: schema.CacheRW, }, { table: "vitess_part1", - cacheType: schema.CACHE_W, + cacheType: schema.CacheW, }, { table: "vitess_part2", - cacheType: schema.CACHE_W, + cacheType: schema.CacheW, }} for _, tcase := range testCases { table, ok := framework.DebugSchema()[tcase.table] diff --git a/go/vt/tabletserver/planbuilder/dml.go b/go/vt/tabletserver/planbuilder/dml.go index 5443651d20..03ae8b6c45 100644 --- a/go/vt/tabletserver/planbuilder/dml.go +++ b/go/vt/tabletserver/planbuilder/dml.go @@ -187,7 +187,7 @@ func analyzeSelect(sel *sqlparser.Select, getTable TableGetter) (plan *ExecPlan, } // Further improvements possible only if table is row-cached - if tableInfo.CacheType == schema.CACHE_NONE || tableInfo.CacheType == schema.CACHE_W { + if tableInfo.CacheType == schema.CacheNone || tableInfo.CacheType == schema.CacheW { plan.Reason = ReasonNocache return plan, nil } diff --git a/go/vt/tabletserver/query_executor.go b/go/vt/tabletserver/query_executor.go index e6faff6678..c0ee82b88c 100644 --- a/go/vt/tabletserver/query_executor.go +++ b/go/vt/tabletserver/query_executor.go @@ -84,7 +84,7 @@ func (qre *QueryExecutor) Execute() (reply *mproto.QueryResult, err error) { defer conn.Recycle() conn.RecordQuery(qre.query) var invalidator CacheInvalidator - if qre.plan.TableInfo != nil && qre.plan.TableInfo.CacheType != schema.CACHE_NONE { + if qre.plan.TableInfo != nil && qre.plan.TableInfo.CacheType != schema.CacheNone { invalidator = conn.DirtyKeys(qre.plan.TableName) } switch qre.plan.PlanID { @@ -186,7 +186,7 @@ func (qre *QueryExecutor) execDmlAutoCommit() (reply *mproto.QueryResult, err er defer conn.Recycle() conn.RecordQuery(qre.query) var invalidator CacheInvalidator - if qre.plan.TableInfo != nil && qre.plan.TableInfo.CacheType != schema.CACHE_NONE { + if qre.plan.TableInfo != nil && qre.plan.TableInfo.CacheType != schema.CacheNone { invalidator = conn.DirtyKeys(qre.plan.TableName) } switch qre.plan.PlanID { diff --git a/go/vt/tabletserver/query_executor_test.go b/go/vt/tabletserver/query_executor_test.go index d989373d96..446009bda9 100644 --- a/go/vt/tabletserver/query_executor_test.go +++ b/go/vt/tabletserver/query_executor_test.go @@ -1063,9 +1063,9 @@ func initQueryExecutorTestDB(db *fakesqldb.DB) { func getTestTableFields() []mproto.Field { return []mproto.Field{ - mproto.Field{Name: "pk", Type: mproto.VT_LONG}, - mproto.Field{Name: "name", Type: mproto.VT_LONG}, - mproto.Field{Name: "addr", Type: mproto.VT_LONG}, + mproto.Field{Name: "pk", Type: mysql.TypeLong}, + mproto.Field{Name: "name", Type: mysql.TypeLong}, + mproto.Field{Name: "addr", Type: mysql.TypeLong}, } } @@ -1125,6 +1125,18 @@ func getQueryExecutorSupportedQueries() map[string]*mproto.QueryResult { }, }, }, + "select * from `test_table` where 1 != 1": &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }, { + Name: "name", + Type: mysql.TypeLong, + }, { + Name: "addr", + Type: mysql.TypeLong, + }}, + }, "describe `test_table`": &mproto.QueryResult{ RowsAffected: 3, Rows: [][]sqltypes.Value{ diff --git a/go/vt/tabletserver/query_splitter_test.go b/go/vt/tabletserver/query_splitter_test.go index 127d5f5068..32103f0956 100644 --- a/go/vt/tabletserver/query_splitter_test.go +++ b/go/vt/tabletserver/query_splitter_test.go @@ -19,9 +19,9 @@ func getSchemaInfo() *SchemaInfo { Name: "test_table", } zero, _ := sqltypes.BuildValue(0) - table.AddColumn("id", "int", zero, "") - table.AddColumn("id2", "int", zero, "") - table.AddColumn("count", "int", zero, "") + table.AddColumn("id", sqltypes.Int64, zero, "") + table.AddColumn("id2", sqltypes.Int64, zero, "") + table.AddColumn("count", sqltypes.Int64, zero, "") table.PKColumns = []int{0} primaryIndex := table.AddIndex("PRIMARY") primaryIndex.AddColumn("id", 12345) @@ -35,7 +35,7 @@ func getSchemaInfo() *SchemaInfo { tableNoPK := &schema.Table{ Name: "test_table_no_pk", } - tableNoPK.AddColumn("id", "int", zero, "") + tableNoPK.AddColumn("id", sqltypes.Int64, zero, "") tableNoPK.PKColumns = []int{} tables["test_table_no_pk"] = &TableInfo{Table: tableNoPK} diff --git a/go/vt/tabletserver/rowcache.go b/go/vt/tabletserver/rowcache.go index 6bee73d99e..94a50034e8 100644 --- a/go/vt/tabletserver/rowcache.go +++ b/go/vt/tabletserver/rowcache.go @@ -12,7 +12,6 @@ import ( "github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/stats" "github.com/youtube/vitess/go/vt/proto/vtrpc" - "github.com/youtube/vitess/go/vt/schema" "golang.org/x/net/context" ) @@ -170,7 +169,7 @@ func (rc *RowCache) decodeRow(b []byte) (row []sqltypes.Value) { // Corrupt data return nil } - if rc.tableInfo.Columns[i].Category == schema.CAT_NUMBER { + if sqltypes.IsIntegral(rc.tableInfo.Columns[i].Type) { row[i] = sqltypes.MakeNumeric(data[:length]) } else { row[i] = sqltypes.MakeString(data[:length]) diff --git a/go/vt/tabletserver/rowcache_invalidator.go b/go/vt/tabletserver/rowcache_invalidator.go index 44b899c5b7..7a92e56c95 100644 --- a/go/vt/tabletserver/rowcache_invalidator.go +++ b/go/vt/tabletserver/rowcache_invalidator.go @@ -187,7 +187,7 @@ func (rci *RowcacheInvalidator) handleDMLEvent(event *blproto.StreamEvent) { if tableInfo == nil { panic(NewTabletError(ErrFail, vtrpc.ErrorCode_BAD_INPUT, "Table %s not found", event.TableName)) } - if tableInfo.CacheType == schema.CACHE_NONE { + if tableInfo.CacheType == schema.CacheNone { return } @@ -251,7 +251,7 @@ func (rci *RowcacheInvalidator) handleUnrecognizedEvent(sql string) { rci.qe.queryServiceStats.InternalErrors.Add("Invalidation", 1) return } - if tableInfo.CacheType == schema.CACHE_NONE { + if tableInfo.CacheType == schema.CacheNone { return } diff --git a/go/vt/tabletserver/schema_info.go b/go/vt/tabletserver/schema_info.go index 4cd63215e5..a404c03258 100644 --- a/go/vt/tabletserver/schema_info.go +++ b/go/vt/tabletserver/schema_info.go @@ -233,10 +233,10 @@ func (si *SchemaInfo) override() { } switch override.Cache.Type { case "RW": - table.CacheType = schema.CACHE_RW + table.CacheType = schema.CacheRW table.Cache = NewRowCache(table, si.cachePool) case "W": - table.CacheType = schema.CACHE_W + table.CacheType = schema.CacheW if override.Cache.Table == "" { log.Warningf("Incomplete cache specs: %v", override) continue @@ -381,7 +381,7 @@ func (si *SchemaInfo) CreateOrUpdateTable(ctx context.Context, tableName string) } si.tables[tableName] = tableInfo - if tableInfo.CacheType == schema.CACHE_NONE { + if tableInfo.CacheType == schema.CacheNone { log.Infof("Initialized table: %s", tableName) } else { log.Infof("Initialized cached table: %s, prefix: %s", tableName, tableInfo.Cache.prefix) @@ -555,7 +555,7 @@ func (si *SchemaInfo) getRowcacheStats() map[string]int64 { defer si.mu.Unlock() tstats := make(map[string]int64) for k, v := range si.tables { - if v.CacheType != schema.CACHE_NONE { + if v.CacheType != schema.CacheNone { hits, absent, misses, _ := v.Stats() tstats[k+".Hits"] = hits tstats[k+".Absent"] = absent @@ -570,7 +570,7 @@ func (si *SchemaInfo) getRowcacheInvalidations() map[string]int64 { defer si.mu.Unlock() tstats := make(map[string]int64) for k, v := range si.tables { - if v.CacheType != schema.CACHE_NONE { + if v.CacheType != schema.CacheNone { _, _, _, invalidations := v.Stats() tstats[k] = invalidations } @@ -745,7 +745,7 @@ func (si *SchemaInfo) handleHTTPTableStats(response http.ResponseWriter, request si.mu.Lock() defer si.mu.Unlock() for k, v := range si.tables { - if v.CacheType != schema.CACHE_NONE { + if v.CacheType != schema.CacheNone { temp.hits, temp.absent, temp.misses, temp.invalidations = v.Stats() tstats[k] = temp totals.hits += temp.hits diff --git a/go/vt/tabletserver/schema_info_test.go b/go/vt/tabletserver/schema_info_test.go index 4038971ebd..4907f5815a 100644 --- a/go/vt/tabletserver/schema_info_test.go +++ b/go/vt/tabletserver/schema_info_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/youtube/vitess/go/mysql" mproto "github.com/youtube/vitess/go/mysql/proto" "github.com/youtube/vitess/go/sqldb" "github.com/youtube/vitess/go/sqltypes" @@ -148,7 +149,7 @@ func TestSchemaInfoOpenFailedDueToTableInfoErr(t *testing.T) { createTestTableBaseShowTable("test_table"), }, }) - db.AddQuery("describe `test_table`", &mproto.QueryResult{ + db.AddQuery("select * from `test_table` where 1 != 1", &mproto.QueryResult{ // this will cause NewTableInfo error RowsAffected: math.MaxUint64, }) @@ -180,14 +181,14 @@ func TestSchemaInfoOpenWithSchemaOverride(t *testing.T) { // test cache type RW schemaInfo.Open(&appParams, &dbaParams, schemaOverrides, true) testTableInfo := schemaInfo.GetTable("test_table_01") - if testTableInfo.Table.CacheType != schema.CACHE_RW { + if testTableInfo.Table.CacheType != schema.CacheRW { t.Fatalf("test_table_01's cache type should be RW") } schemaInfo.Close() // test cache type W schemaInfo.Open(&appParams, &dbaParams, schemaOverrides, true) testTableInfo = schemaInfo.GetTable("test_table_02") - if testTableInfo.Table.CacheType != schema.CACHE_W { + if testTableInfo.Table.CacheType != schema.CacheW { t.Fatalf("test_table_02's cache type should be W") } schemaInfo.Close() @@ -235,11 +236,16 @@ func TestSchemaInfoReload(t *testing.T) { }, }) + db.AddQuery("select * from `test_table_04` where 1 != 1", &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }}, + }) db.AddQuery("describe `test_table_04`", &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{createTestTableDescribe("pk")}, }) - db.AddQuery("show index from `test_table_04`", &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{createTestTableShowIndex("pk")}, @@ -490,6 +496,13 @@ func TestUpdatedMysqlStats(t *testing.T) { createTestTableBaseShowTable(tableName), }, }) + q = fmt.Sprintf("select * from `%s` where 1 != 1", tableName) + db.AddQuery(q, &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }}, + }) q = fmt.Sprintf("describe `%s`", tableName) db.AddQuery(q, &mproto.QueryResult{ RowsAffected: 1, @@ -797,6 +810,12 @@ func getSchemaInfoTestSupportedQueries() map[string]*mproto.QueryResult { }, }, }, + "select * from `test_table_01` where 1 != 1": &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }}, + }, "describe `test_table_01`": &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{ @@ -810,6 +829,12 @@ func getSchemaInfoTestSupportedQueries() map[string]*mproto.QueryResult { }, }, }, + "select * from `test_table_02` where 1 != 1": &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }}, + }, "describe `test_table_02`": &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{ @@ -823,6 +848,12 @@ func getSchemaInfoTestSupportedQueries() map[string]*mproto.QueryResult { }, }, }, + "select * from `test_table_03` where 1 != 1": &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }}, + }, "describe `test_table_03`": &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{ diff --git a/go/vt/tabletserver/schemaz.go b/go/vt/tabletserver/schemaz.go index f1c9e5819e..22ebb433a3 100644 --- a/go/vt/tabletserver/schemaz.go +++ b/go/vt/tabletserver/schemaz.go @@ -30,7 +30,7 @@ var ( schemazTmpl = template.Must(template.New("example").Parse(` {{$top := .}}{{with .Table}} {{.Name}} - {{range .Columns}}{{.Name}}: {{index $top.ColumnCategory .Category}}, {{if .IsAuto}}autoinc{{end}}, {{.Default}}
{{end}} + {{range .Columns}}{{.Name}}: {{.Type}}, {{if .IsAuto}}autoinc{{end}}, {{.Default}}
{{end}} {{range .Indexes}}{{.Name}}: ({{range .Columns}}{{.}},{{end}}), ({{range .Cardinality}}{{.}},{{end}})
{{end}} {{index $top.CacheType .CacheType}} {{.TableRows.Get}} @@ -75,12 +75,10 @@ func schemazHandler(tables []*schema.Table, w http.ResponseWriter, r *http.Reque } sort.Sort(&sorter) envelope := struct { - ColumnCategory []string - CacheType []string - Table *schema.Table + CacheType []string + Table *schema.Table }{ - ColumnCategory: []string{"other", "number", "varbinary"}, - CacheType: []string{"none", "read-write", "write-only"}, + CacheType: []string{"none", "read-write", "write-only"}, } for _, Value := range sorter.rows { envelope.Table = Value diff --git a/go/vt/tabletserver/schemaz_test.go b/go/vt/tabletserver/schemaz_test.go index 2e0a4384d4..5df51420f9 100644 --- a/go/vt/tabletserver/schemaz_test.go +++ b/go/vt/tabletserver/schemaz_test.go @@ -23,17 +23,17 @@ func TestSchamazHandler(t *testing.T) { tableB := schema.NewTable("b") tableC := schema.NewTable("c") - tableA.AddColumn("column1", "int", sqltypes.MakeNumeric([]byte("0")), "auto_increment") + tableA.AddColumn("column1", sqltypes.Int64, sqltypes.MakeNumeric([]byte("0")), "auto_increment") tableA.AddIndex("index1").AddColumn("index_column", 1000) - tableA.CacheType = schema.CACHE_RW + tableA.CacheType = schema.CacheRW - tableB.AddColumn("column2", "string", sqltypes.MakeString([]byte("NULL")), "") + tableB.AddColumn("column2", sqltypes.VarChar, sqltypes.MakeString([]byte("NULL")), "") tableB.AddIndex("index2").AddColumn("index_column2", 200) - tableB.CacheType = schema.CACHE_W + tableB.CacheType = schema.CacheW - tableC.AddColumn("column3", "string", sqltypes.MakeString([]byte("")), "") + tableC.AddColumn("column3", sqltypes.VarChar, sqltypes.MakeString([]byte("")), "") tableC.AddIndex("index3").AddColumn("index_column3", 500) - tableC.CacheType = schema.CACHE_NONE + tableC.CacheType = schema.CacheNone tables := []*schema.Table{ tableA, tableB, tableC, @@ -42,7 +42,7 @@ func TestSchamazHandler(t *testing.T) { body, _ := ioutil.ReadAll(resp.Body) tableCPattern := []string{ `c`, - `column3: other, ,
`, + `column3: VARCHAR, ,
`, `index3: \(index_column3,\), \(500,\)
`, `none`, } @@ -55,7 +55,7 @@ func TestSchamazHandler(t *testing.T) { } tableBPattern := []string{ `b`, - `column2: other, , NULL
`, + `column2: VARCHAR, , NULL
`, `index2: \(index_column2,\), \(200,\)
`, `write-only`, } @@ -68,7 +68,7 @@ func TestSchamazHandler(t *testing.T) { } tableAPattern := []string{ `a`, - `column1: number, autoinc,
`, + `column1: INT64, autoinc,
`, `index1: \(index_column,\), \(1000,\)
`, `read-write`, } diff --git a/go/vt/tabletserver/table_info.go b/go/vt/tabletserver/table_info.go index 02053f3fc1..9b2036f11e 100644 --- a/go/vt/tabletserver/table_info.go +++ b/go/vt/tabletserver/table_info.go @@ -10,7 +10,9 @@ import ( "strings" log "github.com/golang/glog" + "github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/sync2" + "github.com/youtube/vitess/go/vt/proto/query" "github.com/youtube/vitess/go/vt/schema" "golang.org/x/net/context" ) @@ -46,12 +48,29 @@ func loadTableInfo(conn *DBConn, tableName string) (ti *TableInfo, err error) { } func (ti *TableInfo) fetchColumns(conn *DBConn) error { + qr, err := conn.Exec(context.Background(), fmt.Sprintf("select * from `%s` where 1 != 1", ti.Name), 10000, true) + if err != nil { + return err + } + fieldTypes := make(map[string]query.Type, len(qr.Fields)) + for _, field := range qr.Fields { + fieldTypes[field.Name], err = sqltypes.MySQLToType(field.Type, field.Flags) + if err != nil { + return err + } + } columns, err := conn.Exec(context.Background(), fmt.Sprintf("describe `%s`", ti.Name), 10000, false) if err != nil { return err } for _, row := range columns.Rows { - ti.AddColumn(row[0].String(), row[1].String(), row[4], row[5].String()) + name := row[0].String() + columnType, ok := fieldTypes[name] + if !ok { + log.Warningf("Table: %s, column %s not found in select list, skipping.", ti.Name, name) + continue + } + ti.AddColumn(name, columnType, row[4], row[5].String()) } return nil } @@ -155,13 +174,14 @@ func (ti *TableInfo) initRowCache(conn *DBConn, tableType string, comment string return } for _, col := range ti.PKColumns { - if ti.Columns[col].Category == schema.CAT_OTHER { - log.Infof("Table %s pk has unsupported column types. Will not be cached.", ti.Name) - return + if sqltypes.IsIntegral(ti.Columns[col].Type) || ti.Columns[col].Type == sqltypes.VarBinary { + continue } + log.Infof("Table %s pk has unsupported column types. Will not be cached.", ti.Name) + return } - ti.CacheType = schema.CACHE_RW + ti.CacheType = schema.CacheRW ti.Cache = NewRowCache(ti, cachePool) } diff --git a/go/vt/tabletserver/table_info_test.go b/go/vt/tabletserver/table_info_test.go index 21dfeb7f7d..fc00008c67 100644 --- a/go/vt/tabletserver/table_info_test.go +++ b/go/vt/tabletserver/table_info_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/youtube/vitess/go/mysql" mproto "github.com/youtube/vitess/go/mysql/proto" "github.com/youtube/vitess/go/sqldb" "github.com/youtube/vitess/go/sqltypes" @@ -102,6 +103,12 @@ func TestTableInfoWithoutRowCacheViaNoPKColumn(t *testing.T) { fakecacheservice.Register() db := fakesqldb.Register() db.AddQuery("show index from `test_table`", &mproto.QueryResult{}) + db.AddQuery("select * from `test_table` where 1 != 1", &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }}, + }) db.AddQuery("describe `test_table`", &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{ @@ -145,12 +152,18 @@ func TestTableInfoWithoutRowCacheViaUnknownPKColumnType(t *testing.T) { }, }, }) + db.AddQuery("select * from `test_table` where 1 != 1", &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeNewDecimal, + }}, + }) db.AddQuery("describe `test_table`", &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ sqltypes.MakeString([]byte("pk")), - sqltypes.MakeString([]byte("unknown_type")), + sqltypes.MakeString([]byte("decimal")), sqltypes.MakeString([]byte{}), sqltypes.MakeString([]byte{}), sqltypes.MakeString([]byte("1")), @@ -260,7 +273,7 @@ func TestTableInfoInvalidCardinalityInIndex(t *testing.T) { defer cachePool.Close() tableInfo, err := newTestTableInfo(cachePool, "USER_TABLE", "test table", db) if err != nil { - t.Fatalf("failed to create a table info") + t.Fatalf("failed to create a table info: %v", err) } if len(tableInfo.PKColumns) != 1 { t.Fatalf("table should have one PK column although the cardinality is invalid") @@ -309,6 +322,18 @@ func newTestTableInfoCachePool() *CachePool { func getTestTableInfoQueries() map[string]*mproto.QueryResult { return map[string]*mproto.QueryResult{ + "select * from `test_table` where 1 != 1": &mproto.QueryResult{ + Fields: []mproto.Field{{ + Name: "pk", + Type: mysql.TypeLong, + }, { + Name: "name", + Type: mysql.TypeLong, + }, { + Name: "addr", + Type: mysql.TypeLong, + }}, + }, "describe `test_table`": &mproto.QueryResult{ RowsAffected: 3, Rows: [][]sqltypes.Value{ diff --git a/go/vt/tabletserver/tabletserver_test.go b/go/vt/tabletserver/tabletserver_test.go index 9200ae87dc..139405f28c 100644 --- a/go/vt/tabletserver/tabletserver_test.go +++ b/go/vt/tabletserver/tabletserver_test.go @@ -1490,6 +1490,9 @@ func getSupportedQueries() map[string]*mproto.QueryResult { "select * from test_table where 1 != 1": &mproto.QueryResult{ Fields: getTestTableFields(), }, + "select * from `test_table` where 1 != 1": &mproto.QueryResult{ + Fields: getTestTableFields(), + }, baseShowTables: &mproto.QueryResult{ RowsAffected: 1, Rows: [][]sqltypes.Value{ diff --git a/go/vt/vtgate/proto/proto3.go b/go/vt/vtgate/proto/proto3.go index 29c0190cf0..a1ceaa8dad 100644 --- a/go/vt/vtgate/proto/proto3.go +++ b/go/vt/vtgate/proto/proto3.go @@ -8,7 +8,6 @@ import ( "github.com/youtube/vitess/go/vt/key" tproto "github.com/youtube/vitess/go/vt/tabletserver/proto" - pbq "github.com/youtube/vitess/go/vt/proto/query" pb "github.com/youtube/vitess/go/vt/proto/vtgate" ) From ccbe5915c568d69a6f9cf123e2268ff3c8a1fe2f Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Thu, 5 Nov 2015 21:19:23 -0800 Subject: [PATCH 3/6] sqltypes transition: Inner->inner --- go/sqltypes/value.go | 52 +++++++------- go/sqltypes/value_test.go | 28 ++++---- go/vt/tabletserver/endtoend/batch_test.go | 14 ++-- .../endtoend/compatibility_test.go | 70 +++++++++---------- go/vt/tabletserver/endtoend/nocache_test.go | 10 +-- go/vt/tabletserver/query_splitter.go | 2 +- 6 files changed, 88 insertions(+), 88 deletions(-) diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index 853cab52a2..e12fed44c8 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -38,7 +38,7 @@ type BinWriter interface { // Value can store any SQL value. NULL is stored as nil. type Value struct { - Inner InnerValue + inner innerValue } // Numeric represents non-fractional SQL number. @@ -68,26 +68,26 @@ func MakeString(b []byte) Value { // Raw returns the raw bytes. All types are currently implemented as []byte. func (v Value) Raw() []byte { - if v.Inner == nil { + if v.inner == nil { return nil } - return v.Inner.raw() + return v.inner.raw() } // String returns the raw value as a string func (v Value) String() string { - if v.Inner == nil { + if v.inner == nil { return "" } - return hack.String(v.Inner.raw()) + return hack.String(v.inner.raw()) } // ParseInt64 will parse a Numeric value into an int64 func (v Value) ParseInt64() (val int64, err error) { - if v.Inner == nil { + if v.inner == nil { return 0, fmt.Errorf("value is null") } - n, ok := v.Inner.(Numeric) + n, ok := v.inner.(Numeric) if !ok { return 0, fmt.Errorf("value is not Numeric") } @@ -96,10 +96,10 @@ func (v Value) ParseInt64() (val int64, err error) { // ParseUint64 will parse a Numeric value into a uint64 func (v Value) ParseUint64() (val uint64, err error) { - if v.Inner == nil { + if v.inner == nil { return 0, fmt.Errorf("value is null") } - n, ok := v.Inner.(Numeric) + n, ok := v.inner.(Numeric) if !ok { return 0, fmt.Errorf("value is not Numeric") } @@ -108,10 +108,10 @@ func (v Value) ParseUint64() (val uint64, err error) { // ParseFloat64 will parse a Fractional value into an float64 func (v Value) ParseFloat64() (val float64, err error) { - if v.Inner == nil { + if v.inner == nil { return 0, fmt.Errorf("value is null") } - n, ok := v.Inner.(Fractional) + n, ok := v.inner.(Fractional) if !ok { return 0, fmt.Errorf("value is not Fractional") } @@ -120,23 +120,23 @@ func (v Value) ParseFloat64() (val float64, err error) { // EncodeSQL encodes the value into an SQL statement. Can be binary. func (v Value) EncodeSQL(b BinWriter) { - if v.Inner == nil { + if v.inner == nil { if _, err := b.Write(nullstr); err != nil { panic(err) } } else { - v.Inner.encodeSQL(b) + v.inner.encodeSQL(b) } } // EncodeASCII encodes the value using 7-bit clean ascii bytes. func (v Value) EncodeASCII(b BinWriter) { - if v.Inner == nil { + if v.inner == nil { if _, err := b.Write(nullstr); err != nil { panic(err) } } else { - v.Inner.encodeASCII(b) + v.inner.encodeASCII(b) } } @@ -168,21 +168,21 @@ func (v *Value) UnmarshalBson(buf *bytes.Buffer, kind byte) { // IsNull returns true if Value is null. func (v Value) IsNull() bool { - return v.Inner == nil + return v.inner == nil } // IsNumeric returns true if Value is numeric. func (v Value) IsNumeric() (ok bool) { - if v.Inner != nil { - _, ok = v.Inner.(Numeric) + if v.inner != nil { + _, ok = v.inner.(Numeric) } return ok } // IsFractional returns true if Value is fractional. func (v Value) IsFractional() (ok bool) { - if v.Inner != nil { - _, ok = v.Inner.(Fractional) + if v.inner != nil { + _, ok = v.inner.(Fractional) } return ok } @@ -190,8 +190,8 @@ func (v Value) IsFractional() (ok bool) { // IsString returns true if Value is a string, or needs // to be quoted before sending to MySQL. func (v Value) IsString() (ok bool) { - if v.Inner != nil { - _, ok = v.Inner.(String) + if v.inner != nil { + _, ok = v.inner.(String) } return ok } @@ -199,7 +199,7 @@ func (v Value) IsString() (ok bool) { // MarshalJSON should only be used for testing. // It's not a complete implementation. func (v Value) MarshalJSON() ([]byte, error) { - return json.Marshal(v.Inner) + return json.Marshal(v.inner) } // UnmarshalJSON should only be used for testing. @@ -233,8 +233,8 @@ func (v *Value) UnmarshalJSON(b []byte) error { return err } -// InnerValue defines methods that need to be supported by all non-null value types. -type InnerValue interface { +// innerValue defines methods that need to be supported by all non-null value types. +type innerValue interface { raw() []byte encodeSQL(BinWriter) encodeASCII(BinWriter) @@ -267,7 +267,7 @@ func BuildValue(goval interface{}) (v Value, err error) { case time.Time: v = Value{String([]byte(bindVal.Format("2006-01-02 15:04:05")))} case Numeric, Fractional, String: - v = Value{bindVal.(InnerValue)} + v = Value{bindVal.(innerValue)} case Value: v = bindVal default: diff --git a/go/sqltypes/value_test.go b/go/sqltypes/value_test.go index 27a37b218c..d17efc63d6 100644 --- a/go/sqltypes/value_test.go +++ b/go/sqltypes/value_test.go @@ -169,21 +169,21 @@ func TestBuildValue(t *testing.T) { t.Errorf("%v", err) } if !v.IsNumeric() || v.String() != "-1" { - t.Errorf("Expecting -1, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting -1, received %T: %s", v.inner, v.String()) } v, err = BuildValue(int32(-1)) if err != nil { t.Errorf("%v", err) } if !v.IsNumeric() || v.String() != "-1" { - t.Errorf("Expecting -1, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting -1, received %T: %s", v.inner, v.String()) } v, err = BuildValue(int64(-1)) if err != nil { t.Errorf("%v", err) } if !v.IsNumeric() || v.String() != "-1" { - t.Errorf("Expecting -1, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting -1, received %T: %s", v.inner, v.String()) } n64, err = v.ParseUint64() if err == nil { @@ -201,14 +201,14 @@ func TestBuildValue(t *testing.T) { t.Errorf("%v", err) } if !v.IsNumeric() || v.String() != "1" { - t.Errorf("Expecting 1, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 1, received %T: %s", v.inner, v.String()) } v, err = BuildValue(uint32(1)) if err != nil { t.Errorf("%v", err) } if !v.IsNumeric() || v.String() != "1" { - t.Errorf("Expecting 1, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 1, received %T: %s", v.inner, v.String()) } v, err = BuildValue(uint64(1)) if err != nil { @@ -222,14 +222,14 @@ func TestBuildValue(t *testing.T) { t.Errorf("Expecting 1, got %v", n64) } if !v.IsNumeric() || v.String() != "1" { - t.Errorf("Expecting 1, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 1, received %T: %s", v.inner, v.String()) } v, err = BuildValue(1.23) if err != nil { t.Errorf("%v", err) } if !v.IsFractional() || v.String() != "1.23" { - t.Errorf("Expecting 1.23, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 1.23, received %T: %s", v.inner, v.String()) } n64, err = v.ParseUint64() if err == nil { @@ -240,14 +240,14 @@ func TestBuildValue(t *testing.T) { t.Errorf("%v", err) } if !v.IsString() || v.String() != "abcd" { - t.Errorf("Expecting abcd, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting abcd, received %T: %s", v.inner, v.String()) } v, err = BuildValue([]byte("abcd")) if err != nil { t.Errorf("%v", err) } if !v.IsString() || v.String() != "abcd" { - t.Errorf("Expecting abcd, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting abcd, received %T: %s", v.inner, v.String()) } n64, err = v.ParseUint64() if err == nil || err.Error() != "value is not Numeric" { @@ -258,28 +258,28 @@ func TestBuildValue(t *testing.T) { t.Errorf("%v", err) } if !v.IsString() || v.String() != "2012-02-24 23:19:43" { - t.Errorf("Expecting 2012-02-24 23:19:43, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 2012-02-24 23:19:43, received %T: %s", v.inner, v.String()) } v, err = BuildValue(Numeric([]byte("123"))) if err != nil { t.Errorf("%v", err) } if !v.IsNumeric() || v.String() != "123" { - t.Errorf("Expecting 123, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 123, received %T: %s", v.inner, v.String()) } v, err = BuildValue(Fractional([]byte("12.3"))) if err != nil { t.Errorf("%v", err) } if !v.IsFractional() || v.String() != "12.3" { - t.Errorf("Expecting 12.3, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting 12.3, received %T: %s", v.inner, v.String()) } v, err = BuildValue(String([]byte("abc"))) if err != nil { t.Errorf("%v", err) } if !v.IsString() || v.String() != "abc" { - t.Errorf("Expecting abc, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting abc, received %T: %s", v.inner, v.String()) } v, err = BuildValue(float32(1.23)) if err == nil { @@ -291,7 +291,7 @@ func TestBuildValue(t *testing.T) { t.Errorf("%v", err) } if !v.IsString() || v.String() != "ab" { - t.Errorf("Expecting ab, received %T: %s", v.Inner, v.String()) + t.Errorf("Expecting ab, received %T: %s", v.inner, v.String()) } v, err = BuildValue(float32(1.23)) if err == nil { diff --git a/go/vt/tabletserver/endtoend/batch_test.go b/go/vt/tabletserver/endtoend/batch_test.go index 1acec8df96..403ab7ea50 100644 --- a/go/vt/tabletserver/endtoend/batch_test.go +++ b/go/vt/tabletserver/endtoend/batch_test.go @@ -45,10 +45,10 @@ func TestBatchRead(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("1")}, - sqltypes.Value{Inner: sqltypes.Numeric("2")}, - sqltypes.Value{Inner: sqltypes.String("bcde")}, - sqltypes.Value{Inner: sqltypes.String("fghi")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeNumeric([]byte("2")), + sqltypes.MakeString([]byte("bcde")), + sqltypes.MakeString([]byte("fghi")), }, }, } @@ -65,8 +65,8 @@ func TestBatchRead(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("1")}, - sqltypes.Value{Inner: sqltypes.Numeric("2")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeNumeric([]byte("2")), }, }, } @@ -96,7 +96,7 @@ func TestBatchTransaction(t *testing.T) { wantRows := [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("4")}, + sqltypes.MakeNumeric([]byte("4")), sqltypes.Value{}, sqltypes.Value{}, sqltypes.Value{}, diff --git a/go/vt/tabletserver/endtoend/compatibility_test.go b/go/vt/tabletserver/endtoend/compatibility_test.go index 47cee389ad..c7b3028655 100644 --- a/go/vt/tabletserver/endtoend/compatibility_test.go +++ b/go/vt/tabletserver/endtoend/compatibility_test.go @@ -45,10 +45,10 @@ func TestCharaterSet(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("1")}, - sqltypes.Value{Inner: sqltypes.Fractional("1.12345")}, - sqltypes.Value{Inner: sqltypes.String("\xc2\xa2")}, - sqltypes.Value{Inner: sqltypes.String("\x00\xff")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeFractional([]byte("1.12345")), + sqltypes.MakeString([]byte("\xc2\xa2")), + sqltypes.MakeString([]byte("\x00\xff")), }, }, } @@ -138,17 +138,17 @@ func TestInts(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("-128")}, - sqltypes.Value{Inner: sqltypes.Numeric("255")}, - sqltypes.Value{Inner: sqltypes.Numeric("-32768")}, - sqltypes.Value{Inner: sqltypes.Numeric("65535")}, - sqltypes.Value{Inner: sqltypes.Numeric("-8388608")}, - sqltypes.Value{Inner: sqltypes.Numeric("16777215")}, - sqltypes.Value{Inner: sqltypes.Numeric("-2147483648")}, - sqltypes.Value{Inner: sqltypes.Numeric("4294967295")}, - sqltypes.Value{Inner: sqltypes.Numeric("-9223372036854775808")}, - sqltypes.Value{Inner: sqltypes.Numeric("18446744073709551615")}, - sqltypes.Value{Inner: sqltypes.Numeric("2012")}, + sqltypes.MakeNumeric([]byte("-128")), + sqltypes.MakeNumeric([]byte("255")), + sqltypes.MakeNumeric([]byte("-32768")), + sqltypes.MakeNumeric([]byte("65535")), + sqltypes.MakeNumeric([]byte("-8388608")), + sqltypes.MakeNumeric([]byte("16777215")), + sqltypes.MakeNumeric([]byte("-2147483648")), + sqltypes.MakeNumeric([]byte("4294967295")), + sqltypes.MakeNumeric([]byte("-9223372036854775808")), + sqltypes.MakeNumeric([]byte("18446744073709551615")), + sqltypes.MakeNumeric([]byte("2012")), }, }, } @@ -230,11 +230,11 @@ func TestFractionals(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("1")}, - sqltypes.Value{Inner: sqltypes.Fractional("1.99")}, - sqltypes.Value{Inner: sqltypes.Fractional("2.99")}, - sqltypes.Value{Inner: sqltypes.Fractional("3.99")}, - sqltypes.Value{Inner: sqltypes.Fractional("4.99")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeFractional([]byte("1.99")), + sqltypes.MakeFractional([]byte("2.99")), + sqltypes.MakeFractional([]byte("3.99")), + sqltypes.MakeFractional([]byte("4.99")), }, }, } @@ -336,16 +336,16 @@ func TestStrings(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.String("a")}, - sqltypes.Value{Inner: sqltypes.String("b")}, - sqltypes.Value{Inner: sqltypes.String("c")}, - sqltypes.Value{Inner: sqltypes.String("d\x00\x00\x00")}, - sqltypes.Value{Inner: sqltypes.String("e")}, - sqltypes.Value{Inner: sqltypes.String("f")}, - sqltypes.Value{Inner: sqltypes.String("g")}, - sqltypes.Value{Inner: sqltypes.String("h")}, - sqltypes.Value{Inner: sqltypes.String("a")}, - sqltypes.Value{Inner: sqltypes.String("a,b")}, + sqltypes.MakeString([]byte("a")), + sqltypes.MakeString([]byte("b")), + sqltypes.MakeString([]byte("c")), + sqltypes.MakeString([]byte("d\x00\x00\x00")), + sqltypes.MakeString([]byte("e")), + sqltypes.MakeString([]byte("f")), + sqltypes.MakeString([]byte("g")), + sqltypes.MakeString([]byte("h")), + sqltypes.MakeString([]byte("a")), + sqltypes.MakeString([]byte("a,b")), }, }, } @@ -426,11 +426,11 @@ func TestMiscTypes(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("1")}, - sqltypes.Value{Inner: sqltypes.String("\x01")}, - sqltypes.Value{Inner: sqltypes.String("2012-01-01")}, - sqltypes.Value{Inner: sqltypes.String("2012-01-01 15:45:45")}, - sqltypes.Value{Inner: sqltypes.String("15:45:45")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeString([]byte("\x01")), + sqltypes.MakeString([]byte("2012-01-01")), + sqltypes.MakeString([]byte("2012-01-01 15:45:45")), + sqltypes.MakeString([]byte("15:45:45")), }, }, } diff --git a/go/vt/tabletserver/endtoend/nocache_test.go b/go/vt/tabletserver/endtoend/nocache_test.go index eee627742d..59f8072ca3 100644 --- a/go/vt/tabletserver/endtoend/nocache_test.go +++ b/go/vt/tabletserver/endtoend/nocache_test.go @@ -69,7 +69,7 @@ func TestBinary(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.String(binaryData)}, + sqltypes.MakeString([]byte(binaryData)), }, }, } @@ -299,7 +299,7 @@ func TestBindInSelect(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.Numeric("1")}, + sqltypes.MakeNumeric([]byte("1")), }, }, } @@ -325,7 +325,7 @@ func TestBindInSelect(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.String("abcd")}, + sqltypes.MakeString([]byte("abcd")), }, }, } @@ -351,7 +351,7 @@ func TestBindInSelect(t *testing.T) { RowsAffected: 1, Rows: [][]sqltypes.Value{ []sqltypes.Value{ - sqltypes.Value{Inner: sqltypes.String("\x00\xff")}, + sqltypes.MakeString([]byte("\x00\xff")), }, }, } @@ -449,7 +449,7 @@ func TestDBAStatements(t *testing.T) { t.Error(err) return } - wantCol := sqltypes.Value{Inner: sqltypes.String("version")} + wantCol := sqltypes.MakeString([]byte("version")) if !reflect.DeepEqual(qr.Rows[0][0], wantCol) { t.Errorf("Execute: \n%#v, want \n%#v", qr.Rows[0][0], wantCol) } diff --git a/go/vt/tabletserver/query_splitter.go b/go/vt/tabletserver/query_splitter.go index 78713a3fe4..37817c7567 100644 --- a/go/vt/tabletserver/query_splitter.go +++ b/go/vt/tabletserver/query_splitter.go @@ -133,7 +133,7 @@ func (qs *QuerySplitter) split(columnType int64, pkMinMax *mproto.QueryResult) ( RowCount: qs.rowCount, } splits = append(splits, *split) - start.Inner = end.Inner + start = end } qs.sel.Where = whereClause // reset where clause } From 177277b53dcd084948024e4d0377c2cc54ce333c Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Thu, 5 Nov 2015 21:45:19 -0800 Subject: [PATCH 4/6] sqltypes transition: missed a few inner refs --- go/mysql/proto/bson_test.go | 2 +- go/vt/vtgate/router_dml_test.go | 8 +++---- go/vt/vtgate/sandbox_test.go | 4 ++-- go/vt/worker/diff_utils_test.go | 40 ++++++++++++++++----------------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/go/mysql/proto/bson_test.go b/go/mysql/proto/bson_test.go index 34a097c69c..d9fa115112 100644 --- a/go/mysql/proto/bson_test.go +++ b/go/mysql/proto/bson_test.go @@ -32,7 +32,7 @@ func TestQueryResult(t *testing.T) { RowsAffected: 2, InsertId: 3, Rows: [][]sqltypes.Value{ - {{sqltypes.Numeric("1")}, {sqltypes.String("aa")}}, + {sqltypes.MakeNumeric([]byte("1")), sqltypes.MakeString([]byte("aa"))}, }, Err: &RPCError{1000, "failed due to err"}, } diff --git a/go/vt/vtgate/router_dml_test.go b/go/vt/vtgate/router_dml_test.go index 0e40eac656..6939a2da25 100644 --- a/go/vt/vtgate/router_dml_test.go +++ b/go/vt/vtgate/router_dml_test.go @@ -126,8 +126,8 @@ func TestDeleteEqual(t *testing.T) { RowsAffected: 1, InsertId: 0, Rows: [][]sqltypes.Value{{ - {sqltypes.Numeric("1")}, - {sqltypes.String("myname")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeString([]byte("myname")), }}, }}) _, err := routerExec(router, "delete from user where id = 1", nil) @@ -263,8 +263,8 @@ func TestDeleteVindexFail(t *testing.T) { RowsAffected: 1, InsertId: 0, Rows: [][]sqltypes.Value{{ - {sqltypes.String("foo")}, - {sqltypes.String("myname")}, + sqltypes.MakeString([]byte("foo")), + sqltypes.MakeString([]byte("myname")), }}, }}) _, err = routerExec(router, "delete from user where id = 1", nil) diff --git a/go/vt/vtgate/sandbox_test.go b/go/vt/vtgate/sandbox_test.go index ad9d3795c4..f8cf4b8a7f 100644 --- a/go/vt/vtgate/sandbox_test.go +++ b/go/vt/vtgate/sandbox_test.go @@ -592,7 +592,7 @@ var singleRowResult = &mproto.QueryResult{ RowsAffected: 1, InsertId: 0, Rows: [][]sqltypes.Value{{ - {sqltypes.Numeric("1")}, - {sqltypes.String("foo")}, + sqltypes.MakeNumeric([]byte("1")), + sqltypes.MakeString([]byte("foo")), }}, } diff --git a/go/vt/worker/diff_utils_test.go b/go/vt/worker/diff_utils_test.go index c83b8b703d..5b1941d1fd 100644 --- a/go/vt/worker/diff_utils_test.go +++ b/go/vt/worker/diff_utils_test.go @@ -51,8 +51,8 @@ func TestCompareRows(t *testing.T) { }{ { fields: []mproto.Field{{"a", mproto.VT_LONG, mproto.VT_ZEROVALUE_FLAG}}, - left: []sqltypes.Value{{sqltypes.Numeric("123")}}, - right: []sqltypes.Value{{sqltypes.Numeric("14")}}, + left: []sqltypes.Value{sqltypes.MakeNumeric([]byte("123"))}, + right: []sqltypes.Value{sqltypes.MakeNumeric([]byte("14"))}, want: 1, }, { @@ -61,55 +61,55 @@ func TestCompareRows(t *testing.T) { {"b", mproto.VT_LONG, mproto.VT_ZEROVALUE_FLAG}, }, left: []sqltypes.Value{ - {sqltypes.Numeric("555")}, - {sqltypes.Numeric("12")}, + sqltypes.MakeNumeric([]byte("555")), + sqltypes.MakeNumeric([]byte("12")), }, right: []sqltypes.Value{ - {sqltypes.Numeric("555")}, - {sqltypes.Numeric("144")}, + sqltypes.MakeNumeric([]byte("555")), + sqltypes.MakeNumeric([]byte("144")), }, want: -1, }, { fields: []mproto.Field{{"a", mproto.VT_LONG, mproto.VT_ZEROVALUE_FLAG}}, - left: []sqltypes.Value{{sqltypes.Numeric("144")}}, - right: []sqltypes.Value{{sqltypes.Numeric("144")}}, + left: []sqltypes.Value{sqltypes.MakeNumeric([]byte("144"))}, + right: []sqltypes.Value{sqltypes.MakeNumeric([]byte("144"))}, want: 0, }, { fields: []mproto.Field{{"a", mproto.VT_LONGLONG, mproto.VT_UNSIGNED_FLAG}}, - left: []sqltypes.Value{{sqltypes.Numeric("9223372036854775809")}}, - right: []sqltypes.Value{{sqltypes.Numeric("9223372036854775810")}}, + left: []sqltypes.Value{sqltypes.MakeNumeric([]byte("9223372036854775809"))}, + right: []sqltypes.Value{sqltypes.MakeNumeric([]byte("9223372036854775810"))}, want: -1, }, { fields: []mproto.Field{{"a", mproto.VT_LONGLONG, mproto.VT_UNSIGNED_FLAG}}, - left: []sqltypes.Value{{sqltypes.Numeric("9223372036854775819")}}, - right: []sqltypes.Value{{sqltypes.Numeric("9223372036854775810")}}, + left: []sqltypes.Value{sqltypes.MakeNumeric([]byte("9223372036854775819"))}, + right: []sqltypes.Value{sqltypes.MakeNumeric([]byte("9223372036854775810"))}, want: 1, }, { fields: []mproto.Field{{"a", mproto.VT_DOUBLE, mproto.VT_ZEROVALUE_FLAG}}, - left: []sqltypes.Value{{sqltypes.Fractional("3.14")}}, - right: []sqltypes.Value{{sqltypes.Fractional("3.2")}}, + left: []sqltypes.Value{sqltypes.MakeFractional([]byte("3.14"))}, + right: []sqltypes.Value{sqltypes.MakeFractional([]byte("3.2"))}, want: -1, }, { fields: []mproto.Field{{"a", mproto.VT_DOUBLE, mproto.VT_ZEROVALUE_FLAG}}, - left: []sqltypes.Value{{sqltypes.Fractional("123.4")}}, - right: []sqltypes.Value{{sqltypes.Fractional("123.2")}}, + left: []sqltypes.Value{sqltypes.MakeFractional([]byte("123.4"))}, + right: []sqltypes.Value{sqltypes.MakeFractional([]byte("123.2"))}, want: 1, }, { fields: []mproto.Field{{"a", mproto.VT_STRING, mproto.VT_ZEROVALUE_FLAG}}, - left: []sqltypes.Value{{sqltypes.String("abc")}}, - right: []sqltypes.Value{{sqltypes.String("abb")}}, + left: []sqltypes.Value{sqltypes.MakeString([]byte("abc"))}, + right: []sqltypes.Value{sqltypes.MakeString([]byte("abb"))}, want: 1, }, { fields: []mproto.Field{{"a", mproto.VT_STRING, mproto.VT_ZEROVALUE_FLAG}}, - left: []sqltypes.Value{{sqltypes.String("abc")}}, - right: []sqltypes.Value{{sqltypes.String("abd")}}, + left: []sqltypes.Value{sqltypes.MakeString([]byte("abc"))}, + right: []sqltypes.Value{sqltypes.MakeString([]byte("abd"))}, want: -1, }, } From ce67d707be13afb0c628f2593eccfdc703c108df Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Thu, 5 Nov 2015 23:37:27 -0800 Subject: [PATCH 5/6] sqltypes transition: MySQLToType panic on error --- go/mysql/proto/proto3.go | 24 ++-- go/mysql/proto/proto3_test.go | 15 +-- go/sqltypes/type.go | 9 +- go/sqltypes/type_test.go | 21 ++-- go/vt/binlog/grpcbinlogstreamer/streamer.go | 6 +- go/vt/binlog/proto/proto3.go | 10 +- go/vt/tabletmanager/grpctmserver/server.go | 8 +- .../endtoend/compatibility_test.go | 30 +---- go/vt/tabletserver/grpcqueryservice/server.go | 24 ++-- go/vt/tabletserver/proto/proto3.go | 12 +- go/vt/tabletserver/table_info.go | 5 +- go/vt/vtgate/grpcvtgateservice/server.go | 116 +++++++----------- 12 files changed, 98 insertions(+), 182 deletions(-) diff --git a/go/mysql/proto/proto3.go b/go/mysql/proto/proto3.go index e9bf2151b0..9fdeec08a5 100644 --- a/go/mysql/proto/proto3.go +++ b/go/mysql/proto/proto3.go @@ -39,23 +39,19 @@ func ProtoToCharset(c *pbb.Charset) *Charset { } // FieldsToProto3 converts an internal []Field to the proto3 version -func FieldsToProto3(f []Field) ([]*pbq.Field, error) { +func FieldsToProto3(f []Field) []*pbq.Field { if len(f) == 0 { - return nil, nil + return nil } result := make([]*pbq.Field, len(f)) for i, f := range f { - vitessType, err := sqltypes.MySQLToType(f.Type, f.Flags) - if err != nil { - return nil, err - } result[i] = &pbq.Field{ Name: f.Name, - Type: vitessType, + Type: sqltypes.MySQLToType(f.Type, f.Flags), } } - return result, nil + return result } // Proto3ToFields converts a proto3 []Fields to an internal data structure. @@ -132,20 +128,16 @@ func Proto3ToRows(rows []*pbq.Row) [][]sqltypes.Value { } // QueryResultToProto3 converts an internal QueryResult to the proto3 version -func QueryResultToProto3(qr *QueryResult) (*pbq.QueryResult, error) { +func QueryResultToProto3(qr *QueryResult) *pbq.QueryResult { if qr == nil { - return nil, nil - } - fields, err := FieldsToProto3(qr.Fields) - if err != nil { - return nil, err + return nil } return &pbq.QueryResult{ - Fields: fields, + Fields: FieldsToProto3(qr.Fields), RowsAffected: qr.RowsAffected, InsertId: qr.InsertId, Rows: RowsToProto3(qr.Rows), - }, nil + } } // Proto3ToQueryResult converts a proto3 QueryResult to an internal data structure. diff --git a/go/mysql/proto/proto3_test.go b/go/mysql/proto/proto3_test.go index ba9792f88a..8c35edf78a 100644 --- a/go/mysql/proto/proto3_test.go +++ b/go/mysql/proto/proto3_test.go @@ -21,10 +21,7 @@ func TestFields(t *testing.T) { Name: "bb", Type: 2, }} - p3, err := FieldsToProto3(fields) - if err != nil { - t.Error(err) - } + p3 := FieldsToProto3(fields) wantp3 := []*query.Field{ &query.Field{ Name: "aa", @@ -43,16 +40,6 @@ func TestFields(t *testing.T) { if !reflect.DeepEqual(reverse, fields) { t.Errorf("reverse: %v, want %v", reverse, fields) } - - fields = []Field{{ - Name: "aa", - Type: 15, - }} - _, err = FieldsToProto3(fields) - want := "Could not map: 15 to a vitess type" - if err == nil || err.Error() != want { - t.Errorf("Error: %v, want %v", err, want) - } } func TestRowsToProto3(t *testing.T) { diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index adcf6eb74d..52f5c70334 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -182,18 +182,19 @@ var typeToMySQL = map[query.Type]struct { } // MySQLToType computes the vitess type from mysql type and flags. -func MySQLToType(mysqlType, flags int64) (query.Type, error) { +// The function panics if the type is unrecognized. +func MySQLToType(mysqlType, flags int64) query.Type { result, ok := mysqlToType[mysqlType] if !ok { - return Null, fmt.Errorf("Could not map: %d to a vitess type", mysqlType) + panic(fmt.Errorf("Could not map: %d to a vitess type", mysqlType)) } converted := (flags << 16) & relevantFlags modified, ok := modifier[int64(result)|converted] if ok { - return modified, nil + return modified } - return result, nil + return result } // TypeToMySQL returns the equivalent mysql type and flag for a vitess type. diff --git a/go/sqltypes/type_test.go b/go/sqltypes/type_test.go index f3cf20203e..33b9fb1295 100644 --- a/go/sqltypes/type_test.go +++ b/go/sqltypes/type_test.go @@ -172,20 +172,19 @@ func TestTypeToMySQL(t *testing.T) { } func TestTypeFlexibility(t *testing.T) { - v, err := MySQLToType(1, mysqlBinary>>16) - if err != nil { - t.Error(err) - return - } + v := MySQLToType(1, mysqlBinary>>16) if v != Int8 { t.Errorf("conversion: %v, want %v", v, Int8) } } -func TestTypeError(t *testing.T) { - _, err := MySQLToType(15, 0) - want := "Could not map: 15 to a vitess type" - if err == nil || err.Error() != want { - t.Errorf("Error: %v, want %v", err, want) - } +func TestTypePanic(t *testing.T) { + defer func() { + err := recover().(error) + want := "Could not map: 15 to a vitess type" + if err == nil || err.Error() != want { + t.Errorf("Error: %v, want %v", err, want) + } + }() + _ = MySQLToType(15, 0) } diff --git a/go/vt/binlog/grpcbinlogstreamer/streamer.go b/go/vt/binlog/grpcbinlogstreamer/streamer.go index 13bcf64f77..980dca0334 100644 --- a/go/vt/binlog/grpcbinlogstreamer/streamer.go +++ b/go/vt/binlog/grpcbinlogstreamer/streamer.go @@ -31,12 +31,8 @@ func New(updateStream proto.UpdateStream) *UpdateStream { func (server *UpdateStream) StreamUpdate(req *pb.StreamUpdateRequest, stream pbs.UpdateStream_StreamUpdateServer) (err error) { defer server.updateStream.HandlePanic(&err) return server.updateStream.ServeUpdateStream(req.Position, func(reply *proto.StreamEvent) error { - event, err := proto.StreamEventToProto(reply) - if err != nil { - return err - } return stream.Send(&pb.StreamUpdateResponse{ - StreamEvent: event, + StreamEvent: proto.StreamEventToProto(reply), }) }) } diff --git a/go/vt/binlog/proto/proto3.go b/go/vt/binlog/proto/proto3.go index 8585bfce39..eabdee1d6d 100644 --- a/go/vt/binlog/proto/proto3.go +++ b/go/vt/binlog/proto/proto3.go @@ -19,14 +19,10 @@ import ( // structures internally, and this will be obsolete. // StreamEventToProto converts a StreamEvent to a proto3 -func StreamEventToProto(s *StreamEvent) (*pb.StreamEvent, error) { - fields, err := mproto.FieldsToProto3(s.PrimaryKeyFields) - if err != nil { - return nil, err - } +func StreamEventToProto(s *StreamEvent) *pb.StreamEvent { result := &pb.StreamEvent{ TableName: s.TableName, - PrimaryKeyFields: fields, + PrimaryKeyFields: mproto.FieldsToProto3(s.PrimaryKeyFields), PrimaryKeyValues: mproto.RowsToProto3(s.PrimaryKeyValues), Sql: s.Sql, Timestamp: s.Timestamp, @@ -42,7 +38,7 @@ func StreamEventToProto(s *StreamEvent) (*pb.StreamEvent, error) { default: result.Category = pb.StreamEvent_SE_ERR } - return result, nil + return result } // ProtoToStreamEvent converts a proto to a StreamEvent diff --git a/go/vt/tabletmanager/grpctmserver/server.go b/go/vt/tabletmanager/grpctmserver/server.go index 37bca011ef..8711e69ea8 100644 --- a/go/vt/tabletmanager/grpctmserver/server.go +++ b/go/vt/tabletmanager/grpctmserver/server.go @@ -185,8 +185,8 @@ func (s *server) ExecuteFetchAsDba(ctx context.Context, request *pb.ExecuteFetch if err != nil { return err } - response.Result, err = mproto.QueryResultToProto3(qr) - return err + response.Result = mproto.QueryResultToProto3(qr) + return nil }) } @@ -198,8 +198,8 @@ func (s *server) ExecuteFetchAsApp(ctx context.Context, request *pb.ExecuteFetch if err != nil { return err } - response.Result, err = mproto.QueryResultToProto3(qr) - return err + response.Result = mproto.QueryResultToProto3(qr) + return nil }) } diff --git a/go/vt/tabletserver/endtoend/compatibility_test.go b/go/vt/tabletserver/endtoend/compatibility_test.go index c7b3028655..ee035ac140 100644 --- a/go/vt/tabletserver/endtoend/compatibility_test.go +++ b/go/vt/tabletserver/endtoend/compatibility_test.go @@ -169,11 +169,7 @@ func TestInts(t *testing.T) { sqltypes.Year, } for i, field := range qr.Fields { - got, err := sqltypes.MySQLToType(field.Type, field.Flags) - if err != nil { - t.Errorf("col: %d, err: %v", i, err) - continue - } + got := sqltypes.MySQLToType(field.Type, field.Flags) if got != wantTypes[i] { t.Errorf("Unexpected type: col: %d, %d, want %d", i, got, wantTypes[i]) } @@ -249,11 +245,7 @@ func TestFractionals(t *testing.T) { sqltypes.Float64, } for i, field := range qr.Fields { - got, err := sqltypes.MySQLToType(field.Type, field.Flags) - if err != nil { - t.Errorf("col: %d, err: %v", i, err) - continue - } + got := sqltypes.MySQLToType(field.Type, field.Flags) if got != wantTypes[i] { t.Errorf("Unexpected type: col: %d, %d, want %d", i, got, wantTypes[i]) } @@ -365,11 +357,7 @@ func TestStrings(t *testing.T) { sqltypes.Set, } for i, field := range qr.Fields { - got, err := sqltypes.MySQLToType(field.Type, field.Flags) - if err != nil { - t.Errorf("col: %d, err: %v", i, err) - continue - } + got := sqltypes.MySQLToType(field.Type, field.Flags) if got != wantTypes[i] { t.Errorf("Unexpected type: col: %d, %d, want %d", i, got, wantTypes[i]) } @@ -445,11 +433,7 @@ func TestMiscTypes(t *testing.T) { sqltypes.Time, } for i, field := range qr.Fields { - got, err := sqltypes.MySQLToType(field.Type, field.Flags) - if err != nil { - t.Errorf("col: %d, err: %v", i, err) - continue - } + got := sqltypes.MySQLToType(field.Type, field.Flags) if got != wantTypes[i] { t.Errorf("Unexpected type: col: %d, %d, want %d", i, got, wantTypes[i]) } @@ -485,11 +469,7 @@ func TestNull(t *testing.T) { sqltypes.Null, } for i, field := range qr.Fields { - got, err := sqltypes.MySQLToType(field.Type, field.Flags) - if err != nil { - t.Errorf("col: %d, err: %v", i, err) - continue - } + got := sqltypes.MySQLToType(field.Type, field.Flags) if got != wantTypes[i] { t.Errorf("Unexpected type: col: %d, %d, want %d", i, got, wantTypes[i]) } diff --git a/go/vt/tabletserver/grpcqueryservice/server.go b/go/vt/tabletserver/grpcqueryservice/server.go index fa8389816e..7c8b70e6cd 100644 --- a/go/vt/tabletserver/grpcqueryservice/server.go +++ b/go/vt/tabletserver/grpcqueryservice/server.go @@ -65,11 +65,9 @@ func (q *query) Execute(ctx context.Context, request *pb.ExecuteRequest) (respon }, reply); err != nil { return nil, tabletserver.ToGRPCError(err) } - result, err := mproto.QueryResultToProto3(reply) - if err != nil { - return nil, tabletserver.ToGRPCError(err) - } - return &pb.ExecuteResponse{Result: result}, nil + return &pb.ExecuteResponse{ + Result: mproto.QueryResultToProto3(reply), + }, nil } // ExecuteBatch is part of the queryservice.QueryServer interface @@ -92,11 +90,9 @@ func (q *query) ExecuteBatch(ctx context.Context, request *pb.ExecuteBatchReques }, reply); err != nil { return nil, tabletserver.ToGRPCError(err) } - results, err := proto.QueryResultListToProto3(reply.List) - if err != nil { - return nil, tabletserver.ToGRPCError(err) - } - return &pb.ExecuteBatchResponse{Results: results}, nil + return &pb.ExecuteBatchResponse{ + Results: proto.QueryResultListToProto3(reply.List), + }, nil } // StreamExecute is part of the queryservice.QueryServer interface @@ -115,11 +111,9 @@ func (q *query) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Query BindVariables: bv, SessionId: request.SessionId, }, func(reply *mproto.QueryResult) error { - result, err := mproto.QueryResultToProto3(reply) - if err != nil { - return err - } - return stream.Send(&pb.StreamExecuteResponse{Result: result}) + return stream.Send(&pb.StreamExecuteResponse{ + Result: mproto.QueryResultToProto3(reply), + }) }); err != nil { return tabletserver.ToGRPCError(err) } diff --git a/go/vt/tabletserver/proto/proto3.go b/go/vt/tabletserver/proto/proto3.go index 4cc23a2757..6ea7e4362b 100644 --- a/go/vt/tabletserver/proto/proto3.go +++ b/go/vt/tabletserver/proto/proto3.go @@ -317,19 +317,15 @@ func Proto3ToQueryResultList(results []*pb.QueryResult) *QueryResultList { } // QueryResultListToProto3 changes the internal array of QueryResult to the proto3 version -func QueryResultListToProto3(results []mproto.QueryResult) ([]*pb.QueryResult, error) { +func QueryResultListToProto3(results []mproto.QueryResult) []*pb.QueryResult { if len(results) == 0 { - return nil, nil + return nil } result := make([]*pb.QueryResult, len(results)) - var err error for i := range results { - result[i], err = mproto.QueryResultToProto3(&results[i]) - if err != nil { - return nil, err - } + result[i] = mproto.QueryResultToProto3(&results[i]) } - return result, nil + return result } // Proto3ToQuerySplits converts a proto3 QuerySplit array to a native QuerySplit array diff --git a/go/vt/tabletserver/table_info.go b/go/vt/tabletserver/table_info.go index 9b2036f11e..f26190494e 100644 --- a/go/vt/tabletserver/table_info.go +++ b/go/vt/tabletserver/table_info.go @@ -54,10 +54,7 @@ func (ti *TableInfo) fetchColumns(conn *DBConn) error { } fieldTypes := make(map[string]query.Type, len(qr.Fields)) for _, field := range qr.Fields { - fieldTypes[field.Name], err = sqltypes.MySQLToType(field.Type, field.Flags) - if err != nil { - return err - } + fieldTypes[field.Name] = sqltypes.MySQLToType(field.Type, field.Flags) } columns, err := conn.Exec(context.Background(), fmt.Sprintf("describe `%s`", ti.Name), 10000, false) if err != nil { diff --git a/go/vt/vtgate/grpcvtgateservice/server.go b/go/vt/vtgate/grpcvtgateservice/server.go index 1c1abe181a..9df486b353 100644 --- a/go/vt/vtgate/grpcvtgateservice/server.go +++ b/go/vt/vtgate/grpcvtgateservice/server.go @@ -43,14 +43,12 @@ func (vtg *VTGate) Execute(ctx context.Context, request *pb.ExecuteRequest) (res response = &pb.ExecuteResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Result, executeErr = mproto.QueryResultToProto3(reply.Result) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Result = mproto.QueryResultToProto3(reply.Result) + response.Session = reply.Session + return response, nil } // ExecuteShards is the RPC version of vtgateservice.VTGateService method @@ -76,14 +74,12 @@ func (vtg *VTGate) ExecuteShards(ctx context.Context, request *pb.ExecuteShardsR response = &pb.ExecuteShardsResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Result, executeErr = mproto.QueryResultToProto3(reply.Result) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Result = mproto.QueryResultToProto3(reply.Result) + response.Session = reply.Session + return response, nil } // ExecuteKeyspaceIds is the RPC version of vtgateservice.VTGateService method @@ -109,14 +105,12 @@ func (vtg *VTGate) ExecuteKeyspaceIds(ctx context.Context, request *pb.ExecuteKe response = &pb.ExecuteKeyspaceIdsResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Result, executeErr = mproto.QueryResultToProto3(reply.Result) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Result = mproto.QueryResultToProto3(reply.Result) + response.Session = reply.Session + return response, nil } // ExecuteKeyRanges is the RPC version of vtgateservice.VTGateService method @@ -142,14 +136,12 @@ func (vtg *VTGate) ExecuteKeyRanges(ctx context.Context, request *pb.ExecuteKeyR response = &pb.ExecuteKeyRangesResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Result, executeErr = mproto.QueryResultToProto3(reply.Result) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Result = mproto.QueryResultToProto3(reply.Result) + response.Session = reply.Session + return response, nil } // ExecuteEntityIds is the RPC version of vtgateservice.VTGateService method @@ -176,14 +168,12 @@ func (vtg *VTGate) ExecuteEntityIds(ctx context.Context, request *pb.ExecuteEnti response = &pb.ExecuteEntityIdsResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Result, executeErr = mproto.QueryResultToProto3(reply.Result) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Result = mproto.QueryResultToProto3(reply.Result) + response.Session = reply.Session + return response, nil } // ExecuteBatchShards is the RPC version of vtgateservice.VTGateService method @@ -206,14 +196,12 @@ func (vtg *VTGate) ExecuteBatchShards(ctx context.Context, request *pb.ExecuteBa response = &pb.ExecuteBatchShardsResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Results, executeErr = tproto.QueryResultListToProto3(reply.List) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Results = tproto.QueryResultListToProto3(reply.List) + response.Session = reply.Session + return response, nil } // ExecuteBatchKeyspaceIds is the RPC version of @@ -237,14 +225,12 @@ func (vtg *VTGate) ExecuteBatchKeyspaceIds(ctx context.Context, request *pb.Exec response = &pb.ExecuteBatchKeyspaceIdsResponse{ Error: vtgate.RPCErrorToVtRPCError(reply.Err), } - if executeErr == nil { - response.Results, executeErr = tproto.QueryResultListToProto3(reply.List) - if executeErr == nil { - response.Session = reply.Session - return response, nil - } + if executeErr != nil { + return nil, vterrors.ToGRPCError(executeErr) } - return nil, vterrors.ToGRPCError(executeErr) + response.Results = tproto.QueryResultListToProto3(reply.List) + response.Session = reply.Session + return response, nil } // StreamExecute is the RPC version of vtgateservice.VTGateService method @@ -262,11 +248,9 @@ func (vtg *VTGate) StreamExecute(request *pb.StreamExecuteRequest, stream pbs.Vi bv, request.TabletType, func(value *proto.QueryResult) error { - result, err := mproto.QueryResultToProto3(value.Result) - if err != nil { - return err - } - return stream.Send(&pb.StreamExecuteResponse{Result: result}) + return stream.Send(&pb.StreamExecuteResponse{ + Result: mproto.QueryResultToProto3(value.Result), + }) }) return vterrors.ToGRPCError(vtgErr) } @@ -288,11 +272,9 @@ func (vtg *VTGate) StreamExecuteShards(request *pb.StreamExecuteShardsRequest, s request.Shards, request.TabletType, func(value *proto.QueryResult) error { - result, err := mproto.QueryResultToProto3(value.Result) - if err != nil { - return err - } - return stream.Send(&pb.StreamExecuteShardsResponse{Result: result}) + return stream.Send(&pb.StreamExecuteShardsResponse{ + Result: mproto.QueryResultToProto3(value.Result), + }) }) return vterrors.ToGRPCError(vtgErr) } @@ -315,11 +297,9 @@ func (vtg *VTGate) StreamExecuteKeyspaceIds(request *pb.StreamExecuteKeyspaceIds request.KeyspaceIds, request.TabletType, func(value *proto.QueryResult) error { - result, err := mproto.QueryResultToProto3(value.Result) - if err != nil { - return err - } - return stream.Send(&pb.StreamExecuteKeyspaceIdsResponse{Result: result}) + return stream.Send(&pb.StreamExecuteKeyspaceIdsResponse{ + Result: mproto.QueryResultToProto3(value.Result), + }) }) return vterrors.ToGRPCError(vtgErr) } @@ -342,11 +322,9 @@ func (vtg *VTGate) StreamExecuteKeyRanges(request *pb.StreamExecuteKeyRangesRequ request.KeyRanges, request.TabletType, func(value *proto.QueryResult) error { - result, err := mproto.QueryResultToProto3(value.Result) - if err != nil { - return err - } - return stream.Send(&pb.StreamExecuteKeyRangesResponse{Result: result}) + return stream.Send(&pb.StreamExecuteKeyRangesResponse{ + Result: mproto.QueryResultToProto3(value.Result), + }) }) return vterrors.ToGRPCError(vtgErr) } From 1fe934663f1bb811a3cf67b35f9bed829e663848 Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Thu, 5 Nov 2015 23:52:05 -0800 Subject: [PATCH 6/6] proto3 conversion: don't bounds check the lengths It's better to panic if the proto3 lengths are invalid. --- go/mysql/proto/proto3.go | 11 +++-------- go/mysql/proto/proto3_test.go | 35 ----------------------------------- 2 files changed, 3 insertions(+), 43 deletions(-) diff --git a/go/mysql/proto/proto3.go b/go/mysql/proto/proto3.go index 9fdeec08a5..808c3b691e 100644 --- a/go/mysql/proto/proto3.go +++ b/go/mysql/proto/proto3.go @@ -110,17 +110,12 @@ func Proto3ToRows(rows []*pbq.Row) [][]sqltypes.Value { index := 0 result[i] = make([]sqltypes.Value, len(r.Lengths)) for j, l := range r.Lengths { - if l <= -1 { + if l < 0 { result[i][j] = sqltypes.NULL } else { end := index + int(l) - if end > len(r.Values) { - result[i][j] = sqltypes.NULL - index = len(r.Values) - } else { - result[i][j] = sqltypes.MakeString(r.Values[index:end]) - index = end - } + result[i][j] = sqltypes.MakeString(r.Values[index:end]) + index = end } } } diff --git a/go/mysql/proto/proto3_test.go b/go/mysql/proto/proto3_test.go index 8c35edf78a..7aa11c0a97 100644 --- a/go/mysql/proto/proto3_test.go +++ b/go/mysql/proto/proto3_test.go @@ -72,38 +72,3 @@ func TestRowsToProto3(t *testing.T) { t.Errorf("reverse: \n%#v, want \n%#v", reverse, rows) } } - -func TestInvalidRowsProto(t *testing.T) { - p3 := []*query.Row{ - &query.Row{ - Lengths: []int64{3, 5, -1, 6}, - Values: []byte("aa12"), - }, - } - rows := Proto3ToRows(p3) - want := [][]sqltypes.Value{{ - sqltypes.MakeString([]byte("aa1")), - sqltypes.NULL, - sqltypes.NULL, - sqltypes.NULL, - }} - if !reflect.DeepEqual(rows, want) { - t.Errorf("reverse: \n%#v, want \n%#v", rows, want) - } - - p3 = []*query.Row{ - &query.Row{ - Lengths: []int64{2, -2, 2}, - Values: []byte("aa12"), - }, - } - rows = Proto3ToRows(p3) - want = [][]sqltypes.Value{{ - sqltypes.MakeString([]byte("aa")), - sqltypes.NULL, - sqltypes.MakeString([]byte("12")), - }} - if !reflect.DeepEqual(rows, want) { - t.Errorf("reverse: \n%#v, want \n%#v", rows, want) - } -}