Skip to content

Commit

Permalink
Gen2download failure (Azure#38986)
Browse files Browse the repository at this point in the history
  • Loading branch information
achauhan-scc authored Dec 27, 2024
1 parent b52b61c commit c694e3f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
13 changes: 11 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_artifacts/_gen2_storage_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,19 @@

class Gen2StorageClient:
def __init__(self, credential: str, file_system: str, account_url: str):
service_client = DataLakeServiceClient(account_url=account_url, credential=credential)
self.account_name = account_url.split(".")[0].split("//")[1]
self.file_system = file_system
self.file_system_client = service_client.get_file_system_client(file_system=file_system)

try:
service_client = DataLakeServiceClient(account_url=account_url, credential=credential)
self.file_system_client = service_client.get_file_system_client(file_system=file_system)
except ValueError as e:
api_version = e.args[0].split("\n")[-1]
service_client = DataLakeServiceClient(
account_url=account_url, credential=credential, api_version=api_version
)
self.file_system_client = service_client.get_file_system_client(file_system=file_system)

try:
self.file_system_client.create_file_system()
except ResourceExistsError:
Expand Down
20 changes: 18 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/operations/_model_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from azure.ai.ml._utils._storage_utils import get_ds_name_and_path_prefix, get_storage_client
from azure.ai.ml._utils.utils import _is_evaluator, resolve_short_datastore_url, validate_ml_flow_folder
from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ID_FORMAT, REGISTRY_URI_FORMAT, AzureMLResourceType
from azure.ai.ml.entities import AzureDataLakeGen2Datastore
from azure.ai.ml.entities._assets import Environment, Model, ModelPackage
from azure.ai.ml.entities._assets._artifacts.code import Code
from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference
Expand Down Expand Up @@ -415,9 +416,24 @@ def download(self, name: str, version: str, download_path: Union[PathLike, str]
else:
raise e

container = ds.container_name
datastore_type = ds.type
if isinstance(ds, AzureDataLakeGen2Datastore):
container = ds.filesystem
try:
from azure.identity import ClientSecretCredential

token_credential = ClientSecretCredential(
tenant_id=ds.credentials["tenant_id"],
client_id=ds.credentials["client_id"],
client_secret=ds.credentials["client_secret"],
authority=ds.credentials["authority_url"],
)
credential = token_credential
except (KeyError, TypeError):
pass

else:
container = ds.container_name
datastore_type = ds.type
storage_client = get_storage_client(
credential=credential,
container_name=container,
Expand Down
57 changes: 57 additions & 0 deletions sdk/ml/azure-ai-ml/tests/model/unittests/test_model_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,63 @@ def test_restore_container(self, mock_model_operation: ModelOperations) -> None:
resource_group_name=mock_model_operation._resource_group_name,
)

def test_download_from_gen2_with_none_cred(self, mock_model_operation: ModelOperations) -> None:
name = "random_string"
version = "1"
model = Model(
name=name,
version=version,
path="azureml://subscriptions/subscription_id/resourcegroups/rg-name/workspaces/gen2test/datastores/adls_gen2/paths/gen2test/",
)
from azure.ai.ml.entities import AzureDataLakeGen2Datastore

datastore = AzureDataLakeGen2Datastore(name="gen2_datastore", account_name="gen2_account", filesystem="gen2")
storage_client = Mock()
with patch(
"azure.ai.ml.operations._model_operations.get_storage_client", return_value=storage_client
) as get_client_mock, patch(
"azure.ai.ml.operations._model_operations.Model._from_rest_object",
return_value=model,
), patch(
"azure.ai.ml.operations._model_operations.DatastoreOperations.get", return_value=datastore
):
mock_model_operation.download(name=name, version=version)
get_client_mock.assert_called_once()

def test_download_from_gen2_with_sp_cred(self, mock_model_operation: ModelOperations) -> None:
name = "random_string"
version = "1"
model = Model(
name=name,
version=version,
path="azureml://subscriptions/subscription_id/resourcegroups/rg-name/workspaces/gen2test/datastores/adls_gen2/paths/gen2test/",
)
from azure.ai.ml.entities import AzureDataLakeGen2Datastore
from azure.ai.ml.entities._credentials import ServicePrincipalConfiguration

datastore = AzureDataLakeGen2Datastore(
name="adls_gen2_example",
description="Datastore pointing to an Azure Data Lake Storage Gen2.",
account_name="mytestdatalakegen2",
filesystem="my-gen2-container",
credentials=ServicePrincipalConfiguration(
tenant_id="XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX",
client_id="XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX",
client_secret="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
),
)
storage_client = Mock()
with patch(
"azure.ai.ml.operations._model_operations.get_storage_client", return_value=storage_client
) as get_client_mock, patch(
"azure.ai.ml.operations._model_operations.Model._from_rest_object",
return_value=model,
), patch(
"azure.ai.ml.operations._model_operations.DatastoreOperations.get", return_value=datastore
):
mock_model_operation.download(name=name, version=version)
get_client_mock.assert_called_once()

def test_create_with_datastore(
self,
mock_workspace_scope: OperationScope,
Expand Down

0 comments on commit c694e3f

Please sign in to comment.