diff --git a/client/client.go b/client/client.go index a62d859..bc62387 100644 --- a/client/client.go +++ b/client/client.go @@ -1,15 +1,12 @@ package client import ( - "crypto/tls" "fmt" - "net" "net/http" "net/url" "os" "path/filepath" "strings" - "time" "github.com/docker/go-connections/tlsconfig" ) @@ -23,12 +20,8 @@ type Client struct { addr string // basePath holds the path to prepend to the requests basePath string - // scheme holds the scheme of the client i.e. https. - scheme string - // tlsConfig holds the tls configuration to use in hijacked requests. - tlsConfig *tls.Config - // httpClient holds the client transport instance. Exported to keep the old code running. - httpClient *http.Client + // apiTransport holds information about the http transport + apiTransport *apiTransport // version of the server to talk to. version string // custom http headers configured by users @@ -41,7 +34,7 @@ type Client struct { // Use DOCKER_CERT_PATH to load the tls certificates from. // Use DOCKER_TLS_VERIFY to enable or disable TLS verification, off by default. func NewEnvClient() (*Client, error) { - var transport *http.Transport + var client *http.Client if dockerCertPath := os.Getenv("DOCKER_CERT_PATH"); dockerCertPath != "" { options := tlsconfig.Options{ CAFile: filepath.Join(dockerCertPath, "ca.pem"), @@ -54,8 +47,10 @@ func NewEnvClient() (*Client, error) { return nil, err } - transport = &http.Transport{ - TLSClientConfig: tlsc, + client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsc, + }, } } @@ -63,42 +58,29 @@ func NewEnvClient() (*Client, error) { if host == "" { host = DefaultDockerHost } - return NewClient(host, os.Getenv("DOCKER_API_VERSION"), transport, nil) + return NewClient(host, os.Getenv("DOCKER_API_VERSION"), client, nil) } // NewClient initializes a new API client for the given host and API version. // It won't send any version information if the version number is empty. -// It uses the transport to create a new http client. +// It uses the given http client as transport. // It also initializes the custom http headers to add to each request. -func NewClient(host string, version string, transport *http.Transport, httpHeaders map[string]string) (*Client, error) { - var ( - basePath string - scheme = "http" - protoAddrParts = strings.SplitN(host, "://", 2) - proto, addr = protoAddrParts[0], protoAddrParts[1] - ) - - if proto == "tcp" { - parsed, err := url.Parse("tcp://" + addr) - if err != nil { - return nil, err - } - addr = parsed.Host - basePath = parsed.Path +func NewClient(host string, version string, client *http.Client, httpHeaders map[string]string) (*Client, error) { + proto, addr, basePath, err := parseHost(host) + if err != nil { + return nil, err } - transport = configureTransport(transport, proto, addr) - if transport.TLSClientConfig != nil { - scheme = "https" + apiTransport, err := newAPITransport(proto, addr, client) + if err != nil { + return nil, err } return &Client{ proto: proto, addr: addr, basePath: basePath, - scheme: scheme, - tlsConfig: transport.TLSClientConfig, - httpClient: &http.Client{Transport: transport}, + apiTransport: apiTransport, version: version, customHTTPHeaders: httpHeaders, }, nil @@ -127,23 +109,22 @@ func (cli *Client) ClientVersion() string { return cli.version } -func configureTransport(tr *http.Transport, proto, addr string) *http.Transport { - if tr == nil { - tr = &http.Transport{} +// parseHost verifies that the given host strings is valid. +func parseHost(host string) (string, string, string, error) { + protoAddrParts := strings.SplitN(host, "://", 2) + if len(protoAddrParts) == 1 { + return "", "", "", fmt.Errorf("unable to parse docker host `%s`", host) } - // Why 32? See https://github.com/docker/docker/pull/8035. - timeout := 32 * time.Second - if proto == "unix" { - // No need for compression in local communications. - tr.DisableCompression = true - tr.Dial = func(_, _ string) (net.Conn, error) { - return net.DialTimeout(proto, addr, timeout) + var basePath string + proto, addr := protoAddrParts[0], protoAddrParts[1] + if proto == "tcp" { + parsed, err := url.Parse("tcp://" + addr) + if err != nil { + return "", "", "", err } - } else { - tr.Proxy = http.ProxyFromEnvironment - tr.Dial = (&net.Dialer{Timeout: timeout}).Dial + addr = parsed.Host + basePath = parsed.Path } - - return tr + return proto, addr, basePath, nil } diff --git a/client/client_test.go b/client/client_test.go index b758b8d..1485f60 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -34,3 +34,38 @@ func TestGetAPIPath(t *testing.T) { } } } + +func TestParseHost(t *testing.T) { + cases := []struct { + host string + proto string + addr string + base string + err bool + }{ + {"", "", "", "", true}, + {"foobar", "", "", "", true}, + {"foo://bar", "foo", "bar", "", false}, + {"tcp://localhost:2476", "tcp", "localhost:2476", "", false}, + {"tcp://localhost:2476/path", "tcp", "localhost:2476", "/path", false}, + } + + for _, cs := range cases { + p, a, b, e := parseHost(cs.host) + if cs.err && e == nil { + t.Fatalf("expected error, got nil") + } + if !cs.err && e != nil { + t.Fatal(e) + } + if cs.proto != p { + t.Fatalf("expected proto %s, got %s", cs.proto, p) + } + if cs.addr != a { + t.Fatalf("expected addr %s, got %s", cs.addr, a) + } + if cs.base != b { + t.Fatalf("expected base %s, got %s", cs.base, b) + } + } +} diff --git a/client/hijack.go b/client/hijack.go index 5835d8c..e863613 100644 --- a/client/hijack.go +++ b/client/hijack.go @@ -44,7 +44,7 @@ func (cli *Client) postHijacked(path string, query url.Values, body interface{}, req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "tcp") - conn, err := dial(cli.proto, cli.addr, cli.tlsConfig) + conn, err := dial(cli.proto, cli.addr, cli.apiTransport.TLSConfig()) if err != nil { if strings.Contains(err.Error(), "connection refused") { return types.HijackedResponse{}, fmt.Errorf("Cannot connect to the Docker daemon. Is 'docker daemon' running on this host?") diff --git a/client/request.go b/client/request.go index cd8c317..f1c0a35 100644 --- a/client/request.go +++ b/client/request.go @@ -82,13 +82,13 @@ func (cli *Client) sendClientRequest(method, path string, query url.Values, body req, err := cli.newRequest(method, path, query, body, headers) req.URL.Host = cli.addr - req.URL.Scheme = cli.scheme + req.URL.Scheme = cli.apiTransport.Scheme() if expectedPayload && req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "text/plain") } - resp, err := cli.httpClient.Do(req) + resp, err := cli.apiTransport.HTTPClient().Do(req) if resp != nil { serverResp.statusCode = resp.StatusCode } @@ -98,10 +98,10 @@ func (cli *Client) sendClientRequest(method, path string, query url.Values, body return serverResp, ErrConnectionFailed } - if cli.scheme == "http" && strings.Contains(err.Error(), "malformed HTTP response") { + if !cli.apiTransport.IsTLS() && strings.Contains(err.Error(), "malformed HTTP response") { return serverResp, fmt.Errorf("%v.\n* Are you trying to connect to a TLS-enabled daemon without TLS?", err) } - if cli.scheme == "https" && strings.Contains(err.Error(), "remote error: bad certificate") { + if cli.apiTransport.IsTLS() && strings.Contains(err.Error(), "remote error: bad certificate") { return serverResp, fmt.Errorf("The server probably has client authentication (--tlsverify) enabled. Please check your TLS client certification settings: %v", err) } diff --git a/client/transport.go b/client/transport.go new file mode 100644 index 0000000..509b970 --- /dev/null +++ b/client/transport.go @@ -0,0 +1,91 @@ +package client + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "time" +) + +// apiTransport holds information about the http transport to connect with the API. +type apiTransport struct { + // httpClient holds the client transport instance. Exported to keep the old code running. + httpClient *http.Client + // scheme holds the scheme of the client i.e. https. + scheme string + // tlsConfig holds the tls configuration to use in hijacked requests. + tlsConfig *tls.Config +} + +// newAPITransport creates a new transport based on the provided proto, address and client. +// It uses Docker's default http transport configuration if the client is nil. +// It does not modify the client's transport if it's not nil. +func newAPITransport(proto, addr string, client *http.Client) (*apiTransport, error) { + scheme := "http" + var transport *http.Transport + + if client != nil { + tr, ok := client.Transport.(*http.Transport) + if !ok { + return nil, fmt.Errorf("unable to verify TLS configuration, invalid transport %v", client.Transport) + } + transport = tr + } else { + transport = defaultTransport(proto, addr) + client = &http.Client{ + Transport: transport, + } + } + + if transport.TLSClientConfig != nil { + scheme = "https" + } + + return &apiTransport{ + httpClient: client, + scheme: scheme, + tlsConfig: transport.TLSClientConfig, + }, nil +} + +// HTTPClient returns the http client. +func (a *apiTransport) HTTPClient() *http.Client { + return a.httpClient +} + +// Scheme returns the api scheme. +func (a *apiTransport) Scheme() string { + return a.scheme +} + +// TLSConfig returns the TLS configuration. +func (a *apiTransport) TLSConfig() *tls.Config { + return a.tlsConfig +} + +// IsTLS returns true if there is a TLS configuration. +func (a *apiTransport) IsTLS() bool { + return a.tlsConfig != nil +} + +// defaultTransport creates a new http.Transport with Docker's +// default transport configuration. +func defaultTransport(proto, addr string) *http.Transport { + tr := new(http.Transport) + + // Why 32? See https://github.com/docker/docker/pull/8035. + timeout := 32 * time.Second + if proto == "unix" { + // No need for compression in local communications. + tr.DisableCompression = true + tr.Dial = func(_, _ string) (net.Conn, error) { + return net.DialTimeout(proto, addr, timeout) + } + } else { + tr.Proxy = http.ProxyFromEnvironment + tr.Dial = (&net.Dialer{Timeout: timeout}).Dial + } + + return tr +}