diff --git a/pai/api/training_job.py b/pai/api/training_job.py index 4f56b02..5c99c08 100644 --- a/pai/api/training_job.py +++ b/pai/api/training_job.py @@ -19,6 +19,7 @@ AlgorithmSpec, CreateTrainingJobRequest, CreateTrainingJobRequestComputeResource, + CreateTrainingJobRequestComputeResourceInstanceSpec, CreateTrainingJobRequestHyperParameters, CreateTrainingJobRequestInputChannels, CreateTrainingJobRequestLabels, @@ -85,6 +86,8 @@ def create( instance_type, instance_count, job_name, + instance_spec: Optional[Dict[str, str]] = None, + resource_id: Optional[str] = None, hyperparameters: Optional[Dict[str, Any]] = None, input_channels: Optional[List[Dict[str, Any]]] = None, output_channels: Optional[List[Dict[str, Any]]] = None, @@ -119,10 +122,22 @@ def create( CreateTrainingJobRequestOutputChannels().from_map(ch) for ch in output_channels ] - compute_resource = CreateTrainingJobRequestComputeResource( - ecs_count=instance_count, - ecs_spec=instance_type, - ) + if instance_type: + compute_resource = CreateTrainingJobRequestComputeResource( + ecs_count=instance_count, + ecs_spec=instance_type, + ) + elif instance_spec: + compute_resource = CreateTrainingJobRequestComputeResource( + resource_id=resource_id, + instance_count=instance_count, + instance_spec=CreateTrainingJobRequestComputeResourceInstanceSpec().from_map( + instance_spec + ), + ) + else: + raise ValueError("Please provide instance_type or instance_spec.") + hyper_parameters = [ CreateTrainingJobRequestHyperParameters( name=name, diff --git a/pai/estimator.py b/pai/estimator.py index da85be6..dc9cd44 100644 --- a/pai/estimator.py +++ b/pai/estimator.py @@ -239,6 +239,8 @@ def __init__( output_path: Optional[str] = None, checkpoints_path: Optional[str] = None, instance_type: Optional[str] = None, + instance_spec: Optional[Dict] = None, + resource_id: Optional[Dict] = None, instance_count: Optional[int] = None, user_vpc_config: Optional[UserVpcConfig] = None, session: Optional[Session] = None, @@ -311,6 +313,8 @@ def __init__( """ self.hyperparameters = hyperparameters or dict() self.instance_type = instance_type + self.instance_spec = instance_spec + self.resource_id = resource_id self.instance_count = instance_count if instance_count else 1 self.max_run_time = max_run_time self.base_job_name = base_job_name @@ -320,8 +324,6 @@ def __init__( self.session = session or get_default_session() self._latest_training_job = None - self._check_instance_type() - def set_hyperparameters(self, **kwargs): """Set hyperparameters for the training job. @@ -335,16 +337,6 @@ def latest_training_job(self): """Return the latest submitted training job.""" return self._latest_training_job - def _check_instance_type(self): - """Check if the given instance_type is supported for training job.""" - if not is_local_run_instance_type( - self.instance_type - ) and not self.session.is_supported_training_instance(self.instance_type): - raise ValueError( - f"Instance type='{self.instance_type}' is not supported." - " Please provide a supported instance type to create the job." - ) - def _gen_job_display_name(self, job_name=None): """Generate job display name.""" if job_name: @@ -663,7 +655,9 @@ def __init__( instance_type: Optional[str] = None, instance_count: Optional[int] = None, user_vpc_config: Optional[UserVpcConfig] = None, + resource_id: Optional[str] = None, session: Optional[Session] = None, + **kwargs, ): """Estimator constructor. @@ -991,7 +985,9 @@ def _fit(self, job_name, inputs: Dict[str, Any] = None): training_job_id = self.session.training_job_api.create( instance_count=self.instance_count, + instance_spec=self.instance_spec, instance_type=self.instance_type, + resource_id=self.resource_id, job_name=job_name, hyperparameters=self.hyperparameters, max_running_in_seconds=self.max_run_time, @@ -1075,6 +1071,8 @@ def __init__( instance_count: Optional[int] = None, user_vpc_config: Optional[UserVpcConfig] = None, session: Optional[Session] = None, + instance_spec: Optional[Dict[str, Union[int, str]]] = None, + **kwargs, ): """Initialize an AlgorithmEstimator. @@ -1150,17 +1148,19 @@ def __init__( self.algorithm_provider = None self.algorithm_spec = algorithm_spec + if not instance_type and not instance_spec: + instance_type = self._get_default_training_instance_type() super(AlgorithmEstimator, self).__init__( hyperparameters=self._get_hyperparameters(hyperparameters), base_job_name=base_job_name, max_run_time=max_run_time, output_path=output_path, - instance_type=instance_type - if instance_type - else self._get_default_training_instance_type(), + instance_type=instance_type, instance_count=instance_count, session=session, user_vpc_config=user_vpc_config, + instance_spec=instance_spec, + **kwargs, ) # TODO: check if the hyperparameters are valid @@ -1223,14 +1223,6 @@ def _check_args( " The provided algorithm_spec will be ignored." ) - def _check_instance_type(self): - """Check if the given instance_type is supported for training job.""" - if not self.session.is_supported_training_instance(self.instance_type): - raise ValueError( - f"Instance type='{self.instance_type}' is not supported." - " Please provide a supported instance type to create the job." - ) - def _get_algo_version( self, algorithm_name: str, @@ -1391,6 +1383,8 @@ def _fit(self, job_name, inputs: Dict[str, Any] = None): training_job_id = self.session.training_job_api.create( instance_count=self.instance_count, instance_type=self.instance_type, + instance_spec=self.instance_spec, + resource_id=self.resource_id, job_name=job_name, hyperparameters=self.hyperparameters, max_running_in_seconds=self.max_run_time, diff --git a/pai/model.py b/pai/model.py index 609c354..70022c3 100644 --- a/pai/model.py +++ b/pai/model.py @@ -325,19 +325,20 @@ def mount( ) if "storage" in self._cfg_dict: - configs = self._cfg_dict.get("storage", []) + storages = copy.deepcopy(self._cfg_dict.get("storage", [])) else: - configs = [] + storages = [] + configs = [] uris = set() - for conf in configs: - # check if target mount path is already used. - if conf.get("mount_path") == mount_path: - raise MountPathIsOccupiedException( - f"The mount path '{mount_path}' has already been used." - ) - mount_uri = conf.get("oss", {}).get("path") - uris.add(mount_uri) + for s in storages: + # overwrite the existing mount path + if s.get("mount_path") == mount_path: + continue + oss_uri = s.get("oss", {}).get("path") + if oss_uri: + uris.add(oss_uri) + configs.append(s) if is_oss_uri(source): oss_uri_obj = OssUriObj(source) @@ -1758,6 +1759,7 @@ def get_estimator( base_job_name: Optional[str] = None, output_path: Optional[str] = None, max_run_time: Optional[int] = None, + **kwargs, ): """Generate an AlgorithmEstimator. @@ -1828,10 +1830,15 @@ def get_estimator( max_run_time = ts.get("Scheduler", {}).get("MaxRunningTimeInSeconds") train_compute_resource = ts.get("ComputeResource") - if train_compute_resource and (not instance_type or not instance_count): - # If instance_type or instance_count is not provided, use the default + instance_spec = kwargs.get("instance_spec") + if train_compute_resource: instance_type = instance_type or train_compute_resource.get("EcsSpec") - instance_count = instance_count or train_compute_resource.get("EcsCount") + instance_count = ( + instance_count + or train_compute_resource.get("EcsCount") + or train_compute_resource.get("InstanceCount") + ) + instance_spec = instance_spec or train_compute_resource.get("InstanceSpec") return AlgorithmEstimator( algorithm_name=algorithm_name, @@ -1843,7 +1850,9 @@ def get_estimator( max_run_time=max_run_time, instance_type=instance_type, instance_count=instance_count, + instance_spec=instance_spec, output_path=output_path, + **kwargs, ) def get_estimator_inputs(self) -> Dict[str, str]: diff --git a/tests/unit/test_inference_spec.py b/tests/unit/test_inference_spec.py index 49733ce..36eeeb1 100644 --- a/tests/unit/test_inference_spec.py +++ b/tests/unit/test_inference_spec.py @@ -94,3 +94,7 @@ def test_inference_spec(self): infer_spec.mount( "oss://pai-sdk-example/path/to/abc/", mount_path="/ml/code/" ) + + infer_spec.mount( + "oss://pai-sdk-example/path/to/abc/edfg", mount_path="/ml/code/", force=True + )