diff --git a/test/end2end_test.go b/test/end2end_test.go index 3f55024b..73e6cb65 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -795,17 +795,23 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { s, cc := setUp(1, e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) - var err error - for { - time.Sleep(2 * time.Millisecond) - _, err = tc.StreamingInputCall(context.Background()) - // Loop until the settings of max concurrent streams is - // received by the client. - if err != nil { - break + // Perform a unary RPC to make sure the new settings were propagated to the client. + if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", tc, err) + } + // Initiate the 1st stream + if _, err := tc.StreamingInputCall(context.Background()); err != nil { + t.Fatalf("%v.StreamingInputCall(_) = %v, want ", tc, err) + } + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + // The 2nd stream should block until its deadline exceeds. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if _, err := tc.StreamingInputCall(ctx); grpc.Code(err) != codes.DeadlineExceeded { + t.Errorf("%v.StreamingInputCall(%v) = _, %v, want error code %d", tc, ctx, err, codes.DeadlineExceeded) } - } - if grpc.Code(err) != codes.Unavailable { - t.Fatalf("got %v, want error code %d", err, codes.Unavailable) - } + }() + wg.Wait() } diff --git a/transport/http2_client.go b/transport/http2_client.go index 6ba93448..a32d416c 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -79,6 +79,8 @@ type http2Client struct { fc *inFlow // sendQuotaPool provides flow control to outbound message. sendQuotaPool *quotaPool + // streamsQuota limits the max number of concurrent streams. + streamsQuota *quotaPool // The scheme used: https if TLS is on, http otherwise. scheme string @@ -89,7 +91,7 @@ type http2Client struct { state transportState // the state of underlying connection activeStreams map[uint32]*Stream // The max number of concurrent streams - maxStreams uint32 + maxStreams int // the per-stream outbound flow control window size set by the peer. streamSendQuota uint32 } @@ -174,8 +176,8 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e scheme: scheme, state: reachable, activeStreams: make(map[uint32]*Stream), - maxStreams: math.MaxUint32, authCreds: opts.AuthOptions, + maxStreams: math.MaxInt32, streamSendQuota: defaultWindowSize, } go t.controller() @@ -188,7 +190,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e return t, nil } -func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, sq bool) *Stream { fc := &inFlow{ limit: initialWindowSize, conn: t.fc, @@ -198,6 +200,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { id: t.nextID, method: callHdr.Method, buf: newRecvBuffer(), + updateStreams: sq, fc: fc, sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), headerChan: make(chan struct{}), @@ -236,20 +239,29 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea authData[k] = v } } - if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { - return nil, err - } t.mu.Lock() if t.state != reachable { t.mu.Unlock() return nil, ErrConnClosing } - if uint32(len(t.activeStreams)) >= t.maxStreams { - t.mu.Unlock() - t.writableChan <- 0 - return nil, StreamErrorf(codes.Unavailable, "transport: failed to create new stream because the limit has been reached.") + checkStreamsQuota := t.streamsQuota != nil + t.mu.Unlock() + if checkStreamsQuota { + sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) + if err != nil { + return nil, err + } + // Returns the quota balance back. + if sq > 1 { + t.streamsQuota.add(sq - 1) + } } - s := t.newStream(ctx, callHdr) + if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { + // t.streamsQuota will be updated when t.CloseStream is invoked. + return nil, err + } + t.mu.Lock() + s := t.newStream(ctx, callHdr, checkStreamsQuota) t.activeStreams[s.id] = s t.mu.Unlock() // HPACK encodes various headers. Note that once WriteField(...) is @@ -319,6 +331,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) { t.mu.Lock() delete(t.activeStreams, s.id) t.mu.Unlock() + if s.updateStreams { + t.streamsQuota.add(1) + } s.mu.Lock() if q := s.fc.restoreConn(); q > 0 { t.controlBuf.put(&windowUpdate{0, q}) @@ -554,17 +569,32 @@ func (t *http2Client) handleSettings(f *http2.SettingsFrame) { } f.ForeachSetting(func(s http2.Setting) error { if v, ok := f.Value(s.ID); ok { - t.mu.Lock() - defer t.mu.Unlock() switch s.ID { case http2.SettingMaxConcurrentStreams: - t.maxStreams = v + // TODO(zhaoq): This is a hack to avoid significant refactoring of the + // code to deal with the unrealistic int32 overflow. Probably will try + // to find a better way to handle this later. + if v > math.MaxInt32 { + v = math.MaxInt32 + } + t.mu.Lock() + reset := t.streamsQuota != nil + ms := t.maxStreams + t.maxStreams = int(v) + t.mu.Unlock() + if !reset { + t.streamsQuota = newQuotaPool(int(v)) + } else { + t.streamsQuota.reset(int(v) - ms) + } case http2.SettingInitialWindowSize: + t.mu.Lock() for _, s := range t.activeStreams { // Adjust the sending quota for each s. s.sendQuotaPool.reset(int(v - t.streamSendQuota)) } t.streamSendQuota = v + t.mu.Unlock() } } return nil diff --git a/transport/transport.go b/transport/transport.go index 5dfd89f0..498cee54 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -173,8 +173,13 @@ type Stream struct { buf *recvBuffer dec io.Reader - fc *inFlow - recvQuota uint32 + // updateStreams indicates whether the transport's streamsQuota needed + // to be updated when this stream is closed. It is false when the transport + // sticks to the initial infinite value of the number of concurrent streams. + // Ture otherwise. + updateStreams bool + fc *inFlow + recvQuota uint32 // The accumulated inbound quota pending for window update. updateQuota uint32 // The handler to control the window update procedure for both this diff --git a/transport/transport_test.go b/transport/transport_test.go index adbb3e00..8529e2af 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -299,54 +299,6 @@ func TestClientMix(t *testing.T) { } } -func TestExceedMaxStreamsLimit(t *testing.T) { - server, ct := setUp(t, 0, 1, normal) - defer func() { - ct.Close() - server.stop() - }() - callHdr := &CallHdr{ - Host: "localhost", - Method: "foo.Small", - } - // Creates the 1st stream and keep it alive. - _, err1 := ct.NewStream(context.Background(), callHdr) - if err1 != nil { - t.Fatalf("failed to open stream: %v", err1) - } - // Creates the 2nd stream. It has chance to succeed when the settings - // frame from the server has not received at the client. - s, err2 := ct.NewStream(context.Background(), callHdr) - if err2 != nil { - se, ok := err2.(StreamError) - if !ok { - t.Fatalf("Received unexpected error %v", err2) - } - if se.Code != codes.Unavailable { - t.Fatalf("Got error code: %d, want: %d", se.Code, codes.Unavailable) - } - return - } - // If the 2nd stream is created successfully, sends the request. - if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { - t.Fatalf("failed to send data: %v", err) - } - // The 2nd stream was rejected by the server via a reset. - p := make([]byte, len(expectedResponse)) - _, recvErr := io.ReadFull(s, p) - if recvErr != io.EOF || s.StatusCode() != codes.Unavailable { - t.Fatalf("Error: %v, StatusCode: %d; want , %d", recvErr, s.StatusCode(), codes.Unavailable) - } - // Server's setting has been received. From now on, new stream will be rejected instantly. - _, err3 := ct.NewStream(context.Background(), callHdr) - if err3 == nil { - t.Fatalf("Received unexpected , want an error with code %d", codes.Unavailable) - } - if se, ok := err3.(StreamError); !ok || se.Code != codes.Unavailable { - t.Fatalf("Got: %v, want a StreamError with error code %d", err3, codes.Unavailable) - } -} - func TestLargeMessage(t *testing.T) { server, ct := setUp(t, 0, math.MaxUint32, normal) callHdr := &CallHdr{ @@ -357,23 +309,23 @@ func TestLargeMessage(t *testing.T) { for i := 0; i < 2; i++ { wg.Add(1) go func() { + defer wg.Done() s, err := ct.NewStream(context.Background(), callHdr) if err != nil { - t.Fatalf("failed to open stream: %v", err) + t.Errorf("failed to open stream: %v", err) } if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { - t.Fatalf("failed to send data: %v", err) + t.Errorf("failed to send data: %v", err) } p := make([]byte, len(expectedResponseLarge)) _, recvErr := io.ReadFull(s, p) if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) { - t.Fatalf("Error: %v, want ; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge)) + t.Errorf("Error: %v, want ; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge)) } _, recvErr = io.ReadFull(s, p) if recvErr != io.EOF { - t.Fatalf("Error: %v; want ", recvErr) + t.Errorf("Error: %v; want ", recvErr) } - wg.Done() }() } wg.Wait() @@ -548,7 +500,7 @@ func TestClientWithMisbehavedServer(t *testing.T) { for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ { s, err := ct.NewStream(context.Background(), callHdr) if err != nil { - t.Fatalf("Failed to open stream: %v", err) + break } if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { break