Add tests for StreamExecute* for vtgateconn

This commit is contained in:
Ammar Aijazi 2015-08-07 18:27:57 -07:00
Родитель a8f043866b
Коммит 97757d049f
1 изменённых файлов: 216 добавлений и 42 удалений

Просмотреть файл

@ -38,6 +38,7 @@ type fakeVTGateService struct {
// we can test subsequent calls in the transaction (e.g., Commit, Rollback).
forceBeginSuccess bool
hasCallerID bool
errorWait chan struct{}
}
var errTestVtGateError = errors.New("test vtgate error")
@ -245,6 +246,13 @@ func (f *fakeVTGateService) StreamExecute(ctx context.Context, query *proto.Quer
if err := sendReply(&result); err != nil {
return err
}
if f.hasError {
// wait until the client has the response, since all streaming implementation may not
// send previous messages if an error has been triggered.
<-f.errorWait
f.errorWait = make(chan struct{}) // for next test
return errTestVtGateError
}
for _, row := range execCase.reply.Result.Rows {
result := proto.QueryResult{Result: &mproto.QueryResult{}}
result.Result.Rows = [][]sqltypes.Value{row}
@ -280,6 +288,13 @@ func (f *fakeVTGateService) StreamExecuteShard(ctx context.Context, query *proto
if err := sendReply(&result); err != nil {
return err
}
if f.hasError {
// wait until the client has the response, since all streaming implementation may not
// send previous messages if an error has been triggered.
<-f.errorWait
f.errorWait = make(chan struct{}) // for next test
return errTestVtGateError
}
for _, row := range execCase.reply.Result.Rows {
result := proto.QueryResult{Result: &mproto.QueryResult{}}
result.Result.Rows = [][]sqltypes.Value{row}
@ -315,6 +330,13 @@ func (f *fakeVTGateService) StreamExecuteKeyRanges(ctx context.Context, query *p
if err := sendReply(&result); err != nil {
return err
}
if f.hasError {
// wait until the client has the response, since all streaming implementation may not
// send previous messages if an error has been triggered.
<-f.errorWait
f.errorWait = make(chan struct{}) // for next test
return errTestVtGateError
}
for _, row := range execCase.reply.Result.Rows {
result := proto.QueryResult{Result: &mproto.QueryResult{}}
result.Result.Rows = [][]sqltypes.Value{row}
@ -350,6 +372,13 @@ func (f *fakeVTGateService) StreamExecuteKeyspaceIds(ctx context.Context, query
if err := sendReply(&result); err != nil {
return err
}
if f.hasError {
// wait until the client has the response, since all streaming implementation may not
// send previous messages if an error has been triggered.
<-f.errorWait
f.errorWait = make(chan struct{}) // for next test
return errTestVtGateError
}
for _, row := range execCase.reply.Result.Rows {
result := proto.QueryResult{Result: &mproto.QueryResult{}}
result.Result.Rows = [][]sqltypes.Value{row}
@ -446,6 +475,7 @@ func CreateFakeServer(t *testing.T) vtgateservice.VTGateService {
t: t,
panics: false,
hasCallerID: true,
errorWait: make(chan struct{}),
}
}
@ -466,6 +496,8 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
t.Fatalf("Got err: %v from vtgateconn.DialProtocol", err)
}
fs := fakeServer.(*fakeVTGateService)
testExecute(t, conn)
testExecuteShard(t, conn)
testExecuteKeyspaceIds(t, conn)
@ -477,11 +509,11 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
testStreamExecuteShard(t, conn)
testStreamExecuteKeyRanges(t, conn)
testStreamExecuteKeyspaceIds(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = false
fs.hasCallerID = false
testTxPass(t, conn)
testTxPassNotInTransaction(t, conn)
testTxFail(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
fs.hasCallerID = true
testTx2Pass(t, conn)
testTx2PassNotInTransaction(t, conn)
testTx2Fail(t, conn)
@ -489,16 +521,17 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
testGetSrvKeyspace(t, conn)
// return an error for every call, make sure they're handled properly
fakeServer.(*fakeVTGateService).hasError = true
fs.hasError = true
// First test errors in Begin, and then force it to succeed so we can test
// subsequent calls in the transaction.
fakeServer.(*fakeVTGateService).forceBeginSuccess = false
fakeServer.(*fakeVTGateService).hasCallerID = false
fs.hasCallerID = false
testBeginError(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testCommitError(t, conn, fs)
testRollbackError(t, conn, fs)
fs.hasCallerID = true
testBegin2Error(t, conn)
fakeServer.(*fakeVTGateService).forceBeginSuccess = true
testCommit2Error(t, conn, fs)
testRollback2Error(t, conn, fs)
testExecuteError(t, conn)
testExecuteShardError(t, conn)
@ -507,31 +540,26 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
testExecuteEntityIdsError(t, conn)
testExecuteBatchShardError(t, conn)
testExecuteBatchKeyspaceIdsError(t, conn)
// testStreamExecuteError(t, conn)
// testStreamExecuteShardError(t, conn)
// testStreamExecuteKeyRangesError(t, conn)
// testStreamExecuteKeyspaceIdsError(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = false
testCommitError(t, conn)
testRollbackError(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testCommit2Error(t, conn)
testRollback2Error(t, conn)
testStreamExecuteError(t, conn, fs)
testStreamExecuteShardError(t, conn, fs)
testStreamExecuteKeyRangesError(t, conn, fs)
testStreamExecuteKeyspaceIdsError(t, conn, fs)
testSplitQueryError(t, conn)
testGetSrvKeyspaceError(t, conn)
fakeServer.(*fakeVTGateService).hasError = false
fs.hasError = false
// force a panic at every call, then test that works
fakeServer.(*fakeVTGateService).panics = true
fs.panics = true
// First test errors in Begin, and then force it to succeed so we can test
// subsequent calls in the transaction.
fakeServer.(*fakeVTGateService).forceBeginSuccess = false
fakeServer.(*fakeVTGateService).hasCallerID = false
fs.hasCallerID = false
testBeginPanic(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testCommitPanic(t, conn, fs)
testRollbackPanic(t, conn, fs)
fs.hasCallerID = true
testBegin2Panic(t, conn)
fakeServer.(*fakeVTGateService).forceBeginSuccess = true
testCommit2Panic(t, conn, fs)
testRollback2Panic(t, conn, fs)
testExecutePanic(t, conn)
testExecuteShardPanic(t, conn)
@ -544,15 +572,9 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa
testStreamExecuteShardPanic(t, conn)
testStreamExecuteKeyRangesPanic(t, conn)
testStreamExecuteKeyspaceIdsPanic(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = false
testCommitPanic(t, conn)
testRollbackPanic(t, conn)
fakeServer.(*fakeVTGateService).hasCallerID = true
testCommit2Panic(t, conn)
testRollback2Panic(t, conn)
testSplitQueryPanic(t, conn)
testGetSrvKeyspacePanic(t, conn)
fakeServer.(*fakeVTGateService).panics = false
fs.panics = false
}
func expectPanic(t *testing.T, err error) {
@ -901,6 +923,32 @@ func testStreamExecute(t *testing.T, conn *vtgateconn.VTGateConn) {
}
}
func testStreamExecuteError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
execCase := execMap["request1"]
stream, errFunc, err := conn.StreamExecute(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
t.Fatalf("StreamExecute failed: %v", err)
}
qr, ok := <-stream
if !ok {
t.Fatalf("StreamExecute failed: cannot read result1")
}
if !reflect.DeepEqual(qr, &streamResult1) {
t.Errorf("Unexpected result from StreamExecute: got %#v want %#v", qr, &streamResult1)
}
// signal to the server that the first result has been received
close(fake.errorWait)
// After 1 result, we expect to get an error (no more results).
qr, ok = <-stream
if ok {
t.Fatalf("StreamExecute channel wasn't closed")
}
err = errFunc()
verifyError(t, err, "StreamExecute")
}
func testStreamExecutePanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := newContext()
execCase := execMap["request1"]
@ -969,6 +1017,32 @@ func testStreamExecuteShard(t *testing.T, conn *vtgateconn.VTGateConn) {
}
}
func testStreamExecuteShardError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
execCase := execMap["request1"]
stream, errFunc, err := conn.StreamExecuteShard(ctx, execCase.shardQuery.Sql, execCase.shardQuery.Keyspace, execCase.shardQuery.Shards, execCase.execQuery.BindVariables, execCase.execQuery.TabletType)
if err != nil {
t.Fatalf("StreamExecuteShard failed: %v", err)
}
qr, ok := <-stream
if !ok {
t.Fatalf("StreamExecuteShard failed: cannot read result1")
}
if !reflect.DeepEqual(qr, &streamResult1) {
t.Errorf("Unexpected result from StreamExecuteShard: got %#v want %#v", qr, &streamResult1)
}
// signal to the server that the first result has been received
close(fake.errorWait)
// After 1 result, we expect to get an error (no more results).
qr, ok = <-stream
if ok {
t.Fatalf("StreamExecuteShard channel wasn't closed")
}
err = errFunc()
verifyError(t, err, "StreamExecuteShard")
}
func testStreamExecuteShardPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := newContext()
execCase := execMap["request1"]
@ -1037,6 +1111,32 @@ func testStreamExecuteKeyRanges(t *testing.T, conn *vtgateconn.VTGateConn) {
}
}
func testStreamExecuteKeyRangesError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
execCase := execMap["request1"]
stream, errFunc, err := conn.StreamExecuteKeyRanges(ctx, execCase.keyRangeQuery.Sql, execCase.keyRangeQuery.Keyspace, execCase.keyRangeQuery.KeyRanges, execCase.keyRangeQuery.BindVariables, execCase.keyRangeQuery.TabletType)
if err != nil {
t.Fatalf("StreamExecuteKeyRanges failed: %v", err)
}
qr, ok := <-stream
if !ok {
t.Fatalf("StreamExecuteKeyRanges failed: cannot read result1")
}
if !reflect.DeepEqual(qr, &streamResult1) {
t.Errorf("Unexpected result from StreamExecuteKeyRanges: got %#v want %#v", qr, &streamResult1)
}
// signal to the server that the first result has been received
close(fake.errorWait)
// After 1 result, we expect to get an error (no more results).
qr, ok = <-stream
if ok {
t.Fatalf("StreamExecuteKeyRanges channel wasn't closed")
}
err = errFunc()
verifyError(t, err, "StreamExecuteKeyRanges")
}
func testStreamExecuteKeyRangesPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := newContext()
execCase := execMap["request1"]
@ -1105,6 +1205,32 @@ func testStreamExecuteKeyspaceIds(t *testing.T, conn *vtgateconn.VTGateConn) {
}
}
func testStreamExecuteKeyspaceIdsError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
execCase := execMap["request1"]
stream, errFunc, err := conn.StreamExecuteKeyspaceIds(ctx, execCase.keyspaceIdQuery.Sql, execCase.keyspaceIdQuery.Keyspace, execCase.keyspaceIdQuery.KeyspaceIds, execCase.keyspaceIdQuery.BindVariables, execCase.keyspaceIdQuery.TabletType)
if err != nil {
t.Fatalf("StreamExecuteKeyspaceIds failed: %v", err)
}
qr, ok := <-stream
if !ok {
t.Fatalf("StreamExecuteKeyspaceIds failed: cannot read result1")
}
if !reflect.DeepEqual(qr, &streamResult1) {
t.Errorf("Unexpected result from StreamExecuteKeyspaceIds: got %#v want %#v", qr, &streamResult1)
}
// signal to the server that the first result has been received
close(fake.errorWait)
// After 1 result, we expect to get an error (no more results).
qr, ok = <-stream
if ok {
t.Fatalf("StreamExecuteKeyspaceIds channel wasn't closed")
}
err = errFunc()
verifyError(t, err, "StreamExecuteKeyspaceIds")
}
func testStreamExecuteKeyspaceIdsPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
ctx := newContext()
execCase := execMap["request1"]
@ -1411,9 +1537,13 @@ func testBeginError(t *testing.T, conn *vtgateconn.VTGateConn) {
verifyError(t, err, "Begin")
}
func testCommitError(t *testing.T, conn *vtgateconn.VTGateConn) {
func testCommitError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1421,9 +1551,13 @@ func testCommitError(t *testing.T, conn *vtgateconn.VTGateConn) {
verifyError(t, err, "Commit")
}
func testRollbackError(t *testing.T, conn *vtgateconn.VTGateConn) {
func testRollbackError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1437,9 +1571,13 @@ func testBegin2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
verifyError(t, err, "Begin2")
}
func testCommit2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
func testCommit2Error(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin2(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1447,9 +1585,13 @@ func testCommit2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
verifyError(t, err, "Commit2")
}
func testRollback2Error(t *testing.T, conn *vtgateconn.VTGateConn) {
func testRollback2Error(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin2(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1463,9 +1605,13 @@ func testBeginPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
expectPanic(t, err)
}
func testCommitPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
func testCommitPanic(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1473,9 +1619,13 @@ func testCommitPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
expectPanic(t, err)
}
func testRollbackPanic(t *testing.T, conn *vtgateconn.VTGateConn) {
func testRollbackPanic(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1489,9 +1639,13 @@ func testBegin2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
expectPanic(t, err)
}
func testCommit2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
func testCommit2Panic(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin2(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -1499,9 +1653,13 @@ func testCommit2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
expectPanic(t, err)
}
func testRollback2Panic(t *testing.T, conn *vtgateconn.VTGateConn) {
func testRollback2Panic(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) {
ctx := newContext()
fake.forceBeginSuccess = true
tx, err := conn.Begin2(ctx)
fake.forceBeginSuccess = false
if err != nil {
t.Error(err)
}
@ -2158,6 +2316,22 @@ var result1 = mproto.QueryResult{
},
}
var streamResult1 = mproto.QueryResult{
Fields: []mproto.Field{
mproto.Field{
Name: "field1",
Type: 42,
},
mproto.Field{
Name: "field2",
Type: 73,
},
},
RowsAffected: 0,
InsertId: 0,
Rows: [][]sqltypes.Value{},
}
var session1 = &proto.Session{
InTransaction: true,
ShardSessions: []*proto.ShardSession{},