diff --git a/pkg/api/admin/openshiftcluster.go b/pkg/api/admin/openshiftcluster.go index d1bab4bb6..c3d168f98 100644 --- a/pkg/api/admin/openshiftcluster.go +++ b/pkg/api/admin/openshiftcluster.go @@ -442,6 +442,7 @@ type UserAssignedIdentities map[string]ClusterUserAssignedIdentity type Identity struct { Type string `json:"type,omitempty"` UserAssignedIdentities UserAssignedIdentities `json:"userAssignedIdentities,omitempty"` + IdentityURL string `json:"identityURL,omitempty"` } // Install represents an install process. diff --git a/pkg/api/openshiftcluster.go b/pkg/api/openshiftcluster.go index 0826520d0..367f49ac4 100644 --- a/pkg/api/openshiftcluster.go +++ b/pkg/api/openshiftcluster.go @@ -803,4 +803,5 @@ type Identity struct { Type string `json:"type,omitempty"` UserAssignedIdentities UserAssignedIdentities `json:"userAssignedIdentities,omitempty"` + IdentityURL string `json:"identityURL,omitempty"` } diff --git a/pkg/frontend/openshiftcluster_putorpatch.go b/pkg/frontend/openshiftcluster_putorpatch.go index 1fa6f26fd..c35f7255e 100644 --- a/pkg/frontend/openshiftcluster_putorpatch.go +++ b/pkg/frontend/openshiftcluster_putorpatch.go @@ -25,6 +25,8 @@ import ( "github.com/Azure/ARO-RP/pkg/util/version" ) +var errMissingIdentityURL error = fmt.Errorf("identityURL not provided but required for workload identity cluster") + func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := ctx.Value(middleware.ContextKeyLog).(*logrus.Entry) @@ -41,10 +43,12 @@ func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Req subId := chi.URLParam(r, "subscriptionId") resourceProviderNamespace := chi.URLParam(r, "resourceProviderNamespace") + identityURL := r.Header.Get("x-ms-identity-url") + apiVersion := r.URL.Query().Get(api.APIVersionKey) err := cosmosdb.RetryOnPreconditionFailed(func() error { var err error - b, err = f._putOrPatchOpenShiftCluster(ctx, log, body, correlationData, systemData, r.URL.Path, originalPath, r.Method, referer, &header, f.apis[apiVersion].OpenShiftClusterConverter, f.apis[apiVersion].OpenShiftClusterStaticValidator, subId, resourceProviderNamespace, apiVersion) + b, err = f._putOrPatchOpenShiftCluster(ctx, log, body, correlationData, systemData, r.URL.Path, originalPath, r.Method, referer, &header, f.apis[apiVersion].OpenShiftClusterConverter, f.apis[apiVersion].OpenShiftClusterStaticValidator, subId, resourceProviderNamespace, apiVersion, identityURL) return err }) @@ -52,7 +56,7 @@ func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Req reply(log, w, header, b, err) } -func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.Entry, body []byte, correlationData *api.CorrelationData, systemData *api.SystemData, path, originalPath, method, referer string, header *http.Header, converter api.OpenShiftClusterConverter, staticValidator api.OpenShiftClusterStaticValidator, subId, resourceProviderNamespace string, apiVersion string) ([]byte, error) { +func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.Entry, body []byte, correlationData *api.CorrelationData, systemData *api.SystemData, path, originalPath, method, referer string, header *http.Header, converter api.OpenShiftClusterConverter, staticValidator api.OpenShiftClusterStaticValidator, subId, resourceProviderNamespace string, apiVersion string, identityURL string) ([]byte, error) { subscription, err := f.validateSubscriptionState(ctx, path, api.SubscriptionStateRegistered) if err != nil { return nil, err @@ -86,11 +90,17 @@ func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus. }, }, } + if !f.env.IsLocalDevelopmentMode() /* not local dev or CI */ { doc.OpenShiftCluster.Properties.FeatureProfile.GatewayEnabled = true } } + err = validateIdentityUrl(doc.OpenShiftCluster, identityURL, isCreate) + if err != nil { + return nil, err + } + doc.CorrelationData = correlationData err = validateTerminalProvisioningState(doc.OpenShiftCluster.Properties.ProvisioningState) @@ -288,6 +298,26 @@ func enrichClusterSystemData(doc *api.OpenShiftClusterDocument, systemData *api. } } +func validateIdentityUrl(cluster *api.OpenShiftCluster, identityURL string, isCreate bool) error { + // Don't persist identity URL in non-wimi clusters + if cluster.Properties.ServicePrincipalProfile != nil || cluster.Identity == nil { + return nil + } + + if identityURL == "" { + if isCreate { + return errMissingIdentityURL + } + return nil + } + + if cluster.Identity != nil { + cluster.Identity.IdentityURL = identityURL + } + + return nil +} + func (f *frontend) ValidateNewCluster(ctx context.Context, subscription *api.SubscriptionDocument, cluster *api.OpenShiftCluster, staticValidator api.OpenShiftClusterStaticValidator, ext interface{}, path string) error { err := staticValidator.Static(ext, nil, f.env.Location(), f.env.Domain(), f.env.FeatureIsSet(env.FeatureRequireD2sV3Workers), path) if err != nil { diff --git a/pkg/frontend/openshiftcluster_putorpatch_test.go b/pkg/frontend/openshiftcluster_putorpatch_test.go index f7c10ade9..540616ccf 100644 --- a/pkg/frontend/openshiftcluster_putorpatch_test.go +++ b/pkg/frontend/openshiftcluster_putorpatch_test.go @@ -3305,3 +3305,84 @@ func TestEnrichClusterSystemData(t *testing.T) { }) } } + +func TestValidateIdentityUrl(t *testing.T) { + for _, tt := range []struct { + name string + identityURL string + cluster *api.OpenShiftCluster + expected *api.OpenShiftCluster + isCreate bool + wantError error + }{ + { + name: "identity URL is empty, is not wi/mi cluster create", + identityURL: "", + cluster: &api.OpenShiftCluster{}, + expected: &api.OpenShiftCluster{}, + isCreate: false, + }, + { + name: "identity URL is empty, is wi/mi cluster create", + identityURL: "", + cluster: &api.OpenShiftCluster{}, + expected: &api.OpenShiftCluster{}, + isCreate: true, + wantError: errMissingIdentityURL, + }, + { + name: "cluster is not wi/mi, identityURL passed", + identityURL: "http://foo.bar", + cluster: &api.OpenShiftCluster{ + Properties: api.OpenShiftClusterProperties{ + ServicePrincipalProfile: &api.ServicePrincipalProfile{}, + }, + }, + expected: &api.OpenShiftCluster{ + Properties: api.OpenShiftClusterProperties{ + ServicePrincipalProfile: &api.ServicePrincipalProfile{}, + }, + }, + isCreate: true, + }, + { + name: "cluster is not wi/mi, identityURL not passed", + identityURL: "", + cluster: &api.OpenShiftCluster{ + Properties: api.OpenShiftClusterProperties{ + ServicePrincipalProfile: &api.ServicePrincipalProfile{}, + }, + }, + expected: &api.OpenShiftCluster{ + Properties: api.OpenShiftClusterProperties{ + ServicePrincipalProfile: &api.ServicePrincipalProfile{}, + }, + }, + isCreate: true, + }, + { + name: "pass - identity URL passed on wi/mi cluster", + cluster: &api.OpenShiftCluster{ + Identity: &api.Identity{}, + }, + identityURL: "http://foo.bar", + expected: &api.OpenShiftCluster{ + Identity: &api.Identity{ + IdentityURL: "http://foo.bar", + }, + }, + isCreate: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + err := validateIdentityUrl(tt.cluster, tt.identityURL, tt.isCreate) + if err != nil && err != tt.wantError { + t.Error(cmp.Diff(err, tt.wantError)) + } + + if !reflect.DeepEqual(tt.cluster, tt.expected) { + t.Error(cmp.Diff(tt.cluster, tt.expected)) + } + }) + } +}