273 строки
8.8 KiB
Go
273 строки
8.8 KiB
Go
//go:build go1.18
|
|
// +build go1.18
|
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
package azidentity
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
|
|
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
|
|
|
|
// this import ensures well-known configurations in azcore/cloud have ARM audiences for Authenticate()
|
|
_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
|
|
)
|
|
|
|
type publicClientOptions struct {
|
|
azcore.ClientOptions
|
|
|
|
AdditionallyAllowedTenants []string
|
|
Cache Cache
|
|
DeviceCodePrompt func(context.Context, DeviceCodeMessage) error
|
|
DisableAutomaticAuthentication bool
|
|
DisableInstanceDiscovery bool
|
|
LoginHint, RedirectURL string
|
|
Record AuthenticationRecord
|
|
Username, Password string
|
|
}
|
|
|
|
// publicClient wraps the MSAL public client
|
|
type publicClient struct {
|
|
cae, noCAE msalPublicClient
|
|
caeMu, noCAEMu, clientMu *sync.Mutex
|
|
clientID, tenantID string
|
|
defaultScope []string
|
|
host string
|
|
name string
|
|
opts publicClientOptions
|
|
record AuthenticationRecord
|
|
azClient *azcore.Client
|
|
}
|
|
|
|
var errScopeRequired = errors.New("authenticating in this environment requires specifying a scope in TokenRequestOptions")
|
|
|
|
func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*publicClient, error) {
|
|
if !validTenantID(tenantID) {
|
|
return nil, errInvalidTenantID
|
|
}
|
|
host, err := setAuthorityHost(o.Cloud)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// if the application specified a cloud configuration, use its ARM audience as the default scope for Authenticate()
|
|
audience := o.Cloud.Services[cloud.ResourceManager].Audience
|
|
if audience == "" {
|
|
// no cloud configuration, or no ARM audience, specified; try to map the host to a well-known one (all of which have a trailing slash)
|
|
if !strings.HasSuffix(host, "/") {
|
|
host += "/"
|
|
}
|
|
switch host {
|
|
case cloud.AzureChina.ActiveDirectoryAuthorityHost:
|
|
audience = cloud.AzureChina.Services[cloud.ResourceManager].Audience
|
|
case cloud.AzureGovernment.ActiveDirectoryAuthorityHost:
|
|
audience = cloud.AzureGovernment.Services[cloud.ResourceManager].Audience
|
|
case cloud.AzurePublic.ActiveDirectoryAuthorityHost:
|
|
audience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
|
|
}
|
|
}
|
|
// if we didn't come up with an audience, the application will have to specify a scope for Authenticate()
|
|
var defaultScope []string
|
|
if audience != "" {
|
|
defaultScope = []string{audience + defaultSuffix}
|
|
}
|
|
client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
|
|
Tracing: runtime.TracingOptions{
|
|
Namespace: traceNamespace,
|
|
},
|
|
}, &o.ClientOptions)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
o.AdditionallyAllowedTenants = resolveAdditionalTenants(o.AdditionallyAllowedTenants)
|
|
return &publicClient{
|
|
caeMu: &sync.Mutex{},
|
|
clientID: clientID,
|
|
clientMu: &sync.Mutex{},
|
|
defaultScope: defaultScope,
|
|
host: host,
|
|
name: name,
|
|
noCAEMu: &sync.Mutex{},
|
|
opts: o,
|
|
record: o.Record,
|
|
tenantID: tenantID,
|
|
azClient: client,
|
|
}, nil
|
|
}
|
|
|
|
func (p *publicClient) Authenticate(ctx context.Context, tro *policy.TokenRequestOptions) (AuthenticationRecord, error) {
|
|
if tro == nil {
|
|
tro = &policy.TokenRequestOptions{}
|
|
}
|
|
if len(tro.Scopes) == 0 {
|
|
if p.defaultScope == nil {
|
|
return AuthenticationRecord{}, errScopeRequired
|
|
}
|
|
tro.Scopes = p.defaultScope
|
|
}
|
|
client, mu, err := p.client(*tro)
|
|
if err != nil {
|
|
return AuthenticationRecord{}, err
|
|
}
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
_, err = p.reqToken(ctx, client, *tro)
|
|
if err == nil {
|
|
scope := strings.Join(tro.Scopes, ", ")
|
|
msg := fmt.Sprintf("%s.Authenticate() acquired a token for scope %q", p.name, scope)
|
|
log.Write(EventAuthentication, msg)
|
|
}
|
|
return p.record, err
|
|
}
|
|
|
|
// GetToken requests an access token from MSAL, checking the cache first.
|
|
func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
|
|
if len(tro.Scopes) < 1 {
|
|
return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", p.name)
|
|
}
|
|
tenant, err := p.resolveTenant(tro.TenantID)
|
|
if err != nil {
|
|
return azcore.AccessToken{}, err
|
|
}
|
|
client, mu, err := p.client(tro)
|
|
if err != nil {
|
|
return azcore.AccessToken{}, err
|
|
}
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
ar, err := client.AcquireTokenSilent(ctx, tro.Scopes, public.WithSilentAccount(p.record.account()), public.WithClaims(tro.Claims), public.WithTenantID(tenant))
|
|
if err == nil {
|
|
return p.token(ar, err)
|
|
}
|
|
if p.opts.DisableAutomaticAuthentication {
|
|
return azcore.AccessToken{}, newAuthenticationRequiredError(p.name, tro)
|
|
}
|
|
at, err := p.reqToken(ctx, client, tro)
|
|
if err == nil {
|
|
msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", p.name, strings.Join(ar.GrantedScopes, ", "))
|
|
log.Write(EventAuthentication, msg)
|
|
}
|
|
return at, err
|
|
}
|
|
|
|
// reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache.
|
|
func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
|
|
tenant, err := p.resolveTenant(tro.TenantID)
|
|
if err != nil {
|
|
return azcore.AccessToken{}, err
|
|
}
|
|
var ar public.AuthResult
|
|
switch p.name {
|
|
case credNameBrowser:
|
|
ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes,
|
|
public.WithClaims(tro.Claims),
|
|
public.WithLoginHint(p.opts.LoginHint),
|
|
public.WithRedirectURI(p.opts.RedirectURL),
|
|
public.WithTenantID(tenant),
|
|
)
|
|
case credNameDeviceCode:
|
|
dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
|
|
if e != nil {
|
|
return azcore.AccessToken{}, e
|
|
}
|
|
err = p.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{
|
|
Message: dc.Result.Message,
|
|
UserCode: dc.Result.UserCode,
|
|
VerificationURL: dc.Result.VerificationURL,
|
|
})
|
|
if err == nil {
|
|
ar, err = dc.AuthenticationResult(ctx)
|
|
}
|
|
case credNameUserPassword:
|
|
ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
|
|
default:
|
|
return azcore.AccessToken{}, fmt.Errorf("unknown credential %q", p.name)
|
|
}
|
|
return p.token(ar, err)
|
|
}
|
|
|
|
func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) {
|
|
p.clientMu.Lock()
|
|
defer p.clientMu.Unlock()
|
|
if tro.EnableCAE {
|
|
if p.cae == nil {
|
|
client, err := p.newMSALClient(true)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
p.cae = client
|
|
}
|
|
return p.cae, p.caeMu, nil
|
|
}
|
|
if p.noCAE == nil {
|
|
client, err := p.newMSALClient(false)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
p.noCAE = client
|
|
}
|
|
return p.noCAE, p.noCAEMu, nil
|
|
}
|
|
|
|
func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) {
|
|
c, err := internal.ExportReplace(p.opts.Cache, enableCAE)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
o := []public.Option{
|
|
public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)),
|
|
public.WithCache(c),
|
|
public.WithHTTPClient(p),
|
|
}
|
|
if enableCAE {
|
|
o = append(o, public.WithClientCapabilities(cp1))
|
|
}
|
|
if p.opts.DisableInstanceDiscovery || strings.ToLower(p.tenantID) == "adfs" {
|
|
o = append(o, public.WithInstanceDiscovery(false))
|
|
}
|
|
return public.New(p.clientID, o...)
|
|
}
|
|
|
|
func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToken, error) {
|
|
if err == nil {
|
|
p.record, err = newAuthenticationRecord(ar)
|
|
} else {
|
|
err = newAuthenticationFailedErrorFromMSAL(p.name, err)
|
|
}
|
|
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
|
|
}
|
|
|
|
// resolveTenant returns the correct WithTenantID() argument for a token request given the client's
|
|
// configuration, or an error when that configuration doesn't allow the specified tenant
|
|
func (p *publicClient) resolveTenant(specified string) (string, error) {
|
|
t, err := resolveTenant(p.tenantID, specified, p.name, p.opts.AdditionallyAllowedTenants)
|
|
if t == p.tenantID {
|
|
// callers pass this value to MSAL's WithTenantID(). There's no need to redundantly specify
|
|
// the client's default tenant and doing so is an error when that tenant is "organizations"
|
|
t = ""
|
|
}
|
|
return t, err
|
|
}
|
|
|
|
// these methods satisfy the MSAL ops.HTTPClient interface
|
|
|
|
func (p *publicClient) CloseIdleConnections() {
|
|
// do nothing
|
|
}
|
|
|
|
func (p *publicClient) Do(r *http.Request) (*http.Response, error) {
|
|
return doForClient(p.azClient, r)
|
|
}
|