Skip to content

Commit

Permalink
feat: setup default session in DSW environment (#40)
Browse files Browse the repository at this point in the history
* fix: correct vpc endpoint used for network detection

* feat: setup default session in DSW

* fix network config

* fix endpoint of pai-eas vpc network

* bump to dev version
  • Loading branch information
pitt-liang authored Aug 22, 2024
1 parent 8eaf8b3 commit 74d1362
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pai/api/api_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
else:
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
if is_domain_connectable(PAI_VPC_ENDPOINT.format(self._region_id))
else Network.PUBLIC
)

Expand Down
7 changes: 6 additions & 1 deletion pai/api/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def get_endpoint(
raise ValueError("Please provide region_id to get the endpoint.")

if network and network != Network.PUBLIC:
subdomain = f"{service_name}-{network.value.lower()}"
if service_name == "pai-eas":
# see endpoint list provided by PAI-EAS
# https://next.api.aliyun.com/product/eas
subdomain = f"pai-eas-manage-{network.value.lower()}"
else:
subdomain = f"{service_name}-{network.value.lower()}"
else:
subdomain = service_name
return DEFAULT_SERVICE_ENDPOINT_PATTERN.format(subdomain, region_id)
2 changes: 1 addition & 1 deletion pai/common/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DEFAULT_NETWORK_TYPE = os.environ.get("PAI_NETWORK_TYPE", None)

# PAI VPC endpoint
PAI_VPC_ENDPOINT = "pai-vpc.aliyuncs.com"
PAI_VPC_ENDPOINT = "pai-vpc.{}.aliyuncs.com"


class Network(enum.Enum):
Expand Down
148 changes: 143 additions & 5 deletions pai/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
import os.path
import posixpath
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import oss2
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.exceptions import CredentialException
from alibabacloud_credentials.models import Config as CredentialConfig
from alibabacloud_credentials.utils import auth_constant
from Tea.exceptions import TeaException

from .api.api_container import ResourceAPIsContainerMixin
from .common.consts import DEFAULT_CONFIG_PATH, Network
from .api.base import ServiceName
from .api.client_factory import ClientFactory
from .api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from .common.consts import DEFAULT_CONFIG_PATH, PAI_VPC_ENDPOINT, Network
from .common.logging import get_logger
from .common.oss_utils import CredentialProviderWrapper, OssUriObj
from .common.utils import is_domain_connectable, make_list_resource_iterator
Expand Down Expand Up @@ -150,12 +156,81 @@ def get_default_session() -> "Session":
global _default_session
if not _default_session:
config = load_default_config_file()
if not config:
return
_default_session = Session(**config)
if config:
_default_session = Session(**config)
else:
_default_session = _init_default_session_from_env()
return _default_session


def _init_default_session_from_env() -> Optional["Session"]:
credential_client = Session._get_default_credential_client()
if not credential_client:
logger.debug("Not found credential from default credential provider chain.")
return

# legacy region id env var in DSW
region_id = os.getenv("dsw_region")
region_id = os.getenv("REGION", region_id)
if not region_id:
logger.debug(
"No region id found(env var: REGION or dsw_region), skip init default session"
)
return

dsw_instance_id = os.getenv("DSW_INSTANCE_ID")
if not dsw_instance_id:
logger.debug(
"No dsw instance id (env var: DSW_INSTANCE_ID) found, skip init default session"
)
return

workspace_id = os.getenv("PAI_AI_WORKSPACE_ID")
workspace_id = os.getenv("PAI_WORKSPACE_ID", workspace_id)

network = (
Network.VPC
if is_domain_connectable(
PAI_VPC_ENDPOINT.format(region_id),
timeout=1,
)
else Network.PUBLIC
)

if dsw_instance_id and not workspace_id:
logger.debug("Getting workspace id by dsw instance id: %s", dsw_instance_id)
workspace_id = Session._get_workspace_id_by_dsw_instance_id(
dsw_instance_id=dsw_instance_id,
cred=credential_client,
region_id=region_id,
network=network,
)
if not workspace_id:
logger.warning(
"Failed to get workspace id by dsw instance id: %s", dsw_instance_id
)
return
bucket_name, oss_endpoint = Session.get_default_oss_storage(
workspace_id, credential_client, region_id, network
)

if not bucket_name:
logger.warning(
"Default OSS storage is not configured for the workspace: %s", workspace_id
)

sess = Session(
region_id=region_id,
workspace_id=workspace_id,
credential_config=None,
oss_bucket_name=bucket_name,
oss_endpoint=oss_endpoint,
network=network,
)

return sess


def load_default_config_file() -> Optional[Dict[str, Any]]:
"""Read config file"""

Expand Down Expand Up @@ -451,3 +526,66 @@ def is_gpu_inference_instance(self, instance_type: str) -> bool:
"Please provide a supported instance type."
)
return bool(spec["GPU"])

