Skip to content

Commit

Permalink
add disableLocalAuth while creating aml compute (Azure#38913)
Browse files Browse the repository at this point in the history
* add disableLocalAuth while creating aml compute

* add unit tests

* push recordings

* optimize expression

* format

* fix black error format

* update logic for disable local auth

* formatting

* add info to changelog
  • Loading branch information
pdhotems authored Dec 20, 2024
1 parent ce16093 commit 6c0728a
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 14 deletions.
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
### Features Added

### Bugs Fixed

- Fixed disableLocalAuthentication handling while creating amlCompute

## 1.23.0 (2024-12-05)

Expand Down
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/ml/azure-ai-ml",
"Tag": "python/ml/azure-ai-ml_d220df7fea"
"Tag": "python/ml/azure-ai-ml_003b900b39"
}
33 changes: 24 additions & 9 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_compute/aml_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from typing import Any, Dict, Optional

from azure.ai.ml._restclient.v2022_10_01_preview.models import AmlCompute as AmlComputeRest
from azure.ai.ml._restclient.v2022_10_01_preview.models import (
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
AmlCompute as AmlComputeRest,
)
from azure.ai.ml._restclient.v2022_12_01_preview.models import (
AmlComputeProperties,
ComputeResource,
ResourceId,
Expand All @@ -16,7 +18,11 @@
)
from azure.ai.ml._schema._utils.utils import get_subnet_str
from azure.ai.ml._schema.compute.aml_compute import AmlComputeSchema
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_pascal, to_iso_duration_format
from azure.ai.ml._utils.utils import (
camel_to_snake,
snake_to_pascal,
to_iso_duration_format,
)
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, TYPE
from azure.ai.ml.constants._compute import ComputeDefaults, ComputeType
from azure.ai.ml.entities._credentials import IdentityConfiguration
Expand Down Expand Up @@ -180,7 +186,7 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
name=rest_obj.name,
id=rest_obj.id,
description=prop.description,
location=prop.compute_location if prop.compute_location else rest_obj.location,
location=(prop.compute_location if prop.compute_location else rest_obj.location),
tags=rest_obj.tags if rest_obj.tags else None,
provisioning_state=prop.provisioning_state,
provisioning_errors=(
Expand All @@ -190,8 +196,8 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
),
size=prop.properties.vm_size,
tier=camel_to_snake(prop.properties.vm_priority),
min_instances=prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None,
max_instances=prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None,
min_instances=(prop.properties.scale_settings.min_node_count if prop.properties.scale_settings else None),
max_instances=(prop.properties.scale_settings.max_node_count if prop.properties.scale_settings else None),
network_settings=network_settings or None,
ssh_settings=ssh_settings,
ssh_public_access_enabled=(prop.properties.remote_login_port_public_access == "Enabled"),
Expand All @@ -200,7 +206,9 @@ def _load_from_rest(cls, rest_obj: ComputeResource) -> "AmlCompute":
if prop.properties.scale_settings and prop.properties.scale_settings.node_idle_time_before_scale_down
else None
),
identity=IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None,
identity=(
IdentityConfiguration._from_compute_rest_object(rest_obj.identity) if rest_obj.identity else None
),
created_on=prop.additional_properties.get("createdOn", None),
enable_node_public_ip=(
prop.properties.enable_node_public_ip if prop.properties.enable_node_public_ip is not None else True
Expand Down Expand Up @@ -244,21 +252,28 @@ def _to_rest_object(self) -> ComputeResource:
),
)
remote_login_public_access = "Enabled"
disableLocalAuth = not (self.ssh_public_access_enabled and self.ssh_settings is not None)
if self.ssh_public_access_enabled is not None:
remote_login_public_access = "Enabled" if self.ssh_public_access_enabled else "Disabled"

