From a01cc3ca7729c3ce635fef7c1db837b5c6ae1028 Mon Sep 17 00:00:00 2001 From: Tibor Vass Date: Thu, 14 May 2015 07:12:54 -0700 Subject: [PATCH 1/3] registry: Refactor requestfactory to use http.RoundTrippers This patch removes the need for requestFactories and decorators by implementing http.RoundTripper transports instead. It refactors some challenging-to-read code. NewSession now takes an *http.Client that can already have a custom Transport, it will add its own auth transport by wrapping it. The idea is that callers of http.Client should not bother setting custom headers for every handler but instead it should be transparent to the callers of a same context. This patch is needed for future refactorings of registry, namely refactoring of the v1 client code. Signed-off-by: Tibor Vass --- graph/pull.go | 16 +- graph/push.go | 17 +- pkg/requestdecorator/requestdecorator.go | 4 +- registry/auth.go | 32 ++- registry/endpoint.go | 41 ++-- registry/httpfactory.go | 30 --- registry/registry.go | 226 ++++++++++++------ registry/registry_test.go | 66 +++--- registry/service.go | 5 +- registry/session.go | 288 +++++++++++------------ registry/session_v2.go | 32 +-- registry/token.go | 6 +- 12 files changed, 397 insertions(+), 366 deletions(-) delete mode 100644 registry/httpfactory.go diff --git a/graph/pull.go b/graph/pull.go index 5fd837b429..013729a8c0 100644 --- a/graph/pull.go +++ b/graph/pull.go @@ -60,7 +60,13 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf return err } - r, err := registry.NewSession(imagePullConfig.AuthConfig, registry.HTTPRequestFactory(imagePullConfig.MetaHeaders), endpoint, true) + // Adds Docker-specific headers as well as user-specified headers (metaHeaders) + tr := ®istry.DockerHeaders{ + registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure), + imagePullConfig.MetaHeaders, + } + client := registry.HTTPClient(tr) + r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint) if err != nil { return err } @@ -109,7 +115,7 @@ func (s *TagStore) pullRepository(r *registry.Session, out io.Writer, repoInfo * } logrus.Debugf("Retrieving the tag list") - tagsList, err := r.GetRemoteTags(repoData.Endpoints, repoInfo.RemoteName, repoData.Tokens) + tagsList, err := r.GetRemoteTags(repoData.Endpoints, repoInfo.RemoteName) if err != nil { logrus.Errorf("unable to get remote tags: %s", err) return err @@ -240,7 +246,7 @@ func (s *TagStore) pullRepository(r *registry.Session, out io.Writer, repoInfo * } func (s *TagStore) pullImage(r *registry.Session, out io.Writer, imgID, endpoint string, token []string, sf *streamformatter.StreamFormatter) (bool, error) { - history, err := r.GetRemoteHistory(imgID, endpoint, token) + history, err := r.GetRemoteHistory(imgID, endpoint) if err != nil { return false, err } @@ -269,7 +275,7 @@ func (s *TagStore) pullImage(r *registry.Session, out io.Writer, imgID, endpoint ) retries := 5 for j := 1; j <= retries; j++ { - imgJSON, imgSize, err = r.GetRemoteImageJSON(id, endpoint, token) + imgJSON, imgSize, err = r.GetRemoteImageJSON(id, endpoint) if err != nil && j == retries { out.Write(sf.FormatProgress(stringid.TruncateID(id), "Error pulling dependent layers", nil)) return layersDownloaded, err @@ -297,7 +303,7 @@ func (s *TagStore) pullImage(r *registry.Session, out io.Writer, imgID, endpoint status = fmt.Sprintf("Pulling fs layer [retries: %d]", j) } out.Write(sf.FormatProgress(stringid.TruncateID(id), status, nil)) - layer, err := r.GetRemoteImageLayer(img.ID, endpoint, token, int64(imgSize)) + layer, err := r.GetRemoteImageLayer(img.ID, endpoint, int64(imgSize)) if uerr, ok := err.(*url.Error); ok { err = uerr.Err } diff --git a/graph/push.go b/graph/push.go index 1ae1fc1d00..6e9a367bcb 100644 --- a/graph/push.go +++ b/graph/push.go @@ -141,7 +141,7 @@ func lookupImageOnEndpoint(wg *sync.WaitGroup, r *registry.Session, out io.Write images chan imagePushData, imagesToPush chan string) { defer wg.Done() for image := range images { - if err := r.LookupRemoteImage(image.id, image.endpoint, image.tokens); err != nil { + if err := r.LookupRemoteImage(image.id, image.endpoint); err != nil { logrus.Errorf("Error in LookupRemoteImage: %s", err) imagesToPush <- image.id continue @@ -199,7 +199,7 @@ func (s *TagStore) pushImageToEndpoint(endpoint string, out io.Writer, remoteNam } for _, tag := range tags[id] { out.Write(sf.FormatStatus("", "Pushing tag for rev [%s] on {%s}", stringid.TruncateID(id), endpoint+"repositories/"+remoteName+"/tags/"+tag)) - if err := r.PushRegistryTag(remoteName, id, tag, endpoint, repo.Tokens); err != nil { + if err := r.PushRegistryTag(remoteName, id, tag, endpoint); err != nil { return err } } @@ -258,7 +258,7 @@ func (s *TagStore) pushImage(r *registry.Session, out io.Writer, imgID, ep strin } // Send the json - if err := r.PushImageJSONRegistry(imgData, jsonRaw, ep, token); err != nil { + if err := r.PushImageJSONRegistry(imgData, jsonRaw, ep); err != nil { if err == registry.ErrAlreadyExists { out.Write(sf.FormatProgress(stringid.TruncateID(imgData.ID), "Image already pushed, skipping", nil)) return "", nil @@ -284,14 +284,14 @@ func (s *TagStore) pushImage(r *registry.Session, out io.Writer, imgID, ep strin NewLines: false, ID: stringid.TruncateID(imgData.ID), Action: "Pushing", - }), ep, token, jsonRaw) + }), ep, jsonRaw) if err != nil { return "", err } imgData.Checksum = checksum imgData.ChecksumPayload = checksumPayload // Send the checksum - if err := r.PushImageChecksumRegistry(imgData, ep, token); err != nil { + if err := r.PushImageChecksumRegistry(imgData, ep); err != nil { return "", err } @@ -514,7 +514,12 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro return err } - r, err := registry.NewSession(imagePushConfig.AuthConfig, registry.HTTPRequestFactory(imagePushConfig.MetaHeaders), endpoint, false) + // Adds Docker-specific headers as well as user-specified headers (metaHeaders) + tr := ®istry.DockerHeaders{ + registry.NewTransport(registry.NoTimeout, endpoint.IsSecure), + imagePushConfig.MetaHeaders, + } + r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint) if err != nil { return err } diff --git a/pkg/requestdecorator/requestdecorator.go b/pkg/requestdecorator/requestdecorator.go index c236e3fe3f..079e38b21e 100644 --- a/pkg/requestdecorator/requestdecorator.go +++ b/pkg/requestdecorator/requestdecorator.go @@ -47,7 +47,7 @@ func (vi *UAVersionInfo) isValid() bool { // "product/version", where the "product" is get from the name field, while // version is get from the version field. Several pieces of verson information // will be concatinated and separated by space. -func appendVersions(base string, versions ...UAVersionInfo) string { +func AppendVersions(base string, versions ...UAVersionInfo) string { if len(versions) == 0 { return base } @@ -87,7 +87,7 @@ func (h *UserAgentDecorator) ChangeRequest(req *http.Request) (*http.Request, er return req, ErrNilRequest } - userAgent := appendVersions(req.UserAgent(), h.Versions...) + userAgent := AppendVersions(req.UserAgent(), h.Versions...) if len(userAgent) > 0 { req.Header.Set("User-Agent", userAgent) } diff --git a/registry/auth.go b/registry/auth.go index 1ac1ca984e..0b6c3b0f95 100644 --- a/registry/auth.go +++ b/registry/auth.go @@ -11,7 +11,6 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/docker/cliconfig" - "github.com/docker/docker/pkg/requestdecorator" ) type RequestAuthorization struct { @@ -46,7 +45,6 @@ func (auth *RequestAuthorization) getToken() (string, error) { } client := auth.registryEndpoint.HTTPClient() - factory := HTTPRequestFactory(nil) for _, challenge := range auth.registryEndpoint.AuthChallenges { switch strings.ToLower(challenge.Scheme) { @@ -59,7 +57,7 @@ func (auth *RequestAuthorization) getToken() (string, error) { params[k] = v } params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ",")) - token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client, factory) + token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client) if err != nil { return "", err } @@ -92,16 +90,16 @@ func (auth *RequestAuthorization) Authorize(req *http.Request) error { } // Login tries to register/login to the registry server. -func Login(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, factory *requestdecorator.RequestFactory) (string, error) { +func Login(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (string, error) { // Separates the v2 registry login logic from the v1 logic. if registryEndpoint.Version == APIVersion2 { - return loginV2(authConfig, registryEndpoint, factory) + return loginV2(authConfig, registryEndpoint) } - return loginV1(authConfig, registryEndpoint, factory) + return loginV1(authConfig, registryEndpoint) } // loginV1 tries to register/login to the v1 registry server. -func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, factory *requestdecorator.RequestFactory) (string, error) { +func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (string, error) { var ( status string reqBody []byte @@ -151,7 +149,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto } } else if reqStatusCode == 400 { if string(reqBody) == "\"Username or email already exists\"" { - req, err := factory.NewRequest("GET", serverAddress+"users/", nil) + req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) resp, err := client.Do(req) if err != nil { @@ -180,7 +178,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto } else if reqStatusCode == 401 { // This case would happen with private registries where /v1/users is // protected, so people can use `docker login` as an auth check. - req, err := factory.NewRequest("GET", serverAddress+"users/", nil) + req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) resp, err := client.Do(req) if err != nil { @@ -214,7 +212,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto // now, users should create their account through other means like directly from a web page // served by the v2 registry service provider. Whether this will be supported in the future // is to be determined. -func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, factory *requestdecorator.RequestFactory) (string, error) { +func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (string, error) { logrus.Debugf("attempting v2 login to registry endpoint %s", registryEndpoint) var ( err error @@ -227,9 +225,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto switch strings.ToLower(challenge.Scheme) { case "basic": - err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client, factory) + err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) case "bearer": - err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client, factory) + err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) default: // Unsupported challenge types are explicitly skipped. err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme) @@ -247,8 +245,8 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint, facto return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors) } -func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client, factory *requestdecorator.RequestFactory) error { - req, err := factory.NewRequest("GET", registryEndpoint.Path(""), nil) +func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error { + req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil) if err != nil { return err } @@ -268,13 +266,13 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str return nil } -func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client, factory *requestdecorator.RequestFactory) error { - token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client, factory) +func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error { + token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client) if err != nil { return err } - req, err := factory.NewRequest("GET", registryEndpoint.Path(""), nil) + req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil) if err != nil { return err } diff --git a/registry/endpoint.go b/registry/endpoint.go index 84b11a987b..25f66ad25f 100644 --- a/registry/endpoint.go +++ b/registry/endpoint.go @@ -1,7 +1,6 @@ package registry import ( - "crypto/tls" "encoding/json" "fmt" "io/ioutil" @@ -12,7 +11,6 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/distribution/registry/api/v2" - "github.com/docker/docker/pkg/requestdecorator" ) // for mocking in unit tests @@ -109,6 +107,7 @@ func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) { // Endpoint stores basic information about a registry endpoint. type Endpoint struct { + client *http.Client URL *url.URL Version APIVersion IsSecure bool @@ -135,25 +134,24 @@ func (e *Endpoint) Path(path string) string { func (e *Endpoint) Ping() (RegistryInfo, error) { // The ping logic to use is determined by the registry endpoint version. - factory := HTTPRequestFactory(nil) switch e.Version { case APIVersion1: - return e.pingV1(factory) + return e.pingV1() case APIVersion2: - return e.pingV2(factory) + return e.pingV2() } // APIVersionUnknown // We should try v2 first... e.Version = APIVersion2 - regInfo, errV2 := e.pingV2(factory) + regInfo, errV2 := e.pingV2() if errV2 == nil { return regInfo, nil } // ... then fallback to v1. e.Version = APIVersion1 - regInfo, errV1 := e.pingV1(factory) + regInfo, errV1 := e.pingV1() if errV1 == nil { return regInfo, nil } @@ -162,7 +160,7 @@ func (e *Endpoint) Ping() (RegistryInfo, error) { return RegistryInfo{}, fmt.Errorf("unable to ping registry endpoint %s\nv2 ping attempt failed with error: %s\n v1 ping attempt failed with error: %s", e, errV2, errV1) } -func (e *Endpoint) pingV1(factory *requestdecorator.RequestFactory) (RegistryInfo, error) { +func (e *Endpoint) pingV1() (RegistryInfo, error) { logrus.Debugf("attempting v1 ping for registry endpoint %s", e) if e.String() == IndexServerAddress() { @@ -171,12 +169,12 @@ func (e *Endpoint) pingV1(factory *requestdecorator.RequestFactory) (RegistryInf return RegistryInfo{Standalone: false}, nil } - req, err := factory.NewRequest("GET", e.Path("_ping"), nil) + req, err := http.NewRequest("GET", e.Path("_ping"), nil) if err != nil { return RegistryInfo{Standalone: false}, err } - resp, _, err := doRequest(req, nil, ConnectTimeout, e.IsSecure) + resp, err := e.HTTPClient().Do(req) if err != nil { return RegistryInfo{Standalone: false}, err } @@ -216,15 +214,15 @@ func (e *Endpoint) pingV1(factory *requestdecorator.RequestFactory) (RegistryInf return info, nil } -func (e *Endpoint) pingV2(factory *requestdecorator.RequestFactory) (RegistryInfo, error) { +func (e *Endpoint) pingV2() (RegistryInfo, error) { logrus.Debugf("attempting v2 ping for registry endpoint %s", e) - req, err := factory.NewRequest("GET", e.Path(""), nil) + req, err := http.NewRequest("GET", e.Path(""), nil) if err != nil { return RegistryInfo{}, err } - resp, _, err := doRequest(req, nil, ConnectTimeout, e.IsSecure) + resp, err := e.HTTPClient().Do(req) if err != nil { return RegistryInfo{}, err } @@ -265,18 +263,9 @@ HeaderLoop: } func (e *Endpoint) HTTPClient() *http.Client { - tlsConfig := tls.Config{ - MinVersion: tls.VersionTLS10, - } - if !e.IsSecure { - tlsConfig.InsecureSkipVerify = true - } - return &http.Client{ - Transport: &http.Transport{ - DisableKeepAlives: true, - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tlsConfig, - }, - CheckRedirect: AddRequiredHeadersToRedirectedRequests, + if e.client == nil { + tr := NewTransport(ConnectTimeout, e.IsSecure) + e.client = HTTPClient(tr) } + return e.client } diff --git a/registry/httpfactory.go b/registry/httpfactory.go deleted file mode 100644 index f1b89e5829..0000000000 --- a/registry/httpfactory.go +++ /dev/null @@ -1,30 +0,0 @@ -package registry - -import ( - "runtime" - - "github.com/docker/docker/autogen/dockerversion" - "github.com/docker/docker/pkg/parsers/kernel" - "github.com/docker/docker/pkg/requestdecorator" -) - -func HTTPRequestFactory(metaHeaders map[string][]string) *requestdecorator.RequestFactory { - // FIXME: this replicates the 'info' job. - httpVersion := make([]requestdecorator.UAVersionInfo, 0, 4) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("docker", dockerversion.VERSION)) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("go", runtime.Version())) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("git-commit", dockerversion.GITCOMMIT)) - if kernelVersion, err := kernel.GetKernelVersion(); err == nil { - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("kernel", kernelVersion.String())) - } - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("os", runtime.GOOS)) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("arch", runtime.GOARCH)) - uad := &requestdecorator.UserAgentDecorator{ - Versions: httpVersion, - } - mhd := &requestdecorator.MetaHeadersDecorator{ - Headers: metaHeaders, - } - factory := requestdecorator.NewRequestFactory(uad, mhd) - return factory -} diff --git a/registry/registry.go b/registry/registry.go index aff28eaa42..db77a985eb 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -8,12 +8,17 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httputil" "os" "path" + "runtime" "strings" "time" "github.com/Sirupsen/logrus" + "github.com/docker/docker/autogen/dockerversion" + "github.com/docker/docker/pkg/parsers/kernel" + "github.com/docker/docker/pkg/requestdecorator" "github.com/docker/docker/pkg/timeoutconn" ) @@ -31,66 +36,23 @@ const ( ConnectTimeout ) -func newClient(jar http.CookieJar, roots *x509.CertPool, certs []tls.Certificate, timeout TimeoutType, secure bool) *http.Client { - tlsConfig := tls.Config{ - RootCAs: roots, - // Avoid fallback to SSL protocols < TLS1.0 - MinVersion: tls.VersionTLS10, - Certificates: certs, - } - - if !secure { - tlsConfig.InsecureSkipVerify = true - } - - httpTransport := &http.Transport{ - DisableKeepAlives: true, - Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tlsConfig, - } - - switch timeout { - case ConnectTimeout: - httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { - // Set the connect timeout to 30 seconds to allow for slower connection - // times... - d := net.Dialer{Timeout: 30 * time.Second, DualStack: true} - - conn, err := d.Dial(proto, addr) - if err != nil { - return nil, err - } - // Set the recv timeout to 10 seconds - conn.SetDeadline(time.Now().Add(10 * time.Second)) - return conn, nil - } - case ReceiveTimeout: - httpTransport.Dial = func(proto string, addr string) (net.Conn, error) { - d := net.Dialer{DualStack: true} - - conn, err := d.Dial(proto, addr) - if err != nil { - return nil, err - } - conn = timeoutconn.New(conn, 1*time.Minute) - return conn, nil - } - } - - return &http.Client{ - Transport: httpTransport, - CheckRedirect: AddRequiredHeadersToRedirectedRequests, - Jar: jar, - } +type httpsTransport struct { + *http.Transport } -func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType, secure bool) (*http.Response, *http.Client, error) { +// DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip, +// it's because it's so as to match the current behavior in master: we generate the +// certpool on every-goddam-request. It's not great, but it allows people to just put +// the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would +// prefer an fsnotify implementation, but that was out of scope of my refactoring. +// TODO: improve things +func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { var ( - pool *x509.CertPool + roots *x509.CertPool certs []tls.Certificate ) - if secure && req.URL.Scheme == "https" { + if req.URL.Scheme == "https" { hasFile := func(files []os.FileInfo, name string) bool { for _, f := range files { if f.Name() == name { @@ -104,31 +66,31 @@ func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType, secur logrus.Debugf("hostDir: %s", hostDir) fs, err := ioutil.ReadDir(hostDir) if err != nil && !os.IsNotExist(err) { - return nil, nil, err + return nil, err } for _, f := range fs { if strings.HasSuffix(f.Name(), ".crt") { - if pool == nil { - pool = x509.NewCertPool() + if roots == nil { + roots = x509.NewCertPool() } logrus.Debugf("crt: %s", hostDir+"/"+f.Name()) data, err := ioutil.ReadFile(path.Join(hostDir, f.Name())) if err != nil { - return nil, nil, err + return nil, err } - pool.AppendCertsFromPEM(data) + roots.AppendCertsFromPEM(data) } if strings.HasSuffix(f.Name(), ".cert") { certName := f.Name() keyName := certName[:len(certName)-5] + ".key" logrus.Debugf("cert: %s", hostDir+"/"+f.Name()) if !hasFile(fs, keyName) { - return nil, nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName) + return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName) } cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName)) if err != nil { - return nil, nil, err + return nil, err } certs = append(certs, cert) } @@ -137,24 +99,142 @@ func doRequest(req *http.Request, jar http.CookieJar, timeout TimeoutType, secur certName := keyName[:len(keyName)-4] + ".cert" logrus.Debugf("key: %s", hostDir+"/"+f.Name()) if !hasFile(fs, certName) { - return nil, nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName) + return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName) } } } - } - - if len(certs) == 0 { - client := newClient(jar, pool, nil, timeout, secure) - res, err := client.Do(req) - if err != nil { - return nil, nil, err + if tr.Transport.TLSClientConfig == nil { + tr.Transport.TLSClientConfig = &tls.Config{ + // Avoid fallback to SSL protocols < TLS1.0 + MinVersion: tls.VersionTLS10, + } } - return res, client, nil + tr.Transport.TLSClientConfig.RootCAs = roots + tr.Transport.TLSClientConfig.Certificates = certs + } + return tr.Transport.RoundTrip(req) +} + +func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { + tlsConfig := tls.Config{ + // Avoid fallback to SSL protocols < TLS1.0 + MinVersion: tls.VersionTLS10, + InsecureSkipVerify: !secure, } - client := newClient(jar, pool, certs, timeout, secure) - res, err := client.Do(req) - return res, client, err + transport := &http.Transport{ + DisableKeepAlives: true, + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tlsConfig, + } + + switch timeout { + case ConnectTimeout: + transport.Dial = func(proto string, addr string) (net.Conn, error) { + // Set the connect timeout to 30 seconds to allow for slower connection + // times... + d := net.Dialer{Timeout: 30 * time.Second, DualStack: true} + + conn, err := d.Dial(proto, addr) + if err != nil { + return nil, err + } + // Set the recv timeout to 10 seconds + conn.SetDeadline(time.Now().Add(10 * time.Second)) + return conn, nil + } + case ReceiveTimeout: + transport.Dial = func(proto string, addr string) (net.Conn, error) { + d := net.Dialer{DualStack: true} + + conn, err := d.Dial(proto, addr) + if err != nil { + return nil, err + } + conn = timeoutconn.New(conn, 1*time.Minute) + return conn, nil + } + } + + if secure { + // note: httpsTransport also handles http transport + // but for HTTPS, it sets up the certs + return &httpsTransport{transport} + } + + return transport +} + +type DockerHeaders struct { + http.RoundTripper + Headers http.Header +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + return r2 +} + +func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) { + req = cloneRequest(req) + httpVersion := make([]requestdecorator.UAVersionInfo, 0, 4) + httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("docker", dockerversion.VERSION)) + httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("go", runtime.Version())) + httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("git-commit", dockerversion.GITCOMMIT)) + if kernelVersion, err := kernel.GetKernelVersion(); err == nil { + httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("kernel", kernelVersion.String())) + } + httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("os", runtime.GOOS)) + httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("arch", runtime.GOARCH)) + + userAgent := requestdecorator.AppendVersions(req.UserAgent(), httpVersion...) + + req.Header.Set("User-Agent", userAgent) + + for k, v := range tr.Headers { + req.Header[k] = v + } + return tr.RoundTripper.RoundTrip(req) +} + +type debugTransport struct{ http.RoundTripper } + +func (tr debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { + dump, err := httputil.DumpRequestOut(req, false) + if err != nil { + fmt.Println("could not dump request") + } + fmt.Println(string(dump)) + resp, err := tr.RoundTripper.RoundTrip(req) + if err != nil { + return nil, err + } + dump, err = httputil.DumpResponse(resp, false) + if err != nil { + fmt.Println("could not dump response") + } + fmt.Println(string(dump)) + return resp, err +} + +func HTTPClient(transport http.RoundTripper) *http.Client { + if transport == nil { + transport = NewTransport(ConnectTimeout, true) + } + + return &http.Client{ + Transport: transport, + CheckRedirect: AddRequiredHeadersToRedirectedRequests, + } } func trustedLocation(req *http.Request) bool { diff --git a/registry/registry_test.go b/registry/registry_test.go index 799d080ed7..d4a5ded082 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/docker/docker/cliconfig" - "github.com/docker/docker/pkg/requestdecorator" ) var ( @@ -26,38 +25,27 @@ func spawnTestRegistrySession(t *testing.T) *Session { if err != nil { t.Fatal(err) } - r, err := NewSession(authConfig, requestdecorator.NewRequestFactory(), endpoint, true) + var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)} + tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil} + client := HTTPClient(tr) + r, err := NewSession(client, authConfig, endpoint) if err != nil { t.Fatal(err) } + // In a normal scenario for the v1 registry, the client should send a `X-Docker-Token: true` + // header while authenticating, in order to retrieve a token that can be later used to + // perform authenticated actions. + // + // The mock v1 registry does not support that, (TODO(tiborvass): support it), instead, + // it will consider authenticated any request with the header `X-Docker-Token: fake-token`. + // + // Because we know that the client's transport is an `*authTransport` we simply cast it, + // in order to set the internal cached token to the fake token, and thus send that fake token + // upon every subsequent requests. + r.client.Transport.(*authTransport).token = token return r } -func TestPublicSession(t *testing.T) { - authConfig := &cliconfig.AuthConfig{} - - getSessionDecorators := func(index *IndexInfo) int { - endpoint, err := NewEndpoint(index) - if err != nil { - t.Fatal(err) - } - r, err := NewSession(authConfig, requestdecorator.NewRequestFactory(), endpoint, true) - if err != nil { - t.Fatal(err) - } - return len(r.reqFactory.GetDecorators()) - } - - decorators := getSessionDecorators(makeIndex("/v1/")) - assertEqual(t, decorators, 0, "Expected no decorator on http session") - - decorators = getSessionDecorators(makeHttpsIndex("/v1/")) - assertNotEqual(t, decorators, 0, "Expected decorator on https session") - - decorators = getSessionDecorators(makePublicIndex()) - assertEqual(t, decorators, 0, "Expected no decorator on public session") -} - func TestPingRegistryEndpoint(t *testing.T) { testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) { ep, err := NewEndpoint(index) @@ -170,7 +158,7 @@ func TestEndpoint(t *testing.T) { func TestGetRemoteHistory(t *testing.T) { r := spawnTestRegistrySession(t) - hist, err := r.GetRemoteHistory(imageID, makeURL("/v1/"), token) + hist, err := r.GetRemoteHistory(imageID, makeURL("/v1/")) if err != nil { t.Fatal(err) } @@ -182,16 +170,16 @@ func TestGetRemoteHistory(t *testing.T) { func TestLookupRemoteImage(t *testing.T) { r := spawnTestRegistrySession(t) - err := r.LookupRemoteImage(imageID, makeURL("/v1/"), token) + err := r.LookupRemoteImage(imageID, makeURL("/v1/")) assertEqual(t, err, nil, "Expected error of remote lookup to nil") - if err := r.LookupRemoteImage("abcdef", makeURL("/v1/"), token); err == nil { + if err := r.LookupRemoteImage("abcdef", makeURL("/v1/")); err == nil { t.Fatal("Expected error of remote lookup to not nil") } } func TestGetRemoteImageJSON(t *testing.T) { r := spawnTestRegistrySession(t) - json, size, err := r.GetRemoteImageJSON(imageID, makeURL("/v1/"), token) + json, size, err := r.GetRemoteImageJSON(imageID, makeURL("/v1/")) if err != nil { t.Fatal(err) } @@ -200,7 +188,7 @@ func TestGetRemoteImageJSON(t *testing.T) { t.Fatal("Expected non-empty json") } - _, _, err = r.GetRemoteImageJSON("abcdef", makeURL("/v1/"), token) + _, _, err = r.GetRemoteImageJSON("abcdef", makeURL("/v1/")) if err == nil { t.Fatal("Expected image not found error") } @@ -208,7 +196,7 @@ func TestGetRemoteImageJSON(t *testing.T) { func TestGetRemoteImageLayer(t *testing.T) { r := spawnTestRegistrySession(t) - data, err := r.GetRemoteImageLayer(imageID, makeURL("/v1/"), token, 0) + data, err := r.GetRemoteImageLayer(imageID, makeURL("/v1/"), 0) if err != nil { t.Fatal(err) } @@ -216,7 +204,7 @@ func TestGetRemoteImageLayer(t *testing.T) { t.Fatal("Expected non-nil data result") } - _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), token, 0) + _, err = r.GetRemoteImageLayer("abcdef", makeURL("/v1/"), 0) if err == nil { t.Fatal("Expected image not found error") } @@ -224,14 +212,14 @@ func TestGetRemoteImageLayer(t *testing.T) { func TestGetRemoteTags(t *testing.T) { r := spawnTestRegistrySession(t) - tags, err := r.GetRemoteTags([]string{makeURL("/v1/")}, REPO, token) + tags, err := r.GetRemoteTags([]string{makeURL("/v1/")}, REPO) if err != nil { t.Fatal(err) } assertEqual(t, len(tags), 1, "Expected one tag") assertEqual(t, tags["latest"], imageID, "Expected tag latest to map to "+imageID) - _, err = r.GetRemoteTags([]string{makeURL("/v1/")}, "foo42/baz", token) + _, err = r.GetRemoteTags([]string{makeURL("/v1/")}, "foo42/baz") if err == nil { t.Fatal("Expected error when fetching tags for bogus repo") } @@ -265,7 +253,7 @@ func TestPushImageJSONRegistry(t *testing.T) { Checksum: "sha256:1ac330d56e05eef6d438586545ceff7550d3bdcb6b19961f12c5ba714ee1bb37", } - err := r.PushImageJSONRegistry(imgData, []byte{0x42, 0xdf, 0x0}, makeURL("/v1/"), token) + err := r.PushImageJSONRegistry(imgData, []byte{0x42, 0xdf, 0x0}, makeURL("/v1/")) if err != nil { t.Fatal(err) } @@ -274,7 +262,7 @@ func TestPushImageJSONRegistry(t *testing.T) { func TestPushImageLayerRegistry(t *testing.T) { r := spawnTestRegistrySession(t) layer := strings.NewReader("") - _, _, err := r.PushImageLayerRegistry(imageID, layer, makeURL("/v1/"), token, []byte{}) + _, _, err := r.PushImageLayerRegistry(imageID, layer, makeURL("/v1/"), []byte{}) if err != nil { t.Fatal(err) } @@ -694,7 +682,7 @@ func TestNewIndexInfo(t *testing.T) { func TestPushRegistryTag(t *testing.T) { r := spawnTestRegistrySession(t) - err := r.PushRegistryTag("foo42/bar", imageID, "stable", makeURL("/v1/"), token) + err := r.PushRegistryTag("foo42/bar", imageID, "stable", makeURL("/v1/")) if err != nil { t.Fatal(err) } diff --git a/registry/service.go b/registry/service.go index 87fc1d076f..067df107c2 100644 --- a/registry/service.go +++ b/registry/service.go @@ -32,7 +32,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) { return "", err } authConfig.ServerAddress = endpoint.String() - return Login(authConfig, endpoint, HTTPRequestFactory(nil)) + return Login(authConfig, endpoint) } // Search queries the public registry for images matching the specified @@ -42,12 +42,13 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers if err != nil { return nil, err } + // *TODO: Search multiple indexes. endpoint, err := repoInfo.GetEndpoint() if err != nil { return nil, err } - r, err := NewSession(authConfig, HTTPRequestFactory(headers), endpoint, true) + r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint) if err != nil { return nil, err } diff --git a/registry/session.go b/registry/session.go index e65f82cd61..686e322dab 100644 --- a/registry/session.go +++ b/registry/session.go @@ -3,6 +3,7 @@ package registry import ( "bytes" "crypto/sha256" + "errors" // this is required for some certificates _ "crypto/sha512" "encoding/hex" @@ -20,64 +21,105 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/docker/cliconfig" "github.com/docker/docker/pkg/httputils" - "github.com/docker/docker/pkg/requestdecorator" "github.com/docker/docker/pkg/tarsum" ) type Session struct { - authConfig *cliconfig.AuthConfig - reqFactory *requestdecorator.RequestFactory indexEndpoint *Endpoint - jar *cookiejar.Jar - timeout TimeoutType + client *http.Client + // TODO(tiborvass): remove authConfig + authConfig *cliconfig.AuthConfig } -func NewSession(authConfig *cliconfig.AuthConfig, factory *requestdecorator.RequestFactory, endpoint *Endpoint, timeout bool) (r *Session, err error) { - r = &Session{ - authConfig: authConfig, - indexEndpoint: endpoint, +// authTransport handles the auth layer when communicating with a v1 registry (private or official) +// +// For private v1 registries, set alwaysSetBasicAuth to true. +// +// For the official v1 registry, if there isn't already an Authorization header in the request, +// but there is an X-Docker-Token header set to true, then Basic Auth will be used to set the Authorization header. +// After sending the request with the provided base http.RoundTripper, if an X-Docker-Token header, representing +// a token, is present in the response, then it gets cached and sent in the Authorization header of all subsequent +// requests. +// +// If the server sends a token without the client having requested it, it is ignored. +// +// This RoundTripper also has a CancelRequest method important for correct timeout handling. +type authTransport struct { + http.RoundTripper + *cliconfig.AuthConfig + + alwaysSetBasicAuth bool + token []string +} + +func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = cloneRequest(req) + + if tr.alwaysSetBasicAuth { + req.SetBasicAuth(tr.Username, tr.Password) + return tr.RoundTripper.RoundTrip(req) } - if timeout { - r.timeout = ReceiveTimeout - } + var askedForToken bool - r.jar, err = cookiejar.New(nil) + // Don't override + if req.Header.Get("Authorization") == "" { + if req.Header.Get("X-Docker-Token") == "true" { + req.SetBasicAuth(tr.Username, tr.Password) + askedForToken = true + } else if len(tr.token) > 0 { + req.Header.Set("Authorization", "Token "+strings.Join(tr.token, ",")) + } + } + resp, err := tr.RoundTripper.RoundTrip(req) if err != nil { return nil, err } + if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 { + tr.token = resp.Header["X-Docker-Token"] + } + return resp, nil +} + +// TODO(tiborvass): remove authConfig param once registry client v2 is vendored +func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) { + r = &Session{ + authConfig: authConfig, + client: client, + indexEndpoint: endpoint, + } + + var alwaysSetBasicAuth bool // If we're working with a standalone private registry over HTTPS, send Basic Auth headers - // alongside our requests. - if r.indexEndpoint.VersionString(1) != IndexServerAddress() && r.indexEndpoint.URL.Scheme == "https" { - info, err := r.indexEndpoint.Ping() + // alongside all our requests. + if endpoint.VersionString(1) != IndexServerAddress() && endpoint.URL.Scheme == "https" { + info, err := endpoint.Ping() if err != nil { return nil, err } - if info.Standalone && authConfig != nil && factory != nil { - logrus.Debugf("Endpoint %s is eligible for private registry. Enabling decorator.", r.indexEndpoint.String()) - dec := requestdecorator.NewAuthDecorator(authConfig.Username, authConfig.Password) - factory.AddDecorator(dec) + + if info.Standalone && authConfig != nil { + logrus.Debugf("Endpoint %s is eligible for private registry. Enabling decorator.", endpoint.String()) + alwaysSetBasicAuth = true } } - r.reqFactory = factory - return r, nil -} + client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth} -func (r *Session) doRequest(req *http.Request) (*http.Response, *http.Client, error) { - return doRequest(req, r.jar, r.timeout, r.indexEndpoint.IsSecure) + jar, err := cookiejar.New(nil) + if err != nil { + return nil, errors.New("cookiejar.New is not supposed to return an error") + } + client.Jar = jar + + return r, nil } // Retrieve the history of a given image from the Registry. // Return a list of the parent's json (requested image included) -func (r *Session) GetRemoteHistory(imgID, registry string, token []string) ([]string, error) { - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/ancestry", nil) - if err != nil { - return nil, err - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) +func (r *Session) GetRemoteHistory(imgID, registry string) ([]string, error) { + res, err := r.client.Get(registry + "images/" + imgID + "/ancestry") if err != nil { return nil, err } @@ -89,27 +131,18 @@ func (r *Session) GetRemoteHistory(imgID, registry string, token []string) ([]st return nil, httputils.NewHTTPRequestError(fmt.Sprintf("Server error: %d trying to fetch remote history for %s", res.StatusCode, imgID), res) } - jsonString, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("Error while reading the http response: %s", err) + var history []string + if err := json.NewDecoder(res.Body).Decode(&history); err != nil { + return nil, fmt.Errorf("Error while reading the http response: %v", err) } - logrus.Debugf("Ancestry: %s", jsonString) - history := new([]string) - if err := json.Unmarshal(jsonString, history); err != nil { - return nil, err - } - return *history, nil + logrus.Debugf("Ancestry: %v", history) + return history, nil } // Check if an image exists in the Registry -func (r *Session) LookupRemoteImage(imgID, registry string, token []string) error { - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/json", nil) - if err != nil { - return err - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) +func (r *Session) LookupRemoteImage(imgID, registry string) error { + res, err := r.client.Get(registry + "images/" + imgID + "/json") if err != nil { return err } @@ -121,14 +154,8 @@ func (r *Session) LookupRemoteImage(imgID, registry string, token []string) erro } // Retrieve an image from the Registry. -func (r *Session) GetRemoteImageJSON(imgID, registry string, token []string) ([]byte, int, error) { - // Get the JSON - req, err := r.reqFactory.NewRequest("GET", registry+"images/"+imgID+"/json", nil) - if err != nil { - return nil, -1, fmt.Errorf("Failed to download json: %s", err) - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) +func (r *Session) GetRemoteImageJSON(imgID, registry string) ([]byte, int, error) { + res, err := r.client.Get(registry + "images/" + imgID + "/json") if err != nil { return nil, -1, fmt.Errorf("Failed to download json: %s", err) } @@ -147,44 +174,44 @@ func (r *Session) GetRemoteImageJSON(imgID, registry string, token []string) ([] jsonString, err := ioutil.ReadAll(res.Body) if err != nil { - return nil, -1, fmt.Errorf("Failed to parse downloaded json: %s (%s)", err, jsonString) + return nil, -1, fmt.Errorf("Failed to parse downloaded json: %v (%s)", err, jsonString) } return jsonString, imageSize, nil } -func (r *Session) GetRemoteImageLayer(imgID, registry string, token []string, imgSize int64) (io.ReadCloser, error) { +func (r *Session) GetRemoteImageLayer(imgID, registry string, imgSize int64) (io.ReadCloser, error) { var ( retries = 5 statusCode = 0 - client *http.Client res *http.Response + err error imageURL = fmt.Sprintf("%simages/%s/layer", registry, imgID) ) - req, err := r.reqFactory.NewRequest("GET", imageURL, nil) + req, err := http.NewRequest("GET", imageURL, nil) if err != nil { - return nil, fmt.Errorf("Error while getting from the server: %s\n", err) + return nil, fmt.Errorf("Error while getting from the server: %v", err) } - setTokenAuth(req, token) + // TODO: why are we doing retries at this level? + // These retries should be generic to both v1 and v2 for i := 1; i <= retries; i++ { statusCode = 0 - res, client, err = r.doRequest(req) - if err != nil { - logrus.Debugf("Error contacting registry: %s", err) - if res != nil { - if res.Body != nil { - res.Body.Close() - } - statusCode = res.StatusCode - } - if i == retries { - return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", - statusCode, imgID) - } - time.Sleep(time.Duration(i) * 5 * time.Second) - continue + res, err = r.client.Do(req) + if err == nil { + break } - break + logrus.Debugf("Error contacting registry %s: %v", registry, err) + if res != nil { + if res.Body != nil { + res.Body.Close() + } + statusCode = res.StatusCode + } + if i == retries { + return nil, fmt.Errorf("Server error: Status %d while fetching image layer (%s)", + statusCode, imgID) + } + time.Sleep(time.Duration(i) * 5 * time.Second) } if res.StatusCode != 200 { @@ -195,13 +222,13 @@ func (r *Session) GetRemoteImageLayer(imgID, registry string, token []string, im if res.Header.Get("Accept-Ranges") == "bytes" && imgSize > 0 { logrus.Debugf("server supports resume") - return httputils.ResumableRequestReaderWithInitialResponse(client, req, 5, imgSize, res), nil + return httputils.ResumableRequestReaderWithInitialResponse(r.client, req, 5, imgSize, res), nil } logrus.Debugf("server doesn't support resume") return res.Body, nil } -func (r *Session) GetRemoteTags(registries []string, repository string, token []string) (map[string]string, error) { +func (r *Session) GetRemoteTags(registries []string, repository string) (map[string]string, error) { if strings.Count(repository, "/") == 0 { // This will be removed once the Registry supports auto-resolution on // the "library" namespace @@ -209,13 +236,7 @@ func (r *Session) GetRemoteTags(registries []string, repository string, token [] } for _, host := range registries { endpoint := fmt.Sprintf("%srepositories/%s/tags", host, repository) - req, err := r.reqFactory.NewRequest("GET", endpoint, nil) - - if err != nil { - return nil, err - } - setTokenAuth(req, token) - res, _, err := r.doRequest(req) + res, err := r.client.Get(endpoint) if err != nil { return nil, err } @@ -263,16 +284,13 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { logrus.Debugf("[registry] Calling GET %s", repositoryTarget) - req, err := r.reqFactory.NewRequest("GET", repositoryTarget, nil) + req, err := http.NewRequest("GET", repositoryTarget, nil) if err != nil { return nil, err } - if r.authConfig != nil && len(r.authConfig.Username) > 0 { - req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) - } + // this will set basic auth in r.client.Transport and send cached X-Docker-Token headers for all subsequent requests req.Header.Set("X-Docker-Token", "true") - - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, err } @@ -292,11 +310,6 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { return nil, httputils.NewHTTPRequestError(fmt.Sprintf("Error: Status %d trying to pull repository %s: %q", res.StatusCode, remote, errBody), res) } - var tokens []string - if res.Header.Get("X-Docker-Token") != "" { - tokens = res.Header["X-Docker-Token"] - } - var endpoints []string if res.Header.Get("X-Docker-Endpoints") != "" { endpoints, err = buildEndpointsList(res.Header["X-Docker-Endpoints"], r.indexEndpoint.VersionString(1)) @@ -322,29 +335,29 @@ func (r *Session) GetRepositoryData(remote string) (*RepositoryData, error) { return &RepositoryData{ ImgList: imgsData, Endpoints: endpoints, - Tokens: tokens, }, nil } -func (r *Session) PushImageChecksumRegistry(imgData *ImgData, registry string, token []string) error { +func (r *Session) PushImageChecksumRegistry(imgData *ImgData, registry string) error { - logrus.Debugf("[registry] Calling PUT %s", registry+"images/"+imgData.ID+"/checksum") + u := registry + "images/" + imgData.ID + "/checksum" - req, err := r.reqFactory.NewRequest("PUT", registry+"images/"+imgData.ID+"/checksum", nil) + logrus.Debugf("[registry] Calling PUT %s", u) + + req, err := http.NewRequest("PUT", u, nil) if err != nil { return err } - setTokenAuth(req, token) req.Header.Set("X-Docker-Checksum", imgData.Checksum) req.Header.Set("X-Docker-Checksum-Payload", imgData.ChecksumPayload) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { - return fmt.Errorf("Failed to upload metadata: %s", err) + return fmt.Errorf("Failed to upload metadata: %v", err) } defer res.Body.Close() if len(res.Cookies()) > 0 { - r.jar.SetCookies(req.URL, res.Cookies()) + r.client.Jar.SetCookies(req.URL, res.Cookies()) } if res.StatusCode != 200 { errBody, err := ioutil.ReadAll(res.Body) @@ -363,18 +376,19 @@ func (r *Session) PushImageChecksumRegistry(imgData *ImgData, registry string, t } // Push a local image to the registry -func (r *Session) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, registry string, token []string) error { +func (r *Session) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, registry string) error { - logrus.Debugf("[registry] Calling PUT %s", registry+"images/"+imgData.ID+"/json") + u := registry + "images/" + imgData.ID + "/json" - req, err := r.reqFactory.NewRequest("PUT", registry+"images/"+imgData.ID+"/json", bytes.NewReader(jsonRaw)) + logrus.Debugf("[registry] Calling PUT %s", u) + + req, err := http.NewRequest("PUT", u, bytes.NewReader(jsonRaw)) if err != nil { return err } req.Header.Add("Content-type", "application/json") - setTokenAuth(req, token) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return fmt.Errorf("Failed to upload metadata: %s", err) } @@ -398,9 +412,11 @@ func (r *Session) PushImageJSONRegistry(imgData *ImgData, jsonRaw []byte, regist return nil } -func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry string, token []string, jsonRaw []byte) (checksum string, checksumPayload string, err error) { +func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry string, jsonRaw []byte) (checksum string, checksumPayload string, err error) { - logrus.Debugf("[registry] Calling PUT %s", registry+"images/"+imgID+"/layer") + u := registry + "images/" + imgID + "/layer" + + logrus.Debugf("[registry] Calling PUT %s", u) tarsumLayer, err := tarsum.NewTarSum(layer, false, tarsum.Version0) if err != nil { @@ -411,17 +427,16 @@ func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry h.Write([]byte{'\n'}) checksumLayer := io.TeeReader(tarsumLayer, h) - req, err := r.reqFactory.NewRequest("PUT", registry+"images/"+imgID+"/layer", checksumLayer) + req, err := http.NewRequest("PUT", u, checksumLayer) if err != nil { return "", "", err } req.Header.Add("Content-Type", "application/octet-stream") req.ContentLength = -1 req.TransferEncoding = []string{"chunked"} - setTokenAuth(req, token) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { - return "", "", fmt.Errorf("Failed to upload layer: %s", err) + return "", "", fmt.Errorf("Failed to upload layer: %v", err) } if rc, ok := layer.(io.Closer); ok { if err := rc.Close(); err != nil { @@ -444,19 +459,18 @@ func (r *Session) PushImageLayerRegistry(imgID string, layer io.Reader, registry // push a tag on the registry. // Remote has the format '/ -func (r *Session) PushRegistryTag(remote, revision, tag, registry string, token []string) error { +func (r *Session) PushRegistryTag(remote, revision, tag, registry string) error { // "jsonify" the string revision = "\"" + revision + "\"" path := fmt.Sprintf("repositories/%s/tags/%s", remote, tag) - req, err := r.reqFactory.NewRequest("PUT", registry+path, strings.NewReader(revision)) + req, err := http.NewRequest("PUT", registry+path, strings.NewReader(revision)) if err != nil { return err } req.Header.Add("Content-type", "application/json") - setTokenAuth(req, token) req.ContentLength = int64(len(revision)) - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return err } @@ -491,7 +505,8 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate logrus.Debugf("[registry] PUT %s", u) logrus.Debugf("Image list pushed to index:\n%s", imgListJSON) headers := map[string][]string{ - "Content-type": {"application/json"}, + "Content-type": {"application/json"}, + // this will set basic auth in r.client.Transport and send cached X-Docker-Token headers for all subsequent requests "X-Docker-Token": {"true"}, } if validate { @@ -526,9 +541,6 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate } return nil, httputils.NewHTTPRequestError(fmt.Sprintf("Error: Status %d trying to push repository %s: %q", res.StatusCode, remote, errBody), res) } - if res.Header.Get("X-Docker-Token") == "" { - return nil, fmt.Errorf("Index response didn't contain an access token") - } tokens = res.Header["X-Docker-Token"] logrus.Debugf("Auth token: %v", tokens) @@ -539,8 +551,7 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate if err != nil { return nil, err } - } - if validate { + } else { if res.StatusCode != 204 { errBody, err := ioutil.ReadAll(res.Body) if err != nil { @@ -551,22 +562,20 @@ func (r *Session) PushImageJSONIndex(remote string, imgList []*ImgData, validate } return &RepositoryData{ - Tokens: tokens, Endpoints: endpoints, }, nil } func (r *Session) putImageRequest(u string, headers map[string][]string, body []byte) (*http.Response, error) { - req, err := r.reqFactory.NewRequest("PUT", u, bytes.NewReader(body)) + req, err := http.NewRequest("PUT", u, bytes.NewReader(body)) if err != nil { return nil, err } - req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) req.ContentLength = int64(len(body)) for k, v := range headers { req.Header[k] = v } - response, _, err := r.doRequest(req) + response, err := r.client.Do(req) if err != nil { return nil, err } @@ -580,15 +589,7 @@ func shouldRedirect(response *http.Response) bool { func (r *Session) SearchRepositories(term string) (*SearchResults, error) { logrus.Debugf("Index server: %s", r.indexEndpoint) u := r.indexEndpoint.VersionString(1) + "search?q=" + url.QueryEscape(term) - req, err := r.reqFactory.NewRequest("GET", u, nil) - if err != nil { - return nil, err - } - if r.authConfig != nil && len(r.authConfig.Username) > 0 { - req.SetBasicAuth(r.authConfig.Username, r.authConfig.Password) - } - req.Header.Set("X-Docker-Token", "true") - res, _, err := r.doRequest(req) + res, err := r.client.Get(u) if err != nil { return nil, err } @@ -600,6 +601,7 @@ func (r *Session) SearchRepositories(term string) (*SearchResults, error) { return result, json.NewDecoder(res.Body).Decode(result) } +// TODO(tiborvass): remove this once registry client v2 is vendored func (r *Session) GetAuthConfig(withPasswd bool) *cliconfig.AuthConfig { password := "" if withPasswd { @@ -611,9 +613,3 @@ func (r *Session) GetAuthConfig(withPasswd bool) *cliconfig.AuthConfig { Email: r.authConfig.Email, } } - -func setTokenAuth(req *http.Request, token []string) { - if req.Header.Get("Authorization") == "" { // Don't override - req.Header.Set("Authorization", "Token "+strings.Join(token, ",")) - } -} diff --git a/registry/session_v2.go b/registry/session_v2.go index 4188e505bd..c639f9226a 100644 --- a/registry/session_v2.go +++ b/registry/session_v2.go @@ -77,14 +77,14 @@ func (r *Session) GetV2ImageManifest(ep *Endpoint, imageName, tagName string, au method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return nil, "", err } if err := auth.Authorize(req); err != nil { return nil, "", err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, "", err } @@ -118,14 +118,14 @@ func (r *Session) HeadV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Di method := "HEAD" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return false, err } if err := auth.Authorize(req); err != nil { return false, err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return false, err } @@ -152,14 +152,14 @@ func (r *Session) GetV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Dig method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return err } if err := auth.Authorize(req); err != nil { return err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return err } @@ -183,14 +183,14 @@ func (r *Session) GetV2ImageBlobReader(ep *Endpoint, imageName string, dgst dige method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return nil, 0, err } if err := auth.Authorize(req); err != nil { return nil, 0, err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, 0, err } @@ -220,7 +220,7 @@ func (r *Session) PutV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Dig method := "PUT" logrus.Debugf("[registry] Calling %q %s", method, location) - req, err := r.reqFactory.NewRequest(method, location, ioutil.NopCloser(blobRdr)) + req, err := http.NewRequest(method, location, ioutil.NopCloser(blobRdr)) if err != nil { return err } @@ -230,7 +230,7 @@ func (r *Session) PutV2ImageBlob(ep *Endpoint, imageName string, dgst digest.Dig if err := auth.Authorize(req); err != nil { return err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return err } @@ -259,7 +259,7 @@ func (r *Session) initiateBlobUpload(ep *Endpoint, imageName string, auth *Reque } logrus.Debugf("[registry] Calling %q %s", "POST", routeURL) - req, err := r.reqFactory.NewRequest("POST", routeURL, nil) + req, err := http.NewRequest("POST", routeURL, nil) if err != nil { return "", err } @@ -267,7 +267,7 @@ func (r *Session) initiateBlobUpload(ep *Endpoint, imageName string, auth *Reque if err := auth.Authorize(req); err != nil { return "", err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return "", err } @@ -305,14 +305,14 @@ func (r *Session) PutV2ImageManifest(ep *Endpoint, imageName, tagName string, si method := "PUT" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, bytes.NewReader(signedManifest)) + req, err := http.NewRequest(method, routeURL, bytes.NewReader(signedManifest)) if err != nil { return "", err } if err := auth.Authorize(req); err != nil { return "", err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return "", err } @@ -366,14 +366,14 @@ func (r *Session) GetV2RemoteTags(ep *Endpoint, imageName string, auth *RequestA method := "GET" logrus.Debugf("[registry] Calling %q %s", method, routeURL) - req, err := r.reqFactory.NewRequest(method, routeURL, nil) + req, err := http.NewRequest(method, routeURL, nil) if err != nil { return nil, err } if err := auth.Authorize(req); err != nil { return nil, err } - res, _, err := r.doRequest(req) + res, err := r.client.Do(req) if err != nil { return nil, err } diff --git a/registry/token.go b/registry/token.go index b03bd891bb..af7d5f3fcb 100644 --- a/registry/token.go +++ b/registry/token.go @@ -7,15 +7,13 @@ import ( "net/http" "net/url" "strings" - - "github.com/docker/docker/pkg/requestdecorator" ) type tokenResponse struct { Token string `json:"token"` } -func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client, factory *requestdecorator.RequestFactory) (token string, err error) { +func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) { realm, ok := params["realm"] if !ok { return "", errors.New("no realm specified for token auth challenge") @@ -34,7 +32,7 @@ func getToken(username, password string, params map[string]string, registryEndpo } } - req, err := factory.NewRequest("GET", realmURL.String(), nil) + req, err := http.NewRequest("GET", realmURL.String(), nil) if err != nil { return "", err } From cf8c0d0f56021fbea7bc89e378bb937b53aeca3b Mon Sep 17 00:00:00 2001 From: Tibor Vass Date: Fri, 15 May 2015 15:03:08 -0700 Subject: [PATCH 2/3] requestdecorator: repurpose the package and rename to useragent Signed-off-by: Tibor Vass --- pkg/requestdecorator/README.md | 2 - pkg/requestdecorator/requestdecorator.go | 172 -------------- pkg/requestdecorator/requestdecorator_test.go | 222 ------------------ pkg/useragent/README.md | 1 + pkg/useragent/useragent.go | 60 +++++ pkg/useragent/useragent_test.go | 31 +++ registry/registry.go | 18 +- 7 files changed, 101 insertions(+), 405 deletions(-) delete mode 100644 pkg/requestdecorator/README.md delete mode 100644 pkg/requestdecorator/requestdecorator.go delete mode 100644 pkg/requestdecorator/requestdecorator_test.go create mode 100644 pkg/useragent/README.md create mode 100644 pkg/useragent/useragent.go create mode 100644 pkg/useragent/useragent_test.go diff --git a/pkg/requestdecorator/README.md b/pkg/requestdecorator/README.md deleted file mode 100644 index 76f8ca798f..0000000000 --- a/pkg/requestdecorator/README.md +++ /dev/null @@ -1,2 +0,0 @@ -This package provides helper functions for decorating a request with user agent -versions, auth, meta headers. diff --git a/pkg/requestdecorator/requestdecorator.go b/pkg/requestdecorator/requestdecorator.go deleted file mode 100644 index 079e38b21e..0000000000 --- a/pkg/requestdecorator/requestdecorator.go +++ /dev/null @@ -1,172 +0,0 @@ -// Package requestdecorator provides helper functions to decorate a request with -// user agent versions, auth, meta headers. -package requestdecorator - -import ( - "errors" - "io" - "net/http" - "strings" - - "github.com/Sirupsen/logrus" -) - -var ( - ErrNilRequest = errors.New("request cannot be nil") -) - -// UAVersionInfo is used to model UserAgent versions. -type UAVersionInfo struct { - Name string - Version string -} - -func NewUAVersionInfo(name, version string) UAVersionInfo { - return UAVersionInfo{ - Name: name, - Version: version, - } -} - -func (vi *UAVersionInfo) isValid() bool { - const stopChars = " \t\r\n/" - name := vi.Name - vers := vi.Version - if len(name) == 0 || strings.ContainsAny(name, stopChars) { - return false - } - if len(vers) == 0 || strings.ContainsAny(vers, stopChars) { - return false - } - return true -} - -// Convert versions to a string and append the string to the string base. -// -// Each UAVersionInfo will be converted to a string in the format of -// "product/version", where the "product" is get from the name field, while -// version is get from the version field. Several pieces of verson information -// will be concatinated and separated by space. -func AppendVersions(base string, versions ...UAVersionInfo) string { - if len(versions) == 0 { - return base - } - - verstrs := make([]string, 0, 1+len(versions)) - if len(base) > 0 { - verstrs = append(verstrs, base) - } - - for _, v := range versions { - if !v.isValid() { - continue - } - verstrs = append(verstrs, v.Name+"/"+v.Version) - } - return strings.Join(verstrs, " ") -} - -// Decorator is used to change an instance of -// http.Request. It could be used to add more header fields, -// change body, etc. -type Decorator interface { - // ChangeRequest() changes the request accordingly. - // The changed request will be returned or err will be non-nil - // if an error occur. - ChangeRequest(req *http.Request) (newReq *http.Request, err error) -} - -// UserAgentDecorator appends the product/version to the user agent field -// of a request. -type UserAgentDecorator struct { - Versions []UAVersionInfo -} - -func (h *UserAgentDecorator) ChangeRequest(req *http.Request) (*http.Request, error) { - if req == nil { - return req, ErrNilRequest - } - - userAgent := AppendVersions(req.UserAgent(), h.Versions...) - if len(userAgent) > 0 { - req.Header.Set("User-Agent", userAgent) - } - return req, nil -} - -type MetaHeadersDecorator struct { - Headers map[string][]string -} - -func (h *MetaHeadersDecorator) ChangeRequest(req *http.Request) (*http.Request, error) { - if h.Headers == nil { - return req, ErrNilRequest - } - for k, v := range h.Headers { - req.Header[k] = v - } - return req, nil -} - -type AuthDecorator struct { - login string - password string -} - -func NewAuthDecorator(login, password string) Decorator { - return &AuthDecorator{ - login: login, - password: password, - } -} - -func (self *AuthDecorator) ChangeRequest(req *http.Request) (*http.Request, error) { - if req == nil { - return req, ErrNilRequest - } - req.SetBasicAuth(self.login, self.password) - return req, nil -} - -// RequestFactory creates an HTTP request -// and applies a list of decorators on the request. -type RequestFactory struct { - decorators []Decorator -} - -func NewRequestFactory(d ...Decorator) *RequestFactory { - return &RequestFactory{ - decorators: d, - } -} - -func (f *RequestFactory) AddDecorator(d ...Decorator) { - f.decorators = append(f.decorators, d...) -} - -func (f *RequestFactory) GetDecorators() []Decorator { - return f.decorators -} - -// NewRequest() creates a new *http.Request, -// applies all decorators in the Factory on the request, -// then applies decorators provided by d on the request. -func (h *RequestFactory) NewRequest(method, urlStr string, body io.Reader, d ...Decorator) (*http.Request, error) { - req, err := http.NewRequest(method, urlStr, body) - if err != nil { - return nil, err - } - - // By default, a nil factory should work. - if h == nil { - return req, nil - } - for _, dec := range h.decorators { - req, _ = dec.ChangeRequest(req) - } - for _, dec := range d { - req, _ = dec.ChangeRequest(req) - } - logrus.Debugf("%v -- HEADERS: %v", req.URL, req.Header) - return req, err -} diff --git a/pkg/requestdecorator/requestdecorator_test.go b/pkg/requestdecorator/requestdecorator_test.go deleted file mode 100644 index ed61135467..0000000000 --- a/pkg/requestdecorator/requestdecorator_test.go +++ /dev/null @@ -1,222 +0,0 @@ -package requestdecorator - -import ( - "net/http" - "strings" - "testing" -) - -func TestUAVersionInfo(t *testing.T) { - uavi := NewUAVersionInfo("foo", "bar") - if !uavi.isValid() { - t.Fatalf("UAVersionInfo should be valid") - } - uavi = NewUAVersionInfo("", "bar") - if uavi.isValid() { - t.Fatalf("Expected UAVersionInfo to be invalid") - } - uavi = NewUAVersionInfo("foo", "") - if uavi.isValid() { - t.Fatalf("Expected UAVersionInfo to be invalid") - } -} - -func TestUserAgentDecorator(t *testing.T) { - httpVersion := make([]UAVersionInfo, 2) - httpVersion = append(httpVersion, NewUAVersionInfo("testname", "testversion")) - httpVersion = append(httpVersion, NewUAVersionInfo("name", "version")) - uad := &UserAgentDecorator{ - Versions: httpVersion, - } - - req, err := http.NewRequest("GET", "/something", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - reqDecorated, err := uad.ChangeRequest(req) - if err != nil { - t.Fatal(err) - } - - if reqDecorated.Header.Get("User-Agent") != "testname/testversion name/version" { - t.Fatalf("Request should have User-Agent 'testname/testversion name/version'") - } -} - -func TestUserAgentDecoratorErr(t *testing.T) { - httpVersion := make([]UAVersionInfo, 0) - uad := &UserAgentDecorator{ - Versions: httpVersion, - } - - var req *http.Request - _, err := uad.ChangeRequest(req) - if err == nil { - t.Fatalf("Expected to get ErrNilRequest instead no error was returned") - } -} - -func TestMetaHeadersDecorator(t *testing.T) { - var headers = map[string][]string{ - "key1": {"value1"}, - "key2": {"value2"}, - } - mhd := &MetaHeadersDecorator{ - Headers: headers, - } - - req, err := http.NewRequest("GET", "/something", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - reqDecorated, err := mhd.ChangeRequest(req) - if err != nil { - t.Fatal(err) - } - - v, ok := reqDecorated.Header["key1"] - if !ok { - t.Fatalf("Expected to have header key1") - } - if v[0] != "value1" { - t.Fatalf("Expected value for key1 isn't value1") - } - - v, ok = reqDecorated.Header["key2"] - if !ok { - t.Fatalf("Expected to have header key2") - } - if v[0] != "value2" { - t.Fatalf("Expected value for key2 isn't value2") - } -} - -func TestMetaHeadersDecoratorErr(t *testing.T) { - mhd := &MetaHeadersDecorator{} - - var req *http.Request - _, err := mhd.ChangeRequest(req) - if err == nil { - t.Fatalf("Expected to get ErrNilRequest instead no error was returned") - } -} - -func TestAuthDecorator(t *testing.T) { - ad := NewAuthDecorator("test", "password") - - req, err := http.NewRequest("GET", "/something", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - reqDecorated, err := ad.ChangeRequest(req) - if err != nil { - t.Fatal(err) - } - - username, password, ok := reqDecorated.BasicAuth() - if !ok { - t.Fatalf("Cannot retrieve basic auth info from request") - } - if username != "test" { - t.Fatalf("Expected username to be test, got %s", username) - } - if password != "password" { - t.Fatalf("Expected password to be password, got %s", password) - } -} - -func TestAuthDecoratorErr(t *testing.T) { - ad := &AuthDecorator{} - - var req *http.Request - _, err := ad.ChangeRequest(req) - if err == nil { - t.Fatalf("Expected to get ErrNilRequest instead no error was returned") - } -} - -func TestRequestFactory(t *testing.T) { - ad := NewAuthDecorator("test", "password") - httpVersion := make([]UAVersionInfo, 2) - httpVersion = append(httpVersion, NewUAVersionInfo("testname", "testversion")) - httpVersion = append(httpVersion, NewUAVersionInfo("name", "version")) - uad := &UserAgentDecorator{ - Versions: httpVersion, - } - - requestFactory := NewRequestFactory(ad, uad) - - if l := len(requestFactory.GetDecorators()); l != 2 { - t.Fatalf("Expected to have two decorators, got %d", l) - } - - req, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test")) - if err != nil { - t.Fatal(err) - } - - username, password, ok := req.BasicAuth() - if !ok { - t.Fatalf("Cannot retrieve basic auth info from request") - } - if username != "test" { - t.Fatalf("Expected username to be test, got %s", username) - } - if password != "password" { - t.Fatalf("Expected password to be password, got %s", password) - } - if req.Header.Get("User-Agent") != "testname/testversion name/version" { - t.Fatalf("Request should have User-Agent 'testname/testversion name/version'") - } -} - -func TestRequestFactoryNewRequestWithDecorators(t *testing.T) { - ad := NewAuthDecorator("test", "password") - - requestFactory := NewRequestFactory(ad) - - if l := len(requestFactory.GetDecorators()); l != 1 { - t.Fatalf("Expected to have one decorators, got %d", l) - } - - ad2 := NewAuthDecorator("test2", "password2") - - req, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test"), ad2) - if err != nil { - t.Fatal(err) - } - - username, password, ok := req.BasicAuth() - if !ok { - t.Fatalf("Cannot retrieve basic auth info from request") - } - if username != "test2" { - t.Fatalf("Expected username to be test, got %s", username) - } - if password != "password2" { - t.Fatalf("Expected password to be password, got %s", password) - } -} - -func TestRequestFactoryAddDecorator(t *testing.T) { - requestFactory := NewRequestFactory() - - if l := len(requestFactory.GetDecorators()); l != 0 { - t.Fatalf("Expected to have zero decorators, got %d", l) - } - - ad := NewAuthDecorator("test", "password") - requestFactory.AddDecorator(ad) - - if l := len(requestFactory.GetDecorators()); l != 1 { - t.Fatalf("Expected to have one decorators, got %d", l) - } -} - -func TestRequestFactoryNil(t *testing.T) { - var requestFactory RequestFactory - _, err := requestFactory.NewRequest("GET", "/test", strings.NewReader("test")) - if err != nil { - t.Fatalf("Expected not to get and error, got %s", err) - } -} diff --git a/pkg/useragent/README.md b/pkg/useragent/README.md new file mode 100644 index 0000000000..d9cb367d10 --- /dev/null +++ b/pkg/useragent/README.md @@ -0,0 +1 @@ +This package provides helper functions to pack version information into a single User-Agent header. diff --git a/pkg/useragent/useragent.go b/pkg/useragent/useragent.go new file mode 100644 index 0000000000..9e35d1c70d --- /dev/null +++ b/pkg/useragent/useragent.go @@ -0,0 +1,60 @@ +// Package useragent provides helper functions to pack +// version information into a single User-Agent header. +package useragent + +import ( + "errors" + "strings" +) + +var ( + ErrNilRequest = errors.New("request cannot be nil") +) + +// VersionInfo is used to model UserAgent versions. +type VersionInfo struct { + Name string + Version string +} + +func (vi *VersionInfo) isValid() bool { + const stopChars = " \t\r\n/" + name := vi.Name + vers := vi.Version + if len(name) == 0 || strings.ContainsAny(name, stopChars) { + return false + } + if len(vers) == 0 || strings.ContainsAny(vers, stopChars) { + return false + } + return true +} + +// Convert versions to a string and append the string to the string base. +// +// Each VersionInfo will be converted to a string in the format of +// "product/version", where the "product" is get from the name field, while +// version is get from the version field. Several pieces of verson information +// will be concatinated and separated by space. +// +// Example: +// AppendVersions("base", VersionInfo{"foo", "1.0"}, VersionInfo{"bar", "2.0"}) +// results in "base foo/1.0 bar/2.0". +func AppendVersions(base string, versions ...VersionInfo) string { + if len(versions) == 0 { + return base + } + + verstrs := make([]string, 0, 1+len(versions)) + if len(base) > 0 { + verstrs = append(verstrs, base) + } + + for _, v := range versions { + if !v.isValid() { + continue + } + verstrs = append(verstrs, v.Name+"/"+v.Version) + } + return strings.Join(verstrs, " ") +} diff --git a/pkg/useragent/useragent_test.go b/pkg/useragent/useragent_test.go new file mode 100644 index 0000000000..0ad7243a6d --- /dev/null +++ b/pkg/useragent/useragent_test.go @@ -0,0 +1,31 @@ +package useragent + +import "testing" + +func TestVersionInfo(t *testing.T) { + vi := VersionInfo{"foo", "bar"} + if !vi.isValid() { + t.Fatalf("VersionInfo should be valid") + } + vi = VersionInfo{"", "bar"} + if vi.isValid() { + t.Fatalf("Expected VersionInfo to be invalid") + } + vi = VersionInfo{"foo", ""} + if vi.isValid() { + t.Fatalf("Expected VersionInfo to be invalid") + } +} + +func TestAppendVersions(t *testing.T) { + vis := []VersionInfo{ + {"foo", "1.0"}, + {"bar", "0.1"}, + {"pi", "3.1.4"}, + } + v := AppendVersions("base", vis...) + expect := "base foo/1.0 bar/0.1 pi/3.1.4" + if v != expect { + t.Fatalf("expected %q, got %q", expect, v) + } +} diff --git a/registry/registry.go b/registry/registry.go index db77a985eb..4f5403002a 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -18,8 +18,8 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/docker/autogen/dockerversion" "github.com/docker/docker/pkg/parsers/kernel" - "github.com/docker/docker/pkg/requestdecorator" "github.com/docker/docker/pkg/timeoutconn" + "github.com/docker/docker/pkg/useragent" ) var ( @@ -186,17 +186,17 @@ func cloneRequest(r *http.Request) *http.Request { func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) { req = cloneRequest(req) - httpVersion := make([]requestdecorator.UAVersionInfo, 0, 4) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("docker", dockerversion.VERSION)) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("go", runtime.Version())) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("git-commit", dockerversion.GITCOMMIT)) + httpVersion := make([]useragent.VersionInfo, 0, 4) + httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION}) + httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()}) + httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT}) if kernelVersion, err := kernel.GetKernelVersion(); err == nil { - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("kernel", kernelVersion.String())) + httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()}) } - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("os", runtime.GOOS)) - httpVersion = append(httpVersion, requestdecorator.NewUAVersionInfo("arch", runtime.GOARCH)) + httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS}) + httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH}) - userAgent := requestdecorator.AppendVersions(req.UserAgent(), httpVersion...) + userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...) req.Header.Set("User-Agent", userAgent) From 73823e5e56446b23ce01bb8e44a9670ab2552b0a Mon Sep 17 00:00:00 2001 From: Tibor Vass Date: Fri, 15 May 2015 18:35:04 -0700 Subject: [PATCH 3/3] Add transport package to support CancelRequest Signed-off-by: Tibor Vass --- graph/pull.go | 12 +-- graph/push.go | 12 +-- pkg/transport/LICENSE | 27 +++++++ pkg/transport/transport.go | 148 +++++++++++++++++++++++++++++++++++++ registry/auth.go | 26 +++---- registry/endpoint.go | 25 +++---- registry/endpoint_test.go | 3 +- registry/registry.go | 106 +++++++++++--------------- registry/registry_test.go | 15 ++-- registry/service.go | 12 ++- registry/session.go | 59 ++++++++++++--- registry/session_v2.go | 4 +- registry/token.go | 4 +- 13 files changed, 325 insertions(+), 128 deletions(-) create mode 100644 pkg/transport/LICENSE create mode 100644 pkg/transport/transport.go diff --git a/graph/pull.go b/graph/pull.go index 013729a8c0..0c33e313a8 100644 --- a/graph/pull.go +++ b/graph/pull.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/registry" "github.com/docker/docker/utils" ) @@ -55,16 +56,17 @@ func (s *TagStore) Pull(image string, tag string, imagePullConfig *ImagePullConf defer s.poolRemove("pull", utils.ImageReference(repoInfo.LocalName, tag)) logrus.Debugf("pulling image from host %q with remote name %q", repoInfo.Index.Name, repoInfo.RemoteName) - endpoint, err := repoInfo.GetEndpoint() + + endpoint, err := repoInfo.GetEndpoint(imagePullConfig.MetaHeaders) if err != nil { return err } - + // TODO(tiborvass): reuse client from endpoint? // Adds Docker-specific headers as well as user-specified headers (metaHeaders) - tr := ®istry.DockerHeaders{ + tr := transport.NewTransport( registry.NewTransport(registry.ReceiveTimeout, endpoint.IsSecure), - imagePullConfig.MetaHeaders, - } + registry.DockerHeaders(imagePullConfig.MetaHeaders)..., + ) client := registry.HTTPClient(tr) r, err := registry.NewSession(client, imagePullConfig.AuthConfig, endpoint) if err != nil { diff --git a/graph/push.go b/graph/push.go index 6e9a367bcb..817ef707fc 100644 --- a/graph/push.go +++ b/graph/push.go @@ -18,6 +18,7 @@ import ( "github.com/docker/docker/pkg/progressreader" "github.com/docker/docker/pkg/streamformatter" "github.com/docker/docker/pkg/stringid" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/registry" "github.com/docker/docker/runconfig" "github.com/docker/docker/utils" @@ -509,16 +510,17 @@ func (s *TagStore) Push(localName string, imagePushConfig *ImagePushConfig) erro } defer s.poolRemove("push", repoInfo.LocalName) - endpoint, err := repoInfo.GetEndpoint() + endpoint, err := repoInfo.GetEndpoint(imagePushConfig.MetaHeaders) if err != nil { return err } - + // TODO(tiborvass): reuse client from endpoint? // Adds Docker-specific headers as well as user-specified headers (metaHeaders) - tr := ®istry.DockerHeaders{ + tr := transport.NewTransport( registry.NewTransport(registry.NoTimeout, endpoint.IsSecure), - imagePushConfig.MetaHeaders, - } + registry.DockerHeaders(imagePushConfig.MetaHeaders)..., + ) + client := registry.HTTPClient(tr) r, err := registry.NewSession(client, imagePushConfig.AuthConfig, endpoint) if err != nil { return err diff --git a/pkg/transport/LICENSE b/pkg/transport/LICENSE new file mode 100644 index 0000000000..d02f24fd52 --- /dev/null +++ b/pkg/transport/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The oauth2 Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go new file mode 100644 index 0000000000..510d8b4bc2 --- /dev/null +++ b/pkg/transport/transport.go @@ -0,0 +1,148 @@ +package transport + +import ( + "io" + "net/http" + "sync" +) + +type RequestModifier interface { + ModifyRequest(*http.Request) error +} + +type headerModifier http.Header + +// NewHeaderRequestModifier returns a RequestModifier that merges the HTTP headers +// passed as an argument, with the HTTP headers of a request. +// +// If the same key is present in both, the modifying header values for that key, +// are appended to the values for that same key in the request header. +func NewHeaderRequestModifier(header http.Header) RequestModifier { + return headerModifier(header) +} + +func (h headerModifier) ModifyRequest(req *http.Request) error { + for k, s := range http.Header(h) { + req.Header[k] = append(req.Header[k], s...) + } + + return nil +} + +// NewTransport returns an http.RoundTripper that modifies requests according to +// the RequestModifiers passed in the arguments, before sending the requests to +// the base http.RoundTripper (which, if nil, defaults to http.DefaultTransport). +func NewTransport(base http.RoundTripper, modifiers ...RequestModifier) http.RoundTripper { + return &transport{ + Modifiers: modifiers, + Base: base, + } +} + +// transport is an http.RoundTripper that makes HTTP requests after +// copying and modifying the request +type transport struct { + Modifiers []RequestModifier + Base http.RoundTripper + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := CloneRequest(req) + for _, modifier := range t.Modifiers { + if err := modifier.ModifyRequest(req2); err != nil { + return nil, err + } + } + + t.setModReq(req, req2) + res, err := t.base().RoundTrip(req2) + if err != nil { + t.setModReq(req, nil) + return nil, err + } + res.Body = &OnEOFReader{ + Rc: res.Body, + Fn: func() { t.setModReq(req, nil) }, + } + return res, nil +} + +// CancelRequest cancels an in-flight request by closing its connection. +func (t *transport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := t.base().(canceler); ok { + t.mu.Lock() + modReq := t.modReq[req] + delete(t.modReq, req) + t.mu.Unlock() + cr.CancelRequest(modReq) + } +} + +func (t *transport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} + +func (t *transport) setModReq(orig, mod *http.Request) { + t.mu.Lock() + defer t.mu.Unlock() + if t.modReq == nil { + t.modReq = make(map[*http.Request]*http.Request) + } + if mod == nil { + delete(t.modReq, orig) + } else { + t.modReq[orig] = mod + } +} + +// CloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +func CloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + + return r2 +} + +// OnEOFReader ensures a callback function is called +// on Close() and when the underlying Reader returns an io.EOF error +type OnEOFReader struct { + Rc io.ReadCloser + Fn func() +} + +func (r *OnEOFReader) Read(p []byte) (n int, err error) { + n, err = r.Rc.Read(p) + if err == io.EOF { + r.runFunc() + } + return +} + +func (r *OnEOFReader) Close() error { + err := r.Rc.Close() + r.runFunc() + return err +} + +func (r *OnEOFReader) runFunc() { + if fn := r.Fn; fn != nil { + fn() + r.Fn = nil + } +} diff --git a/registry/auth.go b/registry/auth.go index 0b6c3b0f95..33f8fa0689 100644 --- a/registry/auth.go +++ b/registry/auth.go @@ -44,8 +44,6 @@ func (auth *RequestAuthorization) getToken() (string, error) { return auth.tokenCache, nil } - client := auth.registryEndpoint.HTTPClient() - for _, challenge := range auth.registryEndpoint.AuthChallenges { switch strings.ToLower(challenge.Scheme) { case "basic": @@ -57,7 +55,7 @@ func (auth *RequestAuthorization) getToken() (string, error) { params[k] = v } params["scope"] = fmt.Sprintf("%s:%s:%s", auth.resource, auth.scope, strings.Join(auth.actions, ",")) - token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint, client) + token, err := getToken(auth.authConfig.Username, auth.authConfig.Password, params, auth.registryEndpoint) if err != nil { return "", err } @@ -104,7 +102,6 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri status string reqBody []byte err error - client = registryEndpoint.HTTPClient() reqStatusCode = 0 serverAddress = authConfig.ServerAddress ) @@ -128,7 +125,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri // using `bytes.NewReader(jsonBody)` here causes the server to respond with a 411 status. b := strings.NewReader(string(jsonBody)) - req1, err := client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) + req1, err := registryEndpoint.client.Post(serverAddress+"users/", "application/json; charset=utf-8", b) if err != nil { return "", fmt.Errorf("Server Error: %s", err) } @@ -151,7 +148,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri if string(reqBody) == "\"Username or email already exists\"" { req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err } @@ -180,7 +177,7 @@ func loginV1(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri // protected, so people can use `docker login` as an auth check. req, err := http.NewRequest("GET", serverAddress+"users/", nil) req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err } @@ -217,7 +214,6 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri var ( err error allErrors []error - client = registryEndpoint.HTTPClient() ) for _, challenge := range registryEndpoint.AuthChallenges { @@ -225,9 +221,9 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri switch strings.ToLower(challenge.Scheme) { case "basic": - err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) + err = tryV2BasicAuthLogin(authConfig, challenge.Parameters, registryEndpoint) case "bearer": - err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint, client) + err = tryV2TokenAuthLogin(authConfig, challenge.Parameters, registryEndpoint) default: // Unsupported challenge types are explicitly skipped. err = fmt.Errorf("unsupported auth scheme: %q", challenge.Scheme) @@ -245,7 +241,7 @@ func loginV2(authConfig *cliconfig.AuthConfig, registryEndpoint *Endpoint) (stri return "", fmt.Errorf("no successful auth challenge for %s - errors: %s", registryEndpoint, allErrors) } -func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error { +func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error { req, err := http.NewRequest("GET", registryEndpoint.Path(""), nil) if err != nil { return err @@ -253,7 +249,7 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str req.SetBasicAuth(authConfig.Username, authConfig.Password) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return err } @@ -266,8 +262,8 @@ func tryV2BasicAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str return nil } -func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint, client *http.Client) error { - token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint, client) +func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]string, registryEndpoint *Endpoint) error { + token, err := getToken(authConfig.Username, authConfig.Password, params, registryEndpoint) if err != nil { return err } @@ -279,7 +275,7 @@ func tryV2TokenAuthLogin(authConfig *cliconfig.AuthConfig, params map[string]str req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return err } diff --git a/registry/endpoint.go b/registry/endpoint.go index 25f66ad25f..ce92668f41 100644 --- a/registry/endpoint.go +++ b/registry/endpoint.go @@ -11,6 +11,7 @@ import ( "github.com/Sirupsen/logrus" "github.com/docker/distribution/registry/api/v2" + "github.com/docker/docker/pkg/transport" ) // for mocking in unit tests @@ -41,9 +42,9 @@ func scanForAPIVersion(address string) (string, APIVersion) { } // NewEndpoint parses the given address to return a registry endpoint. -func NewEndpoint(index *IndexInfo) (*Endpoint, error) { +func NewEndpoint(index *IndexInfo, metaHeaders http.Header) (*Endpoint, error) { // *TODO: Allow per-registry configuration of endpoints. - endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure) + endpoint, err := newEndpoint(index.GetAuthConfigKey(), index.Secure, metaHeaders) if err != nil { return nil, err } @@ -81,7 +82,7 @@ func validateEndpoint(endpoint *Endpoint) error { return nil } -func newEndpoint(address string, secure bool) (*Endpoint, error) { +func newEndpoint(address string, secure bool, metaHeaders http.Header) (*Endpoint, error) { var ( endpoint = new(Endpoint) trimmedAddress string @@ -98,11 +99,13 @@ func newEndpoint(address string, secure bool) (*Endpoint, error) { return nil, err } endpoint.IsSecure = secure + tr := NewTransport(ConnectTimeout, endpoint.IsSecure) + endpoint.client = HTTPClient(transport.NewTransport(tr, DockerHeaders(metaHeaders)...)) return endpoint, nil } -func (repoInfo *RepositoryInfo) GetEndpoint() (*Endpoint, error) { - return NewEndpoint(repoInfo.Index) +func (repoInfo *RepositoryInfo) GetEndpoint(metaHeaders http.Header) (*Endpoint, error) { + return NewEndpoint(repoInfo.Index, metaHeaders) } // Endpoint stores basic information about a registry endpoint. @@ -174,7 +177,7 @@ func (e *Endpoint) pingV1() (RegistryInfo, error) { return RegistryInfo{Standalone: false}, err } - resp, err := e.HTTPClient().Do(req) + resp, err := e.client.Do(req) if err != nil { return RegistryInfo{Standalone: false}, err } @@ -222,7 +225,7 @@ func (e *Endpoint) pingV2() (RegistryInfo, error) { return RegistryInfo{}, err } - resp, err := e.HTTPClient().Do(req) + resp, err := e.client.Do(req) if err != nil { return RegistryInfo{}, err } @@ -261,11 +264,3 @@ HeaderLoop: return RegistryInfo{}, fmt.Errorf("v2 registry endpoint returned status %d: %q", resp.StatusCode, http.StatusText(resp.StatusCode)) } - -func (e *Endpoint) HTTPClient() *http.Client { - if e.client == nil { - tr := NewTransport(ConnectTimeout, e.IsSecure) - e.client = HTTPClient(tr) - } - return e.client -} diff --git a/registry/endpoint_test.go b/registry/endpoint_test.go index 9567ba2352..6f67867bbb 100644 --- a/registry/endpoint_test.go +++ b/registry/endpoint_test.go @@ -19,7 +19,7 @@ func TestEndpointParse(t *testing.T) { {"0.0.0.0:5000", "https://0.0.0.0:5000/v0/"}, } for _, td := range testData { - e, err := newEndpoint(td.str, false) + e, err := newEndpoint(td.str, false, nil) if err != nil { t.Errorf("%q: %s", td.str, err) } @@ -60,6 +60,7 @@ func TestValidateEndpointAmbiguousAPIVersion(t *testing.T) { testEndpoint := Endpoint{ URL: testServerURL, Version: APIVersionUnknown, + client: HTTPClient(NewTransport(ConnectTimeout, false)), } if err = validateEndpoint(&testEndpoint); err != nil { diff --git a/registry/registry.go b/registry/registry.go index 4f5403002a..b0706e348e 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -19,6 +19,7 @@ import ( "github.com/docker/docker/autogen/dockerversion" "github.com/docker/docker/pkg/parsers/kernel" "github.com/docker/docker/pkg/timeoutconn" + "github.com/docker/docker/pkg/transport" "github.com/docker/docker/pkg/useragent" ) @@ -36,17 +37,32 @@ const ( ConnectTimeout ) -type httpsTransport struct { - *http.Transport +// dockerUserAgent is the User-Agent the Docker client uses to identify itself. +// It is populated on init(), comprising version information of different components. +var dockerUserAgent string + +func init() { + httpVersion := make([]useragent.VersionInfo, 0, 6) + httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION}) + httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()}) + httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT}) + if kernelVersion, err := kernel.GetKernelVersion(); err == nil { + httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()}) + } + httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS}) + httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH}) + + dockerUserAgent = useragent.AppendVersions("", httpVersion...) } +type httpsRequestModifier struct{ tlsConfig *tls.Config } + // DRAGONS(tiborvass): If someone wonders why do we set tlsconfig in a roundtrip, // it's because it's so as to match the current behavior in master: we generate the // certpool on every-goddam-request. It's not great, but it allows people to just put // the certs in /etc/docker/certs.d/.../ and let docker "pick it up" immediately. Would // prefer an fsnotify implementation, but that was out of scope of my refactoring. -// TODO: improve things -func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (m *httpsRequestModifier) ModifyRequest(req *http.Request) error { var ( roots *x509.CertPool certs []tls.Certificate @@ -66,7 +82,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { logrus.Debugf("hostDir: %s", hostDir) fs, err := ioutil.ReadDir(hostDir) if err != nil && !os.IsNotExist(err) { - return nil, err + return nil } for _, f := range fs { @@ -77,7 +93,7 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { logrus.Debugf("crt: %s", hostDir+"/"+f.Name()) data, err := ioutil.ReadFile(path.Join(hostDir, f.Name())) if err != nil { - return nil, err + return err } roots.AppendCertsFromPEM(data) } @@ -86,11 +102,11 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { keyName := certName[:len(certName)-5] + ".key" logrus.Debugf("cert: %s", hostDir+"/"+f.Name()) if !hasFile(fs, keyName) { - return nil, fmt.Errorf("Missing key %s for certificate %s", keyName, certName) + return fmt.Errorf("Missing key %s for certificate %s", keyName, certName) } cert, err := tls.LoadX509KeyPair(path.Join(hostDir, certName), path.Join(hostDir, keyName)) if err != nil { - return nil, err + return err } certs = append(certs, cert) } @@ -99,38 +115,32 @@ func (tr *httpsTransport) RoundTrip(req *http.Request) (*http.Response, error) { certName := keyName[:len(keyName)-4] + ".cert" logrus.Debugf("key: %s", hostDir+"/"+f.Name()) if !hasFile(fs, certName) { - return nil, fmt.Errorf("Missing certificate %s for key %s", certName, keyName) + return fmt.Errorf("Missing certificate %s for key %s", certName, keyName) } } } - if tr.Transport.TLSClientConfig == nil { - tr.Transport.TLSClientConfig = &tls.Config{ - // Avoid fallback to SSL protocols < TLS1.0 - MinVersion: tls.VersionTLS10, - } - } - tr.Transport.TLSClientConfig.RootCAs = roots - tr.Transport.TLSClientConfig.Certificates = certs + m.tlsConfig.RootCAs = roots + m.tlsConfig.Certificates = certs } - return tr.Transport.RoundTrip(req) + return nil } func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { - tlsConfig := tls.Config{ + tlsConfig := &tls.Config{ // Avoid fallback to SSL protocols < TLS1.0 MinVersion: tls.VersionTLS10, InsecureSkipVerify: !secure, } - transport := &http.Transport{ + tr := &http.Transport{ DisableKeepAlives: true, Proxy: http.ProxyFromEnvironment, - TLSClientConfig: &tlsConfig, + TLSClientConfig: tlsConfig, } switch timeout { case ConnectTimeout: - transport.Dial = func(proto string, addr string) (net.Conn, error) { + tr.Dial = func(proto string, addr string) (net.Conn, error) { // Set the connect timeout to 30 seconds to allow for slower connection // times... d := net.Dialer{Timeout: 30 * time.Second, DualStack: true} @@ -144,7 +154,7 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { return conn, nil } case ReceiveTimeout: - transport.Dial = func(proto string, addr string) (net.Conn, error) { + tr.Dial = func(proto string, addr string) (net.Conn, error) { d := net.Dialer{DualStack: true} conn, err := d.Dial(proto, addr) @@ -159,51 +169,23 @@ func NewTransport(timeout TimeoutType, secure bool) http.RoundTripper { if secure { // note: httpsTransport also handles http transport // but for HTTPS, it sets up the certs - return &httpsTransport{transport} + return transport.NewTransport(tr, &httpsRequestModifier{tlsConfig}) } - return transport + return tr } -type DockerHeaders struct { - http.RoundTripper - Headers http.Header -} - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) +// DockerHeaders returns request modifiers that ensure requests have +// the User-Agent header set to dockerUserAgent and that metaHeaders +// are added. +func DockerHeaders(metaHeaders http.Header) []transport.RequestModifier { + modifiers := []transport.RequestModifier{ + transport.NewHeaderRequestModifier(http.Header{"User-Agent": []string{dockerUserAgent}}), } - return r2 -} - -func (tr *DockerHeaders) RoundTrip(req *http.Request) (*http.Response, error) { - req = cloneRequest(req) - httpVersion := make([]useragent.VersionInfo, 0, 4) - httpVersion = append(httpVersion, useragent.VersionInfo{"docker", dockerversion.VERSION}) - httpVersion = append(httpVersion, useragent.VersionInfo{"go", runtime.Version()}) - httpVersion = append(httpVersion, useragent.VersionInfo{"git-commit", dockerversion.GITCOMMIT}) - if kernelVersion, err := kernel.GetKernelVersion(); err == nil { - httpVersion = append(httpVersion, useragent.VersionInfo{"kernel", kernelVersion.String()}) + if metaHeaders != nil { + modifiers = append(modifiers, transport.NewHeaderRequestModifier(metaHeaders)) } - httpVersion = append(httpVersion, useragent.VersionInfo{"os", runtime.GOOS}) - httpVersion = append(httpVersion, useragent.VersionInfo{"arch", runtime.GOARCH}) - - userAgent := useragent.AppendVersions(req.UserAgent(), httpVersion...) - - req.Header.Set("User-Agent", userAgent) - - for k, v := range tr.Headers { - req.Header[k] = v - } - return tr.RoundTripper.RoundTrip(req) + return modifiers } type debugTransport struct{ http.RoundTripper } diff --git a/registry/registry_test.go b/registry/registry_test.go index d4a5ded082..33e86ff43a 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/docker/docker/cliconfig" + "github.com/docker/docker/pkg/transport" ) var ( @@ -21,12 +22,12 @@ const ( func spawnTestRegistrySession(t *testing.T) *Session { authConfig := &cliconfig.AuthConfig{} - endpoint, err := NewEndpoint(makeIndex("/v1/")) + endpoint, err := NewEndpoint(makeIndex("/v1/"), nil) if err != nil { t.Fatal(err) } var tr http.RoundTripper = debugTransport{NewTransport(ReceiveTimeout, endpoint.IsSecure)} - tr = &DockerHeaders{&authTransport{RoundTripper: tr, AuthConfig: authConfig}, nil} + tr = transport.NewTransport(AuthTransport(tr, authConfig, false), DockerHeaders(nil)...) client := HTTPClient(tr) r, err := NewSession(client, authConfig, endpoint) if err != nil { @@ -48,7 +49,7 @@ func spawnTestRegistrySession(t *testing.T) *Session { func TestPingRegistryEndpoint(t *testing.T) { testPing := func(index *IndexInfo, expectedStandalone bool, assertMessage string) { - ep, err := NewEndpoint(index) + ep, err := NewEndpoint(index, nil) if err != nil { t.Fatal(err) } @@ -68,7 +69,7 @@ func TestPingRegistryEndpoint(t *testing.T) { func TestEndpoint(t *testing.T) { // Simple wrapper to fail test if err != nil expandEndpoint := func(index *IndexInfo) *Endpoint { - endpoint, err := NewEndpoint(index) + endpoint, err := NewEndpoint(index, nil) if err != nil { t.Fatal(err) } @@ -77,7 +78,7 @@ func TestEndpoint(t *testing.T) { assertInsecureIndex := func(index *IndexInfo) { index.Secure = true - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) assertNotEqual(t, err, nil, index.Name+": Expected error for insecure index") assertEqual(t, strings.Contains(err.Error(), "insecure-registry"), true, index.Name+": Expected insecure-registry error for insecure index") index.Secure = false @@ -85,7 +86,7 @@ func TestEndpoint(t *testing.T) { assertSecureIndex := func(index *IndexInfo) { index.Secure = true - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) assertNotEqual(t, err, nil, index.Name+": Expected cert error for secure index") assertEqual(t, strings.Contains(err.Error(), "certificate signed by unknown authority"), true, index.Name+": Expected cert error for secure index") index.Secure = false @@ -151,7 +152,7 @@ func TestEndpoint(t *testing.T) { } for _, address := range badEndpoints { index.Name = address - _, err := NewEndpoint(index) + _, err := NewEndpoint(index, nil) checkNotEqual(t, err, nil, "Expected error while expanding bad endpoint") } } diff --git a/registry/service.go b/registry/service.go index 067df107c2..6811749272 100644 --- a/registry/service.go +++ b/registry/service.go @@ -1,6 +1,10 @@ package registry -import "github.com/docker/docker/cliconfig" +import ( + "net/http" + + "github.com/docker/docker/cliconfig" +) type Service struct { Config *ServiceConfig @@ -27,7 +31,7 @@ func (s *Service) Auth(authConfig *cliconfig.AuthConfig) (string, error) { if err != nil { return "", err } - endpoint, err := NewEndpoint(index) + endpoint, err := NewEndpoint(index, nil) if err != nil { return "", err } @@ -44,11 +48,11 @@ func (s *Service) Search(term string, authConfig *cliconfig.AuthConfig, headers } // *TODO: Search multiple indexes. - endpoint, err := repoInfo.GetEndpoint() + endpoint, err := repoInfo.GetEndpoint(http.Header(headers)) if err != nil { return nil, err } - r, err := NewSession(endpoint.HTTPClient(), authConfig, endpoint) + r, err := NewSession(endpoint.client, authConfig, endpoint) if err != nil { return nil, err } diff --git a/registry/session.go b/registry/session.go index 686e322dab..8e54bc8211 100644 --- a/registry/session.go +++ b/registry/session.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/sha256" "errors" + "sync" // this is required for some certificates _ "crypto/sha512" "encoding/hex" @@ -22,6 +23,7 @@ import ( "github.com/docker/docker/cliconfig" "github.com/docker/docker/pkg/httputils" "github.com/docker/docker/pkg/tarsum" + "github.com/docker/docker/pkg/transport" ) type Session struct { @@ -31,7 +33,18 @@ type Session struct { authConfig *cliconfig.AuthConfig } -// authTransport handles the auth layer when communicating with a v1 registry (private or official) +type authTransport struct { + http.RoundTripper + *cliconfig.AuthConfig + + alwaysSetBasicAuth bool + token []string + + mu sync.Mutex // guards modReq + modReq map[*http.Request]*http.Request // original -> modified +} + +// AuthTransport handles the auth layer when communicating with a v1 registry (private or official) // // For private v1 registries, set alwaysSetBasicAuth to true. // @@ -44,16 +57,23 @@ type Session struct { // If the server sends a token without the client having requested it, it is ignored. // // This RoundTripper also has a CancelRequest method important for correct timeout handling. -type authTransport struct { - http.RoundTripper - *cliconfig.AuthConfig - - alwaysSetBasicAuth bool - token []string +func AuthTransport(base http.RoundTripper, authConfig *cliconfig.AuthConfig, alwaysSetBasicAuth bool) http.RoundTripper { + if base == nil { + base = http.DefaultTransport + } + return &authTransport{ + RoundTripper: base, + AuthConfig: authConfig, + alwaysSetBasicAuth: alwaysSetBasicAuth, + modReq: make(map[*http.Request]*http.Request), + } } -func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req = cloneRequest(req) +func (tr *authTransport) RoundTrip(orig *http.Request) (*http.Response, error) { + req := transport.CloneRequest(orig) + tr.mu.Lock() + tr.modReq[orig] = req + tr.mu.Unlock() if tr.alwaysSetBasicAuth { req.SetBasicAuth(tr.Username, tr.Password) @@ -73,14 +93,33 @@ func (tr *authTransport) RoundTrip(req *http.Request) (*http.Response, error) { } resp, err := tr.RoundTripper.RoundTrip(req) if err != nil { + delete(tr.modReq, orig) return nil, err } if askedForToken && len(resp.Header["X-Docker-Token"]) > 0 { tr.token = resp.Header["X-Docker-Token"] } + resp.Body = &transport.OnEOFReader{ + Rc: resp.Body, + Fn: func() { delete(tr.modReq, orig) }, + } return resp, nil } +// CancelRequest cancels an in-flight request by closing its connection. +func (tr *authTransport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := tr.RoundTripper.(canceler); ok { + tr.mu.Lock() + modReq := tr.modReq[req] + delete(tr.modReq, req) + tr.mu.Unlock() + cr.CancelRequest(modReq) + } +} + // TODO(tiborvass): remove authConfig param once registry client v2 is vendored func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint *Endpoint) (r *Session, err error) { r = &Session{ @@ -105,7 +144,7 @@ func NewSession(client *http.Client, authConfig *cliconfig.AuthConfig, endpoint } } - client.Transport = &authTransport{RoundTripper: client.Transport, AuthConfig: authConfig, alwaysSetBasicAuth: alwaysSetBasicAuth} + client.Transport = AuthTransport(client.Transport, authConfig, alwaysSetBasicAuth) jar, err := cookiejar.New(nil) if err != nil { diff --git a/registry/session_v2.go b/registry/session_v2.go index c639f9226a..b660172898 100644 --- a/registry/session_v2.go +++ b/registry/session_v2.go @@ -27,7 +27,7 @@ func getV2Builder(e *Endpoint) *v2.URLBuilder { func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) { // TODO check if should use Mirror if index.Official { - ep, err = newEndpoint(REGISTRYSERVER, true) + ep, err = newEndpoint(REGISTRYSERVER, true, nil) if err != nil { return } @@ -38,7 +38,7 @@ func (r *Session) V2RegistryEndpoint(index *IndexInfo) (ep *Endpoint, err error) } else if r.indexEndpoint.String() == index.GetAuthConfigKey() { ep = r.indexEndpoint } else { - ep, err = NewEndpoint(index) + ep, err = NewEndpoint(index, nil) if err != nil { return } diff --git a/registry/token.go b/registry/token.go index af7d5f3fcb..e27cb6f528 100644 --- a/registry/token.go +++ b/registry/token.go @@ -13,7 +13,7 @@ type tokenResponse struct { Token string `json:"token"` } -func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint, client *http.Client) (token string, err error) { +func getToken(username, password string, params map[string]string, registryEndpoint *Endpoint) (token string, err error) { realm, ok := params["realm"] if !ok { return "", errors.New("no realm specified for token auth challenge") @@ -56,7 +56,7 @@ func getToken(username, password string, params map[string]string, registryEndpo req.URL.RawQuery = reqParams.Encode() - resp, err := client.Do(req) + resp, err := registryEndpoint.client.Do(req) if err != nil { return "", err }