Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tenant ID to internal apis for CMSI usage #3655

Merged
merged 13 commits into from
Jul 4, 2024
1 change: 1 addition & 0 deletions pkg/api/openshiftcluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -808,4 +808,5 @@ type Identity struct {
Type string `json:"type,omitempty"`
UserAssignedIdentities UserAssignedIdentities `json:"userAssignedIdentities,omitempty"`
IdentityURL string `json:"identityURL,omitempty" mutable:"true"`
TenantID string `json:"tenantId,omitempty" mutable:"true"`
niontive marked this conversation as resolved.
Show resolved Hide resolved
}
43 changes: 27 additions & 16 deletions pkg/frontend/openshiftcluster_putorpatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"github.com/Azure/ARO-RP/pkg/util/version"
)

var errMissingIdentityURL error = fmt.Errorf("identityURL not provided but required for workload identity cluster")
var errMissingIdentityParmeter error = fmt.Errorf("identity parameter not provided but required for workload identity cluster")
niontive marked this conversation as resolved.
Show resolved Hide resolved

func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand All @@ -44,19 +44,21 @@ func (f *frontend) putOrPatchOpenShiftCluster(w http.ResponseWriter, r *http.Req
resourceProviderNamespace := chi.URLParam(r, "resourceProviderNamespace")

identityURL := r.Header.Get("x-ms-identity-url")
identityTenantID := r.Header.Get("x-ms-home-tenant-id")

apiVersion := r.URL.Query().Get(api.APIVersionKey)
err := cosmosdb.RetryOnPreconditionFailed(func() error {
var err error
b, err = f._putOrPatchOpenShiftCluster(ctx, log, body, correlationData, systemData, r.URL.Path, originalPath, r.Method, referer, &header, f.apis[apiVersion].OpenShiftClusterConverter, f.apis[apiVersion].OpenShiftClusterStaticValidator, subId, resourceProviderNamespace, apiVersion, identityURL)
b, err = f._putOrPatchOpenShiftCluster(ctx, log, body, correlationData, systemData, r.URL.Path, originalPath, r.Method, referer, &header, f.apis[apiVersion].OpenShiftClusterConverter, f.apis[apiVersion].OpenShiftClusterStaticValidator, subId, resourceProviderNamespace, apiVersion, identityURL, identityTenantID)
return err
})

frontendOperationResultLog(log, r.Method, err)
reply(log, w, header, b, err)
}

