azure-container-networking/keyvault/certrefresher.go

131 строка
3.8 KiB
Go

package keyvault
import (
"context"
//nolint:gosec // sha1 only used to display cert thumbprint in logs for cross-verification with keyvault.
"crypto/sha1"
"crypto/tls"
"fmt"
"sync"
"time"
"github.com/avast/retry-go/v3"
"github.com/pkg/errors"
)
type EventualExpirationErr struct {
time.Time
}
func (e *EventualExpirationErr) Error() string {
return fmt.Sprintf("could not refresh before expiration on %s", e.Time.String())
}
type tlsCertFetcher interface {
GetLatestTLSCertificate(ctx context.Context, certName string) (tls.Certificate, error)
}
type logger interface {
Printf(format string, args ...any)
Errorf(format string, args ...any)
}
// CertRefresher offers a mechanism to present the latest version of a tls.Certificate from KeyVault, refreshed at an interval.
type CertRefresher struct {
certName string
kvc tlsCertFetcher
logger logger
m sync.RWMutex
cert *tls.Certificate
}
// NewCertRefresher returns a CertRefresher. When there's no error, the CertRefresher's GetCertificate method is ready
// for use, returning a valid tls.Certificate fetched from KeyVault during construction.
func NewCertRefresher(ctx context.Context, kvc tlsCertFetcher, l logger, certName string) (*CertRefresher, error) {
cf := CertRefresher{
certName: certName,
kvc: kvc,
logger: l,
}
cert, err := cf.kvc.GetLatestTLSCertificate(ctx, cf.certName)
if err != nil {
return nil, errors.Wrap(err, "could not fetch initial cert")
}
cf.cert = &cert
cf.logger.Printf("initial certificate fetched: %s", &cf)
return &cf, nil
}
func (c *CertRefresher) String() string {
return fmt.Sprintf("cert name: %s, sha1 thumbprint: %s, expiration: %s", c.certName, sha1String(c.cert.Leaf.Raw), c.cert.Leaf.NotAfter.String())
}
// GetCertificate returns the latest certificate fetched from KeyVault.
func (c *CertRefresher) GetCertificate() *tls.Certificate {
c.m.RLock()
defer c.m.RUnlock()
return c.cert
}
// Refresh starts refreshing the certificate at the interval provided.
// It blocks until context is done or refreshing fails.
func (c *CertRefresher) Refresh(ctx context.Context, interval time.Duration) error {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "refresh canceled")
case <-ticker.C:
if err := c.refresh(ctx); err != nil {
c.logger.Errorf("could not refresh before certificate expiration on %s: %v", c.cert.Leaf.NotAfter, err)
return &EventualExpirationErr{c.cert.Leaf.NotAfter}
}
}
}
}
// refresh will attempt to fetch the latest version of a certificate, up until the current one expires.
func (c *CertRefresher) refresh(ctx context.Context) error {
certExpires := c.cert.Leaf.NotAfter
ctx, cancel := context.WithDeadline(ctx, certExpires)
defer cancel()
var latestCert tls.Certificate
retryFn := func() (err error) {
latestCert, err = c.kvc.GetLatestTLSCertificate(ctx, c.certName)
if err != nil {
c.logger.Errorf("could not fetch latest tls certificate: %v. retrying...", err)
return errors.Wrap(err, "could not fetch latest tls certificate")
}
return nil
}
if err := retry.Do(retryFn, retry.Context(ctx), retry.Delay(time.Second), retry.DelayType(retry.FixedDelay)); err != nil {
return errors.Wrap(err, "could not refresh cert")
}
c.m.Lock()
defer c.m.Unlock()
if latestCert.Leaf.Equal(c.cert.Leaf) {
c.logger.Printf("certificate unchanged. certificate %s", c)
return nil
}
oldThumbprint := sha1String(c.cert.Leaf.Raw)
c.cert = &latestCert
c.logger.Printf("certificate refreshed. old sha1 thumbprint: %s, certificate: %s", oldThumbprint, c)
return nil
}
func sha1String(bs []byte) string {
//nolint:gosec // sha1 only used to display cert thumbprint in logs for cross-verification with keyvault.
return fmt.Sprintf("%X", sha1.Sum(bs))
}