transport: allow InTapHandle to return status errors (#4365)
This commit is contained in:
Родитель
aff517ba8a
Коммит
328b1d171a
|
@ -20,13 +20,17 @@ package transport
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) {
|
||||
|
@ -128,6 +132,14 @@ type cleanupStream struct {
|
|||
|
||||
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
|
||||
|
||||
type earlyAbortStream struct {
|
||||
streamID uint32
|
||||
contentSubtype string
|
||||
status *status.Status
|
||||
}
|
||||
|
||||
func (*earlyAbortStream) isTransportResponseFrame() bool { return false }
|
||||
|
||||
type dataFrame struct {
|
||||
streamID uint32
|
||||
endStream bool
|
||||
|
@ -749,6 +761,24 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error {
|
||||
if l.side == clientSide {
|
||||
return errors.New("earlyAbortStream not handled on client")
|
||||
}
|
||||
|
||||
headerFields := []hpack.HeaderField{
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)},
|
||||
{Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))},
|
||||
{Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())},
|
||||
}
|
||||
|
||||
if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *loopyWriter) incomingGoAwayHandler(*incomingGoAway) error {
|
||||
if l.side == clientSide {
|
||||
l.draining = true
|
||||
|
@ -787,6 +817,8 @@ func (l *loopyWriter) handle(i interface{}) error {
|
|||
return l.registerStreamHandler(i)
|
||||
case *cleanupStream:
|
||||
return l.cleanupStreamHandler(i)
|
||||
case *earlyAbortStream:
|
||||
return l.earlyAbortStreamHandler(i)
|
||||
case *incomingGoAway:
|
||||
return l.incomingGoAwayHandler(i)
|
||||
case *dataFrame:
|
||||
|
|
|
@ -356,26 +356,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
if state.data.statsTrace != nil {
|
||||
s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace)
|
||||
}
|
||||
if t.inTapHandle != nil {
|
||||
var err error
|
||||
info := &tap.Info{
|
||||
FullMethodName: state.data.method,
|
||||
}
|
||||
s.ctx, err = t.inTapHandle(s.ctx, info)
|
||||
if err != nil {
|
||||
if logger.V(logLevel) {
|
||||
logger.Warningf("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
|
||||
}
|
||||
t.controlBuf.put(&cleanupStream{
|
||||
streamID: s.id,
|
||||
rst: true,
|
||||
rstCode: http2.ErrCodeRefusedStream,
|
||||
onWrite: func() {},
|
||||
})
|
||||
s.cancel()
|
||||
return false
|
||||
}
|
||||
}
|
||||
t.mu.Lock()
|
||||
if t.state != reachable {
|
||||
t.mu.Unlock()
|
||||
|
@ -417,6 +397,25 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
s.cancel()
|
||||
return false
|
||||
}
|
||||
if t.inTapHandle != nil {
|
||||
var err error
|
||||
if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: state.data.method}); err != nil {
|
||||
t.mu.Unlock()
|
||||
if logger.V(logLevel) {
|
||||
logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err)
|
||||
}
|
||||
stat, ok := status.FromError(err)
|
||||
if !ok {
|
||||
stat = status.New(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
t.controlBuf.put(&earlyAbortStream{
|
||||
streamID: s.id,
|
||||
contentSubtype: s.contentSubtype,
|
||||
status: stat,
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
t.activeStreams[streamID] = s
|
||||
if len(t.activeStreams) == 1 {
|
||||
t.idle = time.Time{}
|
||||
|
|
|
@ -418,6 +418,11 @@ func ChainStreamInterceptor(interceptors ...StreamServerInterceptor) ServerOptio
|
|||
|
||||
// InTapHandle returns a ServerOption that sets the tap handle for all the server
|
||||
// transport to be created. Only one can be installed.
|
||||
//
|
||||
// Experimental
|
||||
//
|
||||
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
|
||||
// later release.
|
||||
func InTapHandle(h tap.ServerInHandle) ServerOption {
|
||||
return newFuncServerOption(func(o *serverOptions) {
|
||||
if o.inTapHandle != nil {
|
||||
|
|
16
tap/tap.go
16
tap/tap.go
|
@ -37,16 +37,16 @@ type Info struct {
|
|||
// TODO: More to be added.
|
||||
}
|
||||
|
||||
// ServerInHandle defines the function which runs before a new stream is created
|
||||
// on the server side. If it returns a non-nil error, the stream will not be
|
||||
// created and a RST_STREAM will be sent back to the client with REFUSED_STREAM.
|
||||
// The client will receive an RPC error "code = Unavailable, desc = stream
|
||||
// terminated by RST_STREAM with error code: REFUSED_STREAM".
|
||||
// ServerInHandle defines the function which runs before a new stream is
|
||||
// created on the server side. If it returns a non-nil error, the stream will
|
||||
// not be created and an error will be returned to the client. If the error
|
||||
// returned is a status error, that status code and message will be used,
|
||||
// otherwise PermissionDenied will be the code and err.Error() will be the
|
||||
// message.
|
||||
//
|
||||
// It's intended to be used in situations where you don't want to waste the
|
||||
// resources to accept the new stream (e.g. rate-limiting). And the content of
|
||||
// the error will be ignored and won't be sent back to the client. For other
|
||||
// general usages, please use interceptors.
|
||||
// resources to accept the new stream (e.g. rate-limiting). For other general
|
||||
// usages, please use interceptors.
|
||||
//
|
||||
// Note that it is executed in the per-connection I/O goroutine(s) instead of
|
||||
// per-RPC goroutine. Therefore, users should NOT have any
|
||||
|
|
|
@ -2507,10 +2507,13 @@ type myTap struct {
|
|||
|
||||
func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, error) {
|
||||
if info != nil {
|
||||
if info.FullMethodName == "/grpc.testing.TestService/EmptyCall" {
|
||||
switch info.FullMethodName {
|
||||
case "/grpc.testing.TestService/EmptyCall":
|
||||
t.cnt++
|
||||
} else if info.FullMethodName == "/grpc.testing.TestService/UnaryCall" {
|
||||
case "/grpc.testing.TestService/UnaryCall":
|
||||
return nil, fmt.Errorf("tap error")
|
||||
case "/grpc.testing.TestService/FullDuplexCall":
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "test custom error")
|
||||
}
|
||||
}
|
||||
return ctx, nil
|
||||
|
@ -2550,8 +2553,15 @@ func testTap(t *testing.T, e env) {
|
|||
ResponseSize: 45,
|
||||
Payload: payload,
|
||||
}
|
||||
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.Unavailable {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.Unavailable)
|
||||
if _, err := tc.UnaryCall(ctx, req); status.Code(err) != codes.PermissionDenied {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, %s", err, codes.PermissionDenied)
|
||||
}
|
||||
str, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating stream: %v", err)
|
||||
}
|
||||
if _, err := str.Recv(); status.Code(err) != codes.FailedPrecondition {
|
||||
t.Fatalf("FullDuplexCall Recv() = _, %v, want _, %s", err, codes.FailedPrecondition)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3639,66 +3649,77 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) {
|
|||
}
|
||||
}
|
||||
|
||||
// Tests that the client transparently retries correctly when receiving a
|
||||
// RST_STREAM with code REFUSED_STREAM.
|
||||
func (s) TestTransparentRetry(t *testing.T) {
|
||||
for _, e := range listTestEnv() {
|
||||
if e.name == "handler-tls" {
|
||||
// Fails with RST_STREAM / FLOW_CONTROL_ERROR
|
||||
continue
|
||||
}
|
||||
testTransparentRetry(t, e)
|
||||
}
|
||||
}
|
||||
|
||||
// This test makes sure RPCs are retried times when they receive a RST_STREAM
|
||||
// with the REFUSED_STREAM error code, which the InTapHandle provokes.
|
||||
func testTransparentRetry(t *testing.T, e env) {
|
||||
te := newTest(t, e)
|
||||
attempts := 0
|
||||
successAttempt := 2
|
||||
te.tapHandle = func(ctx context.Context, _ *tap.Info) (context.Context, error) {
|
||||
attempts++
|
||||
if attempts < successAttempt {
|
||||
return nil, errors.New("not now")
|
||||
}
|
||||
return ctx, nil
|
||||
}
|
||||
te.startServer(&testServer{security: e.security})
|
||||
defer te.tearDown()
|
||||
|
||||
cc := te.clientConn()
|
||||
tsc := testpb.NewTestServiceClient(cc)
|
||||
testCases := []struct {
|
||||
successAttempt int
|
||||
failFast bool
|
||||
errCode codes.Code
|
||||
failFast bool
|
||||
errCode codes.Code
|
||||
}{{
|
||||
successAttempt: 1,
|
||||
// success attempt: 1, (stream ID 1)
|
||||
}, {
|
||||
successAttempt: 2,
|
||||
// success attempt: 2, (stream IDs 3, 5)
|
||||
}, {
|
||||
successAttempt: 3,
|
||||
errCode: codes.Unavailable,
|
||||
// no success attempt (stream IDs 7, 9)
|
||||
errCode: codes.Unavailable,
|
||||
}, {
|
||||
successAttempt: 1,
|
||||
failFast: true,
|
||||
// success attempt: 1 (stream ID 11),
|
||||
failFast: true,
|
||||
}, {
|
||||
successAttempt: 2,
|
||||
failFast: true,
|
||||
// success attempt: 2 (stream IDs 13, 15),
|
||||
failFast: true,
|
||||
}, {
|
||||
successAttempt: 3,
|
||||
failFast: true,
|
||||
errCode: codes.Unavailable,
|
||||
// no success attempt (stream IDs 17, 19)
|
||||
failFast: true,
|
||||
errCode: codes.Unavailable,
|
||||
}}
|
||||
for _, tc := range testCases {
|
||||
attempts = 0
|
||||
successAttempt = tc.successAttempt
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
_, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!tc.failFast))
|
||||
cancel()
|
||||
if status.Code(err) != tc.errCode {
|
||||
t.Errorf("%+v: tsc.EmptyCall(_, _) = _, %v, want _, Code=%v", tc, err, tc.errCode)
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen. Err: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
server := &httpServer{
|
||||
headerFields: [][]string{{
|
||||
":status", "200",
|
||||
"content-type", "application/grpc",
|
||||
"grpc-status", "0",
|
||||
}},
|
||||
refuseStream: func(i uint32) bool {
|
||||
switch i {
|
||||
case 1, 5, 11, 15: // these stream IDs succeed
|
||||
return false
|
||||
}
|
||||
return true // these are refused
|
||||
},
|
||||
}
|
||||
server.start(t, lis)
|
||||
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to dial due to err: %v", err)
|
||||
}
|
||||
defer cc.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := testpb.NewTestServiceClient(cc)
|
||||
|
||||
for i, tc := range testCases {
|
||||
stream, err := client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating stream due to err: %v", err)
|
||||
}
|
||||
code := func(err error) codes.Code {
|
||||
if err == io.EOF {
|
||||
return codes.OK
|
||||
}
|
||||
return status.Code(err)
|
||||
}
|
||||
if _, err := stream.Recv(); code(err) != tc.errCode {
|
||||
t.Fatalf("%v: stream.Recv() = _, %v, want error code: %v", i, err, tc.errCode)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7191,6 +7212,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingMoreThanTwoHeaders(t *testing.T) {
|
|||
|
||||
type httpServer struct {
|
||||
headerFields [][]string
|
||||
refuseStream func(uint32) bool
|
||||
}
|
||||
|
||||
func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error {
|
||||
|
@ -7238,25 +7260,34 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) {
|
|||
writer.Flush() // necessary since client is expecting preface before declaring connection fully setup.
|
||||
|
||||
var sid uint32
|
||||
// Read frames until a header is received.
|
||||
// Loop until conn is closed and framer returns io.EOF
|
||||
for {
|
||||
frame, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
t.Errorf("Error at server-side while reading frame. Err: %v", err)
|
||||
return
|
||||
// Read frames until a header is received.
|
||||
for {
|
||||
frame, err := framer.ReadFrame()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Errorf("Error at server-side while reading frame. Err: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if hframe, ok := frame.(*http2.HeadersFrame); ok {
|
||||
sid = hframe.Header().StreamID
|
||||
if s.refuseStream == nil || !s.refuseStream(sid) {
|
||||
break
|
||||
}
|
||||
framer.WriteRSTStream(sid, http2.ErrCodeRefusedStream)
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
if hframe, ok := frame.(*http2.HeadersFrame); ok {
|
||||
sid = hframe.Header().StreamID
|
||||
break
|
||||
for i, headers := range s.headerFields {
|
||||
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
|
||||
t.Errorf("Error at server-side while writing headers. Err: %v", err)
|
||||
return
|
||||
}
|
||||
writer.Flush()
|
||||
}
|
||||
}
|
||||
for i, headers := range s.headerFields {
|
||||
if err = s.writeHeader(framer, sid, headers, i == len(s.headerFields)-1); err != nil {
|
||||
t.Errorf("Error at server-side while writing headers. Err: %v", err)
|
||||
return
|
||||
}
|
||||
writer.Flush()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче