From e3362d1341a6b76f05fa2e38cb1a0c334bbfdd94 Mon Sep 17 00:00:00 2001 From: Rajdeep Singh Chauhan Date: Wed, 12 Jun 2024 16:55:38 -0400 Subject: [PATCH] ARO-4376 add a stringutil funcs for string array comparison --- pkg/portal/info.go | 3 ++- pkg/portal/kubeconfig/kubeconfig.go | 3 ++- pkg/portal/middleware/aad.go | 16 ++----------- pkg/portal/ssh/ssh.go | 3 ++- pkg/util/stringutils/stringutils.go | 23 +++++++++++++++++++ pkg/validate/dynamic/dynamic.go | 4 ++-- .../platformworkloadidentityprofile.go | 18 ++++----------- 7 files changed, 37 insertions(+), 33 deletions(-) diff --git a/pkg/portal/info.go b/pkg/portal/info.go index 1a6a6285806..4a9aa7026f2 100644 --- a/pkg/portal/info.go +++ b/pkg/portal/info.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/csrf" "github.com/Azure/ARO-RP/pkg/portal/middleware" + "github.com/Azure/ARO-RP/pkg/util/stringutils" "github.com/Azure/ARO-RP/pkg/util/version" ) @@ -23,7 +24,7 @@ type PortalInfo struct { func (p *portal) info(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - elevated := len(middleware.GroupsIntersect(p.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0 + elevated := len(stringutils.GroupsIntersect(p.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0 resp := PortalInfo{ Location: p.env.Location(), diff --git a/pkg/portal/kubeconfig/kubeconfig.go b/pkg/portal/kubeconfig/kubeconfig.go index ec9808ffd1a..478cc243893 100644 --- a/pkg/portal/kubeconfig/kubeconfig.go +++ b/pkg/portal/kubeconfig/kubeconfig.go @@ -25,6 +25,7 @@ import ( "github.com/Azure/ARO-RP/pkg/portal/util/clientcache" "github.com/Azure/ARO-RP/pkg/proxy" "github.com/Azure/ARO-RP/pkg/util/roundtripper" + "github.com/Azure/ARO-RP/pkg/util/stringutils" ) const ( @@ -95,7 +96,7 @@ func (k *Kubeconfig) New(w http.ResponseWriter, r *http.Request) { return } - elevated := len(middleware.GroupsIntersect(k.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0 + elevated := len(stringutils.GroupsIntersect(k.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0 token := k.DbPortal.NewUUID() portalDoc := &api.PortalDocument{ diff --git a/pkg/portal/middleware/aad.go b/pkg/portal/middleware/aad.go index 294f7645f5e..73214b44f68 100644 --- a/pkg/portal/middleware/aad.go +++ b/pkg/portal/middleware/aad.go @@ -23,6 +23,7 @@ import ( "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/util/oidc" "github.com/Azure/ARO-RP/pkg/util/roundtripper" + "github.com/Azure/ARO-RP/pkg/util/stringutils" "github.com/Azure/ARO-RP/pkg/util/uuid" ) @@ -308,7 +309,7 @@ func (a *aad) callback(w http.ResponseWriter, r *http.Request) { return } - groupsIntersect := GroupsIntersect(a.allGroups, claims.Groups) + groupsIntersect := stringutils.GroupsIntersect(a.allGroups, claims.Groups) if len(groupsIntersect) == 0 { http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) } @@ -374,16 +375,3 @@ func (a *aad) internalServerError(w http.ResponseWriter, err error) { a.log.Warn(err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } - -func GroupsIntersect(as, bs []string) (gs []string) { - for _, a := range as { - for _, b := range bs { - if a == b { - gs = append(gs, a) - break - } - } - } - - return gs -} diff --git a/pkg/portal/ssh/ssh.go b/pkg/portal/ssh/ssh.go index 6977f7332e5..782151bacb0 100644 --- a/pkg/portal/ssh/ssh.go +++ b/pkg/portal/ssh/ssh.go @@ -25,6 +25,7 @@ import ( "github.com/Azure/ARO-RP/pkg/env" "github.com/Azure/ARO-RP/pkg/portal/middleware" "github.com/Azure/ARO-RP/pkg/proxy" + "github.com/Azure/ARO-RP/pkg/util/stringutils" ) const ( @@ -132,7 +133,7 @@ func (s *SSH) New(w http.ResponseWriter, r *http.Request) { return } - elevated := len(middleware.GroupsIntersect(s.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0 + elevated := len(stringutils.GroupsIntersect(s.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0 if !elevated { s.sendResponse(w, "", "", "", "Elevated access is required.", s.env.IsLocalDevelopmentMode()) return diff --git a/pkg/util/stringutils/stringutils.go b/pkg/util/stringutils/stringutils.go index 0a48c671d55..6367dfd9f42 100644 --- a/pkg/util/stringutils/stringutils.go +++ b/pkg/util/stringutils/stringutils.go @@ -5,6 +5,8 @@ package stringutils import ( "strings" + + "github.com/stretchr/testify/assert" ) // LastTokenByte splits s on sep and returns the last token @@ -20,3 +22,24 @@ func Contains(list []string, value string) bool { } return false } + +func GroupsIntersect(as, bs []string) (gs []string) { + for _, a := range as { + for _, b := range bs { + if a == b { + gs = append(gs, a) + break + } + } + } + + return gs +} + +type mockT struct{} + +func (t mockT) Errorf(string, ...interface{}) {} + +func ElementsMatch(as, bs []string) (ok bool) { + return assert.ElementsMatch(mockT{}, as, bs) +} diff --git a/pkg/validate/dynamic/dynamic.go b/pkg/validate/dynamic/dynamic.go index d5cb56ffd57..9864a0ad841 100644 --- a/pkg/validate/dynamic/dynamic.go +++ b/pkg/validate/dynamic/dynamic.go @@ -24,12 +24,12 @@ import ( apisubnet "github.com/Azure/ARO-RP/pkg/api/util/subnet" "github.com/Azure/ARO-RP/pkg/database" "github.com/Azure/ARO-RP/pkg/env" - "github.com/Azure/ARO-RP/pkg/portal/middleware" "github.com/Azure/ARO-RP/pkg/util/azureclient" "github.com/Azure/ARO-RP/pkg/util/azureclient/authz/remotepdp" "github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armauthorization" "github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/compute" "github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network" + "github.com/Azure/ARO-RP/pkg/util/stringutils" "github.com/Azure/ARO-RP/pkg/util/token" ) @@ -792,7 +792,7 @@ func (dv *dynamic) ValidatePreConfiguredNSGs(ctx context.Context, oc *api.OpenSh func (dv *dynamic) validateActions(ctx context.Context, r *azure.Resource, actions []string) error { if dv.platformIdentities != nil { for _, platformIdentity := range dv.platformIdentities { - actionsToValidate := middleware.GroupsIntersect(actions, dv.platformIdentitiesActionsMap[platformIdentity.OperatorName]) + actionsToValidate := stringutils.GroupsIntersect(actions, dv.platformIdentitiesActionsMap[platformIdentity.OperatorName]) if len(actionsToValidate) > 0 { if err := dv.validateActionsByOID(ctx, r, actionsToValidate, &platformIdentity.ObjectID); err != nil { return err diff --git a/pkg/validate/dynamic/platformworkloadidentityprofile.go b/pkg/validate/dynamic/platformworkloadidentityprofile.go index b9d53c7d5be..cd337c1b3ba 100644 --- a/pkg/validate/dynamic/platformworkloadidentityprofile.go +++ b/pkg/validate/dynamic/platformworkloadidentityprofile.go @@ -6,20 +6,16 @@ import ( sdkauthorization "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" "github.com/Azure/go-autorest/autorest/azure" - "github.com/stretchr/testify/assert" "github.com/Azure/ARO-RP/pkg/api" "github.com/Azure/ARO-RP/pkg/database" "github.com/Azure/ARO-RP/pkg/util/rbac" + "github.com/Azure/ARO-RP/pkg/util/stringutils" ) // Copyright (c) Microsoft Corporation. // Licensed under the Apache License 2.0. -type fakeT struct{} - -func (t fakeT) Errorf(string, ...interface{}) {} - func (dv *dynamic) ValidatePlatformWorkloadIdentityProfile(ctx context.Context, oc *api.OpenShiftCluster, dbPlatformWorkloadIdentityRoleSets database.PlatformWorkloadIdentityRoleSets) error { dv.log.Print("ValidatePlatformWorkloadIdentityProfile") @@ -46,12 +42,11 @@ func (dv *dynamic) ValidatePlatformWorkloadIdentityProfile(ctx context.Context, recievedOperatorIdentities := []string{} for _, role := range roles { requiredOperatorIdentities = append(requiredOperatorIdentities, role.OperatorName) - // platformIdentitiesRoleMap[role.OperatorName] = role } for _, role := range oc.Properties.PlatformWorkloadIdentityProfile.PlatformWorkloadIdentities { recievedOperatorIdentities = append(recievedOperatorIdentities, role.OperatorName) } - ok := assert.ElementsMatch(fakeT{}, requiredOperatorIdentities, recievedOperatorIdentities) + ok := stringutils.ElementsMatch(requiredOperatorIdentities, recievedOperatorIdentities) if !ok { return api.NewCloudError(http.StatusBadRequest, api.CloudErrorCodePlatformWorkloadIdentityMismatch, "properties.ValidatePlatformWorkloadIdentityProfile.PlatformWorkloadIdentities", "There's a mismatch between the required and expected set of platform workload identities for the requested OpenShift version '%s'.", requestedInstallVersion) @@ -86,14 +81,9 @@ func (dv *dynamic) ValidatePlatformWorkloadIdentityProfile(ctx context.Context, } func (dv *dynamic) validateClusterMSI(ctx context.Context, oc *api.OpenShiftCluster) error { - if len(oc.Identity.UserAssignedIdentities) <= 0 { - return api.NewCloudError(http.StatusBadRequest, api.CloudErrorCodeInvalidClusterMSICount, - "identity.userAssignedIdentities", "No OpenShift Cluster associated User Assigned Identity is provided for the Workload Identity OpenShift cluster creation") - } - - if len(oc.Identity.UserAssignedIdentities) > 1 { + if len(oc.Identity.UserAssignedIdentities) != 1 { return api.NewCloudError(http.StatusBadRequest, api.CloudErrorCodeInvalidClusterMSICount, - "identity.userAssignedIdentities", "More than one OpenShift Cluster associated User Assigned Identity are provided for the Workload Identity OpenShift cluster creation") + "identity.userAssignedIdentities", "Unexpected number of OpenShift Cluster associated User Assigned Identity are provided for the Workload Identity OpenShift cluster, expected one User Assigned Identity") } for _, identity := range oc.Identity.UserAssignedIdentities {