@staticmethod
def get_default_oss_storage(
workspace_id: str, cred: CredentialClient, region_id: str, network: Network
) -> Tuple[Optional[str], Optional[str]]:
acs_ws_client = ClientFactory.create_client(
service_name=ServiceName.PAI_WORKSPACE,
credential_client=cred,
region_id=region_id,
network=network,
)
workspace_api = WorkspaceAPI(
acs_client=acs_ws_client,
)
resp = workspace_api.list_configs(
workspace_id=workspace_id,
config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI,
)
oss_storage_uri = next(
(
item["ConfigValue"]
for item in resp["Configs"]
if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI
),
None,
)

# Default OSS storage uri is not set.
if not oss_storage_uri:
return None, None
uri_obj = OssUriObj(oss_storage_uri)
if network == Network.VPC:
endpoint = "oss-{}-internal.aliyuncs.com".format(region_id)
else:
endpoint = "oss-{}.aliyuncs.com".format(region_id)
return uri_obj.bucket_name, endpoint

@staticmethod
def _get_default_credential_client() -> Optional[CredentialClient]:
try:
# Initialize the credential client with default credential chain.
# see: https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c
return CredentialClient()
except CredentialException:
return

@staticmethod
def _get_workspace_id_by_dsw_instance_id(
dsw_instance_id: str, cred: CredentialClient, region_id: str, network: Network
) -> Optional[str]:
"""Get workspace id by dsw instance id"""
dsw_client = ClientFactory.create_client(
service_name=ServiceName.PAI_DSW,
credential_client=cred,
region_id=region_id,
network=network,
)
try:
resp = dsw_client.get_instance(dsw_instance_id)
return resp.body.workspace_id
except TeaException as e:
logger.warning("Failed to get instance info by dsw instance id: %s", e)
return
5 changes: 2 additions & 3 deletions pai/toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def workspace_choice_name(workspace: Dict[str, Any]):


def prompt_for_oss_bucket(user_profile: UserProfile, workspace_id: str):
default_storage_uri = user_profile.get_default_oss_storage_uri(
default_storage_uri, endpoint = user_profile.get_default_oss_storage_uri(
workspace_id=workspace_id
)
print(
Expand Down Expand Up @@ -667,7 +667,7 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile):
)
)

default_storage_uri = user_profile.get_default_oss_storage_uri(
default_storage_uri, endpoint = user_profile.get_default_oss_storage_uri(
workspace_id=workspace_id,
)

Expand All @@ -687,7 +687,6 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile):
bucket_name, endpoint = None, None
else:
bucket_name = OssUriObj(default_storage_uri).bucket_name
endpoint = f"oss-{user_profile.region_id}-internal.aliyuncs.com"
return workspace_id, bucket_name, endpoint


Expand Down
38 changes: 15 additions & 23 deletions pai/toolkit/helper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import locale
import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import oss2
from alibabacloud_credentials.client import Client as CredentialClient
Expand All @@ -41,6 +41,7 @@
from ...common.oss_utils import CredentialProviderWrapper, OssUriObj
from ...common.utils import is_domain_connectable, make_list_resource_iterator
from ...libs.alibabacloud_pai_dsw20220101.client import Client as DswClient
from ...session import Session

logger = get_logger(__name__)

Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__(
else:
self.network = (
Network.VPC
if is_domain_connectable(PAI_VPC_ENDPOINT)
if is_domain_connectable(PAI_VPC_ENDPOINT.format(self.region_id))
else Network.PUBLIC
)
self._caller_identify = self._get_caller_identity()
Expand All @@ -137,9 +138,11 @@ def _get_caller_identity(self) -> CallerIdentity:
config=open_api_models.Config(
credential=self._get_credential_client(),
region_id=self.region_id,
network=None
if self.network == Network.PUBLIC
else self.network.value.lower(),
network=(
None
if self.network == Network.PUBLIC
else self.network.value.lower()
),
)
)
.get_caller_identity()
Expand Down Expand Up @@ -261,26 +264,15 @@ def get_workspace_api(self) -> WorkspaceAPI:
acs_client=acs_ws_client,
)

def get_default_oss_storage_uri(self, workspace_id: str):
workspace_api = self.get_workspace_api()
resp = workspace_api.list_configs(
def get_default_oss_storage_uri(
self, workspace_id: str
) -> Tuple[Optional[str], Optional[str]]:
return Session._get_default_oss_storage(
workspace_id=workspace_id,
config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI,
)

oss_storage_uri = next(
(
item["ConfigValue"]
for item in resp["Configs"]
if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI
),
None,
cred=self._get_credential_client(),
region_id=self.region_id,
network=self.network,
)
if not oss_storage_uri:
return

uri_obj = OssUriObj(oss_storage_uri)
return "oss://{}".format(uri_obj.bucket_name)

def set_default_oss_storage(
self, workspace_id, bucket_name: str, intranet_endpoint: str
Expand Down
2 changes: 1 addition & 1 deletion pai/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

VERSION = "0.4.9.post0"
VERSION = "0.4.10.dev0"

0 comments on commit 74d1362

Please sign in to comment.