From b5cf4d8d48698c1f6d3b57b8c893e580aa2a4db1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 2 Apr 2017 12:37:13 -0700 Subject: [PATCH] acme/autocert: context propagation and doc tweaks Change-Id: I061b797d46097e37880bea1911475e2b2f1a0378 Reviewed-on: https://go-review.googlesource.com/39270 Reviewed-by: Alex Vaghin --- acme/autocert/autocert.go | 33 ++++++++++++++++----------------- acme/autocert/autocert_test.go | 7 ++++--- acme/autocert/renewal.go | 4 ++-- acme/autocert/renewal_test.go | 5 +++-- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go index 4b15816a..dfb860f4 100644 --- a/acme/autocert/autocert.go +++ b/acme/autocert/autocert.go @@ -41,8 +41,9 @@ func init() { pseudoRand = &lockedMathRand{rnd: mathrand.New(src)} } -// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service -// during account registration. +// AcceptTOS is a Manager.Prompt function that always returns true to +// indicate acceptance of the CA's Terms of Service during account +// registration. func AcceptTOS(tosURL string) bool { return true } // HostPolicy specifies which host names the Manager is allowed to respond to. @@ -178,6 +179,9 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, return nil, errors.New("acme/autocert: missing server name") } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // check whether this is a token cert requested for TLS-SNI challenge if strings.HasSuffix(name, ".acme.invalid") { m.tokenCertMu.RLock() @@ -185,7 +189,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, if cert := m.tokenCert[name]; cert != nil { return cert, nil } - if cert, err := m.cacheGet(name); err == nil { + if cert, err := m.cacheGet(ctx, name); err == nil { return cert, nil } // TODO: cache error results? @@ -194,7 +198,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, // regular domain name = strings.TrimSuffix(name, ".") // golang.org/issue/18114 - cert, err := m.cert(name) + cert, err := m.cert(ctx, name) if err == nil { return cert, nil } @@ -203,7 +207,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, } // first-time - ctx := context.Background() // TODO: use a deadline? if err := m.hostPolicy()(ctx, name); err != nil { return nil, err } @@ -211,14 +214,14 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, if err != nil { return nil, err } - m.cachePut(name, cert) + m.cachePut(ctx, name, cert) return cert, nil } // cert returns an existing certificate either from m.state or cache. // If a certificate is found in cache but not in m.state, the latter will be filled // with the cached value. -func (m *Manager) cert(name string) (*tls.Certificate, error) { +func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) { m.stateMu.Lock() if s, ok := m.state[name]; ok { m.stateMu.Unlock() @@ -227,7 +230,7 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) { return s.tlscert() } defer m.stateMu.Unlock() - cert, err := m.cacheGet(name) + cert, err := m.cacheGet(ctx, name) if err != nil { return nil, err } @@ -249,12 +252,10 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) { } // cacheGet always returns a valid certificate, or an error otherwise. -func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { +func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) { if m.Cache == nil { return nil, ErrCacheMiss } - // TODO: might want to define a cache timeout on m - ctx := context.Background() data, err := m.Cache.Get(ctx, domain) if err != nil { return nil, err @@ -297,7 +298,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { return tlscert, nil } -func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error { +func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error { if m.Cache == nil { return nil } @@ -329,8 +330,6 @@ func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error { } } - // TODO: might want to define a cache timeout on m - ctx := context.Background() return m.Cache.Put(ctx, domain, buf.Bytes()) } @@ -494,7 +493,7 @@ func (m *Manager) verify(ctx context.Context, domain string) error { if err != nil { return err } - m.putTokenCert(name, &cert) + m.putTokenCert(ctx, name, &cert) defer func() { // verification has ended at this point // don't need token cert anymore @@ -512,14 +511,14 @@ func (m *Manager) verify(ctx context.Context, domain string) error { // putTokenCert stores the cert under the named key in both m.tokenCert map // and m.Cache. -func (m *Manager) putTokenCert(name string, cert *tls.Certificate) { +func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) { m.tokenCertMu.Lock() defer m.tokenCertMu.Unlock() if m.tokenCert == nil { m.tokenCert = make(map[string]*tls.Certificate) } m.tokenCert[name] = cert - m.cachePut(name, cert) + m.cachePut(ctx, name, cert) } // deleteTokenCert removes the token certificate for the specified domain name diff --git a/acme/autocert/autocert_test.go b/acme/autocert/autocert_test.go index 7afb2133..c3f3f66e 100644 --- a/acme/autocert/autocert_test.go +++ b/acme/autocert/autocert_test.go @@ -150,7 +150,7 @@ func TestGetCertificate_ForceRSA(t *testing.T) { hello := &tls.ClientHelloInfo{ServerName: "example.org"} testGetCertificate(t, man, "example.org", hello) - cert, err := man.cacheGet("example.org") + cert, err := man.cacheGet(context.Background(), "example.org") if err != nil { t.Fatalf("man.cacheGet: %v", err) } @@ -335,10 +335,11 @@ func TestCache(t *testing.T) { man := &Manager{Cache: newMemCache()} defer man.stopRenew() - if err := man.cachePut("example.org", tlscert); err != nil { + ctx := context.Background() + if err := man.cachePut(ctx, "example.org", tlscert); err != nil { t.Fatalf("man.cachePut: %v", err) } - res, err := man.cacheGet("example.org") + res, err := man.cacheGet(ctx, "example.org") if err != nil { t.Fatalf("man.cacheGet: %v", err) } diff --git a/acme/autocert/renewal.go b/acme/autocert/renewal.go index 1a5018c8..0d2eb607 100644 --- a/acme/autocert/renewal.go +++ b/acme/autocert/renewal.go @@ -83,7 +83,7 @@ func (dr *domainRenewal) renew() { func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { // a race is likely unavoidable in a distributed environment // but we try nonetheless - if tlscert, err := dr.m.cacheGet(dr.domain); err == nil { + if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil { next := dr.next(tlscert.Leaf.NotAfter) if next > dr.m.renewBefore()+maxRandRenew { return next, nil @@ -103,7 +103,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { if err != nil { return 0, err } - dr.m.cachePut(dr.domain, tlscert) + dr.m.cachePut(ctx, dr.domain, tlscert) dr.m.stateMu.Lock() defer dr.m.stateMu.Unlock() // m.state is guaranteed to be non-nil at this point diff --git a/acme/autocert/renewal_test.go b/acme/autocert/renewal_test.go index 10c811ac..27958bb0 100644 --- a/acme/autocert/renewal_test.go +++ b/acme/autocert/renewal_test.go @@ -18,6 +18,7 @@ import ( "time" "golang.org/x/crypto/acme" + "golang.org/x/net/context" ) func TestRenewalNext(t *testing.T) { @@ -127,7 +128,7 @@ func TestRenewFromCache(t *testing.T) { t.Fatal(err) } tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}} - if err := man.cachePut(domain, tlscert); err != nil { + if err := man.cachePut(context.Background(), domain, tlscert); err != nil { t.Fatal(err) } @@ -151,7 +152,7 @@ func TestRenewFromCache(t *testing.T) { // ensure the new cert is cached after := time.Now().Add(future) - tlscert, err := man.cacheGet(domain) + tlscert, err := man.cacheGet(context.Background(), domain) if err != nil { t.Fatalf("man.cacheGet: %v", err) }