Skip to content

Commit

Permalink
Implement enable and platform workload identity flags on cluster create
Browse files Browse the repository at this point in the history
  • Loading branch information
tsatam committed Jun 17, 2024
1 parent f4ccab0 commit 685ec68
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 32 deletions.
29 changes: 29 additions & 0 deletions python/az/aro/azext_aro/_actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the Apache License 2.0.

import argparse

from azext_aro.vendored_sdks.azure.mgmt.redhatopenshift.v2024_08_12_preview.models import PlatformWorkloadIdentity
from azure.cli.core.azclierror import CLIError


# pylint:disable=protected-access
# pylint:disable=too-few-public-methods
class AROPlatformWorkloadIdentityAddAction(argparse._AppendAction):

def __call__(self, parser, namespace, values, option_string=None):
try:
if len(values) != 2:
msg = f"{option_string} requires 2 values in format: `OPERATOR_NAME RESOURCE_ID`"
raise argparse.ArgumentError(self, msg)

operator_name, resource_id = values
parsed = PlatformWorkloadIdentity(
operator_name=operator_name,
resource_id=resource_id
)

super().__call__(parser, namespace, parsed, option_string)

except ValueError as e:
raise CLIError(f"usage error: {option_string} NAME ID") from e
18 changes: 15 additions & 3 deletions python/az/aro/azext_aro/_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the Apache License 2.0.

from azext_aro._actions import AROPlatformWorkloadIdentityAddAction
from azext_aro._validators import validate_cidr
from azext_aro._validators import validate_client_id
from azext_aro._validators import validate_client_secret
Expand Down Expand Up @@ -56,7 +57,8 @@ def load_arguments(self, _):
validator=validate_client_secret(isCreate=True))

c.argument('version',
options_list=['--version', c.deprecate(target='--install-version', redirect='--version', hide=True)],
options_list=[
'--version', c.deprecate(target='--install-version', redirect='--version', hide=True)],
help='OpenShift version to use for cluster creation.',
validator=validate_version_format)

Expand All @@ -76,13 +78,15 @@ def load_arguments(self, _):
help='ResourceID of the DiskEncryptionSet to be used for master and worker VMs.',
validator=validate_disk_encryption_set)
c.argument('master_encryption_at_host', arg_type=get_three_state_flag(),
options_list=['--master-encryption-at-host', '--master-enc-host'],
options_list=['--master-encryption-at-host',
'--master-enc-host'],
help='Encryption at host flag for master VMs. [Default: false]')
c.argument('master_vm_size',
help='Size of master VMs. [Default: Standard_D8s_v3]')

c.argument('worker_encryption_at_host', arg_type=get_three_state_flag(),
options_list=['--worker-encryption-at-host', '--worker-enc-host'],
options_list=['--worker-encryption-at-host',
'--worker-enc-host'],
help='Encryption at host flag for worker VMs. [Default: false]')
c.argument('worker_vm_size',
help='Size of worker VMs. [Default: Standard_D4s_v3]')
Expand Down Expand Up @@ -123,6 +127,14 @@ def load_arguments(self, _):
validator=validate_load_balancer_managed_outbound_ip_count,
options_list=['--load-balancer-managed-outbound-ip-count', '--lb-ip-count'])

c.argument('enable_managed_identity', arg_group='Identity', arg_type=get_three_state_flag(),
options_list=['--enable-managed-identity', '--enable-mi'],
help='Enable managed identity for this cluster.', is_preview=True)
c.argument('platform_workload_identities', arg_group='Identity',
help='Assign a platform workload identity used within the cluster', is_preview=True,
options_list=['--assign-platform-workload-identity', '--assign-platform-wi'],
action=AROPlatformWorkloadIdentityAddAction, nargs='+')

