support inbound flow control checking to protect against misbehaved peers

This commit is contained in:
iamqizhao 2015-04-03 15:24:05 -07:00
Родитель c7b9fa27dc
Коммит 4320b5b158
5 изменённых файлов: 318 добавлений и 53 удалений

Просмотреть файл

@ -34,6 +34,7 @@
package transport
import (
"fmt"
"sync"
"github.com/bradfitz/http2"
@ -151,3 +152,67 @@ func (qb *quotaPool) reset(v int) {
func (qb *quotaPool) acquire() <-chan int {
return qb.c
}
type inFlow struct {
limit uint32
conn *inFlow
mu sync.Mutex
pendingData uint32
// The amount of data user has consumed but grpc has not sent window update
// for them. Used to reduce window update frequency. It is always part of
// pendingData.
pendingUpdate uint32
}
func (f *inFlow) onData(n uint32) error {
if n == 0 {
return nil
}
f.mu.Lock()
defer f.mu.Unlock()
if f.pendingData+n > f.limit {
return fmt.Errorf("recieved %d-bytes data exceeding the limit %d bytes", f.pendingData+n, f.limit)
}
if f.conn != nil {
if err := f.conn.onData(n); err != nil {
return ConnectionErrorf("%v", err)
}
}
f.pendingData += n
return nil
}
func (f *inFlow) onRead(n uint32) uint32 {
if n == 0 {
return 0
}
f.mu.Lock()
defer f.mu.Unlock()
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
ret := f.pendingUpdate
f.pendingData -= ret
f.pendingUpdate = 0
return ret
}
return 0
}
func (f *inFlow) restoreConn() uint32 {
if f.conn == nil {
return 0
}
f.mu.Lock()
defer f.mu.Unlock()
ret := f.pendingData
f.conn.mu.Lock()
f.conn.pendingData -= ret
if f.conn.pendingUpdate > f.conn.pendingData {
f.conn.pendingUpdate = f.conn.pendingData
}
f.conn.mu.Unlock()
f.pendingData = 0
f.pendingUpdate = 0
return ret
}

Просмотреть файл

@ -76,8 +76,7 @@ type http2Client struct {
// controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller.
controlBuf *recvBuffer
// The inbound quota being set
recvQuota uint32
fc *inFlow
// sendQuotaPool provides flow control to outbound message.
sendQuotaPool *quotaPool
@ -91,8 +90,6 @@ type http2Client struct {
activeStreams map[uint32]*Stream
// The max number of concurrent streams
maxStreams uint32
// The accumulated inbound quota pending for window update.
updateQuota uint32
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
}
@ -164,7 +161,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
hBuf: &buf,
hEnc: hpack.NewEncoder(&buf),
controlBuf: newRecvBuffer(),
recvQuota: initialConnWindowSize,
fc: &inFlow{limit: initialConnWindowSize},
sendQuotaPool: newQuotaPool(defaultWindowSize),
scheme: scheme,
state: reachable,
@ -184,12 +181,16 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
}
func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
// TODO(zhaoq): Handle uint32 overflow of Stream.id.
s := &Stream{
id: t.nextID,
method: callHdr.Method,
buf: newRecvBuffer(),
recvQuota: initialWindowSize,
fc: fc,
sendQuotaPool: newQuotaPool(int(t.streamSendQuota)),
headerChan: make(chan struct{}),
}
@ -311,6 +312,9 @@ func (t *http2Client) CloseStream(s *Stream, err error) {
delete(t.activeStreams, s.id)
t.mu.Unlock()
s.mu.Lock()
if q := s.fc.restoreConn(); q > 0 {
t.controlBuf.put(&windowUpdate{0, q})
}
if s.state == streamDone {
s.mu.Unlock()
return
@ -475,18 +479,11 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
func (t *http2Client) updateWindow(s *Stream, n uint32) {
t.mu.Lock()
t.updateQuota += n
if t.updateQuota >= t.recvQuota/4 {
t.controlBuf.put(&windowUpdate{0, t.updateQuota})
t.updateQuota = 0
if q := t.fc.onRead(n); q > 0 {
t.controlBuf.put(&windowUpdate{0, q})
}
t.mu.Unlock()
s.updateQuota += n
if s.updateQuota >= s.recvQuota/4 {
t.controlBuf.put(&windowUpdate{s.id, s.updateQuota})
s.updateQuota = 0
if q := s.fc.onRead(n); q > 0 {
t.controlBuf.put(&windowUpdate{s.id, q})
}
}
@ -496,10 +493,29 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
if !ok {
return
}
size := len(f.Data())
if err := s.fc.onData(uint32(size)); err != nil {
if _, ok := err.(ConnectionError); ok {
t.notifyError(err)
return
}
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
return
}
s.state = streamDone
s.statusCode = codes.ResourceExhausted
s.statusDesc = err.Error()
s.mu.Unlock()
s.write(recvMsg{err: io.EOF})
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return
}
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
data := make([]byte, len(f.Data()))
data := make([]byte, size)
copy(data, f.Data())
s.write(recvMsg{data: data})
}

Просмотреть файл

@ -75,16 +75,13 @@ type http2Server struct {
// controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller.
controlBuf *recvBuffer
// The inbound quota being set
recvQuota uint32
fc *inFlow
// sendQuotaPool provides flow control to outbound message.
sendQuotaPool *quotaPool
mu sync.Mutex // guard the following
state transportState
activeStreams map[uint32]*Stream
// The accumulated inbound quota pending for window update.
updateQuota uint32
// the per-stream outbound flow control window size set by the peer.
streamSendQuota uint32
}
@ -124,7 +121,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err er
hEnc: hpack.NewEncoder(&buf),
maxStreams: maxStreams,
controlBuf: newRecvBuffer(),
recvQuota: initialConnWindowSize,
fc: &inFlow{limit: initialConnWindowSize},
sendQuotaPool: newQuotaPool(defaultWindowSize),
state: reachable,
writableChan: make(chan int, 1),
@ -256,11 +253,15 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
t.maxStreamID = id
buf := newRecvBuffer()
fc := &inFlow{
limit: initialWindowSize,
conn: t.fc,
}
curStream = &Stream{
id: frame.Header().StreamID,
st: t,
buf: buf,
recvQuota: initialWindowSize,
id: frame.Header().StreamID,
st: t,
buf: buf,
fc: fc,
}
endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream)
curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg)
@ -304,18 +305,11 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
func (t *http2Server) updateWindow(s *Stream, n uint32) {
t.mu.Lock()
t.updateQuota += n
if t.updateQuota >= t.recvQuota/4 {
t.controlBuf.put(&windowUpdate{0, t.updateQuota})
t.updateQuota = 0
if q := t.fc.onRead(n); q > 0 {
t.controlBuf.put(&windowUpdate{0, q})
}
t.mu.Unlock()
s.updateQuota += n
if s.updateQuota >= s.recvQuota/4 {
t.controlBuf.put(&windowUpdate{s.id, s.updateQuota})
s.updateQuota = 0
if q := s.fc.onRead(n); q > 0 {
t.controlBuf.put(&windowUpdate{s.id, q})
}
}
@ -325,10 +319,21 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
if !ok {
return
}
size := len(f.Data())
if err := s.fc.onData(uint32(size)); err != nil {
if _, ok := err.(ConnectionError); ok {
log.Printf("transport: http2Server %v", err)
t.Close()
return
}
t.closeStream(s)
t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl})
return
}
// TODO(bradfitz, zhaoq): A copy is required here because there is no
// guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated?
data := make([]byte, len(f.Data()))
data := make([]byte, size)
copy(data, f.Data())
s.write(recvMsg{data: data})
if f.Header().Flags.Has(http2.FlagDataEndStream) {
@ -643,6 +648,9 @@ func (t *http2Server) closeStream(s *Stream) {
t.mu.Lock()
delete(t.activeStreams, s.id)
t.mu.Unlock()
if q := s.fc.restoreConn(); q > 0 {
t.controlBuf.put(&windowUpdate{0, q})
}
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()

Просмотреть файл

@ -173,7 +173,7 @@ type Stream struct {
buf *recvBuffer
dec io.Reader
// The inbound quota being set
fc *inFlow
recvQuota uint32
// The accumulated inbound quota pending for window update.
updateQuota uint32
@ -197,8 +197,9 @@ type Stream struct {
// multiple times.
headerDone bool
// the status received from the server.
statusCode codes.Code
statusDesc string
statusCode codes.Code
statusDesc string
pendingData uint32
}
// Header acquires the key-value pairs of header metadata once it

Просмотреть файл

@ -45,6 +45,7 @@ import (
"testing"
"time"
"github.com/bradfitz/http2"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
@ -71,6 +72,14 @@ type testStreamHandler struct {
t ServerTransport
}
type hType int
const (
normal hType = iota
suspended
misbehaved
)
func (h *testStreamHandler) handleStream(s *Stream) {
req := expectedRequest
resp := expectedResponse
@ -97,8 +106,29 @@ func (h *testStreamHandler) handleStreamSuspension(s *Stream) {
<-s.ctx.Done()
}
func (h *testStreamHandler) handleStreamMisbehave(s *Stream) {
conn, ok := s.ServerTransport().(*http2Server)
if !ok {
log.Fatalf("Failed to convert %v to *http2Server")
}
size := 1
if s.Method() == "foo.MaxFrame" {
size = http2MaxFrameLen
}
// Drain the client flow control window.
var err error
var sent int
for sent <= initialWindowSize {
<-conn.writableChan
if err = conn.framer.writeData(true, s.id, false, make([]byte, size)); err != nil {
}
conn.writableChan <- 0
sent += 1
}
}
// start starts server. Other goroutines should block on s.readyChan for futher operations.
func (s *server) start(useTLS bool, port int, maxStreams uint32, suspend bool) {
func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) {
var err error
if port == 0 {
s.lis, err = net.Listen("tcp", ":0")
@ -142,9 +172,12 @@ func (s *server) start(useTLS bool, port int, maxStreams uint32, suspend bool) {
s.conns[t] = true
s.mu.Unlock()
h := &testStreamHandler{t}
if suspend {
switch ht {
case suspended:
go t.HandleStreams(h.handleStreamSuspension)
} else {
case misbehaved:
go t.HandleStreams(h.handleStreamMisbehave)
default:
go t.HandleStreams(h.handleStream)
}
}
@ -168,9 +201,9 @@ func (s *server) stop() {
s.mu.Unlock()
}
func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool) (*server, ClientTransport) {
func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
server := &server{readyChan: make(chan bool)}
go server.start(useTLS, port, maxStreams, suspend)
go server.start(useTLS, port, maxStreams, ht)
server.wait(t, 2*time.Second)
addr := "localhost:" + server.port
var (
@ -196,7 +229,7 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, suspend bool)
}
func TestClientSendAndReceive(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, false)
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
@ -236,7 +269,7 @@ func TestClientSendAndReceive(t *testing.T) {
}
func TestClientErrorNotify(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, false)
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
go server.stop()
// ct.reader should detect the error and activate ct.Error().
<-ct.Error()
@ -270,7 +303,7 @@ func performOneRPC(ct ClientTransport) {
}
func TestClientMix(t *testing.T) {
s, ct := setUp(t, true, 0, math.MaxUint32, false)
s, ct := setUp(t, true, 0, math.MaxUint32, normal)
go func(s *server) {
time.Sleep(5 * time.Second)
s.stop()
@ -286,7 +319,7 @@ func TestClientMix(t *testing.T) {
}
func TestExceedMaxStreamsLimit(t *testing.T) {
server, ct := setUp(t, true, 0, 1, false)
server, ct := setUp(t, true, 0, 1, normal)
defer func() {
ct.Close()
server.stop()
@ -334,7 +367,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
}
func TestLargeMessage(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, false)
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
@ -368,7 +401,7 @@ func TestLargeMessage(t *testing.T) {
}
func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, true)
server, ct := setUp(t, true, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
@ -389,6 +422,148 @@ func TestLargeMessageSuspension(t *testing.T) {
server.stop()
}
func TestServerWithMisbehavedClient(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
var sc *http2Server
for k, _ := range server.conns {
var ok bool
sc, ok = k.(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", k)
}
}
cc, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
// Test server behavior for violation of stream flow control window size restriction.
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
var sent int
// Drain the stream flow control window
<-cc.writableChan
if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil {
t.Fatalf("Failed to write data: ", err)
}
cc.writableChan <- 0
// Wait until the server creates the corresponding stream.
for {
time.Sleep(time.Millisecond)
sc.mu.Lock()
if len(sc.activeStreams) > 0 {
sc.mu.Unlock()
break
}
sc.mu.Unlock()
}
ss := sc.activeStreams[s.id]
if ss.fc.pendingData != http2MaxFrameLen || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != http2MaxFrameLen || sc.fc.pendingUpdate != 0 {
t.Fatalf("Server mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, http2MaxFrameLen, 0, http2MaxFrameLen, 0)
}
sent += http2MaxFrameLen
// Keep sending until the server inbound window is drained for that stream.
for sent <= initialWindowSize {
<-cc.writableChan
if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil {
t.Fatalf("Failed to write data: ", err)
}
cc.writableChan <- 0
sent += http2MaxFrameLen
}
// Server sent a resetStream for s already.
code := http2RSTErrConvTab[http2.ErrCodeFlowControl]
if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF || s.statusCode != code {
t.Fatalf("%v got err %v with statusCode %d, want err <EOF> with statusCode %d", s, err, s.statusCode, code)
}
if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != 0 {
t.Fatalf("Server mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, 0", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate)
}
ct.CloseStream(s, nil)
// Test server behavior for violation of connection flow control window size restriction.
//
// Keep creating new streams until the connection window is drained on the server and
// the server tears down the connection.
for {
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
<-cc.writableChan
// Write will fail when connection flow control window runs out.
if err := cc.framer.writeData(true, s.id, true, make([]byte, http2MaxFrameLen)); err != nil {
// The server tears down the connection.
break
}
cc.writableChan <- 0
}
ct.Close()
server.stop()
}
func TestClientWithMisbehavedServer(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, misbehaved)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
conn, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
// Test the logic for the violation of stream flow control window size restriction.
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil {
t.Fatalf("Failed to write: %v", err)
}
// Read without window update.
for {
p := make([]byte, http2MaxFrameLen)
if _, err = s.dec.Read(p); err != nil {
break
}
}
if s.fc.pendingData != initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != initialWindowSize || conn.fc.pendingUpdate != 0 {
t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0)
}
if err != io.EOF || s.statusCode != codes.ResourceExhausted {
t.Fatalf("Got err %v and the status code %d, want <EOF> and the code %d", err, s.statusCode, codes.ResourceExhausted)
}
conn.CloseStream(s, err)
if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != 0 {
t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, 0", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate)
}
// Test the logic for the violation of the connection flow control window size restriction.
//
// Generate enough streams to drain the connection window.
callHdr = &CallHdr{
Host: "localhost",
Method: "foo.MaxFrame",
}
for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ {
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil {
break
}
}
// http2Client.errChan is closed due to connection flow control window size violation.
<-conn.Error()
ct.Close()
server.stop()
}
func TestStreamContext(t *testing.T) {
expectedStream := Stream{}
ctx := newContextWithStream(context.Background(), &expectedStream)