Skip to content

Commit

Permalink
feat: support dynamic mounting in DSW
Browse files Browse the repository at this point in the history
  • Loading branch information
pitt-liang committed Sep 8, 2024
1 parent 9649fc4 commit 46c9c99
Show file tree
Hide file tree
Showing 5 changed files with 10,478 additions and 3,832 deletions.
5 changes: 5 additions & 0 deletions pai/api/api_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ..common.consts import DEFAULT_NETWORK_TYPE, PAI_VPC_ENDPOINT, Network
from ..common.utils import is_domain_connectable
from ..libs.alibabacloud_pai_dsw20220101.client import Client as DswClient
from .algorithm import AlgorithmAPI
from .base import PAIRestResourceTypes, ServiceName, WorkspaceScopedResourceAPI
from .client_factory import ClientFactory
Expand Down Expand Up @@ -128,6 +129,10 @@ def _acs_training_client(self):
def _acs_sts_client(self) -> StsClient:
return self._get_acs_client(ServiceName.STS)

@property
def _acs_dsw_client(self) -> DswClient:
return self._get_acs_client(ServiceName.PAI_DSW)

def get_api_by_resource(self, resource_type):
if resource_type in self.api_container:
return self.api_container[resource_type]
Expand Down
226 changes: 226 additions & 0 deletions pai/dsw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import os
import posixpath
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from .common.logging import get_logger
from .common.oss_utils import OssUriObj, is_oss_uri
from .common.utils import is_dataset_id
from .libs.alibabacloud_pai_dsw20220101.models import (
GetInstanceResponse,
GetInstanceResponseBody,
UpdateInstanceRequest,
UpdateInstanceRequestDatasets,
)
from .session import Session, get_default_session

logger = get_logger()


class OptionType(str, Enum):
"""
The type of the options for Datasource used in DSW.
"""

FastReadWrite = "FastReadWrite"
IncrementalReadWrite = "IncrementalReadWrite"
ConsistentReadWrite = "ConsistentReadWrite"
ReadOnly = "ReadOnly"


def _default_instance() -> "DswInstance":
"""
Get the default DSW Instance.
Returns:
DswInstance: The default DSW Instance.
"""
instance_id = os.getenv("DSW_INSTANCE_ID")
if not instance_id:
raise RuntimeError(
"Environment variable 'DSW_INSTANCE_ID' is not set, please check if you are running in DSW environment"
)
return DswInstance(instance_id)


def mount(
source: str,
target: str = None,
options: Optional[Dict[str, Any]] = None,
option_type: Optional[OptionType] = None,
) -> str:
"""
Dynamic mount a data source to the DSW Instance.
Args:
source (str): The source to be mounted, can be a dataset id or an OSS uri.
target (str): Target mount point in the instance, if not specified, the
mount point be generate with given source under the default mount point.
options (dict): Options that apply to when mount a data source, can not be
specified with option_type.
option_type(str): Preset data source mount options, can not be specified with
options.
Returns:
str: The mount point of the data source.
"""
instance = _default_instance()
return instance.mount(
source,
target,
options=options,
option_type=option_type,
)


def list_datasets() -> List[Dict[str, Any]]:
"""
List all the datasets available in the DSW Instance.
Returns:
list: A list of dataset details.
"""
instance = _default_instance()

return [d.to_map() for d in instance._get_instance_info().datasets]


def default_dynamic_mount_point():
"""Get the default dynamic mount point of the DSW Instance.
Returns:
str: The default dynamic mount point of the DSW Instance.
"""
instance = _default_instance()
return instance.default_dynamic_mount_point()


def get_dynamic_mount_config() -> Dict[str, Any]:
"""
Get the dynamic mount config of the DSW Instance.
Returns:
dict: The dynamic mount config of the DSW Instance.
"""
instance = _default_instance()
return instance.get_dynamic_mount_config()


class DswInstance:
"""A object representing a DSW notebook instance"""

def __init__(self, instance_id: str):
self.instance_id = instance_id
self._instance_info: GetInstanceResponseBody = type(self)._get_instance_info(
instance_id
)

@classmethod
def _get_instance_info(self):
session = get_default_session()
resp: GetInstanceResponse = session._acs_dsw_client.get_instance(
self.instance_id
)
return resp.body

def get_dynamic_mount_config(self):
"""Get the dynamic mount config of the DSW Instance.
Returns:
dict: The dynamic mount config of the DSW Instance.
"""
return self._instance_info.dynamic_mount.to_map()

def default_dynamic_mount_point(self):
"""Get the default dynamic mount point of the DSW Instance.
Returns:
str: The default dynamic mount point of the DSW Instance.
"""
if not self._instance_info.dynamic_mount.enable:
raise RuntimeError(
"Dynamic mount is not enabled for the DSW instance: {}".format(
self.instance_id
)
)
if not self._instance_info.dynamic_mount.mount_points:
raise RuntimeError(
"No dynamic mount points found for the DSW instance: {}".format(
self.instance_id
)
)
return self._instance_info.dynamic_mount.mount_points[0].root_path

def mount(
self,
source: str,
mount_point: str = None,
options: Union[str] = None,
option_type: Union[OptionType] = None,
):
"""
Dynamic mount a data source to the DSW Instance.
Args:
source (str): The source to be mounted, can be a dataset id or an OSS uri.
mount_point (str): Target mount point in the instance, if not specified, the
mount point be generate with given source under the default mount point.
options (str): Options that apply to when mount a data source, can not be
specified with option_type.
option_type(str): Preset data source mount options, can not be specified with
options.
"""
if options and option_type:
raise ValueError(
"options and option_type cannot be specified at the same time"
)

sess = get_default_session()
default_root_path = self.default_dynamic_mount_point()

if is_oss_uri(source):
obj = OssUriObj(source)
if not obj.endpoint:
obj.endpoint = sess.oss_endpoint or sess._get_default_oss_endpoint()
# ensure mount source OSS uri is a directory
_, dir_path, _ = obj.parse_object_key()
uri = f"oss://{obj.bucket_name}.{obj.endpoint}{dir_path}"
dataset_id = None
else:
dataset_id = source
uri = None

if not is_oss_uri(source) and not is_dataset_id(source):
raise ValueError("Source must be oss uri or dataset id")

if not mount_point:
if is_oss_uri(source):
obj = OssUriObj(source)
mount_point = f"{obj.bucket_name}/{obj.object_key}"
else:
mount_point = source
if not posixpath.isabs(mount_point):
mount_point = posixpath.join(default_root_path, mount_point)

resp: GetInstanceResponse = sess._acs_dsw_client.get_instance(self.instance_id)
datasets = [
UpdateInstanceRequestDatasets.from_map(ds.to_map())
for ds in resp.body.datasets
]
datasets.append(
UpdateInstanceRequestDatasets(
dataset_id=dataset_id,
dynamic=True,
mount_path=mount_point,
option_type=option_type,
options=options,
uri=uri,
)
)
request = UpdateInstanceRequest(
datasets=datasets,
)
sess._acs_dsw_client.update_instance(
instance_id=self.instance_id, request=request
)
return mount_point
2 changes: 1 addition & 1 deletion pai/libs/alibabacloud_pai_dsw20220101/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.3.0'
__version__ = '1.5.4'
Loading

0 comments on commit 46c9c99

Please sign in to comment.