From 8f06f82ca394b1ac837d4b0c0cfa07188b0e9dee Mon Sep 17 00:00:00 2001 From: mmukhi Date: Mon, 21 May 2018 15:59:39 -0700 Subject: [PATCH] Synchronize WriteStatus with WriteHeader on server. (#2074) --- transport/http2_server.go | 93 ++++++++++++++++++++------------------- transport/transport.go | 37 +++++++++++++--- 2 files changed, 78 insertions(+), 52 deletions(-) diff --git a/transport/http2_server.go b/transport/http2_server.go index 347edbda..af009574 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -683,28 +683,7 @@ func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { }) } -// WriteHeader sends the header metedata md back to the client. -func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { - if s.headerOk || s.getState() == streamDone { - return ErrIllegalHeaderWrite - } - s.headerOk = true - if md.Len() > 0 { - if s.header.Len() > 0 { - s.header = metadata.Join(s.header, md) - } else { - s.header = md - } - } - md = s.header - // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields - // first and create a slice of that exact size. - headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. - headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) - if s.sendCompress != "" { - headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) - } +func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) []hpack.HeaderField { for k, vv := range md { if isReservedHeader(k) { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. @@ -714,6 +693,37 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) } } + return headerFields +} + +// WriteHeader sends the header metedata md back to the client. +func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { + if s.updateHeaderSent() || s.getState() == streamDone { + return ErrIllegalHeaderWrite + } + s.hdrMu.Lock() + if md.Len() > 0 { + if s.header.Len() > 0 { + s.header = metadata.Join(s.header, md) + } else { + s.header = md + } + } + t.writeHeaderLocked(s) + s.hdrMu.Unlock() + return nil +} + +func (t *http2Server) writeHeaderLocked(s *Stream) { + // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields + // first and create a slice of that exact size. + headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + if s.sendCompress != "" { + headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) + } + headerFields = appendHeaderFieldsFromMD(headerFields, s.header) t.controlBuf.put(&headerFrame{ streamID: s.id, hf: headerFields, @@ -729,7 +739,6 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { outHeader := &stats.OutHeader{} t.stats.HandleRPC(s.Context(), outHeader) } - return nil } // WriteStatus sends stream status to the client and terminates the stream. @@ -737,21 +746,20 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { - if !s.headerOk && s.header.Len() > 0 { - if err := t.WriteHeader(s, nil); err != nil { - return err - } - } else { - if s.getState() == streamDone { - return nil - } + if s.getState() == streamDone { + return nil } + s.hdrMu.Lock() // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // first and create a slice of that exact size. headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. - if !s.headerOk { - headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + if !s.updateHeaderSent() { // No headers have been sent. + if len(s.header) > 0 { // Send a separate header frame. + t.writeHeaderLocked(s) + } else { // Send a trailer only response. + headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) + } } headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) @@ -767,16 +775,8 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { } // Attach the trailer metadata. - for k, vv := range s.trailer { - // Clients don't tolerate reading restricted headers after some non restricted ones were sent. - if isReservedHeader(k) { - continue - } - for _, v := range vv { - headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)}) - } - } - trailer := &headerFrame{ + headerFields = appendHeaderFieldsFromMD(headerFields, s.trailer) + trailingHeader := &headerFrame{ streamID: s.id, hf: headerFields, endStream: true, @@ -784,7 +784,8 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { atomic.StoreUint32(&t.resetPingStrikes, 1) }, } - t.closeStream(s, false, 0, trailer, true) + s.hdrMu.Unlock() + t.closeStream(s, false, 0, trailingHeader, true) if t.stats != nil { t.stats.HandleRPC(s.Context(), &stats.OutTrailer{}) } @@ -794,7 +795,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { // Write converts the data into HTTP2 data frame and sends it out. Non-nil error // is returns if it fails (e.g., framing error, transport error). func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error { - if !s.headerOk { // Headers haven't been written yet. + if !s.isHeaderSent() { // Headers haven't been written yet. if err := t.WriteHeader(s, nil); err != nil { // TODO(mmukhi, dfawley): Make sure this is the right code to return. return streamErrorf(codes.Internal, "transport: %v", err) diff --git a/transport/transport.go b/transport/transport.go index 2f643a3d..f51f8788 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -185,13 +185,20 @@ type Stream struct { headerChan chan struct{} // closed to indicate the end of header metadata. headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. - header metadata.MD // the received header metadata. - trailer metadata.MD // the key-value map of trailer metadata. - headerOk bool // becomes true from the first header is about to send - state streamState + // hdrMu protects header and trailer metadata on the server-side. + hdrMu sync.Mutex + header metadata.MD // the received header metadata. + trailer metadata.MD // the key-value map of trailer metadata. - status *status.Status // the status error received from the server + // On the server-side, headerSent is atomically set to 1 when the headers are sent out. + headerSent uint32 + + state streamState + + // On client-side it is the status error received from the server. + // On server-side it is unused. + status *status.Status bytesReceived uint32 // indicates whether any bytes have been received on this stream unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream @@ -201,6 +208,17 @@ type Stream struct { contentSubtype string } +// isHeaderSent is only valid on the server-side. +func (s *Stream) isHeaderSent() bool { + return atomic.LoadUint32(&s.headerSent) == 1 +} + +// updateHeaderSent updates headerSent and returns true +// if it was alreay set. It is valid only on server-side. +func (s *Stream) updateHeaderSent() bool { + return atomic.SwapUint32(&s.headerSent, 1) == 1 +} + func (s *Stream) swapState(st streamState) streamState { return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) } @@ -313,10 +331,12 @@ func (s *Stream) SetHeader(md metadata.MD) error { if md.Len() == 0 { return nil } - if s.headerOk || atomic.LoadUint32((*uint32)(&s.state)) == uint32(streamDone) { + if s.isHeaderSent() || s.getState() == streamDone { return ErrIllegalHeaderWrite } + s.hdrMu.Lock() s.header = metadata.Join(s.header, md) + s.hdrMu.Unlock() return nil } @@ -335,7 +355,12 @@ func (s *Stream) SetTrailer(md metadata.MD) error { if md.Len() == 0 { return nil } + if s.getState() == streamDone { + return ErrIllegalHeaderWrite + } + s.hdrMu.Lock() s.trailer = metadata.Join(s.trailer, md) + s.hdrMu.Unlock() return nil }