with self.argument_context('aro update') as c:
c.argument('client_secret',
help='Client secret of cluster service principal.',
Expand Down
43 changes: 40 additions & 3 deletions python/az/aro/azext_aro/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from azure.cli.core.commands.client_factory import get_mgmt_service_client, get_subscription_id
from azure.cli.core.profiles import ResourceType
from azure.cli.core.azclierror import CLIInternalError, \
InvalidArgumentValueError, RequiredArgumentMissingError
InvalidArgumentValueError, RequiredArgumentMissingError, \
MutuallyExclusiveArgumentError
from azure.core.exceptions import ResourceNotFoundError
from knack.log import get_logger
from msrestazure.azure_exceptions import CloudError
Expand All @@ -36,19 +37,24 @@ def _validate_cidr(namespace):
def validate_client_id(namespace):
if namespace.client_id is None:
return
if namespace.enable_managed_identity is True:
raise MutuallyExclusiveArgumentError('Must not specify --client-id when --enable-managed-identity is True') # pylint: disable=line-too-long

try:
uuid.UUID(namespace.client_id)
except ValueError as e:
raise InvalidArgumentValueError(f"Invalid --client-id '{namespace.client_id}'.") from e
raise InvalidArgumentValueError(f"Invalid --client-id '{namespace.client_id}'.") from e # pylint: disable=line-too-long

if namespace.client_secret is None or not str(namespace.client_secret):
raise RequiredArgumentMissingError('Must specify --client-secret with --client-id.')
raise RequiredArgumentMissingError('Must specify --client-secret with --client-id.') # pylint: disable=line-too-long


def validate_client_secret(isCreate):
def _validate_client_secret(namespace):
if not isCreate or namespace.client_secret is None:
return
if namespace.enable_managed_identity is True:
raise MutuallyExclusiveArgumentError('Must not specify --client-secret when --enable-managed-identity is True') # pylint: disable=line-too-long
if namespace.client_id is None or not str(namespace.client_id):
raise RequiredArgumentMissingError('Must specify --client-id with --client-secret.')

Expand Down Expand Up @@ -283,3 +289,34 @@ def validate_load_balancer_managed_outbound_ip_count(namespace):
if namespace.load_balancer_managed_outbound_ip_count < minimum_managed_outbound_ips or namespace.load_balancer_managed_outbound_ip_count > maximum_managed_outbound_ips: # pylint: disable=line-too-long
error_msg = f"--load-balancer-managed-outbound-ip-count must be between {minimum_managed_outbound_ips} and {maximum_managed_outbound_ips} (inclusive)." # pylint: disable=line-too-long
raise InvalidArgumentValueError(error_msg)


def validate_enable_managed_identity(namespace):
if namespace.client_id is not None:
raise InvalidArgumentValueError('Must not specify --client-id when --enable-managed-identity is True')

if namespace.client_secret is not None:
raise InvalidArgumentValueError('Must not specify --client-secret when --enable-managed-identity is True')


def validate_platform_workload_identities(cmd, namespace):
if namespace.assign_platform_workload_identities is None:
return

if not namespace.enable_managed_identity:
raise RequiredArgumentMissingError('Must set --enable-managed-identity when providing platform workload identities') # pylint: disable=line-too-long

for identity in namespace.platform_workload_identities:
if not is_valid_resource_id(identity.resource_id):
identity.resource_id = resource_id(
subscription=get_subscription_id(cmd.cli_ctx),
resource_group=namespace.resource_group_name,
namespace='Microsoft.ManagedIdentity',
type='userAssignedIdentities',
name=identity.resource_id,
)

parsed_resource_id = parse_resource_id(identity.resource_id)
if parsed_resource_id['namespace'] != 'Microsoft.ManagedIdentity' or \
parsed_resource_id['type'] != 'userAssignedIdentities':
raise InvalidArgumentValueError(f"Resource {identity.resource_id} used for platform workload identity {identity.name} is not a valid userAssignedIdentity") # pylint: disable=line-too-long
43 changes: 29 additions & 14 deletions python/az/aro/azext_aro/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def aro_create(cmd, # pylint: disable=too-many-locals
apiserver_visibility=None,
ingress_visibility=None,
load_balancer_managed_outbound_ip_count=None,
enable_managed_identity=False,
platform_workload_identities=None,
tags=None,
version=None,
no_wait=False):
Expand Down Expand Up @@ -107,16 +109,18 @@ def aro_create(cmd, # pylint: disable=too-many-locals
random_id = generate_random_id()

aad = AADManager(cmd.cli_ctx)
if client_id is None:
client_id, client_secret = aad.create_application(cluster_resource_group or 'aro-' + random_id)

client_sp_id = aad.get_service_principal_id(client_id)
if not client_sp_id:
client_sp_id = aad.create_service_principal(client_id)
if not enable_managed_identity:
if client_id is None:
client_id, client_secret = aad.create_application(cluster_resource_group or 'aro-' + random_id)

rp_client_sp_id = aad.get_service_principal_id(resolve_rp_client_id())
if not rp_client_sp_id:
raise ResourceNotFoundError("RP service principal not found.")
client_sp_id = aad.get_service_principal_id(client_id)
if not client_sp_id:
client_sp_id = aad.create_service_principal(client_id)

rp_client_sp_id = aad.get_service_principal_id(resolve_rp_client_id())
if not rp_client_sp_id:
raise ResourceNotFoundError("RP service principal not found.")

if rp_mode_development():
worker_vm_size = worker_vm_size or 'Standard_D2s_v3'
Expand Down Expand Up @@ -146,10 +150,6 @@ def aro_create(cmd, # pylint: disable=too-many-locals
fips_validated_modules='Enabled' if fips_validated_modules else 'Disabled',
version=version or '',
),
service_principal_profile=openshiftcluster.ServicePrincipalProfile(
client_id=client_id,
client_secret=client_secret,
),
network_profile=openshiftcluster.NetworkProfile(
pod_cidr=pod_cidr or '10.128.0.0/14',
service_cidr=service_cidr or '172.30.0.0/16',
Expand Down Expand Up @@ -183,10 +183,25 @@ def aro_create(cmd, # pylint: disable=too-many-locals
visibility=ingress_visibility or 'Public',
)
],
service_principal_profile=None,
platform_workload_identity_profile=None,
)

sp_obj_ids = [client_sp_id, rp_client_sp_id]
ensure_resource_permissions(cmd.cli_ctx, oc, True, sp_obj_ids)
if enable_managed_identity is True:
oc.platform_workload_identity_profile = openshiftcluster.PlatformWorkloadIdentityProfile(
platform_workload_identities=platform_workload_identities
)

# TODO - perform client-side validation of required identity permissions

else:
oc.service_principal_profile = openshiftcluster.ServicePrincipalProfile(
client_id=client_id,
client_secret=client_secret,
)

sp_obj_ids = [client_sp_id, rp_client_sp_id]
ensure_resource_permissions(cmd.cli_ctx, oc, True, sp_obj_ids)

return sdk_no_wait(no_wait, client.open_shift_clusters.begin_create_or_update,
resource_group_name=resource_group_name,
Expand Down
Loading

0 comments on commit 685ec68

Please sign in to comment.