Skip to content

Commit

Permalink
feat: setup default session in DSW
Browse files Browse the repository at this point in the history
  • Loading branch information
pitt-liang committed Aug 15, 2024
1 parent f3b759b commit c51845f
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 30 deletions.
148 changes: 143 additions & 5 deletions pai/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
import os.path
import posixpath
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union

import oss2
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.exceptions import CredentialException
from alibabacloud_credentials.models import Config as CredentialConfig
from alibabacloud_credentials.utils import auth_constant
from Tea.exceptions import TeaException

from .api.api_container import ResourceAPIsContainerMixin
from .common.consts import DEFAULT_CONFIG_PATH, Network
from .api.base import ServiceName
from .api.client_factory import ClientFactory
from .api.workspace import WorkspaceAPI, WorkspaceConfigKeys
from .common.consts import DEFAULT_CONFIG_PATH, PAI_VPC_ENDPOINT, Network
from .common.logging import get_logger
from .common.oss_utils import CredentialProviderWrapper, OssUriObj
from .common.utils import is_domain_connectable, make_list_resource_iterator
Expand Down Expand Up @@ -150,12 +156,81 @@ def get_default_session() -> "Session":
global _default_session
if not _default_session:
config = load_default_config_file()
if not config:
return
_default_session = Session(**config)
if config:
_default_session = Session(**config)
else:
_default_session = _init_default_session_from_env()
return _default_session


def _init_default_session_from_env() -> Optional["Session"]:
credential_client = Session._get_default_credential_client()
if not credential_client:
logger.debug("Not found credential from default credential provider chain.")
return

# legacy region id env var in DSW
region_id = os.getenv("dsw_region")
region_id = os.getenv("REGION", region_id)
if not region_id:
logger.debug(
"No region id found(env var: REGION or dsw_region), skip init default session"
)
return

dsw_instance_id = os.getenv("DSW_INSTANCE_ID")
if not dsw_instance_id:
logger.debug(
"No dsw instance id (env var: DSW_INSTANCE_ID) found, skip init default session"
)
return

workspace_id = os.getenv("PAI_AI_WORKSPACE_ID")
workspace_id = os.getenv("PAI_WORKSPACE_ID", workspace_id)

network = (
Network.VPC
if is_domain_connectable(
PAI_VPC_ENDPOINT,
timeout=1,
)
else Network.PUBLIC
)

if dsw_instance_id and not workspace_id:
logger.debug("Getting workspace id by dsw instance id: %s", dsw_instance_id)
workspace_id = Session._get_workspace_id_by_dsw_instance_id(
dsw_instance_id=dsw_instance_id,
cred=credential_client,
region_id=region_id,
network=network,
)
if not workspace_id:
logger.warning(
"Failed to get workspace id by dsw instance id: %s", dsw_instance_id
)
return
bucket_name, oss_endpoint = Session.get_default_oss_storage(
workspace_id, credential_client, region_id, network
)

if not bucket_name:
logger.warning(
"Default OSS storage is not configured for the workspace: %s", workspace_id
)

sess = Session(
region_id=region_id,
workspace_id=workspace_id,
credential_config=None,
oss_bucket_name=bucket_name,
oss_endpoint=oss_endpoint,
network=network,
)

return sess


def load_default_config_file() -> Optional[Dict[str, Any]]:
"""Read config file"""

Expand Down Expand Up @@ -451,3 +526,66 @@ def is_gpu_inference_instance(self, instance_type: str) -> bool:
"Please provide a supported instance type."
)
return bool(spec["GPU"])

@staticmethod
def get_default_oss_storage(
workspace_id: str, cred: CredentialClient, region_id: str, network: Network
) -> Tuple[Optional[str], Optional[str]]:
acs_ws_client = ClientFactory.create_client(
service_name=ServiceName.PAI_WORKSPACE,
credential_client=cred,
region_id=region_id,
network=network,
)
workspace_api = WorkspaceAPI(
acs_client=acs_ws_client,
)
resp = workspace_api.list_configs(
workspace_id=workspace_id,
config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI,
)
oss_storage_uri = next(
(
item["ConfigValue"]
for item in resp["Configs"]
if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI
),
None,
)

# Default OSS storage uri is not set.
if not oss_storage_uri:
return None, None
uri_obj = OssUriObj(oss_storage_uri)
if network == Network.VPC:
endpoint = "oss-{}-internal.aliyuncs.com".format(region_id)
else:
endpoint = "oss-{}.aliyuncs.com".format(region_id)
return uri_obj.bucket_name, endpoint

@staticmethod
def _get_default_credential_client() -> Optional[CredentialClient]:
try:
# Initialize the credential client with default credential chain.
# see: https://help.aliyun.com/zh/sdk/developer-reference/v2-manage-python-access-credentials#3ca299f04bw3c
return CredentialClient()
except CredentialException:
return

@staticmethod
def _get_workspace_id_by_dsw_instance_id(
dsw_instance_id: str, cred: CredentialClient, region_id: str, network: Network
) -> Optional[str]:
"""Get workspace id by dsw instance id"""
dsw_client = ClientFactory.create_client(
service_name=ServiceName.PAI_DSW,
credential_client=cred,
region_id=region_id,
network=network,
)
try:
resp = dsw_client.get_instance(dsw_instance_id)
return resp.body.workspace_id
except TeaException as e:
logger.warning("Failed to get instance info by dsw instance id: %s", e)
return
5 changes: 2 additions & 3 deletions pai/toolkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def workspace_choice_name(workspace: Dict[str, Any]):


def prompt_for_oss_bucket(user_profile: UserProfile, workspace_id: str):
default_storage_uri = user_profile.get_default_oss_storage_uri(
default_storage_uri, endpoint = user_profile.get_default_oss_storage_uri(
workspace_id=workspace_id
)
print(
Expand Down Expand Up @@ -667,7 +667,7 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile):
)
)

default_storage_uri = user_profile.get_default_oss_storage_uri(
default_storage_uri, endpoint = user_profile.get_default_oss_storage_uri(
workspace_id=workspace_id,
)

Expand All @@ -687,7 +687,6 @@ def prompt_config_with_default_dsw_role(user_profile: UserProfile):
bucket_name, endpoint = None, None
else:
bucket_name = OssUriObj(default_storage_uri).bucket_name
endpoint = f"oss-{user_profile.region_id}-internal.aliyuncs.com"
return workspace_id, bucket_name, endpoint


Expand Down
36 changes: 14 additions & 22 deletions pai/toolkit/helper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import locale
import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import oss2
from alibabacloud_credentials.client import Client as CredentialClient
Expand All @@ -41,6 +41,7 @@
from ...common.oss_utils import CredentialProviderWrapper, OssUriObj
from ...common.utils import is_domain_connectable, make_list_resource_iterator
from ...libs.alibabacloud_pai_dsw20220101.client import Client as DswClient
from ...session import Session

logger = get_logger(__name__)

Expand Down Expand Up @@ -137,9 +138,11 @@ def _get_caller_identity(self) -> CallerIdentity:
config=open_api_models.Config(
credential=self._get_credential_client(),
region_id=self.region_id,
network=None
if self.network == Network.PUBLIC
else self.network.value.lower(),
network=(
None
if self.network == Network.PUBLIC
else self.network.value.lower()
),
)
)
.get_caller_identity()
Expand Down Expand Up @@ -261,26 +264,15 @@ def get_workspace_api(self) -> WorkspaceAPI:
acs_client=acs_ws_client,
)

def get_default_oss_storage_uri(self, workspace_id: str):
workspace_api = self.get_workspace_api()
resp = workspace_api.list_configs(
def get_default_oss_storage_uri(
self, workspace_id: str
) -> Tuple[Optional[str], Optional[str]]:
return Session._get_default_oss_storage(
workspace_id=workspace_id,
config_keys=WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI,
)

oss_storage_uri = next(
(
item["ConfigValue"]
for item in resp["Configs"]
if item["ConfigKey"] == WorkspaceConfigKeys.DEFAULT_OSS_STORAGE_URI
),
None,
cred=self._get_credential_client(),
region_id=self.region_id,
network=self.network,
)
if not oss_storage_uri:
return

uri_obj = OssUriObj(oss_storage_uri)
return "oss://{}".format(uri_obj.bucket_name)

def set_default_oss_storage(
self, workspace_id, bucket_name: str, intranet_endpoint: str
Expand Down

0 comments on commit c51845f

Please sign in to comment.