From 5ba054bf3709c0a51d1ed76f68d274a8c1eef459 Mon Sep 17 00:00:00 2001 From: dfawley Date: Tue, 23 Jan 2018 11:39:40 -0800 Subject: [PATCH] encoding: Introduce new method for registering and choosing codecs (#1813) --- Documentation/encoding.md | 146 ++++++++++++++++++ benchmark/worker/benchmark_client.go | 2 +- call.go | 7 +- clientconn.go | 11 +- codec.go | 98 ++---------- codec_test.go | 106 +------------ encoding/encoding.go | 123 +++++++++++---- encoding/proto/proto.go | 108 +++++++++++++ .../proto/proto_benchmark_test.go | 17 +- encoding/proto/proto_test.go | 129 ++++++++++++++++ rpc_util.go | 76 ++++++++- rpc_util_test.go | 9 +- server.go | 31 +++- stream.go | 12 +- test/end2end_test.go | 11 +- transport/handler_server.go | 40 +++-- transport/handler_server_test.go | 7 +- transport/http2_client.go | 21 +-- transport/http2_server.go | 23 +-- transport/http_util.go | 67 ++++++-- transport/http_util_test.go | 29 ++-- transport/transport.go | 21 +++ 22 files changed, 775 insertions(+), 319 deletions(-) create mode 100644 Documentation/encoding.md create mode 100644 encoding/proto/proto.go rename codec_benchmark_test.go => encoding/proto/proto_benchmark_test.go (83%) create mode 100644 encoding/proto/proto_test.go diff --git a/Documentation/encoding.md b/Documentation/encoding.md new file mode 100644 index 00000000..31436609 --- /dev/null +++ b/Documentation/encoding.md @@ -0,0 +1,146 @@ +# Encoding + +The gRPC API for sending and receiving is based upon *messages*. However, +messages cannot be transmitted directly over a network; they must first be +converted into *bytes*. This document describes how gRPC-Go converts messages +into bytes and vice-versa for the purposes of network transmission. + +## Codecs (Serialization and Deserialization) + +A `Codec` contains code to serialize a message into a byte slice (`Marshal`) and +deserialize a byte slice back into a message (`Unmarshal`). `Codec`s are +registered by name into a global registry maintained in the `encoding` package. + +### Implementing a `Codec` + +A typical `Codec` will be implemented in its own package with an `init` function +that registers itself, and is imported anonymously. For example: + +```go +package proto + +import "google.golang.org/grpc/encoding" + +func init() { + encoding.RegisterCodec(protoCodec{}) +} + +// ... implementation of protoCodec ... +``` + +For an example, gRPC's implementation of the `proto` codec can be found in +[`encoding/proto`](https://godoc.org/google.golang.org/grpc/encoding/proto). + +### Using a `Codec` + +By default, gRPC registers and uses the "proto" codec, so it is not necessary to +do this in your own code to send and receive proto messages. To use another +`Codec` from a client or server: + +```go +package myclient + +import _ "path/to/another/codec" +``` + +`Codec`s, by definition, must be symmetric, so the same desired `Codec` should +be registered in both client and server binaries. + +On the client-side, to specify a `Codec` to use for message transmission, the +`CallOption` `CallContentSubtype` should be used as follows: + +```go + response, err := myclient.MyCall(ctx, request, grpc.CallContentSubtype("mycodec")) +``` + +As a reminder, all `CallOption`s may be converted into `DialOption`s that become +the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`: + +```go + myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.CallContentSubtype("mycodec"))) +``` + +When specified in either of these ways, messages will be encoded using this +codec and sent along with headers indicating the codec (`content-type` set to +`application/grpc+`). + +On the server-side, using a `Codec` is as simple as registering it into the +global registry (i.e. `import`ing it). If a message is encoded with the content +sub-type supported by a registered `Codec`, it will be used automatically for +decoding the request and encoding the response. Otherwise, for +backward-compatibility reasons, gRPC will attempt to use the "proto" codec. In +an upcoming change (tracked in [this +issue](https://github.com/grpc/grpc-go/issues/1824)), such requests will be +rejected with status code `Unimplemented` instead. + +## Compressors (Compression and Decompression) + +Sometimes, the resulting serialization of a message is not space-efficient, and +it may be beneficial to compress this byte stream before transmitting it over +the network. To facilitate this operation, gRPC supports a mechanism for +performing compression and decompression. + +A `Compressor` contains code to compress and decompress by wrapping `io.Writer`s +and `io.Reader`s, respectively. (The form of `Compress` and `Decompress` were +chosen to most closely match Go's standard package +[implementations](https://golang.org/pkg/compress/) of compressors. Like +`Codec`s, `Compressor`s are registered by name into a global registry maintained +in the `encoding` package. + +### Implementing a `Compressor` + +A typical `Compressor` will be implemented in its own package with an `init` +function that registers itself, and is imported anonymously. For example: + +```go +package gzip + +import "google.golang.org/grpc/encoding" + +func init() { + encoding.RegisterCompressor(compressor{}) +} + +// ... implementation of compressor ... +``` + +An implementation of a `gzip` compressor can be found in +[`encoding/gzip`](https://godoc.org/google.golang.org/grpc/encoding/gzip). + +### Using a `Compressor` + +By default, gRPC does not register or use any compressors. To use a +`Compressor` from a client or server: + +```go +package myclient + +import _ "google.golang.org/grpc/encoding/gzip" +``` + +`Compressor`s, by definition, must be symmetric, so the same desired +`Compressor` should be registered in both client and server binaries. + +On the client-side, to specify a `Compressor` to use for message transmission, +the `CallOption` `UseCompressor` should be used as follows: + +```go + response, err := myclient.MyCall(ctx, request, grpc.UseCompressor("gzip")) +``` + +As a reminder, all `CallOption`s may be converted into `DialOption`s that become +the default for all RPCs sent through a client using `grpc.WithDefaultCallOptions`: + +```go + myclient := grpc.Dial(ctx, target, grpc.WithDefaultCallOptions(grpc.UseCompresor("gzip"))) +``` + +When specified in either of these ways, messages will be compressed using this +compressor and sent along with headers indicating the compressor +(`content-coding` set to ``). + +On the server-side, using a `Compressor` is as simple as registering it into the +global registry (i.e. `import`ing it). If a message is compressed with the +content coding supported by a registered `Compressor`, it will be used +automatically for decompressing the request and compressing the response. +Otherwise, the request will be rejected with status code `Unimplemented`. diff --git a/benchmark/worker/benchmark_client.go b/benchmark/worker/benchmark_client.go index 10d82aee..f2a26503 100644 --- a/benchmark/worker/benchmark_client.go +++ b/benchmark/worker/benchmark_client.go @@ -139,7 +139,7 @@ func createConns(config *testpb.ClientConfig) ([]*grpc.ClientConn, func(), error if config.PayloadConfig != nil { switch config.PayloadConfig.Payload.(type) { case *testpb.PayloadConfig_BytebufParams: - opts = append(opts, grpc.WithCodec(byteBufCodec{})) + opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(byteBufCodec{}))) case *testpb.PayloadConfig_SimpleParams: default: return nil, nil, status.Errorf(codes.InvalidArgument, "unknow payload config: %v", config.PayloadConfig) diff --git a/call.go b/call.go index 5bbe9510..8fa542a7 100644 --- a/call.go +++ b/call.go @@ -67,7 +67,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran } else if rc != "" && rc != encoding.Identity { comp = encoding.GetCompressor(rc) } - if err = recv(p, dopts.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil { + if err = recv(p, c.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil { if err == io.EOF { break } @@ -111,7 +111,7 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, return status.Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) } } - hdr, data, err := encode(dopts.codec, args, compressor, outPayload, comp) + hdr, data, err := encode(c.codec, args, compressor, outPayload, comp) if err != nil { return err } @@ -182,6 +182,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize) c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) + if err := setCallInfoCodec(c); err != nil { + return err + } if EnableTracing { c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) diff --git a/clientconn.go b/clientconn.go index 3730c7bf..0cbb38ad 100644 --- a/clientconn.go +++ b/clientconn.go @@ -85,7 +85,6 @@ var ( type dialOptions struct { unaryInt UnaryClientInterceptor streamInt StreamClientInterceptor - codec Codec cp Compressor dc Decompressor bs backoffStrategy @@ -167,10 +166,10 @@ func WithDefaultCallOptions(cos ...CallOption) DialOption { } // WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. +// +// Deprecated: use WithDefaultCallOptions(CallCustomCodec(c)) instead. func WithCodec(c Codec) DialOption { - return func(o *dialOptions) { - o.codec = c - } + return WithDefaultCallOptions(CallCustomCodec(c)) } // WithCompressor returns a DialOption which sets a Compressor to use for @@ -486,10 +485,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * default: } } - // Set defaults. - if cc.dopts.codec == nil { - cc.dopts.codec = protoCodec{} - } if cc.dopts.bs == nil { cc.dopts.bs = DefaultBackoffConfig } diff --git a/codec.go b/codec.go index 43d81ed2..12977654 100644 --- a/codec.go +++ b/codec.go @@ -19,96 +19,32 @@ package grpc import ( - "math" - "sync" - - "github.com/golang/protobuf/proto" + "google.golang.org/grpc/encoding" + _ "google.golang.org/grpc/encoding/proto" // to register the Codec for "proto" ) +// baseCodec contains the functionality of both Codec and encoding.Codec, but +// omits the name/string, which vary between the two and are not needed for +// anything besides the registry in the encoding package. +type baseCodec interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error +} + +var _ baseCodec = Codec(nil) +var _ baseCodec = encoding.Codec(nil) + // Codec defines the interface gRPC uses to encode and decode messages. // Note that implementations of this interface must be thread safe; // a Codec's methods can be called from concurrent goroutines. +// +// Deprecated: use encoding.Codec instead. type Codec interface { // Marshal returns the wire format of v. Marshal(v interface{}) ([]byte, error) // Unmarshal parses the wire format into v. Unmarshal(data []byte, v interface{}) error - // String returns the name of the Codec implementation. The returned - // string will be used as part of content type in transmission. + // String returns the name of the Codec implementation. This is unused by + // gRPC. String() string } - -// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC. -type protoCodec struct { -} - -type cachedProtoBuffer struct { - lastMarshaledSize uint32 - proto.Buffer -} - -func capToMaxInt32(val int) uint32 { - if val > math.MaxInt32 { - return uint32(math.MaxInt32) - } - return uint32(val) -} - -func (p protoCodec) marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) { - protoMsg := v.(proto.Message) - newSlice := make([]byte, 0, cb.lastMarshaledSize) - - cb.SetBuf(newSlice) - cb.Reset() - if err := cb.Marshal(protoMsg); err != nil { - return nil, err - } - out := cb.Bytes() - cb.lastMarshaledSize = capToMaxInt32(len(out)) - return out, nil -} - -func (p protoCodec) Marshal(v interface{}) ([]byte, error) { - if pm, ok := v.(proto.Marshaler); ok { - // object can marshal itself, no need for buffer - return pm.Marshal() - } - - cb := protoBufferPool.Get().(*cachedProtoBuffer) - out, err := p.marshal(v, cb) - - // put back buffer and lose the ref to the slice - cb.SetBuf(nil) - protoBufferPool.Put(cb) - return out, err -} - -func (p protoCodec) Unmarshal(data []byte, v interface{}) error { - protoMsg := v.(proto.Message) - protoMsg.Reset() - - if pu, ok := protoMsg.(proto.Unmarshaler); ok { - // object can unmarshal itself, no need for buffer - return pu.Unmarshal(data) - } - - cb := protoBufferPool.Get().(*cachedProtoBuffer) - cb.SetBuf(data) - err := cb.Unmarshal(protoMsg) - cb.SetBuf(nil) - protoBufferPool.Put(cb) - return err -} - -func (protoCodec) String() string { - return "proto" -} - -var protoBufferPool = &sync.Pool{ - New: func() interface{} { - return &cachedProtoBuffer{ - Buffer: proto.Buffer{}, - lastMarshaledSize: 16, - } - }, -} diff --git a/codec_test.go b/codec_test.go index 246b13b0..3fda708e 100644 --- a/codec_test.go +++ b/codec_test.go @@ -19,110 +19,14 @@ package grpc import ( - "bytes" - "sync" "testing" - "google.golang.org/grpc/test/codec_perf" + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/encoding/proto" ) -func marshalAndUnmarshal(t *testing.T, protoCodec Codec, expectedBody []byte) { - p := &codec_perf.Buffer{} - p.Body = expectedBody - - marshalledBytes, err := protoCodec.Marshal(p) - if err != nil { - t.Errorf("protoCodec.Marshal(_) returned an error") - } - - if err := protoCodec.Unmarshal(marshalledBytes, p); err != nil { - t.Errorf("protoCodec.Unmarshal(_) returned an error") - } - - if bytes.Compare(p.GetBody(), expectedBody) != 0 { - t.Errorf("Unexpected body; got %v; want %v", p.GetBody(), expectedBody) - } -} - -func TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) { - marshalAndUnmarshal(t, protoCodec{}, []byte{1, 2, 3}) -} - -// Try to catch possible race conditions around use of pools -func TestConcurrentUsage(t *testing.T) { - const ( - numGoRoutines = 100 - numMarshUnmarsh = 1000 - ) - - // small, arbitrary byte slices - protoBodies := [][]byte{ - []byte("one"), - []byte("two"), - []byte("three"), - []byte("four"), - []byte("five"), - } - - var wg sync.WaitGroup - codec := protoCodec{} - - for i := 0; i < numGoRoutines; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for k := 0; k < numMarshUnmarsh; k++ { - marshalAndUnmarshal(t, codec, protoBodies[k%len(protoBodies)]) - } - }() - } - - wg.Wait() -} - -// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get -// stomped on during reuse of a proto.Buffer. -func TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) { - codec1 := protoCodec{} - codec2 := protoCodec{} - - expectedBody1 := []byte{1, 2, 3} - expectedBody2 := []byte{4, 5, 6} - - proto1 := codec_perf.Buffer{Body: expectedBody1} - proto2 := codec_perf.Buffer{Body: expectedBody2} - - var m1, m2 []byte - var err error - - if m1, err = codec1.Marshal(&proto1); err != nil { - t.Errorf("protoCodec.Marshal(%v) failed", proto1) - } - - if m2, err = codec2.Marshal(&proto2); err != nil { - t.Errorf("protoCodec.Marshal(%v) failed", proto2) - } - - if err = codec1.Unmarshal(m1, &proto1); err != nil { - t.Errorf("protoCodec.Unmarshal(%v) failed", m1) - } - - if err = codec2.Unmarshal(m2, &proto2); err != nil { - t.Errorf("protoCodec.Unmarshal(%v) failed", m2) - } - - b1 := proto1.GetBody() - b2 := proto2.GetBody() - - for i, v := range b1 { - if expectedBody1[i] != v { - t.Errorf("expected %v at index %v but got %v", i, expectedBody1[i], v) - } - } - - for i, v := range b2 { - if expectedBody2[i] != v { - t.Errorf("expected %v at index %v but got %v", i, expectedBody2[i], v) - } +func TestGetCodecForProtoIsNotNil(t *testing.T) { + if encoding.GetCodec(proto.Name) == nil { + t.Fatalf("encoding.GetCodec(%q) must not be nil by default", proto.Name) } } diff --git a/encoding/encoding.go b/encoding/encoding.go index 47d10b07..8e26c194 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -16,46 +16,103 @@ * */ -// Package encoding defines the interface for the compressor and the functions -// to register and get the compossor. +// Package encoding defines the interface for the compressor and codec, and +// functions to register and retrieve compressors and codecs. +// // This package is EXPERIMENTAL. package encoding import ( "io" + "strings" ) -var registerCompressor = make(map[string]Compressor) - -// Compressor is used for compressing and decompressing when sending or receiving messages. -type Compressor interface { - // Compress writes the data written to wc to w after compressing it. If an error - // occurs while initializing the compressor, that error is returned instead. - Compress(w io.Writer) (io.WriteCloser, error) - // Decompress reads data from r, decompresses it, and provides the uncompressed data - // via the returned io.Reader. If an error occurs while initializing the decompressor, that error - // is returned instead. - Decompress(r io.Reader) (io.Reader, error) - // Name is the name of the compression codec and is used to set the content coding header. - Name() string -} - -// RegisterCompressor registers the compressor with gRPC by its name. It can be activated when -// sending an RPC via grpc.UseCompressor(). It will be automatically accessed when receiving a -// message based on the content coding header. Servers also use it to send a response with the -// same encoding as the request. -// -// NOTE: this function must only be called during initialization time (i.e. in an init() function). If -// multiple Compressors are registered with the same name, the one registered last will take effect. -func RegisterCompressor(c Compressor) { - registerCompressor[c.Name()] = c -} - -// GetCompressor returns Compressor for the given compressor name. -func GetCompressor(name string) Compressor { - return registerCompressor[name] -} - // Identity specifies the optional encoding for uncompressed streams. // It is intended for grpc internal use only. const Identity = "identity" + +// Compressor is used for compressing and decompressing when sending or +// receiving messages. +type Compressor interface { + // Compress writes the data written to wc to w after compressing it. If an + // error occurs while initializing the compressor, that error is returned + // instead. + Compress(w io.Writer) (io.WriteCloser, error) + // Decompress reads data from r, decompresses it, and provides the + // uncompressed data via the returned io.Reader. If an error occurs while + // initializing the decompressor, that error is returned instead. + Decompress(r io.Reader) (io.Reader, error) + // Name is the name of the compression codec and is used to set the content + // coding header. The result must be static; the result cannot change + // between calls. + Name() string +} + +var registeredCompressor = make(map[string]Compressor) + +// RegisterCompressor registers the compressor with gRPC by its name. It can +// be activated when sending an RPC via grpc.UseCompressor(). It will be +// automatically accessed when receiving a message based on the content coding +// header. Servers also use it to send a response with the same encoding as +// the request. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. If multiple Compressors are +// registered with the same name, the one registered last will take effect. +func RegisterCompressor(c Compressor) { + registeredCompressor[c.Name()] = c +} + +// GetCompressor returns Compressor for the given compressor name. +func GetCompressor(name string) Compressor { + return registeredCompressor[name] +} + +// Codec defines the interface gRPC uses to encode and decode messages. Note +// that implementations of this interface must be thread safe; a Codec's +// methods can be called from concurrent goroutines. +type Codec interface { + // Marshal returns the wire format of v. + Marshal(v interface{}) ([]byte, error) + // Unmarshal parses the wire format into v. + Unmarshal(data []byte, v interface{}) error + // Name returns the name of the Codec implementation. The returned string + // will be used as part of content type in transmission. The result must be + // static; the result cannot change between calls. + Name() string +} + +var registeredCodecs = make(map[string]Codec, 0) + +// RegisterCodec registers the provided Codec for use with all gRPC clients and +// servers. +// +// The Codec will be stored and looked up by result of its Name() method, which +// should match the content-subtype of the encoding handled by the Codec. This +// is case-insensitive, and is stored and looked up as lowercase. If the +// result of calling Name() is an empty string, RegisterCodec will panic. See +// Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. If multiple Compressors are +// registered with the same name, the one registered last will take effect. +func RegisterCodec(codec Codec) { + if codec == nil { + panic("cannot register a nil Codec") + } + contentSubtype := strings.ToLower(codec.Name()) + if contentSubtype == "" { + panic("cannot register Codec with empty string result for String()") + } + registeredCodecs[contentSubtype] = codec +} + +// GetCodec gets a registered Codec by content-subtype, or nil if no Codec is +// registered for the content-subtype. +// +// The content-subtype is expected to be lowercase. +func GetCodec(contentSubtype string) Codec { + return registeredCodecs[contentSubtype] +} diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go new file mode 100644 index 00000000..a0c6ee93 --- /dev/null +++ b/encoding/proto/proto.go @@ -0,0 +1,108 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package proto + +import ( + "math" + "sync" + + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/encoding" +) + +// Name is the name registered for the proto compressor. +const Name = "proto" + +func init() { + encoding.RegisterCodec(codec{}) +} + +// codec is a Codec implementation with protobuf. It is the default codec for gRPC. +type codec struct{} + +type cachedProtoBuffer struct { + lastMarshaledSize uint32 + proto.Buffer +} + +func capToMaxInt32(val int) uint32 { + if val > math.MaxInt32 { + return uint32(math.MaxInt32) + } + return uint32(val) +} + +func marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) { + protoMsg := v.(proto.Message) + newSlice := make([]byte, 0, cb.lastMarshaledSize) + + cb.SetBuf(newSlice) + cb.Reset() + if err := cb.Marshal(protoMsg); err != nil { + return nil, err + } + out := cb.Bytes() + cb.lastMarshaledSize = capToMaxInt32(len(out)) + return out, nil +} + +func (codec) Marshal(v interface{}) ([]byte, error) { + if pm, ok := v.(proto.Marshaler); ok { + // object can marshal itself, no need for buffer + return pm.Marshal() + } + + cb := protoBufferPool.Get().(*cachedProtoBuffer) + out, err := marshal(v, cb) + + // put back buffer and lose the ref to the slice + cb.SetBuf(nil) + protoBufferPool.Put(cb) + return out, err +} + +func (codec) Unmarshal(data []byte, v interface{}) error { + protoMsg := v.(proto.Message) + protoMsg.Reset() + + if pu, ok := protoMsg.(proto.Unmarshaler); ok { + // object can unmarshal itself, no need for buffer + return pu.Unmarshal(data) + } + + cb := protoBufferPool.Get().(*cachedProtoBuffer) + cb.SetBuf(data) + err := cb.Unmarshal(protoMsg) + cb.SetBuf(nil) + protoBufferPool.Put(cb) + return err +} + +func (codec) Name() string { + return Name +} + +var protoBufferPool = &sync.Pool{ + New: func() interface{} { + return &cachedProtoBuffer{ + Buffer: proto.Buffer{}, + lastMarshaledSize: 16, + } + }, +} diff --git a/codec_benchmark_test.go b/encoding/proto/proto_benchmark_test.go similarity index 83% rename from codec_benchmark_test.go rename to encoding/proto/proto_benchmark_test.go index 2286fd81..63ea57de 100644 --- a/codec_benchmark_test.go +++ b/encoding/proto/proto_benchmark_test.go @@ -18,13 +18,14 @@ * */ -package grpc +package proto import ( "fmt" "testing" "github.com/golang/protobuf/proto" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/test/codec_perf" ) @@ -68,7 +69,7 @@ func BenchmarkProtoCodec(b *testing.B) { protoStructs := setupBenchmarkProtoCodecInputs(s) name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", s, p) b.Run(name, func(b *testing.B) { - codec := &protoCodec{} + codec := &codec{} b.SetParallelism(p) b.RunParallel(func(pb *testing.PB) { benchmarkProtoCodec(codec, protoStructs, pb, b) @@ -78,7 +79,7 @@ func BenchmarkProtoCodec(b *testing.B) { } } -func benchmarkProtoCodec(codec *protoCodec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) { +func benchmarkProtoCodec(codec *codec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) { counter := 0 for pb.Next() { counter++ @@ -87,13 +88,13 @@ func benchmarkProtoCodec(codec *protoCodec, protoStructs []proto.Message, pb *te } } -func fastMarshalAndUnmarshal(protoCodec Codec, protoStruct proto.Message, b *testing.B) { - marshaledBytes, err := protoCodec.Marshal(protoStruct) +func fastMarshalAndUnmarshal(codec encoding.Codec, protoStruct proto.Message, b *testing.B) { + marshaledBytes, err := codec.Marshal(protoStruct) if err != nil { - b.Errorf("protoCodec.Marshal(_) returned an error") + b.Errorf("codec.Marshal(_) returned an error") } res := codec_perf.Buffer{} - if err := protoCodec.Unmarshal(marshaledBytes, &res); err != nil { - b.Errorf("protoCodec.Unmarshal(_) returned an error") + if err := codec.Unmarshal(marshaledBytes, &res); err != nil { + b.Errorf("codec.Unmarshal(_) returned an error") } } diff --git a/encoding/proto/proto_test.go b/encoding/proto/proto_test.go new file mode 100644 index 00000000..b6a0b666 --- /dev/null +++ b/encoding/proto/proto_test.go @@ -0,0 +1,129 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package proto + +import ( + "bytes" + "sync" + "testing" + + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/test/codec_perf" +) + +func marshalAndUnmarshal(t *testing.T, codec encoding.Codec, expectedBody []byte) { + p := &codec_perf.Buffer{} + p.Body = expectedBody + + marshalledBytes, err := codec.Marshal(p) + if err != nil { + t.Errorf("codec.Marshal(_) returned an error") + } + + if err := codec.Unmarshal(marshalledBytes, p); err != nil { + t.Errorf("codec.Unmarshal(_) returned an error") + } + + if bytes.Compare(p.GetBody(), expectedBody) != 0 { + t.Errorf("Unexpected body; got %v; want %v", p.GetBody(), expectedBody) + } +} + +func TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) { + marshalAndUnmarshal(t, codec{}, []byte{1, 2, 3}) +} + +// Try to catch possible race conditions around use of pools +func TestConcurrentUsage(t *testing.T) { + const ( + numGoRoutines = 100 + numMarshUnmarsh = 1000 + ) + + // small, arbitrary byte slices + protoBodies := [][]byte{ + []byte("one"), + []byte("two"), + []byte("three"), + []byte("four"), + []byte("five"), + } + + var wg sync.WaitGroup + codec := codec{} + + for i := 0; i < numGoRoutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for k := 0; k < numMarshUnmarsh; k++ { + marshalAndUnmarshal(t, codec, protoBodies[k%len(protoBodies)]) + } + }() + } + + wg.Wait() +} + +// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get +// stomped on during reuse of a proto.Buffer. +func TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) { + codec1 := codec{} + codec2 := codec{} + + expectedBody1 := []byte{1, 2, 3} + expectedBody2 := []byte{4, 5, 6} + + proto1 := codec_perf.Buffer{Body: expectedBody1} + proto2 := codec_perf.Buffer{Body: expectedBody2} + + var m1, m2 []byte + var err error + + if m1, err = codec1.Marshal(&proto1); err != nil { + t.Errorf("codec.Marshal(%v) failed", proto1) + } + + if m2, err = codec2.Marshal(&proto2); err != nil { + t.Errorf("codec.Marshal(%v) failed", proto2) + } + + if err = codec1.Unmarshal(m1, &proto1); err != nil { + t.Errorf("codec.Unmarshal(%v) failed", m1) + } + + if err = codec2.Unmarshal(m2, &proto2); err != nil { + t.Errorf("codec.Unmarshal(%v) failed", m2) + } + + b1 := proto1.GetBody() + b2 := proto2.GetBody() + + for i, v := range b1 { + if expectedBody1[i] != v { + t.Errorf("expected %v at index %v but got %v", i, expectedBody1[i], v) + } + } + + for i, v := range b2 { + if expectedBody2[i] != v { + t.Errorf("expected %v at index %v but got %v", i, expectedBody2[i], v) + } + } +} diff --git a/rpc_util.go b/rpc_util.go index ea673834..949fa05b 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -25,6 +25,7 @@ import ( "io" "io/ioutil" "math" + "strings" "sync" "time" @@ -32,6 +33,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/stats" @@ -130,6 +132,8 @@ type callInfo struct { maxReceiveMessageSize *int maxSendMessageSize *int creds credentials.PerRPCCredentials + contentSubtype string + codec baseCodec } func defaultCallInfo() *callInfo { @@ -252,6 +256,49 @@ func UseCompressor(name string) CallOption { }) } +// CallContentSubtype returns a CallOption that will set the content-subtype +// for a call. For example, if content-subtype is "json", the Content-Type over +// the wire will be "application/grpc+json". The content-subtype is converted +// to lowercase before being included in Content-Type. See Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +// +// If CallCustomCodec is not also used, the content-subtype will be used to +// look up the Codec to use in the registry controlled by RegisterCodec. See +// the documention on RegisterCodec for details on registration. The lookup +// of content-subtype is case-insensitive. If no such Codec is found, the call +// will result in an error with code codes.Internal. +// +// If CallCustomCodec is also used, that Codec will be used for all request and +// response messages, with the content-subtype set to the given contentSubtype +// here for requests. +func CallContentSubtype(contentSubtype string) CallOption { + contentSubtype = strings.ToLower(contentSubtype) + return beforeCall(func(c *callInfo) error { + c.contentSubtype = contentSubtype + return nil + }) +} + +// CallCustomCodec returns a CallOption that will set the given Codec to be +// used for all request and response messages for a call. The result of calling +// String() will be used as the content-subtype in a case-insensitive manner. +// +// See Content-Type on +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. Also see the documentation on RegisterCodec and +// CallContentSubtype for more details on the interaction between Codec and +// content-subtype. +// +// This function is provided for advanced users; prefer to use only +// CallContentSubtype to select a registered codec instead. +func CallCustomCodec(codec Codec) CallOption { + return beforeCall(func(c *callInfo) error { + c.codec = codec + return nil + }) +} + // The format of the payload: compressed or not? type payloadFormat uint8 @@ -267,8 +314,8 @@ type parser struct { // error types. r io.Reader - // The header of a gRPC message. Find more detail - // at https://grpc.io/docs/guides/wire.html. + // The header of a gRPC message. Find more detail at + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md header [5]byte } @@ -317,7 +364,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt // encode serializes msg and returns a buffer of message header and a buffer of msg. // If msg is nil, it generates the message header and an empty msg buffer. // TODO(ddyihai): eliminate extra Compressor parameter. -func encode(c Codec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { +func encode(c baseCodec, msg interface{}, cp Compressor, outPayload *stats.OutPayload, compressor encoding.Compressor) ([]byte, []byte, error) { var ( b []byte cbuf *bytes.Buffer @@ -394,7 +441,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, haveCompressor bool // For the two compressor parameters, both should not be set, but if they are, // dc takes precedence over compressor. // TODO(dfawley): wrap the old compressor/decompressor using the new API? -func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { +func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload, compressor encoding.Compressor) error { pf, d, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return err @@ -489,6 +536,27 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { return status.Errorf(c, format, a...) } +// setCallInfoCodec should only be called after CallOptions have been applied. +func setCallInfoCodec(c *callInfo) error { + if c.codec != nil { + // codec was already set by a CallOption; use it. + return nil + } + + if c.contentSubtype == "" { + // No codec specified in CallOptions; use proto by default. + c.codec = encoding.GetCodec(proto.Name) + return nil + } + + // c.contentSubtype is already lowercased in CallContentSubtype + c.codec = encoding.GetCodec(c.contentSubtype) + if c.codec == nil { + return status.Errorf(codes.Internal, "no codec registered for content-subtype %s", c.contentSubtype) + } + return nil +} + // The SupportPackageIsVersion variables are referenced from generated protocol // buffer files to ensure compatibility with the gRPC version used. The latest // support package version is 5. diff --git a/rpc_util_test.go b/rpc_util_test.go index 23c471e2..6e4f85ad 100644 --- a/rpc_util_test.go +++ b/rpc_util_test.go @@ -27,6 +27,8 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc/codes" + "google.golang.org/grpc/encoding" + protoenc "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/status" perfpb "google.golang.org/grpc/test/codec_perf" "google.golang.org/grpc/transport" @@ -110,7 +112,7 @@ func TestEncode(t *testing.T) { }{ {nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil}, } { - hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil) + hdr, data, err := encode(encoding.GetCodec(protoenc.Name), test.msg, nil, nil, nil) if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) { t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err) } @@ -164,13 +166,14 @@ func TestToRPCErr(t *testing.T) { // bmEncode benchmarks encoding a Protocol Buffer message containing mSize // bytes. func bmEncode(b *testing.B, mSize int) { + cdc := encoding.GetCodec(protoenc.Name) msg := &perfpb.Buffer{Body: make([]byte, mSize)} - encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil) + encodeHdr, encodeData, _ := encode(cdc, msg, nil, nil, nil) encodedSz := int64(len(encodeHdr) + len(encodeData)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - encode(protoCodec{}, msg, nil, nil, nil) + encode(cdc, msg, nil, nil, nil) } b.SetBytes(encodedSz) } diff --git a/server.go b/server.go index f6516216..b80594f9 100644 --- a/server.go +++ b/server.go @@ -40,6 +40,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" + "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal" "google.golang.org/grpc/keepalive" @@ -105,7 +106,7 @@ type Server struct { type options struct { creds credentials.TransportCredentials - codec Codec + codec baseCodec cp Compressor dc Decompressor unaryInt UnaryServerInterceptor @@ -182,6 +183,8 @@ func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { } // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. +// +// This will override any lookups by content-subtype for Codecs registered with RegisterCodec. func CustomCodec(codec Codec) ServerOption { return func(o *options) { o.codec = codec @@ -327,10 +330,6 @@ func NewServer(opt ...ServerOption) *Server { for _, o := range opt { o(&opts) } - if opts.codec == nil { - // Set the default codec. - opts.codec = protoCodec{} - } s := &Server{ lis: make(map[net.Listener]bool), opts: opts, @@ -759,7 +758,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str if s.opts.statsHandler != nil { outPayload = &stats.OutPayload{} } - hdr, data, err := encode(s.opts.codec, msg, cp, outPayload, comp) + hdr, data, err := encode(s.getCodec(stream.ContentSubtype()), msg, cp, outPayload, comp) if err != nil { grpclog.Errorln("grpc: server failed to encode response: ", err) return err @@ -904,7 +903,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. // java implementation. return status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", len(req), s.opts.maxReceiveMessageSize) } - if err := s.opts.codec.Unmarshal(req, v); err != nil { + if err := s.getCodec(stream.ContentSubtype()).Unmarshal(req, v); err != nil { return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } if inPayload != nil { @@ -996,7 +995,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp t: t, s: stream, p: &parser{r: stream}, - codec: s.opts.codec, + codec: s.getCodec(stream.ContentSubtype()), maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, @@ -1262,6 +1261,22 @@ func init() { } } +// contentSubtype must be lowercase +// cannot return nil +func (s *Server) getCodec(contentSubtype string) baseCodec { + if s.opts.codec != nil { + return s.opts.codec + } + if contentSubtype == "" { + return encoding.GetCodec(proto.Name) + } + codec := encoding.GetCodec(contentSubtype) + if codec == nil { + return encoding.GetCodec(proto.Name) + } + return codec +} + // SetHeader sets the header metadata. // When called multiple times, all the provided metadata will be merged. // All the metadata will be sent out when one of the following happens: diff --git a/stream.go b/stream.go index 1b777ec1..8189e832 100644 --- a/stream.go +++ b/stream.go @@ -142,6 +142,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize) c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) + if err := setCallInfoCodec(c); err != nil { + return nil, err + } callHdr := &transport.CallHdr{ Host: cc.authority, @@ -150,7 +153,8 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth // so we don't flush the header. // If it's client streaming, the user may never send a request or send it any // time soon, so we ask the transport to flush the header. - Flush: desc.ClientStreams, + Flush: desc.ClientStreams, + ContentSubtype: c.contentSubtype, } // Set our outgoing compression according to the UseCompressor CallOption, if @@ -259,7 +263,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth opts: opts, c: c, desc: desc, - codec: cc.dopts.codec, + codec: c.codec, cp: cp, dc: cc.dopts.dc, comp: comp, @@ -311,7 +315,7 @@ type clientStream struct { p *parser desc *StreamDesc - codec Codec + codec baseCodec cp Compressor dc Decompressor comp encoding.Compressor @@ -591,7 +595,7 @@ type serverStream struct { t transport.ServerTransport s *transport.Stream p *parser - codec Codec + codec baseCodec cp Compressor dc Decompressor diff --git a/test/end2end_test.go b/test/end2end_test.go index 140673ef..48ccdca6 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -705,7 +705,7 @@ func (te *test) clientConn() *grpc.ClientConn { opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds)) } if te.customCodec != nil { - opts = append(opts, grpc.WithCodec(te.customCodec)) + opts = append(opts, grpc.WithDefaultCallOptions(grpc.CallCustomCodec(te.customCodec))) } if !te.nonBlockingDial && te.srvAddr != "" { // Only do a blocking dial if server is up. @@ -2607,6 +2607,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) { delete(header, "trailer") // RFC 2616 says server SHOULD (but optional) declare trailers delete(header, "date") // the Date header is also optional delete(header, "user-agent") + delete(header, "content-type") } if !reflect.DeepEqual(header, testMetadata) { t.Fatalf("Received header metadata %v, want %v", header, testMetadata) @@ -2723,6 +2724,7 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } delete(header, "user-agent") + delete(header, "content-type") expectedHeader := metadata.Join(testMetadata, testMetadata2) if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) @@ -2767,6 +2769,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } delete(header, "user-agent") + delete(header, "content-type") expectedHeader := metadata.Join(testMetadata, testMetadata2) if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) @@ -2810,6 +2813,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } delete(header, "user-agent") + delete(header, "content-type") expectedHeader := metadata.Join(testMetadata, testMetadata2) if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) @@ -2854,6 +2858,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) { t.Fatalf("%v.Header() = _, %v, want _, ", stream, err) } delete(header, "user-agent") + delete(header, "content-type") expectedHeader := metadata.Join(testMetadata, testMetadata2) if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) @@ -2917,6 +2922,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) { t.Fatalf("%v.Header() = _, %v, want _, ", stream, err) } delete(header, "user-agent") + delete(header, "content-type") expectedHeader := metadata.Join(testMetadata, testMetadata2) if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) @@ -2975,6 +2981,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) { t.Fatalf("%v.Header() = _, %v, want _, ", stream, err) } delete(header, "user-agent") + delete(header, "content-type") expectedHeader := metadata.Join(testMetadata, testMetadata2) if !reflect.DeepEqual(header, expectedHeader) { t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) @@ -3335,6 +3342,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) { } delete(headerMD, "trailer") // ignore if present delete(headerMD, "user-agent") + delete(headerMD, "content-type") if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { t.Errorf("#1 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) } @@ -3342,6 +3350,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) { headerMD, err = stream.Header() delete(headerMD, "trailer") // ignore if present delete(headerMD, "user-agent") + delete(headerMD, "content-type") if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { t.Errorf("#2 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) } diff --git a/transport/handler_server.go b/transport/handler_server.go index 27c4ebb5..ce8ebece 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -53,7 +53,10 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr if r.Method != "POST" { return nil, errors.New("invalid gRPC request method") } - if !validContentType(r.Header.Get("Content-Type")) { + contentType := r.Header.Get("Content-Type") + // TODO: do we assume contentType is lowercase? we did before + contentSubtype, validContentType := contentSubtype(contentType) + if !validContentType { return nil, errors.New("invalid gRPC request content-type") } if _, ok := w.(http.Flusher); !ok { @@ -64,10 +67,12 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr } st := &serverHandlerTransport{ - rw: w, - req: r, - closedCh: make(chan struct{}), - writes: make(chan func()), + rw: w, + req: r, + closedCh: make(chan struct{}), + writes: make(chan func()), + contentType: contentType, + contentSubtype: contentSubtype, } if v := r.Header.Get("grpc-timeout"); v != "" { @@ -79,7 +84,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr st.timeout = to } - var metakv []string + metakv := []string{"content-type", contentType} if r.Host != "" { metakv = append(metakv, ":authority", r.Host) } @@ -126,6 +131,12 @@ type serverHandlerTransport struct { // block concurrent WriteStatus calls // e.g. grpc/(*serverStream).SendMsg/RecvMsg writeStatusMu sync.Mutex + + // we just mirror the request content-type + contentType string + // we store both contentType and contentSubtype so we don't keep recreating them + // TODO make sure this is consistent across handler_server and http2_server + contentSubtype string } func (ht *serverHandlerTransport) Close() error { @@ -235,7 +246,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { h := ht.rw.Header() h["Date"] = nil // suppress Date to make tests happy; TODO: restore - h.Set("Content-Type", "application/grpc") + h.Set("Content-Type", ht.contentType) // Predeclare trailers we'll set later in WriteStatus (after the body). // This is a SHOULD in the HTTP RFC, and the way you add (known) @@ -313,13 +324,14 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace req := ht.req s := &Stream{ - id: 0, // irrelevant - requestRead: func(int) {}, - cancel: cancel, - buf: newRecvBuffer(), - st: ht, - method: req.URL.Path, - recvCompress: req.Header.Get("grpc-encoding"), + id: 0, // irrelevant + requestRead: func(int) {}, + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: req.URL.Path, + recvCompress: req.Header.Get("grpc-encoding"), + contentSubtype: ht.contentSubtype, } pr := &peer.Peer{ Addr: ht.RemoteAddr(), diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 8505e1a7..b7e9120e 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -199,9 +199,10 @@ func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { }, check: func(ht *serverHandlerTransport, tt *testCase) error { want := metadata.MD{ - "meta-bar": {"bar-val1", "bar-val2"}, - "user-agent": {"x/y a/b"}, - "meta-foo": {"foo-val"}, + "meta-bar": {"bar-val1", "bar-val2"}, + "user-agent": {"x/y a/b"}, + "meta-foo": {"foo-val"}, + "content-type": {"application/grpc"}, } if !reflect.DeepEqual(ht.headerMD, want) { diff --git a/transport/http2_client.go b/transport/http2_client.go index 97709069..964942f0 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -314,15 +314,16 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &Stream{ - id: t.nextID, - done: make(chan struct{}), - goAway: make(chan struct{}), - method: callHdr.Method, - sendCompress: callHdr.SendCompress, - buf: newRecvBuffer(), - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), - headerChan: make(chan struct{}), + id: t.nextID, + done: make(chan struct{}), + goAway: make(chan struct{}), + method: callHdr.Method, + sendCompress: callHdr.SendCompress, + buf: newRecvBuffer(), + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), + headerChan: make(chan struct{}), + contentSubtype: callHdr.ContentSubtype, } t.nextID += 2 s.requestRead = func(n int) { @@ -438,7 +439,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea headerFields = append(headerFields, hpack.HeaderField{Name: ":scheme", Value: t.scheme}) headerFields = append(headerFields, hpack.HeaderField{Name: ":path", Value: callHdr.Method}) headerFields = append(headerFields, hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(callHdr.ContentSubtype)}) headerFields = append(headerFields, hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) headerFields = append(headerFields, hpack.HeaderField{Name: "te", Value: "trailers"}) diff --git a/transport/http2_server.go b/transport/http2_server.go index 6d252c53..a1aed08c 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -281,12 +281,13 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( buf := newRecvBuffer() s := &Stream{ - id: streamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - recvCompress: state.encoding, - method: state.method, + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + recvCompress: state.encoding, + method: state.method, + contentSubtype: state.contentSubtype, } if frame.StreamEnded() { @@ -730,7 +731,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // first and create a slice of that exact size. headerFields := make([]hpack.HeaderField, 0, 2) // at least :status, content-type will be there if none else. headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) if s.sendCompress != "" { headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) } @@ -749,9 +750,9 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { endStream: false, }) if t.stats != nil { - outHeader := &stats.OutHeader{ - //WireLength: // TODO(mmukhi): Revisit this later, if needed. - } + // Note: WireLength is not set in outHeader. + // TODO(mmukhi): Revisit this later, if needed. + outHeader := &stats.OutHeader{} t.stats.HandleRPC(s.Context(), outHeader) } return nil @@ -792,7 +793,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headerFields := make([]hpack.HeaderField, 0, 2) // grpc-status and grpc-message will be there if none else. if !headersSent { headerFields = append(headerFields, hpack.HeaderField{Name: ":status", Value: "200"}) - headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: contentType(s.contentSubtype)}) } headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) diff --git a/transport/http_util.go b/transport/http_util.go index 39f878cf..34476773 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -46,6 +46,12 @@ const ( // http2IOBufSize specifies the buffer size for sending frames. defaultWriteBufSize = 32 * 1024 defaultReadBufSize = 32 * 1024 + // baseContentType is the base content-type for gRPC. This is a valid + // content-type on it's own, but can also include a content-subtype such as + // "proto" as a suffix after "+" or ";". See + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + // for more details. + baseContentType = "application/grpc" ) var ( @@ -111,9 +117,10 @@ type decodeState struct { timeout time.Duration method string // key-value metadata map from the peer. - mdata map[string][]string - statsTags []byte - statsTrace []byte + mdata map[string][]string + statsTags []byte + statsTrace []byte + contentSubtype string } // isReservedHeader checks whether hdr belongs to HTTP2 headers @@ -149,17 +156,44 @@ func isWhitelistedPseudoHeader(hdr string) bool { } } -func validContentType(t string) bool { - e := "application/grpc" - if !strings.HasPrefix(t, e) { - return false +// contentSubtype returns the content-subtype for the given content-type. The +// given content-type must be a valid content-type that starts with +// "application/grpc". A content-subtype will follow "application/grpc" after a +// "+" or ";". See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +// +// If contentType is not a valid content-type for gRPC, the boolean +// will be false, otherwise true. If content-type == "application/grpc", +// "application/grpc+", or "application/grpc;", the boolean will be true, +// but no content-subtype will be returned. +// +// contentType is assumed to be lowercase already. +func contentSubtype(contentType string) (string, bool) { + if contentType == baseContentType { + return "", true } - // Support variations on the content-type - // (e.g. "application/grpc+blah", "application/grpc;blah"). - if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' { - return false + if !strings.HasPrefix(contentType, baseContentType) { + return "", false } - return true + // guaranteed since != baseContentType and has baseContentType prefix + switch contentType[len(baseContentType)] { + case '+', ';': + // this will return true for "application/grpc+" or "application/grpc;" + // which the previous validContentType function tested to be valid, so we + // just say that no content-subtype is specified in this case + return contentType[len(baseContentType)+1:], true + default: + return "", false + } +} + +// contentSubtype is assumed to be lowercase +func contentType(contentSubtype string) string { + if contentSubtype == "" { + return baseContentType + } + return baseContentType + "+" + contentSubtype } func (d *decodeState) status() *status.Status { @@ -247,9 +281,16 @@ func (d *decodeState) addMetadata(k, v string) { func (d *decodeState) processHeaderField(f hpack.HeaderField) error { switch f.Name { case "content-type": - if !validContentType(f.Value) { + contentSubtype, validContentType := contentSubtype(f.Value) + if !validContentType { return streamErrorf(codes.FailedPrecondition, "transport: received the unexpected content-type %q", f.Value) } + d.contentSubtype = contentSubtype + // TODO: do we want to propagate the whole content-type in the metadata, + // or come up with a way to just propagate the content-subtype if it was set? + // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} + // in the metadata? + d.addMetadata(f.Name, f.Value) case "grpc-encoding": d.encoding = f.Value case "grpc-status": diff --git a/transport/http_util_test.go b/transport/http_util_test.go index 4ebb2390..c3754781 100644 --- a/transport/http_util_test.go +++ b/transport/http_util_test.go @@ -72,24 +72,25 @@ func TestTimeoutDecode(t *testing.T) { } } -func TestValidContentType(t *testing.T) { +func TestContentSubtype(t *testing.T) { tests := []struct { - h string - want bool + contentType string + want string + wantValid bool }{ - {"application/grpc", true}, - {"application/grpc+", true}, - {"application/grpc+blah", true}, - {"application/grpc;", true}, - {"application/grpc;blah", true}, - {"application/grpcd", false}, - {"application/grpd", false}, - {"application/grp", false}, + {"application/grpc", "", true}, + {"application/grpc+", "", true}, + {"application/grpc+blah", "blah", true}, + {"application/grpc;", "", true}, + {"application/grpc;blah", "blah", true}, + {"application/grpcd", "", false}, + {"application/grpd", "", false}, + {"application/grp", "", false}, } for _, tt := range tests { - got := validContentType(tt.h) - if got != tt.want { - t.Errorf("validContentType(%q) = %v; want %v", tt.h, got, tt.want) + got, gotValid := contentSubtype(tt.contentType) + if got != tt.want || gotValid != tt.wantValid { + t.Errorf("contentSubtype(%q) = (%v, %v); want (%v, %v)", tt.contentType, got, gotValid, tt.want, tt.wantValid) } } } diff --git a/transport/transport.go b/transport/transport.go index 2e7bcaea..f6cd62c4 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -246,6 +246,10 @@ type Stream struct { bytesReceived bool // indicates whether any bytes have been received on this stream unprocessed bool // set if the server sends a refused stream or GOAWAY including this stream + + // contentSubtype is the content-subtype for requests. + // this must be lowercase or the behavior is undefined. + contentSubtype string } func (s *Stream) waitOnHeader() error { @@ -321,6 +325,15 @@ func (s *Stream) ServerTransport() ServerTransport { return s.st } +// ContentSubtype returns the content-subtype for a request. For example, a +// content-subtype of "proto" will result in a content-type of +// "application/grpc+proto". This will always be lowercase. See +// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests for +// more details. +func (s *Stream) ContentSubtype() string { + return s.contentSubtype +} + // Context returns the context of the stream. func (s *Stream) Context() context.Context { return s.ctx @@ -553,6 +566,14 @@ type CallHdr struct { // for performance purposes. // If it's false, new stream will never be flushed. Flush bool + + // ContentSubtype specifies the content-subtype for a request. For example, a + // content-subtype of "proto" will result in a content-type of + // "application/grpc+proto". The value of ContentSubtype must be all + // lowercase, otherwise the behavior is undefined. See + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests + // for more details. + ContentSubtype string } // ClientTransport is the common interface for all gRPC client-side transport