Skip to content

Commit

Permalink
ARO-4376 add a stringutil funcs for string array comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
rajdeepc2792 committed Jun 12, 2024
1 parent e80c020 commit e3362d1
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 33 deletions.
3 changes: 2 additions & 1 deletion pkg/portal/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/gorilla/csrf"

"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/util/stringutils"
"github.com/Azure/ARO-RP/pkg/util/version"
)

Expand All @@ -23,7 +24,7 @@ type PortalInfo struct {

func (p *portal) info(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
elevated := len(middleware.GroupsIntersect(p.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0
elevated := len(stringutils.GroupsIntersect(p.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0

resp := PortalInfo{
Location: p.env.Location(),
Expand Down
3 changes: 2 additions & 1 deletion pkg/portal/kubeconfig/kubeconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/Azure/ARO-RP/pkg/portal/util/clientcache"
"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
"github.com/Azure/ARO-RP/pkg/util/stringutils"
)

const (
Expand Down Expand Up @@ -95,7 +96,7 @@ func (k *Kubeconfig) New(w http.ResponseWriter, r *http.Request) {
return
}

elevated := len(middleware.GroupsIntersect(k.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0
elevated := len(stringutils.GroupsIntersect(k.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0

token := k.DbPortal.NewUUID()
portalDoc := &api.PortalDocument{
Expand Down
16 changes: 2 additions & 14 deletions pkg/portal/middleware/aad.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/util/oidc"
"github.com/Azure/ARO-RP/pkg/util/roundtripper"
"github.com/Azure/ARO-RP/pkg/util/stringutils"
"github.com/Azure/ARO-RP/pkg/util/uuid"
)

Expand Down Expand Up @@ -308,7 +309,7 @@ func (a *aad) callback(w http.ResponseWriter, r *http.Request) {
return
}

groupsIntersect := GroupsIntersect(a.allGroups, claims.Groups)
groupsIntersect := stringutils.GroupsIntersect(a.allGroups, claims.Groups)
if len(groupsIntersect) == 0 {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
}
Expand Down Expand Up @@ -374,16 +375,3 @@ func (a *aad) internalServerError(w http.ResponseWriter, err error) {
a.log.Warn(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

func GroupsIntersect(as, bs []string) (gs []string) {
for _, a := range as {
for _, b := range bs {
if a == b {
gs = append(gs, a)
break
}
}
}

return gs
}
3 changes: 2 additions & 1 deletion pkg/portal/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/proxy"
"github.com/Azure/ARO-RP/pkg/util/stringutils"
)

const (
Expand Down Expand Up @@ -132,7 +133,7 @@ func (s *SSH) New(w http.ResponseWriter, r *http.Request) {
return
}

elevated := len(middleware.GroupsIntersect(s.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0
elevated := len(stringutils.GroupsIntersect(s.elevatedGroupIDs, ctx.Value(middleware.ContextKeyGroups).([]string))) > 0
if !elevated {
s.sendResponse(w, "", "", "", "Elevated access is required.", s.env.IsLocalDevelopmentMode())
return
Expand Down
23 changes: 23 additions & 0 deletions pkg/util/stringutils/stringutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package stringutils

import (
"strings"

"github.com/stretchr/testify/assert"
)

// LastTokenByte splits s on sep and returns the last token
Expand All @@ -20,3 +22,24 @@ func Contains(list []string, value string) bool {
}
return false
}

func GroupsIntersect(as, bs []string) (gs []string) {
for _, a := range as {
for _, b := range bs {
if a == b {
gs = append(gs, a)
break
}
}
}

return gs
}

type mockT struct{}

func (t mockT) Errorf(string, ...interface{}) {}

func ElementsMatch(as, bs []string) (ok bool) {
return assert.ElementsMatch(mockT{}, as, bs)
}
4 changes: 2 additions & 2 deletions pkg/validate/dynamic/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ import (
apisubnet "github.com/Azure/ARO-RP/pkg/api/util/subnet"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/env"
"github.com/Azure/ARO-RP/pkg/portal/middleware"
"github.com/Azure/ARO-RP/pkg/util/azureclient"
"github.com/Azure/ARO-RP/pkg/util/azureclient/authz/remotepdp"
"github.com/Azure/ARO-RP/pkg/util/azureclient/azuresdk/armauthorization"
"github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/compute"
"github.com/Azure/ARO-RP/pkg/util/azureclient/mgmt/network"
"github.com/Azure/ARO-RP/pkg/util/stringutils"
"github.com/Azure/ARO-RP/pkg/util/token"
)

Expand Down Expand Up @@ -792,7 +792,7 @@ func (dv *dynamic) ValidatePreConfiguredNSGs(ctx context.Context, oc *api.OpenSh
func (dv *dynamic) validateActions(ctx context.Context, r *azure.Resource, actions []string) error {
if dv.platformIdentities != nil {
for _, platformIdentity := range dv.platformIdentities {
actionsToValidate := middleware.GroupsIntersect(actions, dv.platformIdentitiesActionsMap[platformIdentity.OperatorName])
actionsToValidate := stringutils.GroupsIntersect(actions, dv.platformIdentitiesActionsMap[platformIdentity.OperatorName])
if len(actionsToValidate) > 0 {
if err := dv.validateActionsByOID(ctx, r, actionsToValidate, &platformIdentity.ObjectID); err != nil {
return err
Expand Down
18 changes: 4 additions & 14 deletions pkg/validate/dynamic/platformworkloadidentityprofile.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,16 @@ import (

sdkauthorization "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"

"github.com/Azure/ARO-RP/pkg/api"
"github.com/Azure/ARO-RP/pkg/database"
"github.com/Azure/ARO-RP/pkg/util/rbac"
"github.com/Azure/ARO-RP/pkg/util/stringutils"
)

// Copyright (c) Microsoft Corporation.
// Licensed under the Apache License 2.0.

type fakeT struct{}

func (t fakeT) Errorf(string, ...interface{}) {}

func (dv *dynamic) ValidatePlatformWorkloadIdentityProfile(ctx context.Context, oc *api.OpenShiftCluster, dbPlatformWorkloadIdentityRoleSets database.PlatformWorkloadIdentityRoleSets) error {
dv.log.Print("ValidatePlatformWorkloadIdentityProfile")

Expand All @@ -46,12 +42,11 @@ func (dv *dynamic) ValidatePlatformWorkloadIdentityProfile(ctx context.Context,
recievedOperatorIdentities := []string{}
for _, role := range roles {
requiredOperatorIdentities = append(requiredOperatorIdentities, role.OperatorName)
// platformIdentitiesRoleMap[role.OperatorName] = role
}
for _, role := range oc.Properties.PlatformWorkloadIdentityProfile.PlatformWorkloadIdentities {
recievedOperatorIdentities = append(recievedOperatorIdentities, role.OperatorName)
}
ok := assert.ElementsMatch(fakeT{}, requiredOperatorIdentities, recievedOperatorIdentities)
ok := stringutils.ElementsMatch(requiredOperatorIdentities, recievedOperatorIdentities)
if !ok {
return api.NewCloudError(http.StatusBadRequest, api.CloudErrorCodePlatformWorkloadIdentityMismatch,
"properties.ValidatePlatformWorkloadIdentityProfile.PlatformWorkloadIdentities", "There's a mismatch between the required and expected set of platform workload identities for the requested OpenShift version '%s'.", requestedInstallVersion)
Expand Down Expand Up @@ -86,14 +81,9 @@ func (dv *dynamic) ValidatePlatformWorkloadIdentityProfile(ctx context.Context,
}

func (dv *dynamic) validateClusterMSI(ctx context.Context, oc *api.OpenShiftCluster) error {
if len(oc.Identity.UserAssignedIdentities) <= 0 {
return api.NewCloudError(http.StatusBadRequest, api.CloudErrorCodeInvalidClusterMSICount,
"identity.userAssignedIdentities", "No OpenShift Cluster associated User Assigned Identity is provided for the Workload Identity OpenShift cluster creation")
}

if len(oc.Identity.UserAssignedIdentities) > 1 {
if len(oc.Identity.UserAssignedIdentities) != 1 {
return api.NewCloudError(http.StatusBadRequest, api.CloudErrorCodeInvalidClusterMSICount,
"identity.userAssignedIdentities", "More than one OpenShift Cluster associated User Assigned Identity are provided for the Workload Identity OpenShift cluster creation")
"identity.userAssignedIdentities", "Unexpected number of OpenShift Cluster associated User Assigned Identity are provided for the Workload Identity OpenShift cluster, expected one User Assigned Identity")
}

for _, identity := range oc.Identity.UserAssignedIdentities {
Expand Down

0 comments on commit e3362d1

Please sign in to comment.