diff --git a/keyvaultcertdownloader/internal/corehelper/corehelper.go b/keyvaultcertdownloader/internal/corehelper/corehelper.go index 6f01897..9cc81fe 100644 --- a/keyvaultcertdownloader/internal/corehelper/corehelper.go +++ b/keyvaultcertdownloader/internal/corehelper/corehelper.go @@ -16,9 +16,13 @@ import ( "net/url" "os" "strings" + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" gpkcs12 "software.sslmate.com/src/go-pkcs12" ) @@ -126,10 +130,76 @@ func GetPrivateKeyFromPEMBlocks(blocks interface{}) (privateKey interface{}, err } // -//!SECTION - SDK dependent functions +//!SECTION - Azcore SDK related functions +// Source: https://github.com/Azure/azure-workload-identity/tree/main/examples // -//!SECTION - Internal functions +// clientAssertionCredential authenticates an application with assertions provided by a callback function. +type clientAssertionCredential struct { + assertion, file string + client confidential.Client + lastRead time.Time +} + +// clientAssertionCredentialOptions contains optional parameters for ClientAssertionCredential. +type clientAssertionCredentialOptions struct { + azcore.ClientOptions +} + +// NewClientAssertionCredential constructs a clientAssertionCredential. Pass nil for options to accept defaults. +func NewClientAssertionCredential(tenantID, clientID, authorityHost, file string, options *clientAssertionCredentialOptions) (*clientAssertionCredential, error) { + c := &clientAssertionCredential{file: file} + + if options == nil { + options = &clientAssertionCredentialOptions{} + } + + cred := confidential.NewCredFromAssertionCallback( + func(ctx context.Context, _ confidential.AssertionRequestOptions) (string, error) { + return c.getAssertion(ctx) + }, + ) + + client, err := confidential.New(clientID, cred, confidential.WithAuthority(fmt.Sprintf("%s%s/oauth2/token", authorityHost, tenantID))) + if err != nil { + return nil, fmt.Errorf("failed to create confidential client: %w", err) + } + c.client = client + + return c, nil +} + +// GetToken implements the TokenCredential interface +func (c *clientAssertionCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + // get the token from the confidential client + token, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes) + if err != nil { + return azcore.AccessToken{}, err + } + + return azcore.AccessToken{ + Token: token.AccessToken, + ExpiresOn: token.ExpiresOn, + }, nil +} + +// getAssertion reads the assertion from the file and returns it +// if the file has not been read in the last 5 minutes +func (c *clientAssertionCredential) getAssertion(context.Context) (string, error) { + if now := time.Now(); c.lastRead.Add(5 * time.Minute).Before(now) { + content, err := os.ReadFile(c.file) + if err != nil { + return "", err + } + c.assertion = string(content) + c.lastRead = now + } + return c.assertion, nil +} + +// +//!SECTION - Keyvault SDK related functions +// func getAKVCertificateBundle(cntx context.Context, client *azcertificates.Client, certURL url.URL) (azcertificates.CertificateBundle, error) { cert, err := client.GetCertificate(cntx, certURL.Path, "", nil) @@ -140,8 +210,6 @@ func getAKVCertificateBundle(cntx context.Context, client *azcertificates.Client return cert.CertificateBundle, nil } -//!SECTION - Public functions - // GetAKVCertificate - Gets a certificate from AKV func GetAKVCertificate(cntx context.Context, client *azsecrets.Client, certURL url.URL) (azsecrets.SecretBundle, error) { certSecret, err := client.GetSecret(cntx, certURL.Path, "", nil) diff --git a/keyvaultcertdownloader/keyvaultcertdownloader.go b/keyvaultcertdownloader/keyvaultcertdownloader.go index 718e202..b8c08f7 100644 --- a/keyvaultcertdownloader/keyvaultcertdownloader.go +++ b/keyvaultcertdownloader/keyvaultcertdownloader.go @@ -22,6 +22,7 @@ import ( "internal/corehelper" "internal/utils" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azcertificates" "github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets" @@ -44,12 +45,11 @@ const ( var ( validEnvironments = []string{"AZUREPUBLICCLOUD", "AZUREUSGOVERNMENTCLOUD", "AZUREGERMANCLOUD", "AZURECHINACLOUD"} certURL = flag.String("certurl", "", "certificate URL, e.g. \"https://mykeyvault.vault.azure.net/mycertificate\"") - keyVaultUrl = "" outputFolder = flag.String("outputfolder", "", "folder where PEM file with certificate and private key will be saved") environment = flag.String("environment", "AZUREPUBLICCLOUD", fmt.Sprintf("valid azure cloud environments: %v", validEnvironments)) cmdlineversion = flag.Bool("version", false, "shows current tool version") exitCode = 0 - version = "0.2.0" + version = "1.0.0" stdout = log.New(os.Stdout, "", log.LstdFlags) stderr = log.New(os.Stderr, "", log.LstdFlags) ) @@ -69,7 +69,7 @@ func main() { } // Checks if version output is needed - if *cmdlineversion == true { + if *cmdlineversion { fmt.Println(version) exitCode = 0 return @@ -105,44 +105,72 @@ func main() { utils.ConsoleOutput(fmt.Sprintf("Using Certificate URL: %v", *certURL), stdout) utils.ConsoleOutput(fmt.Sprintf("Environment: %v", *environment), stdout) - //utils.ConsoleOutput("Checking if this session needs to rely on AD Workload Identity webhook", stdout) - // client := keyvault.New() - // var authorizer autorest.Authorizer + utils.ConsoleOutput("Checking if this session needs to rely on AD Workload Identity webhook", stdout) + var cred azcore.TokenCredential tokenFilePath := os.Getenv("AZURE_FEDERATED_TOKEN_FILE") if tokenFilePath == "" { - // utils.ConsoleOutput("Getting authorizer", stdout) - // os.Setenv("AZURE_ENVIRONMENT", *environment) - // authorizer, err = kvauth.NewAuthorizerFromEnvironment() - // if err != nil { - // utils.ConsoleOutput(fmt.Sprintf(" unable to create vault authorizer: %v\n", err), stderr) - // exitCode = ERR_AUTHORIZER - // return - // } - - // utils.ConsoleOutput("Creating KeyVault base client", stdout) - + // Not running within a container with azwi webhook configured + utils.ConsoleOutput("Obtaining credentials", stdout) + cred, err = azidentity.NewDefaultAzureCredential(nil) + if err != nil { + utils.ConsoleOutput(fmt.Sprintf(" %v\n", err), stderr) + exitCode = ERR_CREDENTIALS + return + } } else { - } + // NOTE: following block is based on azure workload identity sample: + // https://github.dev/Azure/azure-workload-identity/blob/main/examples/msal-net/akvdotnet/TokenCredential.cs + // - // client.Authorizer = authorizer - utils.ConsoleOutput("Obtaining credentials", stdout) - cred, err := azidentity.NewDefaultAzureCredential(nil) - if err != nil { - utils.ConsoleOutput(fmt.Sprintf(" %v\n", err), stderr) - exitCode = ERR_CREDENTIALS - return + // Azure AD Workload Identity webhook will inject the following env vars + // AZURE_CLIENT_ID with the clientID set in the service account annotation + // AZURE_TENANT_ID with the tenantID set in the service account annotation. If not defined, then + // the tenantID provided via azure-wi-webhook-config for the webhook will be used. + // AZURE_FEDERATED_TOKEN_FILE is the service account token path + // AZURE_AUTHORITY_HOST is the AAD authority hostname + clientID := os.Getenv("AZURE_CLIENT_ID") + tenantID := os.Getenv("AZURE_TENANT_ID") + tokenFilePath := os.Getenv("AZURE_FEDERATED_TOKEN_FILE") + authorityHost := os.Getenv("AZURE_AUTHORITY_HOST") + + if clientID == "" { + utils.ConsoleOutput("AZURE_CLIENT_ID environment variable is not set", stderr) + exitCode = ERR_CREDENTIALS + return + } + if tenantID == "" { + utils.ConsoleOutput("AZURE_TENANT_ID environment variable is not set", stderr) + exitCode = ERR_CREDENTIALS + return + } + if authorityHost == "" { + utils.ConsoleOutput("AZURE_AUTHORITY_HOST environment variable is not set", stderr) + exitCode = ERR_CREDENTIALS + return + } + + cred, err = corehelper.NewClientAssertionCredential(tenantID, clientID, authorityHost, tokenFilePath, nil) + if err != nil { + utils.ConsoleOutput(fmt.Sprintf(" failed to create client assertion credential: %v\n", err), stderr) + exitCode = ERR_CREDENTIALS + return + } } utils.ConsoleOutput("Creating clients", stdout) azsecretsClient, err := azsecrets.NewClient(keyVaultUrl, cred, nil) if err != nil { - log.Fatalf("failed to create azsecrets client: %v", err) + utils.ConsoleOutput(fmt.Sprintf(" failed to create azsecrets client: %v\n", err), stderr) + exitCode = ERR_CREDENTIALS + return } azcertsClient, err := azcertificates.NewClient(keyVaultUrl, cred, nil) if err != nil { - log.Fatalf("failed to create azcertificates client: %v", err) + utils.ConsoleOutput(fmt.Sprintf(" failed to create azcertificates client: %v\n", err), stderr) + exitCode = ERR_CREDENTIALS + return } utils.ConsoleOutput("Getting certificate thumbprint", stdout)