diff --git a/clientcredentials/clientcredentials.go b/clientcredentials/clientcredentials.go index 7a0b9ed1..2459d069 100644 --- a/clientcredentials/clientcredentials.go +++ b/clientcredentials/clientcredentials.go @@ -47,6 +47,10 @@ type Config struct { // client ID & client secret sent. The zero value means to // auto-detect. AuthStyle oauth2.AuthStyle + + // authStyleCache caches which auth style to use when Endpoint.AuthStyle is + // the zero value (AuthStyleAutoDetect). + authStyleCache internal.LazyAuthStyleCache } // Token uses client credentials to retrieve a token. @@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) { v[k] = p } - tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) + tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get()) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*oauth2.RetrieveError)(rErr) diff --git a/clientcredentials/clientcredentials_test.go b/clientcredentials/clientcredentials_test.go index 02a1c89a..078e75ec 100644 --- a/clientcredentials/clientcredentials_test.go +++ b/clientcredentials/clientcredentials_test.go @@ -12,8 +12,6 @@ import ( "net/http/httptest" "net/url" "testing" - - "golang.org/x/oauth2/internal" ) func newConf(serverURL string) *Config { @@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) { } func TestTokenRefreshRequest(t *testing.T) { - internal.ResetAuthCache() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { return diff --git a/internal/token.go b/internal/token.go index 58901bda..e83ddeef 100644 --- a/internal/token.go +++ b/internal/token.go @@ -18,6 +18,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" ) @@ -115,41 +116,60 @@ const ( AuthStyleInHeader AuthStyle = 2 ) -// authStyleCache is the set of tokenURLs we've successfully used via +// LazyAuthStyleCache is a backwards compatibility compromise to let Configs +// have a lazily-initialized AuthStyleCache. +// +// The two users of this, oauth2.Config and oauth2/clientcredentials.Config, +// both would ideally just embed an unexported AuthStyleCache but because both +// were historically allowed to be copied by value we can't retroactively add an +// uncopyable Mutex to them. +// +// We could use an atomic.Pointer, but that was added recently enough (in Go +// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03 +// still pass. By using an atomic.Value, it supports both Go 1.17 and +// copying by value, even if that's not ideal. +type LazyAuthStyleCache struct { + v atomic.Value // of *AuthStyleCache +} + +func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { + if c, ok := lc.v.Load().(*AuthStyleCache); ok { + return c + } + c := new(AuthStyleCache) + if !lc.v.CompareAndSwap(nil, c) { + c = lc.v.Load().(*AuthStyleCache) + } + return c +} + +// AuthStyleCache is the set of tokenURLs we've successfully used via // RetrieveToken and which style auth we ended up using. // It's called a cache, but it doesn't (yet?) shrink. It's expected that // the set of OAuth2 servers a program contacts over time is fixed and // small. -var authStyleCache struct { - sync.Mutex - m map[string]AuthStyle // keyed by tokenURL -} - -// ResetAuthCache resets the global authentication style cache used -// for AuthStyleUnknown token requests. -func ResetAuthCache() { - authStyleCache.Lock() - defer authStyleCache.Unlock() - authStyleCache.m = nil +type AuthStyleCache struct { + mu sync.Mutex + m map[string]AuthStyle // keyed by tokenURL } // lookupAuthStyle reports which auth style we last used with tokenURL // when calling RetrieveToken and whether we have ever done so. -func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { - authStyleCache.Lock() - defer authStyleCache.Unlock() - style, ok = authStyleCache.m[tokenURL] +func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + style, ok = c.m[tokenURL] return } // setAuthStyle adds an entry to authStyleCache, documented above. -func setAuthStyle(tokenURL string, v AuthStyle) { - authStyleCache.Lock() - defer authStyleCache.Unlock() - if authStyleCache.m == nil { - authStyleCache.m = make(map[string]AuthStyle) +func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { + c.mu.Lock() + defer c.mu.Unlock() + if c.m == nil { + c.m = make(map[string]AuthStyle) } - authStyleCache.m[tokenURL] = v + c.m[tokenURL] = v } // newTokenRequest returns a new *http.Request to retrieve a new token @@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values { return v2 } -func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { +func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { needsAuthStyleProbe := authStyle == 0 if needsAuthStyleProbe { - if style, ok := lookupAuthStyle(tokenURL); ok { + if style, ok := styleCache.lookupAuthStyle(tokenURL); ok { authStyle = style needsAuthStyleProbe = false } else { @@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, token, err = doTokenRoundTrip(ctx, req) } if needsAuthStyleProbe && err == nil { - setAuthStyle(tokenURL, authStyle) + styleCache.setAuthStyle(tokenURL, authStyle) } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. diff --git a/internal/token_test.go b/internal/token_test.go index c54095ab..c08862ae 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -16,7 +16,7 @@ import ( ) func TestRetrieveToken_InParams(t *testing.T) { - ResetAuthCache() + styleCache := new(AuthStyleCache) const clientID = "client-id" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got, want := r.FormValue("client_id"), clientID; got != want { @@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) { io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) })) defer ts.Close() - _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams) + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache) if err != nil { t.Errorf("RetrieveToken = %v; want no error", err) } } func TestRetrieveTokenWithContexts(t *testing.T) { - ResetAuthCache() + styleCache := new(AuthStyleCache) const clientID = "client-id" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { })) defer ts.Close() - _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown) + _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache) if err != nil { t.Errorf("RetrieveToken (with background context) = %v; want no error", err) } @@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown) + _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache) close(retrieved) if err == nil { t.Errorf("RetrieveToken (with cancelled context) = nil; want error") diff --git a/oauth2.go b/oauth2.go index 9085fabe..cc7c98c2 100644 --- a/oauth2.go +++ b/oauth2.go @@ -58,6 +58,10 @@ type Config struct { // Scope specifies optional requested permissions. Scopes []string + + // authStyleCache caches which auth style to use when Endpoint.AuthStyle is + // the zero value (AuthStyleAutoDetect). + authStyleCache internal.LazyAuthStyleCache } // A TokenSource is anything that can return a token. diff --git a/oauth2_test.go b/oauth2_test.go index a3517767..37f0580d 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -15,8 +15,6 @@ import ( "net/url" "testing" "time" - - "golang.org/x/oauth2/internal" ) type mockTransport struct { @@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { } func TestExchangeRequest_NonBasicAuth(t *testing.T) { - internal.ResetAuthCache() tr := &mockTransport{ rt: func(r *http.Request) (w *http.Response, err error) { headerAuth := r.Header.Get("Authorization") @@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { } func TestTokenRefreshRequest(t *testing.T) { - internal.ResetAuthCache() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.String() == "/somethingelse" { return diff --git a/token.go b/token.go index 5ffce976..5bbb3321 100644 --- a/token.go +++ b/token.go @@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token { // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along // with an error.. func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { - tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle)) + tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) if err != nil { if rErr, ok := err.(*internal.RetrieveError); ok { return nil, (*RetrieveError)(rErr)