ManagedIdentityCredential caches tokens (#18968)

This commit is contained in:
Charles Lowell 2022-09-19 14:34:02 -07:00 коммит произвёл GitHub
Родитель e4aa0f7387
Коммит 1cdf922bb9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
10 изменённых файлов: 78 добавлений и 113 удалений

Просмотреть файл

@ -3,14 +3,11 @@
## 1.2.0-beta.3 (Unreleased)
### Features Added
### Breaking Changes
* `ManagedIdentityCredential` caches tokens in memory
### Bugs Fixed
* `ClientCertificateCredential` sends only the leaf cert for SNI authentication
### Other Changes
## 1.2.0-beta.2 (2022-08-10)
### Features Added

Просмотреть файл

@ -28,7 +28,6 @@ import (
// constants used throughout this package
const (
accessTokenRespSuccess = `{"access_token": "` + tokenValue + `", "expires_in": 3600}`
accessTokenRespMalformed = `{"access_token": 0, "expires_in": 3600}`
badTenantID = "bad_tenant"
tenantDiscoveryResponse = `{
@ -98,10 +97,13 @@ const (
"msgraph_host": "graph.microsoft.com",
"rbac_url": "https://pas.windows.net"
}`
tokenValue = "new_token"
tokenExpiresIn = 3600
tokenValue = "new_token"
)
var instanceDiscoveryResponse = []byte(`{
var (
accessTokenRespSuccess = fmt.Sprintf(`{"access_token": "%s", "expires_in": %d}`, tokenValue, tokenExpiresIn)
instanceDiscoveryResponse = []byte(`{
"tenant_discovery_endpoint": "https://login.microsoftonline.com/tenant/v2.0/.well-known/openid-configuration",
"api-version": "1.1",
"metadata": [
@ -117,6 +119,7 @@ var instanceDiscoveryResponse = []byte(`{
}
]
}`)
)
// constants for this file
const (

Просмотреть файл

@ -67,7 +67,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default
msiCred, err := NewManagedIdentityCredential(o)
if err == nil {
creds = append(creds, msiCred)
msiCred.client.imdsTimeout = time.Second
msiCred.mic.imdsTimeout = time.Second
} else {
errorMessages = append(errorMessages, credNameManagedIdentity+": "+err.Error())
creds = append(creds, &defaultCredentialErrorReporter{credType: credNameManagedIdentity, err: err})

Просмотреть файл

@ -78,9 +78,9 @@ func TestDefaultAzureCredential_UserAssignedIdentity(t *testing.T) {
t.Fatal(err)
}
for _, c := range cred.chain.sources {
if mic, ok := c.(*ManagedIdentityCredential); ok {
if mic.id != ID {
t.Fatalf(`expected %v, got "%v"`, ID, mic.id)
if m, ok := c.(*ManagedIdentityCredential); ok {
if actual := m.mic.id; actual != ID {
t.Fatalf(`expected "%s", got "%v"`, ID, actual)
}
return
}

Просмотреть файл

@ -5,7 +5,7 @@ go 1.18
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0
github.com/AzureAD/microsoft-authentication-library-for-go v0.6.0
github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0
github.com/golang-jwt/jwt/v4 v4.4.2
golang.org/x/crypto v0.0.0-20220511200225-c6db032c6c88
)

Просмотреть файл

@ -2,8 +2,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0 h1:sVPhtT2qjO86rTUaWMr4WoES4
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.0.0/go.mod h1:uGG2W01BaETf0Ozp+QxxKJdMBNRWPdstHG0Fmdwn1/U=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0 h1:jp0dGvZ7ZK0mgqnTSClMxa5xuRL7NZgHameVYF6BurY=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w=
github.com/AzureAD/microsoft-authentication-library-for-go v0.6.0 h1:XMEdVDFxgulDDl0lQmAZS6j8gRQ/0pJ+ZpXH2FHVtDc=
github.com/AzureAD/microsoft-authentication-library-for-go v0.6.0/go.mod h1:BDJ5qMFKx9DugEg3+uQSDCdbYPr5s9vBTrL9P8TpqOU=
github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0 h1:VgSJlZH5u0k2qxSpqyghcFQKmvYckj46uymKK5XzkBM=
github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0/go.mod h1:BDJ5qMFKx9DugEg3+uQSDCdbYPr5s9vBTrL9P8TpqOU=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dnaeon/go-vcr v1.1.0 h1:ReYa/UBrRyQdant9B4fNHGoCNKw6qh6P0fsdGmZpR7c=

Просмотреть файл

@ -200,19 +200,17 @@ func testGetTokenSuccess(t *testing.T, cred azcore.TokenCredential) {
if tk.Token == "" {
t.Fatal("GetToken returned an invalid token")
}
if tk.ExpiresOn.Before(time.Now().UTC()) {
if tk.ExpiresOn.Before(time.Now()) {
t.Fatal("GetToken returned an invalid expiration time")
}
_, actual := tk.ExpiresOn.Zone()
_, expected := time.Now().UTC().Zone()
if actual != expected {
if tk.ExpiresOn.Location() != time.UTC {
t.Fatal("ExpiresOn isn't UTC")
}
tk2, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk2.Token != tk.Token || tk2.ExpiresOn.After(tk.ExpiresOn) {
if tk2.Token != tk.Token || tk2.ExpiresOn != tk.ExpiresOn {
t.Fatal("expected a cached token")
}
}

Просмотреть файл

@ -23,6 +23,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)
const (
@ -148,10 +149,18 @@ func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*manag
return &c, nil
}
// authenticate creates an authentication request for a Managed Identity and returns the resulting Access Token if successful.
// ctx: The current context for controlling the request lifetime.
// clientID: The client (application) ID of the service principal.
// scopes: The scopes required for the token.
// provideToken acquires a token for MSAL's confidential.Client, which caches the token
func (c *managedIdentityClient) provideToken(ctx context.Context, params confidential.TokenProviderParameters) (confidential.TokenProviderResult, error) {
result := confidential.TokenProviderResult{}
tk, err := c.authenticate(ctx, c.id, params.Scopes)
if err == nil {
result.AccessToken = tk.Token
result.ExpiresInSeconds = int(time.Until(tk.ExpiresOn).Seconds())
}
return result, err
}
// authenticate acquires an access token
func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (azcore.AccessToken, error) {
var cancel context.CancelFunc
if c.imdsTimeout > 0 && c.msiType == msiTypeIMDS {

Просмотреть файл

@ -14,6 +14,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)
const credNameManagedIdentity = "ManagedIdentityCredential"
@ -70,8 +71,8 @@ type ManagedIdentityCredentialOptions struct {
// user-assigned identity. See Azure Active Directory documentation for more information about managed identities:
// https://docs.microsoft.com/azure/active-directory/managed-identities-azure-resources/overview
type ManagedIdentityCredential struct {
id ManagedIDKind
client *managedIdentityClient
client confidentialClient
mic *managedIdentityClient
}
// NewManagedIdentityCredential creates a ManagedIdentityCredential. Pass nil to accept default options.
@ -79,11 +80,25 @@ func NewManagedIdentityCredential(options *ManagedIdentityCredentialOptions) (*M
if options == nil {
options = &ManagedIdentityCredentialOptions{}
}
client, err := newManagedIdentityClient(options)
mic, err := newManagedIdentityClient(options)
if err != nil {
return nil, err
}
return &ManagedIdentityCredential{id: options.ID, client: client}, nil
cred := confidential.NewCredFromTokenProvider(mic.provideToken)
if err != nil {
return nil, err
}
// It's okay to give MSAL an invalid client ID because MSAL will use it only as part of a cache key.
// ManagedIdentityClient handles all the details of authentication and won't receive this value from MSAL.
clientID := "SYSTEM-ASSIGNED-MANAGED-IDENTITY"
if options.ID != nil {
clientID = options.ID.String()
}
c, err := confidential.New(clientID, cred)
if err != nil {
return nil, err
}
return &ManagedIdentityCredential{client: c, mic: mic}, nil
}
// GetToken requests an access token from the hosting environment. This method is called automatically by Azure SDK clients.
@ -94,12 +109,17 @@ func (c *ManagedIdentityCredential) GetToken(ctx context.Context, opts policy.To
}
// managed identity endpoints require an AADv1 resource (i.e. token audience), not a v2 scope, so we remove "/.default" here
scopes := []string{strings.TrimSuffix(opts.Scopes[0], defaultSuffix)}
tk, err := c.client.authenticate(ctx, c.id, scopes)
ar, err := c.client.AcquireTokenSilent(ctx, scopes)
if err == nil {
logGetTokenSuccess(c, opts)
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, nil
}
ar, err = c.client.AcquireTokenByCredential(ctx, scopes)
if err != nil {
return azcore.AccessToken{}, err
}
logGetTokenSuccess(c, opts)
return tk, err
return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
}
var _ azcore.TokenCredential = (*ManagedIdentityCredential)(nil)

Просмотреть файл

@ -25,7 +25,6 @@ import (
)
const (
appServiceSuccessResp = `{"access_token": "` + tokenValue + `", "expires_on": "1560974028", "resource": "https://vault.azure.net", "token_type": "Bearer", "client_id": "some-guid"}`
expiresOnIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": "1560974028", "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}`
expiresOnNonStringIntResp = `{"access_token": "new_token", "refresh_token": "", "expires_in": "", "expires_on": 1560974028, "not_before": "1560970130", "resource": "https://vault.azure.net", "token_type": "Bearer"}`
)
@ -94,16 +93,7 @@ func TestManagedIdentityCredential_AzureArc(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf("unexpected token: %s", tk.Token)
}
if tk.ExpiresOn.Before(time.Now().UTC()) {
t.Fatal("GetToken returned an invalid expiration time")
}
testGetTokenSuccess(t, cred)
}
func TestManagedIdentityCredential_CloudShell(t *testing.T) {
@ -131,19 +121,7 @@ func TestManagedIdentityCredential_CloudShell(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := msiCred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf("unexpected token value: %s", tk.Token)
}
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithStatusCode(http.StatusUnauthorized))
srv.AppendResponse()
_, err = msiCred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err == nil {
t.Fatal("expected an error but didn't receive one")
}
testGetTokenSuccess(t, msiCred)
}
func TestManagedIdentityCredential_AppService(t *testing.T) {
@ -185,7 +163,14 @@ func TestManagedIdentityCredential_AppService(t *testing.T) {
t.Run(fmt.Sprintf("%T", id), func(t *testing.T) {
srv, close := mock.NewServer()
defer close()
srv.AppendResponse(mock.WithPredicate(validateReq), mock.WithBody([]byte(appServiceSuccessResp)))
srv.AppendResponse(
mock.WithPredicate(validateReq),
mock.WithBody([]byte(fmt.Sprintf(
`{"access_token": "%s", "expires_on": "%d", "resource": "https://vault.azure.net", "token_type": "Bearer", "client_id": "some-guid"}`,
tokenValue,
time.Now().Add(time.Hour).Unix(),
))),
)
srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest))
setEnvironmentVariables(t, map[string]string{identityEndpoint: srv.URL(), identityHeader: expectedHeader})
options := ManagedIdentityCredentialOptions{ID: id}
@ -194,13 +179,7 @@ func TestManagedIdentityCredential_AppService(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf(`unexpected token "%s"`, tk.Token)
}
testGetTokenSuccess(t, cred)
})
}
}
@ -283,8 +262,7 @@ func TestManagedIdentityCredential_CreateIMDSAuthRequest(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
cred.client.endpoint = imdsEndpoint
req, err := cred.client.createIMDSAuthRequest(context.Background(), ClientID(fakeClientID), []string{liveTestScope})
req, err := cred.mic.createIMDSAuthRequest(context.Background(), ClientID(fakeClientID), []string{liveTestScope})
if err != nil {
t.Fatal(err)
}
@ -355,16 +333,12 @@ func TestManagedIdentityCredential_ScopesImmutable(t *testing.T) {
}
func TestManagedIdentityCredential_ResourceID_IMDS(t *testing.T) {
// setting a dummy value for MSI_ENDPOINT in order to avoid failure in the constructor
setEnvironmentVariables(t, map[string]string{msiEndpoint: "http://localhost"})
resID := "sample/resource/id"
cred, err := NewManagedIdentityCredential(&ManagedIdentityCredentialOptions{ID: ResourceID(resID)})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
cred.client.msiType = msiTypeIMDS
cred.client.endpoint = imdsEndpoint
req, err := cred.client.createAuthRequest(context.Background(), cred.id, []string{liveTestScope})
req, err := cred.mic.createAuthRequest(context.Background(), cred.mic.id, []string{liveTestScope})
if err != nil {
t.Fatal(err)
}
@ -435,16 +409,7 @@ func TestManagedIdentityCredential_IMDSLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token == "" {
t.Fatal("GetToken returned an invalid token")
}
if tk.ExpiresOn.Before(time.Now().UTC()) {
t.Fatal("GetToken returned an invalid expiration time")
}
testGetTokenSuccess(t, cred)
}
func TestManagedIdentityCredential_IMDSClientIDLive(t *testing.T) {
@ -463,16 +428,7 @@ func TestManagedIdentityCredential_IMDSClientIDLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token == "" {
t.Fatal("GetToken returned an invalid token")
}
if tk.ExpiresOn.Before(time.Now().UTC()) {
t.Fatal("GetToken returned an invalid expiration time")
}
testGetTokenSuccess(t, cred)
}
func TestManagedIdentityCredential_IMDSResourceIDLive(t *testing.T) {
@ -491,16 +447,7 @@ func TestManagedIdentityCredential_IMDSResourceIDLive(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token == "" {
t.Fatal("GetToken returned an invalid token")
}
if tk.ExpiresOn.Before(time.Now().UTC()) {
t.Fatal("GetToken returned an invalid expiration time")
}
testGetTokenSuccess(t, cred)
}
func TestManagedIdentityCredential_IMDSTimeoutExceeded(t *testing.T) {
@ -513,7 +460,7 @@ func TestManagedIdentityCredential_IMDSTimeoutExceeded(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cred.client.imdsTimeout = time.Nanosecond
cred.mic.imdsTimeout = time.Nanosecond
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if _, ok := err.(*credentialUnavailableError); !ok {
t.Fatalf("expected credentialUnavailableError, received %T", err)
@ -534,7 +481,7 @@ func TestManagedIdentityCredential_IMDSTimeoutSuccess(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cred.client.imdsTimeout = time.Minute
cred.mic.imdsTimeout = time.Minute
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
@ -542,10 +489,10 @@ func TestManagedIdentityCredential_IMDSTimeoutSuccess(t *testing.T) {
if tk.Token != tokenValue {
t.Fatalf(`got unexpected token "%s"`, tk.Token)
}
if !tk.ExpiresOn.After(time.Now().UTC()) {
t.Fatal("GetToken returned an invalid expiration time")
if v := time.Until(tk.ExpiresOn); v > tokenExpiresIn*time.Second || tokenExpiresIn-v > time.Second {
t.Fatalf("expected token to expire in about %d seconds but it expires in %f seconds", tokenExpiresIn, v.Seconds())
}
if cred.client.imdsTimeout > 0 {
if cred.mic.imdsTimeout > 0 {
t.Fatal("credential didn't remove IMDS timeout after receiving a response")
}
}
@ -574,14 +521,5 @@ func TestManagedIdentityCredential_ServiceFabric(t *testing.T) {
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf(`got unexpected token "%s"`, tk.Token)
}
if !tk.ExpiresOn.After(time.Now().UTC()) {
t.Fatal("GetToken returned an invalid expiration time")
}
testGetTokenSuccess(t, cred)
}