diff --git a/go/vt/vtgate/sandbox_test.go b/go/vt/vtgate/sandbox_test.go index 588eafa003..6b2091a0eb 100644 --- a/go/vt/vtgate/sandbox_test.go +++ b/go/vt/vtgate/sandbox_test.go @@ -16,6 +16,7 @@ import ( "github.com/youtube/vitess/go/sync2" tproto "github.com/youtube/vitess/go/vt/tabletserver/proto" "github.com/youtube/vitess/go/vt/topo" + "github.com/youtube/vitess/go/vt/vtgate/proto" ) // sandbox_test.go provides a sandbox for unit testing Barnacle. @@ -56,6 +57,7 @@ func resetSandbox() { dialMustFail = 0 } +// sandboxTopo satisfies the SrvTopoServer interface type sandboxTopo struct { } @@ -98,10 +100,13 @@ func sandboxDialer(endPoint topo.EndPoint, keyspace, shard string) (TabletConn, if tconn == nil { panic(fmt.Sprintf("can't find conn %v", endPoint.Uid)) } + tconn.(*sandboxConn).endPoint = endPoint return tconn, nil } +// sandboxConn satisfies the TabletConn interface type sandboxConn struct { + endPoint topo.EndPoint mustFailRetry int mustFailFatal int mustFailServer int @@ -117,9 +122,6 @@ type sandboxConn struct { CommitCount int RollbackCount int CloseCount int - - // TransactionId is auto-generated on Begin - transactionId int64 } func (sbc *sandboxConn) getError() error { @@ -150,7 +152,7 @@ func (sbc *sandboxConn) getError() error { return nil } -func (sbc *sandboxConn) Execute(query string, bindVars map[string]interface{}) (*mproto.QueryResult, error) { +func (sbc *sandboxConn) Execute(query string, bindVars map[string]interface{}, transactionId int64) (*proto.QueryResult, error) { sbc.ExecCount++ if sbc.mustDelay != 0 { time.Sleep(sbc.mustDelay) @@ -161,7 +163,7 @@ func (sbc *sandboxConn) Execute(query string, bindVars map[string]interface{}) ( return singleRowResult, nil } -func (sbc *sandboxConn) ExecuteBatch(queries []tproto.BoundQuery) (*tproto.QueryResultList, error) { +func (sbc *sandboxConn) ExecuteBatch(queries []tproto.BoundQuery, transactionId int64) (*proto.QueryResultList, error) { sbc.ExecCount++ if sbc.mustDelay != 0 { time.Sleep(sbc.mustDelay) @@ -169,75 +171,76 @@ func (sbc *sandboxConn) ExecuteBatch(queries []tproto.BoundQuery) (*tproto.Query if err := sbc.getError(); err != nil { return nil, err } - qrl := &tproto.QueryResultList{List: make([]mproto.QueryResult, 0, len(queries))} + qrl := &proto.QueryResultList{} + qrl.List = make([]mproto.QueryResult, 0, len(queries)) for _ = range queries { - qrl.List = append(qrl.List, *singleRowResult) + qrl.List = append(qrl.List, singleRowResult.QueryResult) } return qrl, nil } -func (sbc *sandboxConn) StreamExecute(query string, bindVars map[string]interface{}) (<-chan *mproto.QueryResult, ErrFunc) { +func (sbc *sandboxConn) StreamExecute(query string, bindVars map[string]interface{}, transactionId int64) (<-chan *proto.QueryResult, ErrFunc) { sbc.ExecCount++ if sbc.mustDelay != 0 { time.Sleep(sbc.mustDelay) } - ch := make(chan *mproto.QueryResult, 1) + ch := make(chan *proto.QueryResult, 1) ch <- singleRowResult close(ch) err := sbc.getError() return ch, func() error { return err } } -func (sbc *sandboxConn) Begin() error { +func (sbc *sandboxConn) Begin() (int64, error) { sbc.ExecCount++ sbc.BeginCount++ if sbc.mustDelay != 0 { time.Sleep(sbc.mustDelay) } err := sbc.getError() - if err == nil { - sbc.transactionId = transactionId.Add(1) + if err != nil { + return 0, err } - return err + return transactionId.Add(1), nil } -func (sbc *sandboxConn) Commit() error { +func (sbc *sandboxConn) Commit(transactionId int64) error { sbc.ExecCount++ sbc.CommitCount++ - sbc.transactionId = 0 if sbc.mustDelay != 0 { time.Sleep(sbc.mustDelay) } return sbc.getError() } -func (sbc *sandboxConn) Rollback() error { +func (sbc *sandboxConn) Rollback(transactionId int64) error { sbc.ExecCount++ sbc.RollbackCount++ - sbc.transactionId = 0 if sbc.mustDelay != 0 { time.Sleep(sbc.mustDelay) } return sbc.getError() } -func (sbc *sandboxConn) TransactionId() int64 { - return sbc.transactionId -} - // Close does not change ExecCount func (sbc *sandboxConn) Close() { sbc.CloseCount++ } -var singleRowResult = &mproto.QueryResult{ - Fields: []mproto.Field{ - {"id", 3}, - {"value", 253}}, - RowsAffected: 1, - InsertId: 0, - Rows: [][]sqltypes.Value{{ - {sqltypes.Numeric("1")}, - {sqltypes.String("foo")}, - }}, +func (sbc *sandboxConn) EndPoint() topo.EndPoint { + return sbc.endPoint +} + +var singleRowResult = &proto.QueryResult{ + QueryResult: mproto.QueryResult{ + Fields: []mproto.Field{ + {"id", 3}, + {"value", 253}}, + RowsAffected: 1, + InsertId: 0, + Rows: [][]sqltypes.Value{{ + {sqltypes.Numeric("1")}, + {sqltypes.String("foo")}, + }}, + }, } diff --git a/go/vt/vtgate/shard_conn.go b/go/vt/vtgate/shard_conn.go index 0c5b6d0046..d33f2554a7 100644 --- a/go/vt/vtgate/shard_conn.go +++ b/go/vt/vtgate/shard_conn.go @@ -78,13 +78,16 @@ func (sdc *ShardConn) ExecuteBatch(queries []tproto.BoundQuery, transactionId in // StreamExecute executes a streaming query on vttablet. The retry rules are the same as Execute. func (sdc *ShardConn) StreamExecute(query string, bindVars map[string]interface{}, transactionId int64) (results <-chan *proto.QueryResult, errFunc ErrFunc) { var usedConn TabletConn - // We can ignore the error return because errFunc will have it - _ = sdc.withRetry(func(conn TabletConn) error { - results, errFunc = conn.StreamExecute(query, bindVars, transactionId) + var erFunc ErrFunc + err := sdc.withRetry(func(conn TabletConn) error { + results, erFunc = conn.StreamExecute(query, bindVars, transactionId) usedConn = conn - return errFunc() + return erFunc() }, transactionId) - return results, func() error { return sdc.WrapError(errFunc(), usedConn) } + if err != nil { + return results, func() error { return err } + } + return results, func() error { return sdc.WrapError(erFunc(), usedConn) } } // Begin begins a transaction. The retry rules are the same as Execute. @@ -111,24 +114,6 @@ func (sdc *ShardConn) Rollback(transactionId int64) (err error) { }, transactionId) } -// withRetry sets up the connection and exexutes the action. If there are connection errors, -// it retries retryCount times before failing. It does not retry if the connection is in -// the middle of a transaction. -func (sdc *ShardConn) withRetry(action func(conn TabletConn) error, transactionId int64) error { - var conn TabletConn - var err error - for i := 0; i < sdc.retryCount; i++ { - if conn, err = sdc.getConn(); err != nil { - continue - } - if err = action(conn); sdc.canRetry(err, transactionId, conn) { - continue - } - return sdc.WrapError(err, conn) - } - return sdc.WrapError(err, conn) -} - // Close closes the underlying TabletConn. ShardConn can be // reused after this because it opens connections on demand. func (sdc *ShardConn) Close() { @@ -141,26 +126,50 @@ func (sdc *ShardConn) Close() { sdc.conn = nil } +// withRetry sets up the connection and exexutes the action. If there are connection errors, +// it retries retryCount times before failing. It does not retry if the connection is in +// the middle of a transaction. +func (sdc *ShardConn) withRetry(action func(conn TabletConn) error, transactionId int64) error { + var conn TabletConn + var err error + var retry bool + for i := 0; i < sdc.retryCount; i++ { + conn, err, retry = sdc.getConn() + if err != nil { + if retry { + continue + } + return sdc.WrapError(err, conn) + } + if err = action(conn); sdc.canRetry(err, transactionId, conn) { + continue + } + return sdc.WrapError(err, conn) + } + return sdc.WrapError(err, conn) +} + // getConn reuses an existing connection if possible. Otherwise // it returns a connection which it will save for future reuse. -func (sdc *ShardConn) getConn() (TabletConn, error) { +// If it returns an error, retry will tell you if getConn can be retried. +func (sdc *ShardConn) getConn() (conn TabletConn, err error, retry bool) { sdc.mu.Lock() defer sdc.mu.Unlock() if sdc.conn != nil { - return sdc.conn, nil + return sdc.conn, nil, false } endPoint, err := sdc.balancer.Get() if err != nil { - return nil, err + return nil, err, false } - conn, err := GetDialer()(endPoint, sdc.keyspace, sdc.shard) + conn, err = GetDialer()(endPoint, sdc.keyspace, sdc.shard) if err != nil { sdc.balancer.MarkDown(endPoint.Uid) - return nil, err + return nil, err, true } sdc.conn = conn - return sdc.conn, nil + return sdc.conn, nil, false } // canRetry determines whether a query can be retried or not. diff --git a/go/vt/vtgate/shard_conn_test.go b/go/vt/vtgate/shard_conn_test.go index c2408672fc..ca36e7b08e 100644 --- a/go/vt/vtgate/shard_conn_test.go +++ b/go/vt/vtgate/shard_conn_test.go @@ -14,38 +14,65 @@ import ( // This file uses the sandbox_test framework. func TestShardConnExecute(t *testing.T) { - blm := NewBalancerMap(new(sandboxTopo), "aa") testShardConnGeneric(t, func() error { - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - _, err := sdc.Execute("query", nil) + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + _, err := sdc.Execute("query", nil, 0) + return err + }) + testShardConnTransact(t, func() error { + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + _, err := sdc.Execute("query", nil, 1) return err }) } func TestShardConnExecuteBatch(t *testing.T) { - blm := NewBalancerMap(new(sandboxTopo), "aa") testShardConnGeneric(t, func() error { - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) queries := []tproto.BoundQuery{{"query", nil}} - _, err := sdc.ExecuteBatch(queries) + _, err := sdc.ExecuteBatch(queries, 0) + return err + }) + testShardConnTransact(t, func() error { + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + queries := []tproto.BoundQuery{{"query", nil}} + _, err := sdc.ExecuteBatch(queries, 1) return err }) } func TestShardConnExecuteStream(t *testing.T) { - blm := NewBalancerMap(new(sandboxTopo), "aa") testShardConnGeneric(t, func() error { - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - _, errfunc := sdc.StreamExecute("query", nil) + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + _, errfunc := sdc.StreamExecute("query", nil, 0) + return errfunc() + }) + testShardConnTransact(t, func() error { + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + _, errfunc := sdc.StreamExecute("query", nil, 1) return errfunc() }) } func TestShardConnBegin(t *testing.T) { - blm := NewBalancerMap(new(sandboxTopo), "aa") testShardConnGeneric(t, func() error { - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - return sdc.Begin() + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + _, err := sdc.Begin() + return err + }) +} + +func TestShardConnCommi(t *testing.T) { + testShardConnTransact(t, func() error { + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + return sdc.Commit(1) + }) +} + +func TestShardConnRollback(t *testing.T) { + testShardConnTransact(t, func() error { + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 1*time.Millisecond, 3) + return sdc.Rollback(1) }) } @@ -80,7 +107,7 @@ func testShardConnGeneric(t *testing.T, f func() error) { sbc := &sandboxConn{mustFailRetry: 3} testConns[0] = sbc err = f() - want = "retry: err, shard: (.0.), host: " + want = "retry: err, shard: (.0.), host: 0" if err == nil || err.Error() != want { t.Errorf("want %s, got %v", want, err) } @@ -162,28 +189,6 @@ func testShardConnGeneric(t *testing.T, f func() error) { t.Errorf("want 2, got %v", sbc.ExecCount) } - // conn error (in transaction) - resetSandbox() - sbc = &sandboxConn{mustFailConn: 1, transactionId: 1} - testConns[0] = sbc - err = f() - want = "error: conn, shard: (.0.), host: " - if err == nil || err.Error() != want { - t.Errorf("want %s, got %v", want, err) - } - // Ensure we did not redial. - if dialCounter != 1 { - t.Errorf("want 1, got %v", dialCounter) - } - // One rollback followed by execution. - if sbc.ExecCount != 2 { - t.Errorf("want 2, got %v", sbc.ExecCount) - } - // Ensure one of those ExecCounts was a Rollback - if sbc.RollbackCount != 1 { - t.Errorf("want 1, got %v", sbc.ExecCount) - } - // no failures resetSandbox() sbc = &sandboxConn{} @@ -200,28 +205,44 @@ func testShardConnGeneric(t *testing.T, f func() error) { } } -func TestShardConnBeginOther(t *testing.T) { - // already in transaction +func testShardConnTransact(t *testing.T, f func() error) { + // retry error resetSandbox() - blm := NewBalancerMap(new(sandboxTopo), "aa") - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - testConns[0] = &sandboxConn{transactionId: 1} - // call Execute to cause connection to be opened - sdc.Execute("query", nil) - err := sdc.Begin() - // Begin should not be allowed if already in a transaction. - want := "cannot begin: already in transaction, shard: (.0.), host: 0" + sbc := &sandboxConn{mustFailRetry: 3} + testConns[0] = sbc + err := f() + want := "retry: err, shard: (.0.), host: 0" if err == nil || err.Error() != want { t.Errorf("want %s, got %v", want, err) } + // Should not retry if we're in transaction + if sbc.ExecCount != 1 { + t.Errorf("want 1, got %v", sbc.ExecCount) + } + // conn error + resetSandbox() + sbc = &sandboxConn{mustFailConn: 3} + testConns[0] = sbc + err = f() + want = "error: conn, shard: (.0.), host: 0" + if err == nil || err.Error() != want { + t.Errorf("want %s, got %v", want, err) + } + // Should not retry if we're in transaction + if sbc.ExecCount != 1 { + t.Errorf("want 1, got %v", sbc.ExecCount) + } +} + +func TestShardConnBeginOther(t *testing.T) { // tx_pool_full resetSandbox() sbc := &sandboxConn{mustFailTxPool: 1} testConns[0] = sbc - sdc = NewShardConn(blm, "", "0", "", 10*time.Millisecond, 3) + sdc := NewShardConn(new(sandboxTopo), "aa", "", "0", "", 10*time.Millisecond, 3) startTime := time.Now() - err = sdc.Begin() + _, err := sdc.Begin() // If transaction pool is full, Begin should wait and retry. if time.Now().Sub(startTime) < (10 * time.Millisecond) { t.Errorf("want >10ms, got %v", time.Now().Sub(startTime)) @@ -238,79 +259,3 @@ func TestShardConnBeginOther(t *testing.T) { t.Errorf("want 2, got %v", sbc.ExecCount) } } - -func TestShardConnCommit(t *testing.T) { - // not in transaction - resetSandbox() - blm := NewBalancerMap(new(sandboxTopo), "aa") - testConns[0] = &sandboxConn{} - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - sdc.Execute("query", nil) - err := sdc.Commit() - // Commit should fail if we're not in a transaction. - want := "cannot commit: not in transaction, shard: (.0.), host: 0" - if err == nil || err.Error() != want { - t.Errorf("want %s, got %v", want, err) - } - - // valid commit - testConns[0] = &sandboxConn{transactionId: 1} - sdc = NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - sdc.Execute("query", nil) - err = sdc.Commit() - if err != nil { - t.Errorf("want nil, got %v", err) - } - - // commit fail - sbc := &sandboxConn{} - testConns[0] = sbc - sdc = NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - sdc.Execute("query", nil) - sbc.mustFailServer = 1 - sbc.transactionId = 1 - err = sdc.Commit() - // Commit should fail if server returned an error. - want = "error: err, shard: (.0.), host: 0" - if err == nil || err.Error() != want { - t.Errorf("want %s, got %v", want, err) - } -} - -func TestShardConnRollback(t *testing.T) { - // not in transaction - resetSandbox() - blm := NewBalancerMap(new(sandboxTopo), "aa") - testConns[0] = &sandboxConn{} - sdc := NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - sdc.Execute("query", nil) - err := sdc.Rollback() - // Rollback should fail if we're not in a transaction. - want := "cannot rollback: not in transaction, shard: (.0.), host: 0" - if err == nil || err.Error() != want { - t.Errorf("want %s, got %v", want, err) - } - - // valid rollback - testConns[0] = &sandboxConn{transactionId: 1} - sdc = NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - sdc.Execute("query", nil) - err = sdc.Rollback() - if err != nil { - t.Errorf("want nil, got %v", err) - } - - // rollback fail - sbc := &sandboxConn{} - testConns[0] = sbc - sdc = NewShardConn(blm, "", "0", "", 1*time.Millisecond, 3) - sdc.Execute("query", nil) - sbc.mustFailServer = 1 - sbc.transactionId = 1 - err = sdc.Rollback() - want = "error: err, shard: (.0.), host: 0" - // Rollback should fail if server returned an error. - if err == nil || err.Error() != want { - t.Errorf("want %s, got %v", want, err) - } -} diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index c7457cc54f..55bca620e6 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -7,6 +7,7 @@ package vtgate import ( + "fmt" "time" log "github.com/golang/glog" @@ -84,8 +85,13 @@ func (vtg *VTGate) StreamExecuteShard(context *rpcproto.Context, query *proto.Qu // Begin begins a transaction. It has to be concluded by a Commit or Rollback. func (vtg *VTGate) Begin(context *rpcproto.Context, inSession, outSession *proto.Session) error { - inSession.InTransaction = true + if inSession.InTransaction { + err := fmt.Errorf("Already in transaction") + log.Errorf("Begin: %v", err) + return err + } *outSession = *inSession + outSession.InTransaction = true return nil }