Skip to content

Commit

Permalink
feat: Storage/SharedMemorry configuration supports in EAS service (#34)
Browse files Browse the repository at this point in the history
* feat: eas storage config supports

* fix: fix `has_docker` check pending

* fix: fix AlgorithmEstimator init

* add RawStorageConfig used for eas storage configuration

* fix test case run for spot_instance

* fix ModelRecipe.run not return job instance

* add test case for custom args in model recipe

* add test case for RawStorageConfig
  • Loading branch information
pitt-liang authored Jul 12, 2024
1 parent 2644f04 commit ed3d434
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 16 deletions.
10 changes: 5 additions & 5 deletions pai/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,9 +1164,9 @@ def _get_hyperparameters(
if hps_def:
# Get default hyperparameters.
for hp in hps_def:
hp_name = hp.get("Name")
hp_value = hp.get("DefaultValue", "")
hp_type = hp.get("Type", "String")
hp_name = hp.name
hp_value = hp.default_value
hp_type = hp.type or "String"
# For hyperparameters with type INT or FLOAT, if the default value is
# empty, skip it.
if (
Expand Down Expand Up @@ -1232,8 +1232,8 @@ def fit(
`/ml/input/data/{channel_name}` directory in the training container.
wait (bool): Specifies whether to block until the training job is completed,
either succeeded, failed, or stopped. (Default True).
show_logs (bool): Specifies whether to show the logs produced by the
training job (Default True).
show_logs (bool): Whether to show the logs of the training job. Default to True.
Note that the logs will be shown only when the `wait` is set to True.
job_name (str, optional): The name of the training job.
Returns:
Expand Down
7 changes: 7 additions & 0 deletions pai/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DefaultServiceConfig,
ModelBase,
ResourceConfig,
StorageConfigBase,
container_serving_spec,
)
from ..serializers import SerializerBase
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
requirements: Optional[List[str]] = None,
requirements_path: Optional[str] = None,
health_check: Optional[Dict[str, Any]] = None,
storage_configs: Optional[List[StorageConfigBase]] = None,
session: Optional[Session] = None,
):
"""Initialize a HuggingFace Model.
Expand Down Expand Up @@ -144,6 +146,9 @@ def __init__(
health_check (Dict[str, Any], optional): The health check configuration. If it
not set, A TCP readiness probe will be used to check the health of the
Model server.
storage_configs (List[StorageConfigBase], optional): A list of storage configs
used to mount the storage to the container. The storage can be OSS, NFS,
SharedMemory, or NodeStorage, etc.
session (:class:`pai.session.Session`, optional): A pai session object
manages interactions with PAI REST API.
Expand All @@ -170,6 +175,7 @@ def __init__(
self.requirements = requirements
self.requirements_path = requirements_path
self.health_check = health_check
self.storage_configs = storage_configs

super(HuggingFaceModel, self).__init__(
model_data=self.model_data,
Expand Down Expand Up @@ -342,6 +348,7 @@ def deploy(
requirements=self.requirements,
requirements_path=self.requirements_path,
health_check=self.health_check,
storage_configs=self.storage_configs,
session=self.session,
)

Expand Down
1 change: 1 addition & 0 deletions pai/job/_training_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def _submit(
)
if wait:
training_job.wait(show_logs=show_logs)
return training_job

@classmethod
def _get_input_config(
Expand Down
12 changes: 12 additions & 0 deletions pai/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
InferenceSpec,
Model,
ModelFormat,
NfsStorageConfig,
NodeStorageConfig,
OssStorageConfig,
RawStorageConfig,
RegisteredModel,
ResourceConfig,
SharedMemoryConfig,
StorageConfigBase,
container_serving_spec,
)
from ._model_recipe import ModelRecipe, ModelRecipeType, ModelTrainingRecipe
Expand All @@ -33,4 +39,10 @@
"ModelTrainingRecipe",
"ModelRecipe",
"ModelRecipeType",
"StorageConfigBase",
"NfsStorageConfig",
"NodeStorageConfig",
"SharedMemoryConfig",
"OssStorageConfig",
"RawStorageConfig",
]
124 changes: 123 additions & 1 deletion pai/model/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import time
import typing
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import requests
Expand Down Expand Up @@ -74,6 +75,120 @@ class DefaultServiceConfig(object):
code_path = "/ml/usercode/"


class StorageConfigBase(metaclass=ABCMeta):
"""Base Storage Configuration."""

@abstractmethod
def to_dict(self):
pass


class RawStorageConfig(StorageConfigBase):
def __init__(self, config: Dict[str, Any]):
self.config = config

def to_dict(self):
return self.config


class OssStorageConfig(StorageConfigBase):
"""Configuration for OSS Storage."""

def __init__(
self, mount_path: str, oss_path: str, oss_endpoint: Optional[str] = None
) -> None:
"""
Args:
mount_path (str): The target path where the OSS storage will be mounted.
oss_path (str): The source OSS path, must start with `oss://`. e.g. `oss://bucket-name/path/to/data`.
oss_endpoint (Optional[str]): The endpoint address of the OSS bucket, if not provided,
the internal endpoint for the bucket will be used.
"""
self.mount_path = mount_path
self.oss_path = oss_path
self.oss_endpoint = oss_endpoint

def to_dict(self) -> Dict[str, Any]:
d = {
"mount_path": self.mount_path,
"oss": {"path": self.oss_path},
}

if self.oss_endpoint:
d["oss"]["endpoint"] = self.oss_endpoint
return d


class NfsStorageConfig(StorageConfigBase):
"""Configuration for NFS Storage."""

def __init__(
self,
mount_path: str,
nfs_server: str,
nfs_path: str = "/",
read_only: bool = False,
) -> None:
"""
Args:
mount_path (str): The target path where the NFS storage will be mounted.
nfs_server (str): The NFS server address. e.g. `xxx.cn-shanghai.nas.aliyuncs.com'
nfs_path (str): The source path in the NFS storage, default to '/'.
read_only (bool): Indicates if the NFS storage should be mounted as read-only, default to False.
"""
self.mount_path = mount_path
self.nfs_path = nfs_path
self.read_only = read_only
self.nfs_server = nfs_server

def to_dict(self) -> Dict[str, Any]:
return {
"mount_path": self.mount_path,
"nfs": {
"path": self.nfs_path,
"readOnly": self.read_only,
"server": self.nfs_server,
},
}


class NodeStorageConfig(StorageConfigBase):
"""Use to mount the local node disk storage to the container."""

def __init__(self, mount_path) -> None:
"""
Args:
mount_path (str): The target path where the node disk storage will be mounted.
"""
self.mount_path = mount_path

def to_dict(self) -> Dict[str, Any]:
return {
"empty_dir": {},
"mount_path": self.mount_path,
}


class SharedMemoryConfig(StorageConfigBase):
"""Use to configure the shared memory for the container."""

def __init__(self, size_limit: int) -> None:
"""
Args:
size_limit (int): Size limit of the shared memory, in GB.
"""
self.size_limit = size_limit

def to_dict(self) -> Dict[str, Any]:
return {
"empty_dir": {
"medium": "memory",
"size_limit": self.size_limit,
},
"mount_path": "/dev/shm",
}


class ResourceConfig(object):
"""A class that represents the resource used by a PAI prediction service
instance."""
Expand Down Expand Up @@ -465,6 +580,7 @@ def container_serving_spec(
requirements: Optional[List[str]] = None,
requirements_path: Optional[str] = None,
health_check: Optional[Dict[str, Any]] = None,
storage_configs: Optional[List[StorageConfigBase]] = None,
session: Optional[Session] = None,
) -> InferenceSpec:
"""A convenient function to create an InferenceSpec instance that serving the model
Expand Down Expand Up @@ -539,6 +655,9 @@ def container_serving_spec(
health_check (Dict[str, Any], optional): The health check configuration. If it
not set, A TCP readiness probe will be used to check the health of the
HTTP server.
storage_configs (List[StorageConfigBase], optional): A list of storage configs
used to mount the storage to the container. The storage can be OSS, NFS,
SharedMemory, or NodeStorage, etc.
session (Session, optional): A PAI session instance used for communicating
with PAI service.
Expand Down Expand Up @@ -619,9 +738,12 @@ def container_serving_spec(
container_spec["prepare"] = {
"pythonRequirementsPath": requirements_path,
}

inference_spec = InferenceSpec(containers=[container_spec])

if storage_configs:
storage = [s.to_dict() for s in storage_configs]
inference_spec.storage = storage

# mount the uploaded serving scripts to the serving container.
if source_dir:
inference_spec.mount(
Expand Down
9 changes: 8 additions & 1 deletion pai/model/_model_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ComputeResource,
DatasetConfig,
ExperimentConfig,
HyperParameterDefinition,
InstanceSpec,
ModelRecipeSpec,
OssLocation,
Expand Down Expand Up @@ -53,7 +54,7 @@ class RecipeInitKwargs(object):
model_channel_name: Optional[str]
model_uri: Optional[str]
hyperparameters: Optional[Dict[str, Any]]
# hyperparameter_definitions: Optional[List[HyperParameterDefinition]]
hyperparameter_definitions: Optional[List[HyperParameterDefinition]]
job_type: Optional[str]
image_uri: Optional[str]
source_dir: Optional[str]
Expand Down Expand Up @@ -161,6 +162,7 @@ def __init__(
self.supported_instance_types = init_kwargs.supported_instance_types
self.input_channels = init_kwargs.input_channels
self.output_channels = init_kwargs.output_channels
self.hyperparameter_definitions = init_kwargs.hyperparameter_definitions

super().__init__(
resource_type=resource_type,
Expand Down Expand Up @@ -249,6 +251,7 @@ def _init_kwargs(
default_inputs=default_inputs,
customization=customization,
supported_instance_types=supported_instance_types,
hyperparameter_definitions=None,
)
if not model_uri:
input_ = next(
Expand Down Expand Up @@ -281,6 +284,7 @@ def _init_kwargs(
supported_instance_types = (
supported_instance_types or model_recipe_spec.supported_instance_types
)
hyperparameter_definitions = None
if algorithm_spec:
if (
not source_dir
Expand All @@ -303,6 +307,7 @@ def _init_kwargs(
supported_instance_types = (
supported_instance_types or algorithm_spec.supported_channel_types
)
hyperparameter_definitions = algorithm_spec.hyperparameter_definitions

instance_type, instance_spec, instance_count = cls._get_compute_resource_config(
instance_type=instance_type,
Expand Down Expand Up @@ -352,6 +357,7 @@ def _init_kwargs(
default_inputs=default_inputs,
customization=customization,
supported_instance_types=supported_instance_types,
hyperparameter_definitions=hyperparameter_definitions,
)

@staticmethod
Expand Down Expand Up @@ -684,6 +690,7 @@ def train(
job_name (str, optional): The name of the training job. If not provided, a default
job name will be generated.
show_logs (bool): Whether to show the logs of the training job. Default to True.
Note that the logs will be shown only when the `wait` is set to True.
Returns:
:class:`pai.training.TrainingJob`: A submitted training job.
Expand Down
7 changes: 7 additions & 0 deletions pai/modelscope/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
DefaultServiceConfig,
ModelBase,
ResourceConfig,
StorageConfigBase,
container_serving_spec,
)
from ..serializers import SerializerBase
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
requirements: Optional[List[str]] = None,
requirements_path: Optional[str] = None,
health_check: Optional[Dict[str, Any]] = None,
storage_configs: Optional[List[StorageConfigBase]] = None,
session: Optional[Session] = None,
):
"""Initialize a ModelScope Model.
Expand Down Expand Up @@ -144,6 +146,9 @@ def __init__(
health_check (Dict[str, Any], optional): The health check configuration. If it
not set, A TCP readiness probe will be used to check the health of the
Model server.
storage_configs (List[StorageConfigBase], optional): A list of storage configs
used to mount the storage to the container. The storage can be OSS, NFS,
SharedMemory, or NodeStorage, etc.
session (:class:`pai.session.Session`, optional): A pai session object
manages interactions with PAI REST API.
Expand All @@ -168,6 +173,7 @@ def __init__(
self.requirements = requirements
self.requirements_path = requirements_path
self.health_check = health_check
self.storage_configs = storage_configs
super(ModelScopeModel, self).__init__(
model_data=self.model_data,
session=session,
Expand Down Expand Up @@ -340,6 +346,7 @@ def deploy(
requirements_path=self.requirements_path,
health_check=self.health_check,
session=self.session,
storage_configs=self.storage_configs,
)
return super(ModelScopeModel, self).deploy(
service_name=service_name,
Expand Down
4 changes: 2 additions & 2 deletions pai/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def run(
`/ml/outputs/data/{channel_name}` directory in the job container.
wait (bool): Specifies whether to block until the training job is completed,
either succeeded, failed, or stopped. (Default True).
show_logs (bool): Specifies whether to show the logs produced by the
job (Default True).
show_logs (bool): Whether to show the logs of the job. Default to True.
Note that the logs will be shown only when the `wait` is set to True.
Returns:
:class:`pai.job.TrainingJob`: A submitted training job.
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import posixpath
import re
from unittest import skipIf, skipUnless
from unittest import skipUnless

import pytest

Expand Down Expand Up @@ -75,7 +75,7 @@ def test_xgb_train(self):
model_path = os.path.join(os.path.join(est.model_data(), "model.json"))
self.assertTrue(self.is_oss_object_exists(model_path))

@skipIf(t_context.support_spot_instance, "Skip spot instance test")
@skipUnless(t_context.support_spot_instance, "Skip spot instance test")
def test_use_spot_instance(self):
xgb_image_uri = retrieve("xgboost", framework_version="latest").image_uri
est = Estimator(
Expand Down
Loading

0 comments on commit ed3d434

Please sign in to comment.