diff --git a/stream.go b/stream.go index 008ff10e..0d6cf539 100644 --- a/stream.go +++ b/stream.go @@ -237,6 +237,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth select { case <-t.Error(): // Incur transport error, simply exit. + case <-cc.ctx.Done(): + cs.finish(ErrClientConnClosing) + cs.closeTransportStream(ErrClientConnClosing) case <-s.Done(): // TODO: The trace of the RPC is terminated here when there is no pending // I/O, which is probably not the optimal solution. diff --git a/test/end2end_test.go b/test/end2end_test.go index 54840ee4..fd77cd7c 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -995,6 +995,41 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) { awaitNewConnLogOutput() } +func TestClientConnCloseAfterGoAwayWithActiveStream(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testClientConnCloseAfterGoAwayWithActiveStream(t, e) + } +} + +func testClientConnCloseAfterGoAwayWithActiveStream(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security}) + defer te.tearDown() + cc := te.clientConn() + tc := testpb.NewTestServiceClient(cc) + + if _, err := tc.FullDuplexCall(context.Background()); err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, ", tc, err) + } + done := make(chan struct{}) + go func() { + te.srv.GracefulStop() + close(done) + }() + time.Sleep(time.Second) + cc.Close() + timeout := time.NewTimer(time.Second) + select { + case <-done: + case <-timeout.C: + t.Fatalf("Test timed-out.") + } +} + func TestFailFast(t *testing.T) { defer leakCheck(t)() for _, e := range listTestEnv() {