azure-sdk-for-go/sdk/azidentity/default_azure_credential_te...

423 строки
14 KiB
Go

//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azidentity
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/stretchr/testify/require"
)
func TestDefaultAzureCredential_GetTokenSuccess(t *testing.T) {
env := map[string]string{azureTenantID: fakeTenantID, azureClientID: fakeClientID, azureClientSecret: fakeSecret}
setEnvironmentVariables(t, env)
cred, err := NewDefaultAzureCredential(nil)
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
c := cred.chain.sources[0].(*EnvironmentCredential)
c.cred.(*ClientSecretCredential).client.noCAE = fakeConfidentialClient{}
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}})
if err != nil {
t.Fatalf("GetToken error: %v", err)
}
}
func TestDefaultAzureCredential_ConstructorErrors(t *testing.T) {
// ensure NewEnvironmentCredential returns an error
t.Setenv(azureTenantID, "")
logMsgs := []string{}
log.SetListener(func(e log.Event, s string) {
if e == EventAuthentication {
logMsgs = append(logMsgs, s)
}
})
cred, err := NewDefaultAzureCredential(nil)
if err != nil {
t.Fatal(err)
}
// make GetToken return an error in any runtime environment
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = cred.GetToken(ctx, testTRO)
if err == nil {
t.Fatal("expected an error")
}
// these credentials' constructors returned errors because their configuration is absent;
// those errors should be represented in the error returned by DefaultAzureCredential.GetToken()
// and NewDefaultAzureCredential should have logged them
for _, name := range []string{"EnvironmentCredential", credNameWorkloadIdentity} {
matched, err := regexp.MatchString(name+`: .+\n`, err.Error())
if err != nil {
t.Fatal(err)
}
if !matched {
t.Errorf("expected an error message from %s", name)
}
}
r := regexp.MustCompile(fmt.Sprintf(`(?m)NewDefaultAzureCredential failed to initialize some credentials:\n.*EnvironmentCredential:.+\n.*%s:`, credNameWorkloadIdentity))
for _, msg := range logMsgs {
if r.MatchString(msg) {
return
}
}
t.Fatalf("expected a log message about the constructor errors, got %s", strings.Join(logMsgs, "\n"))
}
func TestDefaultAzureCredential_TenantID(t *testing.T) {
azBefore := defaultAzTokenProvider
t.Cleanup(func() { defaultAzTokenProvider = azBefore })
expected := "expected"
for _, override := range []bool{false, true} {
name := "default tenant"
if override {
name = "TenantID set"
}
for _, credName := range []string{credNameAzureCLI, credNameAzureDeveloperCLI} {
t.Run(fmt.Sprintf("%s_%s", credName, name), func(t *testing.T) {
called := false
verifyTenant := func(tenantID string) {
called = true
if (override && tenantID != expected) || (!override && tenantID != "") {
t.Fatalf("unexpected tenantID %q", tenantID)
}
}
switch credName {
case credNameAzureCLI:
defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenantID, subscription string) ([]byte, error) {
verifyTenant(tenantID)
return mockAzTokenProviderSuccess(ctx, scopes, tenantID, subscription)
}
case credNameAzureDeveloperCLI:
// ensure az returns an error so DefaultAzureCredential tries azd
defaultAzTokenProvider = func(context.Context, []string, string, string) ([]byte, error) {
return nil, newCredentialUnavailableError(credNameAzureCLI, "it didn't work")
}
azdBefore := defaultAzdTokenProvider
t.Cleanup(func() { defaultAzdTokenProvider = azdBefore })
defaultAzdTokenProvider = func(ctx context.Context, scopes []string, tenant string) ([]byte, error) {
verifyTenant(tenant)
return mockAzdTokenProviderSuccess(ctx, scopes, tenant)
}
}
// mock IMDS failure because managed identity precedes dev tools in the chain
srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(mock.WithStatusCode(400))
o := DefaultAzureCredentialOptions{ClientOptions: policy.ClientOptions{Transport: srv}}
if override {
o.TenantID = expected
}
cred, err := NewDefaultAzureCredential(&o)
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatalf("%s wasn't invoked", credName)
}
})
}
t.Run(fmt.Sprintf("%s_%s", credNameWorkloadIdentity, name), func(t *testing.T) {
af := filepath.Join(t.TempDir(), "assertions")
if err := os.WriteFile(af, []byte("assertion"), os.ModePerm); err != nil {
t.Fatal(err)
}
for k, v := range map[string]string{
azureAuthorityHost: "https://login.microsoftonline.com",
azureClientID: fakeClientID,
azureFederatedTokenFile: af,
azureTenantID: "un" + expected,
} {
t.Setenv(k, v)
}
o := DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: &mockSTS{
tenant: expected,
tokenRequestCallback: func(r *http.Request) *http.Response {
if actual := strings.Split(r.URL.Path, "/")[1]; actual != expected {
t.Fatalf("expected tenant %q, got %q", expected, actual)
}
return nil
},
},
},
}
if override {
o.TenantID = expected
}
cred, err := NewDefaultAzureCredential(&o)
if err != nil {
t.Fatal(err)
}
_, err = cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
})
}
}
func TestDefaultAzureCredential_UserAssignedIdentity(t *testing.T) {
for _, ID := range []ManagedIDKind{nil, ClientID("client-id")} {
t.Run(fmt.Sprintf("%v", ID), func(t *testing.T) {
if ID != nil {
t.Setenv(azureClientID, ID.String())
}
cred, err := NewDefaultAzureCredential(nil)
if err != nil {
t.Fatal(err)
}
for _, c := range cred.chain.sources {
if w, ok := c.(*ManagedIdentityCredential); ok {
if actual := w.mic.id; actual != ID {
t.Fatalf(`expected "%s", got "%v"`, ID, actual)
}
return
}
}
t.Fatal("default chain should include ManagedIdentityCredential")
})
}
}
func TestDefaultAzureCredential_Workload(t *testing.T) {
expectedAssertion := "service account token"
tempFile := filepath.Join(t.TempDir(), "service-account-token-file")
if err := os.WriteFile(tempFile, []byte(expectedAssertion), os.ModePerm); err != nil {
t.Fatalf(`failed to write temporary file "%s": %v`, tempFile, err)
}
sts := mockSTS{tokenRequestCallback: func(req *http.Request) *http.Response {
if err := req.ParseForm(); err != nil {
t.Fatal(err)
}
if actual := req.PostForm["client_assertion"]; actual[0] != expectedAssertion {
t.Fatalf(`unexpected assertion "%s"`, actual[0])
}
if actual := req.PostForm["client_id"]; actual[0] != fakeClientID {
t.Fatalf(`unexpected assertion "%s"`, actual[0])
}
if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID {
t.Fatalf(`unexpected tenant "%s"`, actual)
}
return nil
}}
for k, v := range map[string]string{
azureAuthorityHost: cloud.AzurePublic.ActiveDirectoryAuthorityHost,
azureClientID: fakeClientID,
azureFederatedTokenFile: tempFile,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: policy.ClientOptions{Transport: &sts}})
if err != nil {
t.Fatal(err)
}
testGetTokenSuccess(t, cred)
}
func TestDefaultAzureCredential_IMDSLive(t *testing.T) {
if recording.GetRecordMode() != recording.PlaybackMode && !liveManagedIdentity.imds {
t.Skip("set IDENTITY_IMDS_AVAILABLE to run this test")
}
// unsetting environment variables to skip EnvironmentCredential and other managed identity sources
for _, k := range []string{azureTenantID, identityEndpoint, msiEndpoint} {
if v, set := os.LookupEnv(k); set {
require.NoError(t, os.Unsetenv(k))
defer os.Setenv(k, v)
}
}
co, stop := initRecording(t)
defer stop()
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: co})
require.NoError(t, err)
testGetTokenSuccess(t, cred)
t.Run("ClientID", func(t *testing.T) {
if recording.GetRecordMode() != recording.PlaybackMode && liveManagedIdentity.clientID == "" {
t.Skip("set IDENTITY_VM_USER_ASSIGNED_MI_CLIENT_ID to run this test")
}
t.Setenv(azureClientID, liveManagedIdentity.clientID)
co, stop := initRecording(t)
defer stop()
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{ClientOptions: co})
require.NoError(t, err)
testGetTokenSuccess(t, cred)
})
}
// delayPolicy adds a delay to pipeline requests. Used to test timeout behavior.
type delayPolicy struct {
delay time.Duration
}
func (p *delayPolicy) Do(req *policy.Request) (resp *http.Response, err error) {
if p.delay > 0 {
select {
case <-req.Raw().Context().Done():
return nil, req.Raw().Context().Err()
case <-time.After(p.delay):
// delay has elapsed, continue on
}
}
return req.Next()
}
func TestDefaultAzureCredential_IMDS(t *testing.T) {
// unsetting environment variables to skip EnvironmentCredential and other managed identity sources
for _, k := range []string{azureTenantID, identityEndpoint, msiEndpoint} {
if v, set := os.LookupEnv(k); set {
require.NoError(t, os.Unsetenv(k))
defer os.Setenv(k, v)
}
}
// AzureCLICredential returning an error ensures we see fatal errors from ManagedIdentityCredential
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = mockAzTokenProviderFailure
t.Run("probe", func(t *testing.T) {
probed := false
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Retry: policy.RetryOptions{
MaxRetries: 5,
StatusCodes: []int{http.StatusInternalServerError},
},
Transport: &mockSTS{
tokenRequestCallback: func(req *http.Request) *http.Response {
hdr := req.Header.Get(headerMetadata)
if probed {
// This should be a token request. Return nil, mockSTS will respond with a token
require.NotEmpty(t, hdr, "credential shouldn't retry probe request")
return nil
}
// probe request. Respond with retriable status. The credential shouldn't retry
probed = true
require.Empty(t, hdr, "probe request shouldn't have Metadata header")
return &http.Response{
Body: io.NopCloser(strings.NewReader("{}")),
StatusCode: http.StatusInternalServerError,
}
},
},
},
})
require.NoError(t, err)
tk, err := cred.GetToken(context.Background(), testTRO)
require.NoError(t, err)
require.True(t, probed)
require.Equal(t, tokenValue, tk.Token)
t.Run("non-JSON response", func(t *testing.T) {
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = mockAzTokenProviderSuccess
for _, res := range [][]mock.ResponseOption{
{mock.WithStatusCode(http.StatusNotFound)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusBadRequest)},
{mock.WithBody([]byte("not json")), mock.WithStatusCode(http.StatusOK)},
} {
srv, close := mock.NewTLSServer(mock.WithTransformAllRequestsToTestServerUrl())
defer close()
srv.SetResponse(res...)
cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
Transport: srv,
},
})
require.NoError(t, err)
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err, "DefaultAzureCredential should continue after receiving a non-JSON response from IMDS")
}
})
})
t.Run("timeout", func(t *testing.T) {
// shorten the timeout to speed up this test
before := imdsProbeTimeout
defer func() { imdsProbeTimeout = before }()
imdsProbeTimeout = 100 * time.Millisecond
dp := delayPolicy{2 * imdsProbeTimeout}
chain, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{
ClientOptions: policy.ClientOptions{
PerCallPolicies: []policy.Policy{&dp},
Retry: policy.RetryOptions{MaxRetries: -1},
Transport: &mockSTS{},
},
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
// expecting an error because managed identity times out and AzureCLICredential returns an error
_, err = chain.GetToken(context.Background(), testTRO)
require.ErrorContains(t, err, credNameManagedIdentity+": managed identity timed out")
}
// remove the delay so ManagedIdentityCredential can get a token from the fake STS
dp.delay = 0
tk, err := chain.GetToken(context.Background(), testTRO)
require.NoError(t, err)
require.Equal(t, tokenValue, tk.Token)
// now there should be no timeout on token requests
dp.delay = 2 * imdsProbeTimeout
tk, err = chain.GetToken(context.Background(), policy.TokenRequestOptions{
// using a different scope forces a token request by bypassing the cache
Scopes: []string{"not-" + testTRO.Scopes[0]},
})
require.NoError(t, err)
require.Equal(t, tokenValue, tk.Token)
})
}
func TestDefaultAzureCredential_UnsupportedMIClientID(t *testing.T) {
fail := true
before := defaultAzTokenProvider
defer func() { defaultAzTokenProvider = before }()
defaultAzTokenProvider = func(ctx context.Context, scopes []string, tenant, subscription string) ([]byte, error) {
if fail {
return nil, errors.New("fail")
}
return mockAzTokenProviderSuccess(ctx, scopes, tenant, subscription)
}
t.Setenv(azureClientID, fakeClientID)
t.Setenv(msiEndpoint, fakeMIEndpoint)
cred, err := NewDefaultAzureCredential(nil)
require.NoError(t, err, "an unsupported client ID isn't a constructor error")
_, err = cred.GetToken(ctx, testTRO)
require.ErrorContains(t, err, "Cloud Shell", "error should mention the unsupported ID")
fail = false
_, err = cred.GetToken(ctx, testTRO)
require.NoError(t, err, "expected a token from AzureCLICredential")
}