acme/autocert: context propagation and doc tweaks
Change-Id: I061b797d46097e37880bea1911475e2b2f1a0378 Reviewed-on: https://go-review.googlesource.com/39270 Reviewed-by: Alex Vaghin <ddos@google.com>
This commit is contained in:
Родитель
3cb07270c9
Коммит
b5cf4d8d48
|
@ -41,8 +41,9 @@ func init() {
|
|||
pseudoRand = &lockedMathRand{rnd: mathrand.New(src)}
|
||||
}
|
||||
|
||||
// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service
|
||||
// during account registration.
|
||||
// AcceptTOS is a Manager.Prompt function that always returns true to
|
||||
// indicate acceptance of the CA's Terms of Service during account
|
||||
// registration.
|
||||
func AcceptTOS(tosURL string) bool { return true }
|
||||
|
||||
// HostPolicy specifies which host names the Manager is allowed to respond to.
|
||||
|
@ -178,6 +179,9 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
|
|||
return nil, errors.New("acme/autocert: missing server name")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// check whether this is a token cert requested for TLS-SNI challenge
|
||||
if strings.HasSuffix(name, ".acme.invalid") {
|
||||
m.tokenCertMu.RLock()
|
||||
|
@ -185,7 +189,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
|
|||
if cert := m.tokenCert[name]; cert != nil {
|
||||
return cert, nil
|
||||
}
|
||||
if cert, err := m.cacheGet(name); err == nil {
|
||||
if cert, err := m.cacheGet(ctx, name); err == nil {
|
||||
return cert, nil
|
||||
}
|
||||
// TODO: cache error results?
|
||||
|
@ -194,7 +198,7 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
|
|||
|
||||
// regular domain
|
||||
name = strings.TrimSuffix(name, ".") // golang.org/issue/18114
|
||||
cert, err := m.cert(name)
|
||||
cert, err := m.cert(ctx, name)
|
||||
if err == nil {
|
||||
return cert, nil
|
||||
}
|
||||
|
@ -203,7 +207,6 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
|
|||
}
|
||||
|
||||
// first-time
|
||||
ctx := context.Background() // TODO: use a deadline?
|
||||
if err := m.hostPolicy()(ctx, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -211,14 +214,14 @@ func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.cachePut(name, cert)
|
||||
m.cachePut(ctx, name, cert)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// cert returns an existing certificate either from m.state or cache.
|
||||
// If a certificate is found in cache but not in m.state, the latter will be filled
|
||||
// with the cached value.
|
||||
func (m *Manager) cert(name string) (*tls.Certificate, error) {
|
||||
func (m *Manager) cert(ctx context.Context, name string) (*tls.Certificate, error) {
|
||||
m.stateMu.Lock()
|
||||
if s, ok := m.state[name]; ok {
|
||||
m.stateMu.Unlock()
|
||||
|
@ -227,7 +230,7 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
|
|||
return s.tlscert()
|
||||
}
|
||||
defer m.stateMu.Unlock()
|
||||
cert, err := m.cacheGet(name)
|
||||
cert, err := m.cacheGet(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -249,12 +252,10 @@ func (m *Manager) cert(name string) (*tls.Certificate, error) {
|
|||
}
|
||||
|
||||
// cacheGet always returns a valid certificate, or an error otherwise.
|
||||
func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
|
||||
func (m *Manager) cacheGet(ctx context.Context, domain string) (*tls.Certificate, error) {
|
||||
if m.Cache == nil {
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
// TODO: might want to define a cache timeout on m
|
||||
ctx := context.Background()
|
||||
data, err := m.Cache.Get(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -297,7 +298,7 @@ func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) {
|
|||
return tlscert, nil
|
||||
}
|
||||
|
||||
func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
|
||||
func (m *Manager) cachePut(ctx context.Context, domain string, tlscert *tls.Certificate) error {
|
||||
if m.Cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -329,8 +330,6 @@ func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: might want to define a cache timeout on m
|
||||
ctx := context.Background()
|
||||
return m.Cache.Put(ctx, domain, buf.Bytes())
|
||||
}
|
||||
|
||||
|
@ -494,7 +493,7 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.putTokenCert(name, &cert)
|
||||
m.putTokenCert(ctx, name, &cert)
|
||||
defer func() {
|
||||
// verification has ended at this point
|
||||
// don't need token cert anymore
|
||||
|
@ -512,14 +511,14 @@ func (m *Manager) verify(ctx context.Context, domain string) error {
|
|||
|
||||
// putTokenCert stores the cert under the named key in both m.tokenCert map
|
||||
// and m.Cache.
|
||||
func (m *Manager) putTokenCert(name string, cert *tls.Certificate) {
|
||||
func (m *Manager) putTokenCert(ctx context.Context, name string, cert *tls.Certificate) {
|
||||
m.tokenCertMu.Lock()
|
||||
defer m.tokenCertMu.Unlock()
|
||||
if m.tokenCert == nil {
|
||||
m.tokenCert = make(map[string]*tls.Certificate)
|
||||
}
|
||||
m.tokenCert[name] = cert
|
||||
m.cachePut(name, cert)
|
||||
m.cachePut(ctx, name, cert)
|
||||
}
|
||||
|
||||
// deleteTokenCert removes the token certificate for the specified domain name
|
||||
|
|
|
@ -150,7 +150,7 @@ func TestGetCertificate_ForceRSA(t *testing.T) {
|
|||
hello := &tls.ClientHelloInfo{ServerName: "example.org"}
|
||||
testGetCertificate(t, man, "example.org", hello)
|
||||
|
||||
cert, err := man.cacheGet("example.org")
|
||||
cert, err := man.cacheGet(context.Background(), "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("man.cacheGet: %v", err)
|
||||
}
|
||||
|
@ -335,10 +335,11 @@ func TestCache(t *testing.T) {
|
|||
|
||||
man := &Manager{Cache: newMemCache()}
|
||||
defer man.stopRenew()
|
||||
if err := man.cachePut("example.org", tlscert); err != nil {
|
||||
ctx := context.Background()
|
||||
if err := man.cachePut(ctx, "example.org", tlscert); err != nil {
|
||||
t.Fatalf("man.cachePut: %v", err)
|
||||
}
|
||||
res, err := man.cacheGet("example.org")
|
||||
res, err := man.cacheGet(ctx, "example.org")
|
||||
if err != nil {
|
||||
t.Fatalf("man.cacheGet: %v", err)
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ func (dr *domainRenewal) renew() {
|
|||
func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
|
||||
// a race is likely unavoidable in a distributed environment
|
||||
// but we try nonetheless
|
||||
if tlscert, err := dr.m.cacheGet(dr.domain); err == nil {
|
||||
if tlscert, err := dr.m.cacheGet(ctx, dr.domain); err == nil {
|
||||
next := dr.next(tlscert.Leaf.NotAfter)
|
||||
if next > dr.m.renewBefore()+maxRandRenew {
|
||||
return next, nil
|
||||
|
@ -103,7 +103,7 @@ func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) {
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
dr.m.cachePut(dr.domain, tlscert)
|
||||
dr.m.cachePut(ctx, dr.domain, tlscert)
|
||||
dr.m.stateMu.Lock()
|
||||
defer dr.m.stateMu.Unlock()
|
||||
// m.state is guaranteed to be non-nil at this point
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"time"
|
||||
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestRenewalNext(t *testing.T) {
|
||||
|
@ -127,7 +128,7 @@ func TestRenewFromCache(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}}
|
||||
if err := man.cachePut(domain, tlscert); err != nil {
|
||||
if err := man.cachePut(context.Background(), domain, tlscert); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
@ -151,7 +152,7 @@ func TestRenewFromCache(t *testing.T) {
|
|||
|
||||
// ensure the new cert is cached
|
||||
after := time.Now().Add(future)
|
||||
tlscert, err := man.cacheGet(domain)
|
||||
tlscert, err := man.cacheGet(context.Background(), domain)
|
||||
if err != nil {
|
||||
t.Fatalf("man.cacheGet: %v", err)
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче