diff --git a/go/sqltypes/bind_variables.go b/go/sqltypes/bind_variables.go new file mode 100644 index 0000000000..eed730bf59 --- /dev/null +++ b/go/sqltypes/bind_variables.go @@ -0,0 +1,32 @@ +package sqltypes + +import ( + "reflect" + + "github.com/golang/protobuf/proto" +) + +// BindVariablesEqual compares two maps of bind variables. +// For protobuf messages we have to use "proto.Equal". +func BindVariablesEqual(x, y map[string]interface{}) bool { + if len(x) != len(y) { + return false + } + for k := range x { + vx, vy := x[k], y[k] + if reflect.TypeOf(vx) != reflect.TypeOf(vy) { + return false + } + switch vx.(type) { + case proto.Message: + if !proto.Equal(vx.(proto.Message), vy.(proto.Message)) { + return false + } + default: + if !reflect.DeepEqual(vx, vy) { + return false + } + } + } + return true +} diff --git a/go/vt/binlog/binlogplayertest/player.go b/go/vt/binlog/binlogplayertest/player.go index a0c91d26ed..ea82f5a122 100644 --- a/go/vt/binlog/binlogplayertest/player.go +++ b/go/vt/binlog/binlogplayertest/player.go @@ -179,7 +179,7 @@ func testStreamTables(t *testing.T, bpc binlogplayer.Client) { if se, err := stream.Recv(); err != nil { t.Fatalf("got error: %v", err) } else { - if !reflect.DeepEqual(*se, *testBinlogTransaction) { + if !proto.Equal(se, testBinlogTransaction) { t.Errorf("got wrong result, got %v expected %v", *se, *testBinlogTransaction) } } diff --git a/go/vt/mysqlctl/tmutils/schema.go b/go/vt/mysqlctl/tmutils/schema.go index edc54196e2..a56e214a1e 100644 --- a/go/vt/mysqlctl/tmutils/schema.go +++ b/go/vt/mysqlctl/tmutils/schema.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" + "github.com/golang/protobuf/proto" "github.com/youtube/vitess/go/vt/concurrency" tabletmanagerdatapb "github.com/youtube/vitess/go/vt/proto/tabletmanagerdata" @@ -279,3 +280,12 @@ type SchemaChange struct { BeforeSchema *tabletmanagerdatapb.SchemaDefinition AfterSchema *tabletmanagerdatapb.SchemaDefinition } + +// Equal compares two SchemaChange objects. +func (s *SchemaChange) Equal(s2 *SchemaChange) bool { + return s.SQL == s2.SQL && + s.Force == s2.Force && + s.AllowReplication == s2.AllowReplication && + proto.Equal(s.BeforeSchema, s2.BeforeSchema) && + proto.Equal(s.AfterSchema, s2.AfterSchema) +} diff --git a/go/vt/topo/replication.go b/go/vt/topo/replication.go index bd50ab6b1b..d302365036 100644 --- a/go/vt/topo/replication.go +++ b/go/vt/topo/replication.go @@ -6,6 +6,7 @@ package topo import ( log "github.com/golang/glog" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "github.com/youtube/vitess/go/trace" @@ -104,7 +105,7 @@ func RemoveShardReplicationRecord(ctx context.Context, ts Server, cell, keyspace err := ts.UpdateShardReplicationFields(ctx, cell, keyspace, shard, func(sr *topodatapb.ShardReplication) error { nodes := make([]*topodatapb.ShardReplication_Node, 0, len(sr.Nodes)) for _, node := range sr.Nodes { - if *node.TabletAlias != *tabletAlias { + if !proto.Equal(node.TabletAlias, tabletAlias) { nodes = append(nodes, node) } } diff --git a/go/vt/vtgate/vtgateconntest/client.go b/go/vt/vtgate/vtgateconntest/client.go index 1cd4f933e6..0d6fbf7a18 100644 --- a/go/vt/vtgate/vtgateconntest/client.go +++ b/go/vt/vtgate/vtgateconntest/client.go @@ -76,7 +76,7 @@ type queryExecute struct { func (q *queryExecute) equal(q2 *queryExecute) bool { return q.SQL == q2.SQL && - reflect.DeepEqual(q.BindVariables, q2.BindVariables) && + sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) && q.Keyspace == q2.Keyspace && q.TabletType == q2.TabletType && proto.Equal(q.Session, q2.Session) && @@ -166,7 +166,7 @@ type queryExecuteShards struct { func (q *queryExecuteShards) equal(q2 *queryExecuteShards) bool { return q.SQL == q2.SQL && - reflect.DeepEqual(q.BindVariables, q2.BindVariables) && + sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) && q.Keyspace == q2.Keyspace && reflect.DeepEqual(q.Shards, q2.Shards) && q.TabletType == q2.TabletType && @@ -223,7 +223,7 @@ type queryExecuteKeyspaceIds struct { func (q *queryExecuteKeyspaceIds) equal(q2 *queryExecuteKeyspaceIds) bool { return q.SQL == q2.SQL && - reflect.DeepEqual(q.BindVariables, q2.BindVariables) && + sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) && q.Keyspace == q2.Keyspace && reflect.DeepEqual(q.KeyspaceIds, q2.KeyspaceIds) && q.TabletType == q2.TabletType && @@ -279,7 +279,7 @@ type queryExecuteKeyRanges struct { func (q *queryExecuteKeyRanges) equal(q2 *queryExecuteKeyRanges) bool { if q.SQL != q2.SQL || - !reflect.DeepEqual(q.BindVariables, q2.BindVariables) || + !sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) || q.Keyspace != q2.Keyspace || len(q.KeyRanges) != len(q2.KeyRanges) || q.TabletType != q2.TabletType || @@ -344,7 +344,7 @@ type queryExecuteEntityIds struct { func (q *queryExecuteEntityIds) equal(q2 *queryExecuteEntityIds) bool { if q.SQL != q2.SQL || - !reflect.DeepEqual(q.BindVariables, q2.BindVariables) || + !sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) || q.Keyspace != q2.Keyspace || q.EntityColumnName != q2.EntityColumnName || len(q.EntityKeyspaceIDs) != len(q2.EntityKeyspaceIDs) || @@ -824,6 +824,16 @@ type querySplitQuery struct { Algorithm querypb.SplitQueryRequest_Algorithm } +func (q *querySplitQuery) equal(q2 *querySplitQuery) bool { + return q.Keyspace == q2.Keyspace && + q.SQL == q2.SQL && + sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) && + reflect.DeepEqual(q.SplitColumns, q2.SplitColumns) && + q.SplitCount == q2.SplitCount && + q.NumRowsPerQueryPart == q2.NumRowsPerQueryPart && + q.Algorithm == q2.Algorithm +} + // SplitQuery is part of the VTGateService interface func (f *fakeVTGateService) SplitQuery( ctx context.Context, @@ -850,7 +860,7 @@ func (f *fakeVTGateService) SplitQuery( NumRowsPerQueryPart: numRowsPerQueryPart, Algorithm: algorithm, } - if !reflect.DeepEqual(query, splitQueryRequest) { + if !query.equal(splitQueryRequest) { f.t.Errorf("SplitQuery has wrong input: got %#v wanted %#v", query, splitQueryRequest) } return splitQueryResult, nil diff --git a/go/vt/vttablet/agentrpctest/test_agent_rpc.go b/go/vt/vttablet/agentrpctest/test_agent_rpc.go index af8b8a0add..cb36b758f9 100644 --- a/go/vt/vttablet/agentrpctest/test_agent_rpc.go +++ b/go/vt/vttablet/agentrpctest/test_agent_rpc.go @@ -15,6 +15,7 @@ import ( "golang.org/x/net/context" + "github.com/golang/protobuf/proto" "github.com/youtube/vitess/go/sqltypes" "github.com/youtube/vitess/go/vt/hook" "github.com/youtube/vitess/go/vt/logutil" @@ -59,10 +60,36 @@ func NewFakeRPCAgent(t *testing.T) tabletmanager.RPCAgent { // for each possible method of the interface. // This makes the implementations all in the same spot. +var protoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem() + func compare(t *testing.T, name string, got, want interface{}) { - if !reflect.DeepEqual(got, want) { - t.Errorf("Unexpected %v: got %v expected %v", name, got, want) + typ := reflect.TypeOf(got) + if reflect.TypeOf(got) != reflect.TypeOf(want) { + goto fail } + switch { + case typ.Implements(protoMessage): + if !proto.Equal(got.(proto.Message), want.(proto.Message)) { + goto fail + } + case typ.Kind() == reflect.Slice && typ.Elem().Implements(protoMessage): + vx, vy := reflect.ValueOf(got), reflect.ValueOf(want) + if vx.Len() != vy.Len() { + goto fail + } + for i := 0; i < vx.Len(); i++ { + if !proto.Equal(vx.Index(i).Interface().(proto.Message), vy.Index(i).Interface().(proto.Message)) { + goto fail + } + } + default: + if !reflect.DeepEqual(got, want) { + goto fail + } + } + return +fail: + t.Errorf("Unexpected %v:\ngot %#v\nwant %#v", name, got, want) } func compareBool(t *testing.T, name string, got bool) { @@ -527,7 +554,9 @@ func (fra *fakeRPCAgent) ApplySchema(ctx context.Context, change *tmutils.Schema if fra.panics { panic(fmt.Errorf("test-triggered panic")) } - compare(fra.t, "ApplySchema change", change, testSchemaChange) + if !change.Equal(testSchemaChange) { + fra.t.Errorf("Unexpected ApplySchema change:\ngot %#v\nwant %#v", change, testSchemaChange) + } return testSchemaChangeResult[0], nil } diff --git a/go/vt/vttablet/tabletconntest/fakequeryservice.go b/go/vt/vttablet/tabletconntest/fakequeryservice.go index c534e38b11..df0f05f5e1 100644 --- a/go/vt/vttablet/tabletconntest/fakequeryservice.go +++ b/go/vt/vttablet/tabletconntest/fakequeryservice.go @@ -376,7 +376,7 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *querypb.Target, if sql != ExecuteQuery { f.t.Errorf("invalid Execute.Query.Sql: got %v expected %v", sql, ExecuteQuery) } - if !reflect.DeepEqual(bindVariables, ExecuteBindVars) { + if !sqltypes.BindVariablesEqual(bindVariables, ExecuteBindVars) { f.t.Errorf("invalid Execute.BindVariables: got %v expected %v", bindVariables, ExecuteBindVars) } if !proto.Equal(options, TestExecuteOptions) { @@ -432,7 +432,7 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *querypb.Ta if sql != StreamExecuteQuery { f.t.Errorf("invalid StreamExecute.Sql: got %v expected %v", sql, StreamExecuteQuery) } - if !reflect.DeepEqual(bindVariables, StreamExecuteBindVars) { + if !sqltypes.BindVariablesEqual(bindVariables, StreamExecuteBindVars) { f.t.Errorf("invalid StreamExecute.BindVariables: got %v expected %v", bindVariables, StreamExecuteBindVars) } if !proto.Equal(options, TestExecuteOptions) { @@ -538,7 +538,7 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *querypb.Tar if f.Panics { panic(fmt.Errorf("test-triggered panic")) } - if !reflect.DeepEqual(queries, ExecuteBatchQueries) { + if !querytypes.BoundQueriesEqual(queries, ExecuteBatchQueries) { f.t.Errorf("invalid ExecuteBatch.Queries: got %v expected %v", queries, ExecuteBatchQueries) } if !proto.Equal(options, TestExecuteOptions) { @@ -684,7 +684,7 @@ func (f *FakeQueryService) SplitQuery( panic(fmt.Errorf("test-triggered panic")) } f.checkTargetCallerID(ctx, "SplitQuery", target) - if !reflect.DeepEqual(query, SplitQueryBoundQuery) { + if !querytypes.BoundQueryEqual(&query, &SplitQueryBoundQuery) { f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v", querytypes.QueryAsString(query.Sql, query.BindVariables), SplitQueryBoundQuery) } diff --git a/go/vt/vttablet/tabletconntest/tabletconntest.go b/go/vt/vttablet/tabletconntest/tabletconntest.go index 3d5d0d48e4..3f1dc96f59 100644 --- a/go/vt/vttablet/tabletconntest/tabletconntest.go +++ b/go/vt/vttablet/tabletconntest/tabletconntest.go @@ -8,7 +8,6 @@ package tabletconntest import ( "io" - "reflect" "strings" "testing" "time" @@ -21,6 +20,7 @@ import ( "github.com/youtube/vitess/go/vt/vterrors" "github.com/youtube/vitess/go/vt/vttablet/queryservice" "github.com/youtube/vitess/go/vt/vttablet/tabletconn" + "github.com/youtube/vitess/go/vt/vttablet/tabletserver/querytypes" querypb "github.com/youtube/vitess/go/vt/proto/query" topodatapb "github.com/youtube/vitess/go/vt/proto/topodata" @@ -728,7 +728,7 @@ func testSplitQuery(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe if err != nil { t.Fatalf("SplitQuery failed: %v", err) } - if !reflect.DeepEqual(qsl, SplitQueryQuerySplitList) { + if !querytypes.QuerySplitsEqual(qsl, SplitQueryQuerySplitList) { t.Errorf("Unexpected result from SplitQuery: got %v wanted %v", qsl, SplitQueryQuerySplitList) } } diff --git a/go/vt/vttablet/tabletserver/querytypes/bound_query.go b/go/vt/vttablet/tabletserver/querytypes/bound_query.go index 92e57d90b6..ccdc2aeb53 100644 --- a/go/vt/vttablet/tabletserver/querytypes/bound_query.go +++ b/go/vt/vttablet/tabletserver/querytypes/bound_query.go @@ -9,6 +9,8 @@ package querytypes import ( "bytes" "fmt" + + "github.com/youtube/vitess/go/sqltypes" ) // This file defines the BoundQuery type. @@ -55,3 +57,22 @@ func slimit(s string, max int) string { } return s } + +// BoundQueriesEqual compares two slices of BoundQuery objects. +func BoundQueriesEqual(x, y []BoundQuery) bool { + if len(x) != len(y) { + return false + } + for i := range x { + if !BoundQueryEqual(&x[i], &y[i]) { + return false + } + } + return true +} + +// BoundQueryEqual compares two BoundQuery objects. +func BoundQueryEqual(x, y *BoundQuery) bool { + return x.Sql == y.Sql && + sqltypes.BindVariablesEqual(x.BindVariables, y.BindVariables) +} diff --git a/go/vt/vttablet/tabletserver/querytypes/query_split.go b/go/vt/vttablet/tabletserver/querytypes/query_split.go index 0dd4134790..406c07edc2 100644 --- a/go/vt/vttablet/tabletserver/querytypes/query_split.go +++ b/go/vt/vttablet/tabletserver/querytypes/query_split.go @@ -4,6 +4,8 @@ package querytypes +import "github.com/youtube/vitess/go/sqltypes" + // This file defines QuerySplit // QuerySplit represents a split of a query, used for MapReduce purposes. @@ -17,3 +19,23 @@ type QuerySplit struct { // RowCount is the approximate number of rows this query will return RowCount int64 } + +// Equal compares two QuerySplit objects. +func (q *QuerySplit) Equal(q2 *QuerySplit) bool { + return q.Sql == q2.Sql && + sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) && + q.RowCount == q2.RowCount +} + +// QuerySplitsEqual compares two slices of QuerySplit objects. +func QuerySplitsEqual(x, y []QuerySplit) bool { + if len(x) != len(y) { + return false + } + for i := range x { + if !x[i].Equal(&y[i]) { + return false + } + } + return true +}