diff --git a/go/vt/vtgate/fakerpcvtgateconn/conn.go b/go/vt/vtgate/fakerpcvtgateconn/conn.go index 8f25b26ca8..2d6409c215 100644 --- a/go/vt/vtgate/fakerpcvtgateconn/conn.go +++ b/go/vt/vtgate/fakerpcvtgateconn/conn.go @@ -215,6 +215,11 @@ func (conn *FakeVTGateConn) StreamExecute(ctx context.Context, query string, bin return resultChan, func() error { return nil }, nil } +// StreamExecute2 please see vtgateconn.Impl.StreamExecute2 +func (conn *FakeVTGateConn) StreamExecute2(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { + panic("not implemented") +} + // StreamExecuteShard please see vtgateconn.Impl.StreamExecuteShard func (conn *FakeVTGateConn) StreamExecuteShard(ctx context.Context, query string, keyspace string, shards []string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { panic("not implemented") diff --git a/go/vt/vtgate/gorpcvtgateconn/conn.go b/go/vt/vtgate/gorpcvtgateconn/conn.go index 7824894e4a..4e73b29761 100644 --- a/go/vt/vtgate/gorpcvtgateconn/conn.go +++ b/go/vt/vtgate/gorpcvtgateconn/conn.go @@ -257,6 +257,19 @@ func (conn *vtgateConn) StreamExecute(ctx context.Context, query string, bindVar return sendStreamResults(c, sr) } +func (conn *vtgateConn) StreamExecute2(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { + req := &proto.Query{ + CallerID: getEffectiveCallerID(ctx), + Sql: query, + BindVariables: bindVars, + TabletType: tabletType, + Session: nil, + } + sr := make(chan *proto.QueryResult, 10) + c := conn.rpcConn.StreamGo("VTGate.StreamExecute2", req, sr) + return sendStreamResults(c, sr) +} + func (conn *vtgateConn) StreamExecuteShard(ctx context.Context, query string, keyspace string, shards []string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { req := &proto.QueryShard{ CallerID: getEffectiveCallerID(ctx), @@ -304,13 +317,29 @@ func (conn *vtgateConn) StreamExecuteKeyspaceIds(ctx context.Context, query stri func sendStreamResults(c *rpcplus.Call, sr chan *proto.QueryResult) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { srout := make(chan *mproto.QueryResult, 1) + var vtErr error go func() { defer close(srout) for r := range sr { - srout <- r.Result + vtErr = vterrors.FromRPCError(r.Err) + // If we get a QueryResult with an RPCError, that was an extra QueryResult sent by + // the server specifically to indicate an error, and we shouldn't surface it to clients. + if vtErr == nil { + srout <- r.Result + } } }() - return srout, func() error { return c.Error }, nil + // errFunc will return either an RPC-layer error or an application error, if one exists. + // It will only return the most recent application error (i.e, from the QueryResult that + // most recently contained an error). It will prioritize an RPC-layer error over an apperror, + // if both exist. + errFunc := func() error { + if c.Error != nil { + return c.Error + } + return vtErr + } + return srout, errFunc, nil } func (conn *vtgateConn) Begin(ctx context.Context) (interface{}, error) { diff --git a/go/vt/vtgate/gorpcvtgateservice/server.go b/go/vt/vtgate/gorpcvtgateservice/server.go index 7c5dbba718..f8f74c2182 100644 --- a/go/vt/vtgate/gorpcvtgateservice/server.go +++ b/go/vt/vtgate/gorpcvtgateservice/server.go @@ -152,6 +152,32 @@ func (vtg *VTGate) StreamExecute(ctx context.Context, request *proto.Query, send }) } +// StreamExecute2 is the RPC version of vtgateservice.VTGateService method +func (vtg *VTGate) StreamExecute2(ctx context.Context, request *proto.Query, sendReply func(interface{}) error) (err error) { + defer vtg.server.HandlePanic(&err) + ctx = callerid.NewContext(ctx, + callerid.GoRPCEffectiveCallerID(request.CallerID), + callerid.NewImmediateCallerID("gorpc client")) + vtgErr := vtg.server.StreamExecute(ctx, request, func(value *proto.QueryResult) error { + return sendReply(value) + }) + if vtgErr == nil { + return nil + } + if *vtgate.RPCErrorOnlyInReply { + // If there was an app error, send a QueryResult back with it. + qr := new(proto.QueryResult) + vtgate.AddVtGateErrorToQueryResult(vtgErr, qr) + // Sending back errors this way is not backwards compatible. If a (new) server sends an additional + // QueryResult with an error, and the (old) client doesn't know how to read it, it will cause + // problems where the client will get out of sync with the number of QueryResults sent. + // That's why this the error is only sent this way when the --rpc_errors_only_in_reply flag is set + // (signalling that all clients are able to handle new-style errors). + return sendReply(qr) + } + return vtgErr +} + // StreamExecuteShard is the RPC version of vtgateservice.VTGateService method func (vtg *VTGate) StreamExecuteShard(ctx context.Context, request *proto.QueryShard, sendReply func(interface{}) error) (err error) { defer vtg.server.HandlePanic(&err) diff --git a/go/vt/vtgate/grpcvtgateconn/conn.go b/go/vt/vtgate/grpcvtgateconn/conn.go index c75e7de439..52b561d91f 100644 --- a/go/vt/vtgate/grpcvtgateconn/conn.go +++ b/go/vt/vtgate/grpcvtgateconn/conn.go @@ -244,6 +244,10 @@ func (conn *vtgateConn) StreamExecute(ctx context.Context, query string, bindVar }, nil } +func (conn *vtgateConn) StreamExecute2(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { + return conn.StreamExecute(ctx, query, bindVars, tabletType) +} + func (conn *vtgateConn) StreamExecuteShard(ctx context.Context, query string, keyspace string, shards []string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, vtgateconn.ErrFunc, error) { req := &pb.StreamExecuteShardsRequest{ CallerId: callerid.EffectiveCallerIDFromContext(ctx), diff --git a/go/vt/vtgate/vtgateconn/vtgateconn.go b/go/vt/vtgate/vtgateconn/vtgateconn.go index 86d25980ba..957e1a6028 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn.go @@ -101,6 +101,15 @@ func (conn *VTGateConn) StreamExecute(ctx context.Context, query string, bindVar return conn.impl.StreamExecute(ctx, query, bindVars, tabletType) } +// StreamExecute2 executes a streaming query on vtgate. It returns a +// channel, an ErrFunc, and error. First check the error. Then you can +// pull values from the channel till it's closed. Following this, you +// can call ErrFunc to see if the stream ended normally or due to a +// failure. +func (conn *VTGateConn) StreamExecute2(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, ErrFunc, error) { + return conn.impl.StreamExecute2(ctx, query, bindVars, tabletType) +} + // StreamExecuteShard executes a streaming query on vtgate, on a set // of shards. It returns a channel, an ErrFunc, and error. First // check the error. Then you can pull values from the channel till @@ -322,6 +331,9 @@ type Impl interface { // StreamExecute executes a streaming query on vtgate. StreamExecute(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, ErrFunc, error) + // StreamExecute2 executes a streaming query on vtgate. + StreamExecute2(ctx context.Context, query string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, ErrFunc, error) + // StreamExecuteShard executes a streaming query on vtgate, on a set of shards. StreamExecuteShard(ctx context.Context, query string, keyspace string, shards []string, bindVars map[string]interface{}, tabletType topo.TabletType) (<-chan *mproto.QueryResult, ErrFunc, error) diff --git a/go/vt/vtgate/vtgateconntest/client.go b/go/vt/vtgate/vtgateconntest/client.go index 7de694f053..1c0c9e93f0 100644 --- a/go/vt/vtgate/vtgateconntest/client.go +++ b/go/vt/vtgate/vtgateconntest/client.go @@ -541,6 +541,7 @@ func TestSuite(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGa testExecuteBatchShardError(t, conn) testExecuteBatchKeyspaceIdsError(t, conn) testStreamExecuteError(t, conn, fs) + testStreamExecute2Error(t, conn, fs) testStreamExecuteShardError(t, conn, fs) testStreamExecuteKeyRangesError(t, conn, fs) testStreamExecuteKeyspaceIdsError(t, conn, fs) @@ -949,6 +950,32 @@ func testStreamExecuteError(t *testing.T, conn *vtgateconn.VTGateConn, fake *fak verifyError(t, err, "StreamExecute") } +func testStreamExecute2Error(t *testing.T, conn *vtgateconn.VTGateConn, fake *fakeVTGateService) { + ctx := newContext() + execCase := execMap["request1"] + stream, errFunc, err := conn.StreamExecute2(ctx, execCase.execQuery.Sql, execCase.execQuery.BindVariables, execCase.execQuery.TabletType) + if err != nil { + t.Fatalf("StreamExecute2 failed: %v", err) + } + qr, ok := <-stream + if !ok { + t.Fatalf("StreamExecute2 failed: cannot read result1") + } + + if !reflect.DeepEqual(qr, &streamResult1) { + t.Errorf("Unexpected result from StreamExecute2: got %#v want %#v", qr, &streamResult1) + } + // signal to the server that the first result has been received + close(fake.errorWait) + // After 1 result, we expect to get an error (no more results). + qr, ok = <-stream + if ok { + t.Fatalf("StreamExecute2 channel wasn't closed") + } + err = errFunc() + verifyError(t, err, "StreamExecute2") +} + func testStreamExecutePanic(t *testing.T, conn *vtgateconn.VTGateConn) { ctx := newContext() execCase := execMap["request1"]