internal/resolver/unix: Implemented unix resolver. (#3890)
This commit is contained in:
Родитель
ea47aa91b3
Коммит
4be647f7f6
|
@ -201,7 +201,7 @@ func (*rlsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig,
|
||||||
if lookupService == "" {
|
if lookupService == "" {
|
||||||
return nil, fmt.Errorf("rls: empty lookup_service in service config {%+v}", string(c))
|
return nil, fmt.Errorf("rls: empty lookup_service in service config {%+v}", string(c))
|
||||||
}
|
}
|
||||||
parsedTarget := grpcutil.ParseTarget(lookupService)
|
parsedTarget := grpcutil.ParseTarget(lookupService, false)
|
||||||
if parsedTarget.Scheme == "" {
|
if parsedTarget.Scheme == "" {
|
||||||
parsedTarget.Scheme = resolver.GetDefaultScheme()
|
parsedTarget.Scheme = resolver.GetDefaultScheme()
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -48,6 +47,7 @@ import (
|
||||||
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
|
_ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin.
|
||||||
_ "google.golang.org/grpc/internal/resolver/dns" // To register dns resolver.
|
_ "google.golang.org/grpc/internal/resolver/dns" // To register dns resolver.
|
||||||
_ "google.golang.org/grpc/internal/resolver/passthrough" // To register passthrough resolver.
|
_ "google.golang.org/grpc/internal/resolver/passthrough" // To register passthrough resolver.
|
||||||
|
_ "google.golang.org/grpc/internal/resolver/unix" // To register unix resolver.
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -191,16 +191,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||||
}
|
}
|
||||||
cc.mkp = cc.dopts.copts.KeepaliveParams
|
cc.mkp = cc.dopts.copts.KeepaliveParams
|
||||||
|
|
||||||
if cc.dopts.copts.Dialer == nil {
|
|
||||||
cc.dopts.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
network, addr := parseDialTarget(addr)
|
|
||||||
return (&net.Dialer{}).DialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
if cc.dopts.withProxy {
|
|
||||||
cc.dopts.copts.Dialer = newProxyDialer(cc.dopts.copts.Dialer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cc.dopts.copts.UserAgent != "" {
|
if cc.dopts.copts.UserAgent != "" {
|
||||||
cc.dopts.copts.UserAgent += " " + grpcUA
|
cc.dopts.copts.UserAgent += " " + grpcUA
|
||||||
} else {
|
} else {
|
||||||
|
@ -244,8 +234,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the resolver to use.
|
// Determine the resolver to use.
|
||||||
cc.parsedTarget = grpcutil.ParseTarget(cc.target)
|
cc.parsedTarget = grpcutil.ParseTarget(cc.target, cc.dopts.copts.Dialer != nil)
|
||||||
unixScheme := strings.HasPrefix(cc.target, "unix:")
|
|
||||||
channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme)
|
channelz.Infof(logger, cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme)
|
||||||
resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme)
|
resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme)
|
||||||
if resolverBuilder == nil {
|
if resolverBuilder == nil {
|
||||||
|
@ -268,7 +257,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
|
||||||
cc.authority = creds.Info().ServerName
|
cc.authority = creds.Info().ServerName
|
||||||
} else if cc.dopts.insecure && cc.dopts.authority != "" {
|
} else if cc.dopts.insecure && cc.dopts.authority != "" {
|
||||||
cc.authority = cc.dopts.authority
|
cc.authority = cc.dopts.authority
|
||||||
} else if unixScheme {
|
} else if strings.HasPrefix(cc.target, "unix:") {
|
||||||
cc.authority = "localhost"
|
cc.authority = "localhost"
|
||||||
} else {
|
} else {
|
||||||
// Use endpoint from "scheme://authority/endpoint" as the default
|
// Use endpoint from "scheme://authority/endpoint" as the default
|
||||||
|
|
|
@ -71,7 +71,6 @@ type dialOptions struct {
|
||||||
// we need to be able to configure this in tests.
|
// we need to be able to configure this in tests.
|
||||||
resolveNowBackoff func(int) time.Duration
|
resolveNowBackoff func(int) time.Duration
|
||||||
resolvers []resolver.Builder
|
resolvers []resolver.Builder
|
||||||
withProxy bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOption configures how we set up the connection.
|
// DialOption configures how we set up the connection.
|
||||||
|
@ -325,7 +324,7 @@ func WithInsecure() DialOption {
|
||||||
// later release.
|
// later release.
|
||||||
func WithNoProxy() DialOption {
|
func WithNoProxy() DialOption {
|
||||||
return newFuncDialOption(func(o *dialOptions) {
|
return newFuncDialOption(func(o *dialOptions) {
|
||||||
o.withProxy = false
|
o.copts.UseProxy = false
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -595,9 +594,9 @@ func defaultDialOptions() dialOptions {
|
||||||
copts: transport.ConnectOptions{
|
copts: transport.ConnectOptions{
|
||||||
WriteBufferSize: defaultWriteBufSize,
|
WriteBufferSize: defaultWriteBufSize,
|
||||||
ReadBufferSize: defaultReadBufSize,
|
ReadBufferSize: defaultReadBufSize,
|
||||||
|
UseProxy: true,
|
||||||
},
|
},
|
||||||
resolveNowBackoff: internalbackoff.DefaultExponential.Backoff,
|
resolveNowBackoff: internalbackoff.DefaultExponential.Backoff,
|
||||||
withProxy: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,19 +37,32 @@ func split2(s, sep string) (string, string, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseTarget splits target into a resolver.Target struct containing scheme,
|
// ParseTarget splits target into a resolver.Target struct containing scheme,
|
||||||
// authority and endpoint.
|
// authority and endpoint. skipUnixColonParsing indicates that the parse should
|
||||||
|
// not parse "unix:[path]" cases. This should be true in cases where a custom
|
||||||
|
// dialer is present, to prevent a behavior change.
|
||||||
//
|
//
|
||||||
// If target is not a valid scheme://authority/endpoint, it returns {Endpoint:
|
// If target is not a valid scheme://authority/endpoint, it returns {Endpoint:
|
||||||
// target}.
|
// target}.
|
||||||
func ParseTarget(target string) (ret resolver.Target) {
|
func ParseTarget(target string, skipUnixColonParsing bool) (ret resolver.Target) {
|
||||||
var ok bool
|
var ok bool
|
||||||
ret.Scheme, ret.Endpoint, ok = split2(target, "://")
|
ret.Scheme, ret.Endpoint, ok = split2(target, "://")
|
||||||
if !ok {
|
if !ok {
|
||||||
|
if strings.HasPrefix(target, "unix:") && !skipUnixColonParsing {
|
||||||
|
// Handle the "unix:[path]" case, because splitting on :// only
|
||||||
|
// handles the "unix://[/absolute/path]" case. Only handle if the
|
||||||
|
// dialer is nil, to avoid a behavior change with custom dialers.
|
||||||
|
return resolver.Target{Scheme: "unix", Endpoint: target[len("unix:"):]}
|
||||||
|
}
|
||||||
return resolver.Target{Endpoint: target}
|
return resolver.Target{Endpoint: target}
|
||||||
}
|
}
|
||||||
ret.Authority, ret.Endpoint, ok = split2(ret.Endpoint, "/")
|
ret.Authority, ret.Endpoint, ok = split2(ret.Endpoint, "/")
|
||||||
if !ok {
|
if !ok {
|
||||||
return resolver.Target{Endpoint: target}
|
return resolver.Target{Endpoint: target}
|
||||||
}
|
}
|
||||||
|
if ret.Scheme == "unix" {
|
||||||
|
// Add the "/" back in the unix case, so the unix resolver receives the
|
||||||
|
// actual endpoint.
|
||||||
|
ret.Endpoint = "/" + ret.Endpoint
|
||||||
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,17 +32,22 @@ func TestParseTarget(t *testing.T) {
|
||||||
{Scheme: "passthrough", Authority: "", Endpoint: "/unix/socket/address"},
|
{Scheme: "passthrough", Authority: "", Endpoint: "/unix/socket/address"},
|
||||||
} {
|
} {
|
||||||
str := test.Scheme + "://" + test.Authority + "/" + test.Endpoint
|
str := test.Scheme + "://" + test.Authority + "/" + test.Endpoint
|
||||||
got := ParseTarget(str)
|
got := ParseTarget(str, false)
|
||||||
if got != test {
|
if got != test {
|
||||||
t.Errorf("ParseTarget(%q) = %+v, want %+v", str, got, test)
|
t.Errorf("ParseTarget(%q, false) = %+v, want %+v", str, got, test)
|
||||||
|
}
|
||||||
|
got = ParseTarget(str, true)
|
||||||
|
if got != test {
|
||||||
|
t.Errorf("ParseTarget(%q, true) = %+v, want %+v", str, got, test)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseTargetString(t *testing.T) {
|
func TestParseTargetString(t *testing.T) {
|
||||||
for _, test := range []struct {
|
for _, test := range []struct {
|
||||||
targetStr string
|
targetStr string
|
||||||
want resolver.Target
|
want resolver.Target
|
||||||
|
wantWithDialer resolver.Target
|
||||||
}{
|
}{
|
||||||
{targetStr: "", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}},
|
{targetStr: "", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}},
|
||||||
{targetStr: ":///", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}},
|
{targetStr: ":///", want: resolver.Target{Scheme: "", Authority: "", Endpoint: ""}},
|
||||||
|
@ -70,10 +75,24 @@ func TestParseTargetString(t *testing.T) {
|
||||||
{targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}},
|
{targetStr: "a:/b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a:/b"}},
|
||||||
{targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}},
|
{targetStr: "a//b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a//b"}},
|
||||||
{targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}},
|
{targetStr: "a://b", want: resolver.Target{Scheme: "", Authority: "", Endpoint: "a://b"}},
|
||||||
|
|
||||||
|
// Unix cases without custom dialer.
|
||||||
|
// unix:[local_path] and unix:[/absolute] have different behaviors with
|
||||||
|
// a custom dialer, to prevent behavior changes with custom dialers.
|
||||||
|
{targetStr: "unix:domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:domain"}},
|
||||||
|
{targetStr: "unix:/domain", want: resolver.Target{Scheme: "unix", Authority: "", Endpoint: "/domain"}, wantWithDialer: resolver.Target{Scheme: "", Authority: "", Endpoint: "unix:/domain"}},
|
||||||
} {
|
} {
|
||||||
got := ParseTarget(test.targetStr)
|
got := ParseTarget(test.targetStr, false)
|
||||||
if got != test.want {
|
if got != test.want {
|
||||||
t.Errorf("ParseTarget(%q) = %+v, want %+v", test.targetStr, got, test.want)
|
t.Errorf("ParseTarget(%q, false) = %+v, want %+v", test.targetStr, got, test.want)
|
||||||
|
}
|
||||||
|
wantWithDialer := test.wantWithDialer
|
||||||
|
if wantWithDialer == (resolver.Target{}) {
|
||||||
|
wantWithDialer = test.want
|
||||||
|
}
|
||||||
|
got = ParseTarget(test.targetStr, true)
|
||||||
|
if got != wantWithDialer {
|
||||||
|
t.Errorf("ParseTarget(%q, true) = %+v, want %+v", test.targetStr, got, wantWithDialer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2020 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package unix implements a resolver for unix targets.
|
||||||
|
package unix
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc/internal/transport/networktype"
|
||||||
|
"google.golang.org/grpc/resolver"
|
||||||
|
)
|
||||||
|
|
||||||
|
const scheme = "unix"
|
||||||
|
|
||||||
|
type builder struct{}
|
||||||
|
|
||||||
|
func (*builder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
|
||||||
|
cc.UpdateState(resolver.State{Addresses: []resolver.Address{networktype.Set(resolver.Address{Addr: target.Endpoint}, "unix")}})
|
||||||
|
return &nopResolver{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*builder) Scheme() string {
|
||||||
|
return scheme
|
||||||
|
}
|
||||||
|
|
||||||
|
type nopResolver struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*nopResolver) ResolveNow(resolver.ResolveNowOptions) {}
|
||||||
|
|
||||||
|
func (*nopResolver) Close() {}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
resolver.Register(&builder{})
|
||||||
|
}
|
|
@ -33,6 +33,7 @@ import (
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
"google.golang.org/grpc/internal/grpcutil"
|
"google.golang.org/grpc/internal/grpcutil"
|
||||||
|
"google.golang.org/grpc/internal/transport/networktype"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
@ -137,11 +138,18 @@ type http2Client struct {
|
||||||
connectionID uint64
|
connectionID uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
|
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) {
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
return fn(ctx, addr)
|
return fn(ctx, addr.Addr)
|
||||||
}
|
}
|
||||||
return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
|
networkType := "tcp"
|
||||||
|
if n, ok := networktype.Get(addr); ok {
|
||||||
|
networkType = n
|
||||||
|
}
|
||||||
|
if networkType == "tcp" && useProxy {
|
||||||
|
return proxyDial(ctx, addr.Addr, grpcUA)
|
||||||
|
}
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, networkType, addr.Addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isTemporary(err error) bool {
|
func isTemporary(err error) bool {
|
||||||
|
@ -172,7 +180,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, err := dial(connectCtx, opts.Dialer, addr.Addr)
|
conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if opts.FailOnNonTempDialError {
|
if opts.FailOnNonTempDialError {
|
||||||
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
|
return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err)
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2020 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package networktype declares the network type to be used in the default
|
||||||
|
// dailer. Attribute of a resolver.Address.
|
||||||
|
package networktype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc/resolver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// keyType is the key to use for storing State in Attributes.
|
||||||
|
type keyType string
|
||||||
|
|
||||||
|
const key = keyType("grpc.internal.transport.networktype")
|
||||||
|
|
||||||
|
// Set returns a copy of the provided address with attributes containing networkType.
|
||||||
|
func Set(address resolver.Address, networkType string) resolver.Address {
|
||||||
|
address.Attributes = address.Attributes.WithValues(key, networkType)
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the network type in the resolver.Address and true, or "", false
|
||||||
|
// if not present.
|
||||||
|
func Get(address resolver.Address) (string, bool) {
|
||||||
|
v := address.Attributes.Value(key)
|
||||||
|
if v == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return v.(string), true
|
||||||
|
}
|
|
@ -16,13 +16,12 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package grpc
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -34,8 +33,6 @@ import (
|
||||||
const proxyAuthHeaderKey = "Proxy-Authorization"
|
const proxyAuthHeaderKey = "Proxy-Authorization"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// errDisabled indicates that proxy is disabled for the address.
|
|
||||||
errDisabled = errors.New("proxy is disabled for the address")
|
|
||||||
// The following variable will be overwritten in the tests.
|
// The following variable will be overwritten in the tests.
|
||||||
httpProxyFromEnvironment = http.ProxyFromEnvironment
|
httpProxyFromEnvironment = http.ProxyFromEnvironment
|
||||||
)
|
)
|
||||||
|
@ -51,9 +48,6 @@ func mapAddress(ctx context.Context, address string) (*url.URL, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if url == nil {
|
|
||||||
return nil, errDisabled
|
|
||||||
}
|
|
||||||
return url, nil
|
return url, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +70,7 @@ func basicAuth(username, password string) string {
|
||||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||||
}
|
}
|
||||||
|
|
||||||
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
|
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -115,32 +109,28 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
|
||||||
return &bufConn{Conn: conn, r: r}, nil
|
return &bufConn{Conn: conn, r: r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newProxyDialer returns a dialer that connects to proxy first if necessary.
|
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
|
||||||
// The returned dialer checks if a proxy is necessary, dial to the proxy with the
|
// is necessary, dials, does the HTTP CONNECT handshake, and returns the
|
||||||
// provided dialer, does HTTP CONNECT handshake and returns the connection.
|
// connection.
|
||||||
func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
|
func proxyDial(ctx context.Context, addr string, grpcUA string) (conn net.Conn, err error) {
|
||||||
return func(ctx context.Context, addr string) (conn net.Conn, err error) {
|
newAddr := addr
|
||||||
var newAddr string
|
proxyURL, err := mapAddress(ctx, addr)
|
||||||
proxyURL, err := mapAddress(ctx, addr)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
if err != errDisabled {
|
}
|
||||||
return nil, err
|
if proxyURL != nil {
|
||||||
}
|
newAddr = proxyURL.Host
|
||||||
newAddr = addr
|
}
|
||||||
} else {
|
|
||||||
newAddr = proxyURL.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err = dialer(ctx, newAddr)
|
conn, err = (&net.Dialer{}).DialContext(ctx, "tcp", newAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
|
||||||
}
|
|
||||||
if proxyURL != nil {
|
|
||||||
// proxy is disabled if proxyURL is nil.
|
|
||||||
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if proxyURL != nil {
|
||||||
|
// proxy is disabled if proxyURL is nil.
|
||||||
|
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
|
func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error {
|
|
@ -18,7 +18,7 @@
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package grpc
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
@ -138,15 +138,9 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
|
||||||
defer overwrite(hpfe)()
|
defer overwrite(hpfe)()
|
||||||
|
|
||||||
// Dial to proxy server.
|
// Dial to proxy server.
|
||||||
dialer := newProxyDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
|
||||||
return net.DialTimeout("tcp", addr, time.Until(deadline))
|
|
||||||
}
|
|
||||||
return net.Dial("tcp", addr)
|
|
||||||
})
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
c, err := dialer(ctx, blis.Addr().String())
|
c, err := proxyDial(ctx, blis.Addr().String(), "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("http connect Dial failed: %v", err)
|
t.Fatalf("http connect Dial failed: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -173,9 +167,6 @@ func (s) TestHTTPConnect(t *testing.T) {
|
||||||
if req.Method != http.MethodConnect {
|
if req.Method != http.MethodConnect {
|
||||||
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
|
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
|
||||||
}
|
}
|
||||||
if req.UserAgent() != grpcUA {
|
|
||||||
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -195,9 +186,6 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) {
|
||||||
if req.Method != http.MethodConnect {
|
if req.Method != http.MethodConnect {
|
||||||
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
|
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
|
||||||
}
|
}
|
||||||
if req.UserAgent() != grpcUA {
|
|
||||||
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
|
|
||||||
}
|
|
||||||
wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
|
wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
|
||||||
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
|
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
|
||||||
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
|
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
|
|
@ -569,6 +569,8 @@ type ConnectOptions struct {
|
||||||
ChannelzParentID int64
|
ChannelzParentID int64
|
||||||
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
|
// MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received.
|
||||||
MaxHeaderListSize *uint32
|
MaxHeaderListSize *uint32
|
||||||
|
// UseProxy specifies if a proxy should be used.
|
||||||
|
UseProxy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||||
|
|
|
@ -45,9 +45,6 @@ func (s) TestDialParseTargetUnknownScheme(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"/unix/socket/address", "/unix/socket/address"},
|
{"/unix/socket/address", "/unix/socket/address"},
|
||||||
|
|
||||||
// Special test for "unix:///".
|
|
||||||
{"unix:///unix/socket/address", "unix:///unix/socket/address"},
|
|
||||||
|
|
||||||
// For known scheme.
|
// For known scheme.
|
||||||
{"passthrough://a.server.com/google.com", "google.com"},
|
{"passthrough://a.server.com/google.com", "google.com"},
|
||||||
} {
|
} {
|
||||||
|
|
35
rpc_util.go
35
rpc_util.go
|
@ -27,7 +27,6 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math"
|
"math"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -872,40 +871,6 @@ func setCallInfoCodec(c *callInfo) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseDialTarget returns the network and address to pass to dialer
|
|
||||||
func parseDialTarget(target string) (net string, addr string) {
|
|
||||||
net = "tcp"
|
|
||||||
|
|
||||||
m1 := strings.Index(target, ":")
|
|
||||||
m2 := strings.Index(target, ":/")
|
|
||||||
|
|
||||||
// handle unix:addr which will fail with url.Parse
|
|
||||||
if m1 >= 0 && m2 < 0 {
|
|
||||||
if n := target[0:m1]; n == "unix" {
|
|
||||||
net = n
|
|
||||||
addr = target[m1+1:]
|
|
||||||
return net, addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if m2 >= 0 {
|
|
||||||
t, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
return net, target
|
|
||||||
}
|
|
||||||
scheme := t.Scheme
|
|
||||||
addr = t.Path
|
|
||||||
if scheme == "unix" {
|
|
||||||
net = scheme
|
|
||||||
if addr == "" {
|
|
||||||
addr = t.Host
|
|
||||||
}
|
|
||||||
return net, addr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return net, target
|
|
||||||
}
|
|
||||||
|
|
||||||
// channelzData is used to store channelz related data for ClientConn, addrConn and Server.
|
// channelzData is used to store channelz related data for ClientConn, addrConn and Server.
|
||||||
// These fields cannot be embedded in the original structs (e.g. ClientConn), since to do atomic
|
// These fields cannot be embedded in the original structs (e.g. ClientConn), since to do atomic
|
||||||
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
|
// operation on int64 variable on 32-bit machine, user is responsible to enforce memory alignment.
|
||||||
|
|
|
@ -191,27 +191,6 @@ func (s) TestToRPCErr(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestParseDialTarget(t *testing.T) {
|
|
||||||
for _, test := range []struct {
|
|
||||||
target, wantNet, wantAddr string
|
|
||||||
}{
|
|
||||||
{"unix:etcd:0", "unix", "etcd:0"},
|
|
||||||
{"unix:///tmp/unix-3", "unix", "/tmp/unix-3"},
|
|
||||||
{"unix://domain", "unix", "domain"},
|
|
||||||
{"unix://etcd:0", "unix", "etcd:0"},
|
|
||||||
{"unix:///etcd:0", "unix", "/etcd:0"},
|
|
||||||
{"passthrough://unix://domain", "tcp", "passthrough://unix://domain"},
|
|
||||||
{"https://google.com:443", "tcp", "https://google.com:443"},
|
|
||||||
{"dns:///google.com", "tcp", "dns:///google.com"},
|
|
||||||
{"/unix/socket/address", "tcp", "/unix/socket/address"},
|
|
||||||
} {
|
|
||||||
gotNet, gotAddr := parseDialTarget(test.target)
|
|
||||||
if gotNet != test.wantNet || gotAddr != test.wantAddr {
|
|
||||||
t.Errorf("parseDialTarget(%q) = %s, %s want %s, %s", test.target, gotNet, gotAddr, test.wantNet, test.wantAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
|
// bmEncode benchmarks encoding a Protocol Buffer message containing mSize
|
||||||
// bytes.
|
// bytes.
|
||||||
func bmEncode(b *testing.B, mSize int) {
|
func bmEncode(b *testing.B, mSize int) {
|
||||||
|
|
|
@ -21,17 +21,19 @@ package test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
testpb "google.golang.org/grpc/test/grpc_testing"
|
testpb "google.golang.org/grpc/test/grpc_testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func runUnixTest(t *testing.T, address, target, expectedAuthority string) {
|
func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) {
|
||||||
if err := os.RemoveAll(address); err != nil {
|
if err := os.RemoveAll(address); err != nil {
|
||||||
t.Fatalf("Error removing socket file %v: %v\n", address, err)
|
t.Fatalf("Error removing socket file %v: %v\n", address, err)
|
||||||
}
|
}
|
||||||
|
@ -57,7 +59,11 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string) {
|
||||||
address: address,
|
address: address,
|
||||||
target: target,
|
target: target,
|
||||||
}
|
}
|
||||||
if err := us.Start(nil); err != nil {
|
opts := []grpc.DialOption{}
|
||||||
|
if dialer != nil {
|
||||||
|
opts = append(opts, grpc.WithContextDialer(dialer))
|
||||||
|
}
|
||||||
|
if err := us.Start(nil, opts...); err != nil {
|
||||||
t.Fatalf("Error starting endpoint server: %v", err)
|
t.Fatalf("Error starting endpoint server: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -70,6 +76,8 @@ func runUnixTest(t *testing.T, address, target, expectedAuthority string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUnix does end to end tests with the various supported unix target
|
||||||
|
// formats, ensuring that the authority is set to localhost in every case.
|
||||||
func (s) TestUnix(t *testing.T) {
|
func (s) TestUnix(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -78,19 +86,19 @@ func (s) TestUnix(t *testing.T) {
|
||||||
authority string
|
authority string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Unix1",
|
name: "UnixRelative",
|
||||||
address: "sock.sock",
|
address: "sock.sock",
|
||||||
target: "unix:sock.sock",
|
target: "unix:sock.sock",
|
||||||
authority: "localhost",
|
authority: "localhost",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Unix2",
|
name: "UnixAbsolute",
|
||||||
address: "/tmp/sock.sock",
|
address: "/tmp/sock.sock",
|
||||||
target: "unix:/tmp/sock.sock",
|
target: "unix:/tmp/sock.sock",
|
||||||
authority: "localhost",
|
authority: "localhost",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Unix3",
|
name: "UnixAbsoluteAlternate",
|
||||||
address: "/tmp/sock.sock",
|
address: "/tmp/sock.sock",
|
||||||
target: "unix:///tmp/sock.sock",
|
target: "unix:///tmp/sock.sock",
|
||||||
authority: "localhost",
|
authority: "localhost",
|
||||||
|
@ -98,7 +106,44 @@ func (s) TestUnix(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
runUnixTest(t, test.address, test.target, test.authority)
|
runUnixTest(t, test.address, test.target, test.authority, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUnixCustomDialer does end to end tests with various supported unix target
|
||||||
|
// formats, ensuring that the target sent to the dialer does NOT have the
|
||||||
|
// "unix:" prefix stripped.
|
||||||
|
func (s) TestUnixCustomDialer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
address string
|
||||||
|
target string
|
||||||
|
authority string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "UnixRelative",
|
||||||
|
address: "sock.sock",
|
||||||
|
target: "unix:sock.sock",
|
||||||
|
authority: "localhost",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UnixAbsolute",
|
||||||
|
address: "/tmp/sock.sock",
|
||||||
|
target: "unix:/tmp/sock.sock",
|
||||||
|
authority: "localhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
dialer := func(ctx context.Context, address string) (net.Conn, error) {
|
||||||
|
if address != test.target {
|
||||||
|
return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.target, address)
|
||||||
|
}
|
||||||
|
address = address[len("unix:"):]
|
||||||
|
return (&net.Dialer{}).DialContext(ctx, "unix", address)
|
||||||
|
}
|
||||||
|
runUnixTest(t, test.address, test.target, test.authority, dialer)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Загрузка…
Ссылка в новой задаче