oauth2: move global auth style cache to be per-Config

In 80673b4a4 (https://go.dev/cl/157820) I added a never-shrinking
package-global cache to remember which auto-detected auth style (HTTP
headers vs POST) was supported by a certain OAuth2 server, keyed by
its URL.

Unfortunately, some multi-tenant SaaS OIDC servers behave poorly and
have one global OpenID configuration document for all of their
customers which says ("we support all auth styles! you pick!") but
then give each customer control of which style they specifically
accept. This is bogus behavior on their part, but the oauth2 package's
global caching per URL isn't helping. (It's also bad to have a
package-global cache that can never be GC'ed)

So, this change moves the cache to hang off the oauth *Configs
instead. Unfortunately, it does so with some backwards compatiblity
compromises (an atomic.Value hack), lest people are using old versions
of Go still or copying a Config by value, both of which this package
previously accidentally supported, even though they weren't tested.

This change also means that anybody that's repeatedly making ephemeral
oauth.Configs without an explicit auth style will be losing &
reinitializing their cache on any auth style failures + fallbacks to
the other style. I think that should be pretty rare. People seem to
make an oauth2.Config once earlier and stash it away somewhere (often
deep in a token fetcher or HTTP client/transport).

Change-Id: I91f107368ab3c3d77bc425eeef65372a589feb7b
Signed-off-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/515675
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
This commit is contained in:
Brad Fitzpatrick 2023-08-03 09:40:32 -07:00
Родитель 2e4a4e2bfb
Коммит a835fc4358
7 изменённых файлов: 60 добавлений и 39 удалений

Просмотреть файл

@ -47,6 +47,10 @@ type Config struct {
// client ID & client secret sent. The zero value means to
// auto-detect.
AuthStyle oauth2.AuthStyle
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
}
// Token uses client credentials to retrieve a token.
@ -103,7 +107,7 @@ func (c *tokenSource) Token() (*oauth2.Token, error) {
v[k] = p
}
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle))
tk, err := internal.RetrieveToken(c.ctx, c.conf.ClientID, c.conf.ClientSecret, c.conf.TokenURL, v, internal.AuthStyle(c.conf.AuthStyle), c.conf.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*oauth2.RetrieveError)(rErr)

Просмотреть файл

@ -12,8 +12,6 @@ import (
"net/http/httptest"
"net/url"
"testing"
"golang.org/x/oauth2/internal"
)
func newConf(serverURL string) *Config {
@ -114,7 +112,6 @@ func TestTokenRequest(t *testing.T) {
}
func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return

Просмотреть файл

@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
@ -115,41 +116,60 @@ const (
AuthStyleInHeader AuthStyle = 2
)
// authStyleCache is the set of tokenURLs we've successfully used via
// LazyAuthStyleCache is a backwards compatibility compromise to let Configs
// have a lazily-initialized AuthStyleCache.
//
// The two users of this, oauth2.Config and oauth2/clientcredentials.Config,
// both would ideally just embed an unexported AuthStyleCache but because both
// were historically allowed to be copied by value we can't retroactively add an
// uncopyable Mutex to them.
//
// We could use an atomic.Pointer, but that was added recently enough (in Go
// 1.18) that we'd break Go 1.17 users where the tests as of 2023-08-03
// still pass. By using an atomic.Value, it supports both Go 1.17 and
// copying by value, even if that's not ideal.
type LazyAuthStyleCache struct {
v atomic.Value // of *AuthStyleCache
}
func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
if c, ok := lc.v.Load().(*AuthStyleCache); ok {
return c
}
c := new(AuthStyleCache)
if !lc.v.CompareAndSwap(nil, c) {
c = lc.v.Load().(*AuthStyleCache)
}
return c
}
// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
// the set of OAuth2 servers a program contacts over time is fixed and
// small.
var authStyleCache struct {
sync.Mutex
m map[string]AuthStyle // keyed by tokenURL
}
// ResetAuthCache resets the global authentication style cache used
// for AuthStyleUnknown token requests.
func ResetAuthCache() {
authStyleCache.Lock()
defer authStyleCache.Unlock()
authStyleCache.m = nil
type AuthStyleCache struct {
mu sync.Mutex
m map[string]AuthStyle // keyed by tokenURL
}
// lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so.
func lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
authStyleCache.Lock()
defer authStyleCache.Unlock()
style, ok = authStyleCache.m[tokenURL]
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
style, ok = c.m[tokenURL]
return
}
// setAuthStyle adds an entry to authStyleCache, documented above.
func setAuthStyle(tokenURL string, v AuthStyle) {
authStyleCache.Lock()
defer authStyleCache.Unlock()
if authStyleCache.m == nil {
authStyleCache.m = make(map[string]AuthStyle)
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
c.mu.Lock()
defer c.mu.Unlock()
if c.m == nil {
c.m = make(map[string]AuthStyle)
}
authStyleCache.m[tokenURL] = v
c.m[tokenURL] = v
}
// newTokenRequest returns a new *http.Request to retrieve a new token
@ -189,10 +209,10 @@ func cloneURLValues(v url.Values) url.Values {
return v2
}
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle) (*Token, error) {
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe {
if style, ok := lookupAuthStyle(tokenURL); ok {
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
@ -222,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req)
}
if needsAuthStyleProbe && err == nil {
setAuthStyle(tokenURL, authStyle)
styleCache.setAuthStyle(tokenURL, authStyle)
}
// Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request.

Просмотреть файл

@ -16,7 +16,7 @@ import (
)
func TestRetrieveToken_InParams(t *testing.T) {
ResetAuthCache()
styleCache := new(AuthStyleCache)
const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.FormValue("client_id"), clientID; got != want {
@ -29,14 +29,14 @@ func TestRetrieveToken_InParams(t *testing.T) {
io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`)
}))
defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams)
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleInParams, styleCache)
if err != nil {
t.Errorf("RetrieveToken = %v; want no error", err)
}
}
func TestRetrieveTokenWithContexts(t *testing.T) {
ResetAuthCache()
styleCache := new(AuthStyleCache)
const clientID = "client-id"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -45,7 +45,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
}))
defer ts.Close()
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown)
_, err := RetrieveToken(context.Background(), clientID, "", ts.URL, url.Values{}, AuthStyleUnknown, styleCache)
if err != nil {
t.Errorf("RetrieveToken (with background context) = %v; want no error", err)
}
@ -58,7 +58,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown)
_, err = RetrieveToken(ctx, clientID, "", cancellingts.URL, url.Values{}, AuthStyleUnknown, styleCache)
close(retrieved)
if err == nil {
t.Errorf("RetrieveToken (with cancelled context) = nil; want error")

Просмотреть файл

@ -58,6 +58,10 @@ type Config struct {
// Scope specifies optional requested permissions.
Scopes []string
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
// the zero value (AuthStyleAutoDetect).
authStyleCache internal.LazyAuthStyleCache
}
// A TokenSource is anything that can return a token.

Просмотреть файл

@ -15,8 +15,6 @@ import (
"net/url"
"testing"
"time"
"golang.org/x/oauth2/internal"
)
type mockTransport struct {
@ -355,7 +353,6 @@ func TestExchangeRequest_BadResponseType(t *testing.T) {
}
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
internal.ResetAuthCache()
tr := &mockTransport{
rt: func(r *http.Request) (w *http.Response, err error) {
headerAuth := r.Header.Get("Authorization")
@ -427,7 +424,6 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
}
func TestTokenRefreshRequest(t *testing.T) {
internal.ResetAuthCache()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return

Просмотреть файл

@ -164,7 +164,7 @@ func tokenFromInternal(t *internal.Token) *Token {
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle))
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)