func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.Entry, body []byte, correlationData *api.CorrelationData, systemData *api.SystemData, path, originalPath, method, referer string, header *http.Header, converter api.OpenShiftClusterConverter, staticValidator api.OpenShiftClusterStaticValidator, subId, resourceProviderNamespace string, apiVersion string, identityURL string) ([]byte, error) {
// TODO - refactor this function to reduce the number of parameters
func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.Entry, body []byte, correlationData *api.CorrelationData, systemData *api.SystemData, path, originalPath, method, referer string, header *http.Header, converter api.OpenShiftClusterConverter, staticValidator api.OpenShiftClusterStaticValidator, subId, resourceProviderNamespace string, apiVersion string, identityURL string, identityTenantID string) ([]byte, error) {
Comment on lines +60 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we often leave these types of comments with the best intentions but then never loop back on them.

From just jumping into this I would think we need a little struct here but maybe I'm missing something? I get not adding it here to reduce the PR complexity but could we have a branch off this with the refactor so we know it's all going to go together?

Happy to pair on this too if that helps :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have bandwidth at the moment to tackle this refactor. On top of that, I'm not the best person to work on this since I don't own the frontend code.

I'll defer to @bennerv and @hlipsig for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a little play with it and I think I see where you're coming from. Lots of cans and lots of worms. I opened a small refactror PR here that hopefully helps the next person who comes along and I'm okay to leave it there for now.

Also AFAIK we all own all the code, there are no specific teams that own any part of the codebase but do correct me if I'm wrong.

subscription, err := f.validateSubscriptionState(ctx, path, api.SubscriptionStateRegistered)
if err != nil {
return nil, err
Expand Down Expand Up @@ -96,9 +98,16 @@ func (f *frontend) _putOrPatchOpenShiftCluster(ctx context.Context, log *logrus.
}
}

err = validateIdentityUrl(doc.OpenShiftCluster, identityURL, isCreate)
if err != nil {
return nil, err
// Don't persist identity parameters in non-wimi clusters
if doc.OpenShiftCluster.Properties.ServicePrincipalProfile == nil || doc.OpenShiftCluster.Identity != nil {
niontive marked this conversation as resolved.
Show resolved Hide resolved
if isCreate {
niontive marked this conversation as resolved.
Show resolved Hide resolved
if err := validateIdentityUrl(doc.OpenShiftCluster, identityURL); err != nil {
return nil, err
}
if err := validateIdentityTenantID(doc.OpenShiftCluster, identityTenantID); err != nil {
return nil, err
}
}
}

doc.CorrelationData = correlationData
Expand Down Expand Up @@ -298,24 +307,26 @@ func enrichClusterSystemData(doc *api.OpenShiftClusterDocument, systemData *api.
}
}

func validateIdentityUrl(cluster *api.OpenShiftCluster, identityURL string, isCreate bool) error {
// Don't persist identity URL in non-wimi clusters
if cluster.Properties.ServicePrincipalProfile != nil || cluster.Identity == nil {
return nil
}

func validateIdentityUrl(cluster *api.OpenShiftCluster, identityURL string) error {
if identityURL == "" {
if isCreate {
return errMissingIdentityURL
}
return nil
return fmt.Errorf("%w: %s", errMissingIdentityParmeter, "identity URL")
}

cluster.Identity.IdentityURL = identityURL

return nil
}

func validateIdentityTenantID(cluster *api.OpenShiftCluster, identityTenantID string) error {
if identityTenantID == "" {
return fmt.Errorf("%w: %s", errMissingIdentityParmeter, "identity tenant ID")
}

cluster.Identity.TenantID = identityTenantID

return nil
}

func (f *frontend) ValidateNewCluster(ctx context.Context, subscription *api.SubscriptionDocument, cluster *api.OpenShiftCluster, staticValidator api.OpenShiftClusterStaticValidator, ext interface{}, path string) error {
err := staticValidator.Static(ext, nil, f.env.Location(), f.env.Domain(), f.env.FeatureIsSet(env.FeatureRequireD2sV3Workers), path)
if err != nil {
Expand Down
78 changes: 40 additions & 38 deletions pkg/frontend/openshiftcluster_putorpatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package frontend
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
Expand Down Expand Up @@ -3312,71 +3313,72 @@ func TestValidateIdentityUrl(t *testing.T) {
identityURL string
cluster *api.OpenShiftCluster
expected *api.OpenShiftCluster
isCreate bool
wantError error
}{
{
name: "identity URL is empty, is not wi/mi cluster create",
name: "identity URL is empty",
identityURL: "",
cluster: &api.OpenShiftCluster{},
expected: &api.OpenShiftCluster{},
isCreate: false,
wantError: errMissingIdentityParmeter,
},
{
name: "identity URL is empty, is wi/mi cluster create",
identityURL: "",
cluster: &api.OpenShiftCluster{},
expected: &api.OpenShiftCluster{},
isCreate: true,
wantError: errMissingIdentityURL,
},
{
name: "cluster is not wi/mi, identityURL passed",
identityURL: "http://foo.bar",
name: "pass - identity URL passed",
cluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
},
Identity: &api.Identity{},
},
identityURL: "http://foo.bar",
expected: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
Identity: &api.Identity{
IdentityURL: "http://foo.bar",
},
},
isCreate: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
err := validateIdentityUrl(tt.cluster, tt.identityURL)
if !errors.Is(err, tt.wantError) {
t.Error(cmp.Diff(err, tt.wantError))
}

if !reflect.DeepEqual(tt.cluster, tt.expected) {
t.Error(cmp.Diff(tt.cluster, tt.expected))
}
})
}
}

func TestValidateIdentityTenantID(t *testing.T) {
for _, tt := range []struct {
name string
tenantID string
cluster *api.OpenShiftCluster
expected *api.OpenShiftCluster
wantError error
}{
{
name: "cluster is not wi/mi, identityURL not passed",
identityURL: "",
cluster: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
},
},
expected: &api.OpenShiftCluster{
Properties: api.OpenShiftClusterProperties{
ServicePrincipalProfile: &api.ServicePrincipalProfile{},
},
},
isCreate: true,
name: "tenantID is empty",
tenantID: "",
cluster: &api.OpenShiftCluster{},
expected: &api.OpenShiftCluster{},
wantError: errMissingIdentityParmeter,
},
{
name: "pass - identity URL passed on wi/mi cluster",
name: "pass - tenantID passed",
cluster: &api.OpenShiftCluster{
Identity: &api.Identity{},
},
identityURL: "http://foo.bar",
tenantID: "bogus",
expected: &api.OpenShiftCluster{
Identity: &api.Identity{
IdentityURL: "http://foo.bar",
TenantID: "bogus",
},
},
isCreate: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
err := validateIdentityUrl(tt.cluster, tt.identityURL, tt.isCreate)
if err != nil && err != tt.wantError {
err := validateIdentityTenantID(tt.cluster, tt.tenantID)
if !errors.Is(err, tt.wantError) {
t.Error(cmp.Diff(err, tt.wantError))
}

Expand Down
Loading