Improve rpc cancellation when there is no pending I/O
This commit is contained in:
Родитель
f13f7f6db6
Коммит
afca514667
30
stream.go
30
stream.go
|
@ -130,6 +130,12 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
|
|||
cs.t = t
|
||||
cs.s = s
|
||||
cs.p = &parser{s: s}
|
||||
// Listen on ctx.Done() to detect cancellation when there is no pending
|
||||
// I/O operations on this stream.
|
||||
go func() {
|
||||
<-s.Context().Done()
|
||||
cs.closeTransportStream(transport.ContextErr(s.Context().Err()))
|
||||
}()
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
|
@ -143,7 +149,8 @@ type clientStream struct {
|
|||
|
||||
tracing bool // set to EnableTracing when the clientStream is created.
|
||||
|
||||
mu sync.Mutex // protects trInfo.tr
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
|
||||
// and is set to nil when the clientStream's finish method is called.
|
||||
trInfo traceInfo
|
||||
|
@ -157,7 +164,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
|
|||
m, err := cs.s.Header()
|
||||
if err != nil {
|
||||
if _, ok := err.(transport.ConnectionError); !ok {
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
cs.closeTransportStream(err)
|
||||
}
|
||||
}
|
||||
return m, err
|
||||
|
@ -180,7 +187,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
return
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); !ok {
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
cs.closeTransportStream(err)
|
||||
}
|
||||
err = toRPCErr(err)
|
||||
}()
|
||||
|
@ -212,7 +219,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
}
|
||||
// Special handling for client streaming rpc.
|
||||
err = recv(cs.p, cs.codec, m)
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
cs.closeTransportStream(err)
|
||||
if err == nil {
|
||||
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
|
||||
}
|
||||
|
@ -225,7 +232,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
|||
return toRPCErr(err)
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); !ok {
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
cs.closeTransportStream(err)
|
||||
}
|
||||
if err == io.EOF {
|
||||
if cs.s.StatusCode() == codes.OK {
|
||||
|
@ -243,12 +250,23 @@ func (cs *clientStream) CloseSend() (err error) {
|
|||
return
|
||||
}
|
||||
if _, ok := err.(transport.ConnectionError); !ok {
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
cs.closeTransportStream(err)
|
||||
}
|
||||
err = toRPCErr(err)
|
||||
return
|
||||
}
|
||||
|
||||
func (cs *clientStream) closeTransportStream(err error) {
|
||||
cs.mu.Lock()
|
||||
if cs.closed {
|
||||
cs.mu.Unlock()
|
||||
return
|
||||
}
|
||||
cs.closed = true
|
||||
cs.mu.Unlock()
|
||||
cs.t.CloseStream(cs.s, err)
|
||||
}
|
||||
|
||||
func (cs *clientStream) finish(err error) {
|
||||
if !cs.tracing {
|
||||
return
|
||||
|
|
|
@ -798,6 +798,51 @@ func testCancel(t *testing.T, e env) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCancelNoIO(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
testCancelNoIO(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
func testCancelNoIO(t *testing.T, e env) {
|
||||
// Only allows 1 live stream per server transport.
|
||||
s, cc := setUp(t, nil, 1, "", e)
|
||||
tc := testpb.NewTestServiceClient(cc)
|
||||
defer tearDown(s, cc)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_, err := tc.StreamingInputCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
|
||||
}
|
||||
// Loop until receiving the new max stream setting from the server.
|
||||
for {
|
||||
ctx, _ := context.WithTimeout(context.Background(), time.Second)
|
||||
_, err := tc.StreamingInputCall(ctx)
|
||||
if err == nil {
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
if grpc.Code(err) == codes.DeadlineExceeded {
|
||||
break
|
||||
}
|
||||
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, %d", tc, err, codes.DeadlineExceeded)
|
||||
}
|
||||
// If there are any RPCs slipping before the client receives the max streams setting,
|
||||
// let them be expired.
|
||||
time.Sleep(2 * time.Second)
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// This should be blocked until the 1st is canceled.
|
||||
ctx, _ := context.WithTimeout(context.Background(), 2 * time.Second)
|
||||
if _, err := tc.StreamingInputCall(ctx); err != nil {
|
||||
t.Errorf("%v.StreamingInputCall(_) = _, %v, want _, <nil>", tc, err)
|
||||
}
|
||||
}()
|
||||
cancel();
|
||||
<-ch
|
||||
}
|
||||
|
||||
// The following tests the gRPC streaming RPC implementations.
|
||||
// TODO(zhaoq): Have better coverage on error cases.
|
||||
var (
|
||||
|
|
Загрузка…
Ссылка в новой задаче