merge master resolve conflicts
This commit is contained in:
Коммит
8788b75675
|
@ -66,11 +66,11 @@ md := metadata.Pairs(
|
|||
|
||||
## Retrieving metadata from context
|
||||
|
||||
Metadata can be retrieved from context using `FromContext`:
|
||||
Metadata can be retrieved from context using `FromIncomingContext`:
|
||||
|
||||
```go
|
||||
func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
// do something with metadata
|
||||
}
|
||||
```
|
||||
|
@ -88,7 +88,7 @@ To send metadata to server, the client can wrap the metadata into a context usin
|
|||
md := metadata.Pairs("key", "val")
|
||||
|
||||
// create a new context with this metadata
|
||||
ctx := metadata.NewContext(context.Background(), md)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), md)
|
||||
|
||||
// make unary RPC
|
||||
response, err := client.SomeRPC(ctx, someRequest)
|
||||
|
@ -96,6 +96,9 @@ response, err := client.SomeRPC(ctx, someRequest)
|
|||
// or make streaming RPC
|
||||
stream, err := client.SomeStreamingRPC(ctx)
|
||||
```
|
||||
|
||||
To read this back from the context on the client (e.g. in an interceptor) before the RPC is sent, use `FromOutgoingContext`.
|
||||
|
||||
### Receiving metadata
|
||||
|
||||
Metadata that a client can receive includes header and trailer.
|
||||
|
@ -152,7 +155,7 @@ For streaming calls, the server needs to get context from the stream.
|
|||
|
||||
```go
|
||||
func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
// do something with metadata
|
||||
}
|
||||
```
|
||||
|
@ -161,7 +164,7 @@ func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someRespo
|
|||
|
||||
```go
|
||||
func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error {
|
||||
md, ok := metadata.FromContext(stream.Context()) // get context from stream
|
||||
md, ok := metadata.FromIncomingContext(stream.Context()) // get context from stream
|
||||
// do something with metadata
|
||||
}
|
||||
```
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#gRPC-Go
|
||||
# gRPC-Go
|
||||
|
||||
[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc)
|
||||
|
||||
|
|
|
@ -331,6 +331,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
|||
for _, opt := range opts {
|
||||
opt(&cc.dopts)
|
||||
}
|
||||
cc.mkp = cc.dopts.copts.KeepaliveParams
|
||||
|
||||
grpcUA := "grpc-go/" + Version
|
||||
if cc.dopts.copts.UserAgent != "" {
|
||||
|
@ -479,6 +480,8 @@ type ClientConn struct {
|
|||
mu sync.RWMutex
|
||||
sc ServiceConfig
|
||||
conns map[Address]*addrConn
|
||||
// Keepalive parameter can be udated if a GoAway is received.
|
||||
mkp keepalive.ClientParameters
|
||||
}
|
||||
|
||||
// lbWatcher watches the Notify channel of the balancer in cc and manages
|
||||
|
@ -554,6 +557,9 @@ func (cc *ClientConn) resetAddrConn(addr Address, block bool, tearDownErr error)
|
|||
addr: addr,
|
||||
dopts: cc.dopts,
|
||||
}
|
||||
cc.mu.RLock()
|
||||
ac.dopts.copts.KeepaliveParams = cc.mkp
|
||||
cc.mu.RUnlock()
|
||||
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
|
||||
ac.stateCV = sync.NewCond(&ac.mu)
|
||||
if EnableTracing {
|
||||
|
@ -740,6 +746,20 @@ type addrConn struct {
|
|||
tearDownErr error
|
||||
}
|
||||
|
||||
// adjustParams updates parameters used to create transports upon
|
||||
// receiving a GoAway.
|
||||
func (ac *addrConn) adjustParams(r transport.GoAwayReason) {
|
||||
switch r {
|
||||
case transport.TooManyPings:
|
||||
v := 2 * ac.dopts.copts.KeepaliveParams.Time
|
||||
ac.cc.mu.Lock()
|
||||
if v > ac.cc.mkp.Time {
|
||||
ac.cc.mkp.Time = v
|
||||
}
|
||||
ac.cc.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// printf records an event in ac's event log, unless ac has been closed.
|
||||
// REQUIRES ac.mu is held.
|
||||
func (ac *addrConn) printf(format string, a ...interface{}) {
|
||||
|
@ -896,6 +916,7 @@ func (ac *addrConn) transportMonitor() {
|
|||
}
|
||||
return
|
||||
case <-t.GoAway():
|
||||
ac.adjustParams(t.GetGoAwayReason())
|
||||
// If GoAway happens without any network I/O error, ac is closed without shutting down the
|
||||
// underlying transport (the transport will be closed when all the pending RPCs finished or
|
||||
// failed.).
|
||||
|
@ -915,6 +936,7 @@ func (ac *addrConn) transportMonitor() {
|
|||
t.Close()
|
||||
return
|
||||
case <-t.GoAway():
|
||||
ac.adjustParams(t.GetGoAwayReason())
|
||||
ac.cc.resetAddrConn(ac.addr, false, errNetworkIO)
|
||||
return
|
||||
default:
|
||||
|
|
|
@ -41,6 +41,7 @@ import (
|
|||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
const tlsDir = "testdata/"
|
||||
|
@ -306,3 +307,31 @@ func TestNonblockingDialWithEmptyBalancer(t *testing.T) {
|
|||
<-dialDone
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestClientUpdatesParamsAfterGoAway(t *testing.T) {
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen. Err: %v", err)
|
||||
}
|
||||
defer lis.Close()
|
||||
addr := lis.Addr().String()
|
||||
s := NewServer()
|
||||
go s.Serve(lis)
|
||||
defer s.Stop()
|
||||
cc, err := Dial(addr, WithBlock(), WithInsecure(), WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 50 * time.Millisecond,
|
||||
Timeout: 1 * time.Millisecond,
|
||||
PermitWithoutStream: true,
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
|
||||
}
|
||||
defer cc.Close()
|
||||
time.Sleep(1 * time.Second)
|
||||
cc.mu.RLock()
|
||||
defer cc.mu.RUnlock()
|
||||
v := cc.mkp.Time
|
||||
if v < 100*time.Millisecond {
|
||||
t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 100ms", v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2014, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
)
|
||||
|
||||
// Codec defines the interface gRPC uses to encode and decode messages.
|
||||
// Note that implementations of this interface must be thread safe;
|
||||
// a Codec's methods can be called from concurrent goroutines.
|
||||
type Codec interface {
|
||||
// Marshal returns the wire format of v.
|
||||
Marshal(v interface{}) ([]byte, error)
|
||||
// Unmarshal parses the wire format into v.
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
// String returns the name of the Codec implementation. The returned
|
||||
// string will be used as part of content type in transmission.
|
||||
String() string
|
||||
}
|
||||
|
||||
// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
|
||||
type protoCodec struct {
|
||||
}
|
||||
|
||||
type cachedProtoBuffer struct {
|
||||
lastMarshaledSize uint32
|
||||
proto.Buffer
|
||||
}
|
||||
|
||||
func capToMaxInt32(val int) uint32 {
|
||||
if val > math.MaxInt32 {
|
||||
return uint32(math.MaxInt32)
|
||||
}
|
||||
return uint32(val)
|
||||
}
|
||||
|
||||
func (p protoCodec) marshal(v interface{}, cb *cachedProtoBuffer) ([]byte, error) {
|
||||
protoMsg := v.(proto.Message)
|
||||
newSlice := make([]byte, 0, cb.lastMarshaledSize)
|
||||
|
||||
cb.SetBuf(newSlice)
|
||||
cb.Reset()
|
||||
if err := cb.Marshal(protoMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := cb.Bytes()
|
||||
cb.lastMarshaledSize = capToMaxInt32(len(out))
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (p protoCodec) Marshal(v interface{}) ([]byte, error) {
|
||||
cb := protoBufferPool.Get().(*cachedProtoBuffer)
|
||||
out, err := p.marshal(v, cb)
|
||||
|
||||
// put back buffer and lose the ref to the slice
|
||||
cb.SetBuf(nil)
|
||||
protoBufferPool.Put(cb)
|
||||
return out, err
|
||||
}
|
||||
|
||||
func (p protoCodec) Unmarshal(data []byte, v interface{}) error {
|
||||
cb := protoBufferPool.Get().(*cachedProtoBuffer)
|
||||
cb.SetBuf(data)
|
||||
err := cb.Unmarshal(v.(proto.Message))
|
||||
cb.SetBuf(nil)
|
||||
protoBufferPool.Put(cb)
|
||||
return err
|
||||
}
|
||||
|
||||
func (protoCodec) String() string {
|
||||
return "proto"
|
||||
}
|
||||
|
||||
var (
|
||||
protoBufferPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &cachedProtoBuffer{
|
||||
Buffer: proto.Buffer{},
|
||||
lastMarshaledSize: 16,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
|
@ -0,0 +1,115 @@
|
|||
// +build go1.7
|
||||
|
||||
/*
|
||||
*
|
||||
* Copyright 2014, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"google.golang.org/grpc/test/codec_perf"
|
||||
)
|
||||
|
||||
func setupBenchmarkProtoCodecInputs(b *testing.B, payloadBaseSize uint32) []proto.Message {
|
||||
payloadBase := make([]byte, payloadBaseSize)
|
||||
// arbitrary byte slices
|
||||
payloadSuffixes := [][]byte{
|
||||
[]byte("one"),
|
||||
[]byte("two"),
|
||||
[]byte("three"),
|
||||
[]byte("four"),
|
||||
[]byte("five"),
|
||||
}
|
||||
protoStructs := make([]proto.Message, 0)
|
||||
|
||||
for _, p := range payloadSuffixes {
|
||||
ps := &codec_perf.Buffer{}
|
||||
ps.Body = append(payloadBase, p...)
|
||||
protoStructs = append(protoStructs, ps)
|
||||
}
|
||||
|
||||
return protoStructs
|
||||
}
|
||||
|
||||
// The possible use of certain protobuf APIs like the proto.Buffer API potentially involves caching
|
||||
// on our side. This can add checks around memory allocations and possible contention.
|
||||
// Example run: go test -v -run=^$ -bench=BenchmarkProtoCodec -benchmem
|
||||
func BenchmarkProtoCodec(b *testing.B) {
|
||||
// range of message sizes
|
||||
payloadBaseSizes := make([]uint32, 0)
|
||||
for i := uint32(0); i <= 12; i += 4 {
|
||||
payloadBaseSizes = append(payloadBaseSizes, 1<<i)
|
||||
}
|
||||
// range of SetParallelism
|
||||
parallelisms := make([]uint32, 0)
|
||||
for i := uint32(0); i <= 16; i += 4 {
|
||||
parallelisms = append(parallelisms, 1<<i)
|
||||
}
|
||||
for _, s := range payloadBaseSizes {
|
||||
for _, p := range parallelisms {
|
||||
func(parallelism int, payloadBaseSize uint32) {
|
||||
protoStructs := setupBenchmarkProtoCodecInputs(b, payloadBaseSize)
|
||||
name := fmt.Sprintf("MinPayloadSize:%v/SetParallelism(%v)", payloadBaseSize, parallelism)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
codec := &protoCodec{}
|
||||
b.SetParallelism(parallelism)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
benchmarkProtoCodec(codec, protoStructs, pb, b)
|
||||
})
|
||||
})
|
||||
}(int(p), s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkProtoCodec(codec *protoCodec, protoStructs []proto.Message, pb *testing.PB, b *testing.B) {
|
||||
counter := 0
|
||||
for pb.Next() {
|
||||
counter++
|
||||
ps := protoStructs[counter%len(protoStructs)]
|
||||
fastMarshalAndUnmarshal(codec, ps, b)
|
||||
}
|
||||
}
|
||||
|
||||
func fastMarshalAndUnmarshal(protoCodec Codec, protoStruct proto.Message, b *testing.B) {
|
||||
marshaledBytes, err := protoCodec.Marshal(protoStruct)
|
||||
if err != nil {
|
||||
b.Errorf("protoCodec.Marshal(_) returned an error")
|
||||
}
|
||||
if err := protoCodec.Unmarshal(marshaledBytes, protoStruct); err != nil {
|
||||
b.Errorf("protoCodec.Unmarshal(_) returned an error")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2014, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
*/
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/test/codec_perf"
|
||||
)
|
||||
|
||||
func marshalAndUnmarshal(t *testing.T, protoCodec Codec, expectedBody []byte) {
|
||||
p := &codec_perf.Buffer{}
|
||||
p.Body = expectedBody
|
||||
|
||||
marshalledBytes, err := protoCodec.Marshal(p)
|
||||
if err != nil {
|
||||
t.Errorf("protoCodec.Marshal(_) returned an error")
|
||||
}
|
||||
|
||||
if err := protoCodec.Unmarshal(marshalledBytes, p); err != nil {
|
||||
t.Errorf("protoCodec.Unmarshal(_) returned an error")
|
||||
}
|
||||
|
||||
if bytes.Compare(p.GetBody(), expectedBody) != 0 {
|
||||
t.Errorf("Unexpected body; got %v; want %v", p.GetBody(), expectedBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicProtoCodecMarshalAndUnmarshal(t *testing.T) {
|
||||
marshalAndUnmarshal(t, protoCodec{}, []byte{1, 2, 3})
|
||||
}
|
||||
|
||||
// Try to catch possible race conditions around use of pools
|
||||
func TestConcurrentUsage(t *testing.T) {
|
||||
const (
|
||||
numGoRoutines = 100
|
||||
numMarshUnmarsh = 1000
|
||||
)
|
||||
|
||||
// small, arbitrary byte slices
|
||||
protoBodies := [][]byte{
|
||||
[]byte("one"),
|
||||
[]byte("two"),
|
||||
[]byte("three"),
|
||||
[]byte("four"),
|
||||
[]byte("five"),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
codec := protoCodec{}
|
||||
|
||||
for i := 0; i < numGoRoutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for k := 0; k < numMarshUnmarsh; k++ {
|
||||
marshalAndUnmarshal(t, codec, protoBodies[k%len(protoBodies)])
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestStaggeredMarshalAndUnmarshalUsingSamePool tries to catch potential errors in which slices get
|
||||
// stomped on during reuse of a proto.Buffer.
|
||||
func TestStaggeredMarshalAndUnmarshalUsingSamePool(t *testing.T) {
|
||||
codec1 := protoCodec{}
|
||||
codec2 := protoCodec{}
|
||||
|
||||
expectedBody1 := []byte{1, 2, 3}
|
||||
expectedBody2 := []byte{4, 5, 6}
|
||||
|
||||
proto1 := codec_perf.Buffer{Body: expectedBody1}
|
||||
proto2 := codec_perf.Buffer{Body: expectedBody2}
|
||||
|
||||
var m1, m2 []byte
|
||||
var err error
|
||||
|
||||
if m1, err = codec1.Marshal(&proto1); err != nil {
|
||||
t.Errorf("protoCodec.Marshal(%v) failed", proto1)
|
||||
}
|
||||
|
||||
if m2, err = codec2.Marshal(&proto2); err != nil {
|
||||
t.Errorf("protoCodec.Marshal(%v) failed", proto2)
|
||||
}
|
||||
|
||||
if err = codec1.Unmarshal(m1, &proto1); err != nil {
|
||||
t.Errorf("protoCodec.Unmarshal(%v) failed", m1)
|
||||
}
|
||||
|
||||
if err = codec2.Unmarshal(m2, &proto2); err != nil {
|
||||
t.Errorf("protoCodec.Unmarshal(%v) failed", m2)
|
||||
}
|
||||
|
||||
b1 := proto1.GetBody()
|
||||
b2 := proto2.GetBody()
|
||||
|
||||
for i, v := range b1 {
|
||||
if expectedBody1[i] != v {
|
||||
t.Errorf("expected %v at index %v but got %v", i, expectedBody1[i], v)
|
||||
}
|
||||
}
|
||||
|
||||
for i, v := range b2 {
|
||||
if expectedBody2[i] != v {
|
||||
t.Errorf("expected %v at index %v but got %v", i, expectedBody2[i], v)
|
||||
}
|
||||
}
|
||||
}
|
128
grpclb/grpclb.go
128
grpclb/grpclb.go
|
@ -111,7 +111,7 @@ type balancer struct {
|
|||
rand *rand.Rand
|
||||
}
|
||||
|
||||
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
|
||||
func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
|
||||
updates, err := w.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -121,10 +121,6 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
|
|||
if b.done {
|
||||
return grpc.ErrClientConnClosing
|
||||
}
|
||||
var bAddr remoteBalancerInfo
|
||||
if len(b.rbs) > 0 {
|
||||
bAddr = b.rbs[0]
|
||||
}
|
||||
for _, update := range updates {
|
||||
switch update.Op {
|
||||
case naming.Add:
|
||||
|
@ -173,21 +169,11 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo
|
|||
}
|
||||
// TODO: Fall back to the basic round-robin load balancing if the resulting address is
|
||||
// not a load balancer.
|
||||
if len(b.rbs) > 0 {
|
||||
// For simplicity, always use the first one now. May revisit this decision later.
|
||||
if b.rbs[0] != bAddr {
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
// Pick a random one from the list, instead of always using the first one.
|
||||
if l := len(b.rbs); l > 1 {
|
||||
tmpIdx := b.rand.Intn(l - 1)
|
||||
b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
|
||||
}
|
||||
ch <- b.rbs[0]
|
||||
}
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
ch <- b.rbs
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -261,7 +247,7 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) {
|
|||
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
|
||||
stream, err := lbc.BalanceLoad(ctx)
|
||||
if err != nil {
|
||||
grpclog.Printf("Failed to perform RPC to the remote balancer %v", err)
|
||||
return
|
||||
|
@ -340,32 +326,98 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
|
|||
}
|
||||
b.w = w
|
||||
b.mu.Unlock()
|
||||
balancerAddrCh := make(chan remoteBalancerInfo, 1)
|
||||
balancerAddrsCh := make(chan []remoteBalancerInfo, 1)
|
||||
// Spawn a goroutine to monitor the name resolution of remote load balancer.
|
||||
go func() {
|
||||
for {
|
||||
if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil {
|
||||
if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil {
|
||||
grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err)
|
||||
close(balancerAddrCh)
|
||||
close(balancerAddrsCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
// Spawn a goroutine to talk to the remote load balancer.
|
||||
go func() {
|
||||
var cc *grpc.ClientConn
|
||||
for {
|
||||
rb, ok := <-balancerAddrCh
|
||||
var (
|
||||
cc *grpc.ClientConn
|
||||
// ccError is closed when there is an error in the current cc.
|
||||
// A new rb should be picked from rbs and connected.
|
||||
ccError chan struct{}
|
||||
rb *remoteBalancerInfo
|
||||
rbs []remoteBalancerInfo
|
||||
rbIdx int
|
||||
)
|
||||
|
||||
defer func() {
|
||||
if ccError != nil {
|
||||
select {
|
||||
case <-ccError:
|
||||
default:
|
||||
close(ccError)
|
||||
}
|
||||
}
|
||||
if cc != nil {
|
||||
cc.Close()
|
||||
}
|
||||
if !ok {
|
||||
// b is closing.
|
||||
return
|
||||
}()
|
||||
|
||||
for {
|
||||
var ok bool
|
||||
select {
|
||||
case rbs, ok = <-balancerAddrsCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
foundIdx := -1
|
||||
if rb != nil {
|
||||
for i, trb := range rbs {
|
||||
if trb == *rb {
|
||||
foundIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if foundIdx >= 0 {
|
||||
if foundIdx >= 1 {
|
||||
// Move the address in use to the beginning of the list.
|
||||
b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0]
|
||||
rbIdx = 0
|
||||
}
|
||||
continue // If found, don't dial new cc.
|
||||
} else if len(rbs) > 0 {
|
||||
// Pick a random one from the list, instead of always using the first one.
|
||||
if l := len(rbs); l > 1 && rb != nil {
|
||||
tmpIdx := b.rand.Intn(l - 1)
|
||||
b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
|
||||
}
|
||||
rbIdx = 0
|
||||
rb = &rbs[0]
|
||||
} else {
|
||||
// foundIdx < 0 && len(rbs) <= 0.
|
||||
rb = nil
|
||||
}
|
||||
case <-ccError:
|
||||
ccError = nil
|
||||
if rbIdx < len(rbs)-1 {
|
||||
rbIdx++
|
||||
rb = &rbs[rbIdx]
|
||||
} else {
|
||||
rb = nil
|
||||
}
|
||||
}
|
||||
|
||||
if rb == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if cc != nil {
|
||||
cc.Close()
|
||||
}
|
||||
// Talk to the remote load balancer to get the server list.
|
||||
var err error
|
||||
creds := config.DialCreds
|
||||
ccError = make(chan struct{})
|
||||
if creds == nil {
|
||||
cc, err = grpc.Dial(rb.addr, grpc.WithInsecure())
|
||||
} else {
|
||||
|
@ -379,22 +431,24 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error {
|
|||
}
|
||||
if err != nil {
|
||||
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
|
||||
return
|
||||
close(ccError)
|
||||
continue
|
||||
}
|
||||
b.mu.Lock()
|
||||
b.seq++ // tick when getting a new balancer address
|
||||
seq := b.seq
|
||||
b.next = 0
|
||||
b.mu.Unlock()
|
||||
go func(cc *grpc.ClientConn) {
|
||||
go func(cc *grpc.ClientConn, ccError chan struct{}) {
|
||||
lbc := lbpb.NewLoadBalancerClient(cc)
|
||||
for {
|
||||
if retry := b.callRemoteBalancer(lbc, seq); !retry {
|
||||
cc.Close()
|
||||
return
|
||||
}
|
||||
b.callRemoteBalancer(lbc, seq)
|
||||
cc.Close()
|
||||
select {
|
||||
case <-ccError:
|
||||
default:
|
||||
close(ccError)
|
||||
}
|
||||
}(cc)
|
||||
}(cc, ccError)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
|
|
|
@ -99,24 +99,26 @@ func (w *testWatcher) inject(updates []*naming.Update) {
|
|||
}
|
||||
|
||||
type testNameResolver struct {
|
||||
w *testWatcher
|
||||
addr string
|
||||
w *testWatcher
|
||||
addrs []string
|
||||
}
|
||||
|
||||
func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
|
||||
r.w = &testWatcher{
|
||||
update: make(chan *naming.Update, 1),
|
||||
update: make(chan *naming.Update, len(r.addrs)),
|
||||
side: make(chan int, 1),
|
||||
readDone: make(chan int),
|
||||
}
|
||||
r.w.side <- 1
|
||||
r.w.update <- &naming.Update{
|
||||
Op: naming.Add,
|
||||
Addr: r.addr,
|
||||
Metadata: &Metadata{
|
||||
AddrType: GRPCLB,
|
||||
ServerName: lbsn,
|
||||
},
|
||||
r.w.side <- len(r.addrs)
|
||||
for _, addr := range r.addrs {
|
||||
r.w.update <- &naming.Update{
|
||||
Op: naming.Add,
|
||||
Addr: addr,
|
||||
Metadata: &Metadata{
|
||||
AddrType: GRPCLB,
|
||||
ServerName: lbsn,
|
||||
},
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
<-r.w.readDone
|
||||
|
@ -124,6 +126,12 @@ func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
|
|||
return r.w, nil
|
||||
}
|
||||
|
||||
func (r *testNameResolver) inject(updates []*naming.Update) {
|
||||
if r.w != nil {
|
||||
r.w.inject(updates)
|
||||
}
|
||||
}
|
||||
|
||||
type serverNameCheckCreds struct {
|
||||
expected string
|
||||
sn string
|
||||
|
@ -212,10 +220,11 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer)
|
|||
}
|
||||
|
||||
type helloServer struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
|
||||
}
|
||||
|
@ -223,17 +232,17 @@ func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwp
|
|||
return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
|
||||
}
|
||||
return &hwpb.HelloReply{
|
||||
Message: "Hello " + in.Name,
|
||||
Message: "Hello " + in.Name + " for " + s.addr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) {
|
||||
func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
|
||||
for _, l := range lis {
|
||||
creds := &serverNameCheckCreds{
|
||||
sn: sn,
|
||||
}
|
||||
s := grpc.NewServer(grpc.Creds(creds))
|
||||
hwpb.RegisterGreeterServer(s, &helloServer{})
|
||||
hwpb.RegisterGreeterServer(s, &helloServer{addr: l.Addr().String()})
|
||||
servers = append(servers, s)
|
||||
go func(s *grpc.Server, l net.Listener) {
|
||||
s.Serve(l)
|
||||
|
@ -248,32 +257,86 @@ func stopBackends(servers []*grpc.Server) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGRPCLB(t *testing.T) {
|
||||
// Start a backend.
|
||||
beLis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen %v", err)
|
||||
type testServers struct {
|
||||
lbAddr string
|
||||
ls *remoteBalancer
|
||||
lb *grpc.Server
|
||||
beIPs []net.IP
|
||||
bePorts []int
|
||||
}
|
||||
|
||||
func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
|
||||
var (
|
||||
beListeners []net.Listener
|
||||
ls *remoteBalancer
|
||||
lb *grpc.Server
|
||||
beIPs []net.IP
|
||||
bePorts []int
|
||||
)
|
||||
for i := 0; i < numberOfBackends; i++ {
|
||||
// Start a backend.
|
||||
beLis, e := net.Listen("tcp", "localhost:0")
|
||||
if e != nil {
|
||||
err = fmt.Errorf("Failed to listen %v", err)
|
||||
return
|
||||
}
|
||||
beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
|
||||
|
||||
beAddr := strings.Split(beLis.Addr().String(), ":")
|
||||
bePort, _ := strconv.Atoi(beAddr[1])
|
||||
bePorts = append(bePorts, bePort)
|
||||
|
||||
beListeners = append(beListeners, beLis)
|
||||
}
|
||||
beAddr := strings.Split(beLis.Addr().String(), ":")
|
||||
bePort, err := strconv.Atoi(beAddr[1])
|
||||
backends := startBackends(t, besn, beLis)
|
||||
defer stopBackends(backends)
|
||||
backends := startBackends(besn, beListeners...)
|
||||
|
||||
// Start a load balancer.
|
||||
lbLis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create the listener for the load balancer %v", err)
|
||||
err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
|
||||
return
|
||||
}
|
||||
lbCreds := &serverNameCheckCreds{
|
||||
sn: lbsn,
|
||||
}
|
||||
lb := grpc.NewServer(grpc.Creds(lbCreds))
|
||||
lb = grpc.NewServer(grpc.Creds(lbCreds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate the port number %v", err)
|
||||
err = fmt.Errorf("Failed to generate the port number %v", err)
|
||||
return
|
||||
}
|
||||
ls = newRemoteBalancer(nil, nil)
|
||||
lbpb.RegisterLoadBalancerServer(lb, ls)
|
||||
go func() {
|
||||
lb.Serve(lbLis)
|
||||
}()
|
||||
|
||||
tss = &testServers{
|
||||
lbAddr: lbLis.Addr().String(),
|
||||
ls: ls,
|
||||
lb: lb,
|
||||
beIPs: beIPs,
|
||||
bePorts: bePorts,
|
||||
}
|
||||
cleanup = func() {
|
||||
defer stopBackends(backends)
|
||||
defer func() {
|
||||
ls.stop()
|
||||
lb.Stop()
|
||||
}()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestGRPCLB(t *testing.T) {
|
||||
tss, cleanup, err := newLoadBalancer(1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
be := &lbpb.Server{
|
||||
IpAddress: beLis.Addr().(*net.TCPAddr).IP,
|
||||
Port: int32(bePort),
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
}
|
||||
var bes []*lbpb.Server
|
||||
|
@ -281,23 +344,14 @@ func TestGRPCLB(t *testing.T) {
|
|||
sl := &lbpb.ServerList{
|
||||
Servers: bes,
|
||||
}
|
||||
sls := []*lbpb.ServerList{sl}
|
||||
intervals := []time.Duration{0}
|
||||
ls := newRemoteBalancer(sls, intervals)
|
||||
lbpb.RegisterLoadBalancerServer(lb, ls)
|
||||
go func() {
|
||||
lb.Serve(lbLis)
|
||||
}()
|
||||
defer func() {
|
||||
ls.stop()
|
||||
lb.Stop()
|
||||
}()
|
||||
tss.ls.sls = []*lbpb.ServerList{sl}
|
||||
tss.ls.intervals = []time.Duration{0}
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
|
||||
addr: lbLis.Addr().String(),
|
||||
addrs: []string{tss.lbAddr},
|
||||
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
|
@ -310,65 +364,31 @@ func TestGRPCLB(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDropRequest(t *testing.T) {
|
||||
// Start 2 backends.
|
||||
beLis1, err := net.Listen("tcp", "localhost:0")
|
||||
tss, cleanup, err := newLoadBalancer(2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen %v", err)
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
beAddr1 := strings.Split(beLis1.Addr().String(), ":")
|
||||
bePort1, err := strconv.Atoi(beAddr1[1])
|
||||
|
||||
beLis2, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen %v", err)
|
||||
}
|
||||
beAddr2 := strings.Split(beLis2.Addr().String(), ":")
|
||||
bePort2, err := strconv.Atoi(beAddr2[1])
|
||||
|
||||
backends := startBackends(t, besn, beLis1, beLis2)
|
||||
defer stopBackends(backends)
|
||||
|
||||
// Start a load balancer.
|
||||
lbLis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create the listener for the load balancer %v", err)
|
||||
}
|
||||
lbCreds := &serverNameCheckCreds{
|
||||
sn: lbsn,
|
||||
}
|
||||
lb := grpc.NewServer(grpc.Creds(lbCreds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate the port number %v", err)
|
||||
}
|
||||
sls := []*lbpb.ServerList{{
|
||||
defer cleanup()
|
||||
tss.ls.sls = []*lbpb.ServerList{{
|
||||
Servers: []*lbpb.Server{{
|
||||
IpAddress: beLis1.Addr().(*net.TCPAddr).IP,
|
||||
Port: int32(bePort1),
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropRequest: true,
|
||||
}, {
|
||||
IpAddress: beLis2.Addr().(*net.TCPAddr).IP,
|
||||
Port: int32(bePort2),
|
||||
IpAddress: tss.beIPs[1],
|
||||
Port: int32(tss.bePorts[1]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropRequest: false,
|
||||
}},
|
||||
}}
|
||||
intervals := []time.Duration{0}
|
||||
ls := newRemoteBalancer(sls, intervals)
|
||||
lbpb.RegisterLoadBalancerServer(lb, ls)
|
||||
go func() {
|
||||
lb.Serve(lbLis)
|
||||
}()
|
||||
defer func() {
|
||||
ls.stop()
|
||||
lb.Stop()
|
||||
}()
|
||||
tss.ls.intervals = []time.Duration{0}
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
|
||||
addr: lbLis.Addr().String(),
|
||||
addrs: []string{tss.lbAddr},
|
||||
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
|
@ -395,31 +415,14 @@ func TestDropRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDropRequestFailedNonFailFast(t *testing.T) {
|
||||
// Start a backend.
|
||||
beLis, err := net.Listen("tcp", "localhost:0")
|
||||
tss, cleanup, err := newLoadBalancer(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen %v", err)
|
||||
}
|
||||
beAddr := strings.Split(beLis.Addr().String(), ":")
|
||||
bePort, err := strconv.Atoi(beAddr[1])
|
||||
backends := startBackends(t, besn, beLis)
|
||||
defer stopBackends(backends)
|
||||
|
||||
// Start a load balancer.
|
||||
lbLis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create the listener for the load balancer %v", err)
|
||||
}
|
||||
lbCreds := &serverNameCheckCreds{
|
||||
sn: lbsn,
|
||||
}
|
||||
lb := grpc.NewServer(grpc.Creds(lbCreds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate the port number %v", err)
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
be := &lbpb.Server{
|
||||
IpAddress: beLis.Addr().(*net.TCPAddr).IP,
|
||||
Port: int32(bePort),
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
DropRequest: true,
|
||||
}
|
||||
|
@ -428,23 +431,14 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
|
|||
sl := &lbpb.ServerList{
|
||||
Servers: bes,
|
||||
}
|
||||
sls := []*lbpb.ServerList{sl}
|
||||
intervals := []time.Duration{0}
|
||||
ls := newRemoteBalancer(sls, intervals)
|
||||
lbpb.RegisterLoadBalancerServer(lb, ls)
|
||||
go func() {
|
||||
lb.Serve(lbLis)
|
||||
}()
|
||||
defer func() {
|
||||
ls.stop()
|
||||
lb.Stop()
|
||||
}()
|
||||
tss.ls.sls = []*lbpb.ServerList{sl}
|
||||
tss.ls.intervals = []time.Duration{0}
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
|
||||
addr: lbLis.Addr().String(),
|
||||
addrs: []string{tss.lbAddr},
|
||||
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
|
@ -458,31 +452,14 @@ func TestDropRequestFailedNonFailFast(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerExpiration(t *testing.T) {
|
||||
// Start a backend.
|
||||
beLis, err := net.Listen("tcp", "localhost:0")
|
||||
tss, cleanup, err := newLoadBalancer(1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen %v", err)
|
||||
}
|
||||
beAddr := strings.Split(beLis.Addr().String(), ":")
|
||||
bePort, err := strconv.Atoi(beAddr[1])
|
||||
backends := startBackends(t, besn, beLis)
|
||||
defer stopBackends(backends)
|
||||
|
||||
// Start a load balancer.
|
||||
lbLis, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create the listener for the load balancer %v", err)
|
||||
}
|
||||
lbCreds := &serverNameCheckCreds{
|
||||
sn: lbsn,
|
||||
}
|
||||
lb := grpc.NewServer(grpc.Creds(lbCreds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate the port number %v", err)
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
be := &lbpb.Server{
|
||||
IpAddress: beLis.Addr().(*net.TCPAddr).IP,
|
||||
Port: int32(bePort),
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
}
|
||||
var bes []*lbpb.Server
|
||||
|
@ -504,21 +481,14 @@ func TestServerExpiration(t *testing.T) {
|
|||
var intervals []time.Duration
|
||||
intervals = append(intervals, 0)
|
||||
intervals = append(intervals, 500*time.Millisecond)
|
||||
ls := newRemoteBalancer(sls, intervals)
|
||||
lbpb.RegisterLoadBalancerServer(lb, ls)
|
||||
go func() {
|
||||
lb.Serve(lbLis)
|
||||
}()
|
||||
defer func() {
|
||||
ls.stop()
|
||||
lb.Stop()
|
||||
}()
|
||||
tss.ls.sls = sls
|
||||
tss.ls.intervals = intervals
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
|
||||
addr: lbLis.Addr().String(),
|
||||
addrs: []string{tss.lbAddr},
|
||||
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
|
@ -539,3 +509,90 @@ func TestServerExpiration(t *testing.T) {
|
|||
}
|
||||
cc.Close()
|
||||
}
|
||||
|
||||
// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
|
||||
func TestBalancerDisconnects(t *testing.T) {
|
||||
var (
|
||||
lbAddrs []string
|
||||
lbs []*grpc.Server
|
||||
)
|
||||
for i := 0; i < 3; i++ {
|
||||
tss, cleanup, err := newLoadBalancer(1)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new load balancer: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
be := &lbpb.Server{
|
||||
IpAddress: tss.beIPs[0],
|
||||
Port: int32(tss.bePorts[0]),
|
||||
LoadBalanceToken: lbToken,
|
||||
}
|
||||
var bes []*lbpb.Server
|
||||
bes = append(bes, be)
|
||||
sl := &lbpb.ServerList{
|
||||
Servers: bes,
|
||||
}
|
||||
tss.ls.sls = []*lbpb.ServerList{sl}
|
||||
tss.ls.intervals = []time.Duration{0}
|
||||
|
||||
lbAddrs = append(lbAddrs, tss.lbAddr)
|
||||
lbs = append(lbs, tss.lb)
|
||||
}
|
||||
|
||||
creds := serverNameCheckCreds{
|
||||
expected: besn,
|
||||
}
|
||||
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
resolver := &testNameResolver{
|
||||
addrs: lbAddrs[:2],
|
||||
}
|
||||
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(resolver)), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial to the backend %v", err)
|
||||
}
|
||||
helloC := hwpb.NewGreeterClient(cc)
|
||||
var message string
|
||||
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
} else {
|
||||
message = resp.Message
|
||||
}
|
||||
// The initial resolver update contains lbs[0] and lbs[1].
|
||||
// When lbs[0] is stopped, lbs[1] should be used.
|
||||
lbs[0].Stop()
|
||||
for {
|
||||
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
} else if resp.Message != message {
|
||||
// A new backend server should receive the request.
|
||||
// The response contains the backend address, so the message should be different from the previous one.
|
||||
message = resp.Message
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
// Inject a update to add lbs[2] to resolved addresses.
|
||||
resolver.inject([]*naming.Update{
|
||||
{Op: naming.Add,
|
||||
Addr: lbAddrs[2],
|
||||
Metadata: &Metadata{
|
||||
AddrType: GRPCLB,
|
||||
ServerName: lbsn,
|
||||
},
|
||||
},
|
||||
})
|
||||
// Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
|
||||
lbs[1].Stop()
|
||||
for {
|
||||
if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
|
||||
t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
|
||||
} else if resp.Message != message {
|
||||
// A new backend server should receive the request.
|
||||
// The response contains the backend address, so the message should be different from the previous one.
|
||||
break
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
cc.Close()
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ import (
|
|||
// UnaryInvoker is called by UnaryClientInterceptor to complete RPCs.
|
||||
type UnaryInvoker func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error
|
||||
|
||||
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. inovker is the handler to complete the RPC
|
||||
// UnaryClientInterceptor intercepts the execution of a unary RPC on the client. invoker is the handler to complete the RPC
|
||||
// and it is the responsibility of the interceptor to call it.
|
||||
// This is the EXPERIMENTAL API.
|
||||
type UnaryClientInterceptor func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error
|
||||
|
|
|
@ -392,7 +392,7 @@ func DoPerRPCCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScop
|
|||
}
|
||||
token := GetToken(serviceAccountKeyFile, oauthScope)
|
||||
kv := map[string]string{"authorization": token.TokenType + " " + token.AccessToken}
|
||||
ctx := metadata.NewContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}})
|
||||
reply, err := tc.UnaryCall(ctx, req)
|
||||
if err != nil {
|
||||
grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err)
|
||||
|
@ -416,7 +416,7 @@ var (
|
|||
|
||||
// DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent.
|
||||
func DoCancelAfterBegin(tc testpb.TestServiceClient, args ...grpc.CallOption) {
|
||||
ctx, cancel := context.WithCancel(metadata.NewContext(context.Background(), testMetadata))
|
||||
ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(context.Background(), testMetadata))
|
||||
stream, err := tc.StreamingInputCall(ctx, args...)
|
||||
if err != nil {
|
||||
grpclog.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err)
|
||||
|
@ -491,7 +491,7 @@ func DoCustomMetadata(tc testpb.TestServiceClient, args ...grpc.CallOption) {
|
|||
ResponseSize: proto.Int32(int32(1)),
|
||||
Payload: pl,
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), customMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), customMetadata)
|
||||
var header, trailer metadata.MD
|
||||
args = append(args, grpc.Header(&header), grpc.Trailer(&trailer))
|
||||
reply, err := tc.UnaryCall(
|
||||
|
@ -627,7 +627,7 @@ func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error)
|
|||
|
||||
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
status := in.GetResponseStatus()
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
if initialMetadata, ok := md[initialMetadataKey]; ok {
|
||||
header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
|
||||
grpc.SendHeader(ctx, header)
|
||||
|
@ -686,7 +686,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
|
|||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
if md, ok := metadata.FromContext(stream.Context()); ok {
|
||||
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
|
||||
if initialMetadata, ok := md[initialMetadataKey]; ok {
|
||||
header := metadata.Pairs(initialMetadataKey, initialMetadata[0])
|
||||
stream.SendHeader(header)
|
||||
|
|
|
@ -136,17 +136,41 @@ func Join(mds ...MD) MD {
|
|||
return out
|
||||
}
|
||||
|
||||
type mdKey struct{}
|
||||
type mdIncomingKey struct{}
|
||||
type mdOutgoingKey struct{}
|
||||
|
||||
// NewContext creates a new context with md attached.
|
||||
// NewContext is a wrapper for NewOutgoingContext(ctx, md). Deprecated.
|
||||
func NewContext(ctx context.Context, md MD) context.Context {
|
||||
return context.WithValue(ctx, mdKey{}, md)
|
||||
return NewOutgoingContext(ctx, md)
|
||||
}
|
||||
|
||||
// FromContext returns the MD in ctx if it exists.
|
||||
// The returned md should be immutable, writing to it may cause races.
|
||||
// Modification should be made to the copies of the returned md.
|
||||
// NewIncomingContext creates a new context with incoming md attached.
|
||||
func NewIncomingContext(ctx context.Context, md MD) context.Context {
|
||||
return context.WithValue(ctx, mdIncomingKey{}, md)
|
||||
}
|
||||
|
||||
// NewOutgoingContext creates a new context with outgoing md attached.
|
||||
func NewOutgoingContext(ctx context.Context, md MD) context.Context {
|
||||
return context.WithValue(ctx, mdOutgoingKey{}, md)
|
||||
}
|
||||
|
||||
// FromContext is a wrapper for FromIncomingContext(ctx). Deprecated.
|
||||
func FromContext(ctx context.Context) (md MD, ok bool) {
|
||||
md, ok = ctx.Value(mdKey{}).(MD)
|
||||
return FromIncomingContext(ctx)
|
||||
}
|
||||
|
||||
// FromIncomingContext returns the incoming MD in ctx if it exists. The
|
||||
// returned md should be immutable, writing to it may cause races.
|
||||
// Modification should be made to the copies of the returned md.
|
||||
func FromIncomingContext(ctx context.Context) (md MD, ok bool) {
|
||||
md, ok = ctx.Value(mdIncomingKey{}).(MD)
|
||||
return
|
||||
}
|
||||
|
||||
// FromOutgoingContext returns the outgoing MD in ctx if it exists. The
|
||||
// returned md should be immutable, writing to it may cause races.
|
||||
// Modification should be made to the copies of the returned md.
|
||||
func FromOutgoingContext(ctx context.Context) (md MD, ok bool) {
|
||||
md, ok = ctx.Value(mdOutgoingKey{}).(MD)
|
||||
return
|
||||
}
|
||||
|
|
27
rpc_util.go
27
rpc_util.go
|
@ -43,7 +43,6 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -53,32 +52,6 @@ import (
|
|||
"google.golang.org/grpc/transport"
|
||||
)
|
||||
|
||||
// Codec defines the interface gRPC uses to encode and decode messages.
|
||||
type Codec interface {
|
||||
// Marshal returns the wire format of v.
|
||||
Marshal(v interface{}) ([]byte, error)
|
||||
// Unmarshal parses the wire format into v.
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
// String returns the name of the Codec implementation. The returned
|
||||
// string will be used as part of content type in transmission.
|
||||
String() string
|
||||
}
|
||||
|
||||
// protoCodec is a Codec implementation with protobuf. It is the default codec for gRPC.
|
||||
type protoCodec struct{}
|
||||
|
||||
func (protoCodec) Marshal(v interface{}) ([]byte, error) {
|
||||
return proto.Marshal(v.(proto.Message))
|
||||
}
|
||||
|
||||
func (protoCodec) Unmarshal(data []byte, v interface{}) error {
|
||||
return proto.Unmarshal(data, v.(proto.Message))
|
||||
}
|
||||
|
||||
func (protoCodec) String() string {
|
||||
return "proto"
|
||||
}
|
||||
|
||||
// Compressor defines the interface gRPC uses to compress a message.
|
||||
type Compressor interface {
|
||||
// Do compresses p into w.
|
||||
|
|
25
server.go
25
server.go
|
@ -792,19 +792,24 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
Delay: false,
|
||||
}
|
||||
if err := s.sendResponse(t, stream, reply, s.opts.cp, opts); err != nil {
|
||||
// TODO: Translate error into a status.Status error if necessary?
|
||||
// TODO: Write status when appropriate.
|
||||
s, ok := status.FromError(err)
|
||||
if !ok {
|
||||
// TODO: Parse possible non-status error
|
||||
if err == io.EOF {
|
||||
// The entire stream is done (for unary RPC only).
|
||||
return err
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
if e := t.WriteStatus(stream, s); e != nil {
|
||||
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e)
|
||||
}
|
||||
} else {
|
||||
switch s.Code() {
|
||||
case codes.InvalidArgument:
|
||||
if e := t.WriteStatus(stream, s); e != nil {
|
||||
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", e)
|
||||
switch st := err.(type) {
|
||||
case transport.ConnectionError:
|
||||
// Nothing to do here.
|
||||
case transport.StreamError:
|
||||
if e := t.WriteStatus(stream, status.New(st.Code, st.Desc)); e != nil {
|
||||
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", e)
|
||||
}
|
||||
// TODO: Add cases if needed
|
||||
default:
|
||||
panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
|
||||
}
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -75,7 +75,7 @@ var (
|
|||
type testServer struct{}
|
||||
|
||||
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
if err := grpc.SendHeader(ctx, md); err != nil {
|
||||
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
|
||||
|
@ -93,7 +93,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
|||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
md, ok := metadata.FromContext(stream.Context())
|
||||
md, ok := metadata.FromIncomingContext(stream.Context())
|
||||
if ok {
|
||||
if err := stream.SendHeader(md); err != nil {
|
||||
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
|
||||
|
@ -219,10 +219,11 @@ func (te *test) clientConn() *grpc.ClientConn {
|
|||
}
|
||||
|
||||
type rpcConfig struct {
|
||||
count int // Number of requests and responses for streaming RPCs.
|
||||
success bool // Whether the RPC should succeed or return error.
|
||||
failfast bool
|
||||
streaming bool // Whether the rpc should be a streaming RPC.
|
||||
count int // Number of requests and responses for streaming RPCs.
|
||||
success bool // Whether the RPC should succeed or return error.
|
||||
failfast bool
|
||||
streaming bool // Whether the rpc should be a streaming RPC.
|
||||
noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs.
|
||||
}
|
||||
|
||||
func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
|
||||
|
@ -237,7 +238,7 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple
|
|||
} else {
|
||||
req = &testpb.SimpleRequest{Id: errorID}
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
|
||||
resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast))
|
||||
return req, resp, err
|
||||
|
@ -250,7 +251,7 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
|
|||
err error
|
||||
)
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
|
||||
stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast))
|
||||
if err != nil {
|
||||
return reqs, resps, err
|
||||
}
|
||||
|
@ -275,8 +276,14 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest
|
|||
if err = stream.CloseSend(); err != nil && err != io.EOF {
|
||||
return reqs, resps, err
|
||||
}
|
||||
if _, err = stream.Recv(); err != io.EOF {
|
||||
return reqs, resps, err
|
||||
if !c.noLastRecv {
|
||||
if _, err = stream.Recv(); err != io.EOF {
|
||||
return reqs, resps, err
|
||||
}
|
||||
} else {
|
||||
// In the case of not calling the last recv, sleep to avoid
|
||||
// returning too fast to miss the remaining stats (InTrailer and End).
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
return reqs, resps, nil
|
||||
|
@ -968,6 +975,20 @@ func TestClientStatsStreamingRPC(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// If the user doesn't call the last recv() on clientSteam.
|
||||
func TestClientStatsStreamingRPCNotCallingLastRecv(t *testing.T) {
|
||||
count := 1
|
||||
testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, streaming: true, noLastRecv: true}, map[int]*checkFuncWithCount{
|
||||
begin: {checkBegin, 1},
|
||||
outHeader: {checkOutHeader, 1},
|
||||
outPayload: {checkOutPayload, count},
|
||||
inHeader: {checkInHeader, 1},
|
||||
inPayload: {checkInPayload, count},
|
||||
inTrailer: {checkInTrailer, 1},
|
||||
end: {checkEnd, 1},
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientStatsStreamingRPCError(t *testing.T) {
|
||||
count := 5
|
||||
testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, streaming: true}, map[int]*checkFuncWithCount{
|
||||
|
|
|
@ -46,7 +46,7 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
spb "github.com/google/go-genproto/googleapis/rpc/status"
|
||||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ import (
|
|||
"testing"
|
||||
|
||||
apb "github.com/golang/protobuf/ptypes/any"
|
||||
spb "github.com/google/go-genproto/googleapis/rpc/status"
|
||||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
|
|
41
stream.go
41
stream.go
|
@ -303,9 +303,10 @@ type clientStream struct {
|
|||
|
||||
tracing bool // set to EnableTracing when the clientStream is created.
|
||||
|
||||
mu sync.Mutex
|
||||
put func()
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
put func()
|
||||
closed bool
|
||||
finished bool
|
||||
// trInfo.tr is set when the clientStream is created (if EnableTracing is true),
|
||||
// and is set to nil when the clientStream's finish method is called.
|
||||
trInfo traceInfo
|
||||
|
@ -394,21 +395,6 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
|
|||
}
|
||||
|
||||
func (cs *clientStream) RecvMsg(m interface{}) (err error) {
|
||||
defer func() {
|
||||
if err != nil && cs.statsHandler != nil {
|
||||
// Only generate End if err != nil.
|
||||
// If err == nil, it's not the last RecvMsg.
|
||||
// The last RecvMsg gets either an RPC error or io.EOF.
|
||||
end := &stats.End{
|
||||
Client: true,
|
||||
EndTime: time.Now(),
|
||||
}
|
||||
if err != io.EOF {
|
||||
end.Error = toRPCErr(err)
|
||||
}
|
||||
cs.statsHandler.HandleRPC(cs.statsCtx, end)
|
||||
}
|
||||
}()
|
||||
var inPayload *stats.InPayload
|
||||
if cs.statsHandler != nil {
|
||||
inPayload = &stats.InPayload{
|
||||
|
@ -494,13 +480,17 @@ func (cs *clientStream) closeTransportStream(err error) {
|
|||
}
|
||||
|
||||
func (cs *clientStream) finish(err error) {
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
if cs.finished {
|
||||
return
|
||||
}
|
||||
cs.finished = true
|
||||
defer func() {
|
||||
if cs.cancel != nil {
|
||||
cs.cancel()
|
||||
}
|
||||
}()
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
for _, o := range cs.opts {
|
||||
o.after(&cs.c)
|
||||
}
|
||||
|
@ -508,6 +498,17 @@ func (cs *clientStream) finish(err error) {
|
|||
cs.put()
|
||||
cs.put = nil
|
||||
}
|
||||
if cs.statsHandler != nil {
|
||||
end := &stats.End{
|
||||
Client: true,
|
||||
EndTime: time.Now(),
|
||||
}
|
||||
if err != io.EOF {
|
||||
// end.Error is nil if the RPC finished successfully.
|
||||
end.Error = toRPCErr(err)
|
||||
}
|
||||
cs.statsHandler.HandleRPC(cs.statsCtx, end)
|
||||
}
|
||||
if !cs.tracing {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -55,9 +55,9 @@ import (
|
|||
|
||||
"github.com/golang/protobuf/proto"
|
||||
anypb "github.com/golang/protobuf/ptypes/any"
|
||||
spb "github.com/google/go-genproto/googleapis/rpc/status"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
@ -118,7 +118,7 @@ type testServer struct {
|
|||
}
|
||||
|
||||
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
// For testing purpose, returns an error if user-agent is failAppUA.
|
||||
// To test that client gets the correct error.
|
||||
if ua, ok := md["user-agent"]; !ok || strings.HasPrefix(ua[0], failAppUA) {
|
||||
|
@ -152,7 +152,7 @@ func newPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) {
|
|||
}
|
||||
|
||||
func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
|
||||
md, ok := metadata.FromContext(ctx)
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
if _, exists := md[":authority"]; !exists {
|
||||
return nil, grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md)
|
||||
|
@ -223,7 +223,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
|||
}
|
||||
|
||||
func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error {
|
||||
if md, ok := metadata.FromContext(stream.Context()); ok {
|
||||
if md, ok := metadata.FromIncomingContext(stream.Context()); ok {
|
||||
if _, exists := md[":authority"]; !exists {
|
||||
return grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md)
|
||||
}
|
||||
|
@ -274,7 +274,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
|
|||
}
|
||||
|
||||
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
md, ok := metadata.FromContext(stream.Context())
|
||||
md, ok := metadata.FromIncomingContext(stream.Context())
|
||||
if ok {
|
||||
if s.setAndSendHeader {
|
||||
if err := stream.SetHeader(md); err != nil {
|
||||
|
@ -1943,7 +1943,7 @@ func testFailedEmptyUnary(t *testing.T, e env) {
|
|||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
wantErr := detailedError
|
||||
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) {
|
||||
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
|
||||
|
@ -2161,7 +2161,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
|
|||
Payload: payload,
|
||||
}
|
||||
var header, trailer metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
|
@ -2207,7 +2207,7 @@ func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) {
|
|||
Payload: payload,
|
||||
}
|
||||
var trailer metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
|
@ -2230,7 +2230,7 @@ func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) {
|
|||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
|
@ -2281,7 +2281,7 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
|
|||
Payload: payload,
|
||||
}
|
||||
var header metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
|
@ -2325,7 +2325,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
|
|||
}
|
||||
|
||||
var header metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
|
@ -2368,7 +2368,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
|
|||
Payload: payload,
|
||||
}
|
||||
var header metadata.MD
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err == nil {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <non-nil>", ctx, err)
|
||||
}
|
||||
|
@ -2400,7 +2400,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
|
|||
argSize = 1
|
||||
respSize = 1
|
||||
)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
|
@ -2444,7 +2444,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
|
|||
argSize = 1
|
||||
respSize = 1
|
||||
)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
|
@ -2508,7 +2508,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
|
|||
argSize = 1
|
||||
respSize = -1
|
||||
)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
|
@ -2573,7 +2573,7 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) {
|
|||
ResponseSize: proto.Int32(314),
|
||||
Payload: payload,
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), malformedHTTP2Metadata)
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), malformedHTTP2Metadata)
|
||||
if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Internal {
|
||||
t.Fatalf("TestService.UnaryCall(%v, _) = _, %v; want _, %s", ctx, err, codes.Internal)
|
||||
}
|
||||
|
@ -2903,7 +2903,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
|||
defer te.tearDown()
|
||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||
|
||||
ctx := metadata.NewContext(te.ctx, testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(te.ctx, testMetadata)
|
||||
stream, err := tc.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||
|
@ -3042,7 +3042,7 @@ func testFailedServerStreaming(t *testing.T, e env) {
|
|||
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||
ResponseParameters: respParam,
|
||||
}
|
||||
ctx := metadata.NewContext(te.ctx, testMetadata)
|
||||
ctx := metadata.NewOutgoingContext(te.ctx, testMetadata)
|
||||
stream, err := tc.StreamingOutputCall(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want <nil>", tc, err)
|
||||
|
@ -3446,7 +3446,7 @@ func testCompressOK(t *testing.T, e env) {
|
|||
ResponseSize: proto.Int32(respSize),
|
||||
Payload: payload,
|
||||
}
|
||||
ctx := metadata.NewContext(context.Background(), metadata.Pairs("something", "something"))
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something"))
|
||||
if _, err := tc.UnaryCall(ctx, req); err != nil {
|
||||
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
|
||||
}
|
||||
|
@ -4238,3 +4238,168 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) {
|
|||
}
|
||||
return fw.dst.Write(p)
|
||||
}
|
||||
|
||||
// stubServer is a server that is easy to customize within individual test
|
||||
// cases.
|
||||
type stubServer struct {
|
||||
// Guarantees we satisfy this interface; panics if unimplemented methods are called.
|
||||
testpb.TestServiceServer
|
||||
|
||||
// Customizable implementations of server handlers.
|
||||
emptyCall func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error)
|
||||
fullDuplexCall func(stream testpb.TestService_FullDuplexCallServer) error
|
||||
|
||||
// A client connected to this service the test may use. Created in Start().
|
||||
client testpb.TestServiceClient
|
||||
|
||||
cleanups []func() // Lambdas executed in Stop(); populated by Start().
|
||||
}
|
||||
|
||||
func (ss *stubServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
return ss.emptyCall(ctx, in)
|
||||
}
|
||||
|
||||
func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
return ss.fullDuplexCall(stream)
|
||||
}
|
||||
|
||||
// Start starts the server and creates a client connected to it.
|
||||
func (ss *stubServer) Start() error {
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`net.Listen("tcp", ":0") = %v`, err)
|
||||
}
|
||||
ss.cleanups = append(ss.cleanups, func() { lis.Close() })
|
||||
|
||||
s := grpc.NewServer()
|
||||
testpb.RegisterTestServiceServer(s, ss)
|
||||
go s.Serve(lis)
|
||||
ss.cleanups = append(ss.cleanups, s.Stop)
|
||||
|
||||
cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock())
|
||||
if err != nil {
|
||||
return fmt.Errorf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
|
||||
}
|
||||
ss.cleanups = append(ss.cleanups, func() { cc.Close() })
|
||||
|
||||
ss.client = testpb.NewTestServiceClient(cc)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ss *stubServer) Stop() {
|
||||
for i := len(ss.cleanups) - 1; i >= 0; i-- {
|
||||
ss.cleanups[i]()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
|
||||
const mdkey = "somedata"
|
||||
|
||||
// endpoint ensures mdkey is NOT in metadata and returns an error if it is.
|
||||
endpoint := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] != nil {
|
||||
return nil, status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey)
|
||||
}
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
if err := endpoint.Start(); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer endpoint.Stop()
|
||||
|
||||
// proxy ensures mdkey IS in metadata, then forwards the RPC to endpoint
|
||||
// without explicitly copying the metadata.
|
||||
proxy := &stubServer{
|
||||
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] == nil {
|
||||
return nil, status.Errorf(codes.Internal, "proxy: md=%v; want contains(%q)", md, mdkey)
|
||||
}
|
||||
return endpoint.client.EmptyCall(ctx, in)
|
||||
},
|
||||
}
|
||||
if err := proxy.Start(); err != nil {
|
||||
t.Fatalf("Error starting proxy server: %v", err)
|
||||
}
|
||||
defer proxy.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
md := metadata.Pairs(mdkey, "val")
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
// Sanity check that endpoint properly errors when it sees mdkey.
|
||||
_, err := endpoint.client.EmptyCall(ctx, &testpb.Empty{})
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Internal {
|
||||
t.Fatalf("endpoint.client.EmptyCall(_, _) = _, %v; want _, <status with Code()=Internal>", err)
|
||||
}
|
||||
|
||||
if _, err := proxy.client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingProxyDoesNotForwardMetadata(t *testing.T) {
|
||||
const mdkey = "somedata"
|
||||
|
||||
// doFDC performs a FullDuplexCall with client and returns the error from the
|
||||
// first stream.Recv call, or nil if that error is io.EOF. Calls t.Fatal if
|
||||
// the stream cannot be established.
|
||||
doFDC := func(ctx context.Context, client testpb.TestServiceClient) error {
|
||||
stream, err := client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unwanted error: %v", err)
|
||||
}
|
||||
if _, err := stream.Recv(); err != io.EOF {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// endpoint ensures mdkey is NOT in metadata and returns an error if it is.
|
||||
endpoint := &stubServer{
|
||||
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
ctx := stream.Context()
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] != nil {
|
||||
return status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := endpoint.Start(); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer endpoint.Stop()
|
||||
|
||||
// proxy ensures mdkey IS in metadata, then forwards the RPC to endpoint
|
||||
// without explicitly copying the metadata.
|
||||
proxy := &stubServer{
|
||||
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
ctx := stream.Context()
|
||||
if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] == nil {
|
||||
return status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey)
|
||||
}
|
||||
return doFDC(ctx, endpoint.client)
|
||||
},
|
||||
}
|
||||
if err := proxy.Start(); err != nil {
|
||||
t.Fatalf("Error starting proxy server: %v", err)
|
||||
}
|
||||
defer proxy.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
md := metadata.Pairs(mdkey, "val")
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
|
||||
// Sanity check that endpoint properly errors when it sees mdkey in ctx.
|
||||
err := doFDC(ctx, endpoint.client)
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.Internal {
|
||||
t.Fatalf("stream.Recv() = _, %v; want _, <status with Code()=Internal>", err)
|
||||
}
|
||||
|
||||
if err := doFDC(ctx, proxy.client); err != nil {
|
||||
t.Fatalf("doFDC(_, proxy.client) = %v; want nil", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -319,7 +319,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||
if req.TLS != nil {
|
||||
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
|
||||
}
|
||||
ctx = metadata.NewContext(ctx, ht.headerMD)
|
||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
s.ctx = newContextWithStream(ctx, s)
|
||||
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
|
||||
|
|
|
@ -121,6 +121,9 @@ type http2Client struct {
|
|||
goAwayID uint32
|
||||
// prevGoAway ID records the Last-Stream-ID in the previous GOAway frame.
|
||||
prevGoAwayID uint32
|
||||
// goAwayReason records the http2.ErrCode and debug data received with the
|
||||
// GoAway frame.
|
||||
goAwayReason GoAwayReason
|
||||
}
|
||||
|
||||
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
|
||||
|
@ -432,7 +435,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
hasMD bool
|
||||
endHeaders bool
|
||||
)
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
if md, ok := metadata.FromOutgoingContext(ctx); ok {
|
||||
hasMD = true
|
||||
for k, v := range md {
|
||||
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
|
||||
|
@ -909,6 +912,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
|||
t.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
t.setGoAwayReason(f)
|
||||
}
|
||||
t.goAwayID = f.LastStreamID
|
||||
close(t.goAway)
|
||||
|
@ -916,6 +920,26 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
|
|||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// setGoAwayReason sets the value of t.goAwayReason based
|
||||
// on the GoAway frame received.
|
||||
// It expects a lock on transport's mutext to be held by
|
||||
// the caller.
|
||||
func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) {
|
||||
t.goAwayReason = NoReason
|
||||
switch f.ErrCode {
|
||||
case http2.ErrCodeEnhanceYourCalm:
|
||||
if string(f.DebugData()) == "too_many_pings" {
|
||||
t.goAwayReason = TooManyPings
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *http2Client) GetGoAwayReason() GoAwayReason {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
return t.goAwayReason
|
||||
}
|
||||
|
||||
func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) {
|
||||
id := f.Header().StreamID
|
||||
incr := f.Increment
|
||||
|
|
|
@ -261,7 +261,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
s.ctx = newContextWithStream(s.ctx, s)
|
||||
// Attach the received metadata to the context.
|
||||
if len(state.mdata) > 0 {
|
||||
s.ctx = metadata.NewContext(s.ctx, state.mdata)
|
||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
||||
}
|
||||
|
||||
s.dec = &recvBufferReader{
|
||||
|
|
|
@ -45,9 +45,9 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
spb "github.com/google/go-genproto/googleapis/rpc/status"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -383,6 +383,9 @@ func newFramer(conn net.Conn) *framer {
|
|||
writer: bufio.NewWriterSize(conn, http2IOBufSize),
|
||||
}
|
||||
f.fr = http2.NewFramer(f.writer, f.reader)
|
||||
// Opt-in to Frame reuse API on framer to reduce garbage.
|
||||
// Frames aren't safe to read from after a subsequent call to ReadFrame.
|
||||
f.fr.SetReuseFrames()
|
||||
f.fr.ReadMetaHeaders = hpack.NewDecoder(http2InitHeaderTableSize, nil)
|
||||
return f
|
||||
}
|
||||
|
|
|
@ -341,6 +341,12 @@ func (s *Stream) finish(st *status.Status) {
|
|||
close(s.done)
|
||||
}
|
||||
|
||||
// GoString is implemented by Stream so context.String() won't
|
||||
// race when printing %#v.
|
||||
func (s *Stream) GoString() string {
|
||||
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
|
||||
}
|
||||
|
||||
// The key to save transport.Stream in the context.
|
||||
type streamKey struct{}
|
||||
|
||||
|
@ -487,6 +493,9 @@ type ClientTransport interface {
|
|||
// receives the draining signal from the server (e.g., GOAWAY frame in
|
||||
// HTTP/2).
|
||||
GoAway() <-chan struct{}
|
||||
|
||||
// GetGoAwayReason returns the reason why GoAway frame was received.
|
||||
GetGoAwayReason() GoAwayReason
|
||||
}
|
||||
|
||||
// ServerTransport is the common interface for all gRPC server-side transport
|
||||
|
@ -624,3 +633,16 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
|
|||
return i, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GoAwayReason contains the reason for the GoAway frame received.
|
||||
type GoAwayReason uint8
|
||||
|
||||
const (
|
||||
// Invalid indicates that no GoAway frame is received.
|
||||
Invalid GoAwayReason = 0
|
||||
// NoReason is the default value when GoAway frame is received.
|
||||
NoReason GoAwayReason = 1
|
||||
// TooManyPings indicates that a GoAway frame with ErrCodeEnhanceYourCalm
|
||||
// was recieved and that the debug data said "too_many_pings".
|
||||
TooManyPings GoAwayReason = 2
|
||||
)
|
||||
|
|
Загрузка…
Ссылка в новой задаче