oauth2: move global auth style cache to be per-Config

In 80673b4a4 (https://go.dev/cl/157820) I added a never-shrinking
package-global cache to remember which auto-detected auth style (HTTP
headers vs POST) was supported by a certain OAuth2 server, keyed by
its URL.

Unfortunately, some multi-tenant SaaS OIDC servers behave poorly and
have one global OpenID configuration document for all of their
customers which says ("we support all auth styles! you pick!") but
then give each customer control of which style they specifically
accept. This is bogus behavior on their part, but the oauth2 package's
global caching per URL isn't helping. (It's also bad to have a
package-global cache that can never be GC'ed)

So, this change moves the cache to hang off the oauth *Configs
instead. Unfortunately, it does so with some backwards compatiblity
compromises (an atomic.Value hack), lest people are using old versions
of Go still or copying a Config by value, both of which this package
previously accidentally supported, even though they weren't tested.

This change also means that anybody that's repeatedly making ephemeral
oauth.Configs without an explicit auth style will be losing &
reinitializing their cache on any auth style failures + fallbacks to
the other style. I think that should be pretty rare. People seem to
make an oauth2.Config once earlier and stash it away somewhere (often
deep in a token fetcher or HTTP client/transport).

Change-Id: I91f107368ab3c3d77bc425eeef65372a589feb7b
Signed-off-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/515675
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
Brad Fitzpatrick 2023-08-03 09:40:32 -07:00
Родитель 2e4a4e2bfb
Коммит a835fc4358
7 изменённых файлов: 60 добавлений и 39 удалений

Просмотреть файл

@ -47,6 +47,10 @@ type Config struct {
// client ID & client secret sent. The zero value means to // client ID & client secret sent. The zero value means to
// auto-detect. // auto-detect.
AuthStyle oauth2.AuthStyle AuthStyle oauth2.AuthStyle
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
} }
// Token uses client credentials to retrieve a token. // Token uses client credentials to retrieve a token.
@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v[k] = p v[k] = p
} }
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle)) tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
if err != nil { if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok { if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*oauth2.RetrieveError)(rErr) return nil, (*oauth2.RetrieveError)(rErr)

Просмотреть файл

@ -12,8 +12,6 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"golang.org/x/oauth2/internal"
) )
func newConf(serverURL string) *Config { func newConf(serverURL string) *Config {
@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
} }
func TestTokenRefreshRequest(t *testing.T) { func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" { if r.URL.String() == "/somethingelse" {
return return

Просмотреть файл

@ -18,6 +18,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@ -115,41 +116,60 @@ const (
AuthStyleInHeader AuthStyle = 2 AuthStyleInHeader AuthStyle = 2
) )
// authStyleCache is the set of tokenURLs we've successfully used via // LazyAuthStyleCache is a backwards compatibility compromise to let Configs
// have a lazily-initialized AuthStyleCache.
//
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
// both would ideally just embed an unexported AuthStyleCache but because both
// were historically allowed to be copied by value we can't retroactively add an
// uncopyable Mutex to them.
//
// We could use an atomic.Pointer, but that was added recently enough (in Go
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
// still pass. By using an atomic.Value, it supports both Go 1.17 and
// copying by value, even if that's not ideal.
type LazyAuthStyleCache struct {
v atomic.Value // of *AuthStyleCache
}
func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
return c
}
c := new(AuthStyleCache)
if !lc.v.CompareAndSwap(nil, c) {
c = lc.v.Load().(*AuthStyleCache)
}
return c
}
// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using. // RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that // It's called a cache, but it doesn't (yet?) shrink. It's expected that
// the set of OAuth2 servers a program contacts over time is fixed and // the set of OAuth2 servers a program contacts over time is fixed and
// small. // small.
var authStyleCache struct { type AuthStyleCache struct {
sync.Mutex mu sync.Mutex
m map[string]AuthStyle // keyed by tokenURL m map[string]AuthStyle // keyed by tokenURL
} }
// ResetAuthCache resets the global authentication style cache used
// for AuthStyleUnknown token requests.
func ResetAuthCache() {
authStyleCache.Lock()
defer authStyleCache.Unlock()
authStyleCache.m = nil
}
// lookupAuthStyle reports which auth style we last used with tokenURL // lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so. // when calling RetrieveToken and whether we have ever done so.
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
authStyleCache.Lock() c.mu.Lock()
defer authStyleCache.Unlock() defer c.mu.Unlock()
style, ok = authStyleCache.m[tokenURL] style, ok = c.m[tokenURL]
return return
} }
// setAuthStyle adds an entry to authStyleCache, documented above. // setAuthStyle adds an entry to authStyleCache, documented above.
func setAuthStyle(tokenURL string, v AuthStyle) { func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
authStyleCache.Lock() c.mu.Lock()
defer authStyleCache.Unlock() defer c.mu.Unlock()
if authStyleCache.m == nil { if c.m == nil {
authStyleCache.m = make(map[string]AuthStyle) c.m = make(map[string]AuthStyle)
} }
authStyleCache.m[tokenURL] = v c.m[tokenURL] = v
} }
// newTokenRequest returns a new *http.Request to retrieve a new token // newTokenRequest returns a new *http.Request to retrieve a new token
@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
return v2 return v2
} }
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) { func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == 0 needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe { if needsAuthStyleProbe {
if style, ok := lookupAuthStyle(tokenURL); ok { if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style authStyle = style
needsAuthStyleProbe = false needsAuthStyleProbe = false
} else { } else {
@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req) token, err = doTokenRoundTrip(ctx, req)
} }
if needsAuthStyleProbe && err == nil { if needsAuthStyleProbe && err == nil {
setAuthStyle(tokenURL, authStyle) styleCache.setAuthStyle(tokenURL, authStyle)
} }
// Don't overwrite `RefreshToken` with an empty value // Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request. // if this was a token refreshing request.

Просмотреть файл

@ -16,7 +16,7 @@ import (
) )
func TestRetrieveToken_InParams(t *testing.T) { func TestRetrieveToken_InParams(t *testing.T) {
ResetAuthCache() styleCache := new(AuthStyleCache)
const clientID = "client-id" const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want { if got, want := r.FormValue("client_id"), clientID; got != want {
@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
})) }))
defer ts.Close() defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams) _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil { if err != nil {
t.Errorf("RetrieveToken = %v; want no error", err) t.Errorf("RetrieveToken = %v; want no error", err)
} }
} }
func TestRetrieveTokenWithContexts(t *testing.T) { func TestRetrieveTokenWithContexts(t *testing.T) {
ResetAuthCache() styleCache := new(AuthStyleCache)
const clientID = "client-id" const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
})) }))
defer ts.Close() defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown) _, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil { if err != nil {
t.Errorf("RetrieveToken (with background context) = %v; want no error", err) t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
} }
@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown) _, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
close(retrieved) close(retrieved)
if err == nil { if err == nil {
t.Errorf("RetrieveToken (with cancelled context) = nil; want error") t.Errorf("RetrieveToken (with cancelled context) = nil; want error")

Просмотреть файл

@ -58,6 +58,10 @@ type Config struct {
// Scope specifies optional requested permissions. // Scope specifies optional requested permissions.
Scopes []string Scopes []string
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
} }
// A TokenSource is anything that can return a token. // A TokenSource is anything that can return a token.

Просмотреть файл

@ -15,8 +15,6 @@ import (
"net/url" "net/url"
"testing" "testing"
"time" "time"
"golang.org/x/oauth2/internal"
) )
type mockTransport struct { type mockTransport struct {
@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
} }
func TestExchangeRequest_NonBasicAuth(t *testing.T) { func TestExchangeRequest_NonBasicAuth(t *testing.T) {
internal.ResetAuthCache()
tr := &mockTransport{ tr := &mockTransport{
rt: func(r *http.Request) (w *http.Response, err error) { rt: func(r *http.Request) (w *http.Response, err error) {
headerAuth := r.Header.Get("Authorization") headerAuth := r.Header.Get("Authorization")
@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
} }
func TestTokenRefreshRequest(t *testing.T) { func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" { if r.URL.String() == "/somethingelse" {
return return

Просмотреть файл

@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error.. // with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle)) tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil { if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok { if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr) return nil, (*RetrieveError)(rErr)