diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py index 0a5b1287e6d2..bf3e63a48bc2 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py @@ -36,10 +36,19 @@ class Gen2StorageClient: def __init__(self, credential: str, file_system: str, account_url: str): - service_client = DataLakeServiceClient(account_url=account_url, credential=credential) self.account_name = account_url.split(".")[0].split("//")[1] self.file_system = file_system - self.file_system_client = service_client.get_file_system_client(file_system=file_system) + + try: + service_client = DataLakeServiceClient(account_url=account_url, credential=credential) + self.file_system_client = service_client.get_file_system_client(file_system=file_system) + except ValueError as e: + api_version = e.args[0].split("\n")[-1] + service_client = DataLakeServiceClient( + account_url=account_url, credential=credential, api_version=api_version + ) + self.file_system_client = service_client.get_file_system_client(file_system=file_system) + try: self.file_system_client.create_file_system() except ResourceExistsError: diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py index c9f09436aa87..e6f0b2a00d3b 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py @@ -52,6 +52,7 @@ from azure.ai.ml._utils._storage_utils import get_ds_name_and_path_prefix, get_storage_client from azure.ai.ml._utils.utils import _is_evaluator, resolve_short_datastore_url, validate_ml_flow_folder from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ID_FORMAT, REGISTRY_URI_FORMAT, AzureMLResourceType +from azure.ai.ml.entities import AzureDataLakeGen2Datastore from azure.ai.ml.entities._assets import Environment, Model, ModelPackage from azure.ai.ml.entities._assets._artifacts.code import Code from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference @@ -415,9 +416,24 @@ def download(self, name: str, version: str, download_path: Union[PathLike, str] else: raise e - container = ds.container_name - datastore_type = ds.type + if isinstance(ds, AzureDataLakeGen2Datastore): + container = ds.filesystem + try: + from azure.identity import ClientSecretCredential + + token_credential = ClientSecretCredential( + tenant_id=ds.credentials["tenant_id"], + client_id=ds.credentials["client_id"], + client_secret=ds.credentials["client_secret"], + authority=ds.credentials["authority_url"], + ) + credential = token_credential + except (KeyError, TypeError): + pass + else: + container = ds.container_name + datastore_type = ds.type storage_client = get_storage_client( credential=credential, container_name=container, diff --git a/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_operations.py b/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_operations.py index 1a5c48d6cfe1..997beb651e3d 100644 --- a/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_operations.py +++ b/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_operations.py @@ -235,6 +235,63 @@ def test_restore_container(self, mock_model_operation: ModelOperations) -> None: resource_group_name=mock_model_operation._resource_group_name, ) + def test_download_from_gen2_with_none_cred(self, mock_model_operation: ModelOperations) -> None: + name = "random_string" + version = "1" + model = Model( + name=name, + version=version, + path="azureml://subscriptions/subscription_id/resourcegroups/rg-name/workspaces/gen2test/datastores/adls_gen2/paths/gen2test/", + ) + from azure.ai.ml.entities import AzureDataLakeGen2Datastore + + datastore = AzureDataLakeGen2Datastore(name="gen2_datastore", account_name="gen2_account", filesystem="gen2") + storage_client = Mock() + with patch( + "azure.ai.ml.operations._model_operations.get_storage_client", return_value=storage_client + ) as get_client_mock, patch( + "azure.ai.ml.operations._model_operations.Model._from_rest_object", + return_value=model, + ), patch( + "azure.ai.ml.operations._model_operations.DatastoreOperations.get", return_value=datastore + ): + mock_model_operation.download(name=name, version=version) + get_client_mock.assert_called_once() + + def test_download_from_gen2_with_sp_cred(self, mock_model_operation: ModelOperations) -> None: + name = "random_string" + version = "1" + model = Model( + name=name, + version=version, + path="azureml://subscriptions/subscription_id/resourcegroups/rg-name/workspaces/gen2test/datastores/adls_gen2/paths/gen2test/", + ) + from azure.ai.ml.entities import AzureDataLakeGen2Datastore + from azure.ai.ml.entities._credentials import ServicePrincipalConfiguration + + datastore = AzureDataLakeGen2Datastore( + name="adls_gen2_example", + description="Datastore pointing to an Azure Data Lake Storage Gen2.", + account_name="mytestdatalakegen2", + filesystem="my-gen2-container", + credentials=ServicePrincipalConfiguration( + tenant_id="XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX", + client_id="XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX", + client_secret="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", + ), + ) + storage_client = Mock() + with patch( + "azure.ai.ml.operations._model_operations.get_storage_client", return_value=storage_client + ) as get_client_mock, patch( + "azure.ai.ml.operations._model_operations.Model._from_rest_object", + return_value=model, + ), patch( + "azure.ai.ml.operations._model_operations.DatastoreOperations.get", return_value=datastore + ): + mock_model_operation.download(name=name, version=version) + get_client_mock.assert_called_once() + def test_create_with_datastore( self, mock_workspace_scope: OperationScope,