azure-sdk-for-go/sdk/azidentity/workload_identity_test.go

293 строки
9.0 KiB
Go

//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azidentity
import (
"context"
"crypto"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/recording"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func assertion(cert *x509.Certificate, key crypto.PrivateKey) (string, error) {
j := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"aud": fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", liveSP.tenantID),
"exp": json.Number(strconv.FormatInt(time.Now().Add(10*time.Minute).Unix(), 10)),
"iss": liveSP.clientID,
"jti": uuid.New().String(),
"nbf": json.Number(strconv.FormatInt(time.Now().Unix(), 10)),
"sub": liveSP.clientID,
})
x5t := sha1.Sum(cert.Raw) // nosec
j.Header = map[string]interface{}{
"alg": "RS256",
"typ": "JWT",
"x5t": base64.StdEncoding.EncodeToString(x5t[:]),
}
return j.SignedString(key)
}
func TestWorkloadIdentityCredential_Live(t *testing.T) {
// This test triggers the managed identity test app deployed to Azure Kubernetes Service.
// See the bicep file and test resources scripts for details.
// It triggers the app with kubectl because the test subscription prohibits opening ports to the internet.
pod := os.Getenv("AZIDENTITY_POD_NAME")
if pod == "" {
t.Skip("set AZIDENTITY_POD_NAME to run this test")
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "kubectl", "exec", pod, "--", "wget", "-qO-", "localhost")
b, err := cmd.CombinedOutput()
s := string(b)
require.NoError(t, err, s)
require.Equal(t, "test passed", s)
}
func TestWorkloadIdentityCredential_Recorded(t *testing.T) {
if recording.GetRecordMode() == recording.LiveMode {
t.Skip("https://github.com/Azure/azure-sdk-for-go/issues/22879")
}
// workload identity and client cert auth use the same flow. This test
// implements cert auth with WorkloadIdentityCredential as a way to test
// that credential in an environment that's easier to set up than AKS
cert, err := os.ReadFile(liveSP.pemPath)
if err != nil {
t.Fatal(err)
}
certs, key, err := ParseCertificates(cert, nil)
if err != nil {
t.Fatal(err)
}
a, err := assertion(certs[0], key)
if err != nil {
t.Fatal(err)
}
f := filepath.Join(t.TempDir(), t.Name())
if err := os.WriteFile(f, []byte(a), os.ModePerm); err != nil {
t.Fatalf("failed to write token file: %v", err)
}
for _, b := range []bool{true, false} {
name := "default options"
if b {
name = "instance discovery disabled"
}
t.Run(name, func(t *testing.T) {
co, stop := initRecording(t)
defer stop()
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: liveSP.clientID,
ClientOptions: co,
DisableInstanceDiscovery: b,
TenantID: liveSP.tenantID,
TokenFilePath: f,
})
if err != nil {
t.Fatal(err)
}
testGetTokenSuccess(t, cred)
})
}
}
func TestWorkloadIdentityCredential(t *testing.T) {
tempFile := filepath.Join(t.TempDir(), "test-workload-token-file")
if err := os.WriteFile(tempFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatalf("failed to write token file: %v", err)
}
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) *http.Response {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
if actual, ok := req.PostForm["client_assertion"]; !ok {
t.Error("expected a client_assertion")
} else if len(actual) != 1 || actual[0] != tokenValue {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual, ok := req.PostForm["client_id"]; !ok {
t.Error("expected a client_id")
} else if len(actual) != 1 || actual[0] != fakeClientID {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual := strings.Split(req.URL.Path, "/")[1]; actual != fakeTenantID {
t.Errorf(`unexpected tenant "%s"`, actual)
}
return nil
}}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: fakeClientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: fakeTenantID,
TokenFilePath: tempFile,
})
if err != nil {
t.Fatal(err)
}
testGetTokenSuccess(t, cred)
_, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{"scope"}})
if err != nil {
t.Fatal(err)
}
}
func TestWorkloadIdentityCredential_Expiration(t *testing.T) {
tokenReqs := 0
tempFile := filepath.Join(t.TempDir(), "test-workload-token-file")
sts := mockSTS{tenant: fakeTenantID, tokenRequestCallback: func(req *http.Request) *http.Response {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
if actual, ok := req.PostForm["client_assertion"]; !ok {
t.Error("expected a client_assertion")
} else if len(actual) != 1 || actual[0] != fmt.Sprint(tokenReqs) {
t.Errorf(`expected assertion "%d", got "%s"`, tokenReqs, actual[0])
}
tokenReqs++
return nil
}}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: fakeClientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: fakeTenantID,
TokenFilePath: tempFile,
})
if err != nil {
t.Fatal(err)
}
for i := 0; i < 2; i++ {
// tokenReqs counts requests, and its latest value is the expected client assertion and the requested scope.
// Each iteration of this loop therefore sends a token request with a unique assertion.
s := fmt.Sprint(tokenReqs)
if err = os.WriteFile(tempFile, []byte(fmt.Sprint(s)), os.ModePerm); err != nil {
t.Fatalf("failed to write token file: %v", err)
}
if _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{s}}); err != nil {
t.Fatal(err)
}
cred.expires = time.Now().Add(-time.Second)
}
if tokenReqs != 2 {
t.Fatalf("expected 2 token requests, got %d", tokenReqs)
}
}
func TestTestWorkloadIdentityCredential_IncompleteConfig(t *testing.T) {
f := filepath.Join(t.TempDir(), t.Name())
for _, env := range []map[string]string{
{},
{azureClientID: fakeClientID},
{azureFederatedTokenFile: f},
{azureTenantID: fakeTenantID},
{azureClientID: fakeClientID, azureTenantID: fakeTenantID},
{azureClientID: fakeClientID, azureFederatedTokenFile: f},
{azureTenantID: fakeTenantID, azureFederatedTokenFile: f},
} {
t.Run("", func(t *testing.T) {
for k, v := range env {
t.Setenv(k, v)
}
if _, err := NewWorkloadIdentityCredential(nil); err == nil {
t.Fatal("expected an error")
}
})
}
}
func TestWorkloadIdentityCredential_NoFile(t *testing.T) {
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureFederatedTokenFile: filepath.Join(t.TempDir(), t.Name()),
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientOptions: policy.ClientOptions{Transport: &mockSTS{}},
})
if err != nil {
t.Fatal(err)
}
if _, err = cred.GetToken(context.Background(), testTRO); err == nil {
t.Fatal("expected an error")
}
}
func TestWorkloadIdentityCredential_Options(t *testing.T) {
clientID := "not-" + fakeClientID
tenantID := "not-" + fakeTenantID
wrongFile := filepath.Join(t.TempDir(), "wrong")
rightFile := filepath.Join(t.TempDir(), "right")
if err := os.WriteFile(rightFile, []byte(tokenValue), os.ModePerm); err != nil {
t.Fatal(err)
}
sts := mockSTS{
tenant: tenantID,
tokenRequestCallback: func(req *http.Request) *http.Response {
if err := req.ParseForm(); err != nil {
t.Error(err)
}
if actual, ok := req.PostForm["client_assertion"]; !ok {
t.Error("expected a client_assertion")
} else if len(actual) != 1 || actual[0] != tokenValue {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual, ok := req.PostForm["client_id"]; !ok {
t.Error("expected a client_id")
} else if len(actual) != 1 || actual[0] != clientID {
t.Errorf(`unexpected assertion "%s"`, actual[0])
}
if actual := strings.Split(req.URL.Path, "/")[1]; actual != tenantID {
t.Errorf(`unexpected tenant "%s"`, actual)
}
return nil
},
}
// options should override environment variables
for k, v := range map[string]string{
azureClientID: fakeClientID,
azureFederatedTokenFile: wrongFile,
azureTenantID: fakeTenantID,
} {
t.Setenv(k, v)
}
cred, err := NewWorkloadIdentityCredential(&WorkloadIdentityCredentialOptions{
ClientID: clientID,
ClientOptions: policy.ClientOptions{Transport: &sts},
TenantID: tenantID,
TokenFilePath: rightFile,
})
if err != nil {
t.Fatal(err)
}
tk, err := cred.GetToken(context.Background(), testTRO)
if err != nil {
t.Fatal(err)
}
if tk.Token != tokenValue {
t.Fatalf("unexpected token %q", tk.Token)
}
}