Allow to configure the client with a given http.Client.

- Give more control over the http connection to the library callers.
- Validate the host provided when the client is created.

Signed-off-by: David Calavera <david.calavera@gmail.com>
This commit is contained in:
David Calavera 2016-01-20 14:01:35 -05:00
Родитель bdbab71ec2
Коммит e558e63f36
5 изменённых файлов: 162 добавлений и 55 удалений

Просмотреть файл

@ -1,15 +1,12 @@
package client package client
import ( import (
"crypto/tls"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time"
"github.com/docker/go-connections/tlsconfig" "github.com/docker/go-connections/tlsconfig"
) )
@ -23,12 +20,8 @@ type Client struct {
addr string addr string
// basePath holds the path to prepend to the requests // basePath holds the path to prepend to the requests
basePath string basePath string
// scheme holds the scheme of the client i.e. https. // apiTransport holds information about the http transport
scheme string apiTransport *apiTransport
// tlsConfig holds the tls configuration to use in hijacked requests.
tlsConfig *tls.Config
// httpClient holds the client transport instance. Exported to keep the old code running.
httpClient *http.Client
// version of the server to talk to. // version of the server to talk to.
version string version string
// custom http headers configured by users // custom http headers configured by users
@ -41,7 +34,7 @@ type Client struct {
// Use DOCKER_CERT_PATH to load the tls certificates from. // Use DOCKER_CERT_PATH to load the tls certificates from.
// Use DOCKER_TLS_VERIFY to enable or disable TLS verification, off by default. // Use DOCKER_TLS_VERIFY to enable or disable TLS verification, off by default.
func NewEnvClient() (*Client, error) { func NewEnvClient() (*Client, error) {
var transport *http.Transport var client *http.Client
if dockerCertPath := os.Getenv("DOCKER_CERT_PATH"); dockerCertPath != "" { if dockerCertPath := os.Getenv("DOCKER_CERT_PATH"); dockerCertPath != "" {
options := tlsconfig.Options{ options := tlsconfig.Options{
CAFile: filepath.Join(dockerCertPath, "ca.pem"), CAFile: filepath.Join(dockerCertPath, "ca.pem"),
@ -54,8 +47,10 @@ func NewEnvClient() (*Client, error) {
return nil, err return nil, err
} }
transport = &http.Transport{ client = &http.Client{
TLSClientConfig: tlsc, Transport: &http.Transport{
TLSClientConfig: tlsc,
},
} }
} }
@ -63,42 +58,29 @@ func NewEnvClient() (*Client, error) {
if host == "" { if host == "" {
host = DefaultDockerHost host = DefaultDockerHost
} }
return NewClient(host, os.Getenv("DOCKER_API_VERSION"), transport, nil) return NewClient(host, os.Getenv("DOCKER_API_VERSION"), client, nil)
} }
// NewClient initializes a new API client for the given host and API version. // NewClient initializes a new API client for the given host and API version.
// It won't send any version information if the version number is empty. // It won't send any version information if the version number is empty.
// It uses the transport to create a new http client. // It uses the given http client as transport.
// It also initializes the custom http headers to add to each request. // It also initializes the custom http headers to add to each request.
func NewClient(host string, version string, transport *http.Transport, httpHeaders map[string]string) (*Client, error) { func NewClient(host string, version string, client *http.Client, httpHeaders map[string]string) (*Client, error) {
var ( proto, addr, basePath, err := parseHost(host)
basePath string if err != nil {
scheme = "http" return nil, err
protoAddrParts = strings.SplitN(host, "://", 2)
proto, addr = protoAddrParts[0], protoAddrParts[1]
)
if proto == "tcp" {
parsed, err := url.Parse("tcp://" + addr)
if err != nil {
return nil, err
}
addr = parsed.Host
basePath = parsed.Path
} }
transport = configureTransport(transport, proto, addr) apiTransport, err := newAPITransport(proto, addr, client)
if transport.TLSClientConfig != nil { if err != nil {
scheme = "https" return nil, err
} }
return &Client{ return &Client{
proto: proto, proto: proto,
addr: addr, addr: addr,
basePath: basePath, basePath: basePath,
scheme: scheme, apiTransport: apiTransport,
tlsConfig: transport.TLSClientConfig,
httpClient: &http.Client{Transport: transport},
version: version, version: version,
customHTTPHeaders: httpHeaders, customHTTPHeaders: httpHeaders,
}, nil }, nil
@ -127,23 +109,22 @@ func (cli *Client) ClientVersion() string {
return cli.version return cli.version
} }
func configureTransport(tr *http.Transport, proto, addr string) *http.Transport { // parseHost verifies that the given host strings is valid.
if tr == nil { func parseHost(host string) (string, string, string, error) {
tr = &http.Transport{} protoAddrParts := strings.SplitN(host, "://", 2)
if len(protoAddrParts) == 1 {
return "", "", "", fmt.Errorf("unable to parse docker host `%s`", host)
} }
// Why 32? See https://github.com/docker/docker/pull/8035. var basePath string
timeout := 32 * time.Second proto, addr := protoAddrParts[0], protoAddrParts[1]
if proto == "unix" { if proto == "tcp" {
// No need for compression in local communications. parsed, err := url.Parse("tcp://" + addr)
tr.DisableCompression = true if err != nil {
tr.Dial = func(_, _ string) (net.Conn, error) { return "", "", "", err
return net.DialTimeout(proto, addr, timeout)
} }
} else { addr = parsed.Host
tr.Proxy = http.ProxyFromEnvironment basePath = parsed.Path
tr.Dial = (&net.Dialer{Timeout: timeout}).Dial
} }
return proto, addr, basePath, nil
return tr
} }

Просмотреть файл

@ -34,3 +34,38 @@ func TestGetAPIPath(t *testing.T) {
} }
} }
} }
func TestParseHost(t *testing.T) {
cases := []struct {
host string
proto string
addr string
base string
err bool
}{
{"", "", "", "", true},
{"foobar", "", "", "", true},
{"foo://bar", "foo", "bar", "", false},
{"tcp://localhost:2476", "tcp", "localhost:2476", "", false},
{"tcp://localhost:2476/path", "tcp", "localhost:2476", "/path", false},
}
for _, cs := range cases {
p, a, b, e := parseHost(cs.host)
if cs.err && e == nil {
t.Fatalf("expected error, got nil")
}
if !cs.err && e != nil {
t.Fatal(e)
}
if cs.proto != p {
t.Fatalf("expected proto %s, got %s", cs.proto, p)
}
if cs.addr != a {
t.Fatalf("expected addr %s, got %s", cs.addr, a)
}
if cs.base != b {
t.Fatalf("expected base %s, got %s", cs.base, b)
}
}
}

