Skip to content

Commit

Permalink
Support dedicated resource in trainingjob
Browse files Browse the repository at this point in the history
Signed-off-by: pitt-liang <[email protected]>
  • Loading branch information
pitt-liang committed Jan 11, 2024
1 parent 06d5bac commit a58f92a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 40 deletions.
23 changes: 19 additions & 4 deletions pai/api/training_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AlgorithmSpec,
CreateTrainingJobRequest,
CreateTrainingJobRequestComputeResource,
CreateTrainingJobRequestComputeResourceInstanceSpec,
CreateTrainingJobRequestHyperParameters,
CreateTrainingJobRequestInputChannels,
CreateTrainingJobRequestLabels,
Expand Down Expand Up @@ -85,6 +86,8 @@ def create(
instance_type,
instance_count,
job_name,
instance_spec: Optional[Dict[str, str]] = None,
resource_id: Optional[str] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
input_channels: Optional[List[Dict[str, Any]]] = None,
output_channels: Optional[List[Dict[str, Any]]] = None,
Expand Down Expand Up @@ -119,10 +122,22 @@ def create(
CreateTrainingJobRequestOutputChannels().from_map(ch)
for ch in output_channels
]
compute_resource = CreateTrainingJobRequestComputeResource(
ecs_count=instance_count,
ecs_spec=instance_type,
)
if instance_type:
compute_resource = CreateTrainingJobRequestComputeResource(
ecs_count=instance_count,
ecs_spec=instance_type,
)
elif instance_spec:
compute_resource = CreateTrainingJobRequestComputeResource(
resource_id=resource_id,
instance_count=instance_count,
instance_spec=CreateTrainingJobRequestComputeResourceInstanceSpec().from_map(
instance_spec
),
)
else:
raise ValueError("Please provide instance_type or instance_spec.")

hyper_parameters = [
CreateTrainingJobRequestHyperParameters(
name=name,
Expand Down
40 changes: 17 additions & 23 deletions pai/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def __init__(
output_path: Optional[str] = None,
checkpoints_path: Optional[str] = None,
instance_type: Optional[str] = None,
instance_spec: Optional[Dict] = None,
resource_id: Optional[Dict] = None,
instance_count: Optional[int] = None,
user_vpc_config: Optional[UserVpcConfig] = None,
session: Optional[Session] = None,
Expand Down Expand Up @@ -311,6 +313,8 @@ def __init__(
"""
self.hyperparameters = hyperparameters or dict()
self.instance_type = instance_type
self.instance_spec = instance_spec
self.resource_id = resource_id
self.instance_count = instance_count if instance_count else 1
self.max_run_time = max_run_time
self.base_job_name = base_job_name
Expand All @@ -320,8 +324,6 @@ def __init__(
self.session = session or get_default_session()
self._latest_training_job = None

self._check_instance_type()

def set_hyperparameters(self, **kwargs):
"""Set hyperparameters for the training job.
Expand All @@ -335,16 +337,6 @@ def latest_training_job(self):
"""Return the latest submitted training job."""
return self._latest_training_job

def _check_instance_type(self):
"""Check if the given instance_type is supported for training job."""
if not is_local_run_instance_type(
self.instance_type
) and not self.session.is_supported_training_instance(self.instance_type):
raise ValueError(
f"Instance type='{self.instance_type}' is not supported."
" Please provide a supported instance type to create the job."
)

def _gen_job_display_name(self, job_name=None):
"""Generate job display name."""
if job_name:
Expand Down Expand Up @@ -663,7 +655,9 @@ def __init__(
instance_type: Optional[str] = None,
instance_count: Optional[int] = None,
user_vpc_config: Optional[UserVpcConfig] = None,
resource_id: Optional[str] = None,
session: Optional[Session] = None,
**kwargs,
):
"""Estimator constructor.
Expand Down Expand Up @@ -991,7 +985,9 @@ def _fit(self, job_name, inputs: Dict[str, Any] = None):

training_job_id = self.session.training_job_api.create(
instance_count=self.instance_count,
instance_spec=self.instance_spec,
instance_type=self.instance_type,
resource_id=self.resource_id,
job_name=job_name,
hyperparameters=self.hyperparameters,
max_running_in_seconds=self.max_run_time,
Expand Down Expand Up @@ -1075,6 +1071,8 @@ def __init__(
instance_count: Optional[int] = None,
user_vpc_config: Optional[UserVpcConfig] = None,
session: Optional[Session] = None,
instance_spec: Optional[Dict[str, Union[int, str]]] = None,
**kwargs,
):
"""Initialize an AlgorithmEstimator.
Expand Down Expand Up @@ -1150,17 +1148,19 @@ def __init__(
self.algorithm_provider = None
self.algorithm_spec = algorithm_spec

if not instance_type and not instance_spec:
instance_type = self._get_default_training_instance_type()
super(AlgorithmEstimator, self).__init__(
hyperparameters=self._get_hyperparameters(hyperparameters),
base_job_name=base_job_name,
max_run_time=max_run_time,
output_path=output_path,
instance_type=instance_type
if instance_type
else self._get_default_training_instance_type(),
instance_type=instance_type,
instance_count=instance_count,
session=session,
user_vpc_config=user_vpc_config,
instance_spec=instance_spec,
**kwargs,
)

# TODO: check if the hyperparameters are valid
Expand Down Expand Up @@ -1223,14 +1223,6 @@ def _check_args(
" The provided algorithm_spec will be ignored."
)

def _check_instance_type(self):
"""Check if the given instance_type is supported for training job."""
if not self.session.is_supported_training_instance(self.instance_type):
raise ValueError(
f"Instance type='{self.instance_type}' is not supported."
" Please provide a supported instance type to create the job."
)

def _get_algo_version(
self,
algorithm_name: str,
Expand Down Expand Up @@ -1391,6 +1383,8 @@ def _fit(self, job_name, inputs: Dict[str, Any] = None):
training_job_id = self.session.training_job_api.create(
instance_count=self.instance_count,
instance_type=self.instance_type,
instance_spec=self.instance_spec,
resource_id=self.resource_id,
job_name=job_name,
hyperparameters=self.hyperparameters,
max_running_in_seconds=self.max_run_time,
Expand Down
35 changes: 22 additions & 13 deletions pai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,19 +325,20 @@ def mount(
)

if "storage" in self._cfg_dict:
configs = self._cfg_dict.get("storage", [])
storages = copy.deepcopy(self._cfg_dict.get("storage", []))
else:
configs = []
storages = []

configs = []
uris = set()
for conf in configs:
# check if target mount path is already used.
if conf.get("mount_path") == mount_path:
raise MountPathIsOccupiedException(
f"The mount path '{mount_path}' has already been used."
)
mount_uri = conf.get("oss", {}).get("path")
uris.add(mount_uri)
for s in storages:
# overwrite the existing mount path
if s.get("mount_path") == mount_path:
continue
oss_uri = s.get("oss", {}).get("path")
if oss_uri:
uris.add(oss_uri)
configs.append(s)

if is_oss_uri(source):
oss_uri_obj = OssUriObj(source)
Expand Down Expand Up @@ -1758,6 +1759,7 @@ def get_estimator(
base_job_name: Optional[str] = None,
output_path: Optional[str] = None,
max_run_time: Optional[int] = None,
**kwargs,
):
"""Generate an AlgorithmEstimator.
Expand Down Expand Up @@ -1828,10 +1830,15 @@ def get_estimator(
max_run_time = ts.get("Scheduler", {}).get("MaxRunningTimeInSeconds")

train_compute_resource = ts.get("ComputeResource")
if train_compute_resource and (not instance_type or not instance_count):
# If instance_type or instance_count is not provided, use the default
instance_spec = kwargs.get("instance_spec")
if train_compute_resource:
instance_type = instance_type or train_compute_resource.get("EcsSpec")
instance_count = instance_count or train_compute_resource.get("EcsCount")
instance_count = (
instance_count
or train_compute_resource.get("EcsCount")
or train_compute_resource.get("InstanceCount")
)
instance_spec = instance_spec or train_compute_resource.get("InstanceSpec")

return AlgorithmEstimator(
algorithm_name=algorithm_name,
Expand All @@ -1843,7 +1850,9 @@ def get_estimator(
max_run_time=max_run_time,
instance_type=instance_type,
instance_count=instance_count,
instance_spec=instance_spec,
output_path=output_path,
**kwargs,
)

def get_estimator_inputs(self) -> Dict[str, str]:
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_inference_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ def test_inference_spec(self):
infer_spec.mount(
"oss://pai-sdk-example/path/to/abc/", mount_path="/ml/code/"
)

infer_spec.mount(
"oss://pai-sdk-example/path/to/abc/edfg", mount_path="/ml/code/", force=True
)

0 comments on commit a58f92a

Please sign in to comment.