else:
remote_login_public_access = "NotSpecified"
aml_prop = AmlComputeProperties(
vm_size=self.size if self.size else ComputeDefaults.VMSIZE,
vm_priority=snake_to_pascal(self.tier),
user_account_credentials=self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None,
user_account_credentials=(self.ssh_settings._to_user_account_credentials() if self.ssh_settings else None),
scale_settings=scale_settings,
subnet=subnet_resource,
remote_login_port_public_access=remote_login_public_access,
enable_node_public_ip=self.enable_node_public_ip,
)

aml_comp = AmlComputeRest(description=self.description, compute_type=self.type, properties=aml_prop)
aml_comp = AmlComputeRest(
description=self.description,
compute_type=self.type,
properties=aml_prop,
disable_local_auth=disableLocalAuth,
)
return ComputeResource(
location=self.location,
properties=aml_comp,
Expand Down
39 changes: 36 additions & 3 deletions sdk/ml/azure-ai-ml/tests/compute/unittests/test_compute_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,24 @@ def test_compute_from_rest(self):

def _test_loaded_compute(self, compute: AmlCompute):
assert compute.name == "banchaml"
assert compute.ssh_settings.admin_username == "azureuser"
assert compute.identity.type == "user_assigned"
assert compute.type == "amlcompute"
assert compute.location == "eastus"
assert compute.description == "some_desc_aml"

def test_compute_from_yaml(self):
compute: AmlCompute = verify_entity_load_and_dump(
load_compute,
self._test_loaded_compute,
"tests/test_configs/compute/compute-aml.yaml",
)[0]
assert compute.location == "eastus"
assert compute.ssh_settings.admin_username == "azureuser"
assert compute.identity.type == "user_assigned"

rest_intermediate = compute._to_rest_object()
assert rest_intermediate.properties.compute_type == "AmlCompute"
assert rest_intermediate.properties.properties.user_account_credentials.admin_user_name == "azureuser"
assert rest_intermediate.properties.properties.enable_node_public_ip
assert rest_intermediate.properties.disable_local_auth is False
assert rest_intermediate.location == compute.location
assert rest_intermediate.tags is not None
assert rest_intermediate.tags["test"] == "true"
Expand All @@ -81,6 +84,36 @@ def test_compute_from_yaml(self):
)
assert body["location"] == compute.location

def test_aml_compute_from_yaml_with_disable_public_access(self):

compute: AmlCompute = verify_entity_load_and_dump(
load_compute,
self._test_loaded_compute,
"tests/test_configs/compute/compute-aml-disable-public-access.yaml",
)[0]

rest_intermediate = compute._to_rest_object()

assert rest_intermediate.properties.compute_type == "AmlCompute"
assert rest_intermediate.properties.properties.enable_node_public_ip
assert rest_intermediate.properties.disable_local_auth is True
assert rest_intermediate.location == compute.location

def test_aml_compute_from_yaml_with_disable_public_access_when_no_sshSettings(self):

compute: AmlCompute = verify_entity_load_and_dump(
load_compute,
self._test_loaded_compute,
"tests/test_configs/compute/compute-aml-public-access-no-ssh.yaml",
)[0]

rest_intermediate = compute._to_rest_object()

assert rest_intermediate.properties.compute_type == "AmlCompute"
assert rest_intermediate.properties.properties.enable_node_public_ip
assert rest_intermediate.properties.disable_local_auth is True
assert rest_intermediate.location == compute.location

def test_compute_vm_from_yaml(self):
resource_id = "/subscriptions/13e50845-67bc-4ac5-94db-48d493a6d9e8/resourceGroups/myrg/providers/Microsoft.Compute/virtualMachines/myvm"
fake_key = "myfakekey"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: banchaml
type: amlcompute
tier: dedicated
description: some_desc_aml
size: Standard_DS2_v2
min_instances: 0
max_instances: 2
location: eastus
idle_time_before_scale_down: 120
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: banchaml
type: amlcompute
description: some_desc_aml
size: Standard_DS2_v2
location: eastus
tags:
test: "true"
ssh_public_access_enabled: true
max_instances: 2
idle_time_before_scale_down: 100
enable_node_public_ip: true

0 comments on commit 6c0728a

Please sign in to comment.