Просмотреть файл

@ -44,7 +44,7 @@ func (cli *Client) postHijacked(path string, query url.Values, body interface{},
req.Header.Set("Connection", "Upgrade") req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "tcp") req.Header.Set("Upgrade", "tcp")
conn, err := dial(cli.proto, cli.addr, cli.tlsConfig) conn, err := dial(cli.proto, cli.addr, cli.apiTransport.TLSConfig())
if err != nil { if err != nil {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
return types.HijackedResponse{}, fmt.Errorf("Cannot connect to the Docker daemon. Is 'docker daemon' running on this host?") return types.HijackedResponse{}, fmt.Errorf("Cannot connect to the Docker daemon. Is 'docker daemon' running on this host?")

Просмотреть файл

@ -82,13 +82,13 @@ func (cli *Client) sendClientRequest(method, path string, query url.Values, body
req, err := cli.newRequest(method, path, query, body, headers) req, err := cli.newRequest(method, path, query, body, headers)
req.URL.Host = cli.addr req.URL.Host = cli.addr
req.URL.Scheme = cli.scheme req.URL.Scheme = cli.apiTransport.Scheme()
if expectedPayload && req.Header.Get("Content-Type") == "" { if expectedPayload && req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "text/plain") req.Header.Set("Content-Type", "text/plain")
} }
resp, err := cli.httpClient.Do(req) resp, err := cli.apiTransport.HTTPClient().Do(req)
if resp != nil { if resp != nil {
serverResp.statusCode = resp.StatusCode serverResp.statusCode = resp.StatusCode
} }
@ -98,10 +98,10 @@ func (cli *Client) sendClientRequest(method, path string, query url.Values, body
return serverResp, ErrConnectionFailed return serverResp, ErrConnectionFailed
} }
if cli.scheme == "http" && strings.Contains(err.Error(), "malformed HTTP response") { if !cli.apiTransport.IsTLS() && strings.Contains(err.Error(), "malformed HTTP response") {
return serverResp, fmt.Errorf("%v.\n* Are you trying to connect to a TLS-enabled daemon without TLS?", err) return serverResp, fmt.Errorf("%v.\n* Are you trying to connect to a TLS-enabled daemon without TLS?", err)
} }
if cli.scheme == "https" && strings.Contains(err.Error(), "remote error: bad certificate") { if cli.apiTransport.IsTLS() && strings.Contains(err.Error(), "remote error: bad certificate") {
return serverResp, fmt.Errorf("The server probably has client authentication (--tlsverify) enabled. Please check your TLS client certification settings: %v", err) return serverResp, fmt.Errorf("The server probably has client authentication (--tlsverify) enabled. Please check your TLS client certification settings: %v", err)
} }

91
client/transport.go Normal file
Просмотреть файл

@ -0,0 +1,91 @@
package client
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
)
// apiTransport holds information about the http transport to connect with the API.
type apiTransport struct {
// httpClient holds the client transport instance. Exported to keep the old code running.
httpClient *http.Client
// scheme holds the scheme of the client i.e. https.
scheme string
// tlsConfig holds the tls configuration to use in hijacked requests.
tlsConfig *tls.Config
}
// newAPITransport creates a new transport based on the provided proto, address and client.
// It uses Docker's default http transport configuration if the client is nil.
// It does not modify the client's transport if it's not nil.
func newAPITransport(proto, addr string, client *http.Client) (*apiTransport, error) {
scheme := "http"
var transport *http.Transport
if client != nil {
tr, ok := client.Transport.(*http.Transport)
if !ok {
return nil, fmt.Errorf("unable to verify TLS configuration, invalid transport %v", client.Transport)
}
transport = tr
} else {
transport = defaultTransport(proto, addr)
client = &http.Client{
Transport: transport,
}
}
if transport.TLSClientConfig != nil {
scheme = "https"
}
return &apiTransport{
httpClient: client,
scheme: scheme,
tlsConfig: transport.TLSClientConfig,
}, nil
}
// HTTPClient returns the http client.
func (a *apiTransport) HTTPClient() *http.Client {
return a.httpClient
}
// Scheme returns the api scheme.
func (a *apiTransport) Scheme() string {
return a.scheme
}
// TLSConfig returns the TLS configuration.
func (a *apiTransport) TLSConfig() *tls.Config {
return a.tlsConfig
}
// IsTLS returns true if there is a TLS configuration.
func (a *apiTransport) IsTLS() bool {
return a.tlsConfig != nil
}
// defaultTransport creates a new http.Transport with Docker's
// default transport configuration.
func defaultTransport(proto, addr string) *http.Transport {
tr := new(http.Transport)
// Why 32? See https://github.com/docker/docker/pull/8035.
timeout := 32 * time.Second
if proto == "unix" {
// No need for compression in local communications.
tr.DisableCompression = true
tr.Dial = func(_, _ string) (net.Conn, error) {
return net.DialTimeout(proto, addr, timeout)
}
} else {
tr.Proxy = http.ProxyFromEnvironment
tr.Dial = (&net.Dialer{Timeout: timeout}).Dial
}
return tr
}