This commit is contained in:
iamqizhao 2015-05-12 17:59:20 -07:00
Родитель 923d211a3d
Коммит 3617cd5ab3
10 изменённых файлов: 67 добавлений и 89 удалений

Просмотреть файл

@ -107,11 +107,11 @@ func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) Di
// WithHandshaker returns a DialOption that specifies a function to perform some handshaking
// with the server. It is typically used to negotiate the wire protocol version and security
// protocol with the server.
func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption {
return func(o *dialOptions) {
o.copts.Handshaker = h
}
}
//func WithHandshaker(h func(conn net.Conn) (credentials.TransportAuthenticator, error)) DialOption {
// return func(o *dialOptions) {
// o.copts.Handshaker = h
// }
//}
// Dial creates a client connection the given target.
// TODO(zhaoq): Have an option to make Dial return immediately without waiting

Просмотреть файл

@ -84,12 +84,14 @@ type ProtocolInfo struct {
// TransportAuthenticator defines the common interface for all the live gRPC wire
// protocols and supported transport security protocols (e.g., TLS, SSL).
type TransportAuthenticator interface {
// Handshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn.
Handshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
// ClientHandshake does the authentication handshake specified by the corresponding
// authentication protocol on rawConn for clients.
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
// ServerHandshake does the authentication handshake for servers.
ServerHandshake(rawConn net.Conn) (net.Conn, error)
// NewListener creates a listener which accepts connections with requested
// authentication handshake.
NewListener(lis net.Listener) net.Listener
//NewListener(lis net.Listener) net.Listener
// Info provides the ProtocolInfo of this TransportAuthenticator.
Info() ProtocolInfo
Credentials
@ -120,7 +122,7 @@ func (timeoutError) Error() string { return "credentials: Dial timed out" }
func (timeoutError) Timeout() bool { return true }
func (timeoutError) Temporary() bool { return true }
func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
// borrow some code from tls.DialWithDialer
var errChannel chan error
if timeout != 0 {
@ -152,9 +154,13 @@ func (c *tlsCreds) Handshake(addr string, rawConn net.Conn, timeout time.Duratio
return conn, nil
}
// NewListener creates a net.Listener using the information in tlsCreds.
func (c *tlsCreds) NewListener(lis net.Listener) net.Listener {
return tls.NewListener(lis, &c.config)
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) {
conn := tls.Server(rawConn, &c.config)
if err := conn.Handshake(); err != nil {
rawConn.Close()
return nil, err
}
return conn, nil
}
// NewTLS uses c to construct a TransportAuthenticator based on TLS.

Просмотреть файл

@ -225,15 +225,15 @@ func main() {
if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
}
grpcServer := grpc.NewServer()
pb.RegisterRouteGuideServer(grpcServer, newServer())
var opts []grpc.ServerOption
if *tls {
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
grpcServer.Serve(creds.NewListener(lis))
} else {
opts = []grpc.ServerOption{grpc.Creds(creds)}
}
grpcServer := grpc.NewServer(opts...)
pb.RegisterRouteGuideServer(grpcServer, newServer())
grpcServer.Serve(lis)
}
}

Просмотреть файл

@ -15,6 +15,8 @@ creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
log.Fatalf("Failed to generate credentials %v", err)
}
server := grpc.NewServer(grpc.Creds(creds))
...
server.Serve(creds.NewListener(lis))
```

Просмотреть файл

@ -195,15 +195,15 @@ func main() {
if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
}
server := grpc.NewServer()
testpb.RegisterTestServiceServer(server, &testServer{})
var opts []grpc.ServerOption
if *useTLS {
creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile)
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
server.Serve(creds.NewListener(lis))
} else {
opts = []grpc.ServerOption{grpc.Creds(creds)}
}
server := grpc.NewServer(opts...)
testpb.RegisterTestServiceServer(server, &testServer{})
server.Serve(lis)
}
}

Просмотреть файл

@ -44,6 +44,7 @@ import (
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/transport"
@ -85,7 +86,7 @@ type Server struct {
}
type options struct {
handshaker func(net.Conn) error
creds []credentials.Credentials
codec Codec
maxConcurrentStreams uint32
}
@ -93,14 +94,6 @@ type options struct {
// A ServerOption sets options.
type ServerOption func(*options)
// Handshaker returns a ServerOption that specifies a function to perform user-specified
// handshaking on the connection before it becomes usable for gRPC.
func Handshaker(f func(net.Conn) error) ServerOption {
return func(o *options) {
o.handshaker = f
}
}
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
func CustomCodec(codec Codec) ServerOption {
return func(o *options) {
@ -116,6 +109,13 @@ func MaxConcurrentStreams(n uint32) ServerOption {
}
}
// Creds returns a ServerOption that sets credentials for server connections.
func Creds(c credentials.Credentials) ServerOption {
return func(o *options) {
o.creds = append(o.creds, c)
}
}
// NewServer creates a gRPC server which has no service registered and has not
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
@ -195,13 +195,15 @@ func (s *Server) Serve(lis net.Listener) error {
if err != nil {
return err
}
// Perform handshaking if it is required.
if s.opts.handshaker != nil {
if err := s.opts.handshaker(c); err != nil {
grpclog.Println("grpc: Server.Serve failed to complete handshake.")
c.Close()
for _, o := range s.opts.creds {
if creds, ok := o.(credentials.TransportAuthenticator); ok {
c, err = creds.ServerHandshake(c)
if err != nil {
grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
continue
}
break
}
}
s.mu.Lock()
if s.conns == nil {

Просмотреть файл

@ -284,27 +284,27 @@ func listTestEnv() []env {
}
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream))
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)}
la := ":0"
switch e.network {
case "unix":
la = "/tmp/testsock" + fmt.Sprintf("%p", s)
la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now())
syscall.Unlink(la)
}
lis, err := net.Listen(e.network, la)
if err != nil {
grpclog.Fatalf("Failed to listen: %v", err)
}
testpb.RegisterTestServiceServer(s, &testServer{})
if e.security == "tls" {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
go s.Serve(creds.NewListener(lis))
} else {
go s.Serve(lis)
sopts = append(sopts, grpc.Creds(creds))
}
s = grpc.NewServer(sopts...)
testpb.RegisterTestServiceServer(s, &testServer{})
go s.Serve(lis)
addr := la
switch e.network {
case "unix":

Просмотреть файл

@ -111,17 +111,6 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)
}
// Perform handshake if opts.Handshaker is set.
if opts.Handshaker != nil {
auth, err := opts.Handshaker(conn)
if err != nil {
return nil, ConnectionErrorf("transport: handshaking failed %v", err)
}
// Prepend the resulting authenticator to opts.AuthOptions.
if auth != nil {
opts.AuthOptions = append([]credentials.Credentials{auth}, opts.AuthOptions...)
}
}
for _, c := range opts.AuthOptions {
if ccreds, ok := c.(credentials.TransportAuthenticator); ok {
scheme = "https"
@ -132,7 +121,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
if timeout > 0 {
timeout -= time.Since(startT)
}
conn, connErr = ccreds.Handshake(addr, conn, timeout)
conn, connErr = ccreds.ClientHandshake(addr, conn, timeout)
break
}
}

Просмотреть файл

@ -316,7 +316,6 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
// ConnectOptions covers all relevant options for dialing a server.
type ConnectOptions struct {
Dialer func(string, time.Duration) (net.Conn, error)
Handshaker func(conn net.Conn) (credentials.TransportAuthenticator, error)
AuthOptions []credentials.Credentials
Timeout time.Duration
}

Просмотреть файл

@ -47,7 +47,6 @@ import (
"github.com/bradfitz/http2"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
)
@ -61,7 +60,6 @@ type server struct {
}
var (
tlsDir = "testdata/"
expectedRequest = []byte("ping")
expectedResponse = []byte("pong")
expectedRequestLarge = make([]byte, initialWindowSize*2)
@ -129,7 +127,7 @@ func (h *testStreamHandler) handleStreamMisbehave(s *Stream) {
}
// start starts server. Other goroutines should block on s.readyChan for futher operations.
func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) {
func (s *server) start(port int, maxStreams uint32, ht hType) {
var err error
if port == 0 {
s.lis, err = net.Listen("tcp", ":0")
@ -139,13 +137,6 @@ func (s *server) start(useTLS bool, port int, maxStreams uint32, ht hType) {
if err != nil {
grpclog.Fatalf("failed to listen: %v", err)
}
if useTLS {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
s.lis = creds.NewListener(s.lis)
}
_, p, err := net.SplitHostPort(s.lis.Addr().String())
if err != nil {
grpclog.Fatalf("failed to parse listener address: %v", err)
@ -202,27 +193,16 @@ func (s *server) stop() {
s.mu.Unlock()
}
func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
server := &server{readyChan: make(chan bool)}
go server.start(useTLS, port, maxStreams, ht)
go server.start(port, maxStreams, ht)
server.wait(t, 2*time.Second)
addr := "localhost:" + server.port
var (
ct ClientTransport
connErr error
)
if useTLS {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
t.Fatalf("Failed to create credentials %v", err)
}
dopts := ConnectOptions{
AuthOptions: []credentials.Credentials{creds},
}
ct, connErr = NewClientTransport(addr, &dopts)
} else {
ct, connErr = NewClientTransport(addr, &ConnectOptions{})
}
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}
@ -230,7 +210,7 @@ func setUp(t *testing.T, useTLS bool, port int, maxStreams uint32, ht hType) (*s
}
func TestClientSendAndReceive(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
@ -270,7 +250,7 @@ func TestClientSendAndReceive(t *testing.T) {
}
func TestClientErrorNotify(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
server, ct := setUp(t, 0, math.MaxUint32, normal)
go server.stop()
// ct.reader should detect the error and activate ct.Error().
<-ct.Error()
@ -304,7 +284,7 @@ func performOneRPC(ct ClientTransport) {
}
func TestClientMix(t *testing.T) {
s, ct := setUp(t, true, 0, math.MaxUint32, normal)
s, ct := setUp(t, 0, math.MaxUint32, normal)
go func(s *server) {
time.Sleep(5 * time.Second)
s.stop()
@ -320,7 +300,7 @@ func TestClientMix(t *testing.T) {
}
func TestExceedMaxStreamsLimit(t *testing.T) {
server, ct := setUp(t, true, 0, 1, normal)
server, ct := setUp(t, 0, 1, normal)
defer func() {
ct.Close()
server.stop()
@ -368,7 +348,7 @@ func TestExceedMaxStreamsLimit(t *testing.T) {
}
func TestLargeMessage(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, normal)
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
@ -402,7 +382,7 @@ func TestLargeMessage(t *testing.T) {
}
func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, suspended)
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
@ -424,7 +404,7 @@ func TestLargeMessageSuspension(t *testing.T) {
}
func TestServerWithMisbehavedClient(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, suspended)
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
@ -524,7 +504,7 @@ func TestServerWithMisbehavedClient(t *testing.T) {
}
func TestClientWithMisbehavedServer(t *testing.T) {
server, ct := setUp(t, true, 0, math.MaxUint32, misbehaved)
server, ct := setUp(t, 0, math.MaxUint32, misbehaved)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",