//go:build go1.18 // +build go1.18 // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package azidentity import ( "context" "crypto" "crypto/sha1" "crypto/x509" "encoding/base64" "encoding/json" "fmt" "net/http" "os" "os/exec" "path/filepath" "strconv" "strings" "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" ) func assertion(cert *x509.Certificate, key crypto.PrivateKey) (string, error) { j := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ "aud": fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", liveSP.tenantID), "exp": json.Number(strconv.FormatInt(time.Now().Add(10*time.Minute).Unix(), 10)), "iss": liveSP.clientID, "jti": uuid.New().String(), "nbf": json.Number(strconv.FormatInt(time.Now().Unix(), 10)), "sub": liveSP.clientID, }) x5t := sha1.Sum(cert.Raw) // nosec j.Header = map[string]interface{}{ "alg": "RS256", "typ": "JWT", "x5t": base64.StdEncoding.EncodeToString(x5t[:]), } return j.SignedString(key) } func TestWorkloadIdentityCredential_Live(t *testing.T) { // This test triggers the managed identity test app deployed to Azure Kubernetes Service. // See the bicep file and test resources scripts for details. // It triggers the app with kubectl because the test subscription prohibits opening ports to the internet. pod := os.Getenv("AZIDENTITY_POD_NAME") if pod == "" { t.Skip("set AZIDENTITY_POD_NAME to run this test") } ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "kubectl", "exec", pod, "--", "wget", "-qO-", "localhost") b, err := cmd.CombinedOutput() s := string(b) require.NoError(t, err, s) require.Equal(t, "test passed", s) } func TestWorkloadIdentityCredential_Recorded(t *testing.T) { if recording.GetRecordMode() == recording.LiveMode { t.Skip("https://github.com/Azure/azure-sdk-for-go/issues/22879") } // workload identity and client cert auth use the same flow. This test // implements cert auth with WorkloadIdentityCredential as a way to test // that credential in an environment that's easier to set up than AKS cert, err := os.ReadFile(liveSP.pemPath) if err != nil { t.Fatal(err) } certs, key, err := ParseCertificates(cert, nil) if err != nil { t.Fatal(err) } a, err := assertion(certs[0], key) if err != nil { t.Fatal(err) } f := filepath.Join(t.TempDir(), t.Name()) if err := os.WriteFile(f, []byte(a), os.ModePerm); err != nil { t.Fatalf("failed to write token file: %v", err) } for _, b := range []bool{true, false} { name := "default options" if b { name = "instance discovery disabled" } t.Run(name, func(t *testing.T) { co, stop := initRecording(t) defer stop() cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientID: liveSP.clientID, ClientOptions: co, DisableInstanceDiscovery: b, TenantID: liveSP.tenantID, TokenFilePath: f, }) if err != nil { t.Fatal(err) } testGetTokenSuccess(t, cred) }) } } func TestWorkloadIdentityCredential(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") if err := os.WriteFile(tempFile, []byte(tokenValue), os.ModePerm); err != nil { t.Fatalf("failed to write token file: %v", err) } sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) *http.Response { if err := req.ParseForm(); err != nil { t.Error(err) } if actual, ok := req.PostForm["client_assertion"]; !ok { t.Error("expected a client_assertion") } else if len(actual) != 1 || actual[0] != tokenValue { t.Errorf(`unexpected assertion "%s"`, actual[0]) } if actual, ok := req.PostForm["client_id"]; !ok { t.Error("expected a client_id") } else if len(actual) != 1 || actual[0] != fakeClientID { t.Errorf(`unexpected assertion "%s"`, actual[0]) } if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID { t.Errorf(`unexpected tenant "%s"`, actual) } return nil }} cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientID: fakeClientID, ClientOptions: policy.ClientOptions{Transport: &sts}, TenantID: fakeTenantID, TokenFilePath: tempFile, }) if err != nil { t.Fatal(err) } testGetTokenSuccess(t, cred) _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}}) if err != nil { t.Fatal(err) } } func TestWorkloadIdentityCredential_Expiration(t *testing.T) { tokenReqs := 0 tempFile := filepath.Join(t.TempDir(), "test-workload-token-file") sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) *http.Response { if err := req.ParseForm(); err != nil { t.Error(err) } if actual, ok := req.PostForm["client_assertion"]; !ok { t.Error("expected a client_assertion") } else if len(actual) != 1 || actual[0] != fmt.Sprint(tokenReqs) { t.Errorf(`expected assertion "%d", got "%s"`, tokenReqs, actual[0]) } tokenReqs++ return nil }} cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientID: fakeClientID, ClientOptions: policy.ClientOptions{Transport: &sts}, TenantID: fakeTenantID, TokenFilePath: tempFile, }) if err != nil { t.Fatal(err) } for i := 0; i < 2; i++ { // tokenReqs counts requests, and its latest value is the expected client assertion and the requested scope. // Each iteration of this loop therefore sends a token request with a unique assertion. s := fmt.Sprint(tokenReqs) if err = os.WriteFile(tempFile, []byte(fmt.Sprint(s)), os.ModePerm); err != nil { t.Fatalf("failed to write token file: %v", err) } if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{s}}); err != nil { t.Fatal(err) } cred.expires = time.Now().Add(-time.Second) } if tokenReqs != 2 { t.Fatalf("expected 2 token requests, got %d", tokenReqs) } } func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) { f := filepath.Join(t.TempDir(), t.Name()) for _, env := range []map[string]string{ {}, {azureClientID: fakeClientID}, {azureFederatedTokenFile: f}, {azureTenantID: fakeTenantID}, {azureClientID: fakeClientID, azureTenantID: fakeTenantID}, {azureClientID: fakeClientID, azureFederatedTokenFile: f}, {azureTenantID: fakeTenantID, azureFederatedTokenFile: f}, } { t.Run("", func(t *testing.T) { for k, v := range env { t.Setenv(k, v) } if _, err := NewWorkloadIdentityCredential(nil); err == nil { t.Fatal("expected an error") } }) } } func TestWorkloadIdentityCredential_NoFile(t *testing.T) { for k, v := range map[string]string{ azureClientID: fakeClientID, azureFederatedTokenFile: filepath.Join(t.TempDir(), t.Name()), azureTenantID: fakeTenantID, } { t.Setenv(k, v) } cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientOptions: policy.ClientOptions{Transport: &mockSTS{}}, }) if err != nil { t.Fatal(err) } if _, err = cred.GetToken(context.Background(), testTRO); err == nil { t.Fatal("expected an error") } } func TestWorkloadIdentityCredential_Options(t *testing.T) { clientID := "not-" + fakeClientID tenantID := "not-" + fakeTenantID wrongFile := filepath.Join(t.TempDir(), "wrong") rightFile := filepath.Join(t.TempDir(), "right") if err := os.WriteFile(rightFile, []byte(tokenValue), os.ModePerm); err != nil { t.Fatal(err) } sts := mockSTS{ tenant: tenantID, tokenRequestCallback: func(req *http.Request) *http.Response { if err := req.ParseForm(); err != nil { t.Error(err) } if actual, ok := req.PostForm["client_assertion"]; !ok { t.Error("expected a client_assertion") } else if len(actual) != 1 || actual[0] != tokenValue { t.Errorf(`unexpected assertion "%s"`, actual[0]) } if actual, ok := req.PostForm["client_id"]; !ok { t.Error("expected a client_id") } else if len(actual) != 1 || actual[0] != clientID { t.Errorf(`unexpected assertion "%s"`, actual[0]) } if actual := strings.Split(req.URL.Path, "/")[1]; actual != tenantID { t.Errorf(`unexpected tenant "%s"`, actual) } return nil }, } // options should override environment variables for k, v := range map[string]string{ azureClientID: fakeClientID, azureFederatedTokenFile: wrongFile, azureTenantID: fakeTenantID, } { t.Setenv(k, v) } cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{ ClientID: clientID, ClientOptions: policy.ClientOptions{Transport: &sts}, TenantID: tenantID, TokenFilePath: rightFile, }) if err != nil { t.Fatal(err) } tk, err := cred.GetToken(context.Background(), testTRO) if err != nil { t.Fatal(err) } if tk.Token != tokenValue { t.Fatalf("unexpected token %q", tk.Token) } }