Improve rpc cancellation when there is no pending I/O

This commit is contained in:
iamqizhao 2015-10-22 13:07:13 -07:00
Родитель f13f7f6db6
Коммит afca514667
2 изменённых файлов: 69 добавлений и 6 удалений

Просмотреть файл

@ -130,6 +130,12 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
cs.t = t
cs.s = s
cs.p = &parser{s: s}
// Listen on ctx.Done() to detect cancellation when there is no pending
// I/O operations on this stream.
go func() {
<-s.Context().Done()
cs.closeTransportStream(transport.ContextErr(s.Context().Err()))
}()
return cs, nil
}
@ -143,7 +149,8 @@ type clientStream struct {
tracing bool // set to EnableTracing when the clientStream is created.
mu sync.Mutex // protects trInfo.tr
mu sync.Mutex
closed bool
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
// and is set to nil when the clientStream's finish method is called.
trInfo traceInfo
@ -157,7 +164,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
m, err := cs.s.Header()
if err != nil {
if _, ok := err.(transport.ConnectionError); !ok {
cs.t.CloseStream(cs.s, err)
cs.closeTransportStream(err)
}
}
return m, err
@ -180,7 +187,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
return
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.t.CloseStream(cs.s, err)
cs.closeTransportStream(err)
}
err = toRPCErr(err)
}()
@ -212,7 +219,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}
// Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, m)
cs.t.CloseStream(cs.s, err)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
}
@ -225,7 +232,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return toRPCErr(err)
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.t.CloseStream(cs.s, err)
cs.closeTransportStream(err)
}
if err == io.EOF {
if cs.s.StatusCode() == codes.OK {
@ -243,12 +250,23 @@ func (cs *clientStream) CloseSend() (err error) {
return
}
if _, ok := err.(transport.ConnectionError); !ok {
cs.t.CloseStream(cs.s, err)
cs.closeTransportStream(err)
}
err = toRPCErr(err)
return
}
func (cs *clientStream) closeTransportStream(err error) {
cs.mu.Lock()
if cs.closed {
cs.mu.Unlock()
return
}
cs.closed = true
cs.mu.Unlock()
cs.t.CloseStream(cs.s, err)
}
func (cs *clientStream) finish(err error) {
if !cs.tracing {
return

Просмотреть файл

@ -798,6 +798,51 @@ func testCancel(t *testing.T, e env) {
}
}
func TestCancelNoIO(t *testing.T) {
for _, e := range listTestEnv() {
testCancelNoIO(t, e)
}
}
func testCancelNoIO(t *testing.T, e env) {
// Only allows 1 live stream per server transport.
s, cc := setUp(t, nil, 1, "", e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx, cancel := context.WithCancel(context.Background())
_, err := tc.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
// Loop until receiving the new max stream setting from the server.
for {
ctx, _ := context.WithTimeout(context.Background(), time.Second)
_, err := tc.StreamingInputCall(ctx)
if err == nil {
time.Sleep(time.Second)
continue
}
if grpc.Code(err) == codes.DeadlineExceeded {
break
}
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
}
// If there are any RPCs slipping before the client receives the max streams setting,
// let them be expired.
time.Sleep(2 * time.Second)
ch := make(chan struct{})
go func() {
defer close(ch)
// This should be blocked until the 1st is canceled.
ctx, _ := context.WithTimeout(context.Background(), 2 * time.Second)
if _, err := tc.StreamingInputCall(ctx); err != nil {
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
}
}()
cancel();
<-ch
}
// The following tests the gRPC streaming RPC implementations.
// TODO(zhaoq): Have better coverage on error cases.
var (