azure-sdk-for-go/sdk/azidentity/chained_token_credential.go

147 строки
4.9 KiB
Go

//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azidentity
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)
// ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
type ChainedTokenCredentialOptions struct {
// RetrySources configures how the credential uses its sources. When true, the credential always attempts to
// authenticate through each source in turn, stopping when one succeeds. When false, the credential authenticates
// only through this first successful source--it never again tries the sources which failed.
RetrySources bool
}
// ChainedTokenCredential links together multiple credentials and tries them sequentially when authenticating. By default,
// it tries all the credentials until one authenticates, after which it always uses that credential.
type ChainedTokenCredential struct {
cond *sync.Cond
iterating bool
name string
retrySources bool
sources []azcore.TokenCredential
successfulCredential azcore.TokenCredential
}
// NewChainedTokenCredential creates a ChainedTokenCredential. Pass nil for options to accept defaults.
func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
if len(sources) == 0 {
return nil, errors.New("sources must contain at least one TokenCredential")
}
for _, source := range sources {
if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
return nil, errors.New("sources cannot contain nil")
}
}
cp := make([]azcore.TokenCredential, len(sources))
copy(cp, sources)
if options == nil {
options = &ChainedTokenCredentialOptions{}
}
return &ChainedTokenCredential{
cond: sync.NewCond(&sync.Mutex{}),
name: "ChainedTokenCredential",
retrySources: options.RetrySources,
sources: cp,
}, nil
}
// GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token.
// This method is called automatically by Azure SDK clients.
func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
if !c.retrySources {
// ensure only one goroutine at a time iterates the sources and perhaps sets c.successfulCredential
c.cond.L.Lock()
for {
if c.successfulCredential != nil {
c.cond.L.Unlock()
return c.successfulCredential.GetToken(ctx, opts)
}
if !c.iterating {
c.iterating = true
// allow other goroutines to wait while this one iterates
c.cond.L.Unlock()
break
}
c.cond.Wait()
}
}
var (
err error
errs []error
successfulCredential azcore.TokenCredential
token azcore.AccessToken
unavailableErr credentialUnavailable
)
for _, cred := range c.sources {
token, err = cred.GetToken(ctx, opts)
if err == nil {
log.Writef(EventAuthentication, "%s authenticated with %s", c.name, extractCredentialName(cred))
successfulCredential = cred
break
}
errs = append(errs, err)
// continue to the next source iff this one returned credentialUnavailableError
if !errors.As(err, &unavailableErr) {
break
}
}
if c.iterating {
c.cond.L.Lock()
// this is nil when all credentials returned an error
c.successfulCredential = successfulCredential
c.iterating = false
c.cond.L.Unlock()
c.cond.Broadcast()
}
// err is the error returned by the last GetToken call. It will be nil when that call succeeds
if err != nil {
// return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
msg := createChainedErrorMessage(errs)
var authFailedErr *AuthenticationFailedError
switch {
case errors.As(err, &authFailedErr):
err = newAuthenticationFailedError(c.name, msg, authFailedErr.RawResponse)
if af, ok := err.(*AuthenticationFailedError); ok {
// stop Error() printing the response again; it's already in msg
af.omitResponse = true
}
case errors.As(err, &unavailableErr):
err = newCredentialUnavailableError(c.name, msg)
default:
res := getResponseFromError(err)
err = newAuthenticationFailedError(c.name, msg, res)
}
}
return token, err
}
func createChainedErrorMessage(errs []error) string {
msg := "failed to acquire a token.\nAttempted credentials:"
for _, err := range errs {
msg += fmt.Sprintf("\n\t%s", strings.ReplaceAll(err.Error(), "\n", "\n\t\t"))
}
return msg
}
func extractCredentialName(credential azcore.TokenCredential) string {
return strings.TrimPrefix(fmt.Sprintf("%T", credential), "*azidentity.")
}
var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil)