revert handshaker changes
This commit is contained in:
Родитель
923d211a3d
Коммит
3617cd5ab3
|
@ -107,11 +107,11 @@ func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) Di
|
|||
// WithHandshaker returns a DialOption that specifies a function to perform some handshaking
|
||||
// with the server. It is typically used to negotiate the wire protocol version and security
|
||||
// protocol with the server.
|
||||
func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.Handshaker = h
|
||||
}
|
||||
}
|
||||
//func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption {
|
||||
// return func(o *dialOptions) {
|
||||
// o.copts.Handshaker = h
|
||||
// }
|
||||
//}
|
||||
|
||||
// Dial creates a client connection the given target.
|
||||
// TODO(zhaoq): Have an option to make Dial return immediately without waiting
|
||||
|
|
|
@ -84,12 +84,14 @@ type ProtocolInfo struct {
|
|||
// TransportAuthenticator defines the common interface for all the live gRPC wire
|
||||
// protocols and supported transport security protocols (e.g., TLS, SSL).
|
||||
type TransportAuthenticator interface {
|
||||
// Handshake does the authentication handshake specified by the corresponding
|
||||
// authentication protocol on rawConn.
|
||||
Handshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
|
||||
// ClientHandshake does the authentication handshake specified by the corresponding
|
||||
// authentication protocol on rawConn for clients.
|
||||
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
|
||||
// ServerHandshake does the authentication handshake for servers.
|
||||
ServerHandshake(rawConn net.Conn) (net.Conn, error)
|
||||
// NewListener creates a listener which accepts connections with requested
|
||||
// authentication handshake.
|
||||
NewListener(lis net.Listener) net.Listener
|
||||
//NewListener(lis net.Listener) net.Listener
|
||||
// Info provides the ProtocolInfo of this TransportAuthenticator.
|
||||
Info() ProtocolInfo
|
||||
Credentials
|
||||
|
@ -120,7 +122,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
|||
func (timeoutError) Timeout() bool { return true }
|
||||
func (timeoutError) Temporary() bool { return true }
|
||||
|
||||
func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
|
||||
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
|
||||
// borrow some code from tls.DialWithDialer
|
||||
var errChannel chan error
|
||||
if timeout != 0 {
|
||||
|
@ -152,9 +154,13 @@ func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duratio
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
// NewListener creates a net.Listener using the information in tlsCreds.
|
||||
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
|
||||
return tls.NewListener(lis, &c.config)
|
||||
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) {
|
||||
conn := tls.Server(rawConn, &c.config)
|
||||
if err := conn.Handshake(); err != nil {
|
||||
rawConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
|
||||
|
|
|
@ -225,15 +225,15 @@ func main() {
|
|||
if err != nil {
|
||||
grpclog.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
grpcServer := grpc.NewServer()
|
||||
pb.RegisterRouteGuideServer(grpcServer, newServer())
|
||||
var opts []grpc.ServerOption
|
||||
if *tls {
|
||||
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
|
||||
if err != nil {
|
||||
grpclog.Fatalf("Failed to generate credentials %v", err)
|
||||
}
|
||||
grpcServer.Serve(creds.NewListener(lis))
|
||||
} else {
|
||||
opts = []grpc.ServerOption{grpc.Creds(creds)}
|
||||
}
|
||||
grpcServer := grpc.NewServer(opts...)
|
||||
pb.RegisterRouteGuideServer(grpcServer, newServer())
|
||||
grpcServer.Serve(lis)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,8 @@ creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
|
|||
if err != nil {
|
||||
log.Fatalf("Failed to generate credentials %v", err)
|
||||
}
|
||||
server := grpc.NewServer(grpc.Creds(creds))
|
||||
...
|
||||
server.Serve(creds.NewListener(lis))
|
||||
```
|
||||
|
||||
|
|
|
@ -195,15 +195,15 @@ func main() {
|
|||
if err != nil {
|
||||
grpclog.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
server := grpc.NewServer()
|
||||
testpb.RegisterTestServiceServer(server, &testServer{})
|
||||
var opts []grpc.ServerOption
|
||||
if *useTLS {
|
||||
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
|
||||
if err != nil {
|
||||
grpclog.Fatalf("Failed to generate credentials %v", err)
|
||||
}
|
||||
server.Serve(creds.NewListener(lis))
|
||||
} else {
|
||||
opts = []grpc.ServerOption{grpc.Creds(creds)}
|
||||
}
|
||||
server := grpc.NewServer(opts...)
|
||||
testpb.RegisterTestServiceServer(server, &testServer{})
|
||||
server.Serve(lis)
|
||||
}
|
||||
}
|
||||
|
|
30
server.go
30
server.go
|
@ -44,6 +44,7 @@ import (
|
|||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/transport"
|
||||
|
@ -85,7 +86,7 @@ type Server struct {
|
|||
}
|
||||
|
||||
type options struct {
|
||||
handshaker func(net.Conn) error
|
||||
creds []credentials.Credentials
|
||||
codec Codec
|
||||
maxConcurrentStreams uint32
|
||||
}
|
||||
|
@ -93,14 +94,6 @@ type options struct {
|
|||
// A ServerOption sets options.
|
||||
type ServerOption func(*options)
|
||||
|
||||
// Handshaker returns a ServerOption that specifies a function to perform user-specified
|
||||
// handshaking on the connection before it becomes usable for gRPC.
|
||||
func Handshaker(f func(net.Conn) error) ServerOption {
|
||||
return func(o *options) {
|
||||
o.handshaker = f
|
||||
}
|
||||
}
|
||||
|
||||
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
|
||||
func CustomCodec(codec Codec) ServerOption {
|
||||
return func(o *options) {
|
||||
|
@ -116,6 +109,13 @@ func MaxConcurrentStreams(n uint32) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// Creds returns a ServerOption that sets credentials for server connections.
|
||||
func Creds(c credentials.Credentials) ServerOption {
|
||||
return func(o *options) {
|
||||
o.creds = append(o.creds, c)
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer creates a gRPC server which has no service registered and has not
|
||||
// started to accept requests yet.
|
||||
func NewServer(opt ...ServerOption) *Server {
|
||||
|
@ -195,13 +195,15 @@ func (s *Server) Serve(lis net.Listener) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Perform handshaking if it is required.
|
||||
if s.opts.handshaker != nil {
|
||||
if err := s.opts.handshaker(c); err != nil {
|
||||
grpclog.Println("grpc: Server.Serve failed to complete handshake.")
|
||||
c.Close()
|
||||
for _, o := range s.opts.creds {
|
||||
if creds, ok := o.(credentials.TransportAuthenticator); ok {
|
||||
c, err = creds.ServerHandshake(c)
|
||||
if err != nil {
|
||||
grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.conns == nil {
|
||||
|
|
|
@ -284,27 +284,27 @@ func listTestEnv() []env {
|
|||
}
|
||||
|
||||
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
|
||||
s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream))
|
||||
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)}
|
||||
la := ":0"
|
||||
switch e.network {
|
||||
case "unix":
|
||||
la = "/tmp/testsock" + fmt.Sprintf("%p", s)
|
||||
la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now())
|
||||
syscall.Unlink(la)
|
||||
}
|
||||
lis, err := net.Listen(e.network, la)
|
||||
if err != nil {
|
||||
grpclog.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
testpb.RegisterTestServiceServer(s, &testServer{})
|
||||
if e.security == "tls" {
|
||||
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
||||
if err != nil {
|
||||
grpclog.Fatalf("Failed to generate credentials %v", err)
|
||||
}
|
||||
go s.Serve(creds.NewListener(lis))
|
||||
} else {
|
||||
go s.Serve(lis)
|
||||
sopts = append(sopts, grpc.Creds(creds))
|
||||
}
|
||||
s = grpc.NewServer(sopts...)
|
||||
testpb.RegisterTestServiceServer(s, &testServer{})
|
||||
go s.Serve(lis)
|
||||
addr := la
|
||||
switch e.network {
|
||||
case "unix":
|
||||
|
|
|
@ -111,17 +111,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
}
|
||||
// Perform handshake if opts.Handshaker is set.
|
||||
if opts.Handshaker != nil {
|
||||
auth, err := opts.Handshaker(conn)
|
||||
if err != nil {
|
||||
return nil, ConnectionErrorf("transport: handshaking failed %v", err)
|
||||
}
|
||||
// Prepend the resulting authenticator to opts.AuthOptions.
|
||||
if auth != nil {
|
||||
opts.AuthOptions = append([]credentials.Credentials{auth}, opts.AuthOptions...)
|
||||
}
|
||||
}
|
||||
for _, c := range opts.AuthOptions {
|
||||
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
|
||||
scheme = "https"
|
||||
|
@ -132,7 +121,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
|
|||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, connErr = ccreds.Handshake(addr, conn, timeout)
|
||||
conn, connErr = ccreds.ClientHandshake(addr, conn, timeout)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
@ -316,7 +316,6 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
|
|||
// ConnectOptions covers all relevant options for dialing a server.
|
||||
type ConnectOptions struct {
|
||||
Dialer func(string, time.Duration) (net.Conn, error)
|
||||
Handshaker func(conn net.Conn) (credentials.TransportAuthenticator, error)
|
||||
AuthOptions []credentials.Credentials
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
|
|
@ -47,7 +47,6 @@ import (
|
|||
"github.com/bradfitz/http2"
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
)
|
||||
|
||||
|
@ -61,7 +60,6 @@ type server struct {
|
|||
}
|
||||
|
||||
var (
|
||||
tlsDir = "testdata/"
|
||||
expectedRequest = []byte("ping")
|
||||
expectedResponse = []byte("pong")
|
||||
expectedRequestLarge = make([]byte, initialWindowSize*2)
|
||||
|
@ -129,7 +127,7 @@ func (h *testStreamHandler) handleStreamMisbehave(s *Stream) {
|
|||
}
|
||||
|
||||
// start starts server. Other goroutines should block on s.readyChan for futher operations.
|
||||
func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) {
|
||||
func (s *server) start(port int, maxStreams uint32, ht hType) {
|
||||
var err error
|
||||
if port == 0 {
|
||||
s.lis, err = net.Listen("tcp", ":0")
|
||||
|
@ -139,13 +137,6 @@ func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) {
|
|||
if err != nil {
|
||||
grpclog.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
if useTLS {
|
||||
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
|
||||
if err != nil {
|
||||
grpclog.Fatalf("Failed to generate credentials %v", err)
|
||||
}
|
||||
s.lis = creds.NewListener(s.lis)
|
||||
}
|
||||
_, p, err := net.SplitHostPort(s.lis.Addr().String())
|
||||
if err != nil {
|
||||
grpclog.Fatalf("failed to parse listener address: %v", err)
|
||||
|
@ -202,27 +193,16 @@ func (s *server) stop() {
|
|||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
|
||||
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
|
||||
server := &server{readyChan: make(chan bool)}
|
||||
go server.start(useTLS, port, maxStreams, ht)
|
||||
go server.start(port, maxStreams, ht)
|
||||
server.wait(t, 2*time.Second)
|
||||
addr := "localhost:" + server.port
|
||||
var (
|
||||
ct ClientTransport
|
||||
connErr error
|
||||
)
|
||||
if useTLS {
|
||||
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create credentials %v", err)
|
||||
}
|
||||
dopts := ConnectOptions{
|
||||
AuthOptions: []credentials.Credentials{creds},
|
||||
}
|
||||
ct, connErr = NewClientTransport(addr, &dopts)
|
||||
} else {
|
||||
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
|
||||
}
|
||||
if connErr != nil {
|
||||
t.Fatalf("failed to create transport: %v", connErr)
|
||||
}
|
||||
|
@ -230,7 +210,7 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*s
|
|||
}
|
||||
|
||||
func TestClientSendAndReceive(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
|
||||
server, ct := setUp(t, 0, math.MaxUint32, normal)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo.Small",
|
||||
|
@ -270,7 +250,7 @@ func TestClientSendAndReceive(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientErrorNotify(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
|
||||
server, ct := setUp(t, 0, math.MaxUint32, normal)
|
||||
go server.stop()
|
||||
// ct.reader should detect the error and activate ct.Error().
|
||||
<-ct.Error()
|
||||
|
@ -304,7 +284,7 @@ func performOneRPC(ct ClientTransport) {
|
|||
}
|
||||
|
||||
func TestClientMix(t *testing.T) {
|
||||
s, ct := setUp(t, true, 0, math.MaxUint32, normal)
|
||||
s, ct := setUp(t, 0, math.MaxUint32, normal)
|
||||
go func(s *server) {
|
||||
time.Sleep(5 * time.Second)
|
||||
s.stop()
|
||||
|
@ -320,7 +300,7 @@ func TestClientMix(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestExceedMaxStreamsLimit(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, 1, normal)
|
||||
server, ct := setUp(t, 0, 1, normal)
|
||||
defer func() {
|
||||
ct.Close()
|
||||
server.stop()
|
||||
|
@ -368,7 +348,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLargeMessage(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
|
||||
server, ct := setUp(t, 0, math.MaxUint32, normal)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo.Large",
|
||||
|
@ -402,7 +382,7 @@ func TestLargeMessage(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLargeMessageSuspension(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, math.MaxUint32, suspended)
|
||||
server, ct := setUp(t, 0, math.MaxUint32, suspended)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo.Large",
|
||||
|
@ -424,7 +404,7 @@ func TestLargeMessageSuspension(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerWithMisbehavedClient(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, math.MaxUint32, suspended)
|
||||
server, ct := setUp(t, 0, math.MaxUint32, suspended)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo",
|
||||
|
@ -524,7 +504,7 @@ func TestServerWithMisbehavedClient(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClientWithMisbehavedServer(t *testing.T) {
|
||||
server, ct := setUp(t, true, 0, math.MaxUint32, misbehaved)
|
||||
server, ct := setUp(t, 0, math.MaxUint32, misbehaved)
|
||||
callHdr := &CallHdr{
|
||||
Host: "localhost",
|
||||
Method: "foo",
|
||||
|
|
Загрузка…
Ссылка в новой задаче