зеркало из https://github.com/github/vitess-gh.git
Merge pull request #2801 from alainjobart/proto
Fixing more proto equalities.
This commit is contained in:
Коммит
9389cf5c26
|
@ -0,0 +1,32 @@
|
|||
package sqltypes
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
)
|
||||
|
||||
// BindVariablesEqual compares two maps of bind variables.
|
||||
// For protobuf messages we have to use "proto.Equal".
|
||||
func BindVariablesEqual(x, y map[string]interface{}) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for k := range x {
|
||||
vx, vy := x[k], y[k]
|
||||
if reflect.TypeOf(vx) != reflect.TypeOf(vy) {
|
||||
return false
|
||||
}
|
||||
switch vx.(type) {
|
||||
case proto.Message:
|
||||
if !proto.Equal(vx.(proto.Message), vy.(proto.Message)) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !reflect.DeepEqual(vx, vy) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -179,7 +179,7 @@ func testStreamTables(t *testing.T, bpc binlogplayer.Client) {
|
|||
if se, err := stream.Recv(); err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
} else {
|
||||
if !reflect.DeepEqual(*se, *testBinlogTransaction) {
|
||||
if !proto.Equal(se, testBinlogTransaction) {
|
||||
t.Errorf("got wrong result, got %v expected %v", *se, *testBinlogTransaction)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/youtube/vitess/go/vt/concurrency"
|
||||
|
||||
tabletmanagerdatapb "github.com/youtube/vitess/go/vt/proto/tabletmanagerdata"
|
||||
|
@ -279,3 +280,12 @@ type SchemaChange struct {
|
|||
BeforeSchema *tabletmanagerdatapb.SchemaDefinition
|
||||
AfterSchema *tabletmanagerdatapb.SchemaDefinition
|
||||
}
|
||||
|
||||
// Equal compares two SchemaChange objects.
|
||||
func (s *SchemaChange) Equal(s2 *SchemaChange) bool {
|
||||
return s.SQL == s2.SQL &&
|
||||
s.Force == s2.Force &&
|
||||
s.AllowReplication == s2.AllowReplication &&
|
||||
proto.Equal(s.BeforeSchema, s2.BeforeSchema) &&
|
||||
proto.Equal(s.AfterSchema, s2.AfterSchema)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package topo
|
|||
|
||||
import (
|
||||
log "github.com/golang/glog"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/youtube/vitess/go/trace"
|
||||
|
@ -104,7 +105,7 @@ func RemoveShardReplicationRecord(ctx context.Context, ts Server, cell, keyspace
|
|||
err := ts.UpdateShardReplicationFields(ctx, cell, keyspace, shard, func(sr *topodatapb.ShardReplication) error {
|
||||
nodes := make([]*topodatapb.ShardReplication_Node, 0, len(sr.Nodes))
|
||||
for _, node := range sr.Nodes {
|
||||
if *node.TabletAlias != *tabletAlias {
|
||||
if !proto.Equal(node.TabletAlias, tabletAlias) {
|
||||
nodes = append(nodes, node)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ type queryExecute struct {
|
|||
|
||||
func (q *queryExecute) equal(q2 *queryExecute) bool {
|
||||
return q.SQL == q2.SQL &&
|
||||
reflect.DeepEqual(q.BindVariables, q2.BindVariables) &&
|
||||
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
|
||||
q.Keyspace == q2.Keyspace &&
|
||||
q.TabletType == q2.TabletType &&
|
||||
proto.Equal(q.Session, q2.Session) &&
|
||||
|
@ -166,7 +166,7 @@ type queryExecuteShards struct {
|
|||
|
||||
func (q *queryExecuteShards) equal(q2 *queryExecuteShards) bool {
|
||||
return q.SQL == q2.SQL &&
|
||||
reflect.DeepEqual(q.BindVariables, q2.BindVariables) &&
|
||||
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
|
||||
q.Keyspace == q2.Keyspace &&
|
||||
reflect.DeepEqual(q.Shards, q2.Shards) &&
|
||||
q.TabletType == q2.TabletType &&
|
||||
|
@ -223,7 +223,7 @@ type queryExecuteKeyspaceIds struct {
|
|||
|
||||
func (q *queryExecuteKeyspaceIds) equal(q2 *queryExecuteKeyspaceIds) bool {
|
||||
return q.SQL == q2.SQL &&
|
||||
reflect.DeepEqual(q.BindVariables, q2.BindVariables) &&
|
||||
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
|
||||
q.Keyspace == q2.Keyspace &&
|
||||
reflect.DeepEqual(q.KeyspaceIds, q2.KeyspaceIds) &&
|
||||
q.TabletType == q2.TabletType &&
|
||||
|
@ -279,7 +279,7 @@ type queryExecuteKeyRanges struct {
|
|||
|
||||
func (q *queryExecuteKeyRanges) equal(q2 *queryExecuteKeyRanges) bool {
|
||||
if q.SQL != q2.SQL ||
|
||||
!reflect.DeepEqual(q.BindVariables, q2.BindVariables) ||
|
||||
!sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) ||
|
||||
q.Keyspace != q2.Keyspace ||
|
||||
len(q.KeyRanges) != len(q2.KeyRanges) ||
|
||||
q.TabletType != q2.TabletType ||
|
||||
|
@ -344,7 +344,7 @@ type queryExecuteEntityIds struct {
|
|||
|
||||
func (q *queryExecuteEntityIds) equal(q2 *queryExecuteEntityIds) bool {
|
||||
if q.SQL != q2.SQL ||
|
||||
!reflect.DeepEqual(q.BindVariables, q2.BindVariables) ||
|
||||
!sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) ||
|
||||
q.Keyspace != q2.Keyspace ||
|
||||
q.EntityColumnName != q2.EntityColumnName ||
|
||||
len(q.EntityKeyspaceIDs) != len(q2.EntityKeyspaceIDs) ||
|
||||
|
@ -824,6 +824,16 @@ type querySplitQuery struct {
|
|||
Algorithm querypb.SplitQueryRequest_Algorithm
|
||||
}
|
||||
|
||||
func (q *querySplitQuery) equal(q2 *querySplitQuery) bool {
|
||||
return q.Keyspace == q2.Keyspace &&
|
||||
q.SQL == q2.SQL &&
|
||||
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
|
||||
reflect.DeepEqual(q.SplitColumns, q2.SplitColumns) &&
|
||||
q.SplitCount == q2.SplitCount &&
|
||||
q.NumRowsPerQueryPart == q2.NumRowsPerQueryPart &&
|
||||
q.Algorithm == q2.Algorithm
|
||||
}
|
||||
|
||||
// SplitQuery is part of the VTGateService interface
|
||||
func (f *fakeVTGateService) SplitQuery(
|
||||
ctx context.Context,
|
||||
|
@ -850,7 +860,7 @@ func (f *fakeVTGateService) SplitQuery(
|
|||
NumRowsPerQueryPart: numRowsPerQueryPart,
|
||||
Algorithm: algorithm,
|
||||
}
|
||||
if !reflect.DeepEqual(query, splitQueryRequest) {
|
||||
if !query.equal(splitQueryRequest) {
|
||||
f.t.Errorf("SplitQuery has wrong input: got %#v wanted %#v", query, splitQueryRequest)
|
||||
}
|
||||
return splitQueryResult, nil
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
"github.com/youtube/vitess/go/vt/hook"
|
||||
"github.com/youtube/vitess/go/vt/logutil"
|
||||
|
@ -59,10 +60,36 @@ func NewFakeRPCAgent(t *testing.T) tabletmanager.RPCAgent {
|
|||
// for each possible method of the interface.
|
||||
// This makes the implementations all in the same spot.
|
||||
|
||||
var protoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
|
||||
|
||||
func compare(t *testing.T, name string, got, want interface{}) {
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Unexpected %v: got %v expected %v", name, got, want)
|
||||
typ := reflect.TypeOf(got)
|
||||
if reflect.TypeOf(got) != reflect.TypeOf(want) {
|
||||
goto fail
|
||||
}
|
||||
switch {
|
||||
case typ.Implements(protoMessage):
|
||||
if !proto.Equal(got.(proto.Message), want.(proto.Message)) {
|
||||
goto fail
|
||||
}
|
||||
case typ.Kind() == reflect.Slice && typ.Elem().Implements(protoMessage):
|
||||
vx, vy := reflect.ValueOf(got), reflect.ValueOf(want)
|
||||
if vx.Len() != vy.Len() {
|
||||
goto fail
|
||||
}
|
||||
for i := 0; i < vx.Len(); i++ {
|
||||
if !proto.Equal(vx.Index(i).Interface().(proto.Message), vy.Index(i).Interface().(proto.Message)) {
|
||||
goto fail
|
||||
}
|
||||
}
|
||||
default:
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
goto fail
|
||||
}
|
||||
}
|
||||
return
|
||||
fail:
|
||||
t.Errorf("Unexpected %v:\ngot %#v\nwant %#v", name, got, want)
|
||||
}
|
||||
|
||||
func compareBool(t *testing.T, name string, got bool) {
|
||||
|
@ -527,7 +554,9 @@ func (fra *fakeRPCAgent) ApplySchema(ctx context.Context, change *tmutils.Schema
|
|||
if fra.panics {
|
||||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
compare(fra.t, "ApplySchema change", change, testSchemaChange)
|
||||
if !change.Equal(testSchemaChange) {
|
||||
fra.t.Errorf("Unexpected ApplySchema change:\ngot %#v\nwant %#v", change, testSchemaChange)
|
||||
}
|
||||
return testSchemaChangeResult[0], nil
|
||||
}
|
||||
|
||||
|
|
|
@ -376,7 +376,7 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *querypb.Target,
|
|||
if sql != ExecuteQuery {
|
||||
f.t.Errorf("invalid Execute.Query.Sql: got %v expected %v", sql, ExecuteQuery)
|
||||
}
|
||||
if !reflect.DeepEqual(bindVariables, ExecuteBindVars) {
|
||||
if !sqltypes.BindVariablesEqual(bindVariables, ExecuteBindVars) {
|
||||
f.t.Errorf("invalid Execute.BindVariables: got %v expected %v", bindVariables, ExecuteBindVars)
|
||||
}
|
||||
if !proto.Equal(options, TestExecuteOptions) {
|
||||
|
@ -432,7 +432,7 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *querypb.Ta
|
|||
if sql != StreamExecuteQuery {
|
||||
f.t.Errorf("invalid StreamExecute.Sql: got %v expected %v", sql, StreamExecuteQuery)
|
||||
}
|
||||
if !reflect.DeepEqual(bindVariables, StreamExecuteBindVars) {
|
||||
if !sqltypes.BindVariablesEqual(bindVariables, StreamExecuteBindVars) {
|
||||
f.t.Errorf("invalid StreamExecute.BindVariables: got %v expected %v", bindVariables, StreamExecuteBindVars)
|
||||
}
|
||||
if !proto.Equal(options, TestExecuteOptions) {
|
||||
|
@ -538,7 +538,7 @@ func (f *FakeQueryService) ExecuteBatch(ctx context.Context, target *querypb.Tar
|
|||
if f.Panics {
|
||||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
if !reflect.DeepEqual(queries, ExecuteBatchQueries) {
|
||||
if !querytypes.BoundQueriesEqual(queries, ExecuteBatchQueries) {
|
||||
f.t.Errorf("invalid ExecuteBatch.Queries: got %v expected %v", queries, ExecuteBatchQueries)
|
||||
}
|
||||
if !proto.Equal(options, TestExecuteOptions) {
|
||||
|
@ -684,7 +684,7 @@ func (f *FakeQueryService) SplitQuery(
|
|||
panic(fmt.Errorf("test-triggered panic"))
|
||||
}
|
||||
f.checkTargetCallerID(ctx, "SplitQuery", target)
|
||||
if !reflect.DeepEqual(query, SplitQueryBoundQuery) {
|
||||
if !querytypes.BoundQueryEqual(&query, &SplitQueryBoundQuery) {
|
||||
f.t.Errorf("invalid SplitQuery.SplitQueryRequest.Query: got %v expected %v",
|
||||
querytypes.QueryAsString(query.Sql, query.BindVariables), SplitQueryBoundQuery)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ package tabletconntest
|
|||
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -21,6 +20,7 @@ import (
|
|||
"github.com/youtube/vitess/go/vt/vterrors"
|
||||
"github.com/youtube/vitess/go/vt/vttablet/queryservice"
|
||||
"github.com/youtube/vitess/go/vt/vttablet/tabletconn"
|
||||
"github.com/youtube/vitess/go/vt/vttablet/tabletserver/querytypes"
|
||||
|
||||
querypb "github.com/youtube/vitess/go/vt/proto/query"
|
||||
topodatapb "github.com/youtube/vitess/go/vt/proto/topodata"
|
||||
|
@ -728,7 +728,7 @@ func testSplitQuery(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe
|
|||
if err != nil {
|
||||
t.Fatalf("SplitQuery failed: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(qsl, SplitQueryQuerySplitList) {
|
||||
if !querytypes.QuerySplitsEqual(qsl, SplitQueryQuerySplitList) {
|
||||
t.Errorf("Unexpected result from SplitQuery: got %v wanted %v", qsl, SplitQueryQuerySplitList)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ package querytypes
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/youtube/vitess/go/sqltypes"
|
||||
)
|
||||
|
||||
// This file defines the BoundQuery type.
|
||||
|
@ -55,3 +57,22 @@ func slimit(s string, max int) string {
|
|||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// BoundQueriesEqual compares two slices of BoundQuery objects.
|
||||
func BoundQueriesEqual(x, y []BoundQuery) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i := range x {
|
||||
if !BoundQueryEqual(&x[i], &y[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// BoundQueryEqual compares two BoundQuery objects.
|
||||
func BoundQueryEqual(x, y *BoundQuery) bool {
|
||||
return x.Sql == y.Sql &&
|
||||
sqltypes.BindVariablesEqual(x.BindVariables, y.BindVariables)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
|
||||
package querytypes
|
||||
|
||||
import "github.com/youtube/vitess/go/sqltypes"
|
||||
|
||||
// This file defines QuerySplit
|
||||
|
||||
// QuerySplit represents a split of a query, used for MapReduce purposes.
|
||||
|
@ -17,3 +19,23 @@ type QuerySplit struct {
|
|||
// RowCount is the approximate number of rows this query will return
|
||||
RowCount int64
|
||||
}
|
||||
|
||||
// Equal compares two QuerySplit objects.
|
||||
func (q *QuerySplit) Equal(q2 *QuerySplit) bool {
|
||||
return q.Sql == q2.Sql &&
|
||||
sqltypes.BindVariablesEqual(q.BindVariables, q2.BindVariables) &&
|
||||
q.RowCount == q2.RowCount
|
||||
}
|
||||
|
||||
// QuerySplitsEqual compares two slices of QuerySplit objects.
|
||||
func QuerySplitsEqual(x, y []QuerySplit) bool {
|
||||
if len(x) != len(y) {
|
||||
return false
|
||||
}
|
||||
for i := range x {
|
||||
if !x[i].Equal(&y[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче