diff --git a/pai/api/api_container.py b/pai/api/api_container.py index 3faace8..ba24d58 100644 --- a/pai/api/api_container.py +++ b/pai/api/api_container.py @@ -82,7 +82,7 @@ def __init__( else: self.network = ( Network.VPC - if is_domain_connectable(PAI_VPC_ENDPOINT) + if is_domain_connectable(PAI_VPC_ENDPOINT.format(self._region_id)) else Network.PUBLIC ) diff --git a/pai/api/client_factory.py b/pai/api/client_factory.py index 07d3dd1..0f0552a 100644 --- a/pai/api/client_factory.py +++ b/pai/api/client_factory.py @@ -87,7 +87,12 @@ def get_endpoint( raise ValueError("Please provide region_id to get the endpoint.") if network and network != Network.PUBLIC: - subdomain = f"{service_name}-{network.value.lower()}" + if service_name == "pai-eas": + # see endpoint list provided by PAI-EAS + # https://next.api.aliyun.com/product/eas + subdomain = f"pai-eas-manage-{network.value.lower()}" + else: + subdomain = f"{service_name}-{network.value.lower()}" else: subdomain = service_name return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(subdomain, region_id) diff --git a/pai/common/consts.py b/pai/common/consts.py index b2cce95..c57c70c 100644 --- a/pai/common/consts.py +++ b/pai/common/consts.py @@ -23,7 +23,7 @@ DEFAULT_NETWORK_TYPE = os.environ.get("PAI_NETWORK_TYPE", None) # PAI VPC endpoint -PAI_VPC_ENDPOINT = "pai-vpc.aliyuncs.com" +PAI_VPC_ENDPOINT = "pai-vpc.{}.aliyuncs.com" class Network(enum.Enum): diff --git a/pai/session.py b/pai/session.py index 7db682d..ae84987 100644 --- a/pai/session.py +++ b/pai/session.py @@ -18,14 +18,20 @@ import os.path import posixpath from datetime import datetime -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import oss2 +from alibabacloud_credentials.client import Client as CredentialClient +from alibabacloud_credentials.exceptions import CredentialException from alibabacloud_credentials.models import Config as CredentialConfig from alibabacloud_credentials.utils import auth_constant +from Tea.exceptions import TeaException from .api.api_container import ResourceAPIsContainerMixin -from .common.consts import DEFAULT_CONFIG_PATH, Network +from .api.base import ServiceName +from .api.client_factory import ClientFactory +from .api.workspace import WorkspaceAPI, WorkspaceConfigKeys +from .common.consts import DEFAULT_CONFIG_PATH, PAI_VPC_ENDPOINT, Network from .common.logging import get_logger from .common.oss_utils import CredentialProviderWrapper, OssUriObj from .common.utils import is_domain_connectable, make_list_resource_iterator @@ -150,12 +156,81 @@ def get_default_session() -> "Session": global _default_session if not _default_session: config = load_default_config_file() - if not config: - return - _default_session = Session(**config) + if config: + _default_session = Session(**config) + else: + _default_session = _init_default_session_from_env() return _default_session +def _init_default_session_from_env() -> Optional["Session"]: + credential_client = Session._get_default_credential_client() + if not credential_client: + logger.debug("Not found credential from default credential provider chain.") + return + + # legacy region id env var in DSW + region_id = os.getenv("dsw_region") + region_id = os.getenv("REGION", region_id) + if not region_id: + logger.debug( + "No region id found(env var: REGION or dsw_region), skip init default session" + ) + return + + dsw_instance_id = os.getenv("DSW_INSTANCE_ID") + if not dsw_instance_id: + logger.debug( + "No dsw instance id (env var: DSW_INSTANCE_ID) found, skip init default session" + ) + return + + workspace_id = os.getenv("PAI_AI_WORKSPACE_ID") + workspace_id = os.getenv("PAI_WORKSPACE_ID", workspace_id) + + network = ( + Network.VPC + if is_domain_connectable( + PAI_VPC_ENDPOINT.format(region_id), + timeout=1, + ) + else Network.PUBLIC + ) + + if dsw_instance_id and not workspace_id: + logger.debug("Getting workspace id by dsw instance id: %s", dsw_instance_id) + workspace_id = Session._get_workspace_id_by_dsw_instance_id( + dsw_instance_id=dsw_instance_id, + cred=credential_client, + region_id=region_id, + network=network, + ) + if not workspace_id: + logger.warning( + "Failed to get workspace id by dsw instance id: %s", dsw_instance_id + ) + return + bucket_name, oss_endpoint = Session.get_default_oss_storage( + workspace_id, credential_client, region_id, network + ) + + if not bucket_name: + logger.warning( + "Default OSS storage is not configured for the workspace: %s", workspace_id + ) + + sess = Session( + region_id=region_id, + workspace_id=workspace_id, + credential_config=None, + oss_bucket_name=bucket_name, + oss_endpoint=oss_endpoint, + network=network, + ) + + return sess + + def load_default_config_file() -> Optional[Dict[str, Any]]: """Read config file""" @@ -451,3 +526,66 @@ def is_gpu_inference_instance(self, instance_type: str) -> bool: "Please provide a supported instance type." ) return bool(spec["GPU"]) + + @staticmethod + def get_default_oss_storage( + workspace_id: str, cred: CredentialClient, region_id: str, network: Network + ) -> Tuple[Optional[str], Optional[str]]: + acs_ws_client = ClientFactory.create_client( + service_name=ServiceName.PAI_WORKSPACE, + credential_client=cred, + region_id=region_id, + network=network, + ) + workspace_api = WorkspaceAPI( + acs_client=acs_ws_client, + ) + resp = workspace_api.list_configs( + workspace_id=workspace_id, + config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI, + ) + oss_storage_uri = next( + ( + item["ConfigValue"] + for item in resp["Configs"] + if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI + ), + None, + ) + + # Default OSS storage uri is not set. + if not oss_storage_uri: + return None, None + uri_obj = OssUriObj(oss_storage_uri) + if network == Network.VPC: + endpoint = "oss-{}-internal.aliyuncs.com".format(region_id) + else: + endpoint = "oss-{}.aliyuncs.com".format(region_id) + return uri_obj.bucket_name, endpoint + + @staticmethod + def _get_default_credential_client() -> Optional[CredentialClient]: + try: + # Initialize the credential client with default credential chain. + # see: https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c + return CredentialClient() + except CredentialException: + return + + @staticmethod + def _get_workspace_id_by_dsw_instance_id( + dsw_instance_id: str, cred: CredentialClient, region_id: str, network: Network + ) -> Optional[str]: + """Get workspace id by dsw instance id""" + dsw_client = ClientFactory.create_client( + service_name=ServiceName.PAI_DSW, + credential_client=cred, + region_id=region_id, + network=network, + ) + try: + resp = dsw_client.get_instance(dsw_instance_id) + return resp.body.workspace_id + except TeaException as e: + logger.warning("Failed to get instance info by dsw instance id: %s", e) + return diff --git a/pai/toolkit/config.py b/pai/toolkit/config.py index 538ad97..a03465a 100644 --- a/pai/toolkit/config.py +++ b/pai/toolkit/config.py @@ -392,7 +392,7 @@ def workspace_choice_name(workspace: Dict[str, Any]): def prompt_for_oss_bucket(user_profile: UserProfile, workspace_id: str): - default_storage_uri = user_profile.get_default_oss_storage_uri( + default_storage_uri, endpoint = user_profile.get_default_oss_storage_uri( workspace_id=workspace_id ) print( @@ -667,7 +667,7 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile): ) ) - default_storage_uri = user_profile.get_default_oss_storage_uri( + default_storage_uri, endpoint = user_profile.get_default_oss_storage_uri( workspace_id=workspace_id, ) @@ -687,7 +687,6 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile): bucket_name, endpoint = None, None else: bucket_name = OssUriObj(default_storage_uri).bucket_name - endpoint = f"oss-{user_profile.region_id}-internal.aliyuncs.com" return workspace_id, bucket_name, endpoint diff --git a/pai/toolkit/helper/utils.py b/pai/toolkit/helper/utils.py index 2089ccb..1a93eef 100644 --- a/pai/toolkit/helper/utils.py +++ b/pai/toolkit/helper/utils.py @@ -15,7 +15,7 @@ import locale import os import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import oss2 from alibabacloud_credentials.client import Client as CredentialClient @@ -41,6 +41,7 @@ from ...common.oss_utils import CredentialProviderWrapper, OssUriObj from ...common.utils import is_domain_connectable, make_list_resource_iterator from ...libs.alibabacloud_pai_dsw20220101.client import Client as DswClient +from ...session import Session logger = get_logger(__name__) @@ -111,7 +112,7 @@ def __init__( else: self.network = ( Network.VPC - if is_domain_connectable(PAI_VPC_ENDPOINT) + if is_domain_connectable(PAI_VPC_ENDPOINT.format(self.region_id)) else Network.PUBLIC ) self._caller_identify = self._get_caller_identity() @@ -137,9 +138,11 @@ def _get_caller_identity(self) -> CallerIdentity: config=open_api_models.Config( credential=self._get_credential_client(), region_id=self.region_id, - network=None - if self.network == Network.PUBLIC - else self.network.value.lower(), + network=( + None + if self.network == Network.PUBLIC + else self.network.value.lower() + ), ) ) .get_caller_identity() @@ -261,26 +264,15 @@ def get_workspace_api(self) -> WorkspaceAPI: acs_client=acs_ws_client, ) - def get_default_oss_storage_uri(self, workspace_id: str): - workspace_api = self.get_workspace_api() - resp = workspace_api.list_configs( + def get_default_oss_storage_uri( + self, workspace_id: str + ) -> Tuple[Optional[str], Optional[str]]: + return Session._get_default_oss_storage( workspace_id=workspace_id, - config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI, - ) - - oss_storage_uri = next( - ( - item["ConfigValue"] - for item in resp["Configs"] - if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI - ), - None, + cred=self._get_credential_client(), + region_id=self.region_id, + network=self.network, ) - if not oss_storage_uri: - return - - uri_obj = OssUriObj(oss_storage_uri) - return "oss://{}".format(uri_obj.bucket_name) def set_default_oss_storage( self, workspace_id, bucket_name: str, intranet_endpoint: str diff --git a/pai/version.py b/pai/version.py index 7764bf3..b669ad5 100644 --- a/pai/version.py +++ b/pai/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -VERSION = "0.4.9.post0" +VERSION = "0.4.10.dev0"