diff --git a/cmd/webhook/main.go b/cmd/webhook/main.go index 47162f9..ebccc22 100644 --- a/cmd/webhook/main.go +++ b/cmd/webhook/main.go @@ -14,16 +14,21 @@ import ( wh "github.com/Azure/aad-pod-managed-identity/pkg/webhook" ) +var ( + arcCluster bool + audience string +) + func init() { log.SetLogger(zap.New()) } func main() { - var arcCluster bool // TODO (aramase) once webhook is added as an arc extension, use extension // util to check if running in arc cluster. flag.BoolVar(&arcCluster, "arc-cluster", false, "Running on arc cluster") + flag.StringVar(&audience, "audience", "", "Audience for service account token") flag.Parse() entryLog := log.Log.WithName("entrypoint") @@ -41,7 +46,7 @@ func main() { hookServer := mgr.GetWebhookServer() entryLog.Info("registering webhook to the webhook server") - podMutator, err := wh.NewPodMutator(mgr.GetClient(), arcCluster) + podMutator, err := wh.NewPodMutator(mgr.GetClient(), arcCluster, audience) if err != nil { entryLog.Error(err, "unable to set up pod mutator") os.Exit(1) diff --git a/pkg/webhook/webhook.go b/pkg/webhook/webhook.go index cfd3a4c..b3eba81 100644 --- a/pkg/webhook/webhook.go +++ b/pkg/webhook/webhook.go @@ -29,19 +29,30 @@ type podMutator struct { config *config.Config isARCCluster bool decoder *admission.Decoder + audience string } // NewPodMutator returns a pod mutation handler -func NewPodMutator(client client.Client, arcCluster bool) (admission.Handler, error) { +func NewPodMutator(client client.Client, arcCluster bool, audience string) (admission.Handler, error) { c, err := config.ParseConfig() if err != nil { return nil, err } + if audience == "" { + // get aad endpoint to configure as audience + aadEndpoint, err := getAADEndpoint(c) + if err != nil { + return nil, errors.Wrap(err, "failed to get AAD endpoint") + } + aadEndpoint = strings.TrimRight(aadEndpoint, "/") + audience = fmt.Sprintf("%s/federatedidentity", aadEndpoint) + } return &podMutator{ client: client, config: c, isARCCluster: arcCluster, + audience: audience, }, nil } @@ -98,7 +109,7 @@ func (m *podMutator) Handle(ctx context.Context, req admission.Request) admissio if !m.isARCCluster { // add the projected service account token volume to the pod if not exists - if err = addProjectedServiceAccountTokenVolume(pod, m.config, serviceAccountTokenExpiration); err != nil { + if err = addProjectedServiceAccountTokenVolume(pod, serviceAccountTokenExpiration, m.audience); err != nil { logger.Error(err, "failed to add projected service account volume") return admission.Errored(http.StatusBadRequest, err) } @@ -238,7 +249,7 @@ func addProjectedTokenVolumeMount(container corev1.Container) corev1.Container { return container } -func addProjectedServiceAccountTokenVolume(pod *corev1.Pod, config *config.Config, serviceAccountTokenExpiration int64) error { +func addProjectedServiceAccountTokenVolume(pod *corev1.Pod, serviceAccountTokenExpiration int64, audience string) error { // add the projected service account token volume to the pod if not exists for _, volume := range pod.Spec.Volumes { if volume.Projected == nil { @@ -254,13 +265,6 @@ func addProjectedServiceAccountTokenVolume(pod *corev1.Pod, config *config.Confi } } - // get aad endpoint to configure as audience - aadEndpoint, err := getAADEndpoint(config) - if err != nil { - return errors.Wrap(err, "failed to get AAD endpoint") - } - aadEndpoint = strings.TrimRight(aadEndpoint, "/") - // add the projected service account token volume // the path for this volume will always be set to "azure-identity-token" pod.Spec.Volumes = append( @@ -274,7 +278,7 @@ func addProjectedServiceAccountTokenVolume(pod *corev1.Pod, config *config.Confi ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ Path: TokenFilePathName, ExpirationSeconds: &serviceAccountTokenExpiration, - Audience: fmt.Sprintf("%s/federatedidentity", aadEndpoint), + Audience: audience, }, }, }, diff --git a/pkg/webhook/webhook_test.go b/pkg/webhook/webhook_test.go index 10bf92e..9651edf 100644 --- a/pkg/webhook/webhook_test.go +++ b/pkg/webhook/webhook_test.go @@ -477,7 +477,7 @@ func TestAddProjectedServiceAccountTokenVolume(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - err := addProjectedServiceAccountTokenVolume(test.pod, &config.Config{}, serviceAccountTokenExpiry) + err := addProjectedServiceAccountTokenVolume(test.pod, serviceAccountTokenExpiry, "https://login.microsoftonline.com/federatedidentity") if err != nil { t.Fatalf("expected err to be nil, got: %v", err) }