Skip to content

Commit

Permalink
Fix Azure Entra ID only datastore auth access (Azure#38967)
Browse files Browse the repository at this point in the history
- Fixes regression introduced when attempting to authenticate with datastores that are configured with shared key access disabled, and only accessible with Entra ID credentials (with the proper roles assigned)
- Always retrieves time limited SAS tokens now for datastores configured with shared key access as per latest Azure security recommendations
- Adds explicit tests for Entra ID only ("none") AI projects
- Updates test recordings

Bug 3686546
  • Loading branch information
ralph-msft authored Dec 23, 2024
1 parent 8259791 commit 487d0b4
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 16 deletions.
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -1382,6 +1382,7 @@
"fmeasure",
"upia",
"xpia",
"expirable",
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion sdk/evaluation/azure-ai-evaluation/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/evaluation/azure-ai-evaluation",
"Tag": "python/evaluation/azure-ai-evaluation_4f3f9f39dc"
"Tag": "python/evaluation/azure-ai-evaluation_326efc986d"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._models import BlobStoreInfo, Workspace


API_VERSION: Final[str] = "2024-10-01"
API_VERSION: Final[str] = "2024-07-01-preview"
QUERY_KEY_API_VERSION: Final[str] = "api-version"
PATH_ML_WORKSPACES = ("providers", "Microsoft.MachineLearningServices", "workspaces")

Expand Down Expand Up @@ -69,7 +69,9 @@ def get_credential(self) -> TokenCredential:
self._get_token_manager()
return cast(TokenCredential, self._credential)

def workspace_get_default_datastore(self, workspace_name: str, include_credentials: bool = False) -> BlobStoreInfo:
def workspace_get_default_datastore(
self, workspace_name: str, *, include_credentials: bool = False, **kwargs: Any
) -> BlobStoreInfo:
# 1. Get the default blob store
# REST API documentation:
# https://learn.microsoft.com/rest/api/azureml/datastores/list?view=rest-azureml-2024-10-01
Expand All @@ -92,18 +94,29 @@ def workspace_get_default_datastore(self, workspace_name: str, include_credentia
account_name = props_json["accountName"]
endpoint = props_json["endpoint"]
container_name = props_json["containerName"]
credential_type = props_json.get("credentials", {}).get("credentialsType")

# 2. Get the SAS token to use for accessing the blob store
# REST API documentation:
# https://learn.microsoft.com/rest/api/azureml/datastores/list-secrets?view=rest-azureml-2024-10-01
blob_store_credential: Optional[Union[AzureSasCredential, str]] = None
if include_credentials:
blob_store_credential: Optional[Union[AzureSasCredential, TokenCredential, str]]
if not include_credentials:
blob_store_credential = None
elif credential_type and credential_type.lower() == "none":
# If storage account key access is disabled, and only Microsoft Entra ID authentication is available,
# the credentialsType will be "None" and we should not attempt to get the secrets.
blob_store_credential = self.get_credential()
else:
url = self._generate_path(
*PATH_ML_WORKSPACES, workspace_name, "datastores", "workspaceblobstore", "listSecrets"
)
secrets_response = self._http_client.request(
method="POST",
url=url,
json={
"expirableSecret": True,
"expireAfterHours": int(kwargs.get("key_expiration_hours", 1)),
},
params={
QUERY_KEY_API_VERSION: self._api_version,
},
Expand All @@ -114,10 +127,13 @@ def workspace_get_default_datastore(self, workspace_name: str, include_credentia
secrets_json = secrets_response.json()
secrets_type = secrets_json["secretsType"].lower()

# As per this website, only SAS tokens, access tokens, or Entra IDs are valid for accessing blob data
# stores:
# https://learn.microsoft.com/rest/api/storageservices/authorize-requests-to-azure-storage.
if secrets_type == "sas":
blob_store_credential = AzureSasCredential(secrets_json["sasToken"])
elif secrets_type == "accountkey":
# To support olders versions of azure-storage-blob better, we return a string here instead of
# To support older versions of azure-storage-blob better, we return a string here instead of
# an AzureNamedKeyCredential
blob_store_credential = secrets_json["key"]
else:
Expand Down Expand Up @@ -164,19 +180,19 @@ def _throw_on_http_error(response: HttpResponse, description: str, valid_status:
# nothing to see here, move along
return

additional_info: Optional[str] = None
message = f"The {description} request failed with HTTP {response.status_code}"
try:
error_json = response.json()["error"]
additional_info = f"({error_json['code']}) {error_json['message']}"
message += f" - {additional_info}"
except (JSONDecodeError, ValueError, KeyError):
pass

raise EvaluationException(
message=f"The {description} request failed with HTTP {response.status_code}",
message=message,
target=ErrorTarget.EVALUATE,
category=ErrorCategory.FAILED_EXECUTION,
blame=ErrorBlame.SYSTEM_ERROR,
internal_message=additional_info,
)

def _generate_path(self, *paths: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

from typing import Dict, List, NamedTuple, Optional, Union
from msrest.serialization import Model
from azure.core.credentials import AzureSasCredential
from azure.core.credentials import AzureSasCredential, TokenCredential


class BlobStoreInfo(NamedTuple):
name: str
account_name: str
endpoint: str
container_name: str
credential: Optional[Union[AzureSasCredential, str]]
credential: Optional[Union[AzureSasCredential, TokenCredential, str]]


class WorkspaceHubConfig(Model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ def log_artifact(self, artifact_folder: str, artifact_name: str = EVALUATION_ART
local_paths.append(local_file_path)

# We will write the artifacts to the workspaceblobstore
datastore = self._management_client.workspace_get_default_datastore(self._workspace_name, True)
datastore = self._management_client.workspace_get_default_datastore(
self._workspace_name, include_credentials=True
)
account_url = f"{datastore.account_name}.blob.{datastore.endpoint}"

svc_client = BlobServiceClient(account_url=account_url, credential=datastore.credential)
Expand Down
16 changes: 16 additions & 0 deletions sdk/evaluation/azure-ai-evaluation/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,22 @@ def project_scope(request, dev_connections: Dict[str, Any]) -> dict:
return dev_connections[conn_name]["value"]


@pytest.fixture
def datastore_project_scopes(connection_file, project_scope, mock_project_scope):
conn_name = "azure_ai_entra_id_project_scope"
if not is_live():
entra_id = mock_project_scope
else:
entra_id = connection_file.get(conn_name)
if not entra_id:
raise ValueError(f"Connection '{conn_name}' not found in dev connections.")

return {
"sas": project_scope,
"none": entra_id,
}


@pytest.fixture
def mock_trace_destination_to_cloud(project_scope: dict):
"""Mock trace destination to cloud."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import logging
from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential
from azure.core.credentials import AzureSasCredential, TokenCredential
from azure.ai.evaluation._azure._clients import LiteMLClient


Expand Down Expand Up @@ -34,7 +34,12 @@ def test_get_token(self, project_scope, azure_cred):

@pytest.mark.azuretest
@pytest.mark.parametrize("include_credentials", [False, True])
def test_workspace_get_default_store(self, project_scope, azure_cred, include_credentials: bool):
@pytest.mark.parametrize("config_name", ["sas", "none"])
def test_workspace_get_default_store(
self, azure_cred, datastore_project_scopes, config_name: str, include_credentials: bool
):
project_scope = datastore_project_scopes[config_name]

client = LiteMLClient(
subscription_id=project_scope["subscription_id"],
resource_group=project_scope["resource_group_name"],
Expand All @@ -52,7 +57,11 @@ def test_workspace_get_default_store(self, project_scope, azure_cred, include_cr
assert store.endpoint
assert store.container_name
if include_credentials:
assert isinstance(store.credential, str) or isinstance(store.credential, AzureSasCredential)
assert (
(config_name == "account_key" and isinstance(store.credential, str))
or (config_name == "sas" and isinstance(store.credential, AzureSasCredential))
or (config_name == "none" and isinstance(store.credential, TokenCredential))
)
else:
assert store.credential == None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,23 @@ def test_logging_metrics(self, caplog, project_scope, azure_ml_client):
self._assert_no_errors_for_module(caplog.records, EvalRun.__module__)

@pytest.mark.azuretest
def test_log_artifact(self, project_scope, azure_ml_client, caplog, tmp_path):
@pytest.mark.parametrize("config_name", ["sas", "none"])
def test_log_artifact(self, project_scope, azure_cred, datastore_project_scopes, caplog, tmp_path, config_name):
"""Test uploading artifact to the service."""
logger = logging.getLogger(EvalRun.__module__)
# All loggers, having promptflow. prefix will have "promptflow" logger
# as a parent. This logger does not propagate the logs and cannot be
# captured by caplog. Here we will skip this logger to capture logs.
logger.parent = logging.root

project_scope = datastore_project_scopes[config_name]
azure_ml_client = LiteMLClient(
subscription_id=project_scope["subscription_id"],
resource_group=project_scope["resource_group_name"],
logger=logger,
credential=azure_cred,
)

with EvalRun(
run_name="test",
tracking_uri=_get_tracking_uri(azure_ml_client, project_scope),
Expand Down

0 comments on commit 487d0b4

Please sign